import gradio as gr import torch import numpy as np from PIL import Image from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights from huggingface_hub import hf_hub_download import cv2 import zipfile # ---------------- 下载并加载 LaMa 官方权重 ---------------- repo_id = "JosephCatrambone/big-lama-torchscript" model_path = hf_hub_download(repo_id=repo_id, filename="big-lama.pt") # zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip") # import zipfile # with zipfile.ZipFile(zip_path, 'r') as z: # z.extractall("./") # model_path = "./models/best.ckpt" lama_model = torch.jit.load(model_path, map_location="cpu") lama_model.eval() print("torch:", torch.__version__) print("numpy:", np.__version__) # ---- 加载分割模型(CPU) ---- device = torch.device("cpu") weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 model = deeplabv3_resnet50(weights=weights).to(device).eval() preprocess = weights.transforms() MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边 def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image: w, h = pil_img.size if max(w, h) <= max_side: return pil_img r = max_side / float(max(w, h)) return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR) def segment(image: Image.Image): print("DEBUG: type(image) =", type(image), "mode=", getattr(image, "mode", None)) if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") image = _resize_if_needed(image) # 预处理并推理 x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 x = x.unsqueeze(0).to(device) # [1,3,H,W] with torch.no_grad(): out = model(x)["out"][0] # [C,H,W],C=21(含背景) pred = out.argmax(0).cpu().numpy() # [H,W] # 前景 = 非背景(背景类在COCO VOC权重下是0) fg = (pred != 0).astype(np.uint8) # ---------------- mask 膨胀 ---------------- kernel = np.ones((19,19), np.uint8) fg_dilated = cv2.dilate(fg, kernel, iterations=1) print("add dilated process!") mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L") # 叠加彩色遮罩(红色半透明) base = image.convert("RGBA") overlay = Image.new("RGBA", base.size, (255, 0, 0, 0)) alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8)) overlay.putalpha(alpha) blended = Image.alpha_composite(base, overlay).convert("RGB") # ---- LaMa 擦除 ---- img_np = np.array(image) # HWC, uint8 mask_np = np.array(mask_img) # H,W, 0/255 img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0 mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0 with torch.no_grad(): inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W] inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) inpainted_img = Image.fromarray(inpainted_np) return blended, mask_img, inpainted_img # ---- Gradio 界面 ---- demo = gr.Interface( fn=segment, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Image(type="pil", label="Overlay (foreground)"), gr.Image(type="pil", label="Binary Mask (foreground=white)"), gr.Image(type="pil", label="inpaint result"), ], title="Semantic Segmentation + LaMa Inpainting", description="DeepLabV3 分割 + Mask 膨胀 + LaMa 擦除,运行在 CPU 环境。" ) if __name__ == "__main__": demo.launch() # import gradio as gr # import torch # import torch.nn as nn # import numpy as np # from PIL import Image # from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights # from huggingface_hub import hf_hub_download # import cv2 # import zipfile # # ---------------- 下载并加载 LaMa 官方权重 ---------------- # zip_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.zip") # with zipfile.ZipFile(zip_path, 'r') as z: # z.extractall("./") # model_path = "./big-lama/models/best.ckpt" # # ========================================================== # # LaMa FBAResUNet 定义(官方结构) # # ========================================================== # class GatedConv(nn.Module): # def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1): # super().__init__() # padding = (kernel_size - 1) // 2 * dilation # self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, # padding=padding, dilation=dilation) # self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, # padding=padding, dilation=dilation) # self.sigmoid = nn.Sigmoid() # def forward(self, x): # feat = self.conv(x) # mask = self.sigmoid(self.mask_conv(x)) # return feat * mask # class FBAResUNet(nn.Module): # def __init__(self, input_channels=4, output_channels=3, num_filters=64): # super().__init__() # self.enc1 = GatedConv(input_channels, num_filters) # self.enc2 = GatedConv(num_filters, num_filters * 2, stride=2) # self.enc3 = GatedConv(num_filters * 2, num_filters * 4, stride=2) # self.enc4 = GatedConv(num_filters * 4, num_filters * 8, stride=2) # self.middle = GatedConv(num_filters * 8, num_filters * 8) # self.dec4 = nn.ConvTranspose2d(num_filters * 8, num_filters * 4, kernel_size=4, stride=2, padding=1) # self.dec3 = nn.ConvTranspose2d(num_filters * 8, num_filters * 2, kernel_size=4, stride=2, padding=1) # self.dec2 = nn.ConvTranspose2d(num_filters * 4, num_filters, kernel_size=4, stride=2, padding=1) # self.dec1 = nn.Conv2d(num_filters * 2, output_channels, kernel_size=3, padding=1) # self.relu = nn.ReLU(inplace=True) # def forward(self, image, mask): # # image: [B,3,H,W], mask: [B,1,H,W] # x = torch.cat([image, mask], dim=1) # -> [B,4,H,W] # e1 = self.enc1(x) # e2 = self.enc2(self.relu(e1)) # e3 = self.enc3(self.relu(e2)) # e4 = self.enc4(self.relu(e3)) # m = self.middle(self.relu(e4)) # d4 = self.relu(self.dec4(m)) # d4 = torch.cat([d4, e3], dim=1) # d3 = self.relu(self.dec3(d4)) # d3 = torch.cat([d3, e2], dim=1) # d2 = self.relu(self.dec2(d3)) # d2 = torch.cat([d2, e1], dim=1) # out = torch.sigmoid(self.dec1(d2)) # return out # # ========================================================== # # 加载 LaMa 预训练权重 # # ========================================================== # checkpoint = torch.load(model_path, map_location="cpu") # lama_model = FBAResUNet() # if "state_dict" in checkpoint: # state_dict = {k.replace("netG.", ""): v for k, v in checkpoint["state_dict"].items()} # else: # state_dict = checkpoint # lama_model.load_state_dict(state_dict, strict=False) # lama_model.eval() # print("torch:", torch.__version__) # print("numpy:", np.__version__) # # ---- 加载分割模型(CPU) ---- # device = torch.device("cpu") # weights = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1 # model = deeplabv3_resnet50(weights=weights).to(device).eval() # preprocess = weights.transforms() # MAX_SIDE = 1024 # 为了速度与内存,限制输入最大边 # def _resize_if_needed(pil_img: Image.Image, max_side=MAX_SIDE) -> Image.Image: # w, h = pil_img.size # if max(w, h) <= max_side: # return pil_img # r = max_side / float(max(w, h)) # return pil_img.resize((int(w * r), int(h * r)), Image.BILINEAR) # def segment(image: Image.Image): # print("DEBUG: type(image) =", type(image), "mode=", getattr(image, "mode", None)) # if not isinstance(image, Image.Image): # image = Image.fromarray(image) # image = image.convert("RGB") # image = _resize_if_needed(image) # # 预处理并推理 # x = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 # x = x.unsqueeze(0).to(device) # [1,3,H,W] # with torch.no_grad(): # out = model(x)["out"][0] # [C,H,W],C=21(含背景) # pred = out.argmax(0).cpu().numpy() # [H,W] # # 前景 = 非背景(背景类在COCO VOC权重下是0) # fg = (pred != 0).astype(np.uint8) # # ---------------- mask 膨胀 ---------------- # kernel = np.ones((19,19), np.uint8) # fg_dilated = cv2.dilate(fg, kernel, iterations=1) # print("add dilated process!") # mask_img = Image.fromarray((fg_dilated * 255).astype(np.uint8), mode="L") # # 叠加彩色遮罩(红色半透明) # base = image.convert("RGBA") # overlay = Image.new("RGBA", base.size, (255, 0, 0, 0)) # alpha = Image.fromarray((fg_dilated * 120).astype(np.uint8)) # overlay.putalpha(alpha) # blended = Image.alpha_composite(base, overlay).convert("RGB") # # ---- LaMa 擦除 ---- # img_np = np.array(image) # HWC, uint8 # mask_np = np.array(mask_img) # H,W, 0/255 # img_t = torch.from_numpy(img_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0 # mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).float() / 255.0 # with torch.no_grad(): # inpainted_t = lama_model(img_t, mask_t) # [1,3,H,W] # inpainted_np = (inpainted_t[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # inpainted_img = Image.fromarray(inpainted_np) # return blended, mask_img, inpainted_img # # ---- Gradio 界面 ---- # demo = gr.Interface( # fn=segment, # inputs=gr.Image(type="pil", label="Upload Image"), # outputs=[ # gr.Image(type="pil", label="Overlay (foreground)"), # gr.Image(type="pil", label="Binary Mask (foreground=white)"), # gr.Image(type="pil", label="inpaint result"), # ], # title="Semantic Segmentation + LaMa Inpainting", # description="DeepLabV3 分割 + Mask 膨胀 + LaMa 擦除,运行在 CPU 环境。" # ) # if __name__ == "__main__": # demo.launch()