NightRaven109 commited on
Commit
ce8552a
·
verified ·
1 Parent(s): 26122c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -34
app.py CHANGED
@@ -1,46 +1,132 @@
1
- import os
 
2
  import gradio as gr
 
3
  from PIL import Image
4
- import torch
 
 
 
5
  import spaces
6
 
7
- os.system(
8
- 'wget https://github.com/FanChiMao/CMFNet/releases/download/v0.0/deblur_GoPro_CMFNet.pth -P experiments/pretrained_models')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
10
 
11
  @spaces.GPU
12
- def inference(img):
13
- os.makedirs("./test", exist_ok=True)
14
- os.makedirs("./results", exist_ok=True)
15
- basewidth = 512
16
- wpercent = (basewidth / float(img.size[0]))
17
- hsize = int((float(img.size[1]) * float(wpercent)))
18
- img = img.resize((basewidth, hsize), Image.BILINEAR)
19
- img.save("test/1.png", "PNG")
20
 
21
- result = os.system(
22
- 'python main_test_CMFNet.py --input_dir test --weights experiments/pretrained_models/deblur_GoPro_CMFNet.pth')
 
 
 
 
 
 
23
 
24
- output_path = 'results/1.png'
25
- if os.path.exists(output_path):
26
- return output_path
27
- else:
28
- print(f"Error: Output file not found at {output_path}")
29
- print(f"Command exit code: {result}")
30
- return None
31
-
32
-
33
- title = "Compound Multi-branch Feature Fusion for Image Restoration (Deblur)"
34
- description = "Gradio demo for CMFNet. CMFNet achieves competitive performance on three tasks: image deblurring, image dehazing and image deraindrop. Here, we provide a demo for image deblur. To use it, simply upload your image, or click one of the examples to load them. Reference from: https://huggingface.co/akhaliq"
35
- article = "<p style='text-align: center'><a href='https://' target='_blank'>Compound Multi-branch Feature Fusion for Real Image Restoration</a> | <a href='https://github.com/FanChiMao/CMFNet' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=52Hz_CMFNet_deblurring' alt='visitor badge'></center>"
36
-
37
- examples = [['images/Blur1.png'], ['images/Blur2.png'], ['images/Blur5.png'],]
38
- gr.Interface(
39
- inference,
40
- [gr.components.Image(type="pil", label="Input")],
41
- gr.components.Image(type="filepath", label="Output"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  title=title,
43
  description=description,
44
  article=article,
45
- examples=examples
46
- ).launch(debug=True)
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ import numpy as np
5
  from PIL import Image
6
+ import torchvision.transforms.functional as TF
7
+ import torch.nn.functional as F
8
+ from collections import OrderedDict
9
+ from skimage import img_as_ubyte
10
  import spaces
11
 
12
+ from model.CMFNet import CMFNet
13
+
14
+ # Download model weights on startup
15
+ if not os.path.exists('experiments/pretrained_models/deblur_GoPro_CMFNet.pth'):
16
+ os.makedirs('experiments/pretrained_models', exist_ok=True)
17
+ os.system('wget https://github.com/FanChiMao/CMFNet/releases/download/v0.0/deblur_GoPro_CMFNet.pth -P experiments/pretrained_models')
18
+
19
+ # Global model variable
20
+ model = None
21
+ device = None
22
+
23
+ def load_model():
24
+ """Load the CMFNet model"""
25
+ global model, device
26
+
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ model = CMFNet()
29
+ model = model.to(device)
30
+ model.eval()
31
+
32
+ # Load checkpoint
33
+ weights_path = 'experiments/pretrained_models/deblur_GoPro_CMFNet.pth'
34
+ checkpoint = torch.load(weights_path, map_location=device)
35
+
36
+ try:
37
+ model.load_state_dict(checkpoint["state_dict"])
38
+ except:
39
+ state_dict = checkpoint["state_dict"]
40
+ new_state_dict = OrderedDict()
41
+ for k, v in state_dict.items():
42
+ name = k[7:] if k.startswith('module.') else k
43
+ new_state_dict[name] = v
44
+ model.load_state_dict(new_state_dict)
45
+
46
+ print("Model loaded successfully!")
47
 
48
+ # Load model on startup
49
+ load_model()
50
 
51
  @spaces.GPU
52
+ def deblur_image(image: Image.Image) -> Image.Image:
53
+ """
54
+ Deblur an input image using CMFNet
 
 
 
 
 
55
 
56
+ Args:
57
+ image: PIL Image to deblur
58
+
59
+ Returns:
60
+ PIL Image of deblurred result
61
+ """
62
+ if model is None:
63
+ raise gr.Error("Model not loaded properly")
64
 
65
+ try:
66
+ # Preprocess image
67
+ input_tensor = TF.to_tensor(image).unsqueeze(0).to(device)
68
+
69
+ # Pad image to be multiple of 8
70
+ h, w = input_tensor.shape[2], input_tensor.shape[3]
71
+ mul = 8
72
+ H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
73
+ padh = H - h if h % mul != 0 else 0
74
+ padw = W - w if w % mul != 0 else 0
75
+ input_tensor = F.pad(input_tensor, (0, padw, 0, padh), 'reflect')
76
+
77
+ # Run inference
78
+ with torch.no_grad():
79
+ output = model(input_tensor)
80
+
81
+ # Post-process
82
+ output = torch.clamp(output, 0, 1)
83
+ output = output[:, :, :h, :w] # Remove padding
84
+ output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
85
+ output = img_as_ubyte(output)
86
+
87
+ # Convert back to PIL Image
88
+ result_image = Image.fromarray(output)
89
+ return result_image
90
+
91
+ except Exception as e:
92
+ raise gr.Error(f"Error during inference: {str(e)}")
93
+
94
+ # Gradio interface
95
+ title = "CMFNet Image Deblurring"
96
+ description = """
97
+ # Compound Multi-branch Feature Fusion for Image Deblurring
98
+
99
+ Upload a blurry image to get a deblurred version using CMFNet. The model works best on motion blur and defocus blur.
100
+
101
+ **Note**: Images will be resized to have a maximum dimension of 512px for faster processing.
102
+ """
103
+
104
+ article = """
105
+ <p style='text-align: center'>
106
+ <a href='https://github.com/FanChiMao/CMFNet' target='_blank'>GitHub Repository</a>
107
+ </p>
108
+ """
109
+
110
+ # Example images
111
+ examples = [
112
+ "images/Blur1.png",
113
+ "images/Blur2.png",
114
+ "images/Blur5.png"
115
+ ]
116
+
117
+ # Create Gradio interface
118
+ demo = gr.Interface(
119
+ fn=deblur_image,
120
+ inputs=gr.Image(type="pil", label="Upload Blurry Image"),
121
+ outputs=gr.Image(type="pil", label="Deblurred Image"),
122
  title=title,
123
  description=description,
124
  article=article,
125
+ examples=examples,
126
+ cache_examples=True,
127
+ theme=gr.themes.Soft(),
128
+ allow_flagging="never"
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch(debug=True)