MCC_slim / engine_mcc.py
hugoycj
Initial commit
cacb27a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
import math
from typing import Iterable
import os
import matplotlib.pyplot as plt
import random
import torch
import numpy as np
import time
import base64
from io import BytesIO
import util.misc as misc
import util.lr_sched as lr_sched
from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import plot_scene
from pytorch3d.transforms import RotateAxisAngle
from pytorch3d.io import IO
def evaluate_points(predicted_xyz, gt_xyz, dist_thres):
if predicted_xyz.shape[0] == 0:
return 0.0, 0.0, 0.0
slice_size = 1000
precision = 0.0
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
start = slice_size * i
end = slice_size * (i + 1)
dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5
precision += ((dist < dist_thres).sum(axis=1) > 0).sum()
precision /= predicted_xyz.shape[0]
recall = 0.0
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
start = slice_size * i
end = slice_size * (i + 1)
dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5
recall += ((dist < dist_thres).sum(axis=0) > 0).sum()
recall /= gt_xyz.shape[0]
return precision, recall, get_f1(precision, recall)
def aug_xyz(seen_xyz, unseen_xyz, args, is_train):
degree_x = 0
degree_y = 0
degree_z = 0
if is_train:
r_delta = args.random_scale_delta
scale = torch.tensor([
random.uniform(1.0 - r_delta, 1.0 + r_delta),
random.uniform(1.0 - r_delta, 1.0 + r_delta),
random.uniform(1.0 - r_delta, 1.0 + r_delta),
], device=seen_xyz.device)
if args.use_hypersim:
shift = 0
else:
degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
r_shift = args.random_shift
shift = torch.tensor([[[
random.uniform(-r_shift, r_shift),
random.uniform(-r_shift, r_shift),
random.uniform(-r_shift, r_shift),
]]], device=seen_xyz.device)
seen_xyz = seen_xyz * scale + shift
unseen_xyz = unseen_xyz * scale + shift
B, H, W, _ = seen_xyz.shape
return [
rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)),
rotate(unseen_xyz, degree_x, degree_y, degree_z),
]
def rotate(sample, degree_x, degree_y, degree_z):
for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]:
if degree != 0:
sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample)
return sample
def get_grid(B, device, co3d_world_size, granularity):
N = int(np.ceil(2 * co3d_world_size / granularity))
grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device)
for i in range(N):
grid_unseen_xyz[i, :, :, 0] = i
for j in range(N):
grid_unseen_xyz[:, j, :, 1] = j
for k in range(N):
grid_unseen_xyz[:, :, k, 2] = k
grid_unseen_xyz -= (N / 2.0)
grid_unseen_xyz /= (N / 2.0) / co3d_world_size
grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1)
return grid_unseen_xyz
def run_viz(model, data_loader, device, args, epoch):
epoch_start_time = time.time()
model.eval()
os.system(f'mkdir {args.job_dir}/viz')
print('Visualization data_loader length:', len(data_loader))
dataset = data_loader.dataset
for sample_idx, samples in enumerate(data_loader):
if sample_idx >= args.max_n_viz_obj:
break
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True)
pred_occupy = []
pred_colors = []
(model.module if hasattr(model, "module") else model).clear_cache()
# don't forward all at once to avoid oom
max_n_queries_fwd = 2000
total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))
for p_idx in range(total_n_passes):
p_start = p_idx * max_n_queries_fwd
p_end = (p_idx + 1) * max_n_queries_fwd
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
cur_labels = labels[:, p_start:p_end].zero_()
with torch.no_grad():
_, pred, = model(
seen_images=seen_images,
seen_xyz=seen_xyz,
unseen_xyz=cur_unseen_xyz,
unseen_rgb=cur_unseen_rgb,
unseen_occupy=cur_labels,
cache_enc=args.run_viz,
valid_seen_xyz=valid_seen_xyz,
)
cur_occupy_out = pred[..., 0]
if args.regress_color:
cur_color_out = pred[..., 1:].reshape((-1, 3))
else:
cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0
pred_occupy.append(cur_occupy_out)
pred_colors.append(cur_color_out)
rank = misc.get_rank()
prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}'
img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8)
gt_xyz = samples[1][0].to(device).reshape(-1, 3)
gt_rgb = samples[1][1].to(device).reshape(-1, 3)
mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None
with open(prefix + '.html', 'a') as f:
generate_html(
img,
seen_xyz, seen_images,
torch.cat(pred_occupy, dim=1),
torch.cat(pred_colors, dim=0),
unseen_xyz,
f,
gt_xyz=gt_xyz,
gt_rgb=gt_rgb,
mesh_xyz=mesh_xyz,
)
print("Visualization epoch time:", time.time() - epoch_start_time)
def get_f1(precision, recall):
if (precision + recall) == 0:
return 0.0
return 2.0 * precision * recall / (precision + recall)
def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz,
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
pointcloud_marker_size=2,
):
# if img is not None:
# fig = plt.figure()
# plt.imshow(img)
# tmpfile = BytesIO()
# fig.savefig(tmpfile, format='jpg')
# encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
# html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
# f.write(html)
# plt.close()
clouds = {"MCC Output": {}}
# Seen
if seen_xyz is not None:
seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
good_seen = seen_xyz[:, 0] != -100
seen_pc = Pointclouds(
points=seen_xyz[good_seen][None],
features=seen_rgb[good_seen][None],
)
clouds["MCC Output"]["seen"] = seen_pc
# GT points
if gt_xyz is not None:
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
gt_pc = Pointclouds(
points=gt_xyz[subset_gt][None],
features=gt_rgb[subset_gt][None],
)
clouds["MCC Output"]["GT points"] = gt_pc
# GT meshes
if mesh_xyz is not None:
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
mesh_pc = Pointclouds(
points=mesh_xyz[subset_mesh][None],
)
clouds["MCC Output"]["GT mesh"] = mesh_pc
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
for t in score_thresholds:
pos = pred_occ > t
points = unseen_xyz[pos].reshape((-1, 3))
features = pred_rgb[None][pos].reshape((-1, 3))
good_points = points[:, 0] != -100
if good_points.sum() == 0:
continue
pc = Pointclouds(
points=points[good_points][None].cpu(),
features=features[good_points][None].cpu(),
)
clouds["MCC Output"][f"pred_{t}"] = pc
IO().save_pointcloud(pc, "output_pointcloud.ply")
plt.figure()
try:
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
fig.update_layout(height=1000, width=1000)
return fig
except Exception as e:
print('writing failed', e)
try:
plt.close()
except:
pass
def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f,
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
pointcloud_marker_size=2,
):
if img is not None:
fig = plt.figure()
plt.imshow(img)
tmpfile = BytesIO()
fig.savefig(tmpfile, format='jpg')
encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
f.write(html)
plt.close()
clouds = {"MCC Output": {}}
# Seen
if seen_xyz is not None:
seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
good_seen = seen_xyz[:, 0] != -100
seen_pc = Pointclouds(
points=seen_xyz[good_seen][None],
features=seen_rgb[good_seen][None],
)
clouds["MCC Output"]["seen"] = seen_pc
# GT points
if gt_xyz is not None:
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
gt_pc = Pointclouds(
points=gt_xyz[subset_gt][None],
features=gt_rgb[subset_gt][None],
)
clouds["MCC Output"]["GT points"] = gt_pc
# GT meshes
if mesh_xyz is not None:
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
mesh_pc = Pointclouds(
points=mesh_xyz[subset_mesh][None],
)
clouds["MCC Output"]["GT mesh"] = mesh_pc
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
for t in score_thresholds:
pos = pred_occ > t
points = unseen_xyz[pos].reshape((-1, 3))
features = pred_rgb[None][pos].reshape((-1, 3))
good_points = points[:, 0] != -100
if good_points.sum() == 0:
continue
pc = Pointclouds(
points=points[good_points][None].cpu(),
features=features[good_points][None].cpu(),
)
clouds["MCC Output"][f"pred_{t}"] = pc
plt.figure()
try:
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
fig.update_layout(height=1000, width=1000)
html_string = fig.to_html(full_html=False, include_plotlyjs="cnd")
f.write(html_string)
return fig, plt
except Exception as e:
print('writing failed', e)
try:
plt.close()
except:
pass
def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
args=None):
epoch_start_time = time.time()
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
accum_iter = args.accum_iter
optimizer.zero_grad()
print('Training data_loader length:', len(data_loader))
for data_iter_step, samples in enumerate(data_loader):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args)
with torch.cuda.amp.autocast():
loss, _ = model(
seen_images=seen_images,
seen_xyz=seen_xyz,
unseen_xyz=unseen_xyz,
unseen_rgb=unseen_rgb,
unseen_occupy=labels,
valid_seen_xyz=valid_seen_xyz,
)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Warning: Loss is {}".format(loss_value))
loss *= 0.0
loss_value = 100.0
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
clip_grad=args.clip_grad,
update_grad=(data_iter_step + 1) % accum_iter == 0,
verbose=(data_iter_step % 100) == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
if data_iter_step == 30:
os.system('nvidia-smi')
os.system('free -g')
if args.debug and data_iter_step == 5:
break
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
print("Training epoch time:", time.time() - epoch_start_time)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def eval_one_epoch(
model: torch.nn.Module,
data_loader: Iterable,
device: torch.device,
args=None
):
epoch_start_time = time.time()
model.train(False)
metric_logger = misc.MetricLogger(delimiter=" ")
print('Eval len(data_loader):', len(data_loader))
for data_iter_step, samples in enumerate(data_loader):
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args)
# don't forward all at once to avoid oom
max_n_queries_fwd = 5000
all_loss, all_preds = [], []
for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))):
p_start = p_idx * max_n_queries_fwd
p_end = (p_idx + 1) * max_n_queries_fwd
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
cur_unseen_rgb = unseen_rgb[:, p_start:p_end]
cur_labels = labels[:, p_start:p_end]
with torch.no_grad():
loss, pred = model(
seen_images=seen_images,
seen_xyz=seen_xyz,
unseen_xyz=cur_unseen_xyz,
unseen_rgb=cur_unseen_rgb,
unseen_occupy=cur_labels,
valid_seen_xyz=valid_seen_xyz,
)
all_loss.append(loss)
all_preds.append(pred)
loss = sum(all_loss) / len(all_loss)
pred = torch.cat(all_preds, dim=1)
B = pred.shape[0]
gt_xyz = samples[1][0].to(device).reshape((B, -1, 3))
if args.use_hypersim:
mesh_xyz = samples[2].to(device).reshape((B, -1, 3))
s_thres = args.eval_score_threshold
d_thres = args.eval_dist_threshold
for b_idx in range(B):
geometry_metrics = {}
predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres
predicted_xyz = unseen_xyz[b_idx, predicted_idx]
precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres)
geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision
geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall
geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1
if args.use_hypersim:
precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres)
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1
metric_logger.update(**geometry_metrics)
loss_value = loss.item()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
if args.debug and data_iter_step == 5:
break
metric_logger.synchronize_between_processes()
print("Validation averaged stats:", metric_logger)
print("Val epoch time:", time.time() - epoch_start_time)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def sample_uniform_semisphere(B, N, semisphere_size, device):
for _ in range(100):
points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size)
points[..., 2] = points[..., 2].abs()
dist = (points ** 2.0).sum(axis=-1) ** 0.5
if (dist < semisphere_size).sum() >= B * N:
return points[dist < semisphere_size][:B * N].reshape((B, N, 3))
else:
print('resampling sphere')
def get_grid_semisphere(B, granularity, semisphere_size, device):
n_grid_pts = int(semisphere_size / granularity) * 2 + 1
grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device)
for i in range(n_grid_pts):
grid_unseen_xyz[i, :, :, 0] = i
grid_unseen_xyz[:, i, :, 1] = i
for i in range(n_grid_pts // 2 + 1):
grid_unseen_xyz[:, :, i, 2] = i
grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0)
grid_unseen_xyz *= granularity
dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5
grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size]
return grid_unseen_xyz[None].repeat(B, 1, 1)
def get_min_dist(a, b, slice_size=1000):
all_min, all_idx = [], []
for i in range(int(np.ceil(a.shape[1] / slice_size))):
start = slice_size * i
end = slice_size * (i + 1)
# B, n_queries, n_gt
dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5
# B, n_queries
cur_min, cur_idx = dist.min(axis=2)
all_min.append(cur_min)
all_idx.append(cur_idx)
return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1)
def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity):
B = gt_xyz.shape[0]
device = gt_xyz.device
if is_train:
unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device)
else:
unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device)
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
labels = dist < dist_threshold
unseen_rgb = torch.zeros_like(unseen_xyz)
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
return unseen_xyz, unseen_rgb, labels.float()
def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity):
B = gt_xyz.shape[0]
device = gt_xyz.device
if is_train:
unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size)
else:
unseen_xyz = get_grid(B, device, co3d_world_size, granularity)
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
labels = dist < dist_threshold
unseen_rgb = torch.zeros_like(unseen_xyz)
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
return unseen_xyz, unseen_rgb, labels.float()
def prepare_data(samples, device, is_train, args, is_viz=False):
# Seen
seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device)
valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1))
seen_xyz[~valid_seen_xyz] = -100
B = seen_xyz.shape[0]
# Gt
gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3)
sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid
unseen_xyz, unseen_rgb, labels = sampling_func(
gt_xyz, gt_rgb,
args.semisphere_size if args.use_hypersim else args.co3d_world_size,
args.n_queries,
args.train_dist_threshold,
is_train,
args.viz_granularity if is_viz else args.eval_granularity,
)
if is_train:
seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train)
# Random Flip
if random.random() < 0.5:
seen_xyz[..., 0] *= -1
unseen_xyz[..., 0] *= -1
seen_xyz = torch.flip(seen_xyz, [2])
valid_seen_xyz = torch.flip(valid_seen_xyz, [2])
seen_rgb = torch.flip(seen_rgb, [3])
return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb