Spaces:
Running
Running
""" | |
Handles the loading and management of necessary AI models from Hugging Face Hub. | |
Provides functions to load models once at startup and access them throughout | |
the application, managing device placement (CPU/GPU) and data types. | |
Optimized for typical Hugging Face Space GPU environments. | |
""" | |
import torch | |
from diffusers import ControlNetModel | |
from controlnet_aux import OpenposeDetector | |
import gc | |
# --- Configuration --- | |
# Automatically detect CUDA availability and set appropriate device/dtype | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
DTYPE = torch.float16 | |
print(f"CUDA available. Using Device: {DEVICE}, Dtype: {DTYPE}") | |
try: | |
print(f"GPU Name: {torch.cuda.get_device_name(0)}") | |
except Exception as e: | |
print(f"Couldn't get GPU name: {e}") | |
else: | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
print(f"CUDA not available. Using Device: {DEVICE}, Dtype: {DTYPE}") | |
# Model IDs from Hugging Face Hub | |
# BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Base SD model ID needed by pipelines | |
OPENPOSE_DETECTOR_ID = 'lllyasviel/ControlNet' # Preprocessor model repo | |
CONTROLNET_POSE_MODEL_ID = "lllyasviel/sd-controlnet-openpose" # OpenPose ControlNet weights | |
CONTROLNET_TILE_MODEL_ID = "lllyasviel/control_v11f1e_sd15_tile" # Tile ControlNet weights | |
_openpose_detector = None | |
_controlnet_pose = None | |
_controlnet_tile = None | |
_models_loaded = False | |
# --- Loading Function --- | |
def load_models(force_reload=False): | |
""" | |
Loads the OpenPose detector (to CPU) and ControlNet models (to configured DEVICE). | |
This function should typically be called once when the application starts. | |
It checks if models are already loaded to prevent redundant loading unless | |
`force_reload` is True. | |
Args: | |
force_reload (bool): If True, forces reloading even if models are already loaded. | |
Returns: | |
bool: True if all models were loaded successfully (or already were), False otherwise. | |
""" | |
global _openpose_detector, _controlnet_pose, _controlnet_tile, _models_loaded | |
if _models_loaded and not force_reload: | |
print("Models already loaded.") | |
return True | |
print(f"--- Loading Models ---") | |
if DEVICE == "cuda": | |
print("Performing initial CUDA cache clear...") | |
gc.collect() | |
torch.cuda.empty_cache() | |
# 1. OpenPose Detector | |
try: | |
print(f"Loading OpenPose Detector from {OPENPOSE_DETECTOR_ID} to CPU...") | |
_openpose_detector = OpenposeDetector.from_pretrained(OPENPOSE_DETECTOR_ID) | |
print("OpenPose detector loaded successfully (on CPU).") | |
except Exception as e: | |
print(f"ERROR: Failed to load OpenPose Detector: {e}") | |
_models_loaded = False | |
return False | |
# 2. ControlNet Models | |
try: | |
print(f"Loading ControlNet Pose Model from {CONTROLNET_POSE_MODEL_ID} to {DEVICE} ({DTYPE})...") | |
_controlnet_pose = ControlNetModel.from_pretrained( | |
CONTROLNET_POSE_MODEL_ID, torch_dtype=DTYPE | |
) | |
_controlnet_pose.to(DEVICE) | |
print("ControlNet Pose model loaded successfully.") | |
except Exception as e: | |
print(f"ERROR: Failed to load ControlNet Pose Model: {e}") | |
_models_loaded = False | |
return False | |
try: | |
print(f"Loading ControlNet Tile Model from {CONTROLNET_TILE_MODEL_ID} to {DEVICE} ({DTYPE})...") | |
_controlnet_tile = ControlNetModel.from_pretrained( | |
CONTROLNET_TILE_MODEL_ID, torch_dtype=DTYPE | |
) | |
_controlnet_tile.to(DEVICE) | |
print("ControlNet Tile model loaded successfully.") | |
except Exception as e: | |
print(f"ERROR: Failed to load ControlNet Tile Model: {e}") | |
_models_loaded = False | |
return False | |
_models_loaded = True | |
print("--- All prerequisite models loaded successfully. ---") | |
if DEVICE == "cuda": | |
print("Performing post-load CUDA cache clear...") | |
gc.collect() | |
torch.cuda.empty_cache() | |
return True | |
# --- Getter Functions --- | |
def get_openpose_detector(): | |
if not _models_loaded: load_models() | |
return _openpose_detector | |
def get_controlnet_pose(): | |
if not _models_loaded: load_models() | |
return _controlnet_pose | |
def get_controlnet_tile(): | |
if not _models_loaded: load_models() | |
return _controlnet_tile | |
def get_device(): | |
return DEVICE | |
def get_dtype(): | |
return DTYPE | |
def are_models_loaded(): | |
return _models_loaded |