Spaces:
Running
Running
Added Multiple things
Browse files- app.py +12 -4
- models/rcan.py +67 -0
- weights/rcan_epoch_20.pth +3 -0
app.py
CHANGED
@@ -20,10 +20,11 @@ import numpy as np
|
|
20 |
|
21 |
from PIL import Image
|
22 |
import numpy as np
|
|
|
23 |
|
24 |
def edge_directed_interpolation(lr_img_pil, scale=2):
|
25 |
# Ensure input is a PIL Image
|
26 |
-
|
27 |
if isinstance(lr_img_pil, np.ndarray):
|
28 |
lr_img_pil = Image.fromarray(lr_img_pil)
|
29 |
|
@@ -128,11 +129,18 @@ edi_page = gr.Interface(
|
|
128 |
title="Edge Directed Interpolation"
|
129 |
)
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
demo = gr.TabbedInterface(
|
133 |
-
[lancros_page, fourier_page, autoencoder_page, srgan_page, espcn_page, random_forest_page, edi_page],
|
134 |
["Lancros Interpolation", "Fourier Interpolation", "Autoencoder based Super Resolution", "GAN based Super Resolution",
|
135 |
-
"EspCN Super Resolution", "Random Forest based Super Resolution", "Edge Directed Interpolation"],
|
136 |
title="Image Super Resolution"
|
137 |
)
|
138 |
|
|
|
20 |
|
21 |
from PIL import Image
|
22 |
import numpy as np
|
23 |
+
from models.rcan import rcan_upscale
|
24 |
|
25 |
def edge_directed_interpolation(lr_img_pil, scale=2):
|
26 |
# Ensure input is a PIL Image
|
27 |
+
|
28 |
if isinstance(lr_img_pil, np.ndarray):
|
29 |
lr_img_pil = Image.fromarray(lr_img_pil)
|
30 |
|
|
|
129 |
title="Edge Directed Interpolation"
|
130 |
)
|
131 |
|
132 |
+
rcan_page = gr.Interface(
|
133 |
+
fn=rcan_upscale,
|
134 |
+
inputs=[gr.Image(label="Low Resolution Image")],
|
135 |
+
outputs=gr.Image(type="pil", label="High Resolution Image"),
|
136 |
+
title="RCAN based Super Resolution"
|
137 |
+
)
|
138 |
+
|
139 |
+
# Tabs setup
|
140 |
demo = gr.TabbedInterface(
|
141 |
+
[lancros_page, fourier_page, autoencoder_page, srgan_page, espcn_page, random_forest_page, edi_page, rcan_page],
|
142 |
["Lancros Interpolation", "Fourier Interpolation", "Autoencoder based Super Resolution", "GAN based Super Resolution",
|
143 |
+
"EspCN Super Resolution", "Random Forest based Super Resolution", "Edge Directed Interpolation", "RCAN Super Resolution"],
|
144 |
title="Image Super Resolution"
|
145 |
)
|
146 |
|
models/rcan.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# Define your RCAN model architecture here (or import if defined elsewhere)
|
8 |
+
# Example stub (you should replace this with your actual model definition)
|
9 |
+
class RCAN(torch.nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(RCAN, self).__init__()
|
12 |
+
# Dummy model: replace with your actual RCAN layers
|
13 |
+
self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return self.conv(x)
|
17 |
+
|
18 |
+
# Load model once and reuse (optional: use a global cache)
|
19 |
+
_model = None
|
20 |
+
def load_model(model_path, device='cuda'):
|
21 |
+
global _model
|
22 |
+
if _model is None:
|
23 |
+
_model = torch.load(model_path, map_location=device)
|
24 |
+
_model.eval().to(device)
|
25 |
+
return _model
|
26 |
+
|
27 |
+
def super_resolve(model_path, img_bytes, device='cuda'):
|
28 |
+
"""
|
29 |
+
Perform super-resolution using RCAN on input image bytes and return upscaled image bytes.
|
30 |
+
"""
|
31 |
+
# Load model
|
32 |
+
model = load_model(model_path, device)
|
33 |
+
|
34 |
+
# Decode image bytes to PIL
|
35 |
+
lr_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
|
36 |
+
|
37 |
+
# Preprocess: PIL -> Tensor
|
38 |
+
transform = transforms.ToTensor()
|
39 |
+
input_tensor = transform(lr_img).unsqueeze(0).to(device) # shape: (1, 3, H, W)
|
40 |
+
|
41 |
+
# Model inference
|
42 |
+
with torch.no_grad():
|
43 |
+
output_tensor = model(input_tensor)
|
44 |
+
|
45 |
+
# Postprocess: Tensor -> PIL
|
46 |
+
output_tensor = output_tensor.squeeze(0).clamp(0, 1).cpu() # shape: (3, H, W)
|
47 |
+
upscaled_img = transforms.ToPILImage()(output_tensor)
|
48 |
+
|
49 |
+
# Encode to bytes
|
50 |
+
byte_io = io.BytesIO()
|
51 |
+
upscaled_img.save(byte_io, format='PNG')
|
52 |
+
return byte_io.getvalue()
|
53 |
+
|
54 |
+
def rcan_upscale(lr_img):
|
55 |
+
"""
|
56 |
+
High-level wrapper to be called from Gradio or anywhere else.
|
57 |
+
Args:
|
58 |
+
lr_img (PIL.Image): Input low-resolution image
|
59 |
+
Returns:
|
60 |
+
PIL.Image: Output high-resolution image
|
61 |
+
"""
|
62 |
+
img_byte_arr = io.BytesIO()
|
63 |
+
lr_img.save(img_byte_arr, format='PNG')
|
64 |
+
img_bytes = img_byte_arr.getvalue()
|
65 |
+
|
66 |
+
upscaled_bytes = super_resolve(model_path="weights/rcan_epoch_20.pth", img_bytes=img_bytes, device='cuda')
|
67 |
+
return Image.open(io.BytesIO(upscaled_bytes))
|
weights/rcan_epoch_20.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1efb22fe861cd418170c0b9527d968dc0f11aee46fa20fd3a58e57c8016418dc
|
3 |
+
size 63015800
|