Bla1r commited on
Commit
7c70589
·
verified ·
1 Parent(s): 83d196e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
  import spaces
5
  from transformers import AutoModelForImageSegmentation
@@ -8,19 +7,19 @@ from torchvision import transforms
8
  import uuid
9
  import os
10
 
11
- # Automatically select device based on availability
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # Optional: Print which device is being used
15
  print(f"Using device: {device}")
16
 
17
  torch.set_float32_matmul_precision(["high", "highest"][0])
18
 
 
19
  birefnet = AutoModelForImageSegmentation.from_pretrained(
20
  "ZhengPeng7/BiRefNet", trust_remote_code=True
21
  )
22
  birefnet.to(device)
23
 
 
24
  transform_image = transforms.Compose(
25
  [
26
  transforms.Resize((1024, 1024)),
@@ -33,7 +32,6 @@ transform_image = transforms.Compose(
33
  def process(image):
34
  image_size = image.size
35
  input_images = transform_image(image).unsqueeze(0).to(device)
36
- # Prediction
37
  with torch.no_grad():
38
  preds = birefnet(input_images)[-1].sigmoid().cpu()
39
  pred = preds[0].squeeze()
@@ -42,47 +40,45 @@ def process(image):
42
  image.putalpha(mask)
43
  return image
44
 
 
45
  def fn(image):
46
- im = load_img(image, output_type="pil")
47
- im = im.convert("RGB")
48
  processed_image = process(im)
49
 
50
- # Save to a temp file
51
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
52
  processed_image.save(filename)
53
 
54
- return processed_image, filename # Return both preview and downloadable file
55
 
 
56
  def process_file(f):
57
  name_path = f.rsplit(".", 1)[0] + ".png"
58
- im = load_img(f, output_type="pil")
59
- im = im.convert("RGB")
60
  transparent = process(im)
61
  transparent.save(name_path)
62
  return name_path
63
 
64
  # Components
65
- slider1 = gr.Image(label="Preview") # Changed from ImageSlider to Image
66
- output_file1 = gr.File(label="Download Processed Image")
67
- slider2 = ImageSlider(label="Processed Image from URL", type="pil")
68
  image_upload = gr.Image(label="Upload an image")
69
  image_file_upload = gr.Image(label="Upload an image", type="filepath")
70
  url_input = gr.Textbox(label="Paste an image URL")
71
  output_file = gr.File(label="Output PNG File")
72
 
73
- # Example images
74
- chameleon = load_img("dog.jpeg", output_type="pil")
75
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
 
76
 
77
- # Tabs
78
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=[slider1, output_file1], examples=[chameleon], api_name="image")
79
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
80
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["dog.jpeg"], api_name="png")
81
 
 
82
  demo = gr.TabbedInterface(
83
  [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
84
  )
85
 
86
-
87
  if __name__ == "__main__":
88
  demo.launch(show_error=True)
 
1
  import gradio as gr
 
2
  from loadimg import load_img
3
  import spaces
4
  from transformers import AutoModelForImageSegmentation
 
7
  import uuid
8
  import os
9
 
10
+ # Automatically select device (GPU or CPU)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
12
  print(f"Using device: {device}")
13
 
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
 
16
+ # Load the segmentation model
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  )
20
  birefnet.to(device)
21
 
22
+ # Image preprocessing
23
  transform_image = transforms.Compose(
24
  [
25
  transforms.Resize((1024, 1024)),
 
32
  def process(image):
33
  image_size = image.size
34
  input_images = transform_image(image).unsqueeze(0).to(device)
 
35
  with torch.no_grad():
36
  preds = birefnet(input_images)[-1].sigmoid().cpu()
37
  pred = preds[0].squeeze()
 
40
  image.putalpha(mask)
41
  return image
42
 
43
+ # Main function: returns both a preview and a downloadable file
44
  def fn(image):
45
+ im = load_img(image, output_type="pil").convert("RGB")
 
46
  processed_image = process(im)
47
 
 
48
  filename = f"/tmp/processed_{uuid.uuid4().hex}.png"
49
  processed_image.save(filename)
50
 
51
+ return processed_image, filename # Preview + downloadable image
52
 
53
+ # File-only processing tab
54
  def process_file(f):
55
  name_path = f.rsplit(".", 1)[0] + ".png"
56
+ im = load_img(f, output_type="pil").convert("RGB")
 
57
  transparent = process(im)
58
  transparent.save(name_path)
59
  return name_path
60
 
61
  # Components
62
+ preview_image = gr.Image(label="Preview")
63
+ download_file = gr.File(label="Download Processed Image")
 
64
  image_upload = gr.Image(label="Upload an image")
65
  image_file_upload = gr.Image(label="Upload an image", type="filepath")
66
  url_input = gr.Textbox(label="Paste an image URL")
67
  output_file = gr.File(label="Output PNG File")
68
 
69
+ # Example image (from URL instead of invalid local path)
 
70
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
71
+ chameleon = load_img(url_example, output_type="pil")
72
 
73
+ # Interface Tabs
74
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=[preview_image, download_file], examples=[chameleon], api_name="image")
75
+ tab2 = gr.Interface(fn, inputs=url_input, outputs=preview_image, examples=[url_example], api_name="text")
76
+ tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=[], api_name="png")
77
 
78
+ # Combine tabs into one app
79
  demo = gr.TabbedInterface(
80
  [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
81
  )
82
 
 
83
  if __name__ == "__main__":
84
  demo.launch(show_error=True)