Janeka commited on
Commit
888c15e
·
verified ·
1 Parent(s): 5dadb9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -41
app.py CHANGED
@@ -1,54 +1,55 @@
 
1
  import numpy as np
2
  from PIL import Image
3
- from rembg import new_session, remove
4
- from skimage import filters
5
- import gradio as gr
6
- import cv2 # Added for dilation and GaussianBlur
7
 
8
- # Initialize models
9
- isnet_session = new_session("isnet-general-use")
10
- u2net_session = new_session("u2net")
 
 
 
 
 
11
 
12
- def perfect_remove_bg(img):
13
  try:
14
- # Convert input
15
- if isinstance(img, np.ndarray):
16
- img = Image.fromarray(img)
17
-
18
- # ISNet for details
19
- result = remove(img, session=isnet_session)
20
- mask = np.array(result.split()[-1])
21
-
22
- # Edge refinement
23
- mask = filters.rank.mean(
24
- mask.astype(np.uint8),
25
- footprint=np.ones((3,3), np.uint8)
26
- )
27
-
28
- # Enhanced edge preservation
29
- mask = cv2.dilate(mask, np.ones((1,1), np.uint8), iterations=1)
30
- mask = cv2.GaussianBlur(mask, (3,3), 0)
31
-
32
- # U²Net mask
33
- u2net_result = remove(img, session=u2net_session)
34
- u2net_mask = np.array(u2net_result.split()[-1])
35
-
36
- # Combine masks
37
- final_mask = np.where(u2net_mask > 200, mask, u2net_mask)
38
- result.putalpha(Image.fromarray(final_mask))
39
 
 
 
40
  return result
41
-
42
  except Exception as e:
43
- print(f"Error: {e}")
44
- return remove(img, session=u2net_session)
45
 
46
- # Gradio interface
47
  with gr.Blocks() as demo:
48
- gr.Markdown("## Professional BG Remover")
 
49
  with gr.Row():
50
- input_img = gr.Image(label="Input", type="pil")
51
- output_img = gr.Image(label="Output", type="pil")
52
- gr.Button("Process").click(perfect_remove_bg, inputs=input_img, outputs=output_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  demo.launch()
 
1
+ import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ from rembg import remove, new_session
 
 
 
5
 
6
+ # Initialize session with proper settings to prevent cropping
7
+ session = new_session("u2net")
8
+ bg_removal_kwargs = {
9
+ "alpha_matting": False, # Disable advanced features that cause cropping
10
+ "session": session,
11
+ "only_mask": False,
12
+ "post_process_mask": True # Clean edges without cropping
13
+ }
14
 
15
+ def remove_background(input_image):
16
  try:
17
+ # Convert any input to PIL Image
18
+ if isinstance(input_image, np.ndarray):
19
+ img = Image.fromarray(input_image)
20
+ elif isinstance(input_image, dict): # Handle paste/drop
21
+ img = Image.open(input_image["name"])
22
+ else:
23
+ img = input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Preserve original size (disable auto-resizing)
26
+ result = remove(img, **bg_removal_kwargs)
27
  return result
28
+
29
  except Exception as e:
30
+ print(f"Error: {str(e)}")
31
+ return input_image # Return original if fails
32
 
33
+ # Gradio interface with proper image handling
34
  with gr.Blocks() as demo:
35
+ gr.Markdown("## 🖼️ Background Remover (No Cropping)")
36
+
37
  with gr.Row():
38
+ input_img = gr.Image(
39
+ label="Original",
40
+ type="pil", # Ensures we get PIL Images
41
+ height=400
42
+ )
43
+ output_img = gr.Image(
44
+ label="Result",
45
+ type="pil",
46
+ height=400
47
+ )
48
+
49
+ gr.Button("Remove Background").click(
50
+ remove_background,
51
+ inputs=input_img,
52
+ outputs=output_img
53
+ )
54
 
55
  demo.launch()