KenjieDec Arhenniuss commited on
Commit
1822728
·
verified ·
1 Parent(s): c8f8b0e

Update app.py (#5)

Browse files

- Update app.py (144ebf93ca1b27c31e62ebdb367e248b70803f41)


Co-authored-by: Arhenniuss Mallikovv <[email protected]>

Files changed (1) hide show
  1. app.py +109 -109
app.py CHANGED
@@ -1,109 +1,109 @@
1
- import gradio as gr
2
- import os
3
- import cv2
4
- from rembg import new_session, remove
5
- from rembg.sessions import sessions_class
6
-
7
- def inference(file, mask, model, x, y):
8
- im = cv2.imread(file, cv2.IMREAD_COLOR)
9
- input_path = "input.png"
10
- output_path = "output.png"
11
- cv2.imwrite(input_path, im)
12
-
13
- with open(input_path, 'rb') as i:
14
- with open(output_path, 'wb') as o:
15
- input = i.read()
16
- session = new_session(model)
17
-
18
- output = remove(
19
- input,
20
- session=session,
21
- **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
22
- only_mask=(mask == "Mask only")
23
- )
24
- o.write(output)
25
-
26
- return output_path
27
-
28
- title = "RemBG"
29
- description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
30
- badge = """
31
- <div style="position: fixed; left: 50%; text-align: center;">
32
- <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
33
- <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
34
- </a>
35
- </div>
36
- """
37
- def get_coords(evt: gr.SelectData) -> tuple:
38
- return evt.index[0], evt.index[1]
39
-
40
- def show_coords(model: str):
41
- visible = model == "sam"
42
- return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
43
-
44
- for session in sessions_class:
45
- session.download_models()
46
-
47
- with gr.Blocks() as app:
48
- gr.Markdown(f"# {title}")
49
- gr.Markdown(description)
50
-
51
- with gr.Row():
52
- inputs = gr.Image(type="filepath", label="Input Image")
53
- outputs = gr.Image(type="filepath", label="Output Image")
54
-
55
- with gr.Row():
56
- mask_option = gr.Radio(
57
- ["Default", "Mask only"],
58
- value="Default",
59
- label="Output Type"
60
- )
61
- model_selector = gr.Dropdown(
62
- [
63
- "u2net",
64
- "u2netp",
65
- "u2net_human_seg",
66
- "u2net_cloth_seg",
67
- "silueta",
68
- "isnet-general-use",
69
- "isnet-anime",
70
- "sam",
71
- "birefnet-general",
72
- "birefnet-general-lite",
73
- "birefnet-portrait",
74
- "birefnet-dis",
75
- "birefnet-hrsod",
76
- "birefnet-cod",
77
- "birefnet-massive"
78
- ],
79
- value="isnet-general-use",
80
- label="Model Selection"
81
- )
82
-
83
- extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
84
-
85
- x = gr.Number(label="Mouse X Coordinate", visible=False)
86
- y = gr.Number(label="Mouse Y Coordinate", visible=False)
87
-
88
- model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
89
- inputs.select(get_coords, None, [x, y])
90
-
91
-
92
- gr.Button("Process Image").click(
93
- inference,
94
- inputs=[inputs, mask_option, model_selector, x, y],
95
- outputs=outputs
96
- )
97
-
98
- gr.Examples(
99
- examples=[
100
- ["lion.png", "Default", "u2net", None, None],
101
- ["girl.jpg", "Default", "u2net", None, None],
102
- ["anime-girl.jpg", "Default", "isnet-anime", None, None]
103
- ],
104
- inputs=[inputs, mask_option, model_selector, x, y],
105
- outputs=outputs
106
- )
107
- gr.HTML(badge)
108
-
109
- app.launch()
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ from rembg import new_session, remove
5
+ from rembg.sessions import sessions_class
6
+
7
+ def inference(file, mask, model, x, y):
8
+ im = cv2.imread(file, cv2.IMREAD_COLOR)
9
+ input_path = "input.png"
10
+ output_path = "output.png"
11
+ cv2.imwrite(input_path, im)
12
+
13
+ with open(input_path, 'rb') as i:
14
+ with open(output_path, 'wb') as o:
15
+ input = i.read()
16
+ session = new_session(model)
17
+
18
+ output = remove(
19
+ input,
20
+ session=session,
21
+ **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
22
+ only_mask=(mask == "Mask only")
23
+ )
24
+ o.write(output)
25
+
26
+ return output_path
27
+
28
+ title = "RemBG"
29
+ description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
30
+ badge = """
31
+ <div style="position: fixed; left: 50%; text-align: center;">
32
+ <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
33
+ <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
34
+ </a>
35
+ </div>
36
+ """
37
+ def get_coords(evt: gr.SelectData) -> tuple:
38
+ return evt.index[0], evt.index[1]
39
+
40
+ def show_coords(model: str):
41
+ visible = model == "sam"
42
+ return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
43
+
44
+ for session in sessions_class:
45
+ session.download_models()
46
+
47
+ with gr.Blocks() as app:
48
+ gr.Markdown(f"# {title}")
49
+ gr.Markdown(description)
50
+
51
+ with gr.Row():
52
+ inputs = gr.Image(type="filepath", label="Input Image")
53
+ outputs = gr.Image(type="filepath", label="Output Image")
54
+
55
+ with gr.Row():
56
+ mask_option = gr.Radio(
57
+ ["Default", "Mask only"],
58
+ value="Default",
59
+ label="Output Type"
60
+ )
61
+ model_selector = gr.Dropdown(
62
+ [
63
+ "u2net",
64
+ "u2netp",
65
+ "u2net_human_seg",
66
+ "u2net_cloth_seg",
67
+ "silueta",
68
+ "isnet-general-use",
69
+ "isnet-anime",
70
+ "sam",
71
+ "birefnet-general",
72
+ "birefnet-general-lite",
73
+ "birefnet-portrait",
74
+ "birefnet-dis",
75
+ "birefnet-hrsod",
76
+ "birefnet-cod",
77
+ "birefnet-massive"
78
+ ],
79
+ value="isnet-general-use",
80
+ label="Model Selection"
81
+ )
82
+
83
+ extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
84
+
85
+ x = gr.Number(label="Mouse X Coordinate", visible=False)
86
+ y = gr.Number(label="Mouse Y Coordinate", visible=False)
87
+
88
+ model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
89
+ inputs.select(get_coords, None, [x, y])
90
+
91
+
92
+ gr.Button("Process Image").click(
93
+ inference,
94
+ inputs=[inputs, mask_option, model_selector, x, y],
95
+ outputs=outputs
96
+ )
97
+
98
+ gr.Examples(
99
+ examples=[
100
+ ["lion.png", "Default", "u2net", None, None],
101
+ ["girl.jpg", "Default", "u2net", None, None],
102
+ ["anime-girl.jpg", "Default", "isnet-anime", None, None]
103
+ ],
104
+ inputs=[inputs, mask_option, model_selector, x, y],
105
+ outputs=outputs
106
+ )
107
+ gr.HTML(badge)
108
+
109
+ app.launch(share=True)