File size: 3,756 Bytes
c7813b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4f502
c7813b3
 
 
 
 
af4f502
c7813b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4f502
c7813b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F

import os
import sys
sys.path.append('inpaint')
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

import cv2
import numpy as np
import torch
import yaml
from omegaconf import OmegaConf
from saicinpainting.training.trainers import load_checkpoint


class IVModel():
    def __init__(self, device=torch.device('cuda:0')):
        super(IVModel, self).__init__()
        self.device = device
        conf_path = 'inpaint/configs/prediction/default.yaml'
        predict_config = OmegaConf.load(conf_path)
        predict_config.model.path='inpaint/weights/big-lama'
        if not os.path.exists(conf_path):
            print('未找到配置文件!')
        train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
        with open(train_config_path, 'r') as f:
            train_config = OmegaConf.create(yaml.safe_load(f))
        
        train_config.training_model.predict_only = True
        train_config.visualizer.kind = 'noop'

        checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
        self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
        self.model.freeze()
        self.model.to(device)

        self.__first_forward__()


    def __first_forward__(self, input_size=(512, 1024, 3)):
        # 调用forward()严格控制最大显存
        print('initialize Inpaint Model...')
        _ = self.forward(np.random.rand(*input_size) * 255, None)
        print('initialize Complete!')

    def __resize_tensor__(self, image, max_size=512, scale_factor=8):
        h, w = image.size()[2:]
        if max(h, w) > max_size:
            if h < w:
                h, w = int(max_size * h / w), max_size
            else:
                h, w = max_size, int(max_size * w / h)
        h = h // scale_factor * scale_factor
        w = w // scale_factor * scale_factor
        image = F.interpolate(image, (h, w), mode='bicubic')
        return image

    def input_preprocess_tensor(self, img):
        img_t = torch.from_numpy(img.astype(np.float32))  # .to(self.device)
        img_t = img_t.permute(2, 0, 1).unsqueeze(0)
        img_t = img_t / 255.
        img_t_for_net = self.__resize_tensor__(img_t).to(self.device)  # 为了控制最大显存容量
        img_t_for_out = self.__resize_tensor__(img_t, max_size=512).to(self.device)  # 为了控制最大显存容量
        return img_t_for_net, img_t_for_out

    def forward(self, img, json_data):
        _,w,_ = img.shape
        mask = img[:,w//2:,0]
        kernel = np.ones((4, 4), np.uint8)
        mask = cv2.dilate(mask, kernel, iterations=5)[:,:,np.newaxis]
        img = img[:,:w//2,::-1]
        # print(img.shape,mask.shape)
        img_t_for_net, img_t_for_out = self.input_preprocess_tensor(img)
        mask_t_for_net, mask_t_for_out = self.input_preprocess_tensor(mask)
        h, w = img_t_for_out.shape[2:]
        # print(img_t.shape,mask_t.shape)
        mask_t_for_out = (mask_t_for_out>0).int()
        mask_t_for_net = (mask_t_for_net>0).int()
        with torch.no_grad():
            res_t = self.model(dict(image=img_t_for_net, mask=mask_t_for_net))['inpainted']
            res_t = F.interpolate(res_t, (h, w), mode='bicubic')
            res_t = img_t_for_out * (1 - mask_t_for_out) + res_t * mask_t_for_out
        res_t = torch.clip(res_t * 255, min=0, max=255)
        res = res_t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
        res = res[:, :, (2, 1, 0)].astype(np.uint8)
        return res