Janeka's picture
Update app.py
8f3a72a verified
import gradio as gr
from rembg import remove
from PIL import Image, ImageDraw
import numpy as np
import io
# Function to remove background automatically
def remove_bg(image):
if image is None:
return None
image = Image.open(image)
output = remove(image)
return output
# Function to refine background using + and - points
def refine_bg(image, points, threshold, mode):
if image is None:
return None
image = Image.open(image).convert("RGBA")
mask = Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask)
for point in points:
x, y = point
color = 255 if mode == "keep" else 0
draw.ellipse((x - threshold * 10, y - threshold * 10, x + threshold * 10, y + threshold * 10), fill=color)
refined = Image.composite(image, Image.new("RGBA", image.size, (0, 0, 0, 0)), mask)
return refined
# Gradio UI
with gr.Blocks() as iface:
gr.Markdown("# AI Background Remover")
with gr.Row():
input_img = gr.Image(type="filepath", label="Upload Image")
output_img = gr.Image(type="pil", label="Output Image")
remove_btn = gr.Button("Remove Background")
remove_btn.click(remove_bg, inputs=input_img, outputs=output_img)
refine_btn = gr.Button("Refine")
refine_options = gr.Column(visible=False)
with refine_options:
gr.Markdown("### Refine Background")
refine_editor = gr.Image(type="filepath", label="Tap to Add Points", interactive=True)
keep_btn = gr.Button("+ (Keep)")
remove_btn = gr.Button("- (Remove)")
threshold_slider = gr.Slider(0.00, 1.00, value=0.5, step=0.01, label="Threshold")
apply_refine_btn = gr.Button("Apply Refinements")
refine_mode = gr.State("keep")
def show_refine_section(image):
return gr.update(visible=True), image # Show section and keep image visible
def set_mode_keep():
return "keep", "+ (Keep) Selected"
def set_mode_remove():
return "remove", "- (Remove) Selected"
refine_btn.click(show_refine_section, inputs=input_img, outputs=[refine_options, refine_editor])
keep_btn.click(set_mode_keep, outputs=[refine_mode, keep_btn])
remove_btn.click(set_mode_remove, outputs=[refine_mode, remove_btn])
apply_refine_btn.click(refine_bg, inputs=[input_img, refine_editor, threshold_slider, refine_mode], outputs=output_img)
iface.launch()