R2454 commited on
Commit
f8a03f8
·
1 Parent(s): 73408a5

Added Multiple things

Browse files
Files changed (3) hide show
  1. app.py +12 -4
  2. models/rcan.py +67 -0
  3. weights/rcan_epoch_20.pth +3 -0
app.py CHANGED
@@ -20,10 +20,11 @@ import numpy as np
20
 
21
  from PIL import Image
22
  import numpy as np
 
23
 
24
  def edge_directed_interpolation(lr_img_pil, scale=2):
25
  # Ensure input is a PIL Image
26
-
27
  if isinstance(lr_img_pil, np.ndarray):
28
  lr_img_pil = Image.fromarray(lr_img_pil)
29
 
@@ -128,11 +129,18 @@ edi_page = gr.Interface(
128
  title="Edge Directed Interpolation"
129
  )
130
 
131
- # === Tabs === #
 
 
 
 
 
 
 
132
  demo = gr.TabbedInterface(
133
- [lancros_page, fourier_page, autoencoder_page, srgan_page, espcn_page, random_forest_page, edi_page],
134
  ["Lancros Interpolation", "Fourier Interpolation", "Autoencoder based Super Resolution", "GAN based Super Resolution",
135
- "EspCN Super Resolution", "Random Forest based Super Resolution", "Edge Directed Interpolation"],
136
  title="Image Super Resolution"
137
  )
138
 
 
20
 
21
  from PIL import Image
22
  import numpy as np
23
+ from models.rcan import rcan_upscale
24
 
25
  def edge_directed_interpolation(lr_img_pil, scale=2):
26
  # Ensure input is a PIL Image
27
+
28
  if isinstance(lr_img_pil, np.ndarray):
29
  lr_img_pil = Image.fromarray(lr_img_pil)
30
 
 
129
  title="Edge Directed Interpolation"
130
  )
131
 
132
+ rcan_page = gr.Interface(
133
+ fn=rcan_upscale,
134
+ inputs=[gr.Image(label="Low Resolution Image")],
135
+ outputs=gr.Image(type="pil", label="High Resolution Image"),
136
+ title="RCAN based Super Resolution"
137
+ )
138
+
139
+ # Tabs setup
140
  demo = gr.TabbedInterface(
141
+ [lancros_page, fourier_page, autoencoder_page, srgan_page, espcn_page, random_forest_page, edi_page, rcan_page],
142
  ["Lancros Interpolation", "Fourier Interpolation", "Autoencoder based Super Resolution", "GAN based Super Resolution",
143
+ "EspCN Super Resolution", "Random Forest based Super Resolution", "Edge Directed Interpolation", "RCAN Super Resolution"],
144
  title="Image Super Resolution"
145
  )
146
 
models/rcan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import io
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ # Define your RCAN model architecture here (or import if defined elsewhere)
8
+ # Example stub (you should replace this with your actual model definition)
9
+ class RCAN(torch.nn.Module):
10
+ def __init__(self):
11
+ super(RCAN, self).__init__()
12
+ # Dummy model: replace with your actual RCAN layers
13
+ self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, padding=1)
14
+
15
+ def forward(self, x):
16
+ return self.conv(x)
17
+
18
+ # Load model once and reuse (optional: use a global cache)
19
+ _model = None
20
+ def load_model(model_path, device='cuda'):
21
+ global _model
22
+ if _model is None:
23
+ _model = torch.load(model_path, map_location=device)
24
+ _model.eval().to(device)
25
+ return _model
26
+
27
+ def super_resolve(model_path, img_bytes, device='cuda'):
28
+ """
29
+ Perform super-resolution using RCAN on input image bytes and return upscaled image bytes.
30
+ """
31
+ # Load model
32
+ model = load_model(model_path, device)
33
+
34
+ # Decode image bytes to PIL
35
+ lr_img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
36
+
37
+ # Preprocess: PIL -> Tensor
38
+ transform = transforms.ToTensor()
39
+ input_tensor = transform(lr_img).unsqueeze(0).to(device) # shape: (1, 3, H, W)
40
+
41
+ # Model inference
42
+ with torch.no_grad():
43
+ output_tensor = model(input_tensor)
44
+
45
+ # Postprocess: Tensor -> PIL
46
+ output_tensor = output_tensor.squeeze(0).clamp(0, 1).cpu() # shape: (3, H, W)
47
+ upscaled_img = transforms.ToPILImage()(output_tensor)
48
+
49
+ # Encode to bytes
50
+ byte_io = io.BytesIO()
51
+ upscaled_img.save(byte_io, format='PNG')
52
+ return byte_io.getvalue()
53
+
54
+ def rcan_upscale(lr_img):
55
+ """
56
+ High-level wrapper to be called from Gradio or anywhere else.
57
+ Args:
58
+ lr_img (PIL.Image): Input low-resolution image
59
+ Returns:
60
+ PIL.Image: Output high-resolution image
61
+ """
62
+ img_byte_arr = io.BytesIO()
63
+ lr_img.save(img_byte_arr, format='PNG')
64
+ img_bytes = img_byte_arr.getvalue()
65
+
66
+ upscaled_bytes = super_resolve(model_path="weights/rcan_epoch_20.pth", img_bytes=img_bytes, device='cuda')
67
+ return Image.open(io.BytesIO(upscaled_bytes))
weights/rcan_epoch_20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1efb22fe861cd418170c0b9527d968dc0f11aee46fa20fd3a58e57c8016418dc
3
+ size 63015800