Bla1r commited on
Commit
52832e3
·
verified ·
1 Parent(s): a97b66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -45
app.py CHANGED
@@ -7,13 +7,13 @@ from torchvision import transforms
7
  import uuid
8
  import os
9
 
10
- # Device selection
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
16
- # Load model
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  )
@@ -40,60 +40,24 @@ def process(image):
40
  image.putalpha(mask)
41
  return image
42
 
43
- # Image Upload tab
44
  def fn(image):
45
  im = load_img(image, output_type="pil").convert("RGB")
46
  processed_image = process(im)
 
47
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
48
  processed_image.save(filename)
49
- return processed_image, filename
50
-
51
- # File tab
52
- def process_file(f):
53
- name_path = f.rsplit(".", 1)[0] + ".png"
54
- im = load_img(f, output_type="pil").convert("RGB")
55
- transparent = process(im)
56
- transparent.save(name_path)
57
- return name_path
58
 
59
- # Example image
60
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
61
- chameleon = load_img(url_example, output_type="pil")
62
 
63
- # Tab 1: Upload + Preview + Download
64
- tab1 = gr.Interface(
65
  fn,
66
  inputs=gr.Image(label="Upload an image"),
67
  outputs=[
68
- gr.Image(label="Preview"),
69
- gr.File(label="Download Processed Image")
70
  ],
71
- examples=[chameleon],
72
- api_name="image"
73
- )
74
-
75
- # Tab 2: URL input + Preview
76
- tab2 = gr.Interface(
77
- fn,
78
- inputs=gr.Textbox(label="Paste an image URL"),
79
- outputs=gr.Image(label="Preview"),
80
- examples=[url_example],
81
- api_name="text"
82
- )
83
-
84
- # Tab 3: File path input + downloadable result
85
- tab3 = gr.Interface(
86
- process_file,
87
- inputs=gr.Image(label="Upload an image", type="filepath"),
88
- outputs=gr.File(label="Output PNG File"),
89
- examples=[],
90
- api_name="png"
91
- )
92
-
93
- # Final App
94
- demo = gr.TabbedInterface(
95
- [tab1, tab2, tab3],
96
- ["Image Upload", "URL Input", "File Output"],
97
  title="Background Removal Tool"
98
  )
99
 
 
7
  import uuid
8
  import os
9
 
10
+ # Select device
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
16
+ # Load BiRefNet model
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  )
 
40
  image.putalpha(mask)
41
  return image
42
 
43
+ # Main function: image upload → preview + downloadable PNG
44
  def fn(image):
45
  im = load_img(image, output_type="pil").convert("RGB")
46
  processed_image = process(im)
47
+
48
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
49
  processed_image.save(filename)
 
 
 
 
 
 
 
 
 
50
 
51
+ return processed_image, filename
 
 
52
 
53
+ # Gradio interface
54
+ demo = gr.Interface(
55
  fn,
56
  inputs=gr.Image(label="Upload an image"),
57
  outputs=[
58
+ gr.Image(label="Processed Preview"),
59
+ gr.File(label="Download PNG")
60
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  title="Background Removal Tool"
62
  )
63