File size: 4,411 Bytes
dbd510a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Handles the loading and management of necessary AI models from Hugging Face Hub.

Provides functions to load models once at startup and access them throughout
the application, managing device placement (CPU/GPU) and data types.
Optimized for typical Hugging Face Space GPU environments.
"""

import torch
from diffusers import ControlNetModel
from controlnet_aux import OpenposeDetector
import gc

# --- Configuration ---
# Automatically detect CUDA availability and set appropriate device/dtype
if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.float16
    print(f"CUDA available. Using Device: {DEVICE}, Dtype: {DTYPE}")
    try:
        print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    except Exception as e:
        print(f"Couldn't get GPU name: {e}")
else:
    DEVICE = "cpu"
    DTYPE = torch.float32
    print(f"CUDA not available. Using Device: {DEVICE}, Dtype: {DTYPE}")


# Model IDs from Hugging Face Hub
# BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Base SD model ID needed by pipelines
OPENPOSE_DETECTOR_ID = 'lllyasviel/ControlNet' # Preprocessor model repo
CONTROLNET_POSE_MODEL_ID = "lllyasviel/sd-controlnet-openpose" # OpenPose ControlNet weights
CONTROLNET_TILE_MODEL_ID = "lllyasviel/control_v11f1e_sd15_tile" # Tile ControlNet weights

_openpose_detector = None
_controlnet_pose = None
_controlnet_tile = None
_models_loaded = False

# --- Loading Function ---

def load_models(force_reload=False):
    """
    Loads the OpenPose detector (to CPU) and ControlNet models (to configured DEVICE).

    This function should typically be called once when the application starts.
    It checks if models are already loaded to prevent redundant loading unless
    `force_reload` is True.

    Args:
        force_reload (bool): If True, forces reloading even if models are already loaded.

    Returns:
        bool: True if all models were loaded successfully (or already were), False otherwise.
    """
    global _openpose_detector, _controlnet_pose, _controlnet_tile, _models_loaded

    if _models_loaded and not force_reload:
        print("Models already loaded.")
        return True

    print(f"--- Loading Models ---")
    if DEVICE == "cuda":
        print("Performing initial CUDA cache clear...")
        gc.collect()
        torch.cuda.empty_cache()

    # 1. OpenPose Detector
    try:
        print(f"Loading OpenPose Detector from {OPENPOSE_DETECTOR_ID} to CPU...")
        _openpose_detector = OpenposeDetector.from_pretrained(OPENPOSE_DETECTOR_ID)
        print("OpenPose detector loaded successfully (on CPU).")
    except Exception as e:
        print(f"ERROR: Failed to load OpenPose Detector: {e}")
        _models_loaded = False
        return False

    # 2. ControlNet Models
    try:
        print(f"Loading ControlNet Pose Model from {CONTROLNET_POSE_MODEL_ID} to {DEVICE} ({DTYPE})...")
        _controlnet_pose = ControlNetModel.from_pretrained(
            CONTROLNET_POSE_MODEL_ID, torch_dtype=DTYPE
        )
        _controlnet_pose.to(DEVICE)
        print("ControlNet Pose model loaded successfully.")
    except Exception as e:
        print(f"ERROR: Failed to load ControlNet Pose Model: {e}")
        _models_loaded = False
        return False

    try:
        print(f"Loading ControlNet Tile Model from {CONTROLNET_TILE_MODEL_ID} to {DEVICE} ({DTYPE})...")
        _controlnet_tile = ControlNetModel.from_pretrained(
            CONTROLNET_TILE_MODEL_ID, torch_dtype=DTYPE
        )
        _controlnet_tile.to(DEVICE)
        print("ControlNet Tile model loaded successfully.")
    except Exception as e:
        print(f"ERROR: Failed to load ControlNet Tile Model: {e}")
        _models_loaded = False
        return False

    _models_loaded = True
    print("--- All prerequisite models loaded successfully. ---")
    if DEVICE == "cuda":
        print("Performing post-load CUDA cache clear...")
        gc.collect()
        torch.cuda.empty_cache()
    return True

# --- Getter Functions ---

def get_openpose_detector():
    if not _models_loaded: load_models()
    return _openpose_detector

def get_controlnet_pose():
    if not _models_loaded: load_models()
    return _controlnet_pose

def get_controlnet_tile():
    if not _models_loaded: load_models()
    return _controlnet_tile

def get_device():
    return DEVICE

def get_dtype():
    return DTYPE

def are_models_loaded():
    return _models_loaded