Ashish1227 commited on
Commit
46bb761
·
verified ·
1 Parent(s): 6c62f0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -104,9 +104,11 @@ original_sr_load_model = models.sd2_sr.load_model
104
  def patched_sr_load_model(*args, **kwargs):
105
  kwargs['device'] = 'cpu'
106
  model = original_sr_load_model(*args, **kwargs)
107
- # Check if the model has a to method before calling it
108
- if hasattr(model, 'to'):
109
- model.to('cpu')
 
 
110
  return model
111
  models.sd2_sr.load_model = patched_sr_load_model
112
 
@@ -114,10 +116,13 @@ original_sam_load_model = models.sam.load_model
114
  def patched_sam_load_model(*args, **kwargs):
115
  kwargs['device'] = 'cpu'
116
  model = original_sam_load_model(*args, **kwargs)
117
- model.to('cpu')
 
 
118
  return model
119
  models.sam.load_model = patched_sam_load_model
120
 
 
121
  sr_model = models.sd2_sr.load_model(device='cpu')
122
  sam_predictor = models.sam.load_model(device='cpu')
123
 
@@ -125,7 +130,6 @@ inp_model_name = list(inpainting_models.keys())[0]
125
  inp_model = models.load_inpainting_model(
126
  inpainting_models[inp_model_name], device='cpu', cache=True)
127
 
128
-
129
  def set_model_from_name(new_inp_model_name):
130
  global inp_model
131
  global inp_model_name
 
104
  def patched_sr_load_model(*args, **kwargs):
105
  kwargs['device'] = 'cpu'
106
  model = original_sr_load_model(*args, **kwargs)
107
+ # Handle DDIM object which doesn't have to() method
108
+ if hasattr(model, 'model'): # If there's a main model component
109
+ model.model.to('cpu')
110
+ if hasattr(model, 'diffusion'): # Some diffusion models have this
111
+ model.diffusion.to('cpu')
112
  return model
113
  models.sd2_sr.load_model = patched_sr_load_model
114
 
 
116
  def patched_sam_load_model(*args, **kwargs):
117
  kwargs['device'] = 'cpu'
118
  model = original_sam_load_model(*args, **kwargs)
119
+ # SAM predictor doesn't have to() method but its model does
120
+ if hasattr(model, 'model'):
121
+ model.model.to('cpu')
122
  return model
123
  models.sam.load_model = patched_sam_load_model
124
 
125
+ # Load models with CPU
126
  sr_model = models.sd2_sr.load_model(device='cpu')
127
  sam_predictor = models.sam.load_model(device='cpu')
128
 
 
130
  inp_model = models.load_inpainting_model(
131
  inpainting_models[inp_model_name], device='cpu', cache=True)
132
 
 
133
  def set_model_from_name(new_inp_model_name):
134
  global inp_model
135
  global inp_model_name