GreenGoat commited on
Commit
5e6862e
·
verified ·
1 Parent(s): dc6cc0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -36
app.py CHANGED
@@ -360,43 +360,43 @@ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples,
360
  ).images.to(vae.dtype) / vae.config.scaling_factor
361
 
362
  pixels = vae.decode(latents).sample
363
- pixels = pytorch2numpy(pixels)
364
 
365
- if highres_scale > 1.0:
366
- pixels = [resize_without_crop(
367
- image=p,
368
- target_width=int(round(image_width * highres_scale / 64.0) * 64),
369
- target_height=int(round(image_height * highres_scale / 64.0) * 64))
370
- for p in pixels]
371
-
372
- pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
373
- latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
374
- latents = latents.to(device=unet.device, dtype=unet.dtype)
375
-
376
- image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
377
- fg = resize_and_center_crop(input_fg, image_width, image_height)
378
- bg = resize_and_center_crop(input_bg, image_width, image_height)
379
- concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
380
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
381
- concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
382
-
383
- latents = i2i_pipe(
384
- image=latents,
385
- strength=highres_denoise,
386
- prompt_embeds=conds,
387
- negative_prompt_embeds=unconds,
388
- width=image_width,
389
- height=image_height,
390
- num_inference_steps=int(round(steps / highres_denoise)),
391
- num_images_per_prompt=num_samples,
392
- generator=rng,
393
- output_type='latent',
394
- guidance_scale=cfg,
395
- cross_attention_kwargs={'concat_conds': concat_conds},
396
- ).images.to(vae.dtype) / vae.config.scaling_factor
397
-
398
- pixels = vae.decode(latents).sample
399
- pixels = pytorch2numpy(pixels, quant=False)
400
 
401
  return pixels, [fg, bg]
402
 
 
360
  ).images.to(vae.dtype) / vae.config.scaling_factor
361
 
362
  pixels = vae.decode(latents).sample
363
+ pixels = pytorch2numpy(pixels) # Use default quant=True for first pass
364
 
365
+ # Always perform highres processing like the original code
366
+ pixels = [resize_without_crop(
367
+ image=p,
368
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
369
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
370
+ for p in pixels]
371
+
372
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
373
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
374
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
375
+
376
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
377
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
378
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
379
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
380
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
381
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
382
+
383
+ latents = i2i_pipe(
384
+ image=latents,
385
+ strength=highres_denoise,
386
+ prompt_embeds=conds,
387
+ negative_prompt_embeds=unconds,
388
+ width=image_width,
389
+ height=image_height,
390
+ num_inference_steps=int(round(steps / highres_denoise)),
391
+ num_images_per_prompt=num_samples,
392
+ generator=rng,
393
+ output_type='latent',
394
+ guidance_scale=cfg,
395
+ cross_attention_kwargs={'concat_conds': concat_conds},
396
+ ).images.to(vae.dtype) / vae.config.scaling_factor
397
+
398
+ pixels = vae.decode(latents).sample
399
+ pixels = pytorch2numpy(pixels, quant=False) # Return 0-1 range floats for final result
400
 
401
  return pixels, [fg, bg]
402