Spaces:
Paused
Paused
Update mimicmotion/pipelines/pipeline_mimicmotion.py
Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py
CHANGED
|
@@ -222,40 +222,33 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 222 |
decode_chunk_size: int = 8):
|
| 223 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
| 224 |
latents = latents.flatten(0, 1)
|
|
|
|
| 225 |
latents = 1 / self.vae.config.scaling_factor * latents
|
| 226 |
-
|
| 227 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
| 228 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
| 229 |
-
|
| 230 |
-
#
|
| 231 |
-
def process_chunk(start, end, frames_list):
|
| 232 |
-
decode_kwargs = {}
|
| 233 |
-
if accepts_num_frames:
|
| 234 |
-
decode_kwargs["num_frames"] = end - start
|
| 235 |
-
frame = self.vae.decode(latents[start:end], **decode_kwargs).sample
|
| 236 |
-
frames_list.append(frame.cpu())
|
| 237 |
-
|
| 238 |
-
threads = []
|
| 239 |
frames = []
|
| 240 |
-
|
| 241 |
-
# Dividindo o trabalho em chunks e criando threads para processá-los
|
| 242 |
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
| 252 |
frames = torch.cat(frames, dim=0)
|
|
|
|
|
|
|
| 253 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 254 |
-
|
| 255 |
-
#
|
| 256 |
frames = frames.float()
|
| 257 |
return frames
|
| 258 |
|
|
|
|
| 259 |
def check_inputs(self, image, height, width):
|
| 260 |
if (
|
| 261 |
not isinstance(image, torch.Tensor)
|
|
@@ -563,17 +556,21 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 563 |
# expand the latents if we are doing classifier free guidance
|
| 564 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 565 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 566 |
-
|
| 567 |
# Concatenate image_latents over channels dimension
|
| 568 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 569 |
-
|
| 570 |
# predict the noise residual
|
| 571 |
noise_pred = torch.zeros_like(image_latents)
|
| 572 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
| 573 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
| 574 |
weight = torch.minimum(weight, 2 - weight)
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
# classification-free inference
|
| 578 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
| 579 |
_noise_pred = self.unet(
|
|
@@ -585,8 +582,8 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 585 |
image_only_indicator=image_only_indicator,
|
| 586 |
return_dict=False,
|
| 587 |
)[0]
|
| 588 |
-
|
| 589 |
-
|
| 590 |
# normal inference
|
| 591 |
_noise_pred = self.unet(
|
| 592 |
latent_model_input[1:, idx],
|
|
@@ -597,26 +594,34 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
| 597 |
image_only_indicator=image_only_indicator,
|
| 598 |
return_dict=False,
|
| 599 |
)[0]
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
| 605 |
-
|
| 606 |
# perform guidance
|
| 607 |
if self.do_classifier_free_guidance:
|
| 608 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 609 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 610 |
-
|
| 611 |
# compute the previous noisy sample x_t -> x_t-1
|
| 612 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 613 |
-
|
| 614 |
if callback_on_step_end is not None:
|
| 615 |
callback_kwargs = {}
|
| 616 |
for k in callback_on_step_end_tensor_inputs:
|
| 617 |
callback_kwargs[k] = locals()[k]
|
| 618 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 619 |
-
|
| 620 |
latents = callback_outputs.pop("latents", latents)
|
| 621 |
|
| 622 |
self.pose_net.cpu()
|
|
|
|
| 222 |
decode_chunk_size: int = 8):
|
| 223 |
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
| 224 |
latents = latents.flatten(0, 1)
|
| 225 |
+
|
| 226 |
latents = 1 / self.vae.config.scaling_factor * latents
|
| 227 |
+
|
| 228 |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
| 229 |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
| 230 |
+
|
| 231 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
frames = []
|
|
|
|
|
|
|
| 233 |
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 234 |
+
num_frames_in = latents[i: i + decode_chunk_size].shape[0]
|
| 235 |
+
decode_kwargs = {}
|
| 236 |
+
if accepts_num_frames:
|
| 237 |
+
# we only pass num_frames_in if it's expected
|
| 238 |
+
decode_kwargs["num_frames"] = num_frames_in
|
| 239 |
+
|
| 240 |
+
frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
|
| 241 |
+
frames.append(frame.cpu())
|
|
|
|
| 242 |
frames = torch.cat(frames, dim=0)
|
| 243 |
+
|
| 244 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
| 245 |
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 246 |
+
|
| 247 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 248 |
frames = frames.float()
|
| 249 |
return frames
|
| 250 |
|
| 251 |
+
|
| 252 |
def check_inputs(self, image, height, width):
|
| 253 |
if (
|
| 254 |
not isinstance(image, torch.Tensor)
|
|
|
|
| 556 |
# expand the latents if we are doing classifier free guidance
|
| 557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 559 |
+
|
| 560 |
# Concatenate image_latents over channels dimension
|
| 561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 562 |
+
|
| 563 |
# predict the noise residual
|
| 564 |
noise_pred = torch.zeros_like(image_latents)
|
| 565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
| 566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
| 567 |
weight = torch.minimum(weight, 2 - weight)
|
| 568 |
+
|
| 569 |
+
# Paralelização do loop sobre `indices` usando ThreadPoolExecutor
|
| 570 |
+
def process_index(idx):
|
| 571 |
+
nonlocal noise_pred, noise_pred_cnt
|
| 572 |
+
result = torch.zeros_like(image_latents[:1, idx]) # Placeholder for thread-safe accumulation
|
| 573 |
+
|
| 574 |
# classification-free inference
|
| 575 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
| 576 |
_noise_pred = self.unet(
|
|
|
|
| 582 |
image_only_indicator=image_only_indicator,
|
| 583 |
return_dict=False,
|
| 584 |
)[0]
|
| 585 |
+
result[:1] += _noise_pred * weight[:, None, None, None]
|
| 586 |
+
|
| 587 |
# normal inference
|
| 588 |
_noise_pred = self.unet(
|
| 589 |
latent_model_input[1:, idx],
|
|
|
|
| 594 |
image_only_indicator=image_only_indicator,
|
| 595 |
return_dict=False,
|
| 596 |
)[0]
|
| 597 |
+
result[1:] += _noise_pred * weight[:, None, None, None]
|
| 598 |
+
|
| 599 |
+
return result, idx
|
| 600 |
+
|
| 601 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 602 |
+
futures = [executor.submit(process_index, idx) for idx in indices]
|
| 603 |
+
for future in concurrent.futures.as_completed(futures):
|
| 604 |
+
_noise_pred, idx = future.result()
|
| 605 |
+
noise_pred[:, idx] += _noise_pred
|
| 606 |
+
noise_pred_cnt[idx] += weight
|
| 607 |
+
progress_bar.update()
|
| 608 |
+
|
| 609 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
| 610 |
+
|
| 611 |
# perform guidance
|
| 612 |
if self.do_classifier_free_guidance:
|
| 613 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 614 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 615 |
+
|
| 616 |
# compute the previous noisy sample x_t -> x_t-1
|
| 617 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 618 |
+
|
| 619 |
if callback_on_step_end is not None:
|
| 620 |
callback_kwargs = {}
|
| 621 |
for k in callback_on_step_end_tensor_inputs:
|
| 622 |
callback_kwargs[k] = locals()[k]
|
| 623 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 624 |
+
|
| 625 |
latents = callback_outputs.pop("latents", latents)
|
| 626 |
|
| 627 |
self.pose_net.cpu()
|