File size: 5,868 Bytes
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736e88e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df97f6
 
 
 
 
bfd34e9
da1e12f
 
bfd34e9
 
 
 
736e88e
 
 
bfd34e9
 
 
 
 
 
 
736e88e
 
 
 
bfd34e9
 
 
 
736e88e
 
bfd34e9
736e88e
 
bfd34e9
 
 
736e88e
da1e12f
736e88e
bfd34e9
 
 
736e88e
 
bfd34e9
 
 
da1e12f
 
bfd34e9
 
 
da1e12f
bfd34e9
 
 
 
 
 
 
 
736e88e
 
 
 
 
bfd34e9
 
736e88e
 
 
 
bfd34e9
1df97f6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from functools import partial
from glob import glob
from pathlib import Path as PythonPath

import cv2
import torchvision.transforms.functional as TvF
import torch
import torch.nn as nn
import numpy as np
from inspect import isfunction
from PIL import Image

from lib import smplfusion
from lib.smplfusion import share, router, attentionpatch, transformerpatch
from lib.utils.iimage import IImage
from lib.utils import poisson_blend
from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v


def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
    lr_mask = hr_mask.resize(512)

    x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
    x_min = max(x_min - 1, 0)
    y_min = max(y_min - 1, 0)
    x_max = x_min + rect_w + 1
    y_max = y_min + rect_h + 1

    input_box = np.array([x_min, y_min, x_max, y_max])

    sam_predictor.set_image(hr_image.resize(512).data[0])
    masks, _, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=True,
    )
    dilation_kernel = np.ones((13, 13))
    original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
    original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)

    sam_predictor.set_image(lr_image.resize(512).data[0])
    masks, _, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=True,
    )
    dilation_kernel = np.ones((3, 3))
    inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
    inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)

    lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
    new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
    new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
    return new_mask


def run(
    ddim,
    sam_predictor,
    lr_image,
    hr_image,
    hr_mask,
    prompt = 'high resolution professional photo',
    noise_level=20,
    blend_output = True,
    blend_trick = True,
    dt = 50,
    seed = 1,
    guidance_scale = 7.5,
    negative_prompt = '',
    use_sam_mask = False
):
    hr_image_info = hr_image.info
    lr_image = IImage(lr_image)
    hr_image = IImage(hr_image).resize(2048)
    hr_mask = IImage(hr_mask).resize(2048)

    torch.manual_seed(seed)
    dtype = ddim.vae.encoder.conv_in.weight.dtype
    device = ddim.vae.encoder.conv_in.weight.device

    router.attention_forward = attentionpatch.default.forward_xformers
    router.basic_transformer_forward = transformerpatch.default.forward

    hr_image_orig = hr_image
    hr_mask_orig = hr_mask

    if use_sam_mask:
        with torch.no_grad():
            hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)

    orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
    hr_image = hr_image.padx(256, padding_mode='reflect')
    hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
   
    lr_image = lr_image.padx(64, padding_mode='reflect').torch()
    lr_mask = hr_mask.resize((lr_image.shape[2:]), resample = Image.BICUBIC)
    lr_mask = lr_mask.alpha().torch(vmin=0).to(device)
    lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
                                     
    # encode hr image
    with torch.no_grad():
        hr_image = hr_image.torch().to(dtype=dtype, device=device)
        hr_z0 = ddim.vae.encode(hr_image).mean * ddim.config.scale_factor

    assert hr_z0.shape[2] == lr_image.shape[2]
    assert hr_z0.shape[3] == lr_image.shape[3]                                  
                                     
    with torch.no_grad():
        context = ddim.encoder.encode([negative_prompt, prompt])
    
    noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
    unet_condition = lr_image.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
    unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)

    with torch.autocast('cuda'), torch.no_grad():
        zt = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3]))
        zt = zt.cuda().to(dtype=dtype, device=device)
        for index,t in enumerate(range(999, 0, -dt)):
            _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
            eps_uncond, eps = ddim.unet(
                torch.cat([_zt, _zt]).to(dtype=dtype, device=device), 
                timesteps = torch.tensor([t, t]).to(device=device), 
                context = context,
                y=torch.cat([noise_level]*2)
            ).chunk(2)
            ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
            model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
            eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
            z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
            if blend_trick:
                z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
            zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps

    with torch.no_grad():
        hr_result = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
        # postprocess
        hr_result = (255 * ((hr_result + 1) / 2).clip(0, 1)).to(torch.uint8)
        hr_result = hr_result.cpu().permute(0, 2, 3, 1)[0].numpy()
        hr_result = hr_result[:orig_h, :orig_w, :]

    if blend_output:
        hr_result = poisson_blend(
            orig_img=hr_image_orig.data[0],
            fake_img=hr_result,
            mask=hr_mask_orig.alpha().data[0]
        )
    hr_result = Image.fromarray(hr_result)
    hr_result.info = hr_image_info  # save metadata
    return hr_result