Spaces:
Runtime error
Runtime error
Update src/models/sd2_sr.py
Browse files- src/models/sd2_sr.py +11 -1
src/models/sd2_sr.py
CHANGED
@@ -187,6 +187,7 @@ def load_model(dtype=torch.bfloat16, device=None):
|
|
187 |
encoder.to(dtype=dtype, device=device)
|
188 |
encoder.device = device
|
189 |
|
|
|
190 |
ddim = DDIM(config, vae, encoder, unet)
|
191 |
|
192 |
params = {
|
@@ -205,4 +206,13 @@ def load_model(dtype=torch.bfloat16, device=None):
|
|
205 |
low_scale_model = low_scale_model.to(dtype=dtype, device=device)
|
206 |
|
207 |
ddim.low_scale_model = low_scale_model
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
encoder.to(dtype=dtype, device=device)
|
188 |
encoder.device = device
|
189 |
|
190 |
+
# Create a new DDIM instance with the updated components
|
191 |
ddim = DDIM(config, vae, encoder, unet)
|
192 |
|
193 |
params = {
|
|
|
206 |
low_scale_model = low_scale_model.to(dtype=dtype, device=device)
|
207 |
|
208 |
ddim.low_scale_model = low_scale_model
|
209 |
+
|
210 |
+
# Add a 'to' method to the DDIM object to make it compatible with the patching code
|
211 |
+
def ddim_to(self, *args, **kwargs):
|
212 |
+
# This is a no-op function since the components are already moved to the correct device
|
213 |
+
return self
|
214 |
+
|
215 |
+
# Add the 'to' method to the DDIM instance
|
216 |
+
ddim.to = ddim_to.__get__(ddim)
|
217 |
+
|
218 |
+
return ddim
|