|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Iterable |
|
import os |
|
import matplotlib.pyplot as plt |
|
import random |
|
import torch |
|
import numpy as np |
|
import time |
|
import base64 |
|
from io import BytesIO |
|
|
|
import util.misc as misc |
|
import util.lr_sched as lr_sched |
|
|
|
from pytorch3d.structures import Pointclouds |
|
from pytorch3d.vis.plotly_vis import plot_scene |
|
from pytorch3d.transforms import RotateAxisAngle |
|
from pytorch3d.io import IO |
|
|
|
|
|
def evaluate_points(predicted_xyz, gt_xyz, dist_thres): |
|
if predicted_xyz.shape[0] == 0: |
|
return 0.0, 0.0, 0.0 |
|
slice_size = 1000 |
|
precision = 0.0 |
|
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))): |
|
start = slice_size * i |
|
end = slice_size * (i + 1) |
|
dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5 |
|
precision += ((dist < dist_thres).sum(axis=1) > 0).sum() |
|
precision /= predicted_xyz.shape[0] |
|
|
|
recall = 0.0 |
|
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))): |
|
start = slice_size * i |
|
end = slice_size * (i + 1) |
|
dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5 |
|
recall += ((dist < dist_thres).sum(axis=0) > 0).sum() |
|
recall /= gt_xyz.shape[0] |
|
return precision, recall, get_f1(precision, recall) |
|
|
|
def aug_xyz(seen_xyz, unseen_xyz, args, is_train): |
|
degree_x = 0 |
|
degree_y = 0 |
|
degree_z = 0 |
|
if is_train: |
|
r_delta = args.random_scale_delta |
|
scale = torch.tensor([ |
|
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
|
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
|
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
|
], device=seen_xyz.device) |
|
|
|
if args.use_hypersim: |
|
shift = 0 |
|
else: |
|
degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
|
degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
|
degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
|
|
|
r_shift = args.random_shift |
|
shift = torch.tensor([[[ |
|
random.uniform(-r_shift, r_shift), |
|
random.uniform(-r_shift, r_shift), |
|
random.uniform(-r_shift, r_shift), |
|
]]], device=seen_xyz.device) |
|
seen_xyz = seen_xyz * scale + shift |
|
unseen_xyz = unseen_xyz * scale + shift |
|
|
|
B, H, W, _ = seen_xyz.shape |
|
return [ |
|
rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)), |
|
rotate(unseen_xyz, degree_x, degree_y, degree_z), |
|
] |
|
|
|
|
|
def rotate(sample, degree_x, degree_y, degree_z): |
|
for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]: |
|
if degree != 0: |
|
sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample) |
|
return sample |
|
|
|
|
|
def get_grid(B, device, co3d_world_size, granularity): |
|
N = int(np.ceil(2 * co3d_world_size / granularity)) |
|
grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device) |
|
for i in range(N): |
|
grid_unseen_xyz[i, :, :, 0] = i |
|
for j in range(N): |
|
grid_unseen_xyz[:, j, :, 1] = j |
|
for k in range(N): |
|
grid_unseen_xyz[:, :, k, 2] = k |
|
grid_unseen_xyz -= (N / 2.0) |
|
grid_unseen_xyz /= (N / 2.0) / co3d_world_size |
|
grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1) |
|
return grid_unseen_xyz |
|
|
|
|
|
def run_viz(model, data_loader, device, args, epoch): |
|
epoch_start_time = time.time() |
|
model.eval() |
|
os.system(f'mkdir {args.job_dir}/viz') |
|
|
|
print('Visualization data_loader length:', len(data_loader)) |
|
dataset = data_loader.dataset |
|
for sample_idx, samples in enumerate(data_loader): |
|
if sample_idx >= args.max_n_viz_obj: |
|
break |
|
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True) |
|
|
|
pred_occupy = [] |
|
pred_colors = [] |
|
(model.module if hasattr(model, "module") else model).clear_cache() |
|
|
|
|
|
max_n_queries_fwd = 2000 |
|
|
|
total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd)) |
|
for p_idx in range(total_n_passes): |
|
p_start = p_idx * max_n_queries_fwd |
|
p_end = (p_idx + 1) * max_n_queries_fwd |
|
cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
|
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_() |
|
cur_labels = labels[:, p_start:p_end].zero_() |
|
|
|
with torch.no_grad(): |
|
_, pred, = model( |
|
seen_images=seen_images, |
|
seen_xyz=seen_xyz, |
|
unseen_xyz=cur_unseen_xyz, |
|
unseen_rgb=cur_unseen_rgb, |
|
unseen_occupy=cur_labels, |
|
cache_enc=args.run_viz, |
|
valid_seen_xyz=valid_seen_xyz, |
|
) |
|
|
|
cur_occupy_out = pred[..., 0] |
|
|
|
if args.regress_color: |
|
cur_color_out = pred[..., 1:].reshape((-1, 3)) |
|
else: |
|
cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0 |
|
pred_occupy.append(cur_occupy_out) |
|
pred_colors.append(cur_color_out) |
|
|
|
rank = misc.get_rank() |
|
prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}' |
|
|
|
img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8) |
|
|
|
gt_xyz = samples[1][0].to(device).reshape(-1, 3) |
|
gt_rgb = samples[1][1].to(device).reshape(-1, 3) |
|
mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None |
|
|
|
with open(prefix + '.html', 'a') as f: |
|
generate_html( |
|
img, |
|
seen_xyz, seen_images, |
|
torch.cat(pred_occupy, dim=1), |
|
torch.cat(pred_colors, dim=0), |
|
unseen_xyz, |
|
f, |
|
gt_xyz=gt_xyz, |
|
gt_rgb=gt_rgb, |
|
mesh_xyz=mesh_xyz, |
|
) |
|
print("Visualization epoch time:", time.time() - epoch_start_time) |
|
|
|
|
|
def get_f1(precision, recall): |
|
if (precision + recall) == 0: |
|
return 0.0 |
|
return 2.0 * precision * recall / (precision + recall) |
|
|
|
|
|
def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, |
|
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], |
|
pointcloud_marker_size=2, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clouds = {"MCC Output": {}} |
|
|
|
if seen_xyz is not None: |
|
seen_xyz = seen_xyz.reshape((-1, 3)).cpu() |
|
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu() |
|
good_seen = seen_xyz[:, 0] != -100 |
|
|
|
seen_pc = Pointclouds( |
|
points=seen_xyz[good_seen][None], |
|
features=seen_rgb[good_seen][None], |
|
) |
|
clouds["MCC Output"]["seen"] = seen_pc |
|
|
|
|
|
if gt_xyz is not None: |
|
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000) |
|
gt_pc = Pointclouds( |
|
points=gt_xyz[subset_gt][None], |
|
features=gt_rgb[subset_gt][None], |
|
) |
|
clouds["MCC Output"]["GT points"] = gt_pc |
|
|
|
|
|
if mesh_xyz is not None: |
|
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000) |
|
mesh_pc = Pointclouds( |
|
points=mesh_xyz[subset_mesh][None], |
|
) |
|
clouds["MCC Output"]["GT mesh"] = mesh_pc |
|
|
|
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu() |
|
for t in score_thresholds: |
|
pos = pred_occ > t |
|
|
|
points = unseen_xyz[pos].reshape((-1, 3)) |
|
features = pred_rgb[None][pos].reshape((-1, 3)) |
|
good_points = points[:, 0] != -100 |
|
|
|
if good_points.sum() == 0: |
|
continue |
|
|
|
pc = Pointclouds( |
|
points=points[good_points][None].cpu(), |
|
features=features[good_points][None].cpu(), |
|
) |
|
|
|
clouds["MCC Output"][f"pred_{t}"] = pc |
|
IO().save_pointcloud(pc, "output_pointcloud.ply") |
|
|
|
plt.figure() |
|
try: |
|
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2) |
|
fig.update_layout(height=1000, width=1000) |
|
return fig |
|
except Exception as e: |
|
print('writing failed', e) |
|
try: |
|
plt.close() |
|
except: |
|
pass |
|
|
|
|
|
def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f, |
|
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], |
|
pointcloud_marker_size=2, |
|
): |
|
if img is not None: |
|
fig = plt.figure() |
|
plt.imshow(img) |
|
tmpfile = BytesIO() |
|
fig.savefig(tmpfile, format='jpg') |
|
encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8') |
|
|
|
html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded) |
|
f.write(html) |
|
plt.close() |
|
|
|
clouds = {"MCC Output": {}} |
|
|
|
if seen_xyz is not None: |
|
seen_xyz = seen_xyz.reshape((-1, 3)).cpu() |
|
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu() |
|
good_seen = seen_xyz[:, 0] != -100 |
|
|
|
seen_pc = Pointclouds( |
|
points=seen_xyz[good_seen][None], |
|
features=seen_rgb[good_seen][None], |
|
) |
|
clouds["MCC Output"]["seen"] = seen_pc |
|
|
|
|
|
if gt_xyz is not None: |
|
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000) |
|
gt_pc = Pointclouds( |
|
points=gt_xyz[subset_gt][None], |
|
features=gt_rgb[subset_gt][None], |
|
) |
|
clouds["MCC Output"]["GT points"] = gt_pc |
|
|
|
|
|
if mesh_xyz is not None: |
|
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000) |
|
mesh_pc = Pointclouds( |
|
points=mesh_xyz[subset_mesh][None], |
|
) |
|
clouds["MCC Output"]["GT mesh"] = mesh_pc |
|
|
|
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu() |
|
for t in score_thresholds: |
|
pos = pred_occ > t |
|
|
|
points = unseen_xyz[pos].reshape((-1, 3)) |
|
features = pred_rgb[None][pos].reshape((-1, 3)) |
|
good_points = points[:, 0] != -100 |
|
|
|
if good_points.sum() == 0: |
|
continue |
|
|
|
pc = Pointclouds( |
|
points=points[good_points][None].cpu(), |
|
features=features[good_points][None].cpu(), |
|
) |
|
|
|
clouds["MCC Output"][f"pred_{t}"] = pc |
|
|
|
plt.figure() |
|
try: |
|
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2) |
|
fig.update_layout(height=1000, width=1000) |
|
html_string = fig.to_html(full_html=False, include_plotlyjs="cnd") |
|
f.write(html_string) |
|
return fig, plt |
|
except Exception as e: |
|
print('writing failed', e) |
|
try: |
|
plt.close() |
|
except: |
|
pass |
|
|
|
|
|
def train_one_epoch(model: torch.nn.Module, |
|
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
|
device: torch.device, epoch: int, loss_scaler, |
|
args=None): |
|
epoch_start_time = time.time() |
|
model.train(True) |
|
metric_logger = misc.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
|
accum_iter = args.accum_iter |
|
|
|
optimizer.zero_grad() |
|
|
|
print('Training data_loader length:', len(data_loader)) |
|
for data_iter_step, samples in enumerate(data_loader): |
|
|
|
if data_iter_step % accum_iter == 0: |
|
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) |
|
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args) |
|
|
|
with torch.cuda.amp.autocast(): |
|
loss, _ = model( |
|
seen_images=seen_images, |
|
seen_xyz=seen_xyz, |
|
unseen_xyz=unseen_xyz, |
|
unseen_rgb=unseen_rgb, |
|
unseen_occupy=labels, |
|
valid_seen_xyz=valid_seen_xyz, |
|
) |
|
|
|
loss_value = loss.item() |
|
if not math.isfinite(loss_value): |
|
print("Warning: Loss is {}".format(loss_value)) |
|
loss *= 0.0 |
|
loss_value = 100.0 |
|
|
|
loss /= accum_iter |
|
loss_scaler(loss, optimizer, parameters=model.parameters(), |
|
clip_grad=args.clip_grad, |
|
update_grad=(data_iter_step + 1) % accum_iter == 0, |
|
verbose=(data_iter_step % 100) == 0) |
|
|
|
if (data_iter_step + 1) % accum_iter == 0: |
|
optimizer.zero_grad() |
|
|
|
torch.cuda.synchronize() |
|
|
|
metric_logger.update(loss=loss_value) |
|
|
|
lr = optimizer.param_groups[0]["lr"] |
|
metric_logger.update(lr=lr) |
|
|
|
if data_iter_step == 30: |
|
os.system('nvidia-smi') |
|
os.system('free -g') |
|
if args.debug and data_iter_step == 5: |
|
break |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
print("Training epoch time:", time.time() - epoch_start_time) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def eval_one_epoch( |
|
model: torch.nn.Module, |
|
data_loader: Iterable, |
|
device: torch.device, |
|
args=None |
|
): |
|
epoch_start_time = time.time() |
|
model.train(False) |
|
|
|
metric_logger = misc.MetricLogger(delimiter=" ") |
|
|
|
print('Eval len(data_loader):', len(data_loader)) |
|
|
|
for data_iter_step, samples in enumerate(data_loader): |
|
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args) |
|
|
|
|
|
max_n_queries_fwd = 5000 |
|
all_loss, all_preds = [], [] |
|
for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))): |
|
p_start = p_idx * max_n_queries_fwd |
|
p_end = (p_idx + 1) * max_n_queries_fwd |
|
cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
|
cur_unseen_rgb = unseen_rgb[:, p_start:p_end] |
|
cur_labels = labels[:, p_start:p_end] |
|
|
|
with torch.no_grad(): |
|
loss, pred = model( |
|
seen_images=seen_images, |
|
seen_xyz=seen_xyz, |
|
unseen_xyz=cur_unseen_xyz, |
|
unseen_rgb=cur_unseen_rgb, |
|
unseen_occupy=cur_labels, |
|
valid_seen_xyz=valid_seen_xyz, |
|
) |
|
all_loss.append(loss) |
|
all_preds.append(pred) |
|
|
|
loss = sum(all_loss) / len(all_loss) |
|
pred = torch.cat(all_preds, dim=1) |
|
|
|
B = pred.shape[0] |
|
|
|
gt_xyz = samples[1][0].to(device).reshape((B, -1, 3)) |
|
if args.use_hypersim: |
|
mesh_xyz = samples[2].to(device).reshape((B, -1, 3)) |
|
|
|
s_thres = args.eval_score_threshold |
|
d_thres = args.eval_dist_threshold |
|
|
|
for b_idx in range(B): |
|
geometry_metrics = {} |
|
predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres |
|
predicted_xyz = unseen_xyz[b_idx, predicted_idx] |
|
|
|
precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres) |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1 |
|
|
|
if args.use_hypersim: |
|
precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres) |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall |
|
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1 |
|
|
|
metric_logger.update(**geometry_metrics) |
|
|
|
loss_value = loss.item() |
|
|
|
torch.cuda.synchronize() |
|
metric_logger.update(loss=loss_value) |
|
|
|
if args.debug and data_iter_step == 5: |
|
break |
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Validation averaged stats:", metric_logger) |
|
print("Val epoch time:", time.time() - epoch_start_time) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def sample_uniform_semisphere(B, N, semisphere_size, device): |
|
for _ in range(100): |
|
points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size) |
|
points[..., 2] = points[..., 2].abs() |
|
dist = (points ** 2.0).sum(axis=-1) ** 0.5 |
|
if (dist < semisphere_size).sum() >= B * N: |
|
return points[dist < semisphere_size][:B * N].reshape((B, N, 3)) |
|
else: |
|
print('resampling sphere') |
|
|
|
|
|
def get_grid_semisphere(B, granularity, semisphere_size, device): |
|
n_grid_pts = int(semisphere_size / granularity) * 2 + 1 |
|
grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device) |
|
for i in range(n_grid_pts): |
|
grid_unseen_xyz[i, :, :, 0] = i |
|
grid_unseen_xyz[:, i, :, 1] = i |
|
for i in range(n_grid_pts // 2 + 1): |
|
grid_unseen_xyz[:, :, i, 2] = i |
|
grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0) |
|
grid_unseen_xyz *= granularity |
|
dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5 |
|
grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size] |
|
return grid_unseen_xyz[None].repeat(B, 1, 1) |
|
|
|
|
|
def get_min_dist(a, b, slice_size=1000): |
|
all_min, all_idx = [], [] |
|
for i in range(int(np.ceil(a.shape[1] / slice_size))): |
|
start = slice_size * i |
|
end = slice_size * (i + 1) |
|
|
|
dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5 |
|
|
|
cur_min, cur_idx = dist.min(axis=2) |
|
all_min.append(cur_min) |
|
all_idx.append(cur_idx) |
|
return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1) |
|
|
|
|
|
def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity): |
|
B = gt_xyz.shape[0] |
|
device = gt_xyz.device |
|
if is_train: |
|
unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device) |
|
else: |
|
unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device) |
|
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None]) |
|
labels = dist < dist_threshold |
|
unseen_rgb = torch.zeros_like(unseen_xyz) |
|
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels] |
|
return unseen_xyz, unseen_rgb, labels.float() |
|
|
|
|
|
def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity): |
|
B = gt_xyz.shape[0] |
|
device = gt_xyz.device |
|
if is_train: |
|
unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size) |
|
else: |
|
unseen_xyz = get_grid(B, device, co3d_world_size, granularity) |
|
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None]) |
|
labels = dist < dist_threshold |
|
unseen_rgb = torch.zeros_like(unseen_xyz) |
|
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels] |
|
return unseen_xyz, unseen_rgb, labels.float() |
|
|
|
|
|
def prepare_data(samples, device, is_train, args, is_viz=False): |
|
|
|
seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device) |
|
valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1)) |
|
seen_xyz[~valid_seen_xyz] = -100 |
|
B = seen_xyz.shape[0] |
|
|
|
gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3) |
|
|
|
sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid |
|
unseen_xyz, unseen_rgb, labels = sampling_func( |
|
gt_xyz, gt_rgb, |
|
args.semisphere_size if args.use_hypersim else args.co3d_world_size, |
|
args.n_queries, |
|
args.train_dist_threshold, |
|
is_train, |
|
args.viz_granularity if is_viz else args.eval_granularity, |
|
) |
|
|
|
if is_train: |
|
seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train) |
|
|
|
|
|
if random.random() < 0.5: |
|
seen_xyz[..., 0] *= -1 |
|
unseen_xyz[..., 0] *= -1 |
|
seen_xyz = torch.flip(seen_xyz, [2]) |
|
valid_seen_xyz = torch.flip(valid_seen_xyz, [2]) |
|
seen_rgb = torch.flip(seen_rgb, [3]) |
|
|
|
return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb |
|
|