# Copyright (c) HKUST SAIL-Lab and Horizon Robotics. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import gc import glob import os import shutil import sys import time from datetime import datetime import cv2 import gradio as gr import numpy as np import torch from tqdm import tqdm from eval.utils.device import to_cpu from eval.utils.eval_utils import uniform_sample from sailrecon.models.sail_recon import SailRecon from sailrecon.utils.geometry import unproject_depth_map_to_point_map from sailrecon.utils.load_fn import load_and_preprocess_images from sailrecon.utils.pose_enc import ( extri_intri_to_pose_encoding, pose_encoding_to_extri_intri, ) from visual_util import predictions_to_glb device = "cuda" if torch.cuda.is_available() else "cpu" print("Initializing and loading SailRecon model...") model = SailRecon(kv_cache=True) # _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt" # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model_dir = "ckpt/sailrecon.pt" model.load_state_dict(torch.load(model_dir)) model.eval() model = model.to(device) # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- def run_model(target_dir, model, anchor_size=100) -> dict: """ Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions. """ print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Move model to device model = model.to(device) model.eval() # Load and preprocess images image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) print(f"Found {len(image_names)} images") if len(image_names) == 0: raise ValueError("No images found. Check your upload.") images = load_and_preprocess_images(image_names).to(device) print(f"Preprocessed images shape: {images.shape}") # anchor image selection select_indices = uniform_sample(len(image_names), min(100, len(image_names))) anchor_images = images[select_indices] # Run inference print("Running inference...") dtype = ( torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 ) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): print("Processing anchor images ...") model.tmp_forward(anchor_images) # del model.aggregator.global_blocks # relocalization on all images predictions_s = [] with tqdm(total=len(image_names), desc="Relocalizing") as pbar: for img_split in images.split(10, dim=0): pbar.update(10) predictions_s += to_cpu( model.reloc(img_split, ret_img=True, memory_save=False) ) predictions = {} predictions["extrinsic"] = torch.cat( [s["extrinsic"] for s in predictions_s], dim=0 ) # (S, 4, 4) predictions["intrinsic"] = torch.cat( [s["intrinsic"] for s in predictions_s], dim=0 ) # (S, 4, 4) predictions["depth"] = torch.cat( [s["depth_map"] for s in predictions_s], dim=0 ) # (S, H, W, 1) predictions["depth_conf"] = torch.cat( [s["dpt_cnf"] for s in predictions_s], dim=0 ) # (S, H, W, 1) predictions["images"] = torch.cat( [s["images"] for s in predictions_s], dim=0 ) # (S, H, W, 3) predictions["world_points"] = torch.cat( [s["point_map"] for s in predictions_s], dim=0 ) # (S, H, W, 3) predictions["world_points_conf"] = torch.cat( [s["xyz_cnf"] for s in predictions_s], dim=0 ) # (S, H, W, 3) predictions["pose_enc"] = extri_intri_to_pose_encoding( predictions["extrinsic"].unsqueeze(0), predictions["intrinsic"].unsqueeze(0), images.shape[-2:], )[ 0 ] # a del predictions_s # Convert tensors to numpy for key in predictions.keys(): if isinstance(predictions[key], torch.Tensor): predictions[key] = predictions[key].cpu().numpy() # remove batch dimension predictions["pose_enc_list"] = None # remove pose_enc_list # Generate world points from depth map print("Computing world points from depth map...") depth_map = predictions["depth"] # (S, H, W, 1) world_points = unproject_depth_map_to_point_map( depth_map, predictions["extrinsic"], predictions["intrinsic"] ) predictions["world_points_from_depth"] = world_points # Clean up torch.cuda.empty_cache() return predictions # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle images --- if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- def gradio_demo( target_dir, conf_thres=3.0, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, downsample_ratio=100.0, prediction_mode="Pointmap Regression", ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running run_model...") with torch.no_grad(): predictions = run_model(target_dir, model) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, downsample_ratio=downsample_ratio / 100.0, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, downsample_ratio, prediction_mode, is_example, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( None, "No reconstruction available. Please click the Reconstruct button first.", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( None, "No reconstruction available. Please click the Reconstruct button first.", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) key_list = [ "pose_enc", "depth", "depth_conf", "world_points", "world_points_conf", "images", "extrinsic", "intrinsic", "world_points_from_depth", ] loaded = np.load(predictions_path) predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded} print(downsample_ratio) glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, downsample_ratio=downsample_ratio * 1.0 / 100.0, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization" # ------------------------------------------------------------------------- # Example images # ------------------------------------------------------------------------- great_wall_video = "examples/videos/great_wall.mp4" colosseum_video = "examples/videos/Colosseum.mp4" room_video = "examples/videos/room.mp4" kitchen_video = "examples/videos/kitchen.mp4" fern_video = "examples/videos/fern.mp4" single_cartoon_video = "examples/videos/single_cartoon.mp4" single_oil_painting_video = "examples/videos/single_oil_painting.mp4" pyramid_video = "examples/videos/pyramid.mp4" # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = gr.themes.Ocean() theme.set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", ) with gr.Blocks( theme=theme, css=""" .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .example-log * { font-style: italic; font-size: 16px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; color: transparent !important; } #my_radio .wrap { display: flex; flex-wrap: nowrap; justify-content: center; align-items: center; } #my_radio .wrap label { display: flex; width: 50%; justify-content: center; align-items: center; margin: 0; padding: 10px 0; box-sizing: border-box; } """, ) as demo: # Instead of gr.State, we use a hidden Textbox: is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") gr.HTML( """
🐙 GitHub Repository | Project Page
Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.
Please note: SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.