import gradio as gr import spaces import torch from loadimg import load_img from torchvision import transforms from transformers import AutoModelForImageSegmentation from diffusers import FluxFillPipeline from PIL import Image, ImageOps from sam2.sam2_image_predictor import SAM2ImagePredictor import numpy as np torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda") def prepare_image_and_mask( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, ): image = load_img(image).convert("RGB") # expand image (left,top,right,bottom) background = ImageOps.expand( image, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) mask = Image.new("RGB", image.size, "black") mask = ImageOps.expand( mask, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) return background, mask def outpaint( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, prompt="", num_inference_steps=28, guidance_scale=50, ): background, mask = prepare_image_and_mask( image, padding_top, padding_bottom, padding_left, padding_right ) result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def inpaint( image, mask, prompt="", num_inference_steps=28, guidance_scale=50, ): background = image.convert("RGB") mask = mask.convert("L") result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def rmbg(image=None, url=None): if image is None: image = url image = load_img(image).convert("RGB") image_size = image.size input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def mask_generation(image=None, d=None): d = eval(d) # convert this to dictionary predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large") predictor.set_image(image) input_point = np.array(d["input_points"]) input_label = np.array(d["input_labels"]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] scores = scores[sorted_ind] logits = logits[sorted_ind] out = [] for i in range(len(masks)): m = Image.fromarray(masks[i] * 255).convert("L") comp = Image.composite(image, m, m) out.append((comp, f"image {i}")) return out @spaces.GPU def main(*args): api_num = args[0] args = args[1:] if api_num == 1: return rmbg(*args) elif api_num == 2: return outpaint(*args) elif api_num == 3: return inpaint(*args) elif api_num == 4: return mask_generation(*args) rmbg_tab = gr.Interface( fn=main, inputs=[ gr.Number(1, interactive=False), "image", gr.Text("", label="url"), ], outputs=["image"], api_name="rmbg", examples=[[1, "./assets/Inpainting mask.png", ""]], cache_examples=False, description="pass an image or a url of an image", ) outpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(2, interactive=False), gr.Image(label="image", type="pil"), gr.Number(label="padding top"), gr.Number(label="padding bottom"), gr.Number(label="padding left"), gr.Number(label="padding right"), gr.Text(label="prompt"), gr.Number(value=50, label="num_inference_steps"), gr.Number(value=28, label="guidance_scale"), ], outputs=["image"], api_name="outpainting", examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]], cache_examples=False, ) inpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(3, interactive=False), gr.Image(label="image", type="pil"), gr.Image(label="mask", type="pil"), gr.Text(label="prompt"), gr.Number(value=50, label="num_inference_steps"), gr.Number(value=28, label="guidance_scale"), ], outputs=["image"], api_name="inpaint", examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]], cache_examples=False, description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space", ) sam2_tab = gr.Interface( main, inputs=[ gr.Number(4, interactive=False), gr.Image(type="pil"), gr.Text(), ], outputs=gr.Gallery(), examples=[ [ 4, "./assets/truck.jpg", '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}', ] ], api_name="sam2", cache_examples=False, ) demo = gr.TabbedInterface( [rmbg_tab, outpaint_tab, inpaint_tab, sam2_tab], ["remove background", "outpainting", "inpainting", "sam2"], title="Utilities that require GPU", ) demo.launch()