Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import numpy as np | |
import torch | |
from torch.nn.functional import normalize | |
from . import get_model | |
from models.base import BaseModel | |
from models.bev_net import BEVNet | |
from models.bev_projection import CartesianProjection, PolarProjectionDepth | |
from models.voting import ( | |
argmax_xyr,argmax_xyrh, | |
conv2d_fft_batchwise, | |
expectation_xyr, | |
log_softmax_spatial, | |
mask_yaw_prior, | |
nll_loss_xyr, | |
nll_loss_xyr_smoothed, | |
TemplateSampler, | |
UAVTemplateSampler, | |
UAVTemplateSamplerFast | |
) | |
import torch.nn.functional as F | |
from torch.nn.functional import grid_sample, log_softmax, pad | |
from .map_encoder import MapEncoder | |
from .map_encoder_single import MapEncoderSingle | |
from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall | |
class MapLocNet(BaseModel): | |
default_conf = { | |
"image_size": "???", | |
"val_citys":"???", | |
"image_encoder": "???", | |
"map_encoder": "???", | |
"bev_net": "???", | |
"latent_dim": "???", | |
"matching_dim": "???", | |
"scale_range": [0, 9], | |
"num_scale_bins": "???", | |
"z_min": None, | |
"z_max": "???", | |
"x_max": "???", | |
"pixel_per_meter": "???", | |
"num_rotations": "???", | |
"add_temperature": False, | |
"normalize_features": False, | |
"padding_matching": "replicate", | |
"apply_map_prior": True, | |
"do_label_smoothing": False, | |
"sigma_xy": 1, | |
"sigma_r": 2, | |
# depcreated | |
"depth_parameterization": "scale", | |
"norm_depth_scores": False, | |
"normalize_scores_by_dim": False, | |
"normalize_scores_by_num_valid": True, | |
"prior_renorm": True, | |
"retrieval_dim": None, | |
} | |
def _init(self, conf): | |
assert not self.conf.norm_depth_scores | |
assert self.conf.depth_parameterization == "scale" | |
assert not self.conf.normalize_scores_by_dim | |
assert self.conf.normalize_scores_by_num_valid | |
assert self.conf.prior_renorm | |
# a=conf.image_encoder.get("name", "feature_extractor_v2") | |
# b=conf.image_encoder.get("name") | |
Encoder = get_model(conf.image_encoder.get("name")) | |
self.image_encoder = Encoder(conf.image_encoder.backbone) | |
if len(conf.map_encoder.num_classes)==1: | |
self.map_encoder = MapEncoderSingle(conf.map_encoder) | |
else: | |
self.map_encoder = MapEncoder(conf.map_encoder) | |
# self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net) | |
ppm = conf.pixel_per_meter | |
# self.projection_polar = PolarProjectionDepth( | |
# conf.z_max, | |
# ppm, | |
# conf.scale_range, | |
# conf.z_min, | |
# ) | |
# self.projection_bev = CartesianProjection( | |
# conf.z_max, conf.x_max, ppm, conf.z_min | |
# ) | |
# self.template_sampler = TemplateSampler( | |
# self.projection_bev.grid_xz, ppm, conf.num_rotations | |
# ) | |
self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2) | |
# self.template_sampler = UAVTemplateSampler(conf.num_rotations) | |
# self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins) | |
# if conf.bev_net is None: | |
# self.feature_projection = torch.nn.Linear( | |
# conf.latent_dim, conf.matching_dim | |
# ) | |
if conf.add_temperature: | |
temperature = torch.nn.Parameter(torch.tensor(0.0)) | |
self.register_parameter("temperature", temperature) | |
def exhaustive_voting(self, f_bev, f_map): | |
if self.conf.normalize_features: | |
f_bev = normalize(f_bev, dim=1) | |
f_map = normalize(f_map, dim=1) | |
# Build the templates and exhaustively match against the map. | |
# if confidence_bev is not None: | |
# f_bev = f_bev * confidence_bev.unsqueeze(1) | |
# f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0) | |
# torch.save(f_bev, 'f_bev.pt') | |
# torch.save(f_map, 'f_map.pt') | |
f_map = F.interpolate(f_map, size=(256, 256), mode='bilinear', align_corners=False) | |
templates = self.template_sampler(f_bev)#[batch,256,8,129,129] | |
# torch.save(templates, 'templates.pt') | |
with torch.autocast("cuda", enabled=False): | |
scores = conv2d_fft_batchwise( | |
f_map.float(), | |
templates.float(), | |
padding_mode=self.conf.padding_matching, | |
) | |
if self.conf.add_temperature: | |
scores = scores * torch.exp(self.temperature) | |
# Reweight the different rotations based on the number of valid pixels | |
# in each template. Axis-aligned rotation have the maximum number of valid pixels. | |
# valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4) | |
# num_valid = valid_templates.float().sum((-3, -2, -1)) | |
# scores = scores / num_valid[..., None, None] | |
return scores | |
def _forward(self, data): | |
pred = {} | |
pred_map = pred["map"] = self.map_encoder(data) | |
f_map = pred_map["map_features"][0]#[batch,8,256,256] | |
# Extract image features. | |
level = 0 | |
f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] | |
# print("f_map:",f_map.shape) | |
scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] | |
scores = scores.moveaxis(1, -1) # B,H,W,N | |
if "log_prior" in pred_map and self.conf.apply_map_prior: | |
scores = scores + pred_map["log_prior"][0].unsqueeze(-1) | |
# pred["scores_unmasked"] = scores.clone() | |
if "map_mask" in data: | |
scores.masked_fill_(~data["map_mask"][..., None], -np.inf) | |
if "yaw_prior" in data: | |
mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) | |
log_probs = log_softmax_spatial(scores) | |
# torch.save(scores, 'scores.pt') | |
with torch.no_grad(): | |
uvr_max = argmax_xyr(scores).to(scores) | |
uvr_avg, _ = expectation_xyr(log_probs.exp()) | |
return { | |
**pred, | |
"scores": scores, | |
"log_probs": log_probs, | |
"uvr_max": uvr_max, | |
"uv_max": uvr_max[..., :2], | |
"yaw_max": uvr_max[..., 2], | |
"uvr_expectation": uvr_avg, | |
"uv_expectation": uvr_avg[..., :2], | |
"yaw_expectation": uvr_avg[..., 2], | |
"features_image": f_image, | |
} | |
def _forward_scale(self, data,resize=None): | |
pred = {} | |
pred_map = pred["map"] = self.map_encoder(data) | |
f_map = pred_map["map_features"][0]#[batch,8,256,256] | |
# Extract image features. | |
level = 0 | |
f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] | |
# print("f_map:",f_map.shape) | |
scores_list = [] | |
for resize_size in resize: | |
f_image_re = torch.nn.functional.interpolate(f_image, size=resize_size, mode='bilinear', align_corners=False) | |
scores = self.exhaustive_voting(f_image_re, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] | |
scores = scores.moveaxis(1, -1) # B,H,W,N | |
scores_list.append(scores) | |
scores_list = torch.stack(scores_list, dim=-1) | |
log_probs_list = log_softmax(scores_list.flatten(-4), dim=-1).reshape(scores_list.shape) | |
# if "log_prior" in pred_map and self.conf.apply_map_prior: | |
# scores = scores + pred_map["log_prior"][0].unsqueeze(-1) | |
# # pred["scores_unmasked"] = scores.clone() | |
# if "map_mask" in data: | |
# scores.masked_fill_(~data["map_mask"][..., None], -np.inf) | |
# if "yaw_prior" in data: | |
# mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) | |
#scores shape:[batch,W,H,64] | |
# log_probs = log_softmax_spatial(scores) | |
# torch.save(scores, 'scores.pt') | |
with torch.no_grad(): | |
uvr_max = argmax_xyrh(scores_list) | |
# uvr_avg, _ = expectation_xyr(log_probs_list.exp()) | |
uvr_avg= uvr_max | |
return { | |
**pred, | |
"scores": scores, | |
"log_probs": log_probs_list, | |
"uvr_max": uvr_max, | |
"uv_max": uvr_max[..., :2], | |
"yaw_max": uvr_max[..., 2], | |
"uvr_expectation": uvr_avg, | |
"uv_expectation": uvr_avg[..., :2], | |
"yaw_expectation": uvr_avg[..., 2], | |
"features_image": f_image, | |
} | |
def loss(self, pred, data): | |
xy_gt = data["uv"] | |
yaw_gt = data["roll_pitch_yaw"][..., -1] | |
if self.conf.do_label_smoothing: | |
nll = nll_loss_xyr_smoothed( | |
pred["log_probs"], | |
xy_gt, | |
yaw_gt, | |
self.conf.sigma_xy / self.conf.pixel_per_meter, | |
self.conf.sigma_r, | |
mask=data.get("map_mask"), | |
) | |
else: | |
nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) | |
loss = {"total": nll, "nll": nll} | |
if self.training and self.conf.add_temperature: | |
loss["temperature"] = self.temperature.expand(len(nll)) | |
return loss | |
def metrics(self): | |
return { | |
"xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), | |
"xy_expectation_error": Location2DError( | |
"uv_expectation", self.conf.pixel_per_meter | |
), | |
"yaw_max_error": AngleError("yaw_max"), | |
"xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
"xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
"xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
# "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
# "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
# "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
# | |
# "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), | |
# "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), | |
# "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), | |
"yaw_recall_1°": AngleRecall(1.0, "yaw_max"), | |
"yaw_recall_3°": AngleRecall(3.0, "yaw_max"), | |
"yaw_recall_5°": AngleRecall(5.0, "yaw_max"), | |
} | |