Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import os.path | |
| import io | |
| import zipfile | |
| from data.base_dataset import BaseDataset, get_params, get_transform, normalize | |
| from data.image_folder import make_dataset | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from data.Load_Bigfile import BigFileMemoryLoader | |
| import random | |
| import cv2 | |
| from io import BytesIO | |
| def pil_to_np(img_PIL): | |
| '''Converts image in PIL format to np.array. | |
| From W x H x C [0...255] to C x W x H [0..1] | |
| ''' | |
| ar = np.array(img_PIL) | |
| if len(ar.shape) == 3: | |
| ar = ar.transpose(2, 0, 1) | |
| else: | |
| ar = ar[None, ...] | |
| return ar.astype(np.float32) / 255. | |
| def np_to_pil(img_np): | |
| '''Converts image in np.array format to PIL image. | |
| From C x W x H [0..1] to W x H x C [0...255] | |
| ''' | |
| ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) | |
| if img_np.shape[0] == 1: | |
| ar = ar[0] | |
| else: | |
| ar = ar.transpose(1, 2, 0) | |
| return Image.fromarray(ar) | |
| def synthesize_salt_pepper(image,amount,salt_vs_pepper): | |
| ## Give PIL, return the noisy PIL | |
| img_pil=pil_to_np(image) | |
| out = img_pil.copy() | |
| p = amount | |
| q = salt_vs_pepper | |
| flipped = np.random.choice([True, False], size=img_pil.shape, | |
| p=[p, 1 - p]) | |
| salted = np.random.choice([True, False], size=img_pil.shape, | |
| p=[q, 1 - q]) | |
| peppered = ~salted | |
| out[flipped & salted] = 1 | |
| out[flipped & peppered] = 0. | |
| noisy = np.clip(out, 0, 1).astype(np.float32) | |
| return np_to_pil(noisy) | |
| def synthesize_gaussian(image,std_l,std_r): | |
| ## Give PIL, return the noisy PIL | |
| img_pil=pil_to_np(image) | |
| mean=0 | |
| std=random.uniform(std_l/255.,std_r/255.) | |
| gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) | |
| noisy=img_pil+gauss | |
| noisy=np.clip(noisy,0,1).astype(np.float32) | |
| return np_to_pil(noisy) | |
| def synthesize_speckle(image,std_l,std_r): | |
| ## Give PIL, return the noisy PIL | |
| img_pil=pil_to_np(image) | |
| mean=0 | |
| std=random.uniform(std_l/255.,std_r/255.) | |
| gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) | |
| noisy=img_pil+gauss*img_pil | |
| noisy=np.clip(noisy,0,1).astype(np.float32) | |
| return np_to_pil(noisy) | |
| def synthesize_low_resolution(img): | |
| w,h=img.size | |
| new_w=random.randint(int(w/2),w) | |
| new_h=random.randint(int(h/2),h) | |
| img=img.resize((new_w,new_h),Image.BICUBIC) | |
| if random.uniform(0,1)<0.5: | |
| img=img.resize((w,h),Image.NEAREST) | |
| else: | |
| img = img.resize((w, h), Image.BILINEAR) | |
| return img | |
| def convertToJpeg(im,quality): | |
| with BytesIO() as f: | |
| im.save(f, format='JPEG',quality=quality) | |
| f.seek(0) | |
| return Image.open(f).convert('RGB') | |
| def blur_image_v2(img): | |
| x=np.array(img) | |
| kernel_size_candidate=[(3,3),(5,5),(7,7)] | |
| kernel_size=random.sample(kernel_size_candidate,1)[0] | |
| std=random.uniform(1.,5.) | |
| #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std)) | |
| blur=cv2.GaussianBlur(x,kernel_size,std) | |
| return Image.fromarray(blur.astype(np.uint8)) | |
| def online_add_degradation_v2(img): | |
| task_id=np.random.permutation(4) | |
| for x in task_id: | |
| if x==0 and random.uniform(0,1)<0.7: | |
| img = blur_image_v2(img) | |
| if x==1 and random.uniform(0,1)<0.7: | |
| flag = random.choice([1, 2, 3]) | |
| if flag == 1: | |
| img = synthesize_gaussian(img, 5, 50) | |
| if flag == 2: | |
| img = synthesize_speckle(img, 5, 50) | |
| if flag == 3: | |
| img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8)) | |
| if x==2 and random.uniform(0,1)<0.7: | |
| img=synthesize_low_resolution(img) | |
| if x==3 and random.uniform(0,1)<0.7: | |
| img=convertToJpeg(img,random.randint(40,100)) | |
| return img | |
| def irregular_hole_synthesize(img,mask): | |
| img_np=np.array(img).astype('uint8') | |
| mask_np=np.array(mask).astype('uint8') | |
| mask_np=mask_np/255 | |
| img_new=img_np*(1-mask_np)+mask_np*255 | |
| hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB") | |
| return hole_img,mask.convert("L") | |
| def zero_mask(size): | |
| x=np.zeros((size,size,3)).astype('uint8') | |
| mask=Image.fromarray(x).convert("RGB") | |
| return mask | |
| class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old | |
| def initialize(self, opt): | |
| self.opt = opt | |
| self.isImage = 'domainA' in opt.name | |
| self.task = 'old_photo_restoration_training_vae' | |
| self.dir_AB = opt.dataroot | |
| if self.isImage: | |
| self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile") | |
| self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile") | |
| self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") | |
| self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old) | |
| self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old) | |
| self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) | |
| else: | |
| # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset) | |
| self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") | |
| self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) | |
| #### | |
| print("-------------Filter the imgs whose size <256 in VOC-------------") | |
| self.filtered_imgs_clean=[] | |
| for i in range(len(self.loaded_imgs_clean)): | |
| img_name,img=self.loaded_imgs_clean[i] | |
| h,w=img.size | |
| if h<256 or w<256: | |
| continue | |
| self.filtered_imgs_clean.append((img_name,img)) | |
| print("--------Origin image num is [%d], filtered result is [%d]--------" % ( | |
| len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) | |
| ## Filter these images whose size is less than 256 | |
| # self.img_list=os.listdir(load_img_dir) | |
| self.pid = os.getpid() | |
| def __getitem__(self, index): | |
| is_real_old=0 | |
| sampled_dataset=None | |
| degradation=None | |
| if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old | |
| P=random.uniform(0,2) | |
| if P>=0 and P<1: | |
| if random.uniform(0,1)<0.5: | |
| sampled_dataset=self.loaded_imgs_L_old | |
| self.load_img_dir=self.load_img_dir_L_old | |
| else: | |
| sampled_dataset=self.loaded_imgs_RGB_old | |
| self.load_img_dir=self.load_img_dir_RGB_old | |
| is_real_old=1 | |
| if P>=1 and P<2: | |
| sampled_dataset=self.filtered_imgs_clean | |
| self.load_img_dir=self.load_img_dir_clean | |
| degradation=1 | |
| else: | |
| sampled_dataset=self.filtered_imgs_clean | |
| self.load_img_dir=self.load_img_dir_clean | |
| sampled_dataset_len=len(sampled_dataset) | |
| index=random.randint(0,sampled_dataset_len-1) | |
| img_name,img = sampled_dataset[index] | |
| if degradation is not None: | |
| img=online_add_degradation_v2(img) | |
| path=os.path.join(self.load_img_dir,img_name) | |
| # AB = Image.open(path).convert('RGB') | |
| # split AB image into A and B | |
| # apply the same transform to both A and B | |
| if random.uniform(0,1) <0.1: | |
| img=img.convert("L") | |
| img=img.convert("RGB") | |
| ## Give a probability P, we convert the RGB image into L | |
| A=img | |
| w,h=A.size | |
| if w<256 or h<256: | |
| A=transforms.Scale(256,Image.BICUBIC)(A) | |
| ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them. | |
| transform_params = get_params(self.opt, A.size) | |
| A_transform = get_transform(self.opt, transform_params) | |
| B_tensor = inst_tensor = feat_tensor = 0 | |
| A_tensor = A_transform(A) | |
| input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor, | |
| 'feat': feat_tensor, 'path': path} | |
| return input_dict | |
| def __len__(self): | |
| return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number | |
| def name(self): | |
| return 'UnPairOldPhotos_SR' | |
| class PairOldPhotos(BaseDataset): | |
| def initialize(self, opt): | |
| self.opt = opt | |
| self.isImage = 'imagegan' in opt.name | |
| self.task = 'old_photo_restoration_training_mapping' | |
| self.dir_AB = opt.dataroot | |
| if opt.isTrain: | |
| self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") | |
| self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) | |
| print("-------------Filter the imgs whose size <256 in VOC-------------") | |
| self.filtered_imgs_clean = [] | |
| for i in range(len(self.loaded_imgs_clean)): | |
| img_name, img = self.loaded_imgs_clean[i] | |
| h, w = img.size | |
| if h < 256 or w < 256: | |
| continue | |
| self.filtered_imgs_clean.append((img_name, img)) | |
| print("--------Origin image num is [%d], filtered result is [%d]--------" % ( | |
| len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) | |
| else: | |
| self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) | |
| self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) | |
| self.pid = os.getpid() | |
| def __getitem__(self, index): | |
| if self.opt.isTrain: | |
| img_name_clean,B = self.filtered_imgs_clean[index] | |
| path = os.path.join(self.load_img_dir_clean, img_name_clean) | |
| if self.opt.use_v2_degradation: | |
| A=online_add_degradation_v2(B) | |
| ### Remind: A is the input and B is corresponding GT | |
| else: | |
| if self.opt.test_on_synthetic: | |
| img_name_B,B=self.loaded_imgs[index] | |
| A=online_add_degradation_v2(B) | |
| img_name_A=img_name_B | |
| path = os.path.join(self.load_img_dir, img_name_A) | |
| else: | |
| img_name_A,A=self.loaded_imgs[index] | |
| img_name_B,B=self.loaded_imgs[index] | |
| path = os.path.join(self.load_img_dir, img_name_A) | |
| if random.uniform(0,1)<0.1 and self.opt.isTrain: | |
| A=A.convert("L") | |
| B=B.convert("L") | |
| A=A.convert("RGB") | |
| B=B.convert("RGB") | |
| ## In P, we convert the RGB into L | |
| ##test on L | |
| # split AB image into A and B | |
| # w, h = img.size | |
| # w2 = int(w / 2) | |
| # A = img.crop((0, 0, w2, h)) | |
| # B = img.crop((w2, 0, w, h)) | |
| w,h=A.size | |
| if w<256 or h<256: | |
| A=transforms.Scale(256,Image.BICUBIC)(A) | |
| B=transforms.Scale(256, Image.BICUBIC)(B) | |
| # apply the same transform to both A and B | |
| transform_params = get_params(self.opt, A.size) | |
| A_transform = get_transform(self.opt, transform_params) | |
| B_transform = get_transform(self.opt, transform_params) | |
| B_tensor = inst_tensor = feat_tensor = 0 | |
| A_tensor = A_transform(A) | |
| B_tensor = B_transform(B) | |
| input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, | |
| 'feat': feat_tensor, 'path': path} | |
| return input_dict | |
| def __len__(self): | |
| if self.opt.isTrain: | |
| return len(self.filtered_imgs_clean) | |
| else: | |
| return len(self.loaded_imgs) | |
| def name(self): | |
| return 'PairOldPhotos' | |
| class PairOldPhotos_with_hole(BaseDataset): | |
| def initialize(self, opt): | |
| self.opt = opt | |
| self.isImage = 'imagegan' in opt.name | |
| self.task = 'old_photo_restoration_training_mapping' | |
| self.dir_AB = opt.dataroot | |
| if opt.isTrain: | |
| self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") | |
| self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) | |
| print("-------------Filter the imgs whose size <256 in VOC-------------") | |
| self.filtered_imgs_clean = [] | |
| for i in range(len(self.loaded_imgs_clean)): | |
| img_name, img = self.loaded_imgs_clean[i] | |
| h, w = img.size | |
| if h < 256 or w < 256: | |
| continue | |
| self.filtered_imgs_clean.append((img_name, img)) | |
| print("--------Origin image num is [%d], filtered result is [%d]--------" % ( | |
| len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) | |
| else: | |
| self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) | |
| self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) | |
| self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask) | |
| self.pid = os.getpid() | |
| def __getitem__(self, index): | |
| if self.opt.isTrain: | |
| img_name_clean,B = self.filtered_imgs_clean[index] | |
| path = os.path.join(self.load_img_dir_clean, img_name_clean) | |
| B=transforms.RandomCrop(256)(B) | |
| A=online_add_degradation_v2(B) | |
| ### Remind: A is the input and B is corresponding GT | |
| else: | |
| img_name_A,A=self.loaded_imgs[index] | |
| img_name_B,B=self.loaded_imgs[index] | |
| path = os.path.join(self.load_img_dir, img_name_A) | |
| #A=A.resize((256,256)) | |
| A=transforms.CenterCrop(256)(A) | |
| B=A | |
| if random.uniform(0,1)<0.1 and self.opt.isTrain: | |
| A=A.convert("L") | |
| B=B.convert("L") | |
| A=A.convert("RGB") | |
| B=B.convert("RGB") | |
| ## In P, we convert the RGB into L | |
| if self.opt.isTrain: | |
| mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)] | |
| else: | |
| mask_name, mask = self.loaded_masks[index%100] | |
| mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST) | |
| if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain: | |
| mask=zero_mask(256) | |
| if self.opt.no_hole: | |
| mask=zero_mask(256) | |
| A,_=irregular_hole_synthesize(A,mask) | |
| if not self.opt.isTrain and self.opt.hole_image_no_mask: | |
| mask=zero_mask(256) | |
| transform_params = get_params(self.opt, A.size) | |
| A_transform = get_transform(self.opt, transform_params) | |
| B_transform = get_transform(self.opt, transform_params) | |
| if transform_params['flip'] and self.opt.isTrain: | |
| mask=mask.transpose(Image.FLIP_LEFT_RIGHT) | |
| mask_tensor = transforms.ToTensor()(mask) | |
| B_tensor = inst_tensor = feat_tensor = 0 | |
| A_tensor = A_transform(A) | |
| B_tensor = B_transform(B) | |
| input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor, | |
| 'feat': feat_tensor, 'path': path} | |
| return input_dict | |
| def __len__(self): | |
| if self.opt.isTrain: | |
| return len(self.filtered_imgs_clean) | |
| else: | |
| return len(self.loaded_imgs) | |
| def name(self): | |
| return 'PairOldPhotos_with_hole' |