Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| # os.chdir("../") | |
| import gradio as gr | |
| import numpy as np | |
| from pathlib import Path | |
| from matplotlib import pyplot as plt | |
| import torch | |
| import tempfile | |
| from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama | |
| from PIL import Image | |
| #sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything")) | |
| import argparse | |
| import os | |
| import matplotlib.pyplot as plt | |
| from pylab import imshow, imsave | |
| import detectron2 | |
| from detectron2.utils.logger import setup_logger | |
| setup_logger() | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer, ColorMode | |
| from detectron2.data import MetadataCatalog | |
| coco_metadata = MetadataCatalog.get("coco_2017_val") | |
| # import PointRend project | |
| from detectron2_repo.projects.PointRend import point_rend | |
| title = "# PeopleRemover" | |
| description = """ | |
| In this space, you can remove the amount of people you want from a picture. | |
| β οΈ This is just a demo version! | |
| """ | |
| def setup_args(parser): | |
| parser.add_argument( | |
| "--lama_config", type=str, | |
| default="./third_party/lama/configs/prediction/default.yaml", | |
| help="The path to the config file of lama model. " | |
| "Default: the config of big-lama", | |
| ) | |
| parser.add_argument( | |
| "--lama_ckpt", type=str, | |
| default="pretrained_models/big-lama", | |
| help="The path to the lama checkpoint.", | |
| ) | |
| def get_mask(img, num_people_keep, dilate_kernel_size): | |
| cfg = get_cfg() | |
| # Add PointRend-specific config | |
| point_rend.add_pointrend_config(cfg) | |
| # Load a config from file | |
| cfg.merge_from_file("detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml") | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model | |
| # Set when using CPU | |
| cfg.MODEL.DEVICE='cpu' | |
| # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models | |
| cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco/28119989/model_final_ba17b9.pkl" | |
| predictor = DefaultPredictor(cfg) | |
| outputs = predictor(img) | |
| # Select 'people' instances | |
| people_instances = outputs["instances"][outputs["instances"].pred_classes == 0] | |
| # Eliminate the instances of the people we want to keep | |
| eliminate_instances = people_instances[num_people_keep:] | |
| # Generate mask | |
| blank_mask = np.ones((img.shape[0],img.shape[1]), dtype=np.uint8) * 255 | |
| full_mask = np.zeros((img.shape[0],img.shape[1]), dtype=np.uint8) * 255 | |
| for instance_mask in eliminate_instances.pred_masks: | |
| full_mask = full_mask + blank_mask*instance_mask.to("cpu").numpy() | |
| full_mask = full_mask.reshape((img.shape[0],img.shape[1],1)) | |
| mask = full_mask.astype(np.uint8) | |
| # Dilation | |
| kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) | |
| mask_dilation = cv2.dilate(mask, kernel, iterations=2) | |
| return mask_dilation | |
| def get_inpainted_img(img, mask): | |
| lama_config = args.lama_config | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| img_inpainted = inpaint_img_with_builded_lama( | |
| model['lama'], img, mask, lama_config, device=device) | |
| return img_inpainted | |
| def remove_people(img, num_people_keep, dilate_kernel_size): | |
| print('Obtaining mask...') | |
| mask = get_mask(img, num_people_keep, dilate_kernel_size) | |
| print('Mask obtained') | |
| print('Inpainting with LAMA...') | |
| out = get_inpainted_img(img, mask) | |
| print('Image Inpainted!') | |
| return out | |
| # get args | |
| parser = argparse.ArgumentParser() | |
| setup_args(parser) | |
| args = parser.parse_args(sys.argv[1:]) | |
| # build models | |
| model = {} | |
| # build the lama model | |
| lama_config = args.lama_config | |
| lama_ckpt = args.lama_ckpt | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| features = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img = gr.Image(height=300)# value="Input Image" .style(height="200px") | |
| num_people_keep = gr.Number(label="Number of people to keep", minimum=0, maximum=100) | |
| dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=30, step=1, value=5) | |
| lama = gr.Button(value="Remove people", variant="primary", size="sm")#.style(full_width=True, size="sm") | |
| clear_button_image = gr.Button(value="Reset", variant="secondary", size="sm")#.style(full_width=True, size="sm") | |
| with gr.Column(scale=1): | |
| img_out = gr.Image(interactive=False,show_download_button=True)# value="Image with People Removed", type="numpy", .style(height="200px") | |
| #mask = gr.outputs.Image(type="numpy", label="Segmentation Mask")#.style(height="200px") | |
| lama.click( | |
| remove_people, | |
| [img, num_people_keep, dilate_kernel_size], | |
| [img_out] | |
| ) | |
| def reset(*args): | |
| return [None for _ in args] | |
| clear_button_image.click( | |
| reset, | |
| [img, features, img_out], | |
| [img, features, img_out] | |
| ) | |
| gr.Examples( | |
| examples=[[os.path.join(os.getcwd(), "examples/002.jpg"), 2, 15], | |
| [os.path.join(os.getcwd(), "examples/013.jpg"), 1, 15], | |
| [os.path.join(os.getcwd(), "examples/014.jpg"), 1, 15], | |
| [os.path.join(os.getcwd(), "examples/015.jpg"), 1, 15], | |
| [os.path.join(os.getcwd(), "examples/002.jpg"), 0, 15]], | |
| inputs=[img, num_people_keep, dilate_kernel_size], | |
| outputs=img_out, | |
| fn=remove_people, | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |