Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import os | |
| from collections import OrderedDict | |
| from torch.autograd import Variable | |
| from options.test_options import TestOptions | |
| from models.models import create_model | |
| from models.mapping_model import Pix2PixHDModel_Mapping | |
| import util.util as util | |
| from PIL import Image | |
| import torch | |
| import torchvision.utils as vutils | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import cv2 | |
| def data_transforms(img, method=Image.BILINEAR, scale=False): | |
| ow, oh = img.size | |
| pw, ph = ow, oh | |
| if scale == True: | |
| if ow < oh: | |
| ow = 256 | |
| oh = ph / pw * 256 | |
| else: | |
| oh = 256 | |
| ow = pw / ph * 256 | |
| h = int(round(oh / 4) * 4) | |
| w = int(round(ow / 4) * 4) | |
| if (h == ph) and (w == pw): | |
| return img | |
| return img.resize((w, h), method) | |
| def data_transforms_rgb_old(img): | |
| w, h = img.size | |
| A = img | |
| if w < 256 or h < 256: | |
| A = transforms.Scale(256, Image.BILINEAR)(img) | |
| return transforms.CenterCrop(256)(A) | |
| def irregular_hole_synthesize(img, mask): | |
| img_np = np.array(img).astype("uint8") | |
| mask_np = np.array(mask).astype("uint8") | |
| mask_np = mask_np / 255 | |
| img_new = img_np * (1 - mask_np) + mask_np * 255 | |
| hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") | |
| return hole_img | |
| def parameter_set(opt): | |
| ## Default parameters | |
| opt.serial_batches = True # no shuffle | |
| opt.no_flip = True # no flip | |
| opt.label_nc = 0 | |
| opt.n_downsample_global = 3 | |
| opt.mc = 64 | |
| opt.k_size = 4 | |
| opt.start_r = 1 | |
| opt.mapping_n_block = 6 | |
| opt.map_mc = 512 | |
| opt.no_instance = True | |
| opt.checkpoints_dir = "./checkpoints/restoration" | |
| ## | |
| if opt.Quality_restore: | |
| opt.name = "mapping_quality" | |
| opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") | |
| opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") | |
| if opt.Scratch_and_Quality_restore: | |
| opt.NL_res = True | |
| opt.use_SN = True | |
| opt.correlation_renormalize = True | |
| opt.NL_use_mask = True | |
| opt.NL_fusion_method = "combine" | |
| opt.non_local = "Setting_42" | |
| opt.name = "mapping_scratch" | |
| opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") | |
| opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") | |
| if opt.HR: | |
| opt.mapping_exp = 1 | |
| opt.inference_optimize = True | |
| opt.mask_dilation = 3 | |
| opt.name = "mapping_Patch_Attention" | |
| if __name__ == "__main__": | |
| opt = TestOptions().parse(save=False) | |
| parameter_set(opt) | |
| model = Pix2PixHDModel_Mapping() | |
| model.initialize(opt) | |
| model.eval() | |
| if not os.path.exists(opt.outputs_dir + "/" + "input_image"): | |
| os.makedirs(opt.outputs_dir + "/" + "input_image") | |
| if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): | |
| os.makedirs(opt.outputs_dir + "/" + "restored_image") | |
| if not os.path.exists(opt.outputs_dir + "/" + "origin"): | |
| os.makedirs(opt.outputs_dir + "/" + "origin") | |
| dataset_size = 0 | |
| input_loader = os.listdir(opt.test_input) | |
| dataset_size = len(input_loader) | |
| input_loader.sort() | |
| if opt.test_mask != "": | |
| mask_loader = os.listdir(opt.test_mask) | |
| dataset_size = len(os.listdir(opt.test_mask)) | |
| mask_loader.sort() | |
| img_transform = transforms.Compose( | |
| [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
| ) | |
| mask_transform = transforms.ToTensor() | |
| for i in range(dataset_size): | |
| input_name = input_loader[i] | |
| input_file = os.path.join(opt.test_input, input_name) | |
| if not os.path.isfile(input_file): | |
| print("Skipping non-file %s" % input_name) | |
| continue | |
| input = Image.open(input_file).convert("RGB") | |
| print("Now you are processing %s" % (input_name)) | |
| if opt.NL_use_mask: | |
| mask_name = mask_loader[i] | |
| mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") | |
| if opt.mask_dilation != 0: | |
| kernel = np.ones((3,3),np.uint8) | |
| mask = np.array(mask) | |
| mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) | |
| mask = Image.fromarray(mask.astype('uint8')) | |
| origin = input | |
| input = irregular_hole_synthesize(input, mask) | |
| mask = mask_transform(mask) | |
| mask = mask[:1, :, :] ## Convert to single channel | |
| mask = mask.unsqueeze(0) | |
| input = img_transform(input) | |
| input = input.unsqueeze(0) | |
| else: | |
| if opt.test_mode == "Scale": | |
| input = data_transforms(input, scale=True) | |
| if opt.test_mode == "Full": | |
| input = data_transforms(input, scale=False) | |
| if opt.test_mode == "Crop": | |
| input = data_transforms_rgb_old(input) | |
| origin = input | |
| input = img_transform(input) | |
| input = input.unsqueeze(0) | |
| mask = torch.zeros_like(input) | |
| ### Necessary input | |
| try: | |
| with torch.no_grad(): | |
| generated = model.inference(input, mask) | |
| except Exception as ex: | |
| print("Skip %s due to an error:\n%s" % (input_name, str(ex))) | |
| continue | |
| if input_name.endswith(".jpg"): | |
| input_name = input_name[:-4] + ".png" | |
| image_grid = vutils.save_image( | |
| (input + 1.0) / 2.0, | |
| opt.outputs_dir + "/input_image/" + input_name, | |
| nrow=1, | |
| padding=0, | |
| normalize=True, | |
| ) | |
| image_grid = vutils.save_image( | |
| (generated.data.cpu() + 1.0) / 2.0, | |
| opt.outputs_dir + "/restored_image/" + input_name, | |
| nrow=1, | |
| padding=0, | |
| normalize=True, | |
| ) | |
| origin.save(opt.outputs_dir + "/origin/" + input_name) |