coconutiscoding commited on
Commit
5809e83
·
1 Parent(s): ed1cbdb

change model path

Browse files
Files changed (2) hide show
  1. app.py +174 -174
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,5 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # import gradio as gr
2
  # import torch
 
3
  # import numpy as np
4
  # from PIL import Image
5
  # from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
@@ -10,16 +113,82 @@
10
 
11
 
12
  # # ---------------- 下载并加载 LaMa 官方权重 ----------------
13
- # # repo_id = "saic-mdal/lama-big"
14
- # # model_path = hf_hub_download(repo_id=repo_id, filename="big-lama.pt")
15
  # zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip")
16
- # import zipfile
17
  # with zipfile.ZipFile(zip_path, 'r') as z:
18
  # z.extractall("./")
19
- # model_path = "./models/best.ckpt"
20
- # lama_model = torch.jit.load(model_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # lama_model.eval()
22
 
 
23
  # print("torch:", torch.__version__)
24
  # print("numpy:", np.__version__)
25
 
@@ -99,172 +268,3 @@
99
 
100
  # if __name__ == "__main__":
101
  # demo.launch()
102
-
103
- import gradio as gr
104
- import torch
105
- import torch.nn as nn
106
- import numpy as np
107
- from PIL import Image
108
- from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
109
- from huggingface_hub import hf_hub_download
110
- import cv2
111
- import zipfile
112
-
113
-
114
-
115
- # ---------------- 下载并加载 LaMa 官方权重 ----------------
116
- zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip")
117
- with zipfile.ZipFile(zip_path, 'r') as z:
118
- z.extractall("./")
119
- model_path = "./big-lama/models/best.ckpt"
120
-
121
-
122
- # ==========================================================
123
- # LaMa FBAResUNet 定义(官方结构)
124
- # ==========================================================
125
- class GatedConv(nn.Module):
126
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
127
- super().__init__()
128
- padding = (kernel_size - 1) // 2 * dilation
129
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
130
- padding=padding, dilation=dilation)
131
- self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
132
- padding=padding, dilation=dilation)
133
- self.sigmoid = nn.Sigmoid()
134
-
135
- def forward(self, x):
136
- feat = self.conv(x)
137
- mask = self.sigmoid(self.mask_conv(x))
138
- return feat * mask
139
-
140
-
141
- class FBAResUNet(nn.Module):
142
- def __init__(self, input_channels=4, output_channels=3, num_filters=64):
143
- super().__init__()
144
- self.enc1 = GatedConv(input_channels, num_filters)
145
- self.enc2 = GatedConv(num_filters, num_filters * 2, stride=2)
146
- self.enc3 = GatedConv(num_filters * 2, num_filters * 4, stride=2)
147
- self.enc4 = GatedConv(num_filters * 4, num_filters * 8, stride=2)
148
-
149
- self.middle = GatedConv(num_filters * 8, num_filters * 8)
150
-
151
- self.dec4 = nn.ConvTranspose2d(num_filters * 8, num_filters * 4, kernel_size=4, stride=2, padding=1)
152
- self.dec3 = nn.ConvTranspose2d(num_filters * 8, num_filters * 2, kernel_size=4, stride=2, padding=1)
153
- self.dec2 = nn.ConvTranspose2d(num_filters * 4, num_filters, kernel_size=4, stride=2, padding=1)
154
- self.dec1 = nn.Conv2d(num_filters * 2, output_channels, kernel_size=3, padding=1)
155
-
156
- self.relu = nn.ReLU(inplace=True)
157
-
158
- def forward(self, image, mask):
159
- # image: [B,3,H,W], mask: [B,1,H,W]
160
- x = torch.cat([image, mask], dim=1) # -> [B,4,H,W]
161
-
162
- e1 = self.enc1(x)
163
- e2 = self.enc2(self.relu(e1))
164
- e3 = self.enc3(self.relu(e2))
165
- e4 = self.enc4(self.relu(e3))
166
-
167
- m = self.middle(self.relu(e4))
168
-
169
- d4 = self.relu(self.dec4(m))
170
- d4 = torch.cat([d4, e3], dim=1)
171
- d3 = self.relu(self.dec3(d4))
172
- d3 = torch.cat([d3, e2], dim=1)
173
- d2 = self.relu(self.dec2(d3))
174
- d2 = torch.cat([d2, e1], dim=1)
175
- out = torch.sigmoid(self.dec1(d2))
176
- return out
177
-
178
-
179
- # ==========================================================
180
- # 加载 LaMa 预训练权重
181
- # ==========================================================
182
- checkpoint = torch.load(model_path, map_location="cpu")
183
- lama_model = FBAResUNet()
184
- if "state_dict" in checkpoint:
185
- state_dict = {k.replace("netG.", ""): v for k, v in checkpoint["state_dict"].items()}
186
- else:
187
- state_dict = checkpoint
188
- lama_model.load_state_dict(state_dict, strict=False)
189
- lama_model.eval()
190
-
191
-
192
- print("torch:", torch.__version__)
193
- print("numpy:", np.__version__)
194
-
195
- # ---- 加载分割模型(CPU) ----
196
- device = torch.device("cpu")
197
- weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
198
- model = deeplabv3_resnet50(weights=weights).to(device).eval()
199
- preprocess = weights.transforms()
200
-
201
-
202
- MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边
203
-
204
- def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image:
205
- w, h = pil_img.size
206
- if max(w, h) <= max_side:
207
- return pil_img
208
- r = max_side / float(max(w, h))
209
- return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR)
210
-
211
- def segment(image: Image.Image):
212
- print("DEBUG: type(image) =", type(image), "mode=", getattr(image, "mode", None))
213
- if not isinstance(image, Image.Image):
214
- image = Image.fromarray(image)
215
-
216
- image = image.convert("RGB")
217
- image = _resize_if_needed(image)
218
-
219
- # 预处理并推理
220
- x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
221
- x = x.unsqueeze(0).to(device) # [1,3,H,W]
222
-
223
- with torch.no_grad():
224
- out = model(x)["out"][0] # [C,H,W],C=21(含背景)
225
- pred = out.argmax(0).cpu().numpy() # [H,W]
226
-
227
- # 前景 = 非背景(背景类在COCO VOC权重下是0)
228
- fg = (pred != 0).astype(np.uint8)
229
-
230
- # ---------------- mask 膨胀 ----------------
231
- kernel = np.ones((19,19), np.uint8)
232
- fg_dilated = cv2.dilate(fg, kernel, iterations=1)
233
- print("add dilated process!")
234
-
235
- mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L")
236
-
237
- # 叠加彩色遮罩(红色半透明)
238
- base = image.convert("RGBA")
239
- overlay = Image.new("RGBA", base.size, (255, 0, 0, 0))
240
- alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8))
241
- overlay.putalpha(alpha)
242
- blended = Image.alpha_composite(base, overlay).convert("RGB")
243
-
244
- # ---- LaMa 擦除 ----
245
- img_np = np.array(image) # HWC, uint8
246
- mask_np = np.array(mask_img) # H,W, 0/255
247
- img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0
248
- mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0
249
- with torch.no_grad():
250
- inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W]
251
- inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
252
- inpainted_img = Image.fromarray(inpainted_np)
253
-
254
- return blended, mask_img, inpainted_img
255
-
256
- # ---- Gradio 界面 ----
257
- demo = gr.Interface(
258
- fn=segment,
259
- inputs=gr.Image(type="pil", label="Upload Image"),
260
- outputs=[
261
- gr.Image(type="pil", label="Overlay (foreground)"),
262
- gr.Image(type="pil", label="Binary Mask (foreground=white)"),
263
- gr.Image(type="pil", label="inpaint result"),
264
- ],
265
- title="Semantic Segmentation + LaMa Inpainting",
266
- description="DeepLabV3 分割 + Mask 膨胀 + LaMa 擦除,运行在 CPU 环境。"
267
- )
268
-
269
- if __name__ == "__main__":
270
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
6
+ from huggingface_hub import hf_hub_download
7
+ import cv2
8
+ import zipfile
9
+
10
+
11
+
12
+ # ---------------- 下载并加载 LaMa 官方权重 ----------------
13
+ repo_id = "JosephCatrambone/big-lama-torchscript"
14
+ model_path = hf_hub_download(repo_id=repo_id, filename="big-lama.pt")
15
+ # zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip")
16
+ # import zipfile
17
+ # with zipfile.ZipFile(zip_path, 'r') as z:
18
+ # z.extractall("./")
19
+ # model_path = "./models/best.ckpt"
20
+ lama_model = torch.jit.load(model_path, map_location="cpu")
21
+ lama_model.eval()
22
+
23
+ print("torch:", torch.__version__)
24
+ print("numpy:", np.__version__)
25
+
26
+ # ---- 加载分割模型(CPU) ----
27
+ device = torch.device("cpu")
28
+ weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
29
+ model = deeplabv3_resnet50(weights=weights).to(device).eval()
30
+ preprocess = weights.transforms()
31
+
32
+
33
+ MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边
34
+
35
+ def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image:
36
+ w, h = pil_img.size
37
+ if max(w, h) <= max_side:
38
+ return pil_img
39
+ r = max_side / float(max(w, h))
40
+ return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR)
41
+
42
+ def segment(image: Image.Image):
43
+ print("DEBUG: type(image) =", type(image), "mode=", getattr(image, "mode", None))
44
+ if not isinstance(image, Image.Image):
45
+ image = Image.fromarray(image)
46
+
47
+ image = image.convert("RGB")
48
+ image = _resize_if_needed(image)
49
+
50
+ # 预处理并推理
51
+ x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
52
+ x = x.unsqueeze(0).to(device) # [1,3,H,W]
53
+
54
+ with torch.no_grad():
55
+ out = model(x)["out"][0] # [C,H,W],C=21(含背景)
56
+ pred = out.argmax(0).cpu().numpy() # [H,W]
57
+
58
+ # 前景 = 非背景(背景类在COCO VOC权重下是0)
59
+ fg = (pred != 0).astype(np.uint8)
60
+
61
+ # ---------------- mask 膨胀 ----------------
62
+ kernel = np.ones((19,19), np.uint8)
63
+ fg_dilated = cv2.dilate(fg, kernel, iterations=1)
64
+ print("add dilated process!")
65
+
66
+ mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L")
67
+
68
+ # 叠加彩色遮罩(红色半透明)
69
+ base = image.convert("RGBA")
70
+ overlay = Image.new("RGBA", base.size, (255, 0, 0, 0))
71
+ alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8))
72
+ overlay.putalpha(alpha)
73
+ blended = Image.alpha_composite(base, overlay).convert("RGB")
74
+
75
+ # ---- LaMa 擦除 ----
76
+ img_np = np.array(image) # HWC, uint8
77
+ mask_np = np.array(mask_img) # H,W, 0/255
78
+ img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0
79
+ mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0
80
+ with torch.no_grad():
81
+ inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W]
82
+ inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
83
+ inpainted_img = Image.fromarray(inpainted_np)
84
+
85
+ return blended, mask_img, inpainted_img
86
+
87
+ # ---- Gradio 界面 ----
88
+ demo = gr.Interface(
89
+ fn=segment,
90
+ inputs=gr.Image(type="pil", label="Upload Image"),
91
+ outputs=[
92
+ gr.Image(type="pil", label="Overlay (foreground)"),
93
+ gr.Image(type="pil", label="Binary Mask (foreground=white)"),
94
+ gr.Image(type="pil", label="inpaint result"),
95
+ ],
96
+ title="Semantic Segmentation + LaMa Inpainting",
97
+ description="DeepLabV3 分割 + Mask 膨胀 + LaMa 擦除,运行在 CPU 环境。"
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()
102
+
103
  # import gradio as gr
104
  # import torch
105
+ # import torch.nn as nn
106
  # import numpy as np
107
  # from PIL import Image
108
  # from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
 
113
 
114
 
115
  # # ---------------- 下载并加载 LaMa 官方权重 ----------------
 
 
116
  # zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip")
 
117
  # with zipfile.ZipFile(zip_path, 'r') as z:
118
  # z.extractall("./")
119
+ # model_path = "./big-lama/models/best.ckpt"
120
+
121
+
122
+ # # ==========================================================
123
+ # # LaMa FBAResUNet 定义(官方结构)
124
+ # # ==========================================================
125
+ # class GatedConv(nn.Module):
126
+ # def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
127
+ # super().__init__()
128
+ # padding = (kernel_size - 1) // 2 * dilation
129
+ # self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
130
+ # padding=padding, dilation=dilation)
131
+ # self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
132
+ # padding=padding, dilation=dilation)
133
+ # self.sigmoid = nn.Sigmoid()
134
+
135
+ # def forward(self, x):
136
+ # feat = self.conv(x)
137
+ # mask = self.sigmoid(self.mask_conv(x))
138
+ # return feat * mask
139
+
140
+
141
+ # class FBAResUNet(nn.Module):
142
+ # def __init__(self, input_channels=4, output_channels=3, num_filters=64):
143
+ # super().__init__()
144
+ # self.enc1 = GatedConv(input_channels, num_filters)
145
+ # self.enc2 = GatedConv(num_filters, num_filters * 2, stride=2)
146
+ # self.enc3 = GatedConv(num_filters * 2, num_filters * 4, stride=2)
147
+ # self.enc4 = GatedConv(num_filters * 4, num_filters * 8, stride=2)
148
+
149
+ # self.middle = GatedConv(num_filters * 8, num_filters * 8)
150
+
151
+ # self.dec4 = nn.ConvTranspose2d(num_filters * 8, num_filters * 4, kernel_size=4, stride=2, padding=1)
152
+ # self.dec3 = nn.ConvTranspose2d(num_filters * 8, num_filters * 2, kernel_size=4, stride=2, padding=1)
153
+ # self.dec2 = nn.ConvTranspose2d(num_filters * 4, num_filters, kernel_size=4, stride=2, padding=1)
154
+ # self.dec1 = nn.Conv2d(num_filters * 2, output_channels, kernel_size=3, padding=1)
155
+
156
+ # self.relu = nn.ReLU(inplace=True)
157
+
158
+ # def forward(self, image, mask):
159
+ # # image: [B,3,H,W], mask: [B,1,H,W]
160
+ # x = torch.cat([image, mask], dim=1) # -> [B,4,H,W]
161
+
162
+ # e1 = self.enc1(x)
163
+ # e2 = self.enc2(self.relu(e1))
164
+ # e3 = self.enc3(self.relu(e2))
165
+ # e4 = self.enc4(self.relu(e3))
166
+
167
+ # m = self.middle(self.relu(e4))
168
+
169
+ # d4 = self.relu(self.dec4(m))
170
+ # d4 = torch.cat([d4, e3], dim=1)
171
+ # d3 = self.relu(self.dec3(d4))
172
+ # d3 = torch.cat([d3, e2], dim=1)
173
+ # d2 = self.relu(self.dec2(d3))
174
+ # d2 = torch.cat([d2, e1], dim=1)
175
+ # out = torch.sigmoid(self.dec1(d2))
176
+ # return out
177
+
178
+
179
+ # # ==========================================================
180
+ # # 加载 LaMa 预训练权重
181
+ # # ==========================================================
182
+ # checkpoint = torch.load(model_path, map_location="cpu")
183
+ # lama_model = FBAResUNet()
184
+ # if "state_dict" in checkpoint:
185
+ # state_dict = {k.replace("netG.", ""): v for k, v in checkpoint["state_dict"].items()}
186
+ # else:
187
+ # state_dict = checkpoint
188
+ # lama_model.load_state_dict(state_dict, strict=False)
189
  # lama_model.eval()
190
 
191
+
192
  # print("torch:", torch.__version__)
193
  # print("numpy:", np.__version__)
194
 
 
268
 
269
  # if __name__ == "__main__":
270
  # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -16,5 +16,6 @@ opencv-python
16
  gradio>=4.0.0
17
 
18
  # ---- LaMa inpainting 后续需要 ----
19
- pytorch-lightning
 
20
  huggingface_hub
 
16
  gradio>=4.0.0
17
 
18
  # ---- LaMa inpainting 后续需要 ----
19
+ # pytorch-lightning
20
+ # omegaconf
21
  huggingface_hub