SAIL-Recon / app.py
poptree's picture
Update app.py
6970df9 verified
import spaces
import gc
import glob
import os
import shutil
import sys
import time
from datetime import datetime
import cv2
import gradio as gr
import numpy as np
import torch
from tqdm import tqdm
from eval.utils.device import to_cpu
from eval.utils.eval_utils import uniform_sample
from sailrecon.models.sail_recon import SailRecon
from sailrecon.utils.geometry import unproject_depth_map_to_point_map
from sailrecon.utils.load_fn import load_and_preprocess_images
from sailrecon.utils.pose_enc import (
extri_intri_to_pose_encoding,
pose_encoding_to_extri_intri,
)
from visual_util import predictions_to_glb
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Initializing and loading SailRecon model...")
model = SailRecon(kv_cache=True)
model_dir = "ckpt/sailrecon.pt"
if os.path.exists(model_dir):
model.load_state_dict(torch.load(model_dir))
else:
_URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
# model.eval()
# model = model.to(device)
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
@spaces.GPU
def run_model(target_dir, model, anchor_size=100) -> dict:
"""
Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions.
"""
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Move model to device
model = model.to(device)
model.eval()
# Load and preprocess images
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
# sample 100 images, HF cannot solve more images.
sd = len(image_names) // 100
image_names = sorted(image_names)[::sd]
print(f"Found {len(image_names)} images")
if len(image_names) == 0:
raise ValueError("No images found. Check your upload.")
images = load_and_preprocess_images(image_names).to(device)
print(f"Preprocessed images shape: {images.shape}")
# anchor image selection
select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
anchor_images = images[select_indices]
# Run inference
print("Running inference...")
dtype = (
torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
print("Processing anchor images ...")
model.tmp_forward(anchor_images)
# del model.aggregator.global_blocks
# relocalization on all images
predictions_s = []
with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
for img_split in images.split(15, dim=0):
pbar.update(15)
predictions_s += to_cpu(
model.reloc(img_split, ret_img=True, memory_save=True)
)
predictions = {}
predictions["extrinsic"] = torch.cat(
[s["extrinsic"] for s in predictions_s], dim=0
) # (S, 4, 4)
predictions["intrinsic"] = torch.cat(
[s["intrinsic"] for s in predictions_s], dim=0
) # (S, 4, 4)
predictions["depth"] = torch.cat(
[s["depth_map"] for s in predictions_s], dim=0
) # (S, H, W, 1)
predictions["depth_conf"] = torch.cat(
[s["dpt_cnf"] for s in predictions_s], dim=0
) # (S, H, W, 1)
predictions["images"] = torch.cat(
[s["images"] for s in predictions_s], dim=0
) # (S, H, W, 3)
# predictions["world_points"] = torch.cat(
# [s["point_map"] for s in predictions_s], dim=0
# ) # (S, H, W, 3)
# predictions["world_points_conf"] = torch.cat(
# [s["xyz_cnf"] for s in predictions_s], dim=0
# ) # (S, H, W, 3)
predictions["pose_enc"] = extri_intri_to_pose_encoding(
predictions["extrinsic"].unsqueeze(0),
predictions["intrinsic"].unsqueeze(0),
images.shape[-2:],
)[
0
] # a
del predictions_s
# Convert tensors to numpy
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
predictions[key] = predictions[key].cpu().numpy() # remove batch dimension
predictions["pose_enc_list"] = None # remove pose_enc_list
# Generate world points from depth map
print("Computing world points from depth map...")
depth_map = predictions["depth"] # (S, H, W, 1)
world_points = unproject_depth_map_to_point_map(
depth_map, predictions["extrinsic"], predictions["intrinsic"]
)
predictions["world_points_from_depth"] = world_points
# Clean up
torch.cuda.empty_cache()
return predictions
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle images ---
if input_images is not None:
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# --- Handle video ---
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return (
None,
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def gradio_demo(
target_dir,
conf_thres=3.0,
frame_filter="All",
mask_black_bg=False,
mask_white_bg=False,
show_cam=True,
mask_sky=False,
downsample_ratio=100.0,
prediction_mode="Pointmap Regression",
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images))
if os.path.isdir(target_dir_images)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running run_model...")
with torch.no_grad():
predictions = run_model(target_dir, model)
for k,v in predictions.items():
print(f"{k}: {v.shape if isinstance(v, np.ndarray) else type(v)}")
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
# Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
downsample_ratio=downsample_ratio / 100.0,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
log_msg = (
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
)
return (
glbfile,
log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return (
None,
"No reconstruction available. Please click the Reconstruct button first.",
)
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return (
None,
"No reconstruction available. Please click the Reconstruct button first.",
)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return (
None,
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
)
key_list = [
"pose_enc",
"depth",
"depth_conf",
# "world_points",
# "world_points_conf",
"images",
"extrinsic",
"intrinsic",
"world_points_from_depth",
]
loaded = np.load(predictions_path)
predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
print(downsample_ratio)
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
downsample_ratio=downsample_ratio * 1.0 / 100.0,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
return glbfile, "Updating Visualization"
# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------
great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
""",
) as demo:
# Instead of gr.State, we use a hidden Textbox:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
gr.HTML(
"""
<h1>🏛️ SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
<p>
<a href="https://github.com/HKUST-SAIL/sail-recon">🐙 GitHub Repository</a> |
<a href="https://hkust-sail.github.io/sail-recon/">Project Page</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
<h3>Getting Started:</h3>
<ol>
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note that the visualization of 3D points may be slow for a large number of input images.</li>
<li>
<strong>Adjust Visualization (Optional):</strong>
After reconstruction, you can fine-tune the visualization using the options below
<details style="display:inline;">
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
<ul>
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
</ul>
</details>
</li>
</ol>
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.</span></p>
<h3><strong style="color: #0ea5e9;">NOTE:</strong> <span style="color: #0ea5e9; font-weight: bold;">due to the limitation of zerogpu provided by Huggingface, this web application can only process 50-200 images depending on the dynamic GPU resources. Please run `demo_gradio.py` LOCALLY if you want to process the video input.</span></h3>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(
file_count="multiple", label="Upload Images", interactive=True
)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=4):
with gr.Column():
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.",
elem_classes=["custom-log"],
)
reconstruction_output = gr.Model3D(
height=520, zoom_speed=0.5, pan_speed=0.5
)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[
input_video,
input_images,
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
],
scale=1,
)
with gr.Row():
prediction_mode = gr.Radio(
["Depthmap and Camera Branch"],
label="Select a Prediction Mode",
value="Depthmap and Camera Branch",
scale=1,
elem_id="my_radio",
)
with gr.Row():
conf_thres = gr.Slider(
minimum=0,
maximum=100,
value=70,
step=0.1,
label="Confidence Threshold (%)",
)
downsample_ratio = gr.Slider(
minimum=1.0,
maximum=100,
value=20,
step=0.1,
label="Downsample Ratio(%)",
)
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="Show Points from Frame"
)
with gr.Column():
show_cam = gr.Checkbox(label="Show Camera", value=True)
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
mask_black_bg = gr.Checkbox(
label="Filter Black Background", value=False
)
mask_white_bg = gr.Checkbox(
label="Filter White Background", value=False
)
# ---------------------- Examples section ----------------------
examples = [
[
"samples/kitchen.mp4",
"25",
None,
"5.0",
False,
False,
True,
False,
50.0,
"Depthmap and Camera Branch",
"True",
],
[
"samples/barn_images.mp4",
"410",
None,
"5.0",
False,
False,
True,
False,
50.0,
"Depthmap and Camera Branch",
"True",
],
[
"samples/rgbd_dataset_freiburg1_360.mp4",
"410",
None,
"5.0",
False,
False,
True,
False,
50.0,
"Depthmap and Camera Branch",
"True",
],
]
def example_pipeline(
input_video: str,
num_images_str: str|None,
input_images: str|None,
conf_thres: str,
mask_black_bg: bool,
mask_white_bg: bool,
show_cam: bool,
mask_sky: bool,
downsample_ratio: float,
prediction_mode: str,
is_example_str: str,
):
"""
1) Copy example images to new target_dir
2) Reconstruct
3) Return model3D + logs + new_dir + updated dropdown + gallery
We do NOT return is_example. It's just an input.
"""
target_dir, image_paths = handle_uploads(input_video, input_images)
# Always use "All" for frame_filter in examples
frame_filter = "All"
# glbfile, log_msg, dropdown = gradio_demo(
# target_dir,
# conf_thres,
# frame_filter,
# mask_black_bg,
# mask_white_bg,
# show_cam,
# mask_sky,
# downsample_ratio,
# prediction_mode,
# )
# return glbfile, log_msg, target_dir, dropdown, image_paths
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
gr.Examples(
examples=examples,
inputs=[
input_video,
num_images,
input_images,
# conf_thres,
# # mask_black_bg,
# # mask_white_bg,
# # show_cam,
# # mask_sky,
# # downsample_ratio,
# # prediction_mode,
# # is_example,
],
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
fn=example_pipeline,
cache_examples=False,
examples_per_page=50,
)
# -------------------------------------------------------------------------
# "Reconstruct" button logic:
# - Clear fields
# - Update log
# - gradio_demo(...) with the existing target_dir
# - Then set is_example = "False"
# -------------------------------------------------------------------------
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
],
outputs=[reconstruction_output, log_output, frame_filter],
).then(
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
conf_thres.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
downsample_ratio.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
frame_filter.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_black_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_white_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_sky.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
prediction_mode.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
downsample_ratio,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
# -------------------------------------------------------------------------
# Auto-update gallery whenever user uploads or changes their files
# -------------------------------------------------------------------------
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
demo.queue(max_size=20).launch(show_error=True, share=True,ssr_mode=False)