|
import numpy as np |
|
import argparse |
|
import copy |
|
import os, sys |
|
import open3d as o3d |
|
from sys import argv, exit |
|
from PIL import Image |
|
import math |
|
from tqdm import tqdm |
|
import cv2 |
|
|
|
|
|
sys.path.append("../../") |
|
|
|
from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching |
|
import pandas as pd |
|
|
|
|
|
import torch |
|
from lib.model_test import D2Net |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device('cuda:0' if use_cuda else 'cpu') |
|
|
|
|
|
parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.') |
|
|
|
parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/', |
|
help='path to the dataset folder') |
|
|
|
parser.add_argument('--sequence', type=str, default='data1') |
|
|
|
parser.add_argument( |
|
'--output_dir', type=str, default='out', |
|
help='output directory for RT estimates' |
|
) |
|
|
|
parser.add_argument( |
|
'--model_rord', type=str, help='path to the RoRD model for evaluation' |
|
) |
|
|
|
parser.add_argument( |
|
'--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation' |
|
) |
|
|
|
parser.add_argument( |
|
'--model_ens', action='store_true', |
|
help='ensemble model of RoRD + D2-Net' |
|
) |
|
|
|
parser.add_argument( |
|
'--sift', action='store_true', |
|
help='Sift' |
|
) |
|
|
|
parser.add_argument( |
|
'--viz3d', action='store_true', |
|
help='visualize the pointcloud registrations' |
|
) |
|
|
|
parser.add_argument( |
|
'--log_interval', type=int, default=9, |
|
help='Matched image logging interval' |
|
) |
|
|
|
parser.add_argument( |
|
'--camera_file', type=str, default='../../configs/camera.txt', |
|
help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.' |
|
) |
|
|
|
parser.add_argument( |
|
'--persp', action='store_true', default=False, |
|
help='Feature matching on perspective images.' |
|
) |
|
|
|
parser.set_defaults(fp16=False) |
|
args = parser.parse_args() |
|
|
|
|
|
if args.model_ens: |
|
model1_ens = '../../models/rord.pth' |
|
model2_ens = '../../models/d2net.pth' |
|
|
|
def draw_registration_result(source, target, transformation): |
|
source_temp = copy.deepcopy(source) |
|
target_temp = copy.deepcopy(target) |
|
source_temp.transform(transformation) |
|
trgSph.append(source_temp); trgSph.append(target_temp) |
|
axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) |
|
axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) |
|
axis2.transform(transformation) |
|
trgSph.append(axis1); trgSph.append(axis2) |
|
o3d.visualization.draw_geometries(trgSph) |
|
|
|
def readDepth(depthFile): |
|
depth = Image.open(depthFile) |
|
if depth.mode != "I": |
|
raise Exception("Depth image is not in intensity format") |
|
|
|
return np.asarray(depth) |
|
|
|
def readCamera(camera): |
|
with open (camera, "rt") as file: |
|
contents = file.read().split() |
|
|
|
focalX = float(contents[0]) |
|
focalY = float(contents[1]) |
|
centerX = float(contents[2]) |
|
centerY = float(contents[3]) |
|
scalingFactor = float(contents[4]) |
|
|
|
return focalX, focalY, centerX, centerY, scalingFactor |
|
|
|
|
|
def getPointCloud(rgbFile, depthFile, pts): |
|
thresh = 15.0 |
|
|
|
depth = readDepth(depthFile) |
|
rgb = Image.open(rgbFile) |
|
|
|
points = [] |
|
colors = [] |
|
|
|
corIdx = [-1]*len(pts) |
|
corPts = [None]*len(pts) |
|
ptIdx = 0 |
|
|
|
for v in range(depth.shape[0]): |
|
for u in range(depth.shape[1]): |
|
Z = depth[v, u] / scalingFactor |
|
if Z==0: continue |
|
if (Z > thresh): continue |
|
|
|
X = (u - centerX) * Z / focalX |
|
Y = (v - centerY) * Z / focalY |
|
|
|
points.append((X, Y, Z)) |
|
colors.append(rgb.getpixel((u, v))) |
|
|
|
if((u, v) in pts): |
|
index = pts.index((u, v)) |
|
corIdx[index] = ptIdx |
|
corPts[index] = (X, Y, Z) |
|
|
|
ptIdx = ptIdx+1 |
|
|
|
points = np.asarray(points) |
|
colors = np.asarray(colors) |
|
|
|
pcd = o3d.geometry.PointCloud() |
|
pcd.points = o3d.utility.Vector3dVector(points) |
|
pcd.colors = o3d.utility.Vector3dVector(colors/255) |
|
|
|
return pcd, corIdx, corPts |
|
|
|
|
|
def convertPts(A): |
|
X = A[0]; Y = A[1] |
|
|
|
x = []; y = [] |
|
|
|
for i in range(len(X)): |
|
x.append(int(float(X[i]))) |
|
|
|
for i in range(len(Y)): |
|
y.append(int(float(Y[i]))) |
|
|
|
pts = [] |
|
for i in range(len(x)): |
|
pts.append((x[i], y[i])) |
|
|
|
return pts |
|
|
|
|
|
def getSphere(pts): |
|
sphs = [] |
|
|
|
for element in pts: |
|
if(element is not None): |
|
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03) |
|
sphere.paint_uniform_color([0.9, 0.2, 0]) |
|
|
|
trans = np.identity(4) |
|
trans[0, 3] = element[0] |
|
trans[1, 3] = element[1] |
|
trans[2, 3] = element[2] |
|
|
|
sphere.transform(trans) |
|
sphs.append(sphere) |
|
|
|
return sphs |
|
|
|
|
|
def get3dCor(src, trg): |
|
corr = [] |
|
|
|
for sId, tId in zip(src, trg): |
|
if(sId != -1 and tId != -1): |
|
corr.append((sId, tId)) |
|
|
|
corr = np.asarray(corr) |
|
|
|
return corr |
|
|
|
if __name__ == "__main__": |
|
camera_file = args.camera_file |
|
rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv' |
|
depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv' |
|
|
|
os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True) |
|
dir_name = args.output_dir |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file) |
|
|
|
df_rgb = pd.read_csv(rgb_csv) |
|
df_dep = pd.read_csv(depth_csv) |
|
|
|
model1 = D2Net(model_file=args.model_d2).to(device) |
|
model2 = D2Net(model_file=args.model_rord).to(device) |
|
|
|
queryId = 0 |
|
for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]): |
|
filter_list = [] |
|
dbId = 0 |
|
for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]): |
|
if im_d[0] == 'query': |
|
continue |
|
rgb_name_src = os.path.basename(im_q) |
|
H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy' |
|
srcH = args.dataset + args.sequence + '/rgb/' + H_name_src |
|
rgb_name_trg = os.path.basename(im_d[1][1]) |
|
H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy' |
|
trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg |
|
|
|
srcImg = srcH.replace('.npy', '.jpg') |
|
trgImg = trgH.replace('.npy', '.jpg') |
|
|
|
if args.model_rord: |
|
if args.persp: |
|
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) |
|
else: |
|
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device) |
|
|
|
elif args.model_d2: |
|
if args.persp: |
|
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) |
|
else: |
|
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device) |
|
|
|
elif args.model_ens: |
|
model1 = D2Net(model_file=model1_ens) |
|
model1 = model1.to(device) |
|
model2 = D2Net(model_file=model2_ens) |
|
model2 = model2.to(device) |
|
srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device) |
|
|
|
elif args.sift: |
|
if args.persp: |
|
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device) |
|
else: |
|
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device) |
|
|
|
if(isinstance(srcPts, list) == True): |
|
print(np.identity(4)) |
|
filter_list.append(np.identity(4)) |
|
continue |
|
|
|
|
|
srcPts = convertPts(srcPts) |
|
trgPts = convertPts(trgPts) |
|
|
|
depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q |
|
depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1] |
|
|
|
srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts) |
|
trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts) |
|
|
|
srcSph = getSphere(srcCor) |
|
trgSph = getSphere(trgCor) |
|
axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) |
|
srcSph.append(srcCld); srcSph.append(axis) |
|
trgSph.append(trgCld); trgSph.append(axis) |
|
|
|
corr = get3dCor(srcIdx, trgIdx) |
|
|
|
p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint() |
|
trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr)) |
|
|
|
filter_list.append(trans_init) |
|
|
|
if args.viz3d: |
|
o3d.visualization.draw_geometries(srcSph) |
|
o3d.visualization.draw_geometries(trgSph) |
|
draw_registration_result(srcCld, trgCld, trans_init) |
|
|
|
if(dbId%args.log_interval == 0): |
|
cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg) |
|
dbId += 1 |
|
|
|
|
|
RT = np.stack(filter_list).transpose(1,2,0) |
|
|
|
np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT) |
|
queryId += 1 |
|
print('-----check-------', RT.shape) |
|
|