import os import cv2 import numpy as np import torch import torch.nn.functional as F import sys sys.path.insert(0, '../') sys.dont_write_bytecode = True from .PGNet import PGNet class Normalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image = (image - self.mean)/self.std return image class Config(object): def __init__(self, **kwargs): self.kwargs = kwargs self.mean = np.array([[[124.55, 118.90, 102.94]]]) self.std = np.array([[[ 56.77, 55.97, 57.50]]]) print('\nParameters...') for k, v in self.kwargs.items(): print('%-10s: %s'%(k, v)) def __getattr__(self, name): if name in self.kwargs: return self.kwargs[name] else: return None class IVModel(): def __init__(self, device=torch.device('cuda:0')): super(IVModel, self).__init__() self.device = device checkpoint_path = 'sod/weights/PGNet_DUT+HR-model-31.pth' self.cfg = Config(snapshot=checkpoint_path, mode='test') if not os.path.exists(checkpoint_path): print('未找到模型文件!') self.net = PGNet(self.cfg) self.net.train(False) self.net.to(device) self.normalize = Normalize(mean=self.cfg.mean, std=self.cfg.std) self.__first_forward__() def __first_forward__(self, input_size=(512, 512, 3)): # 调用forward()严格控制最大显存 print('initialize Sod Model...') _ = self.forward(np.random.rand(*input_size) * 255, None) print('initialize Complete!') def __resize_tensor__(self, image, max_size=512): h, w = image.size()[2:] if max(h, w) > max_size: if h < w: h, w = int(max_size * h / w)//8*8, max_size else: h, w = max_size, int(max_size * w / h)//8*8 image = F.interpolate(image, (h, w), mode='area') return image def input_preprocess_tensor(self, img): img = self.normalize(img) img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device) img_t = img_t.permute(2, 0, 1).unsqueeze(0) img_t = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量 return img_t def forward(self, img, json_data): img_t = self.input_preprocess_tensor(img) shape = [torch.as_tensor([img_t.shape[2]]), torch.as_tensor([img_t.shape[3]])] h, w = img_t.shape[2], img_t.shape[3] img_t_temp = F.interpolate(img_t, (512, 512), mode='area') with torch.no_grad(): res = self.net(img_t_temp, shape=shape) res = F.interpolate(res[0],size=shape, mode='bilinear') res = torch.sigmoid(res) # print(res.shape, img_t.shape, res.expand_as(img_t).shape) res = torch.cat([img_t, res.expand_as(img_t)], dim=3) res = (res[0].permute(1,2,0)).cpu().numpy() res[:,:w,:] = res[:,:w,:] * self.cfg.std + self.cfg.mean res[:,w:,:] = res[:,w:,:] * 255 return res