| | import os |
| | import glob |
| | import random |
| | import pickle |
| | from data import common |
| | import imageio |
| | import torch.utils.data as data |
| |
|
| | class SRData(data.Dataset): |
| | def __init__(self, args, name='', benchmark=False): |
| | self.args = args |
| | self.name = name |
| | self.split = 'test' |
| | self.do_eval = True |
| | self.benchmark = benchmark |
| | self.input_large = False |
| | self.scale = args.scale |
| | self.idx_scale = 0 |
| | self._set_filesystem(args.dir_data) |
| | list_hr, list_lr = self._scan() |
| | self.images_hr, self.images_lr = list_hr, list_lr |
| |
|
| | |
| | def _scan(self): |
| | names_hr = sorted( |
| | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) |
| | ) |
| | names_lr = [[] for _ in self.scale] |
| | for f in names_hr: |
| | filename, _ = os.path.splitext(os.path.basename(f)) |
| | for si, s in enumerate(self.scale): |
| | names_lr[si].append(os.path.join( |
| | self.dir_lr, 'X{}/{}x{}{}'.format( |
| | s, filename, s, self.ext[1] |
| | ) |
| | )) |
| |
|
| | return names_hr, names_lr |
| |
|
| | def _set_filesystem(self, dir_data): |
| | self.apath = os.path.join(dir_data, self.name) |
| | self.dir_hr = os.path.join(self.apath, 'HR') |
| | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') |
| | self.ext = ('.png', '.png') |
| |
|
| | def __getitem__(self, idx): |
| | lr, hr, filename = self._load_file(idx) |
| | pair = self.get_patch(lr, hr) |
| | pair = common.set_channel(*pair, n_channels=3) |
| | pair_t = common.np2Tensor(*pair, rgb_range=255) |
| |
|
| | return pair_t[0], pair_t[1], filename |
| |
|
| | def __len__(self): |
| | return len(self.images_hr) |
| |
|
| | def _get_index(self, idx): |
| | return idx |
| |
|
| | def _load_file(self, idx): |
| | idx = self._get_index(idx) |
| | f_hr = self.images_hr[idx] |
| | f_lr = self.images_lr[self.idx_scale][idx] |
| |
|
| | filename, _ = os.path.splitext(os.path.basename(f_hr)) |
| | hr = imageio.imread(f_hr) |
| | lr = imageio.imread(f_lr) |
| | return lr, hr, filename |
| |
|
| | def get_patch(self, lr, hr): |
| | scale = self.scale[self.idx_scale] |
| | ih, iw = lr.shape[:2] |
| | hr = hr[0:ih * scale, 0:iw * scale] |
| | return lr, hr |
| |
|
| | def set_scale(self, idx_scale): |
| | if not self.input_large: |
| | self.idx_scale = idx_scale |
| | else: |
| | self.idx_scale = random.randint(0, len(self.scale) - 1) |
| |
|
| |
|