import trimesh import torch import numpy as np import os import math import torchvision from tqdm import tqdm import cv2 # Assuming OpenCV is used for image saving from PIL import Image import pytorch3d import random from PIL import ImageGrab torchvision from torchvision.utils import save_image from pytorch3d.renderer import ( PointsRasterizationSettings, PointsRenderer, PointsRasterizer, AlphaCompositor, PerspectiveCameras, ) import imageio import torch.nn.functional as F from torchvision.transforms import ToPILImage import copy from scipy.interpolate import interp1d from scipy.interpolate import UnivariateSpline from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Slerp import sys sys.path.append('./extern/dust3r') from dust3r.utils.device import to_numpy import matplotlib.pyplot as plt import matplotlib.colors as mcolors from torchvision.transforms import CenterCrop, Compose, Resize def save_video(data,images_path,folder=None): if isinstance(data, np.ndarray): tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8) elif isinstance(data, torch.Tensor): tensor_data = (data.detach().cpu() * 255).to(torch.uint8) elif isinstance(data, list): folder = [folder]*len(data) images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)] stacked_images = np.stack(images, axis=0) tensor_data = torch.from_numpy(stacked_images).to(torch.uint8) torchvision.io.write_video(images_path, tensor_data, fps=8, video_codec='h264', options={'crf': '10'}) def get_input_dict(img_tensor,idx,dtype = torch.float32): return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':img_tensor.to(dtype)} # return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':ToPILImage()((img_tensor.squeeze(0)+ 1) / 2)} def rotate_theta(c2ws_input, theta, phi, r, device): # theta: 图像的倾角,新的y’轴(位于yoz平面)与y轴的夹角 #让相机在以[0,0,depth_avg]为球心的球面上运动,可以先让其在[0,0,0]为球心的球面运动,方便计算旋转矩阵,之后在平移 c2ws = copy.deepcopy(c2ws_input) c2ws[:,2, 3] = c2ws[:,2, 3] + r #将相机坐标系沿着世界坐标系-z方向平移r # 计算旋转向量 theta = torch.deg2rad(torch.tensor(theta)).to(device) phi = torch.deg2rad(torch.tensor(phi)).to(device) v = torch.tensor([0, torch.cos(theta), torch.sin(theta)]) # 计算反对称矩阵 v_x = torch.zeros(3, 3).to(device) v_x[0, 1] = -v[2] v_x[0, 2] = v[1] v_x[1, 0] = v[2] v_x[1, 2] = -v[0] v_x[2, 0] = -v[1] v_x[2, 1] = v[0] # 计算反对称矩阵的平方 v_x_square = torch.matmul(v_x, v_x) # 计算旋转矩阵 R = torch.eye(3).to(device) + torch.sin(phi) * v_x + (1 - torch.cos(phi)) * v_x_square # 转换为齐次表示 R_h = torch.eye(4) R_h[:3, :3] = R Rot_mat = R_h.to(device) c2ws = torch.matmul(Rot_mat, c2ws) c2ws[:,2, 3]= c2ws[:,2, 3] - r #最后减去r,相当于绕着z=|r|为中心旋转 return c2ws def sphere2pose(c2ws_input, theta, phi, r, device): c2ws = copy.deepcopy(c2ws_input) #先沿着世界坐标系z轴方向平移再旋转 c2ws[:,2,3] += r theta = torch.deg2rad(torch.tensor(theta)).to(device) sin_value_x = torch.sin(theta) cos_value_x = torch.cos(theta) rot_mat_x = torch.tensor([[1, 0, 0, 0], [0, cos_value_x, -sin_value_x, 0], [0, sin_value_x, cos_value_x, 0], [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) phi = torch.deg2rad(torch.tensor(phi)).to(device) sin_value_y = torch.sin(phi) cos_value_y = torch.cos(phi) rot_mat_y = torch.tensor([[cos_value_y, 0, sin_value_y, 0], [0, 1, 0, 0], [-sin_value_y, 0, cos_value_y, 0], [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) c2ws = torch.matmul(rot_mat_x,c2ws) c2ws = torch.matmul(rot_mat_y,c2ws) return c2ws def generate_candidate_poses(c2ws_anchor,H,W,fs,c,theta, phi,num_candidates,device): # Initialize a camera. """ The camera coordinate sysmte in COLMAP is right-down-forward Pytorch3D is left-up-forward """ if num_candidates == 2: thetas = np.array([0,-theta]) phis = np.array([phi,phi]) elif num_candidates == 3: thetas = np.array([0,-theta,theta/2.]) #avoid too many downward phis = np.array([phi,phi,phi]) else: raise ValueError("NBV mode only supports 2 or 3 candidates per iteration.") c2ws_list = [] for th, ph in zip(thetas,phis): c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), r=None, device= device) c2ws_list.append(c2w_new) c2ws = torch.cat(c2ws_list,dim=0) num_views = c2ws.shape[0] R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] ## 将dust3r坐标系转成pytorch3d坐标系 R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation new_c2w = torch.cat([R, T], 2) w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix image_size = ((H, W),) # (h, w) cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) return cameras,thetas,phis def generate_traj_specified(c2ws_anchor,H,W,fs,c,theta, phi,d_r,frame,device): # Initialize a camera. """ The camera coordinate sysmte in COLMAP is right-down-forward Pytorch3D is left-up-forward """ thetas = np.linspace(0,theta,frame) phis = np.linspace(0,phi,frame) rs = np.linspace(0,d_r*c2ws_anchor[0,2,3].cpu(),frame) c2ws_list = [] for th, ph, r in zip(thetas,phis,rs): c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device) c2ws_list.append(c2w_new) c2ws = torch.cat(c2ws_list,dim=0) num_views = c2ws.shape[0] R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] ## 将dust3r坐标系转成pytorch3d坐标系 R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation new_c2w = torch.cat([R, T], 2) w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix image_size = ((H, W),) # (h, w) cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) return cameras,num_views def generate_traj_txt(c2ws_anchor,H,W,fs,c,phi, theta, r,frame,device,viz_traj=False, save_dir = None): # Initialize a camera. """ The camera coordinate sysmte in COLMAP is right-down-forward Pytorch3D is left-up-forward """ c2ws_anchor = c2ws_anchor.to(device) if len(phi)>3: phis = txt_interpolation(phi,frame,mode='smooth') phis[0] = phi[0] phis[-1] = phi[-1] else: phis = txt_interpolation(phi,frame,mode='linear') if len(theta)>3: thetas = txt_interpolation(theta,frame,mode='smooth') thetas[0] = theta[0] thetas[-1] = theta[-1] else: thetas = txt_interpolation(theta,frame,mode='linear') if len(r) >3: rs = txt_interpolation(r,frame,mode='smooth') rs[0] = r[0] rs[-1] = r[-1] else: rs = txt_interpolation(r,frame,mode='linear') rs = rs*c2ws_anchor[0,2,3].cpu().numpy() c2ws_list = [] for th, ph, r in zip(thetas,phis,rs): c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device) c2ws_list.append(c2w_new) c2ws = torch.cat(c2ws_list,dim=0) if viz_traj: poses = c2ws.cpu().numpy() # visualizer(poses, os.path.join(save_dir,'viz_traj.png')) frames = [visualizer_frame(poses, i) for i in range(len(poses))] save_video(np.array(frames)/255.,os.path.join(save_dir,'viz_traj.mp4')) num_views = c2ws.shape[0] R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] ## 将dust3r坐标系转成pytorch3d坐标系 R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation new_c2w = torch.cat([R, T], 2) w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix image_size = ((H, W),) # (h, w) cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) return cameras,num_views def setup_renderer(cameras, image_size): # Define the settings for rasterization and shading. raster_settings = PointsRasterizationSettings( image_size=image_size, radius = 0.01, points_per_pixel = 10, bin_size = 0 ) renderer = PointsRenderer( rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), compositor=AlphaCompositor() ) render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} return render_setup def interpolate_sequence(sequence, k,device): N, M = sequence.size() weights = torch.linspace(0, 1, k+1).view(1, -1, 1).to(device) left_values = sequence[:-1].unsqueeze(1).repeat(1, k+1, 1) right_values = sequence[1:].unsqueeze(1).repeat(1, k+1, 1) new_sequence = torch.einsum("ijk,ijl->ijl", (1 - weights), left_values) + torch.einsum("ijk,ijl->ijl", weights, right_values) new_sequence = new_sequence.reshape(-1, M) new_sequence = torch.cat([new_sequence, sequence[-1].view(1, -1)], dim=0) return new_sequence def focus_point_fn(c2ws: torch.Tensor) -> torch.Tensor: """Calculate nearest point to all focal axes in camera-to-world matrices.""" # Extract camera directions and origins from c2ws directions, origins = c2ws[:, :3, 2:3], c2ws[:, :3, 3:4] m = torch.eye(3).to(c2ws.device) - directions * torch.transpose(directions, 1, 2) mt_m = torch.transpose(m, 1, 2) @ m focus_pt = torch.inverse(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] return focus_pt def generate_camera_path(c2ws: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: n_poses = c2ws.shape[0] interpolated_poses = [] for i in range(n_poses-1): start_pose = c2ws[i] end_pose = c2ws[(i + 1) % n_poses] focus_point = focus_point_fn(torch.stack([start_pose,end_pose])) interpolated_path = interpolate_poses(start_pose, end_pose, focus_point, n_inserts, device) # Exclude the last pose (end_pose) for all pairs interpolated_path = interpolated_path[:-1] interpolated_poses.append(interpolated_path) # Concatenate all the interpolated paths interpolated_poses.append(c2ws[-1:]) full_path = torch.cat(interpolated_poses, dim=0) return full_path def interpolate_poses(start_pose: torch.Tensor, end_pose: torch.Tensor, focus_point: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: dtype = start_pose.dtype start_distance = torch.sqrt((start_pose[0, 3] - focus_point[0])**2 + (start_pose[1, 3] - focus_point[1])**2 + (start_pose[2, 3] - focus_point[2])**2) end_distance = torch.sqrt((end_pose[0, 3] - focus_point[0])**2 + (end_pose[1, 3] - focus_point[1])**2 + (end_pose[2, 3] - focus_point[2])**2) start_rot = R.from_matrix(start_pose[:3, :3].cpu().numpy()) end_rot = R.from_matrix(end_pose[:3, :3].cpu().numpy()) slerp_obj = Slerp([0, 1], R.from_quat([start_rot.as_quat(), end_rot.as_quat()])) inserted_c2ws = [] for t in torch.linspace(0., 1., n_inserts + 2, dtype=dtype): # Exclude the first and last point interpolated_rot = slerp_obj(t).as_matrix() interpolated_translation = (1 - t) * start_pose[:3, 3] + t * end_pose[:3, 3] interpolated_distance = (1 - t) * start_distance + t * end_distance direction = (interpolated_translation - focus_point) / torch.norm(interpolated_translation - focus_point) interpolated_translation = focus_point + direction * interpolated_distance inserted_pose = torch.eye(4, dtype=dtype).to(device) inserted_pose[:3, :3] = torch.from_numpy(interpolated_rot).to(device) inserted_pose[:3, 3] = interpolated_translation inserted_c2ws.append(inserted_pose) path = torch.stack(inserted_c2ws) return path def inv(mat): """ Invert a torch or numpy matrix """ if isinstance(mat, torch.Tensor): return torch.linalg.inv(mat) if isinstance(mat, np.ndarray): return np.linalg.inv(mat) raise ValueError(f'bad matrix type = {type(mat)}') def save_pointcloud_with_normals(imgs, pts3d, msk, save_path, mask_pc, reduce_pc): pc = get_pc(imgs, pts3d, msk,mask_pc,reduce_pc) # Assuming get_pc is defined elsewhere and returns a trimesh point cloud # Define a default normal, e.g., [0, 1, 0] default_normal = [0, 1, 0] # Prepare vertices, colors, and normals for saving vertices = pc.vertices colors = pc.colors normals = np.tile(default_normal, (vertices.shape[0], 1)) # Construct the header of the PLY file header = """ply format ascii 1.0 element vertex {} property float x property float y property float z property uchar red property uchar green property uchar blue property float nx property float ny property float nz end_header """.format(len(vertices)) # Write the PLY file with open(save_path, 'w') as ply_file: ply_file.write(header) for vertex, color, normal in zip(vertices, colors, normals): ply_file.write('{} {} {} {} {} {} {} {} {}\n'.format( vertex[0], vertex[1], vertex[2], int(color[0]), int(color[1]), int(color[2]), normal[0], normal[1], normal[2] )) def get_pc(imgs, pts3d, mask, mask_pc=False, reduce_pc=False): imgs = to_numpy(imgs) pts3d = to_numpy(pts3d) mask = to_numpy(mask) if mask_pc: pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) else: pts = np.concatenate([p for p in pts3d]) col = np.concatenate([p for p in imgs]) if reduce_pc: pts = pts.reshape(-1, 3)[::3] col = col.reshape(-1, 3)[::3] else: pts = pts.reshape(-1, 3) col = col.reshape(-1, 3) #mock normals: normals = np.tile([0, 1, 0], (pts.shape[0], 1)) pct = trimesh.PointCloud(pts, colors=col) # debug # pct.export('output.ply') # print('exporting output.ply') pct.vertices_normal = normals # Manually add normals to the point cloud return pct#, pts def world_to_kth(poses, k): # 将世界坐标系转到和第k个pose的相机坐标系一致 kth_pose = poses[k] inv_kth_pose = torch.inverse(kth_pose) new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) return new_poses def world_point_to_kth(poses, points, k, device): # 将世界坐标系转到和第k个pose的相机坐标系一致,同时处理点云 kth_pose = poses[k] inv_kth_pose = torch.inverse(kth_pose) # 给所有pose左成kth_w2c,将其都变到kth_pose的camera coordinate下 new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) N, W, H, _ = points.shape points = points.view(N, W * H, 3) homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) new_points = inv_kth_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) return new_poses, new_points def world_point_to_obj(poses, points, k, r, elevation, device): ## 作用:将世界坐标系转到object的中心 ## 先将世界坐标系转到指定相机 poses, points = world_point_to_kth(poses, points, k, device) ## 定义目标坐标系位姿, 原点位于object中心(远世界坐标系[0,0,r]),Y轴向上, Z轴垂直屏幕向外, X轴向右 elevation_rad = torch.deg2rad(torch.tensor(180-elevation)).to(device) sin_value_x = torch.sin(elevation_rad) cos_value_x = torch.cos(elevation_rad) R = torch.tensor([[1, 0, 0,], [0, cos_value_x, sin_value_x], [0, -sin_value_x, cos_value_x]]).to(device) t = torch.tensor([0, 0, r]).to(device) pose_obj = torch.eye(4).to(device) pose_obj[:3, :3] = R pose_obj[:3, 3] = t ## 给所有点和pose乘以目标坐标系的逆(w2c),将它们变换到目标坐标系下 inv_obj_pose = torch.inverse(pose_obj) new_poses = torch.bmm(inv_obj_pose.unsqueeze(0).expand_as(poses), poses) N, W, H, _ = points.shape points = points.view(N, W * H, 3) homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) new_points = inv_obj_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) return new_poses, new_points def txt_interpolation(input_list,n,mode = 'smooth'): x = np.linspace(0, 1, len(input_list)) if mode == 'smooth': f = UnivariateSpline(x, input_list, k=3) elif mode == 'linear': f = interp1d(x, input_list) else: raise KeyError(f"Invalid txt interpolation mode: {mode}") xnew = np.linspace(0, 1, n) ynew = f(xnew) return ynew def visualizer(camera_poses, save_path="out.png"): fig = plt.figure() ax = fig.add_subplot(111, projection="3d") colors = ["blue" for _ in camera_poses] for pose, color in zip(camera_poses, colors): camera_positions = pose[:3, 3] ax.scatter( camera_positions[0], camera_positions[1], camera_positions[2], c=color, marker="o", ) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Z") ax.set_title("Camera trajectory") # ax.view_init(90+30, -90) plt.savefig(save_path) plt.close() def visualizer_frame(camera_poses, highlight_index): fig = plt.figure() ax = fig.add_subplot(111, projection="3d") # 获取camera_positions[2]的最大值和最小值 z_values = [pose[:3, 3][2] for pose in camera_poses] z_min, z_max = min(z_values), max(z_values) # 创建一个颜色映射对象 cmap = mcolors.LinearSegmentedColormap.from_list("mycmap", ["#00008B", "#ADD8E6"]) # cmap = plt.get_cmap("coolwarm") norm = mcolors.Normalize(vmin=z_min, vmax=z_max) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) for i, pose in enumerate(camera_poses): camera_positions = pose[:3, 3] color = "blue" if i == highlight_index else "blue" size = 100 if i == highlight_index else 25 color = sm.to_rgba(camera_positions[2]) # 根据camera_positions[2]的值映射颜色 ax.scatter( camera_positions[0], camera_positions[1], camera_positions[2], c=color, marker="o", s=size, ) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Z") # ax.set_title("Camera trajectory") ax.view_init(90+30, -90) plt.ylim(-0.1,0.2) fig.canvas.draw() width, height = fig.canvas.get_width_height() img = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8').reshape(height, width, 3) # new_width = int(width * 0.6) # start_x = (width - new_width) // 2 + new_width // 5 # end_x = start_x + new_width # img = img[:, start_x:end_x, :] plt.close() return img def center_crop_image(input_image): height = 576 width = 1024 _,_,h,w = input_image.shape h_ratio = h / height w_ratio = w / width if h_ratio > w_ratio: h = int(h / w_ratio) if h < height: h = height input_image = Resize((h, width))(input_image) else: w = int(w / h_ratio) if w < width: w = width input_image = Resize((height, w))(input_image) transformer = Compose([ # Resize(width), CenterCrop((height, width)), ]) input_image = transformer(input_image) return input_image