Andranik Sargsyan
add demo code
bfd34e9
raw
history blame
No virus
6.06 kB
import os
from functools import partial
from glob import glob
from pathlib import Path as PythonPath
import cv2
import torchvision.transforms.functional as TvF
import torch
import torch.nn as nn
import numpy as np
from inspect import isfunction
from PIL import Image
from lib import smplfusion
from lib.smplfusion import share, router, attentionpatch, transformerpatch
from lib.utils.iimage import IImage
from lib.utils import poisson_blend
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
lr_mask = hr_mask.resize(512)
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
x_min = max(x_min - 1, 0)
y_min = max(y_min - 1, 0)
x_max = x_min + rect_w + 1
y_max = y_min + rect_h + 1
input_box = np.array([x_min, y_min, x_max, y_max])
sam_predictor.set_image(hr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((13, 13))
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)
sam_predictor.set_image(lr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((3, 3))
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
return new_mask
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
blend_output = True, blend_trick = True, no_superres = False,
dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False, dtype=torch.bfloat16):
torch.manual_seed(seed)
router.attention_forward = attentionpatch.default.forward_xformers
router.basic_transformer_forward = transformerpatch.default.forward
if use_sam_mask:
with torch.no_grad():
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
hr_image = hr_image.padx(256, padding_mode='reflect')
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
hr_mask_orig = hr_mask
lr_image = lr_image.padx(64, padding_mode='reflect')
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).cuda()
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
if no_superres:
output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
output_tensor = poisson_blend(
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
)
return IImage(output_tensor[:orig_h, :orig_w, :])
# encode hr image
with torch.no_grad():
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor
assert hr_z0.shape[2] == lr_image.torch().shape[2]
assert hr_z0.shape[3] == lr_image.torch().shape[3]
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype)
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype)
with torch.no_grad():
context = ddim.encoder.encode([negative_prompt, prompt])
noise_level = torch.Tensor(1 * [noise_level]).to('cuda').long()
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
with torch.autocast('cuda'), torch.no_grad():
zt = zT
for index,t in enumerate(range(999, 0, -dt)):
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
eps_uncond, eps = ddim.unet(
torch.cat([_zt, _zt]).to(dtype),
timesteps = torch.tensor([t, t]).cuda(),
context = context,
y=torch.cat([noise_level]*2)
).chunk(2)
ts = torch.full((zt.shape[0],), t, device='cuda', dtype=torch.long)
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
if blend_trick:
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
with torch.no_grad():
output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
if blend_output:
output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
output_tensor = poisson_blend(
orig_img=hr_image.data[0][:orig_h, :orig_w, :],
fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
)
return IImage(output_tensor[:orig_h, :orig_w, :])
else:
return IImage(output_tensor[:, :, :orig_h, :orig_w])