CV_Project / models /espcn.py
R2454's picture
Updated app.py
5ee74b8
raw
history blame
1.38 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class ESPCN(nn.Module):
def __init__(self):
super(ESPCN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(32, (4 ** 2) * 3, kernel_size=3, padding=1) # Handle 3-channel output
self.pixel_shuffle = nn.PixelShuffle(4)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pixel_shuffle(self.conv3(x))
return x
def espcn_upscale(image):
image = image/255
image = torch.from_numpy(image).float().unsqueeze(dim=0).permute(0,3,1,2)
model = ESPCN()
checkpoint = torch.load("weights/model_espcn.pth", map_location=torch.device('cpu')) # or 'cuda' if using GPU
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace("module.", "") # Remove "module." from each key
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
model.eval()
with torch.no_grad():
output = model(image)
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output = (output * 255.0).clip(0, 255).astype("uint8")
return output