File size: 3,673 Bytes
57276d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import cv2
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url

from realesrgan import RealESRGANer


# build sr model
def build_sr_model(scale=2, model_name=None, tile=0, tile_pad=10, pre_pad=0, fp32=False, gpu_id=None):
    # if model_name not specified, use default mapping
    if model_name is None:
        if scale == 2:
            model_name = 'RealESRGAN_x2plus'
        else:
            model_name = 'RealESRGAN_x4plus'

    # model architecture configs
    model_configs = {
        'RealESRGAN_x2plus': {
            'internal_scale': 2,
            'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
            'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
        },
        'RealESRGAN_x4plus': {
            'internal_scale': 4,
            'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
            'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
        }
    }

    if model_name not in model_configs:
        raise ValueError(
            f'Unknown model name: {model_name}. Available models: {list(model_configs.keys())}')

    config = model_configs[model_name]
    model = config['model']()
    file_url = [config['url']]

    model_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'weights', model_name + '.pth')
    if not os.path.isfile(model_path):
        ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
        for url in file_url:
            # model_path will be updated
            model_path = load_file_from_url(
                url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)

    # restorer
    upsampler = RealESRGANer(
        scale=config['internal_scale'],  # Use the internal scale of the model
        model_path=model_path,
        dni_weight=None,
        model=model,
        tile=tile,
        tile_pad=tile_pad,
        pre_pad=pre_pad,
        half=not fp32,
        gpu_id=gpu_id)

    return upsampler


# sr inference code
def sr_inference(input, output_path, upsampler, scale=2, ext='auto', suffix='sr'):
    os.makedirs(output_path, exist_ok=True)

    path = input
    imgname, extension = os.path.splitext(os.path.basename(path))

    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    width = img.shape[1]

    # pad the image to make eliminate the border artifacts
    pad_len = width // 4
    img = cv2.copyMakeBorder(img, 0, 0, pad_len, pad_len, cv2.BORDER_WRAP)
    if len(img.shape) == 3 and img.shape[2] == 4:
        img_mode = 'RGBA'
    else:
        img_mode = None

    try:
        output, _ = upsampler.enhance(
            img, outscale=scale)  # Use the input scale as the final output amplification factor
        # remove the padding
        output = output[:, int(pad_len*scale):int((width+pad_len)*scale), :]
    except RuntimeError as error:
        print('Error', error)
        print(
            'If you encounter CUDA out of memory, try to set --tile with a smaller number.')
    else:
        if ext == 'auto':
            extension = extension[1:]
        else:
            extension = ext
        if img_mode == 'RGBA':  # RGBA images should be saved in png format
            extension = 'png'
        if suffix == '':
            save_path = os.path.join(output_path, f'{imgname}.{extension}')
        else:
            save_path = os.path.join(
                output_path, f'{imgname}_{suffix}.{extension}')
        cv2.imwrite(save_path, output)