Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,7 +52,8 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
| 52 |
if(radio == "draw a mask above"):
|
| 53 |
#with autocast("cuda"):
|
| 54 |
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
| 55 |
-
with autocast(enabled=True, dtype=torch.bfloat16):
|
|
|
|
| 56 |
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
| 57 |
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
| 58 |
else:
|
|
@@ -71,7 +72,8 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
| 71 |
os.remove(filename)
|
| 72 |
#with autocast("cuda"):
|
| 73 |
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
| 74 |
-
with autocast(enabled=True, dtype=torch.bfloat16):
|
|
|
|
| 75 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
| 76 |
return images[0]
|
| 77 |
|
|
|
|
| 52 |
if(radio == "draw a mask above"):
|
| 53 |
#with autocast("cuda"):
|
| 54 |
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
| 55 |
+
#with autocast(enabled=True, dtype=torch.bfloat16):
|
| 56 |
+
with torch.cuda.amp.autocast(True):
|
| 57 |
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
| 58 |
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
| 59 |
else:
|
|
|
|
| 72 |
os.remove(filename)
|
| 73 |
#with autocast("cuda"):
|
| 74 |
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
| 75 |
+
#with autocast(enabled=True, dtype=torch.bfloat16):
|
| 76 |
+
with torch.cuda.amp.autocast(True):
|
| 77 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
| 78 |
return images[0]
|
| 79 |
|