Spaces:
Running
Running
File size: 4,411 Bytes
dbd510a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
Handles the loading and management of necessary AI models from Hugging Face Hub.
Provides functions to load models once at startup and access them throughout
the application, managing device placement (CPU/GPU) and data types.
Optimized for typical Hugging Face Space GPU environments.
"""
import torch
from diffusers import ControlNetModel
from controlnet_aux import OpenposeDetector
import gc
# --- Configuration ---
# Automatically detect CUDA availability and set appropriate device/dtype
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.float16
print(f"CUDA available. Using Device: {DEVICE}, Dtype: {DTYPE}")
try:
print(f"GPU Name: {torch.cuda.get_device_name(0)}")
except Exception as e:
print(f"Couldn't get GPU name: {e}")
else:
DEVICE = "cpu"
DTYPE = torch.float32
print(f"CUDA not available. Using Device: {DEVICE}, Dtype: {DTYPE}")
# Model IDs from Hugging Face Hub
# BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Base SD model ID needed by pipelines
OPENPOSE_DETECTOR_ID = 'lllyasviel/ControlNet' # Preprocessor model repo
CONTROLNET_POSE_MODEL_ID = "lllyasviel/sd-controlnet-openpose" # OpenPose ControlNet weights
CONTROLNET_TILE_MODEL_ID = "lllyasviel/control_v11f1e_sd15_tile" # Tile ControlNet weights
_openpose_detector = None
_controlnet_pose = None
_controlnet_tile = None
_models_loaded = False
# --- Loading Function ---
def load_models(force_reload=False):
"""
Loads the OpenPose detector (to CPU) and ControlNet models (to configured DEVICE).
This function should typically be called once when the application starts.
It checks if models are already loaded to prevent redundant loading unless
`force_reload` is True.
Args:
force_reload (bool): If True, forces reloading even if models are already loaded.
Returns:
bool: True if all models were loaded successfully (or already were), False otherwise.
"""
global _openpose_detector, _controlnet_pose, _controlnet_tile, _models_loaded
if _models_loaded and not force_reload:
print("Models already loaded.")
return True
print(f"--- Loading Models ---")
if DEVICE == "cuda":
print("Performing initial CUDA cache clear...")
gc.collect()
torch.cuda.empty_cache()
# 1. OpenPose Detector
try:
print(f"Loading OpenPose Detector from {OPENPOSE_DETECTOR_ID} to CPU...")
_openpose_detector = OpenposeDetector.from_pretrained(OPENPOSE_DETECTOR_ID)
print("OpenPose detector loaded successfully (on CPU).")
except Exception as e:
print(f"ERROR: Failed to load OpenPose Detector: {e}")
_models_loaded = False
return False
# 2. ControlNet Models
try:
print(f"Loading ControlNet Pose Model from {CONTROLNET_POSE_MODEL_ID} to {DEVICE} ({DTYPE})...")
_controlnet_pose = ControlNetModel.from_pretrained(
CONTROLNET_POSE_MODEL_ID, torch_dtype=DTYPE
)
_controlnet_pose.to(DEVICE)
print("ControlNet Pose model loaded successfully.")
except Exception as e:
print(f"ERROR: Failed to load ControlNet Pose Model: {e}")
_models_loaded = False
return False
try:
print(f"Loading ControlNet Tile Model from {CONTROLNET_TILE_MODEL_ID} to {DEVICE} ({DTYPE})...")
_controlnet_tile = ControlNetModel.from_pretrained(
CONTROLNET_TILE_MODEL_ID, torch_dtype=DTYPE
)
_controlnet_tile.to(DEVICE)
print("ControlNet Tile model loaded successfully.")
except Exception as e:
print(f"ERROR: Failed to load ControlNet Tile Model: {e}")
_models_loaded = False
return False
_models_loaded = True
print("--- All prerequisite models loaded successfully. ---")
if DEVICE == "cuda":
print("Performing post-load CUDA cache clear...")
gc.collect()
torch.cuda.empty_cache()
return True
# --- Getter Functions ---
def get_openpose_detector():
if not _models_loaded: load_models()
return _openpose_detector
def get_controlnet_pose():
if not _models_loaded: load_models()
return _controlnet_pose
def get_controlnet_tile():
if not _models_loaded: load_models()
return _controlnet_tile
def get_device():
return DEVICE
def get_dtype():
return DTYPE
def are_models_loaded():
return _models_loaded |