Spaces:
Runtime error
Runtime error
Commit
·
390338e
1
Parent(s):
d5a5fa0
Refactor inference configuration and pipeline logic; removed unused parameters and improved frame selection process. Updated inference settings in inference.yaml and streamlined surfel model initialization in pipeline.py.
Browse files- configs/inference/inference.yaml +2 -22
- modeling/pipeline.py +88 -243
configs/inference/inference.yaml
CHANGED
|
@@ -4,20 +4,15 @@ model:
|
|
| 4 |
width: 576
|
| 5 |
original_height: 288
|
| 6 |
original_width: 512
|
| 7 |
-
|
| 8 |
-
# pretrained_model_path: "stabilityai/stable-diffusion-2-1"
|
| 9 |
-
# pretrained_video_model_path: "stabilityai/stable-video-diffusion-img2vid"
|
| 10 |
|
| 11 |
context_num_frames: 4
|
| 12 |
target_num_frames: 4
|
| 13 |
num_frames: 8
|
| 14 |
vae_spatial_scale: 8
|
| 15 |
latent_channels: 4
|
| 16 |
-
# num_ray_blocks: 2
|
| 17 |
vae_scale_factor: 8
|
| 18 |
-
inference_mode: false
|
| 19 |
|
| 20 |
-
temporal_only: false
|
| 21 |
use_non_maximum_suppression: true
|
| 22 |
translation_distance_weight: 0.1
|
| 23 |
|
|
@@ -26,14 +21,7 @@ model:
|
|
| 26 |
cfg_min: 1.2
|
| 27 |
cfg: 2.0
|
| 28 |
guider_types: 1
|
| 29 |
-
|
| 30 |
samples_dir: "./visualization"
|
| 31 |
-
save_flag: false
|
| 32 |
-
use_wandb: false
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# model_path: "/homes/55/runjia/storage/simview_weights/2025-04-30_12-08-55/checkpoint_230000.pth"
|
| 37 |
model_path: "liguang0115/vmem"
|
| 38 |
|
| 39 |
|
|
@@ -45,7 +33,7 @@ surfel:
|
|
| 45 |
merge_position_threshold: 0.2
|
| 46 |
merge_normal_threshold: 0.6
|
| 47 |
lr: 0.01
|
| 48 |
-
niter:
|
| 49 |
model_path: "liguang0115/cut3r"
|
| 50 |
width: 512
|
| 51 |
height: 288
|
|
@@ -54,14 +42,6 @@ inference:
|
|
| 54 |
visualize: true
|
| 55 |
visualize_pointcloud: false
|
| 56 |
visualize_surfel: false
|
| 57 |
-
save_surfels: false
|
| 58 |
-
image_dir: "/homes/55/runjia/storage/realestate10k/video_data/test"
|
| 59 |
-
meta_info_dir: "/homes/55/runjia/storage/realestate10k/RealEstate10K/test"
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
|
| 67 |
|
|
|
|
| 4 |
width: 576
|
| 5 |
original_height: 288
|
| 6 |
original_width: 512
|
| 7 |
+
|
|
|
|
|
|
|
| 8 |
|
| 9 |
context_num_frames: 4
|
| 10 |
target_num_frames: 4
|
| 11 |
num_frames: 8
|
| 12 |
vae_spatial_scale: 8
|
| 13 |
latent_channels: 4
|
|
|
|
| 14 |
vae_scale_factor: 8
|
|
|
|
| 15 |
|
|
|
|
| 16 |
use_non_maximum_suppression: true
|
| 17 |
translation_distance_weight: 0.1
|
| 18 |
|
|
|
|
| 21 |
cfg_min: 1.2
|
| 22 |
cfg: 2.0
|
| 23 |
guider_types: 1
|
|
|
|
| 24 |
samples_dir: "./visualization"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
model_path: "liguang0115/vmem"
|
| 26 |
|
| 27 |
|
|
|
|
| 33 |
merge_position_threshold: 0.2
|
| 34 |
merge_normal_threshold: 0.6
|
| 35 |
lr: 0.01
|
| 36 |
+
niter: 400
|
| 37 |
model_path: "liguang0115/cut3r"
|
| 38 |
width: 512
|
| 39 |
height: 288
|
|
|
|
| 42 |
visualize: true
|
| 43 |
visualize_pointcloud: false
|
| 44 |
visualize_surfel: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
|
modeling/pipeline.py
CHANGED
|
@@ -4,27 +4,19 @@ from copy import deepcopy
|
|
| 4 |
|
| 5 |
import math
|
| 6 |
|
| 7 |
-
# import matplotlib.pyplot as plt
|
| 8 |
-
# from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
| 9 |
-
|
| 10 |
import PIL
|
| 11 |
-
from PIL import Image, ImageOps
|
| 12 |
import numpy as np
|
| 13 |
from einops import repeat
|
| 14 |
-
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn.functional as F
|
| 18 |
-
from torch.amp import autocast
|
| 19 |
import torchvision.transforms as tvf
|
| 20 |
|
| 21 |
|
| 22 |
-
# from diffusers import AutoencoderKL, DiffusionPipeline
|
| 23 |
-
# from diffusers.schedulers import DDIMScheduler
|
| 24 |
from diffusers.utils import export_to_gif
|
| 25 |
|
| 26 |
import sys
|
| 27 |
-
# Add CUT3R to Python path for imports
|
| 28 |
sys.path.append("./extern/CUT3R")
|
| 29 |
from extern.CUT3R.surfel_inference import run_inference_from_pil
|
| 30 |
from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
|
|
@@ -91,32 +83,23 @@ class VMemPipeline:
|
|
| 91 |
self.device = device
|
| 92 |
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
# Import CUT3R scene alignment module
|
| 108 |
-
from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
|
| 109 |
-
self.GlobalAlignerMode = GlobalAlignerMode
|
| 110 |
-
self.global_aligner = global_aligner
|
| 111 |
|
| 112 |
|
| 113 |
|
| 114 |
-
|
| 115 |
-
self.surfel_model = None
|
| 116 |
|
| 117 |
|
| 118 |
-
|
| 119 |
-
self.temporal_only = self.config.model.temporal_only
|
| 120 |
self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
|
| 121 |
|
| 122 |
self.context_num_frames = self.config.model.context_num_frames
|
|
@@ -537,33 +520,58 @@ class VMemPipeline:
|
|
| 537 |
embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
|
| 538 |
intrinsics = [self.Ks[i] for i in indices]
|
| 539 |
return c2ws, latents, embeddings, intrinsics, indices
|
| 540 |
-
|
| 541 |
-
if self.
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
max_frames = min(self.config.model.context_num_frames, len(
|
| 555 |
|
| 556 |
-
|
| 557 |
-
is_first_step = len(self.pil_frames) <= 1
|
| 558 |
is_second_step = len(self.pil_frames) == 5
|
| 559 |
-
|
| 560 |
|
| 561 |
# Adaptively determine initial threshold based on camera pose distribution
|
| 562 |
if use_non_maximum_suppression is None:
|
| 563 |
use_non_maximum_suppression = self.use_non_maximum_suppression
|
| 564 |
|
| 565 |
if use_non_maximum_suppression:
|
| 566 |
-
|
| 567 |
if is_second_step:
|
| 568 |
# Calculate pairwise distances between existing frames
|
| 569 |
pairwise_distances = []
|
|
@@ -581,32 +589,26 @@ class VMemPipeline:
|
|
| 581 |
pairwise_distances.sort()
|
| 582 |
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
| 583 |
self.initial_threshold = pairwise_distances[percentile_idx]
|
| 584 |
-
|
| 585 |
-
# Ensure threshold is within reasonable bounds
|
| 586 |
-
# initial_threshold = max(0.00, min(0.001, initial_threshold))
|
| 587 |
else:
|
| 588 |
-
self.initial_threshold =
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
else:
|
| 593 |
self.initial_threshold = 1e8
|
| 594 |
-
|
| 595 |
-
|
| 596 |
|
| 597 |
selected_indices = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
# Try with increasingly relaxed thresholds until we get enough frames
|
| 600 |
-
current_threshold
|
| 601 |
-
while len(selected_indices) < min_required_frames and current_threshold <= 1.0:
|
| 602 |
-
# Reset selection with new threshold
|
| 603 |
-
selected_indices = []
|
| 604 |
-
|
| 605 |
-
# Always start with the closest pose
|
| 606 |
-
selected_indices.append(sorted_indices[0])
|
| 607 |
-
|
| 608 |
# Try to add each subsequent pose in order of distance
|
| 609 |
-
for idx in
|
| 610 |
if len(selected_indices) >= max_frames:
|
| 611 |
break
|
| 612 |
|
|
@@ -627,148 +629,22 @@ class VMemPipeline:
|
|
| 627 |
selected_indices.append(idx)
|
| 628 |
|
| 629 |
# If we still don't have enough frames, relax the threshold and try again
|
| 630 |
-
if len(selected_indices) <
|
| 631 |
-
current_threshold
|
| 632 |
else:
|
| 633 |
break
|
| 634 |
|
| 635 |
# If we still don't have enough frames, just take the top frames by distance
|
| 636 |
-
if len(selected_indices) <
|
| 637 |
available_indices = []
|
| 638 |
-
for idx in
|
| 639 |
if idx not in selected_indices:
|
| 640 |
available_indices.append(idx)
|
| 641 |
-
selected_indices.extend(available_indices[:
|
| 642 |
|
| 643 |
# Convert to tensor and maintain original order (don't reverse)
|
| 644 |
-
context_time_indices = torch.
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
else:
|
| 648 |
-
if len(self.pil_frames) == 1:
|
| 649 |
-
context_time_indices = [0]
|
| 650 |
-
else:
|
| 651 |
-
# get the average camera pose
|
| 652 |
-
average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
|
| 653 |
-
transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
|
| 654 |
-
target_K = np.mean(self.surfel_Ks, axis=0)
|
| 655 |
-
# Select frames using surfel-based relevance
|
| 656 |
-
retrieved_info = self.render_surfels_to_image(
|
| 657 |
-
self.surfels,
|
| 658 |
-
transformed_average_c2w,
|
| 659 |
-
[target_K*0.65] * 2,
|
| 660 |
-
principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
|
| 661 |
-
image_width=int(self.config.surfel.width),
|
| 662 |
-
image_height=int(self.config.surfel.height)
|
| 663 |
-
)
|
| 664 |
-
_, frame_count = self.process_retrieved_spatial_information(retrieved_info)
|
| 665 |
-
if self.config.inference.visualize:
|
| 666 |
-
visualize_depth(retrieved_info["depth"],
|
| 667 |
-
visualization_dir=self.visualize_dir,
|
| 668 |
-
file_name=f"retrieved_depth_surfels.png",
|
| 669 |
-
size=(self.width, self.height))
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
# Build candidate frames based on relevance count
|
| 673 |
-
candidates = []
|
| 674 |
-
for frame, count in frame_count:
|
| 675 |
-
candidates.extend([frame] * count)
|
| 676 |
-
indices_to_frame = {
|
| 677 |
-
i: frame for i, frame in enumerate(candidates)
|
| 678 |
-
}
|
| 679 |
-
|
| 680 |
-
# Sort candidates by distance to target view
|
| 681 |
-
distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
|
| 682 |
-
torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
|
| 683 |
-
weight_translation=self.config.model.translation_distance_weight).item()
|
| 684 |
-
for frame in candidates]
|
| 685 |
-
|
| 686 |
-
sorted_indices = torch.argsort(torch.tensor(distances))
|
| 687 |
-
sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
|
| 688 |
-
max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
is_second_step = len(self.pil_frames) == 5
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
# Adaptively determine initial threshold based on camera pose distribution
|
| 695 |
-
if use_non_maximum_suppression is None:
|
| 696 |
-
use_non_maximum_suppression = self.use_non_maximum_suppression
|
| 697 |
-
|
| 698 |
-
if use_non_maximum_suppression:
|
| 699 |
-
if is_second_step:
|
| 700 |
-
# Calculate pairwise distances between existing frames
|
| 701 |
-
pairwise_distances = []
|
| 702 |
-
for i in range(len(self.c2ws)):
|
| 703 |
-
for j in range(i+1, len(self.c2ws)):
|
| 704 |
-
sim = self.geodesic_distance(
|
| 705 |
-
torch.from_numpy(np.array(self.c2ws[i])).to(self.device, self.dtype),
|
| 706 |
-
torch.from_numpy(np.array(self.c2ws[j])).to(self.device, self.dtype),
|
| 707 |
-
weight_translation=self.config.model.translation_distance_weight
|
| 708 |
-
)
|
| 709 |
-
pairwise_distances.append(sim.item())
|
| 710 |
-
|
| 711 |
-
if pairwise_distances:
|
| 712 |
-
# Sort distances and take percentile as threshold
|
| 713 |
-
pairwise_distances.sort()
|
| 714 |
-
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
| 715 |
-
self.initial_threshold = pairwise_distances[percentile_idx]
|
| 716 |
-
else:
|
| 717 |
-
self.initial_threshold = 1
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
else:
|
| 722 |
-
self.initial_threshold = 1e8
|
| 723 |
-
|
| 724 |
-
selected_indices = []
|
| 725 |
-
current_threshold = self.initial_threshold
|
| 726 |
-
|
| 727 |
-
# Always start with the closest pose
|
| 728 |
-
selected_indices.append(sorted_frames[0])
|
| 729 |
-
if not use_non_maximum_suppression:
|
| 730 |
-
selected_indices.append(len(self.c2ws) - 1)
|
| 731 |
-
|
| 732 |
-
# Try with increasingly relaxed thresholds until we get enough frames
|
| 733 |
-
while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
|
| 734 |
-
# Try to add each subsequent pose in order of distance
|
| 735 |
-
for idx in sorted_frames[1:]:
|
| 736 |
-
if len(selected_indices) >= max_frames:
|
| 737 |
-
break
|
| 738 |
-
|
| 739 |
-
# Check if this candidate is sufficiently different from all selected frames
|
| 740 |
-
is_too_similar = False
|
| 741 |
-
for selected_idx in selected_indices:
|
| 742 |
-
similarity = self.geodesic_distance(
|
| 743 |
-
torch.from_numpy(np.array(self.c2ws[idx])).to(self.device, self.dtype),
|
| 744 |
-
torch.from_numpy(np.array(self.c2ws[selected_idx])).to(self.device, self.dtype),
|
| 745 |
-
weight_translation=self.config.model.translation_distance_weight
|
| 746 |
-
)
|
| 747 |
-
if similarity < current_threshold:
|
| 748 |
-
is_too_similar = True
|
| 749 |
-
break
|
| 750 |
-
|
| 751 |
-
# Add to selected frames if not too similar to any existing selection
|
| 752 |
-
if not is_too_similar:
|
| 753 |
-
selected_indices.append(idx)
|
| 754 |
-
|
| 755 |
-
# If we still don't have enough frames, relax the threshold and try again
|
| 756 |
-
if len(selected_indices) < max_frames:
|
| 757 |
-
current_threshold /= 1.2
|
| 758 |
-
else:
|
| 759 |
-
break
|
| 760 |
-
|
| 761 |
-
# If we still don't have enough frames, just take the top frames by distance
|
| 762 |
-
if len(selected_indices) < max_frames:
|
| 763 |
-
available_indices = []
|
| 764 |
-
for idx in sorted_frames:
|
| 765 |
-
if idx not in selected_indices:
|
| 766 |
-
available_indices.append(idx)
|
| 767 |
-
selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
|
| 768 |
-
|
| 769 |
-
# Convert to tensor and maintain original order (don't reverse)
|
| 770 |
-
context_time_indices = torch.from_numpy(np.array(selected_indices))
|
| 771 |
-
context_data = prepare_context_data(context_time_indices)
|
| 772 |
|
| 773 |
(context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
|
| 774 |
print(f"context_time_indices: {context_time_indices}")
|
|
@@ -992,11 +868,7 @@ class VMemPipeline:
|
|
| 992 |
# Flip Y and Z components of camera poses to match dataset convention
|
| 993 |
c2ws_transformed = self.get_transformed_c2ws()
|
| 994 |
|
| 995 |
-
|
| 996 |
-
if self.global_step == 10:
|
| 997 |
-
visualize = True
|
| 998 |
-
else:
|
| 999 |
-
visualize = False
|
| 1000 |
scene = run_inference_from_pil(
|
| 1001 |
input_images,
|
| 1002 |
self.surfel_model,
|
|
@@ -1004,8 +876,7 @@ class VMemPipeline:
|
|
| 1004 |
depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
|
| 1005 |
lr = lr,
|
| 1006 |
niter = niter,
|
| 1007 |
-
|
| 1008 |
-
visualize=visualize,
|
| 1009 |
device=device,
|
| 1010 |
)
|
| 1011 |
|
|
@@ -1043,12 +914,10 @@ class VMemPipeline:
|
|
| 1043 |
)
|
| 1044 |
confs = confs.squeeze(1)
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
# self.surfel_to_timestep = {}
|
| 1048 |
start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
|
| 1049 |
end_idx = len(pointcloud)
|
| 1050 |
-
|
| 1051 |
-
# Create surfels for the current frame
|
| 1052 |
for frame_idx in range(start_idx, end_idx):
|
| 1053 |
surfels = self.pointmap_to_surfels(
|
| 1054 |
pointmap=pointcloud[frame_idx],
|
|
@@ -1077,30 +946,6 @@ class VMemPipeline:
|
|
| 1077 |
for surfel_index in range(num_surfels):
|
| 1078 |
self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
|
| 1079 |
|
| 1080 |
-
# Save surfels if configured
|
| 1081 |
-
if self.config.inference.save_surfels and len(self.surfels) > 0:
|
| 1082 |
-
positions = np.array([s.position for s in surfels], dtype=np.float32)
|
| 1083 |
-
normals = np.array([s.normal for s in surfels], dtype=np.float32)
|
| 1084 |
-
radii = np.array([s.radius for s in surfels], dtype=np.float32)
|
| 1085 |
-
colors = np.array([s.color for s in surfels], dtype=np.float32)
|
| 1086 |
-
|
| 1087 |
-
np.savez(f"{self.config.visualization_dir}/surfels_added.npz",
|
| 1088 |
-
positions=positions,
|
| 1089 |
-
normals=normals,
|
| 1090 |
-
radii=radii,
|
| 1091 |
-
colors=colors)
|
| 1092 |
-
|
| 1093 |
-
positions = np.array([s.position for s in self.surfels], dtype=np.float32)
|
| 1094 |
-
normals = np.array([s.normal for s in self.surfels], dtype=np.float32)
|
| 1095 |
-
radii = np.array([s.radius for s in self.surfels], dtype=np.float32)
|
| 1096 |
-
colors = np.array([s.color for s in self.surfels], dtype=np.float32)
|
| 1097 |
-
|
| 1098 |
-
np.savez(f"{self.config.visualization_dir}/surfels_original.npz",
|
| 1099 |
-
positions=positions,
|
| 1100 |
-
normals=normals,
|
| 1101 |
-
radii=radii,
|
| 1102 |
-
colors=colors)
|
| 1103 |
-
|
| 1104 |
self.surfels.extend(surfels)
|
| 1105 |
|
| 1106 |
if self.config.inference.visualize_surfel:
|
|
@@ -1323,12 +1168,12 @@ class VMemPipeline:
|
|
| 1323 |
self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
|
| 1324 |
|
| 1325 |
# Update scene reconstruction if needed
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
self.global_step += 1
|
| 1333 |
|
| 1334 |
if self.config.inference.visualize:
|
|
@@ -1386,9 +1231,9 @@ class VMemPipeline:
|
|
| 1386 |
|
| 1387 |
# Handle surfels if using reconstructor
|
| 1388 |
self.global_step -= frames_to_remove
|
| 1389 |
-
|
| 1390 |
-
|
| 1391 |
-
|
| 1392 |
|
| 1393 |
|
| 1394 |
# Find surfels that belong only to the removed timesteps
|
|
|
|
| 4 |
|
| 5 |
import math
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
import PIL
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from einops import repeat
|
| 10 |
+
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import torch.nn.functional as F
|
|
|
|
| 14 |
import torchvision.transforms as tvf
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
| 17 |
from diffusers.utils import export_to_gif
|
| 18 |
|
| 19 |
import sys
|
|
|
|
| 20 |
sys.path.append("./extern/CUT3R")
|
| 21 |
from extern.CUT3R.surfel_inference import run_inference_from_pil
|
| 22 |
from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
|
|
|
|
| 83 |
self.device = device
|
| 84 |
|
| 85 |
|
| 86 |
+
surfel_model_path = hf_hub_download(repo_id=self.config.surfel.model_path, filename="cut3r_512_dpt_4_64.pth")
|
| 87 |
+
print(f"Loading model from {surfel_model_path}...")
|
| 88 |
+
add_path_to_dust3r(surfel_model_path)
|
| 89 |
+
self.surfel_model = ARCroco3DStereo.from_pretrained(surfel_model_path).to(device)
|
| 90 |
+
self.surfel_model.eval()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Import CUT3R scene alignment module
|
| 94 |
+
from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
|
| 95 |
+
self.GlobalAlignerMode = GlobalAlignerMode
|
| 96 |
+
self.global_aligner = global_aligner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
|
| 100 |
+
|
|
|
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
| 103 |
self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
|
| 104 |
|
| 105 |
self.context_num_frames = self.config.model.context_num_frames
|
|
|
|
| 520 |
embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
|
| 521 |
intrinsics = [self.Ks[i] for i in indices]
|
| 522 |
return c2ws, latents, embeddings, intrinsics, indices
|
| 523 |
+
|
| 524 |
+
if len(self.pil_frames) == 1:
|
| 525 |
+
context_time_indices = [0]
|
| 526 |
+
else:
|
| 527 |
+
# get the average camera pose
|
| 528 |
+
average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
|
| 529 |
+
transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
|
| 530 |
+
target_K = np.mean(self.surfel_Ks, axis=0)
|
| 531 |
+
# Select frames using surfel-based relevance
|
| 532 |
+
retrieved_info = self.render_surfels_to_image(
|
| 533 |
+
self.surfels,
|
| 534 |
+
transformed_average_c2w,
|
| 535 |
+
[target_K*0.65] * 2,
|
| 536 |
+
principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
|
| 537 |
+
image_width=int(self.config.surfel.width),
|
| 538 |
+
image_height=int(self.config.surfel.height)
|
| 539 |
+
)
|
| 540 |
+
_, frame_count = self.process_retrieved_spatial_information(retrieved_info)
|
| 541 |
+
if self.config.inference.visualize:
|
| 542 |
+
visualize_depth(retrieved_info["depth"],
|
| 543 |
+
visualization_dir=self.visualize_dir,
|
| 544 |
+
file_name=f"retrieved_depth_surfels.png",
|
| 545 |
+
size=(self.width, self.height))
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# Build candidate frames based on relevance count
|
| 549 |
+
candidates = []
|
| 550 |
+
for frame, count in frame_count:
|
| 551 |
+
candidates.extend([frame] * count)
|
| 552 |
+
indices_to_frame = {
|
| 553 |
+
i: frame for i, frame in enumerate(candidates)
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
# Sort candidates by distance to target view
|
| 557 |
+
distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
|
| 558 |
+
torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
|
| 559 |
+
weight_translation=self.config.model.translation_distance_weight).item()
|
| 560 |
+
for frame in candidates]
|
| 561 |
|
| 562 |
+
sorted_indices = torch.argsort(torch.tensor(distances))
|
| 563 |
+
sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
|
| 564 |
+
max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
|
| 565 |
|
| 566 |
+
|
|
|
|
| 567 |
is_second_step = len(self.pil_frames) == 5
|
| 568 |
+
|
| 569 |
|
| 570 |
# Adaptively determine initial threshold based on camera pose distribution
|
| 571 |
if use_non_maximum_suppression is None:
|
| 572 |
use_non_maximum_suppression = self.use_non_maximum_suppression
|
| 573 |
|
| 574 |
if use_non_maximum_suppression:
|
|
|
|
| 575 |
if is_second_step:
|
| 576 |
# Calculate pairwise distances between existing frames
|
| 577 |
pairwise_distances = []
|
|
|
|
| 589 |
pairwise_distances.sort()
|
| 590 |
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
| 591 |
self.initial_threshold = pairwise_distances[percentile_idx]
|
|
|
|
|
|
|
|
|
|
| 592 |
else:
|
| 593 |
+
self.initial_threshold = 1
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
|
| 597 |
else:
|
| 598 |
self.initial_threshold = 1e8
|
|
|
|
|
|
|
| 599 |
|
| 600 |
selected_indices = []
|
| 601 |
+
current_threshold = self.initial_threshold
|
| 602 |
+
|
| 603 |
+
# Always start with the closest pose
|
| 604 |
+
selected_indices.append(sorted_frames[0])
|
| 605 |
+
if not use_non_maximum_suppression:
|
| 606 |
+
selected_indices.append(len(self.c2ws) - 1)
|
| 607 |
|
| 608 |
# Try with increasingly relaxed thresholds until we get enough frames
|
| 609 |
+
while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
# Try to add each subsequent pose in order of distance
|
| 611 |
+
for idx in sorted_frames[1:]:
|
| 612 |
if len(selected_indices) >= max_frames:
|
| 613 |
break
|
| 614 |
|
|
|
|
| 629 |
selected_indices.append(idx)
|
| 630 |
|
| 631 |
# If we still don't have enough frames, relax the threshold and try again
|
| 632 |
+
if len(selected_indices) < max_frames:
|
| 633 |
+
current_threshold /= 1.2
|
| 634 |
else:
|
| 635 |
break
|
| 636 |
|
| 637 |
# If we still don't have enough frames, just take the top frames by distance
|
| 638 |
+
if len(selected_indices) < max_frames:
|
| 639 |
available_indices = []
|
| 640 |
+
for idx in sorted_frames:
|
| 641 |
if idx not in selected_indices:
|
| 642 |
available_indices.append(idx)
|
| 643 |
+
selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
|
| 644 |
|
| 645 |
# Convert to tensor and maintain original order (don't reverse)
|
| 646 |
+
context_time_indices = torch.from_numpy(np.array(selected_indices))
|
| 647 |
+
context_data = prepare_context_data(context_time_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
(context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
|
| 650 |
print(f"context_time_indices: {context_time_indices}")
|
|
|
|
| 868 |
# Flip Y and Z components of camera poses to match dataset convention
|
| 869 |
c2ws_transformed = self.get_transformed_c2ws()
|
| 870 |
|
| 871 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 872 |
scene = run_inference_from_pil(
|
| 873 |
input_images,
|
| 874 |
self.surfel_model,
|
|
|
|
| 876 |
depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
|
| 877 |
lr = lr,
|
| 878 |
niter = niter,
|
| 879 |
+
visualize=self.config.inference.visualize_surfel,
|
|
|
|
| 880 |
device=device,
|
| 881 |
)
|
| 882 |
|
|
|
|
| 914 |
)
|
| 915 |
confs = confs.squeeze(1)
|
| 916 |
|
| 917 |
+
|
|
|
|
| 918 |
start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
|
| 919 |
end_idx = len(pointcloud)
|
| 920 |
+
|
|
|
|
| 921 |
for frame_idx in range(start_idx, end_idx):
|
| 922 |
surfels = self.pointmap_to_surfels(
|
| 923 |
pointmap=pointcloud[frame_idx],
|
|
|
|
| 946 |
for surfel_index in range(num_surfels):
|
| 947 |
self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
|
| 948 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 949 |
self.surfels.extend(surfels)
|
| 950 |
|
| 951 |
if self.config.inference.visualize_surfel:
|
|
|
|
| 1168 |
self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
|
| 1169 |
|
| 1170 |
# Update scene reconstruction if needed
|
| 1171 |
+
|
| 1172 |
+
self.construct_and_store_scene(self.pil_frames,
|
| 1173 |
+
time_indices=context_time_indices,
|
| 1174 |
+
niter=self.config.surfel.niter,
|
| 1175 |
+
lr=self.config.surfel.lr,
|
| 1176 |
+
device=self.device)
|
| 1177 |
self.global_step += 1
|
| 1178 |
|
| 1179 |
if self.config.inference.visualize:
|
|
|
|
| 1231 |
|
| 1232 |
# Handle surfels if using reconstructor
|
| 1233 |
self.global_step -= frames_to_remove
|
| 1234 |
+
|
| 1235 |
+
for _ in range(frames_to_remove):
|
| 1236 |
+
self.surfel_depths.pop()
|
| 1237 |
|
| 1238 |
|
| 1239 |
# Find surfels that belong only to the removed timesteps
|