Spaces:
Runtime error
Runtime error
Update src/models/sd2_sr.py
Browse files- src/models/sd2_sr.py +9 -5
src/models/sd2_sr.py
CHANGED
@@ -153,16 +153,20 @@ def load_obj(path):
|
|
153 |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
154 |
|
155 |
|
156 |
-
def load_model(dtype=torch.bfloat16, device=
|
|
|
|
|
|
|
|
|
157 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
158 |
|
159 |
state_dict = safetensors.torch.load_file(MODEL_PATH)
|
160 |
|
161 |
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
|
162 |
|
163 |
-
unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().
|
164 |
-
vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().
|
165 |
-
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().
|
166 |
ddim = DDIM(config, vae, encoder, unet)
|
167 |
|
168 |
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
@@ -193,7 +197,7 @@ def load_model(dtype=torch.bfloat16, device='cuda:0'):
|
|
193 |
'max_noise_level': 350
|
194 |
}
|
195 |
|
196 |
-
low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to(
|
197 |
low_scale_model.train = disabled_train
|
198 |
for param in low_scale_model.parameters():
|
199 |
param.requires_grad = False
|
|
|
153 |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
154 |
|
155 |
|
156 |
+
def load_model(dtype=torch.bfloat16, device=None):
|
157 |
+
# Auto-detect device if not specified
|
158 |
+
if device is None:
|
159 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
160 |
+
|
161 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
162 |
|
163 |
state_dict = safetensors.torch.load_file(MODEL_PATH)
|
164 |
|
165 |
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
|
166 |
|
167 |
+
unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().to(device)
|
168 |
+
vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().to(device)
|
169 |
+
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().to(device)
|
170 |
ddim = DDIM(config, vae, encoder, unet)
|
171 |
|
172 |
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
|
|
|
197 |
'max_noise_level': 350
|
198 |
}
|
199 |
|
200 |
+
low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to(device)
|
201 |
low_scale_model.train = disabled_train
|
202 |
for param in low_scale_model.parameters():
|
203 |
param.requires_grad = False
|