Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms.functional as TF | |
import torch.nn.functional as F | |
from collections import OrderedDict | |
from skimage import img_as_ubyte | |
import spaces | |
from model.CMFNet import CMFNet | |
# Download model weights on startup | |
if not os.path.exists('experiments/pretrained_models/deblur_GoPro_CMFNet.pth'): | |
os.makedirs('experiments/pretrained_models', exist_ok=True) | |
os.system('wget https://github.com/FanChiMao/CMFNet/releases/download/v0.0/deblur_GoPro_CMFNet.pth -P experiments/pretrained_models') | |
# Global model variable | |
model = None | |
device = None | |
def load_model(): | |
"""Load the CMFNet model""" | |
global model, device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = CMFNet() | |
model = model.to(device) | |
model.eval() | |
# Load checkpoint | |
weights_path = 'experiments/pretrained_models/deblur_GoPro_CMFNet.pth' | |
checkpoint = torch.load(weights_path, map_location=device) | |
try: | |
model.load_state_dict(checkpoint["state_dict"]) | |
except: | |
state_dict = checkpoint["state_dict"] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] if k.startswith('module.') else k | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict) | |
print("Model loaded successfully!") | |
# Load model on startup | |
load_model() | |
def deblur_image(image: Image.Image) -> Image.Image: | |
""" | |
Deblur an input image using CMFNet | |
Args: | |
image: PIL Image to deblur | |
Returns: | |
PIL Image of deblurred result | |
""" | |
if model is None: | |
raise gr.Error("Model not loaded properly") | |
try: | |
# Preprocess image | |
input_tensor = TF.to_tensor(image).unsqueeze(0).to(device) | |
# Pad image to be multiple of 8 | |
h, w = input_tensor.shape[2], input_tensor.shape[3] | |
mul = 8 | |
H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul | |
padh = H - h if h % mul != 0 else 0 | |
padw = W - w if w % mul != 0 else 0 | |
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') | |
# Run inference | |
with torch.no_grad(): | |
output = model(input_tensor) | |
# Post-process | |
output = torch.clamp(output, 0, 1) | |
output = output[:, :, :h, :w] # Remove padding | |
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
output = img_as_ubyte(output) | |
# Convert back to PIL Image | |
result_image = Image.fromarray(output) | |
return result_image | |
except Exception as e: | |
raise gr.Error(f"Error during inference: {str(e)}") | |
# Gradio interface | |
title = "CMFNet Image Deblurring" | |
description = """ | |
# Compound Multi-branch Feature Fusion for Image Deblurring | |
Upload a blurry image to get a deblurred version using CMFNet. The model works best on motion blur and defocus blur. | |
**Note**: Images will be resized to have a maximum dimension of 512px for faster processing. | |
""" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://github.com/FanChiMao/CMFNet' target='_blank'>GitHub Repository</a> | |
</p> | |
""" | |
# Example images | |
examples = [ | |
"images/Blur1.png", | |
"images/Blur2.png", | |
"images/Blur5.png" | |
] | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=deblur_image, | |
inputs=gr.Image(type="pil", label="Upload Blurry Image"), | |
outputs=gr.Image(type="pil", label="Deblurred Image"), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
cache_examples=True, | |
theme=gr.themes.Soft(), | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |