Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -360,43 +360,43 @@ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples,
|
|
360 |
).images.to(vae.dtype) / vae.config.scaling_factor
|
361 |
|
362 |
pixels = vae.decode(latents).sample
|
363 |
-
pixels = pytorch2numpy(pixels)
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
|
401 |
return pixels, [fg, bg]
|
402 |
|
|
|
360 |
).images.to(vae.dtype) / vae.config.scaling_factor
|
361 |
|
362 |
pixels = vae.decode(latents).sample
|
363 |
+
pixels = pytorch2numpy(pixels) # Use default quant=True for first pass
|
364 |
|
365 |
+
# Always perform highres processing like the original code
|
366 |
+
pixels = [resize_without_crop(
|
367 |
+
image=p,
|
368 |
+
target_width=int(round(image_width * highres_scale / 64.0) * 64),
|
369 |
+
target_height=int(round(image_height * highres_scale / 64.0) * 64))
|
370 |
+
for p in pixels]
|
371 |
+
|
372 |
+
pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
|
373 |
+
latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
|
374 |
+
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
375 |
+
|
376 |
+
image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
|
377 |
+
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
378 |
+
bg = resize_and_center_crop(input_bg, image_width, image_height)
|
379 |
+
concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
|
380 |
+
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
381 |
+
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
|
382 |
+
|
383 |
+
latents = i2i_pipe(
|
384 |
+
image=latents,
|
385 |
+
strength=highres_denoise,
|
386 |
+
prompt_embeds=conds,
|
387 |
+
negative_prompt_embeds=unconds,
|
388 |
+
width=image_width,
|
389 |
+
height=image_height,
|
390 |
+
num_inference_steps=int(round(steps / highres_denoise)),
|
391 |
+
num_images_per_prompt=num_samples,
|
392 |
+
generator=rng,
|
393 |
+
output_type='latent',
|
394 |
+
guidance_scale=cfg,
|
395 |
+
cross_attention_kwargs={'concat_conds': concat_conds},
|
396 |
+
).images.to(vae.dtype) / vae.config.scaling_factor
|
397 |
+
|
398 |
+
pixels = vae.decode(latents).sample
|
399 |
+
pixels = pytorch2numpy(pixels, quant=False) # Return 0-1 range floats for final result
|
400 |
|
401 |
return pixels, [fg, bg]
|
402 |
|