cavargas10 commited on
Commit
f7b7abd
·
verified ·
1 Parent(s): 6163a25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -53
app.py CHANGED
@@ -15,22 +15,18 @@ from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
-
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
-
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
 
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
32
 
33
-
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  processed_image = pipeline.preprocess_image(image)
36
  return processed_image
@@ -51,7 +47,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
  },
52
  }
53
 
54
-
55
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
56
  gs = Gaussian(
57
  aabb=state['gaussian']['aabb'],
@@ -74,14 +69,9 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
74
 
75
  return gs, mesh
76
 
77
-
78
  def get_seed(randomize_seed: bool, seed: int) -> int:
79
- """
80
- Get the random seed.
81
- """
82
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
83
 
84
-
85
  @spaces.GPU
86
  def image_to_3d(
87
  image: Image.Image,
@@ -117,7 +107,6 @@ def image_to_3d(
117
  torch.cuda.empty_cache()
118
  return state, video_path
119
 
120
-
121
  @spaces.GPU(duration=90)
122
  def extract_glb(
123
  state: dict,
@@ -125,7 +114,6 @@ def extract_glb(
125
  texture_size: int,
126
  req: gr.Request,
127
  ) -> Tuple[str, str]:
128
-
129
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
130
  gs, mesh = unpack_state(state)
131
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
@@ -134,20 +122,7 @@ def extract_glb(
134
  torch.cuda.empty_cache()
135
  return glb_path, glb_path
136
 
137
- @spaces.GPU
138
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
139
-
140
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
141
- gs, _ = unpack_state(state)
142
- gaussian_path = os.path.join(user_dir, 'sample.ply')
143
- gs.save_ply(gaussian_path)
144
- torch.cuda.empty_cache()
145
- return gaussian_path, gaussian_path
146
-
147
  def split_image(image: Image.Image) -> List[Image.Image]:
148
- """
149
- Split an image into multiple views.
150
- """
151
  image = np.array(image)
152
  alpha = image[..., 3]
153
  alpha = np.any(alpha>0, axis=0)
@@ -158,14 +133,13 @@ def split_image(image: Image.Image) -> List[Image.Image]:
158
  images.append(Image.fromarray(image[:, s:e+1]))
159
  return [preprocess_image(image) for image in images]
160
 
161
-
162
  with gr.Blocks(delete_cache=(600, 600)) as demo:
163
  gr.Markdown("""
164
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
165
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
166
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
167
 
168
- ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
169
  """)
170
 
171
  with gr.Row():
@@ -192,24 +166,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
192
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
193
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
194
 
195
- with gr.Row():
196
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
197
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
198
- gr.Markdown("""
199
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
200
- """)
201
 
202
  with gr.Column():
203
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
204
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
205
 
206
  with gr.Row():
207
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
208
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
209
-
210
  output_buf = gr.State()
211
 
212
- # Handlers
213
  demo.load(start_session)
214
  demo.unload(end_session)
215
 
@@ -228,13 +195,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
228
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
229
  outputs=[output_buf, video_output],
230
  ).then(
231
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
232
- outputs=[extract_glb_btn, extract_gs_btn],
233
  )
234
 
235
  video_output.clear(
236
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
237
- outputs=[extract_glb_btn, extract_gs_btn],
238
  )
239
 
240
  extract_glb_btn.click(
@@ -245,23 +212,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
245
  lambda: gr.Button(interactive=True),
246
  outputs=[download_glb],
247
  )
248
-
249
- extract_gs_btn.click(
250
- extract_gaussian,
251
- inputs=[output_buf],
252
- outputs=[model_output, download_gs],
253
- ).then(
254
- lambda: gr.Button(interactive=True),
255
- outputs=[download_gs],
256
- )
257
 
258
  model_output.clear(
259
  lambda: gr.Button(interactive=False),
260
  outputs=[download_glb],
261
  )
262
 
263
-
264
- # Launch the Gradio app
265
  if __name__ == "__main__":
266
  pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
267
  pipeline.cuda()
@@ -269,4 +225,4 @@ if __name__ == "__main__":
269
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
270
  except:
271
  pass
272
- demo.launch()
 
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
 
22
  def start_session(req: gr.Request):
23
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
24
  os.makedirs(user_dir, exist_ok=True)
25
 
 
26
  def end_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
  shutil.rmtree(user_dir)
29
 
 
30
  def preprocess_image(image: Image.Image) -> Image.Image:
31
  processed_image = pipeline.preprocess_image(image)
32
  return processed_image
 
47
  },
48
  }
49
 
 
50
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
51
  gs = Gaussian(
52
  aabb=state['gaussian']['aabb'],
 
69
 
70
  return gs, mesh
71
 
 
72
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
73
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
74
 
 
75
  @spaces.GPU
76
  def image_to_3d(
77
  image: Image.Image,
 
107
  torch.cuda.empty_cache()
108
  return state, video_path
109
 
 
110
  @spaces.GPU(duration=90)
111
  def extract_glb(
112
  state: dict,
 
114
  texture_size: int,
115
  req: gr.Request,
116
  ) -> Tuple[str, str]:
 
117
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
118
  gs, mesh = unpack_state(state)
119
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
122
  torch.cuda.empty_cache()
123
  return glb_path, glb_path
124
 
 
 
 
 
 
 
 
 
 
 
125
  def split_image(image: Image.Image) -> List[Image.Image]:
 
 
 
126
  image = np.array(image)
127
  alpha = image[..., 3]
128
  alpha = np.any(alpha>0, axis=0)
 
133
  images.append(Image.fromarray(image[:, s:e+1]))
134
  return [preprocess_image(image) for image in images]
135
 
 
136
  with gr.Blocks(delete_cache=(600, 600)) as demo:
137
  gr.Markdown("""
138
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
139
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
140
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
141
 
142
+ ✨New: Experimental multi-image support.
143
  """)
144
 
145
  with gr.Row():
 
166
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
167
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
168
 
169
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
 
 
 
 
170
 
171
  with gr.Column():
172
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
173
+ model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
174
 
175
  with gr.Row():
176
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
177
+
 
178
  output_buf = gr.State()
179
 
 
180
  demo.load(start_session)
181
  demo.unload(end_session)
182
 
 
195
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
196
  outputs=[output_buf, video_output],
197
  ).then(
198
+ lambda: gr.Button(interactive=True),
199
+ outputs=[extract_glb_btn],
200
  )
201
 
202
  video_output.clear(
203
+ lambda: gr.Button(interactive=False),
204
+ outputs=[extract_glb_btn],
205
  )
206
 
207
  extract_glb_btn.click(
 
212
  lambda: gr.Button(interactive=True),
213
  outputs=[download_glb],
214
  )
 
 
 
 
 
 
 
 
 
215
 
216
  model_output.clear(
217
  lambda: gr.Button(interactive=False),
218
  outputs=[download_glb],
219
  )
220
 
 
 
221
  if __name__ == "__main__":
222
  pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
223
  pipeline.cuda()
 
225
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
226
  except:
227
  pass
228
+ demo.launch()