rgndgn commited on
Commit
59beb43
·
verified ·
1 Parent(s): f700879

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +148 -0
gradio_app.py CHANGED
@@ -63,6 +63,154 @@ example_files = [
63
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
64
  ]
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def auto_process(input_image):
67
  if input_image is None:
68
  return None, None, None, None
 
63
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
64
  ]
65
 
66
+ def create_zip_file(glb_file, pc_file, illumination_file):
67
+ if not all([glb_file, pc_file, illumination_file]):
68
+ return None
69
+
70
+ # Create a temporary zip file
71
+ temp_dir = tempfile.mkdtemp()
72
+ zip_path = os.path.join(temp_dir, "spar3d_output.zip")
73
+
74
+ with zipfile.ZipFile(zip_path, "w") as zipf:
75
+ zipf.write(glb_file, "mesh.glb")
76
+ zipf.write(pc_file, "points.ply")
77
+ zipf.write(illumination_file, "illumination.hdr")
78
+
79
+ generated_files.append(zip_path)
80
+ return zip_path
81
+
82
+ def forward_model(
83
+ batch,
84
+ system,
85
+ guidance_scale=3.0,
86
+ seed=0,
87
+ device="cuda",
88
+ remesh_option="none",
89
+ vertex_count=-1,
90
+ texture_resolution=1024,
91
+ ):
92
+ batch_size = batch["rgb_cond"].shape[0]
93
+
94
+ # prepare the condition for point cloud generation
95
+ # set seed
96
+ random.seed(seed)
97
+ torch.manual_seed(seed)
98
+ np.random.seed(seed)
99
+ cond_tokens = system.forward_pdiff_cond(batch)
100
+
101
+ if "pc_cond" not in batch:
102
+ sample_iter = system.sampler.sample_batch_progressive(
103
+ batch_size,
104
+ cond_tokens,
105
+ guidance_scale=guidance_scale,
106
+ device=device,
107
+ )
108
+ for x in sample_iter:
109
+ samples = x["xstart"]
110
+ batch["pc_cond"] = samples.permute(0, 2, 1).float()
111
+ batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
112
+
113
+ # subsample to the 512 points
114
+ batch["pc_cond"] = batch["pc_cond"][
115
+ :, torch.randperm(batch["pc_cond"].shape[1])[:512]
116
+ ]
117
+
118
+ # get the point cloud
119
+ xyz = batch["pc_cond"][0, :, :3].cpu().numpy()
120
+ color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
121
+ pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
122
+
123
+ # forward for the final mesh
124
+ trimesh_mesh, _glob_dict = model.generate_mesh(
125
+ batch,
126
+ texture_resolution,
127
+ remesh=remesh_option,
128
+ vertex_count=vertex_count,
129
+ estimate_illumination=True,
130
+ )
131
+ trimesh_mesh = trimesh_mesh[0]
132
+ illumination = _glob_dict["illumination"]
133
+
134
+ return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0]
135
+
136
+ def process_model_run(
137
+ fr_res,
138
+ guidance_scale,
139
+ random_seed,
140
+ pc_cond,
141
+ remesh_option,
142
+ vertex_count_type,
143
+ vertex_count,
144
+ texture_resolution,
145
+ ):
146
+ start = time.time()
147
+ with torch.no_grad():
148
+ with (
149
+ torch.autocast(device_type=device, dtype=torch.bfloat16)
150
+ if "cuda" in device
151
+ else nullcontext()
152
+ ):
153
+ model_batch = create_batch(fr_res)
154
+ model_batch = {k: v.to(device) for k, v in model_batch.items()}
155
+
156
+ trimesh_mesh, trimesh_pc, illumination_map = forward_model(
157
+ model_batch,
158
+ model,
159
+ guidance_scale=guidance_scale,
160
+ seed=random_seed,
161
+ device="cuda",
162
+ remesh_option=remesh_option.lower(),
163
+ vertex_count=vertex_count,
164
+ texture_resolution=texture_resolution,
165
+ )
166
+
167
+ # Create new tmp file
168
+ temp_dir = tempfile.mkdtemp()
169
+ tmp_file = os.path.join(temp_dir, "mesh.glb")
170
+
171
+ trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
172
+ generated_files.append(tmp_file)
173
+
174
+ tmp_file_pc = os.path.join(temp_dir, "points.ply")
175
+ trimesh_pc.export(tmp_file_pc)
176
+ generated_files.append(tmp_file_pc)
177
+
178
+ tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
179
+ cv2.imwrite(tmp_file_illumination, illumination_map)
180
+ generated_files.append(tmp_file_illumination)
181
+
182
+ print("Generation took:", time.time() - start, "s")
183
+
184
+ return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc
185
+
186
+ def create_batch(input_image: Image) -> dict[str, Any]:
187
+ img_cond = (
188
+ torch.from_numpy(
189
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
190
+ / 255.0
191
+ )
192
+ .float()
193
+ .clip(0, 1)
194
+ )
195
+ mask_cond = img_cond[:, :, -1:]
196
+ rgb_cond = torch.lerp(
197
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
198
+ )
199
+
200
+ batch_elem = {
201
+ "rgb_cond": rgb_cond,
202
+ "mask_cond": mask_cond,
203
+ "c2w_cond": c2w_cond.unsqueeze(0),
204
+ "intrinsic_cond": intrinsic.unsqueeze(0),
205
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
206
+ }
207
+ # Add batch dim
208
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
209
+ return batched
210
+
211
+ def remove_background(input_image: Image) -> Image:
212
+ return bg_remover.process(input_image.convert("RGB"))
213
+
214
  def auto_process(input_image):
215
  if input_image is None:
216
  return None, None, None, None