Bla1r commited on
Commit
a97b66d
·
verified ·
1 Parent(s): 7c70589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -7,19 +7,19 @@ from torchvision import transforms
7
  import uuid
8
  import os
9
 
10
- # Automatically select device (GPU or CPU)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
16
- # Load the segmentation model
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  )
20
  birefnet.to(device)
21
 
22
- # Image preprocessing
23
  transform_image = transforms.Compose(
24
  [
25
  transforms.Resize((1024, 1024)),
@@ -40,17 +40,15 @@ def process(image):
40
  image.putalpha(mask)
41
  return image
42
 
43
- # Main function: returns both a preview and a downloadable file
44
  def fn(image):
45
  im = load_img(image, output_type="pil").convert("RGB")
46
  processed_image = process(im)
47
-
48
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
49
  processed_image.save(filename)
 
50
 
51
- return processed_image, filename # Preview + downloadable image
52
-
53
- # File-only processing tab
54
  def process_file(f):
55
  name_path = f.rsplit(".", 1)[0] + ".png"
56
  im = load_img(f, output_type="pil").convert("RGB")
@@ -58,26 +56,45 @@ def process_file(f):
58
  transparent.save(name_path)
59
  return name_path
60
 
61
- # Components
62
- preview_image = gr.Image(label="Preview")
63
- download_file = gr.File(label="Download Processed Image")
64
- image_upload = gr.Image(label="Upload an image")
65
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
66
- url_input = gr.Textbox(label="Paste an image URL")
67
- output_file = gr.File(label="Output PNG File")
68
-
69
- # Example image (from URL instead of invalid local path)
70
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
71
  chameleon = load_img(url_example, output_type="pil")
72
 
73
- # Interface Tabs
74
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=[preview_image, download_file], examples=[chameleon], api_name="image")
75
- tab2 = gr.Interface(fn, inputs=url_input, outputs=preview_image, examples=[url_example], api_name="text")
76
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=[], api_name="png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Combine tabs into one app
79
  demo = gr.TabbedInterface(
80
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
 
 
81
  )
82
 
83
  if __name__ == "__main__":
 
7
  import uuid
8
  import os
9
 
10
+ # Device selection
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
16
+ # Load model
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  )
20
  birefnet.to(device)
21
 
22
+ # Preprocessing
23
  transform_image = transforms.Compose(
24
  [
25
  transforms.Resize((1024, 1024)),
 
40
  image.putalpha(mask)
41
  return image
42
 
43
+ # Image Upload tab
44
  def fn(image):
45
  im = load_img(image, output_type="pil").convert("RGB")
46
  processed_image = process(im)
 
47
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
48
  processed_image.save(filename)
49
+ return processed_image, filename
50
 
51
+ # File tab
 
 
52
  def process_file(f):
53
  name_path = f.rsplit(".", 1)[0] + ".png"
54
  im = load_img(f, output_type="pil").convert("RGB")
 
56
  transparent.save(name_path)
57
  return name_path
58
 
59
+ # Example image
 
 
 
 
 
 
 
 
60
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
61
  chameleon = load_img(url_example, output_type="pil")
62
 
63
+ # Tab 1: Upload + Preview + Download
64
+ tab1 = gr.Interface(
65
+ fn,
66
+ inputs=gr.Image(label="Upload an image"),
67
+ outputs=[
68
+ gr.Image(label="Preview"),
69
+ gr.File(label="Download Processed Image")
70
+ ],
71
+ examples=[chameleon],
72
+ api_name="image"
73
+ )
74
+
75
+ # Tab 2: URL input + Preview
76
+ tab2 = gr.Interface(
77
+ fn,
78
+ inputs=gr.Textbox(label="Paste an image URL"),
79
+ outputs=gr.Image(label="Preview"),
80
+ examples=[url_example],
81
+ api_name="text"
82
+ )
83
+
84
+ # Tab 3: File path input + downloadable result
85
+ tab3 = gr.Interface(
86
+ process_file,
87
+ inputs=gr.Image(label="Upload an image", type="filepath"),
88
+ outputs=gr.File(label="Output PNG File"),
89
+ examples=[],
90
+ api_name="png"
91
+ )
92
 
93
+ # Final App
94
  demo = gr.TabbedInterface(
95
+ [tab1, tab2, tab3],
96
+ ["Image Upload", "URL Input", "File Output"],
97
+ title="Background Removal Tool"
98
  )
99
 
100
  if __name__ == "__main__":