Sod_Inpaint / inpaint /infer_model.py
wenpeng's picture
update .gitignore
c7813b3
raw
history blame
3.76 kB
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import os
import sys
sys.path.append('inpaint')
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
import cv2
import numpy as np
import torch
import yaml
from omegaconf import OmegaConf
from saicinpainting.training.trainers import load_checkpoint
class IVModel():
def __init__(self, device=torch.device('cuda:0')):
super(IVModel, self).__init__()
self.device = device
conf_path = 'inpaint/configs/prediction/default.yaml'
predict_config = OmegaConf.load(conf_path)
predict_config.model.path='inpaint/weights/big-lama'
if not os.path.exists(conf_path):
print('未找到配置文件!')
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
with open(train_config_path, 'r') as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
self.model.freeze()
self.model.to(device)
self.__first_forward__()
def __first_forward__(self, input_size=(2048, 4096, 3)):
# 调用forward()严格控制最大显存
print('initialize Inpaint Model...')
_ = self.forward(np.random.rand(*input_size) * 255, None)
print('initialize Complete!')
def __resize_tensor__(self, image, max_size=1024, scale_factor=8):
h, w = image.size()[2:]
if max(h, w) > max_size:
if h < w:
h, w = int(max_size * h / w), max_size
else:
h, w = max_size, int(max_size * w / h)
h = h // scale_factor * scale_factor
w = w // scale_factor * scale_factor
image = F.interpolate(image, (h, w), mode='bicubic')
return image
def input_preprocess_tensor(self, img):
img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device)
img_t = img_t.permute(2, 0, 1).unsqueeze(0)
img_t = img_t / 255.
img_t_for_net = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量
img_t_for_out = self.__resize_tensor__(img_t, max_size=2048).to(self.device) # 为了控制最大显存容量
return img_t_for_net, img_t_for_out
def forward(self, img, json_data):
_,w,_ = img.shape
mask = img[:,w//2:,0]
kernel = np.ones((4, 4), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=5)[:,:,np.newaxis]
img = img[:,:w//2,::-1]
# print(img.shape,mask.shape)
img_t_for_net, img_t_for_out = self.input_preprocess_tensor(img)
mask_t_for_net, mask_t_for_out = self.input_preprocess_tensor(mask)
h, w = img_t_for_out.shape[2:]
# print(img_t.shape,mask_t.shape)
mask_t_for_out = (mask_t_for_out>0).int()
mask_t_for_net = (mask_t_for_net>0).int()
with torch.no_grad():
res_t = self.model(dict(image=img_t_for_net, mask=mask_t_for_net))['inpainted']
res_t = F.interpolate(res_t, (h, w), mode='bicubic')
res_t = img_t_for_out * (1 - mask_t_for_out) + res_t * mask_t_for_out
res_t = torch.clip(res_t * 255, min=0, max=255)
res = res_t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
res = res[:, :, (2, 1, 0)].astype(np.uint8)
return res