import math |
from typing import Iterable |
import os |
import matplotlib.pyplot as plt |
import random |
import torch |
import numpy as np |
import time |
import base64 |
from io import BytesIO |
import util.misc as misc |
import util.lr_sched as lr_sched |
from pytorch3d.structures import Pointclouds |
from pytorch3d.vis.plotly_vis import plot_scene |
from pytorch3d.transforms import RotateAxisAngle |
from pytorch3d.io import IO |
def evaluate_points(predicted_xyz, gt_xyz, dist_thres): |
if predicted_xyz.shape[0] == 0: |
return 0.0, 0.0, 0.0 |
slice_size = 1000 |
precision = 0.0 |
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))): |
start = slice_size * i |
end = slice_size * (i + 1) |
dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5 |
precision += ((dist < dist_thres).sum(axis=1) > 0).sum() |
precision /= predicted_xyz.shape[0] |
recall = 0.0 |
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))): |
start = slice_size * i |
end = slice_size * (i + 1) |
dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5 |
recall += ((dist < dist_thres).sum(axis=0) > 0).sum() |
recall /= gt_xyz.shape[0] |
return precision, recall, get_f1(precision, recall) |
def aug_xyz(seen_xyz, unseen_xyz, args, is_train): |
degree_x = 0 |
degree_y = 0 |
degree_z = 0 |
if is_train: |
r_delta = args.random_scale_delta |
scale = torch.tensor([ |
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
random.uniform(1.0 - r_delta, 1.0 + r_delta), |
], device=seen_xyz.device) |
if args.use_hypersim: |
shift = 0 |
else: |
degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1) |
r_shift = args.random_shift |
shift = torch.tensor([[[ |
random.uniform(-r_shift, r_shift), |
random.uniform(-r_shift, r_shift), |
random.uniform(-r_shift, r_shift), |
]]], device=seen_xyz.device) |
seen_xyz = seen_xyz * scale + shift |
unseen_xyz = unseen_xyz * scale + shift |
B, H, W, _ = seen_xyz.shape |
return [ |
rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)), |
rotate(unseen_xyz, degree_x, degree_y, degree_z), |
] |
def rotate(sample, degree_x, degree_y, degree_z): |
for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]: |
if degree != 0: |
sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample) |
return sample |
def get_grid(B, device, co3d_world_size, granularity): |
N = int(np.ceil(2 * co3d_world_size / granularity)) |
grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device) |
for i in range(N): |
grid_unseen_xyz[i, :, :, 0] = i |
for j in range(N): |
grid_unseen_xyz[:, j, :, 1] = j |
for k in range(N): |
grid_unseen_xyz[:, :, k, 2] = k |
grid_unseen_xyz -= (N / 2.0) |
grid_unseen_xyz /= (N / 2.0) / co3d_world_size |
grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1) |
return grid_unseen_xyz |
def run_viz(model, data_loader, device, args, epoch): |
epoch_start_time = time.time() |
model.eval() |
os.system(f'mkdir {args.job_dir}/viz') |
print('Visualization data_loader length:', len(data_loader)) |
dataset = data_loader.dataset |
for sample_idx, samples in enumerate(data_loader): |
if sample_idx >= args.max_n_viz_obj: |
break |
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True) |
pred_occupy = [] |
pred_colors = [] |
(model.module if hasattr(model, "module") else model).clear_cache() |
max_n_queries_fwd = 2000 |
total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd)) |
for p_idx in range(total_n_passes): |
p_start = p_idx * max_n_queries_fwd |
p_end = (p_idx + 1) * max_n_queries_fwd |
cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_() |
cur_labels = labels[:, p_start:p_end].zero_() |
with torch.no_grad(): |
_, pred, = model( |
seen_images=seen_images, |
seen_xyz=seen_xyz, |
unseen_xyz=cur_unseen_xyz, |
unseen_rgb=cur_unseen_rgb, |
unseen_occupy=cur_labels, |
cache_enc=args.run_viz, |
valid_seen_xyz=valid_seen_xyz, |
) |
cur_occupy_out = pred[..., 0] |
if args.regress_color: |
cur_color_out = pred[..., 1:].reshape((-1, 3)) |
else: |
cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0 |
pred_occupy.append(cur_occupy_out) |
pred_colors.append(cur_color_out) |
rank = misc.get_rank() |
prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}' |
img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8) |
gt_xyz = samples[1][0].to(device).reshape(-1, 3) |
gt_rgb = samples[1][1].to(device).reshape(-1, 3) |
mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None |
with open(prefix + '.html', 'a') as f: |
generate_html( |
img, |
seen_xyz, seen_images, |
torch.cat(pred_occupy, dim=1), |
torch.cat(pred_colors, dim=0), |
unseen_xyz, |
f, |
gt_xyz=gt_xyz, |
gt_rgb=gt_rgb, |
mesh_xyz=mesh_xyz, |
) |
print("Visualization epoch time:", time.time() - epoch_start_time) |
def get_f1(precision, recall): |
if (precision + recall) == 0: |
return 0.0 |
return 2.0 * precision * recall / (precision + recall) |
def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, |
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], |
pointcloud_marker_size=2, |
): |
clouds = {"MCC Output": {}} |
if seen_xyz is not None: |
seen_xyz = seen_xyz.reshape((-1, 3)).cpu() |
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu() |
good_seen = seen_xyz[:, 0] != -100 |
seen_pc = Pointclouds( |
points=seen_xyz[good_seen][None], |
features=seen_rgb[good_seen][None], |
) |
clouds["MCC Output"]["seen"] = seen_pc |
if gt_xyz is not None: |
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000) |
gt_pc = Pointclouds( |
points=gt_xyz[subset_gt][None], |
features=gt_rgb[subset_gt][None], |
) |
clouds["MCC Output"]["GT points"] = gt_pc |
if mesh_xyz is not None: |
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000) |
mesh_pc = Pointclouds( |
points=mesh_xyz[subset_mesh][None], |
) |
clouds["MCC Output"]["GT mesh"] = mesh_pc |
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu() |
for t in score_thresholds: |
pos = pred_occ > t |
points = unseen_xyz[pos].reshape((-1, 3)) |
features = pred_rgb[None][pos].reshape((-1, 3)) |
good_points = points[:, 0] != -100 |
if good_points.sum() == 0: |
continue |
pc = Pointclouds( |
points=points[good_points][None].cpu(), |
features=features[good_points][None].cpu(), |
) |
clouds["MCC Output"][f"pred_{t}"] = pc |
IO().save_pointcloud(pc, "output_pointcloud.ply") |
plt.figure() |
try: |
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2) |
fig.update_layout(height=1000, width=1000) |
return fig |
except Exception as e: |
print('writing failed', e) |
try: |
plt.close() |
except: |
pass |
def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f, |
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], |
pointcloud_marker_size=2, |
): |
if img is not None: |
fig = plt.figure() |
plt.imshow(img) |
tmpfile = BytesIO() |
fig.savefig(tmpfile, format='jpg') |
encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8') |
html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded) |
f.write(html) |
plt.close() |
clouds = {"MCC Output": {}} |
if seen_xyz is not None: |
seen_xyz = seen_xyz.reshape((-1, 3)).cpu() |
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu() |
good_seen = seen_xyz[:, 0] != -100 |
seen_pc = Pointclouds( |
points=seen_xyz[good_seen][None], |
features=seen_rgb[good_seen][None], |
) |
clouds["MCC Output"]["seen"] = seen_pc |
if gt_xyz is not None: |
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000) |
gt_pc = Pointclouds( |
points=gt_xyz[subset_gt][None], |
features=gt_rgb[subset_gt][None], |
) |
clouds["MCC Output"]["GT points"] = gt_pc |
if mesh_xyz is not None: |
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000) |
mesh_pc = Pointclouds( |
points=mesh_xyz[subset_mesh][None], |
) |
clouds["MCC Output"]["GT mesh"] = mesh_pc |
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu() |
for t in score_thresholds: |
pos = pred_occ > t |
points = unseen_xyz[pos].reshape((-1, 3)) |
features = pred_rgb[None][pos].reshape((-1, 3)) |
good_points = points[:, 0] != -100 |
if good_points.sum() == 0: |
continue |
pc = Pointclouds( |
points=points[good_points][None].cpu(), |
features=features[good_points][None].cpu(), |
) |
clouds["MCC Output"][f"pred_{t}"] = pc |
plt.figure() |
try: |
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2) |
fig.update_layout(height=1000, width=1000) |
html_string = fig.to_html(full_html=False, include_plotlyjs="cnd") |
f.write(html_string) |
return fig, plt |
except Exception as e: |
print('writing failed', e) |
try: |
plt.close() |
except: |
pass |
def train_one_epoch(model: torch.nn.Module, |
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
device: torch.device, epoch: int, loss_scaler, |
args=None): |
epoch_start_time = time.time() |
model.train(True) |
metric_logger = misc.MetricLogger(delimiter=" ") |
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
accum_iter = args.accum_iter |
optimizer.zero_grad() |
print('Training data_loader length:', len(data_loader)) |
for data_iter_step, samples in enumerate(data_loader): |
if data_iter_step % accum_iter == 0: |
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) |
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args) |
with torch.cuda.amp.autocast(): |
loss, _ = model( |
seen_images=seen_images, |
seen_xyz=seen_xyz, |
unseen_xyz=unseen_xyz, |
unseen_rgb=unseen_rgb, |
unseen_occupy=labels, |
valid_seen_xyz=valid_seen_xyz, |
) |
loss_value = loss.item() |
if not math.isfinite(loss_value): |
print("Warning: Loss is {}".format(loss_value)) |
loss *= 0.0 |
loss_value = 100.0 |
loss /= accum_iter |
loss_scaler(loss, optimizer, parameters=model.parameters(), |
clip_grad=args.clip_grad, |
update_grad=(data_iter_step + 1) % accum_iter == 0, |
verbose=(data_iter_step % 100) == 0) |
if (data_iter_step + 1) % accum_iter == 0: |
optimizer.zero_grad() |
torch.cuda.synchronize() |
metric_logger.update(loss=loss_value) |
lr = optimizer.param_groups[0]["lr"] |
metric_logger.update(lr=lr) |
if data_iter_step == 30: |
os.system('nvidia-smi') |
os.system('free -g') |
if args.debug and data_iter_step == 5: |
break |
metric_logger.synchronize_between_processes() |
print("Averaged stats:", metric_logger) |
print("Training epoch time:", time.time() - epoch_start_time) |
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
def eval_one_epoch( |
model: torch.nn.Module, |
data_loader: Iterable, |
device: torch.device, |
args=None |
): |
epoch_start_time = time.time() |
model.train(False) |
metric_logger = misc.MetricLogger(delimiter=" ") |
print('Eval len(data_loader):', len(data_loader)) |
for data_iter_step, samples in enumerate(data_loader): |
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args) |
max_n_queries_fwd = 5000 |
all_loss, all_preds = [], [] |
for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))): |
p_start = p_idx * max_n_queries_fwd |
p_end = (p_idx + 1) * max_n_queries_fwd |
cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
cur_unseen_rgb = unseen_rgb[:, p_start:p_end] |
cur_labels = labels[:, p_start:p_end] |
with torch.no_grad(): |
loss, pred = model( |
seen_images=seen_images, |
seen_xyz=seen_xyz, |
unseen_xyz=cur_unseen_xyz, |
unseen_rgb=cur_unseen_rgb, |
unseen_occupy=cur_labels, |
valid_seen_xyz=valid_seen_xyz, |
) |
all_loss.append(loss) |
all_preds.append(pred) |
loss = sum(all_loss) / len(all_loss) |
pred = torch.cat(all_preds, dim=1) |
B = pred.shape[0] |
gt_xyz = samples[1][0].to(device).reshape((B, -1, 3)) |
if args.use_hypersim: |
mesh_xyz = samples[2].to(device).reshape((B, -1, 3)) |
s_thres = args.eval_score_threshold |
d_thres = args.eval_dist_threshold |
for b_idx in range(B): |
geometry_metrics = {} |
predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres |
predicted_xyz = unseen_xyz[b_idx, predicted_idx] |
precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres) |
geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision |
geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall |
geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1 |
if args.use_hypersim: |
precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres) |
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision |
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall |
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1 |
metric_logger.update(**geometry_metrics) |
loss_value = loss.item() |
torch.cuda.synchronize() |
metric_logger.update(loss=loss_value) |
if args.debug and data_iter_step == 5: |
break |
metric_logger.synchronize_between_processes() |
print("Validation averaged stats:", metric_logger) |
print("Val epoch time:", time.time() - epoch_start_time) |
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
def sample_uniform_semisphere(B, N, semisphere_size, device): |
for _ in range(100): |
points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size) |
points[..., 2] = points[..., 2].abs() |
dist = (points ** 2.0).sum(axis=-1) ** 0.5 |
if (dist < semisphere_size).sum() >= B * N: |
return points[dist < semisphere_size][:B * N].reshape((B, N, 3)) |
else: |
print('resampling sphere') |
def get_grid_semisphere(B, granularity, semisphere_size, device): |
n_grid_pts = int(semisphere_size / granularity) * 2 + 1 |
grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device) |
for i in range(n_grid_pts): |
grid_unseen_xyz[i, :, :, 0] = i |
grid_unseen_xyz[:, i, :, 1] = i |
for i in range(n_grid_pts // 2 + 1): |
grid_unseen_xyz[:, :, i, 2] = i |
grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0) |
grid_unseen_xyz *= granularity |
dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5 |
grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size] |
return grid_unseen_xyz[None].repeat(B, 1, 1) |
def get_min_dist(a, b, slice_size=1000): |
all_min, all_idx = [], [] |
for i in range(int(np.ceil(a.shape[1] / slice_size))): |
start = slice_size * i |
end = slice_size * (i + 1) |
dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5 |
cur_min, cur_idx = dist.min(axis=2) |
all_min.append(cur_min) |
all_idx.append(cur_idx) |
return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1) |
def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity): |
B = gt_xyz.shape[0] |
device = gt_xyz.device |
if is_train: |
unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device) |
else: |
unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device) |
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None]) |
labels = dist < dist_threshold |
unseen_rgb = torch.zeros_like(unseen_xyz) |
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels] |
return unseen_xyz, unseen_rgb, labels.float() |
def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity): |
B = gt_xyz.shape[0] |
device = gt_xyz.device |
if is_train: |
unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size) |
else: |
unseen_xyz = get_grid(B, device, co3d_world_size, granularity) |
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None]) |
labels = dist < dist_threshold |
unseen_rgb = torch.zeros_like(unseen_xyz) |
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels] |
return unseen_xyz, unseen_rgb, labels.float() |
def prepare_data(samples, device, is_train, args, is_viz=False): |
seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device) |
valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1)) |
seen_xyz[~valid_seen_xyz] = -100 |
B = seen_xyz.shape[0] |
gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3) |
sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid |
unseen_xyz, unseen_rgb, labels = sampling_func( |
gt_xyz, gt_rgb, |
args.semisphere_size if args.use_hypersim else args.co3d_world_size, |
args.n_queries, |
args.train_dist_threshold, |
is_train, |
args.viz_granularity if is_viz else args.eval_granularity, |
) |
if is_train: |
seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train) |
if random.random() < 0.5: |
seen_xyz[..., 0] *= -1 |
unseen_xyz[..., 0] *= -1 |
seen_xyz = torch.flip(seen_xyz, [2]) |
valid_seen_xyz = torch.flip(valid_seen_xyz, [2]) |
seen_rgb = torch.flip(seen_rgb, [3]) |
return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb |