import os import torch import gradio as gr import numpy as np from PIL import Image import torchvision.transforms.functional as TF import torch.nn.functional as F from collections import OrderedDict from skimage import img_as_ubyte import spaces from model.CMFNet import CMFNet # Download model weights on startup if not os.path.exists('experiments/pretrained_models/deblur_GoPro_CMFNet.pth'): os.makedirs('experiments/pretrained_models', exist_ok=True) os.system('wget https://github.com/FanChiMao/CMFNet/releases/download/v0.0/deblur_GoPro_CMFNet.pth -P experiments/pretrained_models') # Global model variable model = None device = None def load_model(): """Load the CMFNet model""" global model, device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CMFNet() model = model.to(device) model.eval() # Load checkpoint weights_path = 'experiments/pretrained_models/deblur_GoPro_CMFNet.pth' checkpoint = torch.load(weights_path, map_location=device) try: model.load_state_dict(checkpoint["state_dict"]) except: state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) print("Model loaded successfully!") # Load model on startup load_model() @spaces.GPU def deblur_image(image: Image.Image) -> Image.Image: """ Deblur an input image using CMFNet Args: image: PIL Image to deblur Returns: PIL Image of deblurred result """ if model is None: raise gr.Error("Model not loaded properly") try: # Preprocess image input_tensor = TF.to_tensor(image).unsqueeze(0).to(device) # Pad image to be multiple of 8 h, w = input_tensor.shape[2], input_tensor.shape[3] mul = 8 H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul padh = H - h if h % mul != 0 else 0 padw = W - w if w % mul != 0 else 0 input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect') # Run inference with torch.no_grad(): output = model(input_tensor) # Post-process output = torch.clamp(output, 0, 1) output = output[:, :, :h, :w] # Remove padding output = output.squeeze(0).permute(1, 2, 0).cpu().numpy() output = img_as_ubyte(output) # Convert back to PIL Image result_image = Image.fromarray(output) return result_image except Exception as e: raise gr.Error(f"Error during inference: {str(e)}") # Gradio interface title = "CMFNet Image Deblurring" description = """ # Compound Multi-branch Feature Fusion for Image Deblurring Upload a blurry image to get a deblurred version using CMFNet. The model works best on motion blur and defocus blur. **Note**: Images will be resized to have a maximum dimension of 512px for faster processing. """ article = """

GitHub Repository

""" # Example images examples = [ "images/Blur1.png", "images/Blur2.png", "images/Blur5.png" ] # Create Gradio interface demo = gr.Interface( fn=deblur_image, inputs=gr.Image(type="pil", label="Upload Blurry Image"), outputs=gr.Image(type="pil", label="Deblurred Image"), title=title, description=description, article=article, examples=examples, cache_examples=True, theme=gr.themes.Soft(), allow_flagging="never" ) if __name__ == "__main__": demo.launch(debug=True)