|
|
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import os |
|
|
|
class DIV2KDataset(Dataset): |
|
def __init__(self, hr_dir, lr_dir, patch_size=96, upscale_factor=4): |
|
self.hr_files = sorted(os.listdir(hr_dir)) |
|
self.lr_files = sorted(os.listdir(lr_dir)) |
|
self.hr_dir = hr_dir |
|
self.lr_dir = lr_dir |
|
self.patch_size = patch_size |
|
self.upscale_factor = upscale_factor |
|
|
|
|
|
self.lr_transform = transforms.Compose([ |
|
transforms.Resize((patch_size//upscale_factor, patch_size//upscale_factor), |
|
interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
self.hr_transform = transforms.Compose([ |
|
transforms.Resize((patch_size, patch_size), |
|
interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
self.lr_upscale = transforms.Compose([ |
|
transforms.Resize((patch_size, patch_size), |
|
interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def __getitem__(self, idx): |
|
|
|
hr_img = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('YCbCr') |
|
lr_img = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('YCbCr') |
|
|
|
|
|
hr_y, _, _ = hr_img.split() |
|
lr_y, _, _ = lr_img.split() |
|
|
|
|
|
lr_y_upscaled = self.lr_upscale(lr_y) |
|
hr_y_tensor = self.hr_transform(hr_y) |
|
|
|
return lr_y_upscaled, hr_y_tensor |
|
|
|
def __len__(self): |
|
return len(self.hr_files) |