ABHISHEKSINGH0204 commited on
Commit
aa30951
·
verified ·
1 Parent(s): 3de41b0

initial commit

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import rembg
10
+ import torch
11
+ from PIL import Image
12
+ from functools import partial
13
+
14
+ from tsr.system import TSR
15
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
16
+
17
+ import argparse
18
+
19
+
20
+ if torch.cuda.is_available():
21
+ device = "cuda:0"
22
+ else:
23
+ device = "cpu"
24
+
25
+ model = TSR.from_pretrained(
26
+ "stabilityai/TripoSR",
27
+ config_name="config.yaml",
28
+ weight_name="model.ckpt",
29
+ )
30
+
31
+ # adjust the chunk size to balance between speed and memory usage
32
+ model.renderer.set_chunk_size(8192)
33
+ model.to(device)
34
+
35
+ rembg_session = rembg.new_session()
36
+
37
+
38
+ def check_input_image(input_image):
39
+ if input_image is None:
40
+ raise gr.Error("No image uploaded!")
41
+
42
+
43
+ def preprocess(input_image, do_remove_background, foreground_ratio):
44
+ def fill_background(image):
45
+ image = np.array(image).astype(np.float32) / 255.0
46
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
47
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
48
+ return image
49
+
50
+ if do_remove_background:
51
+ image = input_image.convert("RGB")
52
+ image = remove_background(image, rembg_session)
53
+ image = resize_foreground(image, foreground_ratio)
54
+ image = fill_background(image)
55
+ else:
56
+ image = input_image
57
+ if image.mode == "RGBA":
58
+ image = fill_background(image)
59
+ return image
60
+
61
+
62
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
63
+ scene_codes = model(image, device=device)
64
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
65
+ mesh = to_gradio_3d_orientation(mesh)
66
+ rv = []
67
+ for format in formats:
68
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
69
+ mesh.export(mesh_path.name)
70
+ rv.append(mesh_path.name)
71
+ return rv
72
+
73
+
74
+ def run_example(image_pil):
75
+ preprocessed = preprocess(image_pil, False, 0.9)
76
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
77
+ return preprocessed, mesh_name_obj, mesh_name_glb
78
+
79
+
80
+ with gr.Blocks(title="TripoSR") as interface:
81
+ gr.Markdown(
82
+ """
83
+ # TripoSR Demo
84
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
85
+
86
+ **Tips:**
87
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
88
+ 2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
89
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
90
+ """
91
+ )
92
+ with gr.Row(variant="panel"):
93
+ with gr.Column():
94
+ with gr.Row():
95
+ input_image = gr.Image(
96
+ label="Input Image",
97
+ image_mode="RGBA",
98
+ sources="upload",
99
+ type="pil",
100
+ elem_id="content_image",
101
+ )
102
+ processed_image = gr.Image(label="Processed Image", interactive=False)
103
+ with gr.Row():
104
+ with gr.Group():
105
+ do_remove_background = gr.Checkbox(
106
+ label="Remove Background", value=True
107
+ )
108
+ foreground_ratio = gr.Slider(
109
+ label="Foreground Ratio",
110
+ minimum=0.5,
111
+ maximum=1.0,
112
+ value=0.85,
113
+ step=0.05,
114
+ )
115
+ mc_resolution = gr.Slider(
116
+ label="Marching Cubes Resolution",
117
+ minimum=32,
118
+ maximum=320,
119
+ value=256,
120
+ step=32
121
+ )
122
+ with gr.Row():
123
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
124
+ with gr.Column():
125
+ with gr.Tab("OBJ"):
126
+ output_model_obj = gr.Model3D(
127
+ label="Output Model (OBJ Format)",
128
+ interactive=False,
129
+ )
130
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
131
+ with gr.Tab("GLB"):
132
+ output_model_glb = gr.Model3D(
133
+ label="Output Model (GLB Format)",
134
+ interactive=False,
135
+ )
136
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
137
+ with gr.Row(variant="panel"):
138
+ gr.Examples(
139
+ examples=[
140
+ "examples/hamburger.png",
141
+ "examples/poly_fox.png",
142
+ "examples/robot.png",
143
+ "examples/teapot.png",
144
+ "examples/tiger_girl.png",
145
+ "examples/horse.png",
146
+ "examples/flamingo.png",
147
+ "examples/unicorn.png",
148
+ "examples/chair.png",
149
+ "examples/iso_house.png",
150
+ "examples/marble.png",
151
+ "examples/police_woman.png",
152
+ "examples/captured.jpeg",
153
+ ],
154
+ inputs=[input_image],
155
+ outputs=[processed_image, output_model_obj, output_model_glb],
156
+ cache_examples=False,
157
+ fn=partial(run_example),
158
+ label="Examples",
159
+ examples_per_page=20,
160
+ )
161
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
162
+ fn=preprocess,
163
+ inputs=[input_image, do_remove_background, foreground_ratio],
164
+ outputs=[processed_image],
165
+ ).success(
166
+ fn=generate,
167
+ inputs=[processed_image, mc_resolution],
168
+ outputs=[output_model_obj, output_model_glb],
169
+ )
170
+
171
+
172
+
173
+ if __name__ == '__main__':
174
+ parser = argparse.ArgumentParser()
175
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
176
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
177
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
178
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
179
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
180
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
181
+ args = parser.parse_args()
182
+ interface.queue(max_size=args.queuesize)
183
+ interface.launch(
184
+ auth=(args.username, args.password) if (args.username and args.password) else None,
185
+ share=args.share,
186
+ server_name="0.0.0.0" if args.listen else None,
187
+ server_port=args.port
188
+ )