Ashish1227 commited on
Commit
e552ad0
·
verified ·
1 Parent(s): d0c0e24

Update src/models/sd2_sr.py

Browse files
Files changed (1) hide show
  1. src/models/sd2_sr.py +11 -1
src/models/sd2_sr.py CHANGED
@@ -187,6 +187,7 @@ def load_model(dtype=torch.bfloat16, device=None):
187
  encoder.to(dtype=dtype, device=device)
188
  encoder.device = device
189
 
 
190
  ddim = DDIM(config, vae, encoder, unet)
191
 
192
  params = {
@@ -205,4 +206,13 @@ def load_model(dtype=torch.bfloat16, device=None):
205
  low_scale_model = low_scale_model.to(dtype=dtype, device=device)
206
 
207
  ddim.low_scale_model = low_scale_model
208
- return ddim
 
 
 
 
 
 
 
 
 
 
187
  encoder.to(dtype=dtype, device=device)
188
  encoder.device = device
189
 
190
+ # Create a new DDIM instance with the updated components
191
  ddim = DDIM(config, vae, encoder, unet)
192
 
193
  params = {
 
206
  low_scale_model = low_scale_model.to(dtype=dtype, device=device)
207
 
208
  ddim.low_scale_model = low_scale_model
209
+
210
+ # Add a 'to' method to the DDIM object to make it compatible with the patching code
211
+ def ddim_to(self, *args, **kwargs):
212
+ # This is a no-op function since the components are already moved to the correct device
213
+ return self
214
+
215
+ # Add the 'to' method to the DDIM instance
216
+ ddim.to = ddim_to.__get__(ddim)
217
+
218
+ return ddim