Fabrice-TIERCELIN commited on
Commit
d9ebca7
·
verified ·
1 Parent(s): 2856d54
Files changed (1) hide show
  1. RealESRGAN/model.py +92 -92
RealESRGAN/model.py CHANGED
@@ -1,93 +1,93 @@
1
- import os
2
- import torch
3
- from torch.nn import functional as F
4
- from PIL import Image
5
- import numpy as np
6
- import cv2
7
- from huggingface_hub import hf_hub_url, hf_hub_download, cached_download
8
-
9
- from .rrdbnet_arch import RRDBNet
10
- from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
11
- unpad_image
12
-
13
- HF_MODELS = {
14
- 2: dict(
15
- repo_id='sberbank-ai/Real-ESRGAN',
16
- filename='RealESRGAN_x2.pth',
17
- ),
18
- 4: dict(
19
- repo_id='sberbank-ai/Real-ESRGAN',
20
- filename='RealESRGAN_x4.pth',
21
- ),
22
- 8: dict(
23
- repo_id='sberbank-ai/Real-ESRGAN',
24
- filename='RealESRGAN_x8.pth',
25
- ),
26
- }
27
-
28
-
29
- class RealESRGAN:
30
- def __init__(self, device, scale=4):
31
- self.device = device
32
- self.scale = scale
33
- self.model = RRDBNet(
34
- num_in_ch=3, num_out_ch=3, num_feat=64,
35
- num_block=23, num_grow_ch=32, scale=scale
36
- )
37
-
38
- def load_weights(self, model_path, download=True):
39
- if not os.path.exists(model_path) and download:
40
- assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
41
- config = HF_MODELS[self.scale]
42
- cache_dir = os.path.dirname(model_path)
43
- local_filename = os.path.basename(model_path)
44
- config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
45
- htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
46
- filename=config['filename'])
47
- print(htr)
48
- # cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
49
- print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
50
-
51
- loadnet = torch.load(model_path)
52
- if 'params' in loadnet:
53
- self.model.load_state_dict(loadnet['params'], strict=True)
54
- elif 'params_ema' in loadnet:
55
- self.model.load_state_dict(loadnet['params_ema'], strict=True)
56
- else:
57
- self.model.load_state_dict(loadnet, strict=True)
58
- self.model.eval()
59
- self.model.to(self.device)
60
-
61
- # @torch.cuda.amp.autocast()
62
- def predict(self, lr_image, batch_size=4, patches_size=192,
63
- padding=24, pad_size=15):
64
- torch.autocast(device_type=self.device.type)
65
- scale = self.scale
66
- device = self.device
67
- lr_image = np.array(lr_image)
68
- lr_image = pad_reflect(lr_image, pad_size)
69
-
70
- patches, p_shape = split_image_into_overlapping_patches(
71
- lr_image, patch_size=patches_size, padding_size=padding
72
- )
73
- img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
74
-
75
- with torch.no_grad():
76
- res = self.model(img[0:batch_size])
77
- for i in range(batch_size, img.shape[0], batch_size):
78
- res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
79
-
80
- sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
81
- np_sr_image = sr_image.numpy()
82
-
83
- padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
84
- scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
85
- np_sr_image = stich_together(
86
- np_sr_image, padded_image_shape=padded_size_scaled,
87
- target_shape=scaled_image_shape, padding_size=padding * scale
88
- )
89
- sr_img = (np_sr_image * 255).astype(np.uint8)
90
- sr_img = unpad_image(sr_img, pad_size * scale)
91
- sr_img = Image.fromarray(sr_img)
92
-
93
  return sr_img
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+ from huggingface_hub import hf_hub_url, hf_hub_download
8
+
9
+ from .rrdbnet_arch import RRDBNet
10
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
11
+ unpad_image
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
41
+ config = HF_MODELS[self.scale]
42
+ cache_dir = os.path.dirname(model_path)
43
+ local_filename = os.path.basename(model_path)
44
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
45
+ htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
46
+ filename=config['filename'])
47
+ print(htr)
48
+ # cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
49
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
50
+
51
+ loadnet = torch.load(model_path)
52
+ if 'params' in loadnet:
53
+ self.model.load_state_dict(loadnet['params'], strict=True)
54
+ elif 'params_ema' in loadnet:
55
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
56
+ else:
57
+ self.model.load_state_dict(loadnet, strict=True)
58
+ self.model.eval()
59
+ self.model.to(self.device)
60
+
61
+ # @torch.cuda.amp.autocast()
62
+ def predict(self, lr_image, batch_size=4, patches_size=192,
63
+ padding=24, pad_size=15):
64
+ torch.autocast(device_type=self.device.type)
65
+ scale = self.scale
66
+ device = self.device
67
+ lr_image = np.array(lr_image)
68
+ lr_image = pad_reflect(lr_image, pad_size)
69
+
70
+ patches, p_shape = split_image_into_overlapping_patches(
71
+ lr_image, patch_size=patches_size, padding_size=padding
72
+ )
73
+ img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
74
+
75
+ with torch.no_grad():
76
+ res = self.model(img[0:batch_size])
77
+ for i in range(batch_size, img.shape[0], batch_size):
78
+ res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
79
+
80
+ sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
81
+ np_sr_image = sr_image.numpy()
82
+
83
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
84
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
85
+ np_sr_image = stich_together(
86
+ np_sr_image, padded_image_shape=padded_size_scaled,
87
+ target_shape=scaled_image_shape, padding_size=padding * scale
88
+ )
89
+ sr_img = (np_sr_image * 255).astype(np.uint8)
90
+ sr_img = unpad_image(sr_img, pad_size * scale)
91
+ sr_img = Image.fromarray(sr_img)
92
+
93
  return sr_img