Ashish1227 commited on
Commit
d0c0e24
·
verified ·
1 Parent(s): 1559f22

Update src/models/sd2_sr.py

Browse files
Files changed (1) hide show
  1. src/models/sd2_sr.py +9 -5
src/models/sd2_sr.py CHANGED
@@ -153,16 +153,20 @@ def load_obj(path):
153
  return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
154
 
155
 
156
- def load_model(dtype=torch.bfloat16, device='cuda:0'):
 
 
 
 
157
  download_file(DOWNLOAD_URL, MODEL_PATH)
158
 
159
  state_dict = safetensors.torch.load_file(MODEL_PATH)
160
 
161
  config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
162
 
163
- unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().cuda()
164
- vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().cuda()
165
- encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
166
  ddim = DDIM(config, vae, encoder, unet)
167
 
168
  extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
@@ -193,7 +197,7 @@ def load_model(dtype=torch.bfloat16, device='cuda:0'):
193
  'max_noise_level': 350
194
  }
195
 
196
- low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to('cuda')
197
  low_scale_model.train = disabled_train
198
  for param in low_scale_model.parameters():
199
  param.requires_grad = False
 
153
  return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
154
 
155
 
156
+ def load_model(dtype=torch.bfloat16, device=None):
157
+ # Auto-detect device if not specified
158
+ if device is None:
159
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
160
+
161
  download_file(DOWNLOAD_URL, MODEL_PATH)
162
 
163
  state_dict = safetensors.torch.load_file(MODEL_PATH)
164
 
165
  config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
166
 
167
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().to(device)
168
+ vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().to(device)
169
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().to(device)
170
  ddim = DDIM(config, vae, encoder, unet)
171
 
172
  extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
 
197
  'max_noise_level': 350
198
  }
199
 
200
+ low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to(device)
201
  low_scale_model.train = disabled_train
202
  for param in low_scale_model.parameters():
203
  param.requires_grad = False