LaRI / app.py
ruili3's picture
update text
d4d16aa
raw
history blame contribute delete
9.88 kB
import argparse
import gradio
import torch
import torch.backends.cudnn as cudnn
from src.utils.vis import prob_to_mask
from src.lari.model import LaRIModel, DinoSegModel
from tools import load_model, process_image, post_process_output, get_masked_depth, save_to_glb, get_point_cloud, removebg_crop
from huggingface_hub import hf_hub_download
parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
parser.add_argument(
"--model_info_pm",
type=str,
default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
help="Network parameters to load the model",
)
parser.add_argument(
"--model_info_mask",
type=str,
default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
help="Network parameters to load the model",
)
parser.add_argument(
"--ckpt_path_pm",
type=str,
default="lari_obj_16k_pointmap.pth",
help="Path to pre-trained weights",
)
parser.add_argument(
"--ckpt_path_mask",
type=str,
default="lari_obj_16k_seg.pth",
help="Path to pre-trained weights",
)
parser.add_argument(
"--resolution", type=int, default=512, help="Default model resolution"
)
args = parser.parse_args()
def model_forward(pil_input, layered_id, rembg_checkbox):
"""
Perform LaRI estimation by:
1. image processing
2. network forward
3. save masked layered depth image
4. save point cloud
"""
if pil_input is None:
return (None, None, None, None, None, None)
if rembg_checkbox:
pil_input = removebg_crop(pil_input)
# Process the input image.
input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
pil_input, resolution=512
)
input_tensor = input_tensor.to(device)
# Run inference.
with torch.no_grad():
# lari map
pred_dict = model_pm(input_tensor)
lari_map = -pred_dict["pts3d"].squeeze(
0
) # Expected output shape: (H_reso, W_reso, L, 3)
# mask
if model_mask:
pred_dict = model_mask(input_tensor)
assert "seg_prob" in pred_dict
valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
else:
h, w, l, _ = lari_map.shape
valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
# crop & resize the output to the original resolution.
if original_size[0] != args.resolution or original_size[1] != args.resolution:
lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
valid_mask = post_process_output(
valid_mask.float(), crop_coords, original_size
).bool() # H W L 1
max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
valid_mask = valid_mask[:, :, :max_n_layer, :]
lari_map = lari_map[:, :, :max_n_layer, :]
curr_layer_id = min(max_n_layer - 1, layered_id - 1)
# masked depth list
depth_image = get_masked_depth(
lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
)
# point cloud
glb_path, ply_path = get_point_cloud(
lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo"
)
return (
depth_image,
glb_path,
lari_map,
valid_mask,
0,
max_n_layer - 1,
glb_path,
ply_path,
pil_input,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
# Download the file
model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
# Load the model with pretrained weights.
model_pm = load_model(args.model_info_pm, model_path_pm, device)
model_mask = (
load_model(args.model_info_mask, model_path_mask, device)
if args.model_info_mask is not None
else None
)
def change_layer(slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id):
if lari_map is None:
return
slider_layer_id = slider_layer_id - 1
curr_layer_id = min(slider_layer_id, max_layer_id)
curr_layer_id = max(curr_layer_id, min_layer_id)
# masked depth list
depth_image = get_masked_depth(
lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
)
return depth_image
def clear_everything():
return (
gradio.update(value=None),
gradio.update(value=None),
gradio.update(value=None),
gradio.update(value=None),
gradio.update(value=None),
gradio.update(value=None),
gradio.update(value=None),
)
with gradio.Blocks(
css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
title="LaRI Demo",
) as demo:
gradio.HTML(
"""
<h1 style="text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 1em;">
LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning
</h1>
"""
)
gradio.HTML(
"""
<p style="font-size: 16px; line-height: 1.6;">
This is the official demo of Layered Ray Intersections
(<a href="https://ruili3.github.io/lari/index.html" target="_blank" style="color: #42aaf5;">LaRI</a>).
This demo currently supports object-level reconstruction only.
</p>
<h3 style="color: #42aaf5;">Get Started</h3>
<p style="font-size: 16px; line-height: 1.6;">
As a quick start, click one image from `Examples` and press 'Process'. Try it out with your own images by following these steps:
<ul>
<li>Load an image</li>
<li>(Optional) Check the 'Remove Background' box</li>
<li>Click the 'Process' button</li>
<li>Explore the layered depth maps (z-channel of the LaRI point map) by adjusting the 'Layer ID' slider</li>
</ul>
</p>
<p style="font-size: 16px; line-height: 1.6;">
In the '3D Point Cloud' view, different colors represent different intersection layers:
<span style="color: #FFBD1C;">Layer 1</span>,
<span style="color: #FB5607;">Layer 2</span>,
<span style="color: #F15BB5;">Layer 3</span>,
<span style="color: #8338EC;">Layer 4</span>.
</p>
<h3 style="color: #42aaf5;">Contact</h3>
<p style="font-size: 16px; line-height: 1.6;">
If you have any questions, feel free to open an issue on our
<a href="https://github.com/ruili3/lari" target="_blank" style="color: #42aaf5;">GitHub repository</a> ⭐
</p>
"""
)
# , <b style="color: #3A86FF;">layer 5</b>.
lari_map = gradio.State(None)
valid_mask = gradio.State(None)
min_layer_id = gradio.State(None)
max_layer_id = gradio.State(None)
with gradio.Column():
with gradio.Row(equal_height=True):
with gradio.Column(scale=1):
image_input = gradio.Image(
label="Upload an Image", type="pil", height=350
)
with gradio.Row():
rembg_checkbox = gradio.Checkbox(label="Remove background")
clear_button = gradio.Button("Clear")
submit_btn = gradio.Button("Process")
with gradio.Column(scale=1):
depth_output = gradio.Image(
label="LaRI Map at Z-axis (depth)",
type="pil",
interactive=False,
height=300,
)
slider_layer_id = gradio.Slider(
minimum=1,
maximum=4,
step=1,
value=1,
label="Layer ID",
interactive=True,
)
with gradio.Row(scale=1):
outmodel = gradio.Model3D(
label="3D Point Cloud (Color denotes different layers)",
interactive=False,
zoom_speed=0.5,
pan_speed=0.5,
height=450,
)
with gradio.Row():
ply_file_output = gradio.File(label="ply output", elem_classes="small-file")
glb_file_output = gradio.File(label="glb output", elem_classes="small-file")
submit_btn.click(
fn=model_forward,
inputs=[image_input, slider_layer_id, rembg_checkbox],
outputs=[
depth_output,
outmodel,
lari_map,
valid_mask,
min_layer_id,
max_layer_id,
glb_file_output,
ply_file_output,
image_input,
],
)
clear_button.click(
fn=clear_everything,
outputs=[
lari_map,
valid_mask,
min_layer_id,
max_layer_id,
image_input,
depth_output,
outmodel,
],
)
slider_layer_id.change(
fn=change_layer,
inputs=[slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id],
outputs=depth_output,
)
gradio.Examples(examples=["./assets/cole_hardware.png",
"./assets/3m_tape.png",
"./assets/horse.png",
"./assets/rhino.png",
"./assets/alphabet.png",
"./assets/martin_wedge.png",
"./assets/d_rose.png",
"./assets/ace.png",
"./assets/bifidus.png",
"./assets/fem.png",
],
inputs=image_input)
demo.launch(share=False)