import argparse import os import torch import numpy as np import trimesh from scipy.spatial.transform import Rotation import torch.backends.cudnn as cudnn import torch.nn.functional as F from PIL import Image from src.utils.vis import ( prob_to_mask, colorize, denormalize, ) import numpy as np from src.lari.model import LaRIModel, DinoSegModel from rembg import remove from plyfile import PlyData, PlyElement import torchvision.transforms as transforms LAYER_COLOR = [ [255, 190, 11], # FFFF0B [251, 86, 7], # FB5607 [241, 91, 181], # F15BB5 [131, 56, 236], # 8338EC [58, 134, 255], # 3A86FF ] OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) def save_point_cloud(pcd, rgb, filename, binary=True): """Save an RGB point cloud as a PLY file. :paras @pcd: Nx3 matrix, the XYZ coordinates @rgb: Nx3 matrix, the rgb colors for each 3D point """ if rgb is None: gray_concat = np.tile(np.array([128], dtype=np.uint8), (pcd.shape[0], 3)) points_3d = np.hstack((pcd, gray_concat)) else: assert pcd.shape[0] == rgb.shape[0] points_3d = np.hstack((pcd, rgb)) python_types = (float, float, float, int, int, int) npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] if binary is True: # Format into Numpy structured array vertices = [] for row_idx in range(points_3d.shape[0]): cur_point = points_3d[row_idx] vertices.append( tuple( dtype(point) for dtype, point in zip(python_types, cur_point))) vertices_array = np.array(vertices, dtype=npy_types) el = PlyElement.describe(vertices_array, 'vertex') # write PlyData([el]).write(filename) else: x = np.squeeze(points_3d[:, 0]) y = np.squeeze(points_3d[:, 1]) z = np.squeeze(points_3d[:, 2]) r = np.squeeze(points_3d[:, 3]) g = np.squeeze(points_3d[:, 4]) b = np.squeeze(points_3d[:, 5]) ply_head = 'ply\n' \ 'format ascii 1.0\n' \ 'element vertex %d\n' \ 'property float x\n' \ 'property float y\n' \ 'property float z\n' \ 'property uchar red\n' \ 'property uchar green\n' \ 'property uchar blue\n' \ 'end_header' % r.shape[0] # ---- Save ply data to disk np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='') def load_model(model_info, ckpt_path, device): model = eval(model_info) model.to(device) model.eval() # Load pretrained weights ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) if "model" in ckpt: model.load_state_dict(ckpt["model"], strict=False) else: model.load_state_dict(ckpt, strict=False) return model def process_image_custom(pil_image, resolution=512): """ Read an image, resize the long side to `resolution` and pad the short side with gray, so that the final image is (resolution x resolution). Returns: padded_img (PIL.Image): The processed image. crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region. original_size (tuple): (width, height) of the original image. """ pil_image = pil_image.convert("RGB") original_width, original_height = pil_image.size # If already at fixed resolution, no processing is needed. if original_width == resolution and original_height == resolution: crop_coords = (0, 0, resolution, resolution) return pil_image, crop_coords, (original_width, original_height), pil_image # Compute scaling factor based on the long side. if original_width >= original_height: # Width is the long side. scale = resolution / float(original_width) new_width = resolution new_height = int(round(original_height * scale)) resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR) # Compute vertical padding. pad_top = (resolution - new_height) // 2 pad_bottom = resolution - new_height - pad_top pad_left, pad_right = 0, 0 else: # Height is the long side. scale = resolution / float(original_height) new_height = resolution new_width = int(round(original_width * scale)) resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR) # Compute horizontal padding. pad_left = (resolution - new_width) // 2 pad_right = resolution - new_width - pad_left pad_top, pad_bottom = 0, 0 # Create new image filled with black padded_img = Image.new("RGB", (resolution, resolution), (0, 0, 0)) padded_img.paste(resized_img, (pad_left, pad_top)) # The valid region (crop) is where the resized image was pasted. crop_coords = (pad_top, pad_left, pad_top + new_height, pad_left + new_width) return padded_img, crop_coords, (original_width, original_height), pil_image def process_image(pil_image, resolution=512): """ Process the image: apply custom resize/pad then convert to normalized tensor. Returns: img_tensor (torch.Tensor): Tensor of shape (1, 3, resolution, resolution). crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region. original_size (tuple): (width, height) of the original image. """ padded_img, crop_coords, original_size, ori_img = process_image_custom( pil_image, resolution ) transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) img_tensor = transform(padded_img).unsqueeze(0) ori_img_tensor = transform(ori_img).unsqueeze(0) return img_tensor, ori_img_tensor, crop_coords, original_size def post_process_output(input_tensor, crop_coords, original_size): """ Crop the input tensor using the crop_coords and then resize to the original image size. Args: input_tensor (torch.Tensor): Input with shape (H, W, L, C) where C is 1 or 3. crop_coords (tuple): (top, left, bottom, right) coordinates for cropping. original_size (tuple): (width, height) of the original image. Returns: processed_output (torch.Tensor): Output with shape (original_height, original_width, L, C). """ top, left, bottom, right = crop_coords # Crop the input spatially: resulting shape (crop_h, crop_w, L, C) cropped = input_tensor[top:bottom, left:right, ...] crop_h, crop_w, L, C = cropped.shape # New shape becomes (1, L * C, crop_h, crop_w) reshaped = cropped.permute(2, 3, 0, 1).reshape(1, L * C, crop_h, crop_w) # Unpack the original size (width, height) and use bilinear interpolation. new_width, new_height = original_size mode = "nearest" if L == 1 else "bilinear" resized = F.interpolate( reshaped, size=(new_height, new_width), mode=mode, align_corners=False ) resized = resized.reshape(L, C, new_height, new_width) # Permute to the output shape: (new_height, new_width, L, C) processed_output = resized.permute(2, 3, 0, 1) return processed_output def get_masked_depth(lari_map, valid_mask, layer_id): layer_id = max(0, layer_id) lari_depth = lari_map[:, :, layer_id, 2].cpu().numpy() # H W valid_mask = valid_mask[:, :, layer_id, 0].cpu().numpy() # H W valid_values = lari_depth[valid_mask] # Handle empty valid values if valid_values.size == 0: vis_depth_range = [0, 1] else: vis_depth_range = [valid_values.min(), valid_values.max()] depth_image = Image.fromarray( colorize( lari_depth, vis_depth_range[0], vis_depth_range[1], invalid_mask=~valid_mask, cmap="Spectral", ) ).convert("RGB") return depth_image def save_to_glb(pts3d, color3d, path): scene = trimesh.Scene() pct = trimesh.PointCloud(pts3d, colors=color3d) scene.add_geometry(pct) rot_y = np.eye(4) rot_y[:3, :3] = Rotation.from_euler("y", np.deg2rad(180)).as_matrix() scene.apply_transform(np.linalg.inv(OPENGL @ rot_y)) outfile = os.path.join(path, "res.glb") scene.export(file_obj=outfile) return outfile def get_point_cloud(pred, img, mask, first_layer_color="image", target_folder=None): """ pred h w l 3 - the point cloud img: 3 h w - the colored image mask: h w l - indicating the valid layers n_samples: int - n of pts to sample and save """ ori_shape = pred.shape pred = pred.cpu().numpy() pred = pred.reshape(-1, 3) # M 3 color_palette = LAYER_COLOR[: min(len(LAYER_COLOR), ori_shape[-2])] assert first_layer_color in ["image", "pseudo"] # assign color to point clouds: [M,3] -> [M, 6] img = torch.clip(denormalize(img).squeeze(0), 0.0, 1.0) img = img.permute(1, 2, 0).unsqueeze(2).cpu().numpy() # H W 1 3 img = (img * 255.0).astype(np.uint8) layered_color = np.array([[color_palette]]).astype(np.uint8) # 1 1 n_layer 3 layered_color = np.broadcast_to( layered_color, (img.shape[0], img.shape[1], ori_shape[2], 3) ) # H W n_layer 3 if first_layer_color == "image": layered_color[:, :, :1, :] = img layered_color = layered_color.reshape(-1, 3) valid_mask_arr = mask.squeeze().reshape(-1).cpu().numpy() # [H,W,layers] -> [M] pred = pred[valid_mask_arr.astype(bool)] layered_color = layered_color[valid_mask_arr.astype(bool)] # V,3 save_folder = target_folder if target_folder is not None else os.path.dirname(__file__) ply_path = os.path.join(save_folder, "res.ply") save_point_cloud(pred, layered_color, filename=ply_path) glb_path = save_to_glb(pred, layered_color, save_folder) return glb_path, ply_path def removebg_crop(pil_input): pil_input = remove(pil_input.convert("RGB")) pil_np = np.array(pil_input) alpha = pil_np[:, :, 3] is_crop = ( False if np.sum(alpha > 0.8 * 255) > 0.1 * (alpha.shape[0] * alpha.shape[1]) else True ) # adjust object size to fit the image resolution if is_crop: width, height = pil_input.size # adjust object size output_np = np.array(pil_input) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = ( np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]), ) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.5) bbox = ( max(center[0] - size // 2, 0), max(center[1] - size // 2, 0), min(center[0] + size // 2, width), min(center[1] + size // 2, height), ) pil_input = pil_input.crop(bbox) # type: ignore return pil_input