import torch import torch.nn as nn import torch.nn.functional as F from torchvision.io import read_image from torchvision.utils import save_image import os from torchvision.transforms.functional import to_tensor, to_pil_image # === RCAN MODULES === class CALayer(nn.Module): def __init__(self, channels, reduction=16): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Conv2d(channels, channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // reduction, channels, 1), nn.Sigmoid() ) def forward(self, x): w = self.fc(self.avg_pool(x)) return x * w class RCAB(nn.Module): def __init__(self, channels): super().__init__() self.body = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, 3, padding=1), CALayer(channels) ) def forward(self, x): return x + self.body(x) class ResidualGroup(nn.Module): def __init__(self, channels, n_rcab): super().__init__() modules = [RCAB(channels) for _ in range(n_rcab)] modules.append(nn.Conv2d(channels, channels, 3, padding=1)) self.body = nn.Sequential(*modules) def forward(self, x): return x + self.body(x) class Upsampler(nn.Sequential): def __init__(self, scale, channels): m = [] for _ in range(int(torch.log2(torch.tensor(scale)))): m.append(nn.Conv2d(channels, channels * 4, 3, padding=1)) m.append(nn.PixelShuffle(2)) super().__init__(*m) class RCAN(nn.Module): def __init__(self, in_channels=3, out_channels=3, n_feat=64, n_rg=10, n_rcab=20, scale=4): super().__init__() self.head = nn.Conv2d(in_channels, n_feat, 3, padding=1) self.body = nn.Sequential( *[ResidualGroup(n_feat, n_rcab) for _ in range(n_rg)], nn.Conv2d(n_feat, n_feat, 3, padding=1) ) self.upsample = Upsampler(scale, n_feat) self.tail = nn.Conv2d(n_feat, out_channels, 3, padding=1) def forward(self, x): x = self.head(x) res = self.body(x) x = x + res x = self.upsample(x) return self.tail(x) # === INFERENCE === def rcan_upscale(lr_img_pil, model_path="weights/rcan_epoch_20.pth", device='cpu'): """ Super resolves a low-resolution PIL image using the RCAN model. Args: lr_img_pil (PIL.Image): Low-resolution input image. model_path (str): Path to the model weights. device (str): 'cuda' or 'cpu'. Returns: PIL.Image: High-resolution output image. """ # Load model device = torch.device(device if torch.cuda.is_available() else 'cpu') model = RCAN(scale=4) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device).eval() # Convert PIL image to normalized tensor lr_tensor = to_tensor(lr_img_pil).unsqueeze(0).to(device) # Add batch dim # Inference with torch.no_grad(): sr_tensor = model(lr_tensor).squeeze(0).clamp(0, 1).cpu() # Remove batch # Convert tensor back to PIL image sr_img_pil = to_pil_image(sr_tensor) return sr_img_pil