Spaces:
Runtime error
Runtime error
bring back inits
Browse files
app.py
CHANGED
|
@@ -118,26 +118,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
| 118 |
std=[0.26862954, 0.26130258, 0.27577711])
|
| 119 |
|
| 120 |
|
| 121 |
-
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
|
| 122 |
all_frames = []
|
| 123 |
prompts = [text]
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
batch_size = 1
|
| 129 |
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
|
| 130 |
tv_scale = tv_scale # Controls the smoothness of the final output.
|
| 131 |
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
|
| 132 |
cutn = cutn
|
| 133 |
n_batches = 1
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
|
| 139 |
# Higher values make the output look more like the init.
|
| 140 |
-
|
| 141 |
seed = seed
|
| 142 |
|
| 143 |
if seed is not None:
|
|
@@ -149,25 +148,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
| 149 |
txt, weight = parse_prompt(prompt)
|
| 150 |
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
| 151 |
weights.append(weight)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
target_embeds = torch.cat(target_embeds)
|
| 161 |
weights = torch.tensor(weights, device=device)
|
| 162 |
if weights.sum().abs() < 1e-3:
|
| 163 |
raise RuntimeError('The weights must not sum to 0.')
|
| 164 |
weights /= weights.sum().abs()
|
| 165 |
init = None
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
cur_t = None
|
| 172 |
def cond_fn(x, t, y=None):
|
| 173 |
with torch.enable_grad():
|
|
@@ -185,10 +184,10 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
| 185 |
tv_losses = tv_loss(x_in)
|
| 186 |
range_losses = range_loss(out['pred_xstart'])
|
| 187 |
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
|
| 188 |
-
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
return -torch.autograd.grad(loss, x)[0]
|
| 193 |
if model_config['timestep_respacing'].startswith('ddim'):
|
| 194 |
sample_fn = diffusion.ddim_sample_loop_progressive
|
|
|
|
| 118 |
std=[0.26862954, 0.26130258, 0.27577711])
|
| 119 |
|
| 120 |
|
|
|
|
| 121 |
all_frames = []
|
| 122 |
prompts = [text]
|
| 123 |
+
if image_prompts:
|
| 124 |
+
image_prompts = [image_prompts.name]
|
| 125 |
+
else:
|
| 126 |
+
image_prompts = []
|
| 127 |
batch_size = 1
|
| 128 |
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
|
| 129 |
tv_scale = tv_scale # Controls the smoothness of the final output.
|
| 130 |
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
|
| 131 |
cutn = cutn
|
| 132 |
n_batches = 1
|
| 133 |
+
if init_image:
|
| 134 |
+
init_image = init_image.name
|
| 135 |
+
else:
|
| 136 |
+
init_image = None # This can be an URL or Colab local path and must be in quotes.
|
| 137 |
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
|
| 138 |
# Higher values make the output look more like the init.
|
| 139 |
+
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
|
| 140 |
seed = seed
|
| 141 |
|
| 142 |
if seed is not None:
|
|
|
|
| 148 |
txt, weight = parse_prompt(prompt)
|
| 149 |
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
| 150 |
weights.append(weight)
|
| 151 |
+
for prompt in image_prompts:
|
| 152 |
+
path, weight = parse_prompt(prompt)
|
| 153 |
+
img = Image.open(fetch(path)).convert('RGB')
|
| 154 |
+
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
|
| 155 |
+
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
|
| 156 |
+
embed = clip_model.encode_image(normalize(batch)).float()
|
| 157 |
+
target_embeds.append(embed)
|
| 158 |
+
weights.extend([weight / cutn] * cutn)
|
| 159 |
target_embeds = torch.cat(target_embeds)
|
| 160 |
weights = torch.tensor(weights, device=device)
|
| 161 |
if weights.sum().abs() < 1e-3:
|
| 162 |
raise RuntimeError('The weights must not sum to 0.')
|
| 163 |
weights /= weights.sum().abs()
|
| 164 |
init = None
|
| 165 |
+
if init_image is not None:
|
| 166 |
+
lpips_model = lpips.LPIPS(net='vgg').to(device)
|
| 167 |
+
init = Image.open(fetch(init_image)).convert('RGB')
|
| 168 |
+
init = init.resize((side_x, side_y), Image.LANCZOS)
|
| 169 |
+
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
|
| 170 |
cur_t = None
|
| 171 |
def cond_fn(x, t, y=None):
|
| 172 |
with torch.enable_grad():
|
|
|
|
| 184 |
tv_losses = tv_loss(x_in)
|
| 185 |
range_losses = range_loss(out['pred_xstart'])
|
| 186 |
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
|
| 187 |
+
if init is not None and init_scale:
|
| 188 |
|
| 189 |
+
init_losses = lpips_model(x_in, init)
|
| 190 |
+
loss = loss + init_losses.sum() * init_scale
|
| 191 |
return -torch.autograd.grad(loss, x)[0]
|
| 192 |
if model_config['timestep_respacing'].startswith('ddim'):
|
| 193 |
sample_fn = diffusion.ddim_sample_loop_progressive
|