Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from collections import OrderedDict | |
# SRCNN Model Definition | |
class SuperResolutionCNN(nn.Module): | |
def __init__(self): | |
super(SuperResolutionCNN, self).__init__() | |
# Feature extraction | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=9, padding=4), | |
nn.ReLU(), | |
nn.Conv2d(64, 32, kernel_size=1, padding=0), | |
nn.ReLU() | |
) | |
# Upsampling blocks | |
self.upsample = nn.Sequential( | |
nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1), | |
nn.ReLU(), | |
nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1), | |
nn.ReLU() | |
) | |
# Reconstruction | |
self.reconstruction = nn.Conv2d(32, 3, kernel_size=5, padding=2) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.upsample(x) | |
x = self.reconstruction(x) | |
return torch.sigmoid(x) | |
def srcnn_upscale(image, model_path="weights/model_cnn.pth", device='cpu'): | |
""" | |
Upscale image using SRCNN model | |
Input: numpy array (H,W,3) in range 0-255 | |
Output: numpy array (H,W,3) in range 0-255 | |
Args: | |
image: Input image as numpy array (H,W,3) in range 0-255 | |
model_path: Path to the trained SRCNN model weights | |
device: 'cpu' or 'cuda' for GPU acceleration | |
""" | |
# Normalize and convert to tensor [1,3,H,W] | |
image = image / 255.0 | |
image = torch.from_numpy(image).float().unsqueeze(0).permute(0, 3, 1, 2).to(device) | |
# Load model | |
model = SuperResolutionCNN().to(device) | |
# Load weights | |
checkpoint = torch.load(model_path, map_location=torch.device(device)) | |
# Handle different checkpoint formats | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
elif 'model_state_dict' in checkpoint: | |
state_dict = checkpoint['model_state_dict'] | |
else: | |
state_dict = checkpoint # Assume direct state dict | |
# Remove "module." prefix if present (for DataParallel models) | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
new_key = k.replace("module.", "") | |
new_state_dict[new_key] = v | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
# Process image | |
with torch.no_grad(): | |
output = model(image) | |
# Convert back to numpy array [H,W,3] 0-255 | |
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
output = (output * 255.0).clip(0, 255).astype("uint8") | |
return output | |