NightRaven109's picture
Update app.py
ce8552a verified
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from collections import OrderedDict
from skimage import img_as_ubyte
import spaces
from model.CMFNet import CMFNet
# Download model weights on startup
if not os.path.exists('experiments/pretrained_models/deblur_GoPro_CMFNet.pth'):
os.makedirs('experiments/pretrained_models', exist_ok=True)
os.system('wget https://github.com/FanChiMao/CMFNet/releases/download/v0.0/deblur_GoPro_CMFNet.pth -P experiments/pretrained_models')
# Global model variable
model = None
device = None
def load_model():
"""Load the CMFNet model"""
global model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CMFNet()
model = model.to(device)
model.eval()
# Load checkpoint
weights_path = 'experiments/pretrained_models/deblur_GoPro_CMFNet.pth'
checkpoint = torch.load(weights_path, map_location=device)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("Model loaded successfully!")
# Load model on startup
load_model()
@spaces.GPU
def deblur_image(image: Image.Image) -> Image.Image:
"""
Deblur an input image using CMFNet
Args:
image: PIL Image to deblur
Returns:
PIL Image of deblurred result
"""
if model is None:
raise gr.Error("Model not loaded properly")
try:
# Preprocess image
input_tensor = TF.to_tensor(image).unsqueeze(0).to(device)
# Pad image to be multiple of 8
h, w = input_tensor.shape[2], input_tensor.shape[3]
mul = 8
H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
padh = H - h if h % mul != 0 else 0
padw = W - w if w % mul != 0 else 0
input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
# Run inference
with torch.no_grad():
output = model(input_tensor)
# Post-process
output = torch.clamp(output, 0, 1)
output = output[:, :, :h, :w] # Remove padding
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output = img_as_ubyte(output)
# Convert back to PIL Image
result_image = Image.fromarray(output)
return result_image
except Exception as e:
raise gr.Error(f"Error during inference: {str(e)}")
# Gradio interface
title = "CMFNet Image Deblurring"
description = """
# Compound Multi-branch Feature Fusion for Image Deblurring
Upload a blurry image to get a deblurred version using CMFNet. The model works best on motion blur and defocus blur.
**Note**: Images will be resized to have a maximum dimension of 512px for faster processing.
"""
article = """
<p style='text-align: center'>
<a href='https://github.com/FanChiMao/CMFNet' target='_blank'>GitHub Repository</a>
</p>
"""
# Example images
examples = [
"images/Blur1.png",
"images/Blur2.png",
"images/Blur5.png"
]
# Create Gradio interface
demo = gr.Interface(
fn=deblur_image,
inputs=gr.Image(type="pil", label="Upload Blurry Image"),
outputs=gr.Image(type="pil", label="Deblurred Image"),
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True,
theme=gr.themes.Soft(),
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(debug=True)