cocktailpeanut commited on
Commit
4af5c1c
·
1 Parent(s): a0ef064
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -25,8 +25,10 @@ snapshot_download(
25
  MAX_SEED = np.iinfo(np.int32).max
26
  #device = "cuda" if torch.cuda.is_available() else "cpu"
27
  #dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
 
28
  if torch.cuda.is_available():
29
  device = torch.device("cuda")
 
30
  elif torch.backends.mps.is_available():
31
  device = torch.device("mps")
32
  else:
@@ -62,7 +64,7 @@ pipe.unet.load_state_dict(
62
  hf_hub_download(
63
  "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
64
  ),
65
- device=device,
66
  )
67
  )
68
 
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  #device = "cuda" if torch.cuda.is_available() else "cpu"
27
  #dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
+ load_device = "cpu"
29
  if torch.cuda.is_available():
30
  device = torch.device("cuda")
31
+ load_device = "cuda"
32
  elif torch.backends.mps.is_available():
33
  device = torch.device("mps")
34
  else:
 
64
  hf_hub_download(
65
  "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
66
  ),
67
+ device=load_device,
68
  )
69
  )
70