Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from collections import OrderedDict | |
| class ESPCN(nn.Module): | |
| def __init__(self): | |
| super(ESPCN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2) | |
| self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1) | |
| self.conv3 = nn.Conv2d(32, (4 ** 2) * 3, kernel_size=3, padding=1) # Handle 3-channel output | |
| self.pixel_shuffle = nn.PixelShuffle(4) | |
| def forward(self, x): | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = self.pixel_shuffle(self.conv3(x)) | |
| return x | |
| def espcn_upscale(image): | |
| image = image/255 | |
| image = torch.from_numpy(image).float().unsqueeze(dim=0).permute(0,3,1,2) | |
| model = ESPCN() | |
| checkpoint = torch.load("weights/model_espcn.pth", map_location=torch.device('cpu')) # or 'cuda' if using GPU | |
| state_dict = checkpoint['state_dict'] | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_key = k.replace("module.", "") # Remove "module." from each key | |
| new_state_dict[new_key] = v | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(image) | |
| output = output.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| output = (output * 255.0).clip(0, 255).astype("uint8") | |
| return output | |