diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fea7e72b5ee96b7254a67fa1f4fa4accf07b6c0c --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea +temp +temp.py +weight diff --git a/conf/maplocnet.yaml b/conf/maplocnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df0be9a09848e49803bb2049c2cffeffde771084 --- /dev/null +++ b/conf/maplocnet.yaml @@ -0,0 +1,100 @@ +data: + root: '/root/DATASET/UAV2MAP/UAV/' + train_citys: + - Paris + - Berlin + - London + - Tokyo + - NewYork + val_citys: + - Toronto + image_size: 256 + train: + batch_size: 12 + num_workers: 4 + val: + batch_size: ${..train.batch_size} + num_workers: ${.batch_size} + num_classes: + areas: 7 + ways: 10 + nodes: 33 + pixel_per_meter: 1 + crop_size_meters: 64 + max_init_error: 48 + add_map_mask: true + resize_image: 512 + pad_to_square: true + rectify_pitch: true + augmentation: + rot90: true + flip: true + image: + apply: true + brightness: 0.5 + contrast: 0.4 + saturation: 0.4 + hue": 0.5/3.14 +model: + image_size: ${data.image_size} + latent_dim: 128 + val_citys: ${data.val_citys} + image_encoder: + name: feature_extractor_v2 + backbone: + encoder: resnet50 + pretrained: true + output_dim: 8 + num_downsample: null + remove_stride_from_first_conv: false + name: orienternet + matching_dim: 8 + z_max: 32 + x_max: 32 + pixel_per_meter: 1 + num_scale_bins: 33 + num_rotations: 64 + map_encoder: + embedding_dim: 16 + output_dim: 8 + num_classes: + areas: 7 + ways: 10 + nodes: 33 + backbone: + encoder: vgg19 + pretrained: false + output_scales: + - 0 + num_downsample: 3 + decoder: + - 128 + - 64 + - 64 + padding: replicate + unary_prior: false + bev_net: + num_blocks: 4 + latent_dim: 128 + output_dim: 8 + confidence: true +experiment: + name: maplocanet_0906_diffhight + gpus: 6 + seed: 0 +training: + lr: 0.0001 + lr_scheduler: null + finetune_from_checkpoint: null + trainer: + val_check_interval: 1000 + log_every_n_steps: 100 +# limit_val_batches: 1000 + max_steps: 200000 + devices: ${experiment.gpus} + checkpointing: + monitor: "loss/total/val" + save_top_k: 10 + mode: min + +# filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}' \ No newline at end of file diff --git a/dataset/UAV/dataset.py b/dataset/UAV/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbc8121bb355dcbee2ed7f2bc68e85e25b2b808 --- /dev/null +++ b/dataset/UAV/dataset.py @@ -0,0 +1,116 @@ +import torch +from torch.utils.data import Dataset +import os +import cv2 +# @Time : 2023-02-13 22:56 +# @Author : Wang Zhen +# @Email : frozenzhencola@163.com +# @File : SatelliteTool.py +# @Project : TGRS_seqmatch_2023_1 +import numpy as np +import random +from utils.geo import BoundaryBox, Projection +from osm.tiling import TileManager,MapTileManager +from pathlib import Path +from torchvision import transforms +from torch.utils.data import DataLoader + +class UavMapPair(Dataset): + def __init__( + self, + root: Path, + city:str, + training:bool, + transform + ): + super().__init__() + + # self.root = root + + # city = 'Manhattan' + # root = '/root/DATASET/CrossModel/' + # root=Path(root) + self.uav_image_path = root/city/'uav' + self.map_path = root/city/'map' + self.map_vis = root / city / 'map_vis' + info_path = root / city / 'info.csv' + + self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) + + self.transform=transform + self.training=training + + def random_center_crop(self,image): + height, width = image.shape[:2] + + # 随机生成剪裁尺寸 + crop_size = random.randint(min(height, width) // 2, min(height, width)) + + # 计算剪裁的起始坐标 + start_x = (width - crop_size) // 2 + start_y = (height - crop_size) // 2 + + # 进行剪裁 + cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size] + + return cropped_image + def __getitem__(self, index: int): + id, uav_name, map_name, \ + uav_long, uav_lat, \ + map_long, map_lat, \ + tile_size_meters, pixel_per_meter, \ + u, v, yaw,dis=self.info[index] + + + uav_image=cv2.imread(str(self.uav_image_path/uav_name)) + if self.training: + uav_image =self.random_center_crop(uav_image) + uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB) + if self.transform: + uav_image=self.transform(uav_image) + map=np.load(str(self.map_path/map_name)) + + return { + 'map':torch.from_numpy(np.ascontiguousarray(map)).long(), + 'image':torch.tensor(uav_image), + 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(), + 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(), + "uv":torch.tensor([float(u), float(v)]).float(), + } + def __len__(self): + return len(self.info) +if __name__ == '__main__': + + root=Path('/root/DATASET/OrienterNet/UavMap/') + city='NewYork' + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize(256), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + dataset=UavMapPair( + root=root, + city=city, + transform=transform + ) + datasetloder = DataLoader(dataset, batch_size=3) + for batch, i in enumerate(datasetloder): + pass + # 将PyTorch张量转换为PIL图像 + # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy()) + + # 显示图像 + # 将PyTorch张量转换为NumPy数组 + # numpy_array = i['uav_image'][0].numpy() + # + # # 显示图像 + # plt.imshow(numpy_array.transpose(1, 2, 0)) + # plt.axis('off') + # plt.show() + # + # map_viz, label = Colormap.apply(i['map'][0]) + # map_viz = map_viz * 255 + # map_viz = map_viz.astype(np.uint8) + # plot_images([map_viz], titles=["OpenStreetMap raster"]) diff --git a/dataset/UAV/prepara_dataset.py b/dataset/UAV/prepara_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed607eafd6e8678688570db999755d136f381db --- /dev/null +++ b/dataset/UAV/prepara_dataset.py @@ -0,0 +1,270 @@ +import torch +from torch.utils.data import Dataset +import os +import cv2 +# @Time : 2023-02-13 22:56 +# @Author : Wang Zhen +# @Email : frozenzhencola@163.com +# @File : SatelliteTool.py +# @Project : TGRS_seqmatch_2023_1 +import numpy as np +import random +from utils.geo import BoundaryBox, Projection +from osm.tiling import TileManager,MapTileManager +from pathlib import Path +from torchvision import transforms +from tqdm import tqdm +import time +import math +import random +from geopy import Point, distance +from osm.viz import Colormap, plot_nodes + +def generate_random_coordinate(latitude, longitude, dis): + # 生成一个随机方向角 + random_angle = random.uniform(0, 360) + # print("random_angle",random_angle) + # 计算目标点的经纬度 + start_point = Point(latitude, longitude) + destination = distance.distance(kilometers=dis/1000).destination(start_point, random_angle) + + return destination.latitude, destination.longitude + +def rotate_corp(src,angle): + # 原图的高、宽 以及通道数 + rows, cols, channel = src.shape + + # 绕图像的中心旋转 + # 参数:旋转中心 旋转度数 scale + M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) + # rows, cols=700,700 + # 自适应图片边框大小 + cos = np.abs(M[0, 0]) + sin = np.abs(M[0, 1]) + new_w = rows * sin + cols * cos + new_h = rows * cos + cols * sin + M[0, 2] += (new_w - cols) * 0.5 + M[1, 2] += (new_h - rows) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine(src, M, (w, h)) + + # rotated = cv2.warpAffine(src, M, (cols, rows)) + + c=int(w / 2) + w=int(rows*math.sqrt(2)/4) + rotated2=rotated[c-w:c+w,c-w:c+w,:] + return rotated2 + +class SatelliteGeoTools: + """ + 用于读取卫星图tfw文件,执行 像素坐标-Mercator-GPS坐标 的转化 + """ + def __init__(self, tfw_path): + self.SatelliteParameter=self.Parsetfw(tfw_path) + def Parsetfw(self, tfw_path): + info = [] + f = open(tfw_path) + for _ in range(6): + line = f.readline() + line = line.strip('\n') + info.append(float(line)) + f.close() + return info + def Pix2Geo(self, x, y): + A, D, B, E, C, F = self.SatelliteParameter + x1 = A * x + B * y + C + y1 = D * x + E * y + F + # print(x1,y1) + s_long, s_lat = self.MercatorTolonlat(x1, y1) + return s_long, s_lat + + def Geo2Pix(self, lon, lat): + """ + https://baike.baidu.com/item/TFW%E6%A0%BC%E5%BC%8F/6273151?fr=aladdin + x'=Ax+By+C + y'=Dx+Ey+F + :return: + """ + x1, y1 = self.LonlatToMercator(lon, lat) + A, D, B, E, C, F = self.SatelliteParameter + M = np.array([[A, B, C], + [D, E, F], + [0, 0, 1]]) + M_INV = np.linalg.inv(M) + XY = np.matmul(M_INV, np.array([x1, y1, 1]).T) + return int(XY[0]), int(XY[1]) + def MercatorTolonlat(self,mx,my): + x = mx/20037508.3427892*180 + y = my/20037508.3427892*180 + # y= 180/math.pi*(2*math.atan(math.exp(y*math.pi/180))-math.pi/2) + y = 180.0 / np.pi * (2.0 * np.arctan(np.exp(y * np.pi / 180.0)) - np.pi / 2.0) + return x,y + def LonlatToMercator(self,lon, lat): + x = lon * 20037508.342789 / 180 + y = np.log(np.tan((90 + lat) * np.pi / 360)) / (np.pi / 180) + y = y * 20037508.34789 / 180 + return x, y + +def geodistance(lng1, lat1, lng2, lat2): + lng1, lat1, lng2, lat2 = map(np.radians, [lng1, lat1, lng2, lat2]) + dlon = lng2 - lng1 + dlat = lat2 - lat1 + a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2 + distance = 2 * np.arcsin(np.sqrt(a)) * 6371 * 1000 # 地球平均半径,6371km + return distance + +class PreparaDataset: + def __init__( + self, + root: Path, + city:str, + patch_size:int, + tile_size_meters:float + ): + super().__init__() + + # self.root = root + + # city = 'Manhattan' + # root = '/root/DATASET/CrossModel/' + imagepath = root/city/ '{}.tif'.format(city) + tfwpath = root/city/'{}.tfw'.format(city) + + self.osmpath = root/city/'{}.osm'.format(city) + + self.TileManager=MapTileManager(self.osmpath) + image = cv2.imread(str(imagepath)) + self.image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) + + self.ST = SatelliteGeoTools(str(tfwpath)) + + self.patch_size=patch_size + self.tile_size_meters=tile_size_meters + + + + def get_osm(self,prior_latlon,uav_latlon): + latlon = np.array(prior_latlon) + proj = Projection(*latlon) + center = proj.project(latlon) + + uav_latlon=np.array(uav_latlon) + + XY=proj.project(uav_latlon) + # tile_size_meters = 128 + bbox = BoundaryBox(center, center) + self.tile_size_meters + # bbox= BoundaryBox(center, center) + # Query OpenStreetMap for this area + self.pixel_per_meter = 1 + start_time = time.time() + canvas = self.TileManager.from_bbox(proj, bbox, self.pixel_per_meter) + end_time = time.time() + execution_time = end_time - start_time + # print("方法执行时间:", execution_time, "秒") + # canvas = tiler.query(bbox) + XY=[XY[0]+self.tile_size_meters,-XY[1]+self.tile_size_meters] + return canvas,XY + def random_corp(self): + + # 根据随机裁剪尺寸计算出裁剪区域的左上角坐标 + x = random.randint(1000, self.image.shape[1] - self.patch_size-1000) + y = random.randint(1000, self.image.shape[0] - self.patch_size-1000) + x1 = x + self.patch_size + y1 = y + self.patch_size + return x,x1,y,y1 + + def generate(self): + x,x1,y,y1 = self.random_corp() + uav_center_x,uav_center_y=int((x+x1)//2),int((y+y1)//2) + uav_center_long,uav_center_lat=self.ST.Pix2Geo(uav_center_x,uav_center_y) + # print(uav_center_long,uav_center_lat) + self.image_patch = self.image[y:y1, x:x1] + + map_center_lat, map_center_long = generate_random_coordinate(uav_center_lat, uav_center_long, self.tile_size_meters) + map,XY=self.get_osm([map_center_lat,map_center_long],[uav_center_lat, uav_center_long]) + + + yaw=np.random.random()*360 + self.image_patch=rotate_corp(self.image_patch,yaw) + # return self.image_patch,self.osm_patch + # XY=[X+self.tile_size_meters + return { + 'uav_image':self.image_patch, + 'uav_long_lat':[uav_center_long,uav_center_lat], + 'map_long_lat': [map_center_long,map_center_lat], + 'tile_size_meters': map.raster.shape[1], + 'pixel_per_meter':self.pixel_per_meter, + 'yaw':yaw, + 'map':map.raster, + "uv":XY + } +if __name__ == '__main__': + + import argparse + + parser = argparse.ArgumentParser(description='manual to this script') + parser.add_argument('--city', type=str, default=None,required=True) + parser.add_argument('--num', type=int, default=10000) + args = parser.parse_args() + + + root=Path('/root/DATASET/OrienterNet/UavMap/') + city=args.city + dataset = PreparaDataset( + root=root, + city=city, + patch_size=512, + tile_size_meters=128, + ) + + uav_path=root/city/'uav' + if not uav_path.exists(): + uav_path.mkdir(parents=True) + + map_path = root / city / 'map' + if not map_path.exists(): + map_path.mkdir(parents=True) + + map_vis_path = root / city / 'map_vis' + if not map_vis_path.exists(): + map_vis_path.mkdir(parents=True) + + info_path = root / city / 'info.csv' + + # num=1000 + num = args.num + info=[['id','uav_name','map_name','uav_long','uav_lat','map_long','map_lat','tile_size_meters','pixel_per_meter','u','v','yaw']] + # info =[] + for i in tqdm(range(num)): + data=dataset.generate() + # print(str(uav_path/"{:05d}.jpg".format(i))) + + cv2.imwrite(str(uav_path/"{:05d}.jpg".format(i)),cv2.cvtColor(data['uav_image'],cv2.COLOR_RGB2BGR)) + + np.save(str(map_path/"{:05d}.npy".format(i)),data['map']) + + map_viz, label = Colormap.apply(data['map']) + map_viz = map_viz * 255 + map_viz = map_viz.astype(np.uint8) + cv2.imwrite(str(map_vis_path / "{:05d}.jpg".format(i)), cv2.cvtColor(map_viz, cv2.COLOR_RGB2BGR)) + + + uav_center_long, uav_center_lat=data['uav_long_lat'] + map_center_long, map_center_lat = data['map_long_lat'] + info.append([ + i, + "{:05d}.jpg".format(i), + "{:05d}.npy".format(i), + uav_center_long, + uav_center_lat, + map_center_long, + map_center_lat, + data["tile_size_meters"], + data["pixel_per_meter"], + data['uv'][0], + data['uv'][1], + data['yaw'] + ]) + # print(info) + np.savetxt(info_path,info,delimiter=',',fmt="%s") \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b51bfd2ed3eaf38719d8ee102df779d53d1ffa4 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,4 @@ +# from .UAV.dataset import UavMapPair +from .dataset import UavMapDatasetModule + +# modules = {"UAV": UavMapPair} diff --git a/dataset/dataset.py b/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec79057f559178a4c7508ab2970bb963dcf4e09a --- /dev/null +++ b/dataset/dataset.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List +# from logger import logger +import numpy as np +# import torch +# import torch.utils.data as torchdata +# import torchvision.transforms as tvf +from omegaconf import DictConfig, OmegaConf +import pytorch_lightning as pl +from dataset.UAV.dataset import UavMapPair +# from torch.utils.data import Dataset, DataLoader +# from torchvision import transforms +from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data import Dataset, DataLoader, random_split +import torchvision.transforms as tvf + +# 自定义数据模块类,继承自pl.LightningDataModule +class UavMapDatasetModule(pl.LightningDataModule): + + + def __init__(self, cfg: Dict[str, Any]): + super().__init__() + + # default_cfg = OmegaConf.create(self.default_cfg) + # OmegaConf.set_struct(default_cfg, True) # cannot add new keys + # self.cfg = OmegaConf.merge(default_cfg, cfg) + self.cfg=cfg + # self.transform = tvf.Compose([ + # tvf.ToTensor(), + # tvf.Resize(self.cfg.image_size), + # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + # ]) + + tfs = [] + tfs.append(tvf.ToTensor()) + tfs.append(tvf.Resize(self.cfg.image_size)) + self.val_tfs = tvf.Compose(tfs) + + # transforms.Resize(self.cfg.image_size), + if cfg.augmentation.image.apply: + args = OmegaConf.masked_copy( + cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] + ) + tfs.append(tvf.ColorJitter(**args)) + self.train_tfs = tvf.Compose(tfs) + + # self.train_tfs=self.transform + # self.val_tfs = self.transform + self.init() + def init(self): + self.train_dataset = ConcatDataset([ + UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs) + for city in self.cfg.train_citys + ]) + + self.val_dataset = ConcatDataset([ + UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) + for city in self.cfg.val_citys + ]) + + # self.val_datasets = { + # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs) + # for city in self.cfg.val_citys + # } + # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset))) + # # 定义分割比例 + # train_ratio = 0.8 # 训练集比例 + # # 计算分割的样本数量 + # train_size = int(len(self.dataset) * train_ratio) + # val_size = len(self.dataset) - train_size + # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size]) + def train_dataloader(self): + train_loader = DataLoader(self.train_dataset, + batch_size=self.cfg.train.batch_size, + num_workers=self.cfg.train.num_workers, + shuffle=True,pin_memory = True) + return train_loader + + def val_dataloader(self): + val_loader = DataLoader(self.val_dataset, + batch_size=self.cfg.val.batch_size, + num_workers=self.cfg.val.num_workers, + shuffle=True,pin_memory = True) + # + # my_dict = {k: v for k, v in self.val_datasets} + # val_loaders={city: DataLoader(dataset, + # batch_size=self.cfg.val.batch_size, + # num_workers=self.cfg.val.num_workers, + # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()} + return val_loader diff --git a/dataset/image.py b/dataset/image.py new file mode 100644 index 0000000000000000000000000000000000000000..75b3dc68cc2481150c5ff938483ae640956bcf0d --- /dev/null +++ b/dataset/image.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from typing import Callable, Optional, Union, Sequence + +import numpy as np +import torch +import torchvision.transforms.functional as tvf +import collections +from scipy.spatial.transform import Rotation + +from utils.geometry import from_homogeneous, to_homogeneous +from utils.wrappers import Camera + + +def rectify_image( + image: torch.Tensor, + cam: Camera, + roll: float, + pitch: Optional[float] = None, + valid: Optional[torch.Tensor] = None, +): + *_, h, w = image.shape + grid = torch.meshgrid( + [torch.arange(w, device=image.device), torch.arange(h, device=image.device)], + indexing="xy", + ) + grid = torch.stack(grid, -1).to(image.dtype) + + if pitch is not None: + args = ("ZX", (roll, pitch)) + else: + args = ("Z", roll) + R = Rotation.from_euler(*args, degrees=True).as_matrix() + R = torch.from_numpy(R).to(image) + + grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T + grid_rect = cam.denormalize(from_homogeneous(grid_rect)) + grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1 + rectified = torch.nn.functional.grid_sample( + image[None], + grid_norm[None], + align_corners=False, + mode="bilinear", + ).squeeze(0) + if valid is None: + valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1) + else: + valid = ( + torch.nn.functional.grid_sample( + valid[None, None].float(), + grid_norm[None], + align_corners=False, + mode="nearest", + )[0, 0] + > 0 + ) + return rectified, valid + + +def resize_image( + image: torch.Tensor, + size: Union[int, Sequence, np.ndarray], + fn: Optional[Callable] = None, + camera: Optional[Camera] = None, + valid: np.ndarray = None, +): + """Resize an image to a fixed size, or according to max or min edge.""" + *_, h, w = image.shape + if fn is not None: + assert isinstance(size, int) + scale = size / fn(h, w) + h_new, w_new = int(round(h * scale)), int(round(w * scale)) + scale = (scale, scale) + else: + if isinstance(size, (collections.abc.Sequence, np.ndarray)): + w_new, h_new = size + elif isinstance(size, int): + w_new = h_new = size + else: + raise ValueError(f"Incorrect new size: {size}") + scale = (w_new / w, h_new / h) + if (w, h) != (w_new, h_new): + mode = tvf.InterpolationMode.BILINEAR + image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True) + image.clip_(0, 1) + if camera is not None: + camera = camera.scale(scale) + if valid is not None: + valid = tvf.resize( + valid.unsqueeze(0), + (h_new, w_new), + interpolation=tvf.InterpolationMode.NEAREST, + ).squeeze(0) + ret = [image, scale] + if camera is not None: + ret.append(camera) + if valid is not None: + ret.append(valid) + return ret + + +def pad_image( + image: torch.Tensor, + size: Union[int, Sequence, np.ndarray], + camera: Optional[Camera] = None, + valid: torch.Tensor = None, + crop_and_center: bool = False, +): + if isinstance(size, int): + w_new = h_new = size + elif isinstance(size, (collections.abc.Sequence, np.ndarray)): + w_new, h_new = size + else: + raise ValueError(f"Incorrect new size: {size}") + *c, h, w = image.shape + if crop_and_center: + diff = np.array([w - w_new, h - h_new]) + left, top = left_top = np.round(diff / 2).astype(int) + right, bottom = diff - left_top + else: + assert h <= h_new + assert w <= w_new + top = bottom = left = right = 0 + slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)] + slice_in = np.s_[ + ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0) + ] + if (w, h) == (w_new, h_new): + out = image + else: + out = torch.zeros((*c, h_new, w_new), dtype=image.dtype) + out[slice_out] = image[slice_in] + if camera is not None: + camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new)) + out_valid = torch.zeros((h_new, w_new), dtype=torch.bool) + out_valid[slice_out] = True if valid is None else valid[slice_in] + if camera is not None: + return out, out_valid, camera + else: + return out, out_valid diff --git a/dataset/torch.py b/dataset/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9547ca149c606a345e5b8916591e43c26031022c --- /dev/null +++ b/dataset/torch.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import collections +import os + +import torch +from torch.utils.data import get_worker_info +from torch.utils.data._utils.collate import ( + default_collate_err_msg_format, + np_str_obj_array_pattern, +) +from lightning_fabric.utilities.seed import pl_worker_init_function +from lightning_utilities.core.apply_func import apply_to_collection +from lightning_fabric.utilities.apply_func import move_data_to_device + + +def collate(batch): + """Difference with PyTorch default_collate: it can stack other tensor-like objects. + Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich + https://github.com/cvg/pixloc + Released under the Apache License 2.0 + """ + if not isinstance(batch, list): # no batching + return batch + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return torch.stack(batch, 0, out=out) + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, (str, bytes)): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type(*(collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = zip(*batch) + return [collate(samples) for samples in transposed] + else: + # try to stack anyway in case the object implements stacking. + try: + return torch.stack(batch, 0) + except TypeError as e: + if "expected Tensor as element" in str(e): + return batch + else: + raise e + + +def set_num_threads(nt): + """Force numpy and other libraries to use a limited number of threads.""" + try: + import mkl + except ImportError: + pass + else: + mkl.set_num_threads(nt) + torch.set_num_threads(1) + os.environ["IPC_ENABLE"] = "1" + for o in [ + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + ]: + os.environ[o] = str(nt) + + +def worker_init_fn(i): + info = get_worker_info() + pl_worker_init_function(info.id) + num_threads = info.dataset.cfg.get("num_threads") + if num_threads is not None: + set_num_threads(num_threads) + + +def unbatch_to_device(data, device="cpu"): + data = move_data_to_device(data, device) + data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0)) + data = apply_to_collection( + data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x + ) + return data diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..32b54b259de97aa0253b5a81379cd080897ace29 --- /dev/null +++ b/demo.py @@ -0,0 +1,354 @@ + +import matplotlib.pyplot as plt +# from demo import Demo, read_input_image,read_input_image_test +from evaluation.viz import plot_example_single +from dataset.torch import unbatch_to_device +import matplotlib.pyplot as plt +from typing import Optional, Tuple +import cv2 +import torch +import numpy as np +import time +from logger import logger +from evaluation.run import resolve_checkpoint_path, pretrained_models +from models.maplocnet import MapLocNet +from models.voting import fuse_gps, argmax_xyr +# from data.image import resize_image, pad_image, rectify_image +from osm.raster import Canvas +from utils.wrappers import Camera +from utils.io import read_image +from utils.geo import BoundaryBox, Projection +from utils.exif import EXIF +import requests +from pathlib import Path +from utils.exif import EXIF +from dataset.image import resize_image, pad_image, rectify_image +# from maploc.demo import Demo, read_input_image +from dataset import UavMapDatasetModule +import torchvision.transforms as tvf +import matplotlib.pyplot as plt +import numpy as np +from sklearn.decomposition import PCA +from PIL import Image +# import pyproj +# Query OpenStreetMap for this area +from osm.tiling import TileManager +from utils.viz_localization import ( + likelihood_overlay, + plot_dense_rotations, + add_circle_inset, +) +# Show the inputs to the model: image and raster map +from osm.viz import Colormap, plot_nodes +from utils.viz_2d import plot_images + +from utils.viz_2d import features_to_RGB +import random +from geopy.distance import geodesic + + +def vis_image_feature(F): + def normalize(x): + return x / np.linalg.norm(x, axis=-1, keepdims=True) + + # F=neural_map.numpy() + F = F[:, 0:180, 0:180] + flatten = [] + c, h, w = F.shape + print(F.shape) + F = np.rollaxis(F, 0, 3) + F_flat = F.reshape(-1, c) + flatten.append(F_flat) + flatten = normalize(flatten)[0] + + flatten = np.nan_to_num(flatten, nan=0) + pca = PCA(n_components=3) + + print(flatten.shape) + flatten = pca.fit_transform(flatten) + flatten = (normalize(flatten) + 1) / 2 + + # h, w = F.shape[-2:] + F_rgb, flatten = np.split(flatten, [h * w], axis=0) + F_rgb = F_rgb.reshape((h, w, 3)) + return F_rgb +def distance(lat1, lon1, lat2, lon2): + point1 = (lat1, lon1) + point2 = (lat2, lon2) + distance_km = geodesic(point1, point2).meters + return distance_km + +# # 示例 +# lat1, lon1 = 39.9, 116.4 # 北京的经纬度 +# lat2, lon2 = 31.2, 121.5 # 上海的经纬度 + +# distance_km = distance(lat1, lon1, lat2, lon2) +# print(distance_km) +def show_result(map_vis_image, pre_uv, pre_yaw): + # 创建一个和原始图片大小相同的灰色蒙版图像 + gray_mask = np.zeros_like(map_vis_image) + gray_mask.fill(128) # 填充灰色 + + # 将灰色蒙版图像与原始图像进行融合 + image = cv2.addWeighted(map_vis_image, 1, gray_mask, 0, 0) + # 绘制真实值 + + # 绘制预测值 + u, v = pre_uv + x1, y1 = int(u), int(v) # 替换为实际的起点坐标 + angle = pre_yaw - 90 # 替换为实际的箭头角度 + # 计算箭头的终点坐标 + length = 20 + x2 = int(x1 + length * np.cos(np.radians(angle))) + y2 = int(y1 + length * np.sin(np.radians(angle))) + # 在图像上画出箭头 + cv2.arrowedLine(image, (x1, y1), (x2, y2), (0, 0, 0), 2, 5, 0, 0.3) + # cv2.circle(image, (x1, y1), radius=2, color=(255, 0, 255), thickness=-1) + return image + + +def xyz_to_latlon(x, y, z): + # 定义WGS84投影 + wgs84 = pyproj.CRS('EPSG:4326') + + # 定义XYZ投影 + xyz = pyproj.CRS(f'+proj=geocent +datum=WGS84 +units=m +no_defs') + + # 创建坐标转换器 + transformer = pyproj.Transformer.from_crs(xyz, wgs84) + + # 转换坐标 + lon, lat, _ = transformer.transform(x, y, z) + + return lat, lon + + +class Demo: + def __init__( + self, + experiment_or_path: Optional[str] = "OrienterNet_MGL", + device=None, + **kwargs + ): + if experiment_or_path in pretrained_models: + experiment_or_path, _ = pretrained_models[experiment_or_path] + path = resolve_checkpoint_path(experiment_or_path) + ckpt = torch.load(path, map_location=(lambda storage, loc: storage)) + config = ckpt["hyper_parameters"] + config.model.update(kwargs) + config.model.image_encoder.backbone.pretrained = False + + model = MapLocNet(config.model).eval() + state = {k[len("model."):]: v for k, v in ckpt["state_dict"].items()} + model.load_state_dict(state, strict=True) + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + self.model = model + self.config = config + self.device = device + + def prepare_data( + self, + image: np.ndarray, + camera: Camera, + canvas: Canvas, + roll_pitch: Optional[Tuple[float]] = None, + ): + + image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255) + + return { + 'map': torch.from_numpy(canvas.raster).long(), + 'image': image, + # 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float().unsqueeze(0), + # 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float().unsqueeze(0), + # "uv":torch.tensor([float(u), float(v)]).float().unsqueeze(0), + } + # return dict( + # image=image, + # map=torch.from_numpy(canvas.raster).long(), + # camera=camera.float(), + # valid=valid, + # ) + + def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs): + + data = self.prepare_data(image, camera, canvas, **kwargs) + data_ = {k: v.to(self.device)[None] for k, v in data.items()} + # data_np = {k: v.cpu().numpy()[None] for k, v in data.items()} + # logger.info(data_) + # np.save(data_np, 'data_.npy') + start = time.time() + with torch.no_grad(): + pred = self.model(data_) + + end = time.time() + xy_gps = canvas.bbox.center + uv_gps = torch.from_numpy(canvas.to_uv(xy_gps)) + + lp_xyr = pred["log_probs"].squeeze(0) + # tile_size = canvas.bbox.size.min() / 2 + # sigma = tile_size - 20 # 20 meters margin + # lp_xyr = fuse_gps( + # lp_xyr, + # uv_gps.to(lp_xyr), + # self.config.model.pixel_per_meter, + # sigma=sigma, + # ) + xyr = argmax_xyr(lp_xyr).cpu() + + prob = lp_xyr.exp().cpu() + neural_map = pred["map"]["map_features"][0].squeeze(0).cpu() + print('total time:', start - end) + return xyr[:2], xyr[2], prob, neural_map, data["image"], data_, pred + + +def load_test_data( + root: Path, + city: str, + index: int, +): + uav_image_path = root / city / 'uav' + map_path = root / city / 'map' + map_vis = root / city / 'map_vis' + info_path = root / city / 'info.csv' + osm_path = root / city / '{}.osm'.format(city) + + info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1) + + id, uav_name, map_name, \ + uav_long, uav_lat, \ + map_long, map_lat, \ + tile_size_meters, pixel_per_meter, \ + u, v, yaw, dis = info[index] + print(info[index]) + uav_image_rgb = cv2.imread(str(uav_image_path / uav_name)) + uav_image_rgb = cv2.cvtColor(uav_image_rgb, cv2.COLOR_BGR2RGB) + + # w,h,c=uav_image_rgb.shape + # # 指定裁剪区域的坐标 + # x = w//2 # 起始横坐标 + # y = h//2 # 起始纵坐标 + # w = 150 # 宽度 + # h = 150 # 高度 + + # # 裁剪图像 + # uav_image_rgb = uav_image_rgb[y-h:y+h, x-w:x+w] + + map_vis_image = cv2.imread(str(map_vis / uav_name)) + map_vis_image = cv2.cvtColor(map_vis_image, cv2.COLOR_BGR2RGB) + + map = np.load(str(map_path / map_name)) + + tfs = [] + tfs.append(tvf.ToTensor()) + tfs.append(tvf.Resize(256)) + val_tfs = tvf.Compose(tfs) + + uav_image = val_tfs(uav_image_rgb) + # print(id, uav_name, map_name, \ + # uav_long, uav_lat, \ + # map_long, map_lat, \ + # tile_size_meters, pixel_per_meter, \ + # u, v, yaw,dis) + uav_path = str(uav_image_path / uav_name) + return { + 'map': torch.from_numpy(np.ascontiguousarray(map)).long().unsqueeze(0), + 'image': torch.tensor(uav_image).unsqueeze(0), + 'roll_pitch_yaw': torch.tensor((0, 0, float(yaw))).float().unsqueeze(0), + 'pixels_per_meter': torch.tensor(float(pixel_per_meter)).float().unsqueeze(0), + "uv": torch.tensor([float(u), float(v)]).float().unsqueeze(0), + }, uav_image_rgb, map_vis_image, uav_path, [float(map_lat), float(map_long)] + + +def crop_image(image, width, height): + # 计算剪裁区域的起始点坐标 + x = int((image.shape[1] - width) / 2) + y = int((image.shape[0] - height) / 2) + + # 剪裁图像 + cropped_image = image[y:y + height, x:x + width] + return cropped_image + + +def crop_square(image): + # 获取图像的宽度和高度 + height, width = image.shape[:2] + + # 确定最小边的长度 + min_length = min(height, width) + + # 计算剪裁区域的坐标 + top = (height - min_length) // 2 + bottom = top + min_length + left = (width - min_length) // 2 + right = left + min_length + + # 剪裁图像为正方形 + cropped_image = image[top:bottom, left:right] + + return cropped_image +def read_input_image_test( + image, + prior_latlon, + tile_size_meters, +): + # image = read_image(image_path) + # # 剪裁图像 + # # 指定剪裁的宽度和高度 + # width = 1080*2 + # height =1080*2 + # image = crop_square(image) + # # print("input image:",image.shape) + # image = crop_image(image, width, height) + # # print("crop_image:",image.shape) + image = cv2.resize(image,(256,256)) + roll_pitch = None + + + latlon = None + if prior_latlon is not None: + latlon = prior_latlon + logger.info("Using prior latlon %s.", prior_latlon) + + if latlon is None: + with open(image_path, "rb") as fid: + exif = EXIF(fid, lambda: image.shape[:2]) + geo = exif.extract_geo() + if geo: + alt = geo.get("altitude", 0) # read if available + latlon = (geo["latitude"], geo["longitude"], alt) + logger.info("Using prior location from EXIF.") + # print(latlon) + else: + logger.info("Could not find any prior location in the image EXIF metadata.") + + latlon = np.array(latlon) + + proj = Projection(*latlon) + center = proj.project(latlon) + bbox = BoundaryBox(center, center) + float(tile_size_meters) + camera=None + image=cv2.resize(image,(256,256)) + return image, camera, roll_pitch, proj, bbox, latlon +if __name__ == '__main__': + experiment_or_path = "weight/last-step-checkpointing.ckpt" + # experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt" + image_path='images/00000.jpg' + prior_latlon=(37.75704325989902,-122.435941445631) + tile_size_meters=128 + demo = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu') + image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test( + image_path, + prior_latlon=prior_latlon, + tile_size_meters=tile_size_meters, # try 64, 256, etc. + ) + tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1, tile_size=tile_size_meters) + # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1) + canvas = tiler.query(bbox) + uv, yaw, prob, neural_map, image_rectified, data_, pred = demo.localize( + image, camera, canvas) + prior_latlon_pred = proj.unproject(canvas.to_xy(uv)) + pass \ No newline at end of file diff --git a/evaluation/kitti.py b/evaluation/kitti.py new file mode 100644 index 0000000000000000000000000000000000000000..e91da069f307a533b0471a3fb43f8622cadc60db --- /dev/null +++ b/evaluation/kitti.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import argparse +from pathlib import Path +from typing import Optional, Tuple + +from omegaconf import OmegaConf, DictConfig + +from .. import logger +from ..data import KittiDataModule +from .run import evaluate + + +default_cfg_single = OmegaConf.create({}) +# For the sequential evaluation, we need to center the map around the GT location, +# since random offsets would accumulate and leave only the GT location with a valid mask. +# This should not have much impact on the results. +default_cfg_sequential = OmegaConf.create( + { + "data": { + "mask_radius": KittiDataModule.default_cfg["max_init_error"], + "prior_range_rotation": KittiDataModule.default_cfg[ + "max_init_error_rotation" + ] + + 1, + "max_init_error": 0, + "max_init_error_rotation": 0, + }, + "chunking": { + "max_length": 100, # about 10s? + }, + } +) + + +def run( + split: str, + experiment: str, + cfg: Optional[DictConfig] = None, + sequential: bool = False, + thresholds: Tuple[int] = (1, 3, 5), + **kwargs, +): + cfg = cfg or {} + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + default = default_cfg_sequential if sequential else default_cfg_single + cfg = OmegaConf.merge(default, cfg) + dataset = KittiDataModule(cfg.get("data", {})) + + metrics = evaluate( + experiment, + cfg, + dataset, + split=split, + sequential=sequential, + viz_kwargs=dict(show_dir_error=True, show_masked_prob=False), + **kwargs, + ) + + keys = ["directional_error", "yaw_max_error"] + if sequential: + keys += ["directional_seq_error", "yaw_seq_error"] + for k in keys: + rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() + logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--experiment", type=str, required=True) + parser.add_argument( + "--split", type=str, default="test", choices=["test", "val", "train"] + ) + parser.add_argument("--sequential", action="store_true") + parser.add_argument("--output_dir", type=Path) + parser.add_argument("--num", type=int) + parser.add_argument("dotlist", nargs="*") + args = parser.parse_args() + cfg = OmegaConf.from_cli(args.dotlist) + run( + args.split, + args.experiment, + cfg, + args.sequential, + output_dir=args.output_dir, + num=args.num, + ) diff --git a/evaluation/mapillary.py b/evaluation/mapillary.py new file mode 100644 index 0000000000000000000000000000000000000000..c45b845bacd6d6a9c995d3b8d7ee9cd9ec2a9f78 --- /dev/null +++ b/evaluation/mapillary.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import argparse +from pathlib import Path +from typing import Optional, Tuple + +from omegaconf import OmegaConf, DictConfig + +from .. import logger +from ..conf import data as conf_data_dir +from ..data import MapillaryDataModule +from .run import evaluate + + +split_overrides = { + "val": { + "scenes": [ + "sanfrancisco_soma", + "sanfrancisco_hayes", + "amsterdam", + "berlin", + "lemans", + "montrouge", + "toulouse", + "nantes", + "vilnius", + "avignon", + "helsinki", + "milan", + "paris", + ], + }, +} +data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml") +data_cfg = OmegaConf.merge( + data_cfg_train, + { + "return_gps": True, + "add_map_mask": True, + "max_init_error": 32, + "loading": {"val": {"batch_size": 1, "num_workers": 0}}, + }, +) +default_cfg_single = OmegaConf.create({"data": data_cfg}) +default_cfg_sequential = OmegaConf.create( + { + **default_cfg_single, + "chunking": { + "max_length": 10, + }, + } +) + + +def run( + split: str, + experiment: str, + cfg: Optional[DictConfig] = None, + sequential: bool = False, + thresholds: Tuple[int] = (1, 3, 5), + **kwargs, +): + cfg = cfg or {} + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + default = default_cfg_sequential if sequential else default_cfg_single + default = OmegaConf.merge(default, split_overrides[split]) + cfg = OmegaConf.merge(default, cfg) + dataset = MapillaryDataModule(cfg.get("data", {})) + + metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs) + + keys = [ + "xy_max_error", + "xy_gps_error", + "yaw_max_error", + ] + if sequential: + keys += [ + "xy_seq_error", + "xy_gps_seq_error", + "yaw_seq_error", + "yaw_gps_seq_error", + ] + for k in keys: + if k not in metrics: + logger.warning("Key %s not in metrics.", k) + continue + rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() + logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--experiment", type=str, required=True) + parser.add_argument("--split", type=str, default="val", choices=["val"]) + parser.add_argument("--sequential", action="store_true") + parser.add_argument("--output_dir", type=Path) + parser.add_argument("--num", type=int) + parser.add_argument("dotlist", nargs="*") + args = parser.parse_args() + cfg = OmegaConf.from_cli(args.dotlist) + run( + args.split, + args.experiment, + cfg, + args.sequential, + output_dir=args.output_dir, + num=args.num, + ) diff --git a/evaluation/run.py b/evaluation/run.py new file mode 100644 index 0000000000000000000000000000000000000000..ff29688454358303643f3a26f7900486c19dbf22 --- /dev/null +++ b/evaluation/run.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import functools +from itertools import islice +from typing import Callable, Dict, Optional, Tuple +from pathlib import Path + +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +from torchmetrics import MetricCollection +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from logger import logger, EXPERIMENTS_PATH +from dataset.torch import collate, unbatch_to_device +from models.voting import argmax_xyr, fuse_gps +from models.metrics import AngleError, LateralLongitudinalError, Location2DError +# from models.sequential import GPSAligner, RigidAligner +from module import GenericModule +from utils.io import download_file, DATA_URL +from evaluation.viz import plot_example_single, plot_example_sequential +from evaluation.utils import write_dump + + +pretrained_models = dict( + OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)), +) + + +def resolve_checkpoint_path(experiment_or_path: str) -> Path: + path = Path(experiment_or_path) + if not path.exists(): + # provided name of experiment + path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/")) + if not path.exists(): + if experiment_or_path in set(p for p, _ in pretrained_models.values()): + download_file(f"{DATA_URL}/{experiment_or_path}", path) + else: + raise FileNotFoundError(path) + if path.is_file(): + return path + # provided only the experiment name + maybe_path = path / "last-step-v1.ckpt" + if not maybe_path.exists(): + maybe_path = path / "last.ckpt" + if not maybe_path.exists(): + raise FileNotFoundError(f"Could not find any checkpoint in {path}.") + return maybe_path + + +@torch.no_grad() +def evaluate_single_image( + dataloader: torch.utils.data.DataLoader, + model: GenericModule, + num: Optional[int] = None, + callback: Optional[Callable] = None, + progress: bool = True, + mask_index: Optional[Tuple[int]] = None, + has_gps: bool = False, +): + ppm = model.model.conf.pixel_per_meter + metrics = MetricCollection(model.model.metrics()) + metrics["directional_error"] = LateralLongitudinalError(ppm) + if has_gps: + metrics["xy_gps_error"] = Location2DError("uv_gps", ppm) + metrics["xy_fused_error"] = Location2DError("uv_fused", ppm) + metrics["yaw_fused_error"] = AngleError("yaw_fused") + metrics = metrics.to(model.device) + + for i, batch_ in enumerate( + islice(tqdm(dataloader, total=num, disable=not progress), num) + ): + batch = model.transfer_batch_to_device(batch_, model.device, i) + # Ablation: mask semantic classes + if mask_index is not None: + mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1) + batch["map"][0, mask_index[0]][mask] = 0 + pred = model(batch) + + if has_gps: + (uv_gps,) = pred["uv_gps"] = batch["uv_gps"] + pred["log_probs_fused"] = fuse_gps( + pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"] + ) + uvt_fused = argmax_xyr(pred["log_probs_fused"]) + pred["uv_fused"] = uvt_fused[..., :2] + pred["yaw_fused"] = uvt_fused[..., -1] + del uv_gps, uvt_fused + + results = metrics(pred, batch) + if callback is not None: + callback( + i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results + ) + del batch_, batch, pred, results + + return metrics.cpu() + + +@torch.no_grad() +def evaluate_sequential( + dataset: torch.utils.data.Dataset, + chunk2idx: Dict, + model: GenericModule, + num: Optional[int] = None, + shuffle: bool = False, + callback: Optional[Callable] = None, + progress: bool = True, + num_rotations: int = 512, + mask_index: Optional[Tuple[int]] = None, + has_gps: bool = True, +): + chunk_keys = list(chunk2idx) + if shuffle: + chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))] + if num is not None: + chunk_keys = chunk_keys[:num] + lengths = [len(chunk2idx[k]) for k in chunk_keys] + logger.info( + "Min/max/med lengths: %d/%d/%d, total number of images: %d", + min(lengths), + np.median(lengths), + max(lengths), + sum(lengths), + ) + viz = callback is not None + + metrics = MetricCollection(model.model.metrics()) + ppm = model.model.conf.pixel_per_meter + metrics["directional_error"] = LateralLongitudinalError(ppm) + metrics["xy_seq_error"] = Location2DError("uv_seq", ppm) + metrics["yaw_seq_error"] = AngleError("yaw_seq") + metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq") + if has_gps: + metrics["xy_gps_error"] = Location2DError("uv_gps", ppm) + metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm) + metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq") + metrics = metrics.to(model.device) + + keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"] + if has_gps: + keys_save.append("uv_gps") + if viz: + keys_save.append("log_probs") + + for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)): + indices = chunk2idx[key] + aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations) + if has_gps: + aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations) + batches = [] + preds = [] + for i in indices: + data = dataset[i] + data = model.transfer_batch_to_device(data, model.device, 0) + pred = model(collate([data])) + + canvas = data["canvas"] + data["xy_geo"] = xy = canvas.to_xy(data["uv"].double()) + data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double() + aligner.update(pred["log_probs"][0], canvas, xy, yaw) + + if has_gps: + (uv_gps) = pred["uv_gps"] = data["uv_gps"][None] + xy_gps = canvas.to_xy(uv_gps.double()) + aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw) + + if not viz: + data.pop("image") + data.pop("map") + batches.append(data) + preds.append({k: pred[k][0] for k in keys_save}) + del pred + + xy_gt = torch.stack([b["xy_geo"] for b in batches]) + yaw_gt = torch.stack([b["yaw"] for b in batches]) + aligner.compute() + xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt) + if has_gps: + aligner_gps.compute() + xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt) + results = [] + for i in range(len(indices)): + preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float() + preds[i]["yaw_seq"] = yaw_seq[i].float() + if has_gps: + preds[i]["uv_gps_seq"] = ( + batches[i]["canvas"].to_uv(xy_gps_seq[i]).float() + ) + preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float() + results.append(metrics(preds[i], batches[i])) + if viz: + callback(chunk_index, model, batches, preds, results, aligner) + del aligner, preds, batches, results + return metrics.cpu() + + +def evaluate( + experiment: str, + cfg: DictConfig, + dataset, + split: str, + sequential: bool = False, + output_dir: Optional[Path] = None, + callback: Optional[Callable] = None, + num_workers: int = 1, + viz_kwargs=None, + **kwargs, +): + if experiment in pretrained_models: + experiment, cfg_override = pretrained_models[experiment] + cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg) + + logger.info("Evaluating model %s with config %s", experiment, cfg) + checkpoint_path = resolve_checkpoint_path(experiment) + model = GenericModule.load_from_checkpoint( + checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt") + ) + model = model.eval() + if torch.cuda.is_available(): + model = model.cuda() + + dataset.prepare_data() + dataset.setup() + + if output_dir is not None: + output_dir.mkdir(exist_ok=True, parents=True) + if callback is None: + if sequential: + callback = plot_example_sequential + else: + callback = plot_example_single + callback = functools.partial( + callback, out_dir=output_dir, **(viz_kwargs or {}) + ) + kwargs = {**kwargs, "callback": callback} + + seed_everything(dataset.cfg.seed) + if sequential: + dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking) + metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs) + else: + loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers) + metrics = evaluate_single_image(loader, model, **kwargs) + + results = metrics.compute() + logger.info("All results: %s", results) + if output_dir is not None: + write_dump(output_dir, experiment, cfg, results, metrics) + logger.info("Outputs have been written to %s.", output_dir) + return metrics diff --git a/evaluation/utils.py b/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc954ed20557c351965cad89aa2e249986986ee --- /dev/null +++ b/evaluation/utils.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import numpy as np +from omegaconf import OmegaConf + +from utils.io import write_json + + +def compute_recall(errors): + num_elements = len(errors) + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(num_elements) + 1) / num_elements + recall = np.r_[0, recall] + errors = np.r_[0, errors] + return errors, recall + + +def compute_auc(errors, recall, thresholds): + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t, side="right") + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + auc = np.trapz(r, x=e) / t + aucs.append(auc * 100) + return aucs + + +def write_dump(output_dir, experiment, cfg, results, metrics): + dump = { + "experiment": experiment, + "cfg": OmegaConf.to_container(cfg), + "results": results, + "errors": {}, + } + for k, m in metrics.items(): + if hasattr(m, "get_errors"): + dump["errors"][k] = m.get_errors().numpy() + write_json(output_dir / "log.json", dump) diff --git a/evaluation/viz.py b/evaluation/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd9f7dfd0f2103f2ebdda8cfe8022ad5a2e719b --- /dev/null +++ b/evaluation/viz.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from utils.io import write_torch_image +from utils.viz_2d import plot_images, features_to_RGB, save_plot +from utils.viz_localization import ( + likelihood_overlay, + plot_pose, + plot_dense_rotations, + add_circle_inset, +) +from osm.viz import Colormap, plot_nodes + + +def plot_example_single( + idx, + model, + pred, + data, + results, + plot_bev=True, + out_dir=None, + fig_for_paper=False, + show_gps=False, + show_fused=False, + show_dir_error=False, + show_masked_prob=False, +): + scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv")) + uv_gps = data.get("uv_gps") + yaw_gt = data["roll_pitch_yaw"][-1].numpy() + image = data["image"].permute(1, 2, 0) + if "valid" in data: + image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3) + + lp_uvt = lp_uv = pred["log_probs"] + if show_fused and "log_probs_fused" in pred: + lp_uvt = lp_uv = pred["log_probs_fused"] + elif not show_masked_prob and "scores_unmasked" in pred: + lp_uvt = lp_uv = pred["scores_unmasked"] + has_rotation = lp_uvt.ndim == 3 + if has_rotation: + lp_uv = lp_uvt.max(-1).values + if lp_uv.min() > -np.inf: + lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1)) + prob = lp_uv.exp() + uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max") + if show_fused and "uv_fused" in pred: + uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused") + feats_map = pred["map"]["map_features"][0] + (feats_map_rgb,) = features_to_RGB(feats_map.numpy()) + + text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m' + if has_rotation: + text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°' + if show_fused and "xy_fused_error" in results: + text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m' + text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°' + if show_dir_error and "directional_error" in results: + err_lat, err_lon = results["directional_error"] + text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m" + if "xy_gps_error" in results: + text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m' + + map_viz = Colormap.apply(rasters) + overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True)) + plot_images( + [image, map_viz, overlay, feats_map_rgb], + titles=[text1, "map", "likelihood", "neural map"], + dpi=75, + cmaps="jet", + ) + fig = plt.gcf() + axes = fig.axes + axes[1].images[0].set_interpolation("none") + axes[2].images[0].set_interpolation("none") + Colormap.add_colorbar() + plot_nodes(1, rasters[2]) + + if show_gps and uv_gps is not None: + plot_pose([1], uv_gps, c="blue") + plot_pose([1], uv_gt, yaw_gt, c="red") + plot_pose([1], uv_p, yaw_p, c="k") + plot_dense_rotations(2, lp_uvt.exp()) + inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt + axins = add_circle_inset(axes[2], inset_center) + axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15) + axes[0].text( + 0.003, + 0.003, + f"{scene}/{name}", + transform=axes[0].transAxes, + fontsize=3, + va="bottom", + ha="left", + color="w", + ) + plt.show() + if out_dir is not None: + name_ = name.replace("/", "_") + p = str(out_dir / f"{scene}_{name_}_{{}}.pdf") + save_plot(p.format("pred")) + plt.close() + + if fig_for_paper: + # !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg + plot_images([map_viz]) + plt.gca().images[0].set_interpolation("none") + plot_nodes(0, rasters[2]) + plot_pose([0], uv_gt, yaw_gt, c="red") + plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k") + save_plot(p.format("map")) + plt.close() + plot_images([lp_uv], cmaps="jet") + plot_dense_rotations(0, lp_uvt.exp()) + save_plot(p.format("loglikelihood"), dpi=100) + plt.close() + plot_images([overlay]) + plt.gca().images[0].set_interpolation("none") + axins = add_circle_inset(plt.gca(), inset_center) + axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50) + save_plot(p.format("likelihood")) + plt.close() + write_torch_image( + p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb + ) + write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy()) + + if not plot_bev: + return + + feats_q = pred["features_bev"] + mask_bev = pred["valid_bev"] + prior = None + if "log_prior" in pred["map"]: + prior = pred["map"]["log_prior"][0].sigmoid() + if "bev" in pred and "confidence" in pred["bev"]: + conf_q = pred["bev"]["confidence"] + else: + conf_q = torch.norm(feats_q, dim=0) + conf_q = conf_q.masked_fill(~mask_bev, np.nan) + (feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()]) + # feats_map_rgb, feats_q_rgb, = features_to_RGB( + # feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev]) + norm_map = torch.norm(feats_map, dim=0) + + plot_images( + [conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]), + titles=["BEV confidence", "BEV features", "map norm"] + + ([] if prior is None else ["map prior"]), + dpi=50, + cmaps="jet", + ) + plt.show() + + if out_dir is not None: + save_plot(p.format("bev")) + plt.close() + + +def plot_example_sequential( + idx, + model, + pred, + data, + results, + plot_bev=True, + out_dir=None, + fig_for_paper=False, + show_gps=False, + show_fused=False, + show_dir_error=False, + show_masked_prob=False, +): + return diff --git a/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png b/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png new file mode 100644 index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7 Binary files /dev/null and b/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png differ diff --git a/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png b/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png new file mode 100644 index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7 Binary files /dev/null and b/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png differ diff --git a/flagged/log.csv b/flagged/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..61a8dd134c6a6e7a5cd88ccf2ef430e489e8d4b4 --- /dev/null +++ b/flagged/log.csv @@ -0,0 +1,3 @@ +inp,longitude,latitude,Area,output,flag,username,timestamp +E:\MapLocNetDemo\Demo\flagged\inp\10d2e4a8712491181c2f48b61f5003b216d2b9f9\tmp48n9eoyh.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmp59657zop.json,,,2023-09-22 10:07:17.488625 +E:\MapLocNetDemo\Demo\flagged\inp\e1b18d44d9e381d586209f73a015fed7f688822b\tmp86ith_2q.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmpbs17s28d.json,,,2023-09-22 10:07:21.485967 diff --git a/flagged/output/tmp59657zop.json b/flagged/output/tmp59657zop.json new file mode 100644 index 0000000000000000000000000000000000000000..6da7282f99d84615c8e174ce435ecd85765184f8 --- /dev/null +++ b/flagged/output/tmp59657zop.json @@ -0,0 +1 @@ +{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]} \ No newline at end of file diff --git a/flagged/output/tmpbs17s28d.json b/flagged/output/tmpbs17s28d.json new file mode 100644 index 0000000000000000000000000000000000000000..6da7282f99d84615c8e174ce435ecd85765184f8 --- /dev/null +++ b/flagged/output/tmpbs17s28d.json @@ -0,0 +1 @@ +{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]} \ No newline at end of file diff --git a/images/00000.jpg b/images/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3340fd6671c91ec138a9cef129df4f9ce5adbd6 Binary files /dev/null and b/images/00000.jpg differ diff --git a/images/00011.jpg b/images/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c9201b98433b985bafd85793bd7992c8e7f55c6 Binary files /dev/null and b/images/00011.jpg differ diff --git a/images/00022.jpg b/images/00022.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d23d874d54e1f2608d6907fd1ef3416ac6e0716 Binary files /dev/null and b/images/00022.jpg differ diff --git a/images/00033.jpg b/images/00033.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fe8b2a1dac78db32ab979df3ca482a37108d5698 Binary files /dev/null and b/images/00033.jpg differ diff --git a/images/cat_dog.png b/images/cat_dog.png new file mode 100644 index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7 Binary files /dev/null and b/images/cat_dog.png differ diff --git a/label.txt b/label.txt new file mode 100644 index 0000000000000000000000000000000000000000..888d6f51dd77bbf10b76dba58b5d2f1afa8ad5bd --- /dev/null +++ b/label.txt @@ -0,0 +1,1000 @@ +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..56b4d718091ad22d84d59387ca76628aa242555e --- /dev/null +++ b/logger.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from pathlib import Path +import logging + +import pytorch_lightning # noqa: F401 + + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger("maploc") +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +pl_logger = logging.getLogger("pytorch_lightning") +if len(pl_logger.handlers): + pl_logger.handlers[0].setFormatter(formatter) + +repo_dir = Path(__file__).parent +EXPERIMENTS_PATH = repo_dir / "experiments/" +DATASETS_PATH = repo_dir / "datasets/" diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..311d7687bba3cb830615cfb9e89644e25df3d2ee --- /dev/null +++ b/main.py @@ -0,0 +1,98 @@ +import gradio as gr +import cv2 +import gradio as gr +import torch +from torchvision import transforms +import requests +from PIL import Image +from demo import Demo,read_input_image_test,show_result,vis_image_feature +from osm.tiling import TileManager +from osm.viz import Colormap, plot_nodes +from utils.viz_2d import plot_images +import numpy as np +from utils.viz_2d import features_to_RGB +from utils.viz_localization import ( + likelihood_overlay, + plot_dense_rotations, + add_circle_inset, +) +from osm.viz import GeoPlotter +import matplotlib.pyplot as plt +import random +from geopy.distance import geodesic + +experiment_or_path = "weight/last-step-checkpointing.ckpt" +# experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt" +image_path = 'images/00000.jpg' + +# prior_latlon = (37.75704325989902, -122.435941445631) +# tile_size_meters = 128 +model = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu') + +def demo_localize(image,long,lat,tile_size_meters): + # inp = Image.fromarray(inp.astype('uint8'), 'RGB') + # inp = transforms.ToTensor()(inp).unsqueeze(0) + prior_latlon=(lat,long) + image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test( + image, + prior_latlon=prior_latlon, + tile_size_meters=tile_size_meters, # try 64, 256, etc. + ) + tiler = TileManager.from_bbox(projection=proj, bbox=bbox, ppm=1, tile_size=tile_size_meters) + # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1) + canvas = tiler.query(bbox) + uv, yaw, prob, neural_map, image_rectified, data_, pred = model.localize( + image, camera, canvas) + prior_latlon_pred = proj.unproject(canvas.to_xy(uv)) + + map_viz = Colormap.apply(canvas.raster) + map_vis_image_result = map_viz * 255 + map_vis_image_result =show_result(map_vis_image_result.astype(np.uint8), uv, yaw) + # map_vis_image_result = show_result(map_vis_image_result.astype(np.uint8), True_uv, + # uv, + # 90.0 - yaw_T, + # yaw) + # return prior_latlon_pred + uab_feature_rgb = vis_image_feature(pred['features_image'][0].cpu().numpy()) + map_viz = cv2.resize(map_viz, (prob.numpy().shape[0], prob.numpy().shape[1])) + overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True)) + (neural_map_rgb,) = features_to_RGB(neural_map.numpy()) + fig=plot_images([image, map_vis_image_result / 255, overlay, uab_feature_rgb, neural_map_rgb], + titles=["UAV image", "map","likelihood","UAV feature","map feature"]) + # plot_images([overlay, neural_map_rgb], titles=["prediction", "neural map"]) + # ax = plt.gcf().axes[2] + # ax.scatter(*canvas.to_uv(bbox.center), s=5, c="red") + # plot_dense_rotations(ax, prob, w=0.005, s=1 / 25) + # add_circle_inset(ax, uv) + + # Plot as interactive figure + bbox_latlon = proj.unproject(canvas.bbox) + plot2 = GeoPlotter(zoom=16.5) + plot2.raster(map_viz, bbox_latlon, opacity=0.5) + plot2.raster(likelihood_overlay(prob.numpy().max(-1)), proj.unproject(bbox)) + plot2.points(prior_latlon[:2], "red", name="location prior", size=10) + plot2.points(proj.unproject(canvas.to_xy(uv)), "black", name="argmax", size=10) + plot2.bbox(bbox_latlon, "blue", name="map tile") + # plot2.fig.show() + return fig,plot2.fig,str(prior_latlon_pred) +# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval() +#标题 +title = "MapLocNet" +#标题下的描述,支持md格式 +description = "UAV Vision-based Geo-Localization Using Vectorized Maps" + +# outputs = gr.outputs.Label(num_top_classes=3) +outputs = gr.Plot() +interface = gr.Interface(fn=demo_localize, + inputs=["image", + gr.Number(label="Prior location-longitude)"), + gr.Number(label="Prior location-longitude)"), + gr.Radio([64, 128, 256], label="Search radius (meters)", info="vectorized map size"), + # gr.inputs.RadioGroup(label="Search radius (meters)",["English", "French", "Spanish"]), + # gr.Slider(64, 512,label='Search radius (meters)') + ], + outputs=["plot","plot","text"], + title=title, + description=description, + examples=[['images/00000.jpg',-122.435941445631,37.75704325989902,128]]) +interface.launch(share=True) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02c1f950d5d3f84b18ba4178e2549fc328479d3f --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/pixloc +# Released under the Apache License 2.0 + +import inspect + +from .base import BaseModel + + +def get_class(mod_name, base_path, BaseClass): + """Get the class object which inherits from BaseClass and is defined in + the module named mod_name, child of base_path. + """ + mod_path = "{}.{}".format(base_path, mod_name) + mod = __import__(mod_path, fromlist=[""]) + classes = inspect.getmembers(mod, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == mod_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseClass)] + assert len(classes) == 1, classes + return classes[0][1] + + +def get_model(name): + if name == "localizer": + name = "localizer_basic" + elif name == "rotation_localizer": + name = "localizer_basic_rotation" + elif name == "bev_localizer": + name = "localizer_bev_plane" + return get_class(name, __name__, BaseModel) diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a9978eabb3c32fcd98f12399347f4c864e463494 --- /dev/null +++ b/models/base.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/pixloc +# Released under the Apache License 2.0 + +""" +Base class for trainable models. +""" + +from abc import ABCMeta, abstractmethod +from copy import copy + +import omegaconf +from omegaconf import OmegaConf +from torch import nn + + +class BaseModel(nn.Module, metaclass=ABCMeta): + """ + What the child model is expect to declare: + default_conf: dictionary of the default configuration of the model. + It recursively updates the default_conf of all parent classes, and + it is updated by the user-provided configuration passed to __init__. + Configurations can be nested. + + required_data_keys: list of expected keys in the input data dictionary. + + strict_conf (optional): boolean. If false, BaseModel does not raise + an error when the user provides an unknown configuration entry. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + _forward(self, data): method that returns a dictionary of batched + prediction tensors based on a dictionary of batched input data tensors. + + loss(self, pred, data): method that returns a dictionary of losses, + computed from model predictions and input data. Each loss is a batch + of scalars, i.e. a torch.Tensor of shape (B,). + The total loss to be optimized has the key `'total'`. + + metrics(self, pred, data): method that returns a dictionary of metrics, + each as a batch of scalars. + """ + + base_default_conf = { + "name": None, + "trainable": True, # if false: do not optimize this model parameters + "freeze_batch_normalization": False, # use test-time statistics + } + default_conf = {} + required_data_keys = [] + strict_conf = True + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + default_conf = OmegaConf.merge( + self.base_default_conf, OmegaConf.create(self.default_conf) + ) + if self.strict_conf: + OmegaConf.set_struct(default_conf, True) + + # fixme: backward compatibility + if "pad" in conf and "pad" not in default_conf: # backward compat. + with omegaconf.read_write(conf): + with omegaconf.open_dict(conf): + conf["interpolation"] = {"pad": conf.pop("pad")} + + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(conf, True) + OmegaConf.set_struct(conf, True) + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + if not conf.trainable: + for p in self.parameters(): + p.requires_grad = False + + def train(self, mode=True): + super().train(mode) + + def freeze_bn(module): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.eval() + + if self.conf.freeze_batch_normalization: + self.apply(freeze_bn) + + return self + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + + def recursive_key_check(expected, given): + for key in expected: + assert key in given, f"Missing key {key} in data" + if isinstance(expected, dict): + recursive_key_check(expected[key], given[key]) + + recursive_key_check(self.required_data_keys, data) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + def loss(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError + + def metrics(self): + return {} # no metrics diff --git a/models/feature_extractor.py b/models/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..5a06066c8ebce96b859d20fa444833d2b884a7ed --- /dev/null +++ b/models/feature_extractor.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/pixloc +# Released under the Apache License 2.0 + +""" +Flexible UNet model which takes any Torchvision backbone as encoder. +Predicts multi-level feature and makes sure that they are well aligned. +""" + +import torch +import torch.nn as nn +import torchvision + +from .base import BaseModel +from .utils import checkpointed + + +class DecoderBlock(nn.Module): + def __init__( + self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" + ): + super().__init__() + + self.upsample = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) + + layers = [] + for i in range(num_convs): + conv = nn.Conv2d( + previous + skip if i == 0 else out, + out, + kernel_size=3, + padding=1, + bias=norm is None, + padding_mode=padding, + ) + layers.append(conv) + if norm is not None: + layers.append(norm(out)) + layers.append(nn.ReLU(inplace=True)) + self.layers = nn.Sequential(*layers) + + def forward(self, previous, skip): + upsampled = self.upsample(previous) + # If the shape of the input map `skip` is not a multiple of 2, + # it will not match the shape of the upsampled map `upsampled`. + # If the downsampling uses ceil_mode=False, we nedd to crop `skip`. + # If it uses ceil_mode=True (not supported here), we should pad it. + _, _, hu, wu = upsampled.shape + _, _, hs, ws = skip.shape + assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?" + # assert (hu == hs) and (wu == ws), 'Careful about padding' + skip = skip[:, :, :hu, :wu] + return self.layers(torch.cat([upsampled, skip], dim=1)) + + +class AdaptationBlock(nn.Sequential): + def __init__(self, inp, out): + conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True) + super().__init__(conv) + + +class FeatureExtractor(BaseModel): + default_conf = { + "pretrained": True, + "input_dim": 3, + "output_scales": [0, 2, 4], # what scales to adapt and output + "output_dim": 128, # # of channels in output feature maps + "encoder": "vgg16", # string (torchvision net) or list of channels + "num_downsample": 4, # how many downsample block (if VGG-style net) + "decoder": [64, 64, 64, 64], # list of channels of decoder + "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks + "do_average_pooling": False, + "checkpointed": False, # whether to use gradient checkpointing + "padding": "zeros", + } + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + def build_encoder(self, conf): + assert isinstance(conf.encoder, str) + if conf.pretrained: + assert conf.input_dim == 3 + Encoder = getattr(torchvision.models, conf.encoder) + encoder = Encoder(weights="DEFAULT" if conf.pretrained else None) + Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed) + assert max(conf.output_scales) <= conf.num_downsample + + if conf.encoder.startswith("vgg"): + # Parse the layers and pack them into downsampling blocks + # It's easy for VGG-style nets because of their linear structure. + # This does not handle strided convs and residual connections + skip_dims = [] + previous_dim = None + blocks = [[]] + for i, layer in enumerate(encoder.features): + if isinstance(layer, torch.nn.Conv2d): + # Change the first conv layer if the input dim mismatches + if i == 0 and conf.input_dim != layer.in_channels: + args = {k: getattr(layer, k) for k in layer.__constants__} + args.pop("output_padding") + layer = torch.nn.Conv2d( + **{**args, "in_channels": conf.input_dim} + ) + previous_dim = layer.out_channels + elif isinstance(layer, torch.nn.MaxPool2d): + assert previous_dim is not None + skip_dims.append(previous_dim) + if (conf.num_downsample + 1) == len(blocks): + break + blocks.append([]) # start a new block + if conf.do_average_pooling: + assert layer.dilation == 1 + layer = torch.nn.AvgPool2d( + kernel_size=layer.kernel_size, + stride=layer.stride, + padding=layer.padding, + ceil_mode=layer.ceil_mode, + count_include_pad=False, + ) + blocks[-1].append(layer) + encoder = [Block(*b) for b in blocks] + elif conf.encoder.startswith("resnet"): + # Manually define the ResNet blocks such that the downsampling comes first + assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"] + assert conf.input_dim == 3, "Unsupported for now." + block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) + block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1) + block3 = encoder.layer2 + block4 = encoder.layer3 + block5 = encoder.layer4 + blocks = [block1, block2, block3, block4, block5] + # Extract the output dimension of each block + skip_dims = [encoder.conv1.out_channels] + for i in range(1, 5): + modules = getattr(encoder, f"layer{i}")[-1]._modules + conv = sorted(k for k in modules if k.startswith("conv"))[-1] + skip_dims.append(modules[conv].out_channels) + # Add a dummy block such that the first one does not downsample + encoder = [torch.nn.Identity()] + [Block(b) for b in blocks] + skip_dims = [3] + skip_dims + # Trim based on the requested encoder size + encoder = encoder[: conf.num_downsample + 1] + skip_dims = skip_dims[: conf.num_downsample + 1] + else: + raise NotImplementedError(conf.encoder) + + assert (conf.num_downsample + 1) == len(encoder) + encoder = nn.ModuleList(encoder) + + return encoder, skip_dims + + def _init(self, conf): + # Encoder + self.encoder, skip_dims = self.build_encoder(conf) + self.skip_dims = skip_dims + + def update_padding(module): + if isinstance(module, nn.Conv2d): + module.padding_mode = conf.padding + + if conf.padding != "zeros": + self.encoder.apply(update_padding) + + # Decoder + if conf.decoder is not None: + assert len(conf.decoder) == (len(skip_dims) - 1) + Block = checkpointed(DecoderBlock, do=conf.checkpointed) + norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa + + previous = skip_dims[-1] + decoder = [] + for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]): + decoder.append( + Block(previous, skip, out, norm=norm, padding=conf.padding) + ) + previous = out + self.decoder = nn.ModuleList(decoder) + + # Adaptation layers + adaptation = [] + for idx, i in enumerate(conf.output_scales): + if conf.decoder is None or i == (len(self.encoder) - 1): + input_ = skip_dims[i] + else: + input_ = conf.decoder[-1 - i] + + # out_dim can be an int (same for all scales) or a list (per scale) + dim = conf.output_dim + if not isinstance(dim, int): + dim = dim[idx] + + block = AdaptationBlock(input_, dim) + adaptation.append(block) + self.adaptation = nn.ModuleList(adaptation) + self.scales = [2**s for s in conf.output_scales] + + def _forward(self, data): + image = data["image"] + if self.conf.pretrained: + mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) + image = (image - mean[:, None, None]) / std[:, None, None] + + skip_features = [] + features = image + for block in self.encoder: + features = block(features) + skip_features.append(features) + + if self.conf.decoder: + pre_features = [skip_features[-1]] + for block, skip in zip(self.decoder, skip_features[:-1][::-1]): + pre_features.append(block(pre_features[-1], skip)) + pre_features = pre_features[::-1] # fine to coarse + else: + pre_features = skip_features + + out_features = [] + for adapt, i in zip(self.adaptation, self.conf.output_scales): + out_features.append(adapt(pre_features[i])) + pred = {"feature_maps": out_features, "skip_features": skip_features} + return pred + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + raise NotImplementedError diff --git a/models/feature_extractor_v2.py b/models/feature_extractor_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..38c910651f63e5394214ea0e2b1909537948da54 --- /dev/null +++ b/models/feature_extractor_v2.py @@ -0,0 +1,192 @@ +import logging + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torchvision.models.feature_extraction import create_feature_extractor + +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class DecoderBlock(nn.Module): + def __init__( + self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros" + ): + super().__init__() + layers = [] + for i in range(num_convs): + conv = nn.Conv2d( + previous if i == 0 else out, + out, + kernel_size=ksize, + padding=ksize // 2, + bias=norm is None, + padding_mode=padding, + ) + layers.append(conv) + if norm is not None: + layers.append(norm(out)) + layers.append(nn.ReLU(inplace=True)) + self.layers = nn.Sequential(*layers) + + def forward(self, previous, skip): + _, _, hp, wp = previous.shape + _, _, hs, ws = skip.shape + scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp]))) + upsampled = nn.functional.interpolate( + previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False + ) + # If the shape of the input map `skip` is not a multiple of 2, + # it will not match the shape of the upsampled map `upsampled`. + # If the downsampling uses ceil_mode=False, we nedd to crop `skip`. + # If it uses ceil_mode=True (not supported here), we should pad it. + _, _, hu, wu = upsampled.shape + _, _, hs, ws = skip.shape + if (hu <= hs) and (wu <= ws): + skip = skip[:, :, :hu, :wu] + elif (hu >= hs) and (wu >= ws): + skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs]) + else: + raise ValueError( + f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}" + ) + + return self.layers(skip) + upsampled + + +class FPN(nn.Module): + def __init__(self, in_channels_list, out_channels, **kw): + super().__init__() + self.first = nn.Conv2d( + in_channels_list[-1], out_channels, 1, padding=0, bias=True + ) + self.blocks = nn.ModuleList( + [ + DecoderBlock(c, out_channels, ksize=1, **kw) + for c in in_channels_list[::-1][1:] + ] + ) + self.out = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, layers): + feats = None + for idx, x in enumerate(reversed(layers.values())): + if feats is None: + feats = self.first(x) + else: + feats = self.blocks[idx - 1](feats, x) + out = self.out(feats) + return out + + +def remove_conv_stride(conv): + conv_new = nn.Conv2d( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + bias=conv.bias is not None, + stride=1, + padding=conv.padding, + ) + conv_new.weight = conv.weight + conv_new.bias = conv.bias + return conv_new + + +class FeatureExtractor(BaseModel): + default_conf = { + "pretrained": True, + "input_dim": 3, + "output_dim": 128, # # of channels in output feature maps + "encoder": "resnet50", # torchvision net as string + "remove_stride_from_first_conv": False, + "num_downsample": None, # how many downsample block + "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks + "do_average_pooling": False, + "checkpointed": False, # whether to use gradient checkpointing + } + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + def build_encoder(self, conf): + assert isinstance(conf.encoder, str) + if conf.pretrained: + assert conf.input_dim == 3 + Encoder = getattr(torchvision.models, conf.encoder) + + kw = {} + if conf.encoder.startswith("resnet"): + layers = ["relu", "layer1", "layer2", "layer3", "layer4"] + kw["replace_stride_with_dilation"] = [False, False, False] + elif conf.encoder == "vgg13": + layers = [ + "features.3", + "features.8", + "features.13", + "features.18", + "features.23", + ] + elif conf.encoder == "vgg16": + layers = [ + "features.3", + "features.8", + "features.15", + "features.22", + "features.29", + ] + else: + raise NotImplementedError(conf.encoder) + + if conf.num_downsample is not None: + layers = layers[: conf.num_downsample] + encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw) + encoder = create_feature_extractor(encoder, return_nodes=layers) + if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv: + encoder.conv1 = remove_conv_stride(encoder.conv1) + + if conf.do_average_pooling: + raise NotImplementedError + if conf.checkpointed: + raise NotImplementedError + + return encoder, layers + + def _init(self, conf): + # Preprocessing + self.register_buffer("mean_", torch.tensor(self.mean), persistent=False) + self.register_buffer("std_", torch.tensor(self.std), persistent=False) + + # Encoder + self.encoder, self.layers = self.build_encoder(conf) + s = 128 + inp = torch.zeros(1, 3, s, s) + features = list(self.encoder(inp).values()) + self.skip_dims = [x.shape[1] for x in features] + self.layer_strides = [s / f.shape[-1] for f in features] + self.scales = [self.layer_strides[0]] + + # Decoder + norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa + self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm) + + logger.debug( + "Built feature extractor with layers {name:dim:stride}:\n" + f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n" + f"and output scales {self.scales}." + ) + + def _forward(self, data): + image = data["image"] + image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] + + skip_features = self.encoder(image) + output = self.decoder(skip_features) + pred = {"feature_maps": [output], "skip_features": skip_features} + return pred diff --git a/models/map_encoder.py b/models/map_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e54db926df6cc16d9082826ffee2a8b838dbed21 --- /dev/null +++ b/models/map_encoder.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch +import torch.nn as nn + +from .base import BaseModel +from .feature_extractor import FeatureExtractor + + +class MapEncoder(BaseModel): + default_conf = { + "embedding_dim": "???", + "output_dim": None, + "num_classes": "???", + "backbone": "???", + "unary_prior": False, + } + + def _init(self, conf): + self.embeddings = torch.nn.ModuleDict( + { + k: torch.nn.Embedding(n + 1, conf.embedding_dim) + for k, n in conf.num_classes.items() + } + ) + #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33} + input_dim = len(conf.num_classes) * conf.embedding_dim + output_dim = conf.output_dim + if output_dim is None: + output_dim = conf.backbone.output_dim + if conf.unary_prior: + output_dim += 1 + if conf.backbone is None: + self.encoder = nn.Conv2d(input_dim, output_dim, 1) + elif conf.backbone == "simple": + self.encoder = nn.Sequential( + nn.Conv2d(input_dim, 128, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, output_dim, 3, padding=1), + ) + else: + self.encoder = FeatureExtractor( + { + **conf.backbone, + "input_dim": input_dim, + "output_dim": output_dim, + } + ) + + def _forward(self, data): + embeddings = [ + self.embeddings[k](data["map"][:, i]) + for i, k in enumerate(("areas", "ways", "nodes")) + ] + embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2) + if isinstance(self.encoder, BaseModel): + features = self.encoder({"image": embeddings})["feature_maps"] + else: + features = [self.encoder(embeddings)] + pred = {} + if self.conf.unary_prior: + pred["log_prior"] = [f[:, -1] for f in features] + features = [f[:, :-1] for f in features] + pred["map_features"] = features + return pred diff --git a/models/maplocnet.py b/models/maplocnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4d32dc4b78bac1d0c1eb23827be875598489b447 --- /dev/null +++ b/models/maplocnet.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import numpy as np +import torch +from torch.nn.functional import normalize + +from . import get_model +from models.base import BaseModel +# from models.bev_net import BEVNet +# from models.bev_projection import CartesianProjection, PolarProjectionDepth +from models.voting import ( + argmax_xyr, + conv2d_fft_batchwise, + expectation_xyr, + log_softmax_spatial, + mask_yaw_prior, + nll_loss_xyr, + nll_loss_xyr_smoothed, + TemplateSampler, + UAVTemplateSampler, + UAVTemplateSamplerFast +) +from .map_encoder import MapEncoder +from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall + + +class MapLocNet(BaseModel): + default_conf = { + "image_size": "???", + "val_citys":"???", + "image_encoder": "???", + "map_encoder": "???", + "bev_net": "???", + "latent_dim": "???", + "matching_dim": "???", + "scale_range": [0, 9], + "num_scale_bins": "???", + "z_min": None, + "z_max": "???", + "x_max": "???", + "pixel_per_meter": "???", + "num_rotations": "???", + "add_temperature": False, + "normalize_features": False, + "padding_matching": "replicate", + "apply_map_prior": True, + "do_label_smoothing": False, + "sigma_xy": 1, + "sigma_r": 2, + # depcreated + "depth_parameterization": "scale", + "norm_depth_scores": False, + "normalize_scores_by_dim": False, + "normalize_scores_by_num_valid": True, + "prior_renorm": True, + "retrieval_dim": None, + } + + def _init(self, conf): + assert not self.conf.norm_depth_scores + assert self.conf.depth_parameterization == "scale" + assert not self.conf.normalize_scores_by_dim + assert self.conf.normalize_scores_by_num_valid + assert self.conf.prior_renorm + + Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2")) + self.image_encoder = Encoder(conf.image_encoder.backbone) + self.map_encoder = MapEncoder(conf.map_encoder) + # self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net) + + ppm = conf.pixel_per_meter + # self.projection_polar = PolarProjectionDepth( + # conf.z_max, + # ppm, + # conf.scale_range, + # conf.z_min, + # ) + # self.projection_bev = CartesianProjection( + # conf.z_max, conf.x_max, ppm, conf.z_min + # ) + # self.template_sampler = TemplateSampler( + # self.projection_bev.grid_xz, ppm, conf.num_rotations + # ) + # self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2) + self.template_sampler = UAVTemplateSampler(conf.num_rotations) + # self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins) + # if conf.bev_net is None: + # self.feature_projection = torch.nn.Linear( + # conf.latent_dim, conf.matching_dim + # ) + if conf.add_temperature: + temperature = torch.nn.Parameter(torch.tensor(0.0)) + self.register_parameter("temperature", temperature) + + def exhaustive_voting(self, f_bev, f_map): + if self.conf.normalize_features: + f_bev = normalize(f_bev, dim=1) + f_map = normalize(f_map, dim=1) + + # Build the templates and exhaustively match against the map. + # if confidence_bev is not None: + # f_bev = f_bev * confidence_bev.unsqueeze(1) + # f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0) + # torch.save(f_bev, 'f_bev.pt') + # torch.save(f_map, 'f_map.pt') + + templates = self.template_sampler(f_bev)#[batch,256,8,129,129] + # torch.save(templates, 'templates.pt') + with torch.autocast("cuda", enabled=False): + scores = conv2d_fft_batchwise( + f_map.float(), + templates.float(), + padding_mode=self.conf.padding_matching, + ) + if self.conf.add_temperature: + scores = scores * torch.exp(self.temperature) + + # Reweight the different rotations based on the number of valid pixels + # in each template. Axis-aligned rotation have the maximum number of valid pixels. + # valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4) + # num_valid = valid_templates.float().sum((-3, -2, -1)) + # scores = scores / num_valid[..., None, None] + return scores + + def _forward(self, data): + pred = {} + pred_map = pred["map"] = self.map_encoder(data) + f_map = pred_map["map_features"][0]#[batch,8,256,256] + + # Extract image features. + level = 0 + f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] + # print("f_map:",f_map.shape) + + scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] + scores = scores.moveaxis(1, -1) # B,H,W,N + if "log_prior" in pred_map and self.conf.apply_map_prior: + scores = scores + pred_map["log_prior"][0].unsqueeze(-1) + # pred["scores_unmasked"] = scores.clone() + if "map_mask" in data: + scores.masked_fill_(~data["map_mask"][..., None], -np.inf) + if "yaw_prior" in data: + mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) + log_probs = log_softmax_spatial(scores) + # torch.save(scores, 'scores.pt') + with torch.no_grad(): + uvr_max = argmax_xyr(scores).to(scores) + uvr_avg, _ = expectation_xyr(log_probs.exp()) + + return { + **pred, + "scores": scores, + "log_probs": log_probs, + "uvr_max": uvr_max, + "uv_max": uvr_max[..., :2], + "yaw_max": uvr_max[..., 2], + "uvr_expectation": uvr_avg, + "uv_expectation": uvr_avg[..., :2], + "yaw_expectation": uvr_avg[..., 2], + "features_image": f_image, + } + + def loss(self, pred, data): + xy_gt = data["uv"] + yaw_gt = data["roll_pitch_yaw"][..., -1] + if self.conf.do_label_smoothing: + nll = nll_loss_xyr_smoothed( + pred["log_probs"], + xy_gt, + yaw_gt, + self.conf.sigma_xy / self.conf.pixel_per_meter, + self.conf.sigma_r, + mask=data.get("map_mask"), + ) + else: + nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) + loss = {"total": nll, "nll": nll} + if self.training and self.conf.add_temperature: + loss["temperature"] = self.temperature.expand(len(nll)) + return loss + + def metrics(self): + return { + "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), + "xy_expectation_error": Location2DError( + "uv_expectation", self.conf.pixel_per_meter + ), + "yaw_max_error": AngleError("yaw_max"), + "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), + "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), + "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), + + # "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), + # "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), + # "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), + # + # "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), + # "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), + # "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), + + "yaw_recall_1°": AngleRecall(1.0, "yaw_max"), + "yaw_recall_3°": AngleRecall(3.0, "yaw_max"), + "yaw_recall_5°": AngleRecall(5.0, "yaw_max"), + } diff --git a/models/metrics.py b/models/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b50a724f4719853b476d693db25ddbba562a3a51 --- /dev/null +++ b/models/metrics.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch +import torchmetrics +from torchmetrics.utilities.data import dim_zero_cat + +from .utils import deg2rad, rotmat2d + + +def location_error(uv, uv_gt, ppm=1): + return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm + +def location_error_single(uv, uv_gt, ppm=1): + return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm + +def angle_error(t, t_gt): + error = torch.abs(t % 360 - t_gt.to(t) % 360) + error = torch.minimum(error, 360 - error) + return error + + +class Location2DRecall(torchmetrics.MeanMetric): + def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs): + self.threshold = threshold + self.ppm = pixel_per_meter + self.key = key + super().__init__(*args, **kwargs) + + def update(self, pred, data): + self.cuda() + error = location_error(pred[self.key], data["uv"], self.ppm) + # print(error,self.threshold) + super().update((error <= torch.tensor(self.threshold,device=error.device)).float()) + +class Location1DRecall(torchmetrics.MeanMetric): + def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs): + self.threshold = threshold + self.ppm = pixel_per_meter + self.key = key + super().__init__(*args, **kwargs) + + def update(self, pred, data): + self.cuda() + error = location_error(pred[self.key], data["uv"], self.ppm) + # print(error,self.threshold) + super().update((error <= torch.tensor(self.threshold,device=error.device)).float()) +class AngleRecall(torchmetrics.MeanMetric): + def __init__(self, threshold, key="yaw_max", *args, **kwargs): + self.threshold = threshold + self.key = key + + super().__init__(*args, **kwargs) + + def update(self, pred, data): + self.cuda() + error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1]) + super().update((error <= self.threshold).float()) + + +class MeanMetricWithRecall(torchmetrics.Metric): + full_state_update = True + + def __init__(self): + super().__init__() + self.add_state("value", default=[], dist_reduce_fx="cat") + def compute(self): + return dim_zero_cat(self.value).mean(0) + + def get_errors(self): + return dim_zero_cat(self.value) + + def recall(self, thresholds): + self.cuda() + error = self.get_errors() + thresholds = error.new_tensor(thresholds) + return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100 + + +class AngleError(MeanMetricWithRecall): + def __init__(self, key): + super().__init__() + self.key = key + + def update(self, pred, data): + self.cuda() + value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1]) + if value.numel(): + self.value.append(value) + + +class Location2DError(MeanMetricWithRecall): + def __init__(self, key, pixel_per_meter): + super().__init__() + self.key = key + self.ppm = pixel_per_meter + + def update(self, pred, data): + self.cuda() + value = location_error(pred[self.key], data["uv"], self.ppm) + if value.numel(): + self.value.append(value) + + +class LateralLongitudinalError(MeanMetricWithRecall): + def __init__(self, pixel_per_meter, key="uv_max"): + super().__init__() + self.ppm = pixel_per_meter + self.key = key + + def update(self, pred, data): + self.cuda() + yaw = deg2rad(data["roll_pitch_yaw"][..., -1]) + shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1]) + shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1) + error = torch.abs(shift) / self.ppm + value = error.view(-1, 2) + if value.numel(): + self.value.append(value) diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec246b7d0adf04cc9307475867650523f67a5063 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import math +from typing import Optional + +import torch + + +def checkpointed(cls, do=True): + """Adapted from the DISK implementation of Michał Tyszkiewicz.""" + assert issubclass(cls, torch.nn.Module) + + class Checkpointed(cls): + def forward(self, *args, **kwargs): + super_fwd = super(Checkpointed, self).forward + if any((torch.is_tensor(a) and a.requires_grad) for a in args): + return torch.utils.checkpoint.checkpoint(super_fwd, *args, **kwargs) + else: + return super_fwd(*args, **kwargs) + + return Checkpointed if do else cls + + +class GlobalPooling(torch.nn.Module): + def __init__(self, kind): + super().__init__() + if kind == "mean": + self.fn = torch.nn.Sequential( + torch.nn.Flatten(2), torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten() + ) + elif kind == "max": + self.fn = torch.nn.Sequential( + torch.nn.Flatten(2), torch.nn.AdaptiveMaxPool1d(1), torch.nn.Flatten() + ) + else: + raise ValueError(f"Unknown pooling type {kind}.") + + def forward(self, x): + return self.fn(x) + + +@torch.jit.script +def make_grid( + w: float, + h: float, + step_x: float = 1.0, + step_y: float = 1.0, + orig_x: float = 0, + orig_y: float = 0, + y_up: bool = False, + device: Optional[torch.device] = None, +) -> torch.Tensor: + x, y = torch.meshgrid( + [ + torch.arange(orig_x, w + orig_x, step_x, device=device), + torch.arange(orig_y, h + orig_y, step_y, device=device), + ], + indexing="xy", + ) + if y_up: + y = y.flip(-2) + grid = torch.stack((x, y), -1) + return grid + + +@torch.jit.script +def rotmat2d(angle: torch.Tensor) -> torch.Tensor: + c = torch.cos(angle) + s = torch.sin(angle) + R = torch.stack([c, -s, s, c], -1).reshape(angle.shape + (2, 2)) + return R + + +@torch.jit.script +def rotmat2d_grad(angle: torch.Tensor) -> torch.Tensor: + c = torch.cos(angle) + s = torch.sin(angle) + R = torch.stack([-s, -c, c, -s], -1).reshape(angle.shape + (2, 2)) + return R + + +def deg2rad(x): + return x * math.pi / 180 + + +def rad2deg(x): + return x * 180 / math.pi diff --git a/models/voting.py b/models/voting.py new file mode 100644 index 0000000000000000000000000000000000000000..b57bc1e86d6f738c060f7ef0fea3698f4fc13dd6 --- /dev/null +++ b/models/voting.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from typing import Optional, Tuple + +import numpy as np +import torch +from torch.fft import irfftn, rfftn +from torch.nn.functional import grid_sample, log_softmax, pad + +from .metrics import angle_error +from .utils import make_grid, rotmat2d +from torchvision.transforms.functional import rotate + +class UAVTemplateSamplerFast(torch.nn.Module): + def __init__(self, num_rotations,w=128,optimize=True): + super().__init__() + + h, w = w,w + grid_xy = make_grid( + w=w, + h=h, + step_x=1, + step_y=1, + orig_y=-h//2, + orig_x=-h//2, + y_up=True, + ).cuda() + + if optimize: + assert (num_rotations % 4) == 0 + angles = torch.arange( + 0, 90, 90 / (num_rotations // 4) + ).cuda() + else: + angles = torch.arange( + 0, 360, 360 / num_rotations, device=grid_xz_bev.device + ) + rotmats = rotmat2d(angles / 180 * np.pi) + grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) + + grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor( + [1, -1] + ) + grid_ij_rot = grid_ij_rot + grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 + + self.optimize = optimize + self.num_rots = num_rotations + self.register_buffer("angles", angles, persistent=False) + self.register_buffer("grid_norm", grid_norm, persistent=False) + + def forward(self, image_bev): + grid = self.grid_norm + b, c = image_bev.shape[:2] + n, h, w = grid.shape[:3] + grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) + image = ( + image_bev[:, None] + .repeat_interleave(n, 1) + .reshape(b * n, *image_bev.shape[1:]) + ) + # print(image.shape,grid.shape,self.grid_norm.shape) + kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( + b, n, c, h, w + ) + + if self.optimize: # we have computed only the first quadrant + kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] + kernels = torch.cat([kernels] + kernels_quad234, 1) + + return kernels +class UAVTemplateSampler(torch.nn.Module): + def __init__(self, num_rotations): + super().__init__() + + self.num_rotations = num_rotations + + def Template(self, input_features): + # 角度数量 + num_angles = self.num_rotations + # 扩展第二个维度为旋转角度数量 + input_shape = torch.tensor(input_features.shape) + output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist() + expanded_features = torch.zeros(output_shape,device=input_features.device) + + # 生成旋转角度序列 + rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1] + # rotation_angles=torch.flip(rotation_angles, dims=[0]) + # 对扩展后的特征应用不同的旋转角度 + rotated_features = [] + # print(len(rotation_angles)) + for i in range(len(rotation_angles)): + # print(rotation_angles[i].item()) + rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0) + expanded_features[:, i, :, :, :] = rotated_feature + + # 将所有旋转后的特征堆叠起来形成最终的输出向量 + # output_features = torch.stack(rotated_features, dim=1) + + # 输出向量的维度 + # output_size = [3, num_angles, 8, 128, 128] + return expanded_features # 输出调试信息,验证输出向量的维度是否正确 + def forward(self, image_bev): + + kernels=self.Template(image_bev) + + return kernels +class TemplateSampler(torch.nn.Module): + def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True): + super().__init__() + + Δ = 1 / ppm + h, w = grid_xz_bev.shape[:2] + ksize = max(w, h * 2 + 1) + radius = ksize * Δ + grid_xy = make_grid( + radius, + radius, + step_x=Δ, + step_y=Δ, + orig_y=(Δ - radius) / 2, + orig_x=(Δ - radius) / 2, + y_up=True, + ) + + if optimize: + assert (num_rotations % 4) == 0 + angles = torch.arange( + 0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device + ) + else: + angles = torch.arange( + 0, 360, 360 / num_rotations, device=grid_xz_bev.device + ) + rotmats = rotmat2d(angles / 180 * np.pi) + grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) + + grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor( + [1, -1] + ) + grid_ij_rot = grid_ij_rot / Δ + grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 + + self.optimize = optimize + self.num_rots = num_rotations + self.register_buffer("angles", angles, persistent=False) + self.register_buffer("grid_norm", grid_norm, persistent=False) + + def forward(self, image_bev): + grid = self.grid_norm + b, c = image_bev.shape[:2] + n, h, w = grid.shape[:3] + grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) + image = ( + image_bev[:, None] + .repeat_interleave(n, 1) + .reshape(b * n, *image_bev.shape[1:]) + ) + kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( + b, n, c, h, w + ) + + if self.optimize: # we have computed only the first quadrant + kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] + kernels = torch.cat([kernels] + kernels_quad234, 1) + + return kernels + + +def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"): + if padding == "same": + padding = [i // 2 for i in kernel.shape[-2:]] + padding_signal = [p for p in padding[::-1] for _ in range(2)] + signal = pad(signal, padding_signal, mode=padding_mode) + assert signal.size(-1) % 2 == 0 + + padding_kernel = [ + pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)] + ] + kernel_padded = pad(kernel, padding_kernel) + + signal_fr = rfftn(signal, dim=(-1, -2)) + kernel_fr = rfftn(kernel_padded, dim=(-1, -2)) + + kernel_fr.imag *= -1 # flip the kernel + output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr) + output = irfftn(output_fr, dim=(-1, -2)) + + crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [ + slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1] + ] + output = output[crop_slices].contiguous() + + return output + + +class SparseMapSampler(torch.nn.Module): + def __init__(self, num_rotations): + super().__init__() + angles = torch.arange(0, 360, 360 / self.conf.num_rotations) + rotmats = rotmat2d(angles / 180 * np.pi) + self.num_rotations = num_rotations + self.register_buffer("rotmats", rotmats, persistent=False) + + def forward(self, image_map, p2d_bev): + h, w = image_map.shape[-2:] + locations = make_grid(w, h, device=p2d_bev.device) + p2d_candidates = torch.einsum( + "kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev + ) + p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1) + # ... x N x W x H x K x 2 + + p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1 + valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1) + value = grid_sample( + image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear" + ) + value = value.reshape(image_map.shape[:2] + valid.shape[-4]) + return valid, value + + +def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False): + # (B, C, H, W, N) to (B, C, H, W, N+1) + volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular") + + size = xy_grid.new_tensor(volume.shape[-3:-1][::-1]) + xy_norm = xy_grid / (size - 1) # align_corners=True + angle_norm = (angle_grid / 360) % 1 + grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1) + grid_norm = grid * 2 - 1 + + valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1) + value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear") + + # if one of the values used for linear interpolation is infinite, + # we fallback to nearest to avoid propagating inf + if nearest_for_inf: + value_nearest = grid_sample( + volume_padded, grid_norm, align_corners=True, mode="nearest" + ) + value = torch.where(~torch.isfinite(value) & valid, value_nearest, value) + + return value, valid + + +def nll_loss_xyr(log_probs, xy, angle): + log_prob, _ = sample_xyr( + log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None] + ) + nll = -log_prob.reshape(-1) # remove C,H,W,N + return nll + + +def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None): + *_, nx, ny, nr = log_probs.shape + grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float) + dx = (grid_x - xy[..., None, 0]) / sigma_xy + grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float) + dy = (grid_y - xy[..., None, 1]) / sigma_xy + dr = ( + torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float) + - angle[..., None] + ) % 360 + dr = torch.minimum(dr, 360 - dr) / sigma_r + diff = ( + dx[..., None, :, None] ** 2 + + dy[..., :, None, None] ** 2 + + dr[..., None, None, :] ** 2 + ) + pdf = torch.exp(-diff / 2) + if mask is not None: + pdf.masked_fill_(~mask[..., None], 0) + log_probs = log_probs.masked_fill(~mask[..., None], 0) + pdf /= pdf.sum((-1, -2, -3), keepdim=True) + return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3)) + + +def log_softmax_spatial(x, dims=3): + return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape) + + +@torch.jit.script +def argmax_xy(scores: torch.Tensor) -> torch.Tensor: + indices = scores.flatten(-2).max(-1).indices + width = scores.shape[-1] + x = indices % width + y = torch.div(indices, width, rounding_mode="floor") + return torch.stack((x, y), -1) + + +@torch.jit.script +def expectation_xy(prob: torch.Tensor) -> torch.Tensor: + h, w = prob.shape[-2:] + grid = make_grid(float(w), float(h), device=prob.device).to(prob) + return torch.einsum("...hw,hwd->...d", prob, grid) + + +@torch.jit.script +def expectation_xyr( + prob: torch.Tensor, covariance: bool = False +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + h, w, num_rotations = prob.shape[-3:] + x, y = torch.meshgrid( + [ + torch.arange(w, device=prob.device, dtype=prob.dtype), + torch.arange(h, device=prob.device, dtype=prob.dtype), + ], + indexing="xy", + ) + grid_xy = torch.stack((x, y), -1) + xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy) + + angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype) + angles = angles * 2 * np.pi + grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1) + cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs) + angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0]) + angle = (angle * 180 / np.pi) % 360 + + if covariance: + xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy) + xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean) + else: + xy_cov = None + + xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1) + return xyr_mean, xy_cov + + +@torch.jit.script +def argmax_xyr(scores: torch.Tensor) -> torch.Tensor: + indices = scores.flatten(-3).max(-1).indices + width, num_rotations = scores.shape[-2:] + wr = width * num_rotations + y = torch.div(indices, wr, rounding_mode="floor") + x = torch.div(indices % wr, num_rotations, rounding_mode="floor") + angle_index = indices % num_rotations + angle = angle_index * 360 / num_rotations + xyr = torch.stack((x, y, angle), -1) + return xyr + + +@torch.jit.script +def mask_yaw_prior( + scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int +) -> torch.Tensor: + step = 360 / num_rotations + step_2 = step / 2 + angles = torch.arange(step_2, 360 + step_2, step, device=scores.device) + yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1) + rot_mask = angle_error(angles, yaw_init) < yaw_range + return scores.masked_fill_(~rot_mask[:, None, None], -np.inf) + + +def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False): + grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob) + dist = torch.sum((grid - uv_gps) ** 2, -1) + sigma_pixel = sigma * ppm + if gaussian: + gps_log_prob = -1 / 2 * dist / sigma_pixel**2 + else: + gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf) + log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1)) + return log_prob_fused diff --git a/module.py b/module.py new file mode 100644 index 0000000000000000000000000000000000000000..47bc341f83aba111638a5de1fb8c6d88ed900df7 --- /dev/null +++ b/module.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from pathlib import Path + +import pytorch_lightning as pl +import torch +from omegaconf import DictConfig, OmegaConf, open_dict +from torchmetrics import MeanMetric, MetricCollection + +import logger +from models import get_model + + +class AverageKeyMeter(MeanMetric): + def __init__(self, key, *args, **kwargs): + self.key = key + super().__init__(*args, **kwargs) + + def update(self, dict): + value = dict[self.key] + value = value[torch.isfinite(value)] + return super().update(value) + + +class GenericModule(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + name = cfg.model.get("name") + name = "orienternet" if name in ("localizer_bev_depth", None) else name + self.model = get_model(name)(cfg.model) + self.cfg = cfg + self.save_hyperparameters(cfg) + + + + self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/") + self.losses_val = None # we do not know the loss keys in advance + + # self.citys = self.cfg.data.val_citys + # for i in range(len(self.citys)): + # city=self.citys[i] + # setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city))) + # self.losse_vals = [None for city in self.cfg.data.val_citys] + + + def forward(self, batch): + return self.model(batch) + + def training_step(self, batch): + pred = self(batch) + losses = self.model.loss(pred, batch) + self.log_dict( + {f"loss/{k}/train": v.mean() for k, v in losses.items()}, + prog_bar=True, + rank_zero_only=True, + ) + return losses["total"].mean() + + # def validation_step(self, batch, batch_idx,dataloader_idx): + # city=self.citys[dataloader_idx] + # + # pred = self(batch) + # losses = self.model.loss(pred, batch) + # + # if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False: + # setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection( + # {k: AverageKeyMeter(k).to(self.device) for k in losses}, + # prefix="loss_{}/".format(city), + # postfix="/val_{}".format(city), + # )) + # + # # print(pred, batch) + # getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch) + # self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True) + # + # getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses) + # # print(getattr(self,"losse_val_{}".format(dataloader_idx))) + # self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True) + def validation_step(self, batch, batch_idx): + pred = self(batch) + losses = self.model.loss(pred, batch) + if self.losses_val is None: + self.losses_val = MetricCollection( + {k: AverageKeyMeter(k).to(self.device) for k in losses}, + prefix="loss/", + postfix="/val", + ) + self.metrics_val(pred, batch) + self.log_dict(self.metrics_val, sync_dist=True) + self.losses_val.update(losses) + self.log_dict(self.losses_val, sync_dist=True) + + def validation_epoch_start(self, batch): + self.losses_val = None + # self.losse_val = [None for city in self.cfg.data.val_citys] + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr) + ret = {"optimizer": optimizer} + cfg_scheduler = self.cfg.training.get("lr_scheduler") + if cfg_scheduler is not None: + scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)( + optimizer=optimizer, **cfg_scheduler.get("args", {}) + ) + ret["lr_scheduler"] = { + "scheduler": scheduler, + "interval": "epoch", + "frequency": 1, + "monitor": "loss/total/val", + "strict": True, + "name": "learning_rate", + } + return ret + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path, + map_location=None, + hparams_file=None, + strict=True, + cfg=None, + find_best=False, + ): + assert hparams_file is None, "hparams are not supported." + + checkpoint = torch.load( + checkpoint_path, map_location=map_location or (lambda storage, loc: storage) + ) + if find_best: + best_score, best_name = None, None + modes = {"min": torch.lt, "max": torch.gt} + for key, state in checkpoint["callbacks"].items(): + if not key.startswith("ModelCheckpoint"): + continue + mode = eval(key.replace("ModelCheckpoint", ""))["mode"] + if best_score is None or modes[mode]( + state["best_model_score"], best_score + ): + best_score = state["best_model_score"] + best_name = Path(state["best_model_path"]).name + logger.info("Loading best checkpoint %s", best_name) + if best_name != checkpoint_path: + return cls.load_from_checkpoint( + Path(checkpoint_path).parent / best_name, + map_location, + hparams_file, + strict, + cfg, + find_best=False, + ) + + logger.info( + "Using checkpoint %s from epoch %d and step %d.", + checkpoint_path.name, + checkpoint["epoch"], + checkpoint["global_step"], + ) + cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility + cfg_ckpt = cfg_ckpt["cfg"] + cfg_ckpt = OmegaConf.create(cfg_ckpt) + + if cfg is None: + cfg = {} + if not isinstance(cfg, DictConfig): + cfg = OmegaConf.create(cfg) + with open_dict(cfg_ckpt): + cfg = OmegaConf.merge(cfg_ckpt, cfg) + + return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg) diff --git a/osm/analysis.py b/osm/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..a667c21373a31482f7bbcfb41d4fa14681741260 --- /dev/null +++ b/osm/analysis.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from collections import Counter, defaultdict +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go + +from .parser import ( + filter_area, + filter_node, + filter_way, + match_to_group, + parse_area, + parse_node, + parse_way, + Patterns, +) +from .reader import OSMData + + +def recover_hierarchy(counter: Counter) -> Dict: + """Recover a two-level hierarchy from the flat group labels.""" + groups = defaultdict(dict) + for k, v in sorted(counter.items(), key=lambda x: -x[1]): + if ":" in k: + prefix, group = k.split(":") + if prefix in groups and isinstance(groups[prefix], int): + groups[prefix] = {} + groups[prefix][prefix] = groups[prefix] + groups[prefix] = {} + groups[prefix][group] = v + else: + groups[k] = v + return dict(groups) + + +def bar_autolabel(rects, fontsize): + """Attach a text label above each bar in *rects*, displaying its height.""" + for rect in rects: + width = rect.get_width() + plt.gca().annotate( + f"{width}", + xy=(width, rect.get_y() + rect.get_height() / 2), + xytext=(3, 0), # 3 points vertical offset + textcoords="offset points", + ha="left", + va="center", + fontsize=fontsize, + ) + + +def plot_histogram(counts, fontsize, dpi): + fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20)) + + labels = [] + for k, v in counts.items(): + if isinstance(v, dict): + labels += list(v.keys()) + v = list(v.values()) + else: + labels.append(k) + v = [v] + bars = plt.barh( + len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k + ) + bar_autolabel(bars, fontsize) + + ax.set_yticklabels(labels, fontsize=fontsize) + ax.axes.xaxis.set_ticklabels([]) + ax.xaxis.tick_top() + ax.invert_yaxis() + plt.yticks(np.arange(len(labels))) + plt.xscale("log") + plt.legend(ncol=len(counts), loc="upper center") + + +def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict: + """Count the number of elements in each group.""" + counts = Counter() + for elem in filter(filter_fn, elems.values()): + group = parse_fn(elem.tags) + if group is None: + continue + counts[group] += 1 + counts = recover_hierarchy(counts) + return counts + + +def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150): + counts = count_elements(osm.nodes, filter_node, parse_node) + plot_histogram(counts, fontsize, dpi) + plt.title("nodes") + + counts = count_elements(osm.ways, filter_way, parse_way) + plot_histogram(counts, fontsize, dpi) + plt.title("ways") + + counts = count_elements(osm.ways, filter_area, parse_area) + plot_histogram(counts, fontsize, dpi) + plt.title("areas") + + +def plot_sankey_hierarchy(osm: OSMData): + triplets = [] + for node in filter(filter_node, osm.nodes.values()): + label = parse_node(node.tags) + if label is None: + continue + group = match_to_group(label, Patterns.nodes) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + group = "null" + if ":" in label: + key, tag = label.split(":") + if tag == "yes": + tag = key + else: + key = tag = label + triplets.append((key, tag, group)) + keys, tags, groups = list(zip(*triplets)) + counts_key_tag = Counter(zip(keys, tags)) + counts_key_tag_group = Counter(triplets) + + key2tags = defaultdict(set) + for k, t in zip(keys, tags): + key2tags[k].add(t) + key2tags = {k: sorted(t) for k, t in key2tags.items()} + keytag2group = dict(zip(zip(keys, tags), groups)) + key_names = sorted(set(keys)) + tag_names = [(k, t) for k in key_names for t in key2tags[k]] + + group_names = [] + for k in key_names: + for t in key2tags[k]: + g = keytag2group[k, t] + if g not in group_names and g != "null": + group_names.append(g) + group_names += ["null"] + + key2idx = dict(zip(key_names, range(len(key_names)))) + tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)} + group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)} + + key_counts = Counter(keys) + key_text = [f"{k} {key_counts[k]}" for k in key_names] + tag_counts = Counter(list(zip(keys, tags))) + tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names] + group_counts = Counter(groups) + group_text = [f"{k} {group_counts[k]}" for k in group_names] + + fig = go.Figure( + data=[ + go.Sankey( + orientation="h", + node=dict( + pad=15, + thickness=20, + line=dict(color="black", width=0.5), + label=key_text + tag_text + group_text, + x=[0] * len(key_names) + + [1] * len(tag_names) + + [2] * len(group_names), + color="blue", + ), + arrangement="fixed", + link=dict( + source=[key2idx[k] for k, _ in counts_key_tag] + + [tag2idx[k, t] for k, t, _ in counts_key_tag_group], + target=[tag2idx[k, t] for k, t in counts_key_tag] + + [group2idx[g] for _, _, g in counts_key_tag_group], + value=list(counts_key_tag.values()) + + list(counts_key_tag_group.values()), + ), + ) + ] + ) + fig.update_layout(autosize=False, width=800, height=2000, font_size=10) + fig.show() + return fig diff --git a/osm/data.py b/osm/data.py new file mode 100644 index 0000000000000000000000000000000000000000..dafc568f8ad5ac8c72ea9ffbd096838e0693ad84 --- /dev/null +++ b/osm/data.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple + +import numpy as np + +from .parser import ( + filter_area, + filter_node, + filter_way, + match_to_group, + parse_area, + parse_node, + parse_way, + Patterns, +) +from .reader import OSMData, OSMNode, OSMRelation, OSMWay + + +logger = logging.getLogger(__name__) + + +def glue(ways: List[OSMWay]) -> List[List[OSMNode]]: + result: List[List[OSMNode]] = [] + to_process: Set[Tuple[OSMNode]] = set() + + for way in ways: + if way.is_cycle(): + result.append(way.nodes) + else: + to_process.add(tuple(way.nodes)) + + while to_process: + nodes: List[OSMNode] = list(to_process.pop()) + glued: Optional[List[OSMNode]] = None + other_nodes: Optional[Tuple[OSMNode]] = None + + for other_nodes in to_process: + glued = try_to_glue(nodes, list(other_nodes)) + if glued is not None: + break + + if glued is not None: + to_process.remove(other_nodes) + if is_cycle(glued): + result.append(glued) + else: + to_process.add(tuple(glued)) + else: + result.append(nodes) + + return result + + +def is_cycle(nodes: List[OSMNode]) -> bool: + """Is way a cycle way or an area boundary.""" + return nodes[0] == nodes[-1] + + +def try_to_glue(nodes: List[OSMNode], other: List[OSMNode]) -> Optional[List[OSMNode]]: + """Create new combined way if ways share endpoints.""" + if nodes[0] == other[0]: + return list(reversed(other[1:])) + nodes + if nodes[0] == other[-1]: + return other[:-1] + nodes + if nodes[-1] == other[-1]: + return nodes + list(reversed(other[:-1])) + if nodes[-1] == other[0]: + return nodes + other[1:] + return None + + +def multipolygon_from_relation(rel: OSMRelation, osm: OSMData): + inner_ways = [] + outer_ways = [] + for member in rel.members: + if member.type_ == "way": + if member.role == "inner": + if member.ref in osm.ways: + inner_ways.append(osm.ways[member.ref]) + elif member.role == "outer": + if member.ref in osm.ways: + outer_ways.append(osm.ways[member.ref]) + else: + logger.warning(f'Unknown member role "{member.role}".') + if outer_ways: + inners_path = glue(inner_ways) + outers_path = glue(outer_ways) + return inners_path, outers_path + + +@dataclass +class MapElement: + id_: int + label: str + group: str + tags: Optional[Dict[str, str]] + + +@dataclass +class MapNode(MapElement): + xy: np.ndarray + + @classmethod + def from_osm(cls, node: OSMNode, label: str, group: str): + return cls( + node.id_, + label, + group, + node.tags, + xy=node.xy, + ) + + +@dataclass +class MapLine(MapElement): + xy: np.ndarray + + @classmethod + def from_osm(cls, way: OSMWay, label: str, group: str): + xy = np.stack([n.xy for n in way.nodes]) + return cls( + way.id_, + label, + group, + way.tags, + xy=xy, + ) + + +@dataclass +class MapArea(MapElement): + outers: List[np.ndarray] + inners: List[np.ndarray] = field(default_factory=list) + + @classmethod + def from_relation(cls, rel: OSMRelation, label: str, group: str, osm: OSMData): + outers_inners = multipolygon_from_relation(rel, osm) + if outers_inners is None: + return None + outers, inners = outers_inners + outers = [np.stack([n.xy for n in way]) for way in outers] + inners = [np.stack([n.xy for n in way]) for way in inners] + return cls( + rel.id_, + label, + group, + rel.tags, + outers=outers, + inners=inners, + ) + + @classmethod + def from_way(cls, way: OSMWay, label: str, group: str): + xy = np.stack([n.xy for n in way.nodes]) + return cls( + way.id_, + label, + group, + way.tags, + outers=[xy], + ) + + +class MapData: + def __init__(self): + self.nodes: Dict[int, MapNode] = {} + self.lines: Dict[int, MapLine] = {} + self.areas: Dict[int, MapArea] = {} + + @classmethod + def from_osm(cls, osm: OSMData): + self = cls() + + for node in filter(filter_node, osm.nodes.values()): + label = parse_node(node.tags) + if label is None: + continue + group = match_to_group(label, Patterns.nodes) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + continue # missing + self.nodes[node.id_] = MapNode.from_osm(node, label, group) + + for way in filter(filter_way, osm.ways.values()): + label = parse_way(way.tags) + if label is None: + continue + group = match_to_group(label, Patterns.ways) + if group is None: + group = match_to_group(label, Patterns.nodes) + if group is None: + continue # missing + self.lines[way.id_] = MapLine.from_osm(way, label, group) + + for area in filter(filter_area, osm.ways.values()): + label = parse_area(area.tags) + if label is None: + continue + group = match_to_group(label, Patterns.areas) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + group = match_to_group(label, Patterns.nodes) + if group is None: + continue # missing + self.areas[area.id_] = MapArea.from_way(area, label, group) + + for rel in osm.relations.values(): + if rel.tags.get("type") != "multipolygon": + continue + label = parse_area(rel.tags) + if label is None: + continue + group = match_to_group(label, Patterns.areas) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + group = match_to_group(label, Patterns.nodes) + if group is None: + continue # missing + area = MapArea.from_relation(rel, label, group, osm) + assert rel.id_ not in self.areas # not sure if there can be collision + if area is not None: + self.areas[rel.id_] = area + + return self diff --git a/osm/download.py b/osm/download.py new file mode 100644 index 0000000000000000000000000000000000000000..5a188e513aaaff8e73d6ca60a4caf51ff60bc0fa --- /dev/null +++ b/osm/download.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +from pathlib import Path +from typing import Dict, Optional + +import urllib3 + + +from utils.geo import BoundaryBox +import urllib.request +import requests + +def get_osm( + boundary_box: BoundaryBox, + cache_path: Optional[Path] = None, + overwrite: bool = False, +) -> str: + if not overwrite and cache_path is not None and cache_path.is_file(): + with cache_path.open() as fp: + return json.load(fp) + + (bottom, left), (top, right) = boundary_box.min_, boundary_box.max_ + content: bytes = get_web_data( + # "https://api.openstreetmap.org/api/0.6/map.json", + "https://openstreetmap.erniubot.live/api/0.6/map.json", + # 'https://overpass-api.de/api/map', + # 'http://localhost:29505/api/map', + # "https://lz4.overpass-api.de/api/interpreter", + {"bbox": f"{left},{bottom},{right},{top}"}, + ) + + content_str = content.decode("utf-8") + if content_str.startswith("You requested too many nodes"): + raise ValueError(content_str) + + if cache_path is not None: + with cache_path.open("bw+") as fp: + fp.write(content) + a=json.loads(content_str) + return json.loads(content_str) + + +def get_web_data(address: str, parameters: Dict[str, str]) -> bytes: + # logger.info("Getting %s...", address) + # proxy_address = "http://107.173.122.186:3128" + # + # # 设置代理服务器地址和端口 + # proxies = { + # 'http': proxy_address, + # 'https': proxy_address + # } + + # 发送GET请求并返回响应数据 + # response = requests.get(address, params=parameters, timeout=100, proxies=proxies) + print('url:',address) + response = requests.get(address, params=parameters, timeout=100) + return response.content +def get_web_data(address: str, parameters: Dict[str, str]) -> bytes: + # logger.info("Getting %s...", address) + while True: + try: + # proxy_address = "http://107.173.122.186:3128" + # + # # 设置代理服务器地址和端口 + # proxies = { + # 'http': proxy_address, + # 'https': proxy_address + # } + # # 发送GET请求并返回响应数据 + response = requests.get(address, params=parameters, timeout=100) + request = requests.Request('GET', address, params=parameters) + prepared_request = request.prepare() + # 获取完整URL + full_url = prepared_request.url + break + + except Exception as e: + # 打印错误信息 + print(f"发生错误: {e}") + print("重试...") + + return response.content +# def get_web_data_2(address: str, parameters: Dict[str, str]) -> bytes: +# # logger.info("Getting %s...", address) +# proxy_address="http://107.173.122.186:3128" +# http = urllib3.PoolManager(proxy_url=proxy_address) +# result = http.request("GET", address, parameters, timeout=100) +# return result.data +# +# +# def get_web_data_1(address: str, parameters: Dict[str, str]) -> bytes: +# +# # 设置代理服务器地址和端口 +# proxy_address = "http://107.173.122.186:3128" +# +# # 创建ProxyHandler对象 +# proxy_handler = urllib.request.ProxyHandler({'http': proxy_address}) +# +# # 构建查询字符串 +# query_string = urllib.parse.urlencode(parameters) +# +# # 构建完整的URL +# url = address + '?' + query_string +# print(url) +# # 创建OpenerDirector对象,并将ProxyHandler对象作为参数传递 +# opener = urllib.request.build_opener(proxy_handler) +# +# # 使用OpenerDirector对象发送请求 +# response = opener.open(url) +# +# # 发送GET请求 +# # response = urllib.request.urlopen(url, timeout=100) +# +# # 读取响应内容 +# data = response.read() +# print() +# return data \ No newline at end of file diff --git a/osm/parser.py b/osm/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..d235c71bdbdd22d280d60015b5941d25ba0345e1 --- /dev/null +++ b/osm/parser.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import re +from typing import List + +from .reader import OSMData, OSMElement, OSMNode, OSMWay + +IGNORE_TAGS = {"source", "phone", "entrance", "inscription", "note", "name"} + + +def parse_levels(string: str) -> List[float]: + """Parse string representation of level sequence value.""" + try: + cleaned = string.replace(",", ";").replace(" ", "") + return list(map(float, cleaned.split(";"))) + except ValueError: + logging.debug("Cannot parse level description from `%s`.", string) + return [] + + +def filter_level(elem: OSMElement): + level = elem.tags.get("level") + if level is not None: + levels = parse_levels(level) + # In the US, ground floor levels are sometimes marked as level=1 + # so let's be conservative and include it. + if not (0 in levels or 1 in levels): + return False + layer = elem.tags.get("layer") + if layer is not None: + layer = parse_levels(layer) + if len(layer) > 0 and max(layer) < 0: + return False + return ( + elem.tags.get("location") != "underground" + and elem.tags.get("parking") != "underground" + ) + + +def filter_node(node: OSMNode): + return len(node.tags.keys() - IGNORE_TAGS) > 0 and filter_level(node) + + +def is_area(way: OSMWay): + if way.nodes[0] != way.nodes[-1]: + return False + if way.tags.get("area") == "no": + return False + filters = [ + "area", + "building", + "amenity", + "indoor", + "landuse", + "landcover", + "leisure", + "public_transport", + "shop", + ] + for f in filters: + if f in way.tags and way.tags.get(f) != "no": + return True + if way.tags.get("natural") in {"wood", "grassland", "water"}: + return True + return False + + +def filter_area(way: OSMWay): + return len(way.tags.keys() - IGNORE_TAGS) > 0 and is_area(way) and filter_level(way) + + +def filter_way(way: OSMWay): + return not filter_area(way) and way.tags != {} and filter_level(way) + + +def parse_node(tags): + keys = tags.keys() + for key in [ + "amenity", + "natural", + "highway", + "barrier", + "shop", + "tourism", + "public_transport", + "emergency", + "man_made", + ]: + if key in keys: + if "disused" in tags[key]: + continue + return f"{key}:{tags[key]}" + return None + + +def parse_area(tags): + if "building" in tags: + group = "building" + kind = tags["building"] + if kind == "yes": + for key in ["amenity", "tourism"]: + if key in tags: + kind = tags[key] + break + if kind != "yes": + group += f":{kind}" + return group + if "area:highway" in tags: + return f'highway:{tags["area:highway"]}' + for key in [ + "amenity", + "landcover", + "leisure", + "shop", + "highway", + "tourism", + "natural", + "waterway", + "landuse", + ]: + if key in tags: + return f"{key}:{tags[key]}" + return None + + +def parse_way(tags): + keys = tags.keys() + for key in ["highway", "barrier", "natural"]: + if key in keys: + return f"{key}:{tags[key]}" + return None + + +def match_to_group(label, patterns): + for group, pattern in patterns.items(): + if re.match(pattern, label): + return group + return None + + +class Patterns: + areas = dict( + building="building($|:.*?)*", + parking="amenity:parking", + playground="leisure:(playground|pitch)", + grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)", + park="leisure:(park|garden|dog_park)", + forest="(landuse:forest|natural:wood)", + water="(natural:water|waterway:*)", + ) + # + ways: road, path + # + node: fountain, bicycle_parking + + ways = dict( + fence="barrier:(fence|yes)", + wall="barrier:(wall|retaining_wall)", + hedge="barrier:hedge", + kerb="barrier:kerb", + building_outline="building($|:.*?)*", + cycleway="highway:cycleway", + path="highway:(pedestrian|footway|steps|path|corridor)", + road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)", + busway="highway:busway", + tree_row="natural:tree_row", # maybe merge with node? + ) + # + nodes: bollard + + nodes = dict( + tree="natural:tree", + stone="(natural:stone|barrier:block)", + crossing="highway:crossing", + lamp="highway:street_lamp", + traffic_signal="highway:traffic_signals", + bus_stop="highway:bus_stop", + stop_sign="highway:stop", + junction="highway:motorway_junction", + bus_stop_position="public_transport:stop_position", + gate="barrier:(gate|lift_gate|swing_gate|cycle_barrier)", + bollard="barrier:bollard", + shop="(shop.*?|amenity:(bank|post_office))", + restaurant="amenity:(restaurant|fast_food)", + bar="amenity:(cafe|bar|pub|biergarten)", + pharmacy="amenity:pharmacy", + fuel="amenity:fuel", + bicycle_parking="amenity:(bicycle_parking|bicycle_rental)", + charging_station="amenity:charging_station", + parking_entrance="amenity:parking_entrance", + atm="amenity:atm", + toilets="amenity:toilets", + vending_machine="amenity:vending_machine", + fountain="amenity:fountain", + waste_basket="amenity:(waste_basket|waste_disposal)", + bench="amenity:bench", + post_box="amenity:post_box", + artwork="tourism:artwork", + recycling="amenity:recycling", + give_way="highway:give_way", + clock="amenity:clock", + fire_hydrant="emergency:fire_hydrant", + pole="man_made:(flagpole|utility_pole)", + street_cabinet="man_made:street_cabinet", + ) + # + ways: kerb + + +class Groups: + areas = list(Patterns.areas) + ways = list(Patterns.ways) + nodes = list(Patterns.nodes) + + +def group_elements(osm: OSMData): + elem2group = { + "area": {}, + "way": {}, + "node": {}, + } + + for node in filter(filter_node, osm.nodes.values()): + label = parse_node(node.tags) + if label is None: + continue + group = match_to_group(label, Patterns.nodes) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + continue # missing + elem2group["node"][node.id_] = group + + for way in filter(filter_way, osm.ways.values()): + label = parse_way(way.tags) + if label is None: + continue + group = match_to_group(label, Patterns.ways) + if group is None: + group = match_to_group(label, Patterns.nodes) + if group is None: + continue # missing + elem2group["way"][way.id_] = group + + for area in filter(filter_area, osm.ways.values()): + label = parse_area(area.tags) + if label is None: + continue + group = match_to_group(label, Patterns.areas) + if group is None: + group = match_to_group(label, Patterns.ways) + if group is None: + group = match_to_group(label, Patterns.nodes) + if group is None: + continue # missing + elem2group["area"][area.id_] = group + + return elem2group diff --git a/osm/raster.py b/osm/raster.py new file mode 100644 index 0000000000000000000000000000000000000000..9e203bc50db0d873a9c8544ed0a1533c84a38df9 --- /dev/null +++ b/osm/raster.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from typing import Dict, List + +import cv2 +import numpy as np +import torch + +from utils.geo import BoundaryBox +from .data import MapArea, MapLine, MapNode +from .parser import Groups + + +class Canvas: + def __init__(self, bbox: BoundaryBox, ppm: float): + self.bbox = bbox + self.ppm = ppm + self.scaling = bbox.size * ppm + self.w, self.h = np.ceil(self.scaling).astype(int) + self.clear() + + def clear(self): + self.raster = np.zeros((self.h, self.w), np.uint8) + + def to_uv(self, xy: np.ndarray): + xy = self.bbox.normalize(xy) + xy[..., 1] = 1 - xy[..., 1] + s = self.scaling + if isinstance(xy, torch.Tensor): + s = torch.from_numpy(s).to(xy) + return xy * s - 0.5 + + def to_xy(self, uv: np.ndarray): + s = self.scaling + if isinstance(uv, torch.Tensor): + s = torch.from_numpy(s).to(uv) + xy = (uv + 0.5) / s + xy[..., 1] = 1 - xy[..., 1] + return self.bbox.unnormalize(xy) + + def draw_polygon(self, xy: np.ndarray): + uv = self.to_uv(xy) + cv2.fillPoly(self.raster, uv[None].astype(np.int32), 255) + + def draw_multipolygon(self, xys: List[np.ndarray]): + uvs = [self.to_uv(xy).round().astype(np.int32) for xy in xys] + cv2.fillPoly(self.raster, uvs, 255) + + def draw_line(self, xy: np.ndarray, width: float = 1): + uv = self.to_uv(xy) + cv2.polylines( + self.raster, uv[None].round().astype(np.int32), False, 255, thickness=width + ) + + def draw_cell(self, xy: np.ndarray): + if not self.bbox.contains(xy): + return + uv = self.to_uv(xy) + self.raster[tuple(uv.round().astype(int).T[::-1])] = 255 + + +def render_raster_masks( + nodes: List[MapNode], + lines: List[MapLine], + areas: List[MapArea], + canvas: Canvas, +) -> Dict[str, np.ndarray]: + all_groups = Groups.areas + Groups.ways + Groups.nodes + masks = {k: np.zeros((canvas.h, canvas.w), np.uint8) for k in all_groups} + + for area in areas: + canvas.raster = masks[area.group] + outlines = area.outers + area.inners + canvas.draw_multipolygon(outlines) + if area.group == "building": + canvas.raster = masks["building_outline"] + for line in outlines: + canvas.draw_line(line) + + for line in lines: + canvas.raster = masks[line.group] + canvas.draw_line(line.xy) + + for node in nodes: + canvas.raster = masks[node.group] + canvas.draw_cell(node.xy) + + return masks + + +def mask_to_idx(group2mask: Dict[str, np.ndarray], groups: List[str]) -> np.ndarray: + masks = np.stack([group2mask[k] for k in groups]) > 0 + void = ~np.any(masks, 0) + idx = np.argmax(masks, 0) + idx = np.where(void, np.zeros_like(idx), idx + 1) # add background + return idx + + +def render_raster_map(masks: Dict[str, np.ndarray]) -> np.ndarray: + areas = mask_to_idx(masks, Groups.areas) + ways = mask_to_idx(masks, Groups.ways) + nodes = mask_to_idx(masks, Groups.nodes) + return np.stack([areas, ways, nodes]) diff --git a/osm/reader.py b/osm/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..793ad1879f8b2068cd5265bed408be14719e9680 --- /dev/null +++ b/osm/reader.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from lxml import etree +import numpy as np + +from utils.geo import BoundaryBox, Projection + +METERS_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*m$") +KILOMETERS_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*km$") +MILES_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*mi$") + + +def parse_float(string: str) -> Optional[float]: + """Parse string representation of a float or integer value.""" + try: + return float(string) + except (TypeError, ValueError): + return None + + +@dataclass(eq=False) +class OSMElement: + """ + Something with tags (string to string mapping). + """ + + id_: int + tags: Dict[str, str] + + def get_float(self, key: str) -> Optional[float]: + """Parse float from tag value.""" + if key in self.tags: + return parse_float(self.tags[key]) + return None + + def get_length(self, key: str) -> Optional[float]: + """Get length in meters.""" + if key not in self.tags: + return None + + value: str = self.tags[key] + + float_value: float = parse_float(value) + if float_value is not None: + return float_value + + for pattern, ratio in [ + (METERS_PATTERN, 1.0), + (KILOMETERS_PATTERN, 1000.0), + (MILES_PATTERN, 1609.344), + ]: + matcher: re.Match = pattern.match(value) + if matcher: + float_value: float = parse_float(matcher.group("value")) + if float_value is not None: + return float_value * ratio + + return None + + def __hash__(self) -> int: + return self.id_ + + +@dataclass(eq=False) +class OSMNode(OSMElement): + """ + OpenStreetMap node. + + See https://wiki.openstreetmap.org/wiki/Node + """ + + geo: np.ndarray + visible: Optional[str] = None + xy: Optional[np.ndarray] = None + + @classmethod + def from_dict(cls, structure: Dict[str, Any]) -> "OSMNode": + """ + Parse node from Overpass-like structure. + + :param structure: input structure + """ + return cls( + structure["id"], + structure.get("tags", {}), + geo=np.array((structure["lat"], structure["lon"])), + visible=structure.get("visible"), + ) + + +@dataclass(eq=False) +class OSMWay(OSMElement): + """ + OpenStreetMap way. + + See https://wiki.openstreetmap.org/wiki/Way + """ + + nodes: Optional[List[OSMNode]] = field(default_factory=list) + visible: Optional[str] = None + + @classmethod + def from_dict( + cls, structure: Dict[str, Any], nodes: Dict[int, OSMNode] + ) -> "OSMWay": + """ + Parse way from Overpass-like structure. + + :param structure: input structure + :param nodes: node structure + """ + return cls( + structure["id"], + structure.get("tags", {}), + [nodes[x] for x in structure["nodes"]], + visible=structure.get("visible"), + ) + + def is_cycle(self) -> bool: + """Is way a cycle way or an area boundary.""" + return self.nodes[0] == self.nodes[-1] + + def __repr__(self) -> str: + return f"Way <{self.id_}> {self.nodes}" + + +@dataclass +class OSMMember: + """ + Member of OpenStreetMap relation. + """ + + type_: str + ref: int + role: str + + +@dataclass(eq=False) +class OSMRelation(OSMElement): + """ + OpenStreetMap relation. + + See https://wiki.openstreetmap.org/wiki/Relation + """ + + members: Optional[List[OSMMember]] + visible: Optional[str] = None + + @classmethod + def from_dict(cls, structure: Dict[str, Any]) -> "OSMRelation": + """ + Parse relation from Overpass-like structure. + + :param structure: input structure + """ + return cls( + structure["id"], + structure["tags"], + [OSMMember(x["type"], x["ref"], x["role"]) for x in structure["members"]], + visible=structure.get("visible"), + ) + + +class OSMData: + """ + The whole OpenStreetMap information about nodes, ways, and relations. + """ + + def __init__(self) -> None: + self.nodes: Dict[int, OSMNode] = {} + self.ways: Dict[int, OSMWay] = {} + self.relations: Dict[int, OSMRelation] = {} + self.box: BoundaryBox = None + + @classmethod + def from_dict(cls, structure: Dict[str, Any]): + data = cls() + bounds = structure.get("bounds") + if bounds is not None: + data.box = BoundaryBox( + np.array([bounds["minlat"], bounds["minlon"]]), + np.array([bounds["maxlat"], bounds["maxlon"]]), + ) + + for element in structure["elements"]: + if element["type"] == "node": + node = OSMNode.from_dict(element) + data.add_node(node) + for element in structure["elements"]: + if element["type"] == "way": + way = OSMWay.from_dict(element, data.nodes) + data.add_way(way) + for element in structure["elements"]: + if element["type"] == "relation": + relation = OSMRelation.from_dict(element) + data.add_relation(relation) + + return data + + @classmethod + def from_json(cls, path: Path): + with path.open(encoding='utf-8') as fid: + structure = json.load(fid) + return cls.from_dict(structure) + + @classmethod + def from_xml(cls, path: Path): + root = etree.parse(str(path)).getroot() + structure = {"elements": []} + from tqdm import tqdm + + for elem in tqdm(root): + if elem.tag == "bounds": + structure["bounds"] = { + k: float(elem.attrib[k]) + for k in ("minlon", "minlat", "maxlon", "maxlat") + } + elif elem.tag in {"node", "way", "relation"}: + if elem.tag == "node": + item = { + "id": int(elem.attrib["id"]), + "lat": float(elem.attrib["lat"]), + "lon": float(elem.attrib["lon"]), + "visible": elem.attrib.get("visible"), + "tags": { + x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag" + }, + } + elif elem.tag == "way": + item = { + "id": int(elem.attrib["id"]), + "visible": elem.attrib.get("visible"), + "tags": { + x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag" + }, + "nodes": [int(x.attrib["ref"]) for x in elem if x.tag == "nd"], + } + elif elem.tag == "relation": + item = { + "id": int(elem.attrib["id"]), + "visible": elem.attrib.get("visible"), + "tags": { + x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag" + }, + "members": [ + { + "type": x.attrib["type"], + "ref": int(x.attrib["ref"]), + "role": x.attrib["role"], + } + for x in elem + if x.tag == "member" + ], + } + item["type"] = elem.tag + structure["elements"].append(item) + elem.clear() + del root + return cls.from_dict(structure) + + @classmethod + def from_file(cls, path: Path): + ext = path.suffix + if ext == ".json": + return cls.from_json(path) + elif ext in {".osm", ".xml"}: + return cls.from_xml(path) + else: + raise ValueError(f"Unknown extension for {path}") + + def add_node(self, node: OSMNode): + """Add node and update map parameters.""" + if node.id_ in self.nodes: + raise ValueError(f"Node with duplicate id {node.id_}.") + self.nodes[node.id_] = node + + def add_way(self, way: OSMWay): + """Add way and update map parameters.""" + if way.id_ in self.ways: + raise ValueError(f"Way with duplicate id {way.id_}.") + self.ways[way.id_] = way + + def add_relation(self, relation: OSMRelation): + """Add relation and update map parameters.""" + if relation.id_ in self.relations: + raise ValueError(f"Relation with duplicate id {relation.id_}.") + self.relations[relation.id_] = relation + + def add_xy_to_nodes(self, proj: Projection): + nodes = list(self.nodes.values()) + if len(nodes) == 0: + return + geos = np.stack([n.geo for n in nodes], 0) + if proj.bounds is not None: + # For some reasons few nodes are sometimes very far off the initial bbox. + valid = proj.bounds.contains(geos) + if valid.mean() < 0.9: + print("Many nodes are out of the projection bounds.") + xys = np.zeros_like(geos) + xys[valid] = proj.project(geos[valid]) + else: + xys = proj.project(geos) + for xy, node in zip(xys, nodes): + node.xy = xy diff --git a/osm/tiling.py b/osm/tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..b610f0dc1eb31376e516deaafa1d33dc496aec2c --- /dev/null +++ b/osm/tiling.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import io +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from PIL import Image +import rtree + +from utils.geo import BoundaryBox, Projection +from .data import MapData +from .download import get_osm +from .parser import Groups +from .raster import Canvas, render_raster_map, render_raster_masks +from .reader import OSMData, OSMNode, OSMWay + + +class MapIndex: + def __init__( + self, + data: MapData, + ): + self.index_nodes = rtree.index.Index() + for i, node in data.nodes.items(): + self.index_nodes.insert(i, tuple(node.xy) * 2) + + self.index_lines = rtree.index.Index() + for i, line in data.lines.items(): + bbox = tuple(np.r_[line.xy.min(0), line.xy.max(0)]) + self.index_lines.insert(i, bbox) + + self.index_areas = rtree.index.Index() + for i, area in data.areas.items(): + xy = np.concatenate(area.outers + area.inners) + bbox = tuple(np.r_[xy.min(0), xy.max(0)]) + self.index_areas.insert(i, bbox) + + self.data = data + + def query(self, bbox: BoundaryBox) -> Tuple[List[OSMNode], List[OSMWay]]: + query = tuple(np.r_[bbox.min_, bbox.max_]) + ret = [] + for x in ["nodes", "lines", "areas"]: + ids = getattr(self, "index_" + x).intersection(query) + ret.append([getattr(self.data, x)[i] for i in ids]) + return tuple(ret) + + +def bbox_to_slice(bbox: BoundaryBox, canvas: Canvas): + uv_min = np.ceil(canvas.to_uv(bbox.min_)).astype(int) + uv_max = np.ceil(canvas.to_uv(bbox.max_)).astype(int) + slice_ = (slice(uv_max[1], uv_min[1]), slice(uv_min[0], uv_max[0])) + return slice_ + + +def round_bbox(bbox: BoundaryBox, origin: np.ndarray, ppm: int): + bbox = bbox.translate(-origin) + bbox = BoundaryBox(np.round(bbox.min_ * ppm) / ppm, np.round(bbox.max_ * ppm) / ppm) + return bbox.translate(origin) + +class MapTileManager: + def __init__( + self, + osmpath:Path, + ): + + self.osm = OSMData.from_file(osmpath) + + + # @classmethod + def from_bbox( + self, + projection: Projection, + bbox: BoundaryBox, + ppm: int, + tile_size: int = 128, + ): + # bbox_osm = projection.unproject(bbox) + # if path is not None and path.is_file(): + # print(OSMData.from_file) + # osm = OSMData.from_file(path) + # if osm.box is not None: + # assert osm.box.contains(bbox_osm) + # else: + # osm = OSMData.from_dict(get_osm(bbox_osm, path)) + + self.osm.add_xy_to_nodes(projection) + map_data = MapData.from_osm(self.osm) + map_index = MapIndex(map_data) + + bounds_x, bounds_y = [ + np.r_[np.arange(min_, max_, tile_size), max_] + for min_, max_ in zip(bbox.min_, bbox.max_) + ] + bbox_tiles = {} + for i, xmin in enumerate(bounds_x[:-1]): + for j, ymin in enumerate(bounds_y[:-1]): + bbox_tiles[i, j] = BoundaryBox( + [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]] + ) + + tiles = {} + for ij, bbox_tile in bbox_tiles.items(): + canvas = Canvas(bbox_tile, ppm) + nodes, lines, areas = map_index.query(bbox_tile) + masks = render_raster_masks(nodes, lines, areas, canvas) + canvas.raster = render_raster_map(masks) + tiles[ij] = canvas + + groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")} + + self.origin = bbox.min_ + self.bbox = bbox + self.tiles = tiles + self.tile_size = tile_size + self.ppm = ppm + self.projection = projection + self.groups = groups + self.map_data = map_data + + return self.query(bbox) + # return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data) + + def query(self, bbox: BoundaryBox) -> Canvas: + bbox = round_bbox(bbox, self.bbox.min_, self.ppm) + canvas = Canvas(bbox, self.ppm) + raster = np.zeros((3, canvas.h, canvas.w), np.uint8) + + bbox_all = bbox & self.bbox + ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int) + ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1 + for i in range(ij_min[0], ij_max[0] + 1): + for j in range(ij_min[1], ij_max[1] + 1): + tile = self.tiles[i, j] + bbox_select = tile.bbox & bbox + slice_query = bbox_to_slice(bbox_select, canvas) + slice_tile = bbox_to_slice(bbox_select, tile) + raster[(slice(None),) + slice_query] = tile.raster[ + (slice(None),) + slice_tile + ] + canvas.raster = raster + return canvas + + def save(self, path: Path): + dump = { + "bbox": self.bbox.format(), + "tile_size": self.tile_size, + "ppm": self.ppm, + "groups": self.groups, + "tiles_bbox": {}, + "tiles_raster": {}, + } + if self.projection is not None: + dump["ref_latlonalt"] = self.projection.latlonalt + for ij, canvas in self.tiles.items(): + dump["tiles_bbox"][ij] = canvas.bbox.format() + raster_bytes = io.BytesIO() + raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8)) + raster.save(raster_bytes, format="PNG") + dump["tiles_raster"][ij] = raster_bytes + with open(path, "wb") as fp: + pickle.dump(dump, fp) + + @classmethod + def load(cls, path: Path): + with path.open("rb") as fp: + dump = pickle.load(fp) + tiles = {} + for ij, bbox in dump["tiles_bbox"].items(): + tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"]) + raster = np.asarray(Image.open(dump["tiles_raster"][ij])) + tiles[ij].raster = raster.transpose(2, 0, 1).copy() + projection = Projection(*dump["ref_latlonalt"]) + return cls( + tiles, + BoundaryBox.from_string(dump["bbox"]), + dump["tile_size"], + dump["ppm"], + projection, + dump["groups"], + ) + +class TileManager: + def __init__( + self, + tiles: Dict, + bbox: BoundaryBox, + tile_size: int, + ppm: int, + projection: Projection, + groups: Dict[str, List[str]], + map_data: Optional[MapData] = None, + ): + self.origin = bbox.min_ + self.bbox = bbox + self.tiles = tiles + self.tile_size = tile_size + self.ppm = ppm + self.projection = projection + self.groups = groups + self.map_data = map_data + assert np.all(tiles[0, 0].bbox.min_ == self.origin) + for tile in tiles.values(): + assert bbox.contains(tile.bbox) + + @classmethod + def from_bbox( + cls, + projection: Projection, + bbox: BoundaryBox, + ppm: int, + path: Optional[Path] = None, + tile_size: int = 128, + ): + bbox_osm = projection.unproject(bbox) + if path is not None and path.is_file(): + print(OSMData.from_file) + osm = OSMData.from_file(path) + if osm.box is not None: + assert osm.box.contains(bbox_osm) + else: + osm = OSMData.from_dict(get_osm(bbox_osm, path)) + + osm.add_xy_to_nodes(projection) + map_data = MapData.from_osm(osm) + map_index = MapIndex(map_data) + + bounds_x, bounds_y = [ + np.r_[np.arange(min_, max_, tile_size), max_] + for min_, max_ in zip(bbox.min_, bbox.max_) + ] + bbox_tiles = {} + for i, xmin in enumerate(bounds_x[:-1]): + for j, ymin in enumerate(bounds_y[:-1]): + bbox_tiles[i, j] = BoundaryBox( + [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]] + ) + + tiles = {} + for ij, bbox_tile in bbox_tiles.items(): + canvas = Canvas(bbox_tile, ppm) + nodes, lines, areas = map_index.query(bbox_tile) + masks = render_raster_masks(nodes, lines, areas, canvas) + canvas.raster = render_raster_map(masks) + tiles[ij] = canvas + + groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")} + + return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data) + + def query(self, bbox: BoundaryBox) -> Canvas: + bbox = round_bbox(bbox, self.bbox.min_, self.ppm) + canvas = Canvas(bbox, self.ppm) + raster = np.zeros((3, canvas.h, canvas.w), np.uint8) + + bbox_all = bbox & self.bbox + ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int) + ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1 + for i in range(ij_min[0], ij_max[0] + 1): + for j in range(ij_min[1], ij_max[1] + 1): + tile = self.tiles[i, j] + bbox_select = tile.bbox & bbox + slice_query = bbox_to_slice(bbox_select, canvas) + slice_tile = bbox_to_slice(bbox_select, tile) + raster[(slice(None),) + slice_query] = tile.raster[ + (slice(None),) + slice_tile + ] + canvas.raster = raster + return canvas + + def save(self, path: Path): + dump = { + "bbox": self.bbox.format(), + "tile_size": self.tile_size, + "ppm": self.ppm, + "groups": self.groups, + "tiles_bbox": {}, + "tiles_raster": {}, + } + if self.projection is not None: + dump["ref_latlonalt"] = self.projection.latlonalt + for ij, canvas in self.tiles.items(): + dump["tiles_bbox"][ij] = canvas.bbox.format() + raster_bytes = io.BytesIO() + raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8)) + raster.save(raster_bytes, format="PNG") + dump["tiles_raster"][ij] = raster_bytes + with open(path, "wb") as fp: + pickle.dump(dump, fp) + + @classmethod + def load(cls, path: Path): + with path.open("rb") as fp: + dump = pickle.load(fp) + tiles = {} + for ij, bbox in dump["tiles_bbox"].items(): + tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"]) + raster = np.asarray(Image.open(dump["tiles_raster"][ij])) + tiles[ij].raster = raster.transpose(2, 0, 1).copy() + projection = Projection(*dump["ref_latlonalt"]) + return cls( + tiles, + BoundaryBox.from_string(dump["bbox"]), + dump["tile_size"], + dump["ppm"], + projection, + dump["groups"], + ) diff --git a/osm/viz.py b/osm/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..cd99e3eda0049c2aae35d397018db73b2eb661ae --- /dev/null +++ b/osm/viz.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import plotly.graph_objects as go +import PIL.Image + +from utils.viz_2d import add_text +from .parser import Groups + + +class GeoPlotter: + def __init__(self, zoom=12, **kwargs): + self.fig = go.Figure() + self.fig.update_layout( + mapbox_style="open-street-map", + autosize=True, + mapbox_zoom=zoom, + margin={"r": 0, "t": 0, "l": 0, "b": 0}, + showlegend=True, + **kwargs, + ) + + def points(self, latlons, color, text=None, name=None, size=5, **kwargs): + latlons = np.asarray(latlons) + self.fig.add_trace( + go.Scattermapbox( + lat=latlons[..., 0], + lon=latlons[..., 1], + mode="markers", + text=text, + marker_color=color, + marker_size=size, + name=name, + **kwargs, + ) + ) + center = latlons.reshape(-1, 2).mean(0) + self.fig.update_layout( + mapbox_center=dict(zip(("lat", "lon"), center)), + ) + + def bbox(self, bbox, color, name=None, **kwargs): + corners = np.stack( + [bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_] + ) + self.fig.add_trace( + go.Scattermapbox( + lat=corners[:, 0], + lon=corners[:, 1], + mode="lines", + marker_color=color, + name=name, + **kwargs, + ) + ) + self.fig.update_layout( + mapbox_center=dict(zip(("lat", "lon"), bbox.center)), + ) + + def raster(self, raster, bbox, below="traces", **kwargs): + if not np.issubdtype(raster.dtype, np.integer): + raster = (raster * 255).astype(np.uint8) + raster = PIL.Image.fromarray(raster) + corners = np.stack( + [ + bbox.min_, + bbox.left_top, + bbox.max_, + bbox.right_bottom, + ] + )[::-1, ::-1] + layers = [*self.fig.layout.mapbox.layers] + layers.append( + dict( + sourcetype="image", + source=raster, + coordinates=corners, + below=below, + **kwargs, + ) + ) + self.fig.layout.mapbox.layers = layers + + +map_colors = { + "building": (84, 155, 255), + "parking": (255, 229, 145), + "playground": (150, 133, 125), + "grass": (188, 255, 143), + "park": (0, 158, 16), + "forest": (0, 92, 9), + "water": (184, 213, 255), + "fence": (238, 0, 255), + "wall": (0, 0, 0), + "hedge": (107, 68, 48), + "kerb": (255, 234, 0), + "building_outline": (0, 0, 255), + "cycleway": (0, 251, 255), + "path": (8, 237, 0), + "road": (255, 0, 0), + "tree_row": (0, 92, 9), + "busway": (255, 128, 0), + "void": [int(255 * 0.9)] * 3, +} + + +class Colormap: + colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas]) + colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways]) + + @classmethod + def apply(cls, rasters): + return ( + np.where( + rasters[1, ..., None] > 0, + cls.colors_ways[rasters[1]], + cls.colors_areas[rasters[0]], + ) + / 255.0 + ) + + @classmethod + def add_colorbar(cls): + ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8]) + color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0 + cmap = mpl.colors.ListedColormap(color_list[::-1]) + ticks = np.linspace(0, 1, len(color_list), endpoint=False) + ticks += 1 / len(color_list) / 2 + cb = mpl.colorbar.ColorbarBase( + ax2, + cmap=cmap, + orientation="vertical", + ticks=ticks, + ) + cb.set_ticklabels((Groups.areas + Groups.ways)[::-1]) + ax2.tick_params(labelsize=15) + + +def plot_nodes(idx, raster, fontsize=8, size=15): + ax = plt.gcf().axes[idx] + ax.autoscale(enable=False) + nodes_xy = np.stack(np.where(raster > 0)[::-1], -1) + nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1 + ax.scatter(*nodes_xy.T, c="k", s=size) + for xy, val in zip(nodes_xy, nodes_val): + group = Groups.nodes[val] + add_text( + idx, + group, + xy + 2, + lcolor=None, + fs=fontsize, + color="k", + normalized=False, + ha="center", + ) + plt.show() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..be5aaa88e586bcf463c7e9252b45bd4ebcd2f411 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +torch +torchvision +numpy +opencv-python +Pillow +tqdm>=4.36.0 +matplotlib +plotly +scipy +omegaconf +pytorch-lightning +torchmetrics +jupyter +lxml +rtree +scikit-learn +geopy +exifread +gradio_client +urllib3>=2 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b95c8795faf68414f860ef87dc0f6acb6cadebf2 --- /dev/null +++ b/train.py @@ -0,0 +1,217 @@ +import os.path as osp +import warnings +warnings.filterwarnings('ignore') +from typing import Optional +from pathlib import Path +from models.maplocnet import MapLocNet +import hydra +import pytorch_lightning as pl +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_only +from module import GenericModule +from logger import logger, pl_logger, EXPERIMENTS_PATH +from module import GenericModule +from dataset import UavMapDatasetModule +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +# print(osp.join(osp.dirname(__file__), "conf")) + + +class CleanProgressBar(pl.callbacks.TQDMProgressBar): + def get_metrics(self, trainer, model): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) # don't show the version number + items.pop("loss", None) + return items + + +class SeedingCallback(pl.callbacks.Callback): + def on_epoch_start_(self, trainer, module): + seed = module.cfg.experiment.seed + is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 + if trainer.training and not is_overfit: + seed = seed + trainer.current_epoch + + # Temporarily disable the logging (does not seem to work?) + pl_logger.disabled = True + try: + pl.seed_everything(seed, workers=True) + finally: + pl_logger.disabled = False + + def on_train_epoch_start(self, *args, **kwargs): + self.on_epoch_start_(*args, **kwargs) + + def on_validation_epoch_start(self, *args, **kwargs): + self.on_epoch_start_(*args, **kwargs) + + def on_test_epoch_start(self, *args, **kwargs): + self.on_epoch_start_(*args, **kwargs) + + +class ConsoleLogger(pl.callbacks.Callback): + @rank_zero_only + def on_train_epoch_start(self, trainer, module): + logger.info( + "New training epoch %d for experiment '%s'.", + module.current_epoch, + module.cfg.experiment.name, + ) + + # @rank_zero_only + # def on_validation_epoch_end(self, trainer, module): + # results = { + # **dict(module.metrics_val.items()), + # **dict(module.losses_val.items()), + # } + # results = [f"{k} {v.compute():.3E}" for k, v in results.items()] + # logger.info(f'[Validation] {{{", ".join(results)}}}') + + +def find_last_checkpoint_path(experiment_dir): + cls = pl.callbacks.ModelCheckpoint + path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) + if osp.exists(path): + return path + else: + return None + + +def prepare_experiment_dir(experiment_dir, cfg, rank): + config_path = osp.join(experiment_dir, "config.yaml") + last_checkpoint_path = find_last_checkpoint_path(experiment_dir) + if last_checkpoint_path is not None: + if rank == 0: + logger.info( + "Resuming the training from checkpoint %s", last_checkpoint_path + ) + if osp.exists(config_path): + with open(config_path, "r") as fp: + cfg_prev = OmegaConf.create(fp.read()) + compare_keys = ["experiment", "data", "model", "training"] + if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( + cfg_prev, compare_keys + ): + raise ValueError( + "Attempting to resume training with a different config: " + f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " + f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" + ) + if rank == 0: + Path(experiment_dir).mkdir(exist_ok=True, parents=True) + with open(config_path, "w") as fp: + OmegaConf.save(cfg, fp) + return last_checkpoint_path + + +def train(cfg: DictConfig) -> None: + torch.set_float32_matmul_precision("medium") + OmegaConf.resolve(cfg) + rank = rank_zero_only.rank + + if rank == 0: + logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) + if cfg.experiment.gpus in (None, 0): + logger.warning("Will train on CPU...") + cfg.experiment.gpus = 0 + elif not torch.cuda.is_available(): + raise ValueError("Requested GPU but no NVIDIA drivers found.") + pl.seed_everything(cfg.experiment.seed, workers=True) + + init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") + if init_checkpoint_path is not None: + logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) + model = GenericModule.load_from_checkpoint( + init_checkpoint_path, strict=True, find_best=False, cfg=cfg + ) + else: + model = GenericModule(cfg) + if rank == 0: + logger.info("Network:\n%s", model.model) + + experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) + last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) + checkpointing_epoch = pl.callbacks.ModelCheckpoint( + dirpath=experiment_dir, + filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", + auto_insert_metric_name=False, + save_last=True, + every_n_epochs=1, + save_on_train_epoch_end=True, + verbose=True, + **cfg.training.checkpointing, + ) + checkpointing_step = pl.callbacks.ModelCheckpoint( + dirpath=experiment_dir, + filename="checkpoint-step-{step}-{loss/total/val:02f}", + auto_insert_metric_name=False, + save_last=True, + every_n_train_steps=1000, + verbose=True, + **cfg.training.checkpointing, + ) + checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" + + # 创建 EarlyStopping 回调 + early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) + + strategy = None + if cfg.experiment.gpus > 1: + strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) + for split in ["train", "val"]: + cfg.data[split].batch_size = ( + cfg.data[split].batch_size // cfg.experiment.gpus + ) + cfg.data[split].num_workers = int( + (cfg.data[split].num_workers + cfg.experiment.gpus - 1) + / cfg.experiment.gpus + ) + + # data = data_modules[cfg.data.get("name", "mapillary")](cfg.data) + + datamodule =UavMapDatasetModule(cfg.data) + + tb_args = {"name": cfg.experiment.name, "version": ""} + tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) + + callbacks = [ + checkpointing_epoch, + checkpointing_step, + # early_stopping_callback, + pl.callbacks.LearningRateMonitor(), + SeedingCallback(), + CleanProgressBar(), + ConsoleLogger(), + ] + if cfg.experiment.gpus > 0: + callbacks.append(pl.callbacks.DeviceStatsMonitor()) + + trainer = pl.Trainer( + default_root_dir=experiment_dir, + detect_anomaly=False, + # strategy=ddp_find_unused_parameters_true, + enable_model_summary=True, + sync_batchnorm=True, + enable_checkpointing=True, + logger=tb, + callbacks=callbacks, + strategy=strategy, + check_val_every_n_epoch=1, + accelerator="gpu", + num_nodes=1, + **cfg.training.trainer, + ) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) + + +@hydra.main( + config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" +) +def main(cfg: DictConfig) -> None: + OmegaConf.save(config=cfg, f='maplocnet.yaml') + train(cfg) + + +if __name__ == "__main__": + main() + diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..3683adf5eb906698baec50d39e1e8d86f0d84bf7 --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +nohup python train.py > logs/train0907.log 2>&1 & \ No newline at end of file diff --git a/utils/exif.py b/utils/exif.py new file mode 100644 index 0000000000000000000000000000000000000000..c272019b5673dc1e3aab8a3e0e21b630cc629154 --- /dev/null +++ b/utils/exif.py @@ -0,0 +1,356 @@ +"""Copied from opensfm.exif to minimize hard dependencies.""" +from pathlib import Path +import json +import datetime +import logging +from codecs import encode, decode +from typing import Any, Dict, Optional, Tuple + +import exifread + +logger: logging.Logger = logging.getLogger(__name__) + +inch_in_mm = 25.4 +cm_in_mm = 10 +um_in_mm = 0.001 +default_projection = "perspective" +maximum_altitude = 1e4 + + +def sensor_data(): + with (Path(__file__).parent / "sensor_data.json").open() as fid: + data = json.load(fid) + return {k.lower(): v for k, v in data.items()} + + +def eval_frac(value) -> Optional[float]: + try: + return float(value.num) / float(value.den) + except ZeroDivisionError: + return None + + +def gps_to_decimal(values, reference) -> Optional[float]: + sign = 1 if reference in "NE" else -1 + degrees = eval_frac(values[0]) + minutes = eval_frac(values[1]) + seconds = eval_frac(values[2]) + if degrees is not None and minutes is not None and seconds is not None: + return sign * (degrees + minutes / 60 + seconds / 3600) + return None + + +def get_tag_as_float(tags, key, index: int = 0) -> Optional[float]: + if key in tags: + val = tags[key].values[index] + if isinstance(val, exifread.utils.Ratio): + ret_val = eval_frac(val) + if ret_val is None: + logger.error( + 'The rational "{2}" of tag "{0:s}" at index {1:d} c' + "aused a division by zero error".format(key, index, val) + ) + return ret_val + else: + return float(val) + else: + return None + + +def compute_focal( + focal_35: Optional[float], focal: Optional[float], sensor_width, sensor_string +) -> Tuple[float, float]: + if focal_35 is not None and focal_35 > 0: + focal_ratio = focal_35 / 36.0 # 35mm film produces 36x24mm pictures. + else: + if not sensor_width: + sensor_width = sensor_data().get(sensor_string, None) + if sensor_width and focal: + focal_ratio = focal / sensor_width + focal_35 = 36.0 * focal_ratio + else: + focal_35 = 0.0 + focal_ratio = 0.0 + return focal_35, focal_ratio + + +def sensor_string(make: str, model: str) -> str: + if make != "unknown": + # remove duplicate 'make' information in 'model' + model = model.replace(make, "") + return (make.strip() + " " + model.strip()).strip().lower() + + +def unescape_string(s) -> str: + return decode(encode(s, "latin-1", "backslashreplace"), "unicode-escape") + + +class EXIF: + def __init__( + self, fileobj, image_size_loader, use_exif_size=True, name=None + ) -> None: + self.image_size_loader = image_size_loader + self.use_exif_size = use_exif_size + self.fileobj = fileobj + self.tags = exifread.process_file(fileobj, details=False) + fileobj.seek(0) + self.fileobj_name = self.fileobj.name if name is None else name + + def extract_image_size(self) -> Tuple[int, int]: + if ( + self.use_exif_size + and "EXIF ExifImageWidth" in self.tags + and "EXIF ExifImageLength" in self.tags + ): + width, height = ( + int(self.tags["EXIF ExifImageWidth"].values[0]), + int(self.tags["EXIF ExifImageLength"].values[0]), + ) + elif ( + self.use_exif_size + and "Image ImageWidth" in self.tags + and "Image ImageLength" in self.tags + ): + width, height = ( + int(self.tags["Image ImageWidth"].values[0]), + int(self.tags["Image ImageLength"].values[0]), + ) + else: + height, width = self.image_size_loader() + return width, height + + def _decode_make_model(self, value) -> str: + """Python 2/3 compatible decoding of make/model field.""" + if hasattr(value, "decode"): + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return "unknown" + else: + return value + + def extract_make(self) -> str: + # Camera make and model + if "EXIF LensMake" in self.tags: + make = self.tags["EXIF LensMake"].values + elif "Image Make" in self.tags: + make = self.tags["Image Make"].values + else: + make = "unknown" + return self._decode_make_model(make) + + def extract_model(self) -> str: + if "EXIF LensModel" in self.tags: + model = self.tags["EXIF LensModel"].values + elif "Image Model" in self.tags: + model = self.tags["Image Model"].values + else: + model = "unknown" + return self._decode_make_model(model) + + def extract_focal(self) -> Tuple[float, float]: + make, model = self.extract_make(), self.extract_model() + focal_35, focal_ratio = compute_focal( + get_tag_as_float(self.tags, "EXIF FocalLengthIn35mmFilm"), + get_tag_as_float(self.tags, "EXIF FocalLength"), + self.extract_sensor_width(), + sensor_string(make, model), + ) + return focal_35, focal_ratio + + def extract_sensor_width(self) -> Optional[float]: + """Compute sensor with from width and resolution.""" + if ( + "EXIF FocalPlaneResolutionUnit" not in self.tags + or "EXIF FocalPlaneXResolution" not in self.tags + ): + return None + resolution_unit = self.tags["EXIF FocalPlaneResolutionUnit"].values[0] + mm_per_unit = self.get_mm_per_unit(resolution_unit) + if not mm_per_unit: + return None + pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneXResolution") + if pixels_per_unit is None: + return None + if pixels_per_unit <= 0.0: + pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneYResolution") + if pixels_per_unit is None or pixels_per_unit <= 0.0: + return None + units_per_pixel = 1 / pixels_per_unit + width_in_pixels = self.extract_image_size()[0] + return width_in_pixels * units_per_pixel * mm_per_unit + + def get_mm_per_unit(self, resolution_unit) -> Optional[float]: + """Length of a resolution unit in millimeters. + + Uses the values from the EXIF specs in + https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/EXIF.html + + Args: + resolution_unit: the resolution unit value given in the EXIF + """ + if resolution_unit == 2: # inch + return inch_in_mm + elif resolution_unit == 3: # cm + return cm_in_mm + elif resolution_unit == 4: # mm + return 1 + elif resolution_unit == 5: # um + return um_in_mm + else: + logger.warning( + "Unknown EXIF resolution unit value: {}".format(resolution_unit) + ) + return None + + def extract_orientation(self) -> int: + orientation = 1 + if "Image Orientation" in self.tags: + value = self.tags.get("Image Orientation").values[0] + if type(value) == int and value != 0: + orientation = value + return orientation + + def extract_ref_lon_lat(self) -> Tuple[str, str]: + if "GPS GPSLatitudeRef" in self.tags: + reflat = self.tags["GPS GPSLatitudeRef"].values + else: + reflat = "N" + if "GPS GPSLongitudeRef" in self.tags: + reflon = self.tags["GPS GPSLongitudeRef"].values + else: + reflon = "E" + return reflon, reflat + + def extract_lon_lat(self) -> Tuple[Optional[float], Optional[float]]: + if "GPS GPSLatitude" in self.tags: + reflon, reflat = self.extract_ref_lon_lat() + lat = gps_to_decimal(self.tags["GPS GPSLatitude"].values, reflat) + lon = gps_to_decimal(self.tags["GPS GPSLongitude"].values, reflon) + else: + lon, lat = None, None + return lon, lat + + def extract_altitude(self) -> Optional[float]: + if "GPS GPSAltitude" in self.tags: + alt_value = self.tags["GPS GPSAltitude"].values[0] + if isinstance(alt_value, exifread.utils.Ratio): + altitude = eval_frac(alt_value) + elif isinstance(alt_value, int): + altitude = float(alt_value) + else: + altitude = None + + # Check if GPSAltitudeRef is equal to 1, which means GPSAltitude should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53 + if ( + "GPS GPSAltitudeRef" in self.tags + and self.tags["GPS GPSAltitudeRef"].values[0] == 1 + and altitude is not None + ): + altitude = -altitude + else: + altitude = None + return altitude + + def extract_dop(self) -> Optional[float]: + if "GPS GPSDOP" in self.tags: + return eval_frac(self.tags["GPS GPSDOP"].values[0]) + return None + + def extract_geo(self) -> Dict[str, Any]: + altitude = self.extract_altitude() + dop = self.extract_dop() + lon, lat = self.extract_lon_lat() + d = {} + + if lon is not None and lat is not None: + d["latitude"] = lat + d["longitude"] = lon + if altitude is not None: + d["altitude"] = min([maximum_altitude, altitude]) + if dop is not None: + d["dop"] = dop + return d + + def extract_capture_time(self) -> float: + if ( + "GPS GPSDate" in self.tags + and "GPS GPSTimeStamp" in self.tags # Actually GPSDateStamp + ): + try: + hours_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 0) + minutes_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 1) + if hours_f is None or minutes_f is None: + raise TypeError + hours = int(hours_f) + minutes = int(minutes_f) + seconds = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 2) + gps_timestamp_string = "{0:s} {1:02d}:{2:02d}:{3:02f}".format( + self.tags["GPS GPSDate"].values, hours, minutes, seconds + ) + return ( + datetime.datetime.strptime( + gps_timestamp_string, "%Y:%m:%d %H:%M:%S.%f" + ) + - datetime.datetime(1970, 1, 1) + ).total_seconds() + except (TypeError, ValueError): + logger.info( + 'The GPS time stamp in image file "{0:s}" is invalid. ' + "Falling back to DateTime*".format(self.fileobj_name) + ) + + time_strings = [ + ("EXIF DateTimeOriginal", "EXIF SubSecTimeOriginal", "EXIF Tag 0x9011"), + ("EXIF DateTimeDigitized", "EXIF SubSecTimeDigitized", "EXIF Tag 0x9012"), + ("Image DateTime", "Image SubSecTime", "Image Tag 0x9010"), + ] + for datetime_tag, subsec_tag, offset_tag in time_strings: + if datetime_tag in self.tags: + date_time = self.tags[datetime_tag].values + if subsec_tag in self.tags: + subsec_time = self.tags[subsec_tag].values + else: + subsec_time = "0" + try: + s = "{0:s}.{1:s}".format(date_time, subsec_time) + d = datetime.datetime.strptime(s, "%Y:%m:%d %H:%M:%S.%f") + except ValueError: + logger.debug( + 'The "{1:s}" time stamp or "{2:s}" tag is invalid in ' + 'image file "{0:s}"'.format( + self.fileobj_name, datetime_tag, subsec_tag + ) + ) + continue + # Test for OffsetTimeOriginal | OffsetTimeDigitized | OffsetTime + if offset_tag in self.tags: + offset_time = self.tags[offset_tag].values + try: + d += datetime.timedelta( + hours=-int(offset_time[0:3]), minutes=int(offset_time[4:6]) + ) + except (TypeError, ValueError): + logger.debug( + 'The "{0:s}" time zone offset in image file "{1:s}"' + " is invalid".format(offset_tag, self.fileobj_name) + ) + logger.debug( + 'Naively assuming UTC on "{0:s}" in image file ' + '"{1:s}"'.format(datetime_tag, self.fileobj_name) + ) + else: + logger.debug( + "No GPS time stamp and no time zone offset in image " + 'file "{0:s}"'.format(self.fileobj_name) + ) + logger.debug( + 'Naively assuming UTC on "{0:s}" in image file "{1:s}"'.format( + datetime_tag, self.fileobj_name + ) + ) + return (d - datetime.datetime(1970, 1, 1)).total_seconds() + logger.info( + 'Image file "{0:s}" has no valid time stamp'.format(self.fileobj_name) + ) + return 0.0 diff --git a/utils/geo.py b/utils/geo.py new file mode 100644 index 0000000000000000000000000000000000000000..7f97e501f81092b57fc3ad043aa520779b44faeb --- /dev/null +++ b/utils/geo.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from typing import Union + +import numpy as np +import torch + +from .geo_opensfm import TopocentricConverter + + +class BoundaryBox: + def __init__(self, min_: np.ndarray, max_: np.ndarray): + self.min_ = np.asarray(min_) + self.max_ = np.asarray(max_) + assert np.all(self.min_ <= self.max_) + + @classmethod + def from_string(cls, string: str): + return cls(*np.split(np.array(string.split(","), float), 2)) + + @property + def left_top(self): + return np.stack([self.min_[..., 0], self.max_[..., 1]], -1) + + @property + def right_bottom(self) -> (np.ndarray, np.ndarray): + return np.stack([self.max_[..., 0], self.min_[..., 1]], -1) + + @property + def center(self) -> np.ndarray: + return (self.min_ + self.max_) / 2 + + @property + def size(self) -> np.ndarray: + return self.max_ - self.min_ + + def translate(self, t: float): + return self.__class__(self.min_ + t, self.max_ + t) + + def contains(self, xy: Union[np.ndarray, "BoundaryBox"]): + if isinstance(xy, self.__class__): + return self.contains(xy.min_) and self.contains(xy.max_) + return np.all((xy >= self.min_) & (xy <= self.max_), -1) + + def normalize(self, xy): + min_, max_ = self.min_, self.max_ + if isinstance(xy, torch.Tensor): + min_ = torch.from_numpy(min_).to(xy) + max_ = torch.from_numpy(max_).to(xy) + return (xy - min_) / (max_ - min_) + + def unnormalize(self, xy): + min_, max_ = self.min_, self.max_ + if isinstance(xy, torch.Tensor): + min_ = torch.from_numpy(min_).to(xy) + max_ = torch.from_numpy(max_).to(xy) + return xy * (max_ - min_) + min_ + + def format(self) -> str: + return ",".join(np.r_[self.min_, self.max_].astype(str)) + + def __add__(self, x): + if isinstance(x, (int, float)): + return self.__class__(self.min_ - x, self.max_ + x) + else: + raise TypeError(f"Cannot add {self.__class__.__name__} to {type(x)}.") + + def __and__(self, other): + return self.__class__( + np.maximum(self.min_, other.min_), np.minimum(self.max_, other.max_) + ) + + def __repr__(self): + return self.format() + + +class Projection: + def __init__(self, lat, lon, alt=0, max_extent=25e3): + # The approximation error is |L - radius * tan(L / radius)| + # and is around 13cm for L=25km. + self.latlonalt = (lat, lon, alt) + self.converter = TopocentricConverter(lat, lon, alt) + min_ = self.converter.to_lla(*(-max_extent,) * 2, 0)[:2] + max_ = self.converter.to_lla(*(max_extent,) * 2, 0)[:2] + self.bounds = BoundaryBox(min_, max_) + + @classmethod + def from_points(cls, all_latlon): + assert all_latlon.shape[-1] == 2 + all_latlon = all_latlon.reshape(-1, 2) + latlon_mid = (all_latlon.min(0) + all_latlon.max(0)) / 2 + return cls(*latlon_mid) + + def check_bbox(self, bbox: BoundaryBox): + if self.bounds is not None and not self.bounds.contains(bbox): + raise ValueError( + f"Bbox {bbox.format()} is not contained in " + f"projection with bounds {self.bounds.format()}." + ) + + def project(self, geo, return_z=False): + if isinstance(geo, BoundaryBox): + return BoundaryBox(*self.project(np.stack([geo.min_, geo.max_]))) + geo = np.asarray(geo) + assert geo.shape[-1] in (2, 3) + if self.bounds is not None: + if not np.all(self.bounds.contains(geo[..., :2])): + raise ValueError( + f"Points {geo} are out of the valid bounds " + f"{self.bounds.format()}." + ) + lat, lon = geo[..., 0], geo[..., 1] + if geo.shape[-1] == 3: + alt = geo[..., -1] + else: + alt = np.zeros_like(lat) + x, y, z = self.converter.to_topocentric(lat, lon, alt) + return np.stack([x, y] + ([z] if return_z else []), -1) + + def unproject(self, xy, return_z=False): + if isinstance(xy, BoundaryBox): + return BoundaryBox(*self.unproject(np.stack([xy.min_, xy.max_]))) + xy = np.asarray(xy) + x, y = xy[..., 0], xy[..., 1] + if xy.shape[-1] == 3: + z = xy[..., -1] + else: + z = np.zeros_like(x) + lat, lon, alt = self.converter.to_lla(x, y, z) + return np.stack([lat, lon] + ([alt] if return_z else []), -1) diff --git a/utils/geo_opensfm.py b/utils/geo_opensfm.py new file mode 100644 index 0000000000000000000000000000000000000000..d42145236dd4c65a8764cbe37fa25c527814017e --- /dev/null +++ b/utils/geo_opensfm.py @@ -0,0 +1,180 @@ +"""Copied from opensfm.geo to minimize hard dependencies.""" +import numpy as np +from numpy import ndarray +from typing import Tuple + +WGS84_a = 6378137.0 +WGS84_b = 6356752.314245 + + +def ecef_from_lla(lat, lon, alt: float) -> Tuple[float, ...]: + """ + Compute ECEF XYZ from latitude, longitude and altitude. + + All using the WGS84 model. + Altitude is the distance to the WGS84 ellipsoid. + Check results here http://www.oc.nps.edu/oc2902w/coord/llhxyz.htm + + >>> lat, lon, alt = 10, 20, 30 + >>> x, y, z = ecef_from_lla(lat, lon, alt) + >>> np.allclose(lla_from_ecef(x,y,z), [lat, lon, alt]) + True + """ + a2 = WGS84_a**2 + b2 = WGS84_b**2 + lat = np.radians(lat) + lon = np.radians(lon) + L = 1.0 / np.sqrt(a2 * np.cos(lat) ** 2 + b2 * np.sin(lat) ** 2) + x = (a2 * L + alt) * np.cos(lat) * np.cos(lon) + y = (a2 * L + alt) * np.cos(lat) * np.sin(lon) + z = (b2 * L + alt) * np.sin(lat) + return x, y, z + + +def lla_from_ecef(x, y, z): + """ + Compute latitude, longitude and altitude from ECEF XYZ. + + All using the WGS84 model. + Altitude is the distance to the WGS84 ellipsoid. + """ + a = WGS84_a + b = WGS84_b + ea = np.sqrt((a**2 - b**2) / a**2) + eb = np.sqrt((a**2 - b**2) / b**2) + p = np.sqrt(x**2 + y**2) + theta = np.arctan2(z * a, p * b) + lon = np.arctan2(y, x) + lat = np.arctan2( + z + eb**2 * b * np.sin(theta) ** 3, p - ea**2 * a * np.cos(theta) ** 3 + ) + N = a / np.sqrt(1 - ea**2 * np.sin(lat) ** 2) + alt = p / np.cos(lat) - N + return np.degrees(lat), np.degrees(lon), alt + + +def ecef_from_topocentric_transform(lat, lon, alt: float) -> ndarray: + """ + Transformation from a topocentric frame at reference position to ECEF. + + The topocentric reference frame is a metric one with the origin + at the given (lat, lon, alt) position, with the X axis heading east, + the Y axis heading north and the Z axis vertical to the ellipsoid. + >>> a = ecef_from_topocentric_transform(30, 20, 10) + >>> b = ecef_from_topocentric_transform_finite_diff(30, 20, 10) + >>> np.allclose(a, b) + True + """ + x, y, z = ecef_from_lla(lat, lon, alt) + sa = np.sin(np.radians(lat)) + ca = np.cos(np.radians(lat)) + so = np.sin(np.radians(lon)) + co = np.cos(np.radians(lon)) + return np.array( + [ + [-so, -sa * co, ca * co, x], + [co, -sa * so, ca * so, y], + [0, ca, sa, z], + [0, 0, 0, 1], + ] + ) + + +def ecef_from_topocentric_transform_finite_diff(lat, lon, alt: float) -> ndarray: + """ + Transformation from a topocentric frame at reference position to ECEF. + + The topocentric reference frame is a metric one with the origin + at the given (lat, lon, alt) position, with the X axis heading east, + the Y axis heading north and the Z axis vertical to the ellipsoid. + """ + eps = 1e-2 + x, y, z = ecef_from_lla(lat, lon, alt) + v1 = ( + ( + np.array(ecef_from_lla(lat, lon + eps, alt)) + - np.array(ecef_from_lla(lat, lon - eps, alt)) + ) + / 2 + / eps + ) + v2 = ( + ( + np.array(ecef_from_lla(lat + eps, lon, alt)) + - np.array(ecef_from_lla(lat - eps, lon, alt)) + ) + / 2 + / eps + ) + v3 = ( + ( + np.array(ecef_from_lla(lat, lon, alt + eps)) + - np.array(ecef_from_lla(lat, lon, alt - eps)) + ) + / 2 + / eps + ) + v1 /= np.linalg.norm(v1) + v2 /= np.linalg.norm(v2) + v3 /= np.linalg.norm(v3) + return np.array( + [ + [v1[0], v2[0], v3[0], x], + [v1[1], v2[1], v3[1], y], + [v1[2], v2[2], v3[2], z], + [0, 0, 0, 1], + ] + ) + + +def topocentric_from_lla(lat, lon, alt: float, reflat, reflon, refalt: float): + """ + Transform from lat, lon, alt to topocentric XYZ. + + >>> lat, lon, alt = -10, 20, 100 + >>> np.allclose(topocentric_from_lla(lat, lon, alt, lat, lon, alt), + ... [0,0,0]) + True + >>> x, y, z = topocentric_from_lla(lat, lon, alt, 0, 0, 0) + >>> np.allclose(lla_from_topocentric(x, y, z, 0, 0, 0), + ... [lat, lon, alt]) + True + """ + T = np.linalg.inv(ecef_from_topocentric_transform(reflat, reflon, refalt)) + x, y, z = ecef_from_lla(lat, lon, alt) + tx = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3] + ty = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3] + tz = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3] + return tx, ty, tz + + +def lla_from_topocentric(x, y, z, reflat, reflon, refalt: float): + """ + Transform from topocentric XYZ to lat, lon, alt. + """ + T = ecef_from_topocentric_transform(reflat, reflon, refalt) + ex = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3] + ey = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3] + ez = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3] + return lla_from_ecef(ex, ey, ez) + + +class TopocentricConverter(object): + """Convert to and from a topocentric reference frame.""" + + def __init__(self, reflat, reflon, refalt): + """Init the converter given the reference origin.""" + self.lat = reflat + self.lon = reflon + self.alt = refalt + + def to_topocentric(self, lat, lon, alt): + """Convert lat, lon, alt to topocentric x, y, z.""" + return topocentric_from_lla(lat, lon, alt, self.lat, self.lon, self.alt) + + def to_lla(self, x, y, z): + """Convert topocentric x, y, z to lat, lon, alt.""" + return lla_from_topocentric(x, y, z, self.lat, self.lon, self.alt) + + def __eq__(self, o): + return np.allclose([self.lat, self.lon, self.alt], (o.lat, o.lon, o.alt)) diff --git a/utils/geometry.py b/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcbcba7c41e689e9dd9e35fe33e7787fdd13b03 --- /dev/null +++ b/utils/geometry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import numpy as np +import torch + + +def from_homogeneous(points, eps: float = 1e-8): + """Remove the homogeneous dimension of N-dimensional points. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N+1). + Returns: + A torch.Tensor or numpy ndarray with size (..., N). + """ + return points[..., :-1] / (points[..., -1:] + eps) + + +def to_homogeneous(points): + """Convert N-dimensional points to homogeneous coordinates. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N). + Returns: + A torch.Tensor or numpy.ndarray with size (..., N+1). + """ + if isinstance(points, torch.Tensor): + pad = points.new_ones(points.shape[:-1] + (1,)) + return torch.cat([points, pad], dim=-1) + elif isinstance(points, np.ndarray): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + else: + raise ValueError + + +@torch.jit.script +def undistort_points(pts, dist): + dist = dist.unsqueeze(-2) # add point dimension + ndist = dist.shape[-1] + undist = pts + valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool) + if ndist > 0: + k1, k2 = dist[..., :2].split(1, -1) + r2 = torch.sum(pts**2, -1, keepdim=True) + radial = k1 * r2 + k2 * r2**2 + undist = undist + pts * radial + + # The distortion model is supposedly only valid within the image + # boundaries. Because of the negative radial distortion, points that + # are far outside of the boundaries might actually be mapped back + # within the image. To account for this, we discard points that are + # beyond the inflection point of the distortion model, + # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0 + limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0)) + limit = torch.abs( + torch.where( + k2 > 0, + (torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2), + 1 / (3 * k1), + ) + ) + valid = valid & torch.squeeze(~limited | (r2 < limit), -1) + + if ndist > 2: + p12 = dist[..., 2:] + p21 = p12.flip(-1) + uv = torch.prod(pts, -1, keepdim=True) + undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2) + + return undist, valid diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..e092ef8695a41f2fae6da4e03d7cad74eab5cadf --- /dev/null +++ b/utils/io.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import json +import requests +import shutil +from pathlib import Path + +import cv2 +import numpy as np +import torch +from tqdm.auto import tqdm + +import logger + +DATA_URL = "https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023" + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f"Cannot read image {path}.") + if not grayscale and len(image.shape) == 3: + image = np.ascontiguousarray(image[:, :, ::-1]) # BGR to RGB + return image + + +def write_torch_image(path, image): + image_cv2 = np.round(image.clip(0, 1) * 255).astype(int)[..., ::-1] + cv2.imwrite(str(path), image_cv2) + + +class JSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (np.ndarray, torch.Tensor)): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + return json.JSONEncoder.default(self, obj) + + +def write_json(path, data): + with open(path, "w") as f: + json.dump(data, f, cls=JSONEncoder) + + +def download_file(url, path): + path = Path(path) + if path.is_dir(): + path = path / Path(url).name + path.parent.mkdir(exist_ok=True, parents=True) + logger.info("Downloading %s to %s.", url, path) + with requests.get(url, stream=True) as r: + total_length = int(r.headers.get("Content-Length")) + with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw: + with open(path, "wb") as output: + shutil.copyfileobj(raw, output) + return path diff --git a/utils/tools.py b/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..adbc3eae97c599d9ce606426ae5628484b9e2499 --- /dev/null +++ b/utils/tools.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import time + + +class Timer: + def __init__(self, name=None): + self.name = name + + def __enter__(self): + self.tstart = time.time() + return self + + def __exit__(self, type, value, traceback): + self.duration = time.time() - self.tstart + if self.name is not None: + print("[%s] Elapsed: %s" % (self.name, self.duration)) diff --git a/utils/viz_2d.py b/utils/viz_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ac93d002e2666559f3d77dd54c4f52d465e96869 --- /dev/null +++ b/utils/viz_2d.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from Hierarchical-Localization, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/utils/viz.py +# Released under the Apache License 2.0 + +import matplotlib +import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt +import numpy as np + + +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + adaptive: whether the figure size should fit the image aspect ratios. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H + else: + ratios = [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] + fig, ax = plt.subplots( + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + return fig + + +def plot_keypoints(kpts, colors="lime", ps=4): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=color[i], + linewidth=lw, + alpha=a, + ) + for i in range(len(kpts0)) + ] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=2, + ha="left", + va="top", + normalized=True, + zorder=3, +): + ax = plt.gcf().axes[idx] + tfm = ax.transAxes if normalized else ax.transData + t = ax.text( + *pos, + text, + fontsize=fs, + ha=ha, + va=va, + color=color, + transform=tfm, + clip_on=True, + zorder=zorder, + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) + + +def features_to_RGB(*Fs, masks=None, skip=1): + """Project a list of d-dimensional feature maps to RGB colors using PCA.""" + from sklearn.decomposition import PCA + + def normalize(x): + return x / np.linalg.norm(x, axis=-1, keepdims=True) + + if masks is not None: + assert len(Fs) == len(masks) + + flatten = [] + for i, F in enumerate(Fs): + c, h, w = F.shape + F = np.rollaxis(F, 0, 3) + F_flat = F.reshape(-1, c) + if masks is not None and masks[i] is not None: + mask = masks[i] + assert mask.shape == F.shape[:2] + F_flat = F_flat[mask.reshape(-1)] + flatten.append(F_flat) + flatten = np.concatenate(flatten, axis=0) + flatten = normalize(flatten) + + pca = PCA(n_components=3) + if skip > 1: + pca.fit(flatten[::skip]) + flatten = pca.transform(flatten) + else: + flatten = pca.fit_transform(flatten) + flatten = (normalize(flatten) + 1) / 2 + + Fs_rgb = [] + for i, F in enumerate(Fs): + h, w = F.shape[-2:] + if masks is None or masks[i] is None: + F_rgb, flatten = np.split(flatten, [h * w], axis=0) + F_rgb = F_rgb.reshape((h, w, 3)) + else: + F_rgb = np.zeros((h, w, 3)) + indices = np.where(masks[i]) + F_rgb[indices], flatten = np.split(flatten, [len(indices[0])], axis=0) + F_rgb = np.concatenate([F_rgb, masks[i][..., None]], axis=-1) + Fs_rgb.append(F_rgb) + assert flatten.shape[0] == 0, flatten.shape + return Fs_rgb diff --git a/utils/viz_localization.py b/utils/viz_localization.py new file mode 100644 index 0000000000000000000000000000000000000000..3e15da2329ae2956c4eb663dfcd3a0cf841f2b84 --- /dev/null +++ b/utils/viz_localization.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import copy + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def likelihood_overlay( + prob, map_viz=None, p_rgb=0.2, p_alpha=1 / 15, thresh=None, cmap="jet" +): + prob = prob / prob.max() + cmap = plt.get_cmap(cmap) + rgb = cmap(prob**p_rgb) + alpha = prob[..., None] ** p_alpha + if thresh is not None: + alpha[prob <= thresh] = 0 + if map_viz is not None: + faded = map_viz + (1 - map_viz) * 0.5 + rgb = rgb[..., :3] * alpha + faded * (1 - alpha) + rgb = np.clip(rgb, 0, 1) + else: + rgb[..., -1] = alpha.squeeze(-1) + return rgb + + +def heatmap2rgb(scores, mask=None, clip_min=0.05, alpha=0.8, cmap="jet"): + min_, max_ = np.quantile(scores, [clip_min, 1]) + scores = scores.clip(min=min_) + rgb = plt.get_cmap(cmap)((scores - min_) / (max_ - min_)) + if mask is not None: + if alpha == 0: + rgb[mask] = np.nan + else: + rgb[..., -1] = 1 - (1 - 1.0 * mask) * (1 - alpha) + return rgb + + +def plot_pose(axs, xy, yaw=None, s=1 / 35, c="r", a=1, w=0.015, dot=True, zorder=10): + if yaw is not None: + yaw = np.deg2rad(yaw) + uv = np.array([np.sin(yaw), -np.cos(yaw)]) + xy = np.array(xy) + 0.5 + if not isinstance(axs, list): + axs = [axs] + for ax in axs: + if isinstance(ax, int): + ax = plt.gcf().axes[ax] + if dot: + ax.scatter(*xy, c=c, s=70, zorder=zorder, linewidths=0, alpha=a) + if yaw is not None: + ax.quiver( + *xy, + *uv, + scale=s, + scale_units="xy", + angles="xy", + color=c, + zorder=zorder, + alpha=a, + width=w, + ) + + +def plot_dense_rotations( + ax, prob, thresh=0.01, skip=10, s=1 / 15, k=3, c="k", w=None, **kwargs +): + t = torch.argmax(prob, -1) + yaws = t.numpy() / prob.shape[-1] * 360 + prob = prob.max(-1).values / prob.max() + mask = prob > thresh + masked = prob.masked_fill(~mask, 0) + max_ = torch.nn.functional.max_pool2d( + masked.float()[None, None], k, stride=1, padding=k // 2 + ) + mask = (max_[0, 0] == masked.float()) & mask + indices = np.where(mask.numpy() > 0) + plot_pose( + ax, + indices[::-1], + yaws[indices], + s=s, + c=c, + dot=False, + zorder=0.1, + w=w, + **kwargs, + ) + + +def copy_image(im, ax): + prop = im.properties() + prop.pop("children") + prop.pop("size") + prop.pop("tightbbox") + prop.pop("transformed_clip_path_and_affine") + prop.pop("window_extent") + prop.pop("figure") + prop.pop("transform") + return ax.imshow(im.get_array(), **prop) + + +def add_circle_inset( + ax, + center, + corner=None, + radius_px=10, + inset_size=0.4, + inset_offset=0.005, + color="red", +): + data_t_axes = ax.transAxes + ax.transData.inverted() + if corner is None: + center_axes = np.array(data_t_axes.inverted().transform(center)) + corner = 1 - np.round(center_axes).astype(int) + corner = np.array(corner) + bottom_left = corner * (1 - inset_size - inset_offset) + (1 - corner) * inset_offset + axins = ax.inset_axes([*bottom_left, inset_size, inset_size]) + if ax.yaxis_inverted(): + axins.invert_yaxis() + axins.set_axis_off() + + c = mpl.patches.Circle(center, radius_px, fill=False, color=color) + c1 = mpl.patches.Circle(center, radius_px, fill=False, color=color) + # ax.add_patch(c) + ax.add_patch(c1) + # ax.add_patch(c.frozen()) + axins.add_patch(c) + + radius_inset = radius_px + 1 + axins.set_xlim([center[0] - radius_inset, center[0] + radius_inset]) + ylim = center[1] - radius_inset, center[1] + radius_inset + if axins.yaxis_inverted(): + ylim = ylim[::-1] + axins.set_ylim(ylim) + + for im in ax.images: + im2 = copy_image(im, axins) + im2.set_clip_path(c) + return axins + + +def plot_bev(bev, uv, yaw, ax=None, zorder=10, **kwargs): + if ax is None: + ax = plt.gca() + h, w = bev.shape[:2] + tfm = mpl.transforms.Affine2D().translate(-w / 2, -h) + tfm = tfm.rotate_deg(yaw).translate(*uv + 0.5) + tfm += plt.gca().transData + ax.imshow(bev, transform=tfm, zorder=zorder, **kwargs) + ax.plot( + [0, w - 1, w / 2, 0], + [0, 0, h - 0.5, 0], + transform=tfm, + c="k", + lw=1, + zorder=zorder + 1, + ) diff --git a/utils/wrappers.py b/utils/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8b35a35fdc83569e943735e7433ed72da1343e --- /dev/null +++ b/utils/wrappers.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/pixloc +# Released under the Apache License 2.0 + +""" +Convenience classes for an SE3 pose and a pinhole Camera with lens distortion. +Based on PyTorch tensors: differentiable, batched, with GPU support. +""" + +import functools +import inspect +import math +from typing import Dict, List, NamedTuple, Tuple, Union + +import numpy as np +import torch + +from .geometry import undistort_points + + +def autocast(func): + """Cast the inputs of a TensorWrapper method to PyTorch tensors + if they are numpy arrays. Use the device and dtype of the wrapper. + """ + + @functools.wraps(func) + def wrap(self, *args): + device = torch.device("cpu") + dtype = None + if isinstance(self, TensorWrapper): + if self._data is not None: + device = self.device + dtype = self.dtype + elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): + raise ValueError(self) + + cast_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.from_numpy(arg) + arg = arg.to(device=device, dtype=dtype) + cast_args.append(arg) + return func(self, *cast_args) + + return wrap + + +class TensorWrapper: + _data = None + + @autocast + def __init__(self, data: torch.Tensor): + self._data = data + + @property + def shape(self): + return self._data.shape[:-1] + + @property + def device(self): + return self._data.device + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, index): + return self.__class__(self._data[index]) + + def __setitem__(self, index, item): + self._data[index] = item.data + + def to(self, *args, **kwargs): + return self.__class__(self._data.to(*args, **kwargs)) + + def cpu(self): + return self.__class__(self._data.cpu()) + + def cuda(self): + return self.__class__(self._data.cuda()) + + def pin_memory(self): + return self.__class__(self._data.pin_memory()) + + def float(self): + return self.__class__(self._data.float()) + + def double(self): + return self.__class__(self._data.double()) + + def detach(self): + return self.__class__(self._data.detach()) + + @classmethod + def stack(cls, objects: List, dim=0, *, out=None): + data = torch.stack([obj._data for obj in objects], dim=dim, out=out) + return cls(data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.stack: + return cls.stack(*args, **kwargs) + else: + return NotImplemented + + +class Pose(TensorWrapper): + def __init__(self, data: torch.Tensor): + assert data.shape[-1] == 12 + super().__init__(data) + + @classmethod + @autocast + def from_Rt(cls, R: torch.Tensor, t: torch.Tensor): + """Pose from a rotation matrix and translation vector. + Accepts numpy arrays or PyTorch tensors. + + Args: + R: rotation matrix with shape (..., 3, 3). + t: translation vector with shape (..., 3). + """ + assert R.shape[-2:] == (3, 3) + assert t.shape[-1] == 3 + assert R.shape[:-2] == t.shape[:-1] + data = torch.cat([R.flatten(start_dim=-2), t], -1) + return cls(data) + + @classmethod + def from_4x4mat(cls, T: torch.Tensor): + """Pose from an SE(3) transformation matrix. + Args: + T: transformation matrix with shape (..., 4, 4). + """ + assert T.shape[-2:] == (4, 4) + R, t = T[..., :3, :3], T[..., :3, 3] + return cls.from_Rt(R, t) + + @classmethod + def from_colmap(cls, image: NamedTuple): + """Pose from a COLMAP Image.""" + return cls.from_Rt(image.qvec2rotmat(), image.tvec) + + @property + def R(self) -> torch.Tensor: + """Underlying rotation matrix with shape (..., 3, 3).""" + rvec = self._data[..., :9] + return rvec.reshape(rvec.shape[:-1] + (3, 3)) + + @property + def t(self) -> torch.Tensor: + """Underlying translation vector with shape (..., 3).""" + return self._data[..., -3:] + + def inv(self) -> "Pose": + """Invert an SE(3) pose.""" + R = self.R.transpose(-1, -2) + t = -(R @ self.t.unsqueeze(-1)).squeeze(-1) + return self.__class__.from_Rt(R, t) + + def compose(self, other: "Pose") -> "Pose": + """Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C.""" + R = self.R @ other.R + t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1) + return self.__class__.from_Rt(R, t) + + @autocast + def transform(self, p3d: torch.Tensor) -> torch.Tensor: + """Transform a set of 3D points. + Args: + p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). + """ + assert p3d.shape[-1] == 3 + # assert p3d.shape[:-2] == self.shape # allow broadcasting + return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2) + + def __matmul__( + self, other: Union["Pose", torch.Tensor] + ) -> Union["Pose", torch.Tensor]: + """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B. + or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C.""" + if isinstance(other, self.__class__): + return self.compose(other) + else: + return self.transform(other) + + def numpy(self) -> Tuple[np.ndarray]: + return self.R.numpy(), self.t.numpy() + + def magnitude(self) -> Tuple[torch.Tensor]: + """Magnitude of the SE(3) transformation. + Returns: + dr: rotation anngle in degrees. + dt: translation distance in meters. + """ + trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1) + cos = torch.clamp((trace - 1) / 2, -1, 1) + dr = torch.acos(cos).abs() / math.pi * 180 + dt = torch.norm(self.t, dim=-1) + return dr, dt + + def __repr__(self): + return f"Pose: {self.shape} {self.dtype} {self.device}" + + +class Camera(TensorWrapper): + eps = 1e-4 + + def __init__(self, data: torch.Tensor): + assert data.shape[-1] in {6, 8, 10} + super().__init__(data) + + @classmethod + def from_dict(cls, camera: Union[Dict, NamedTuple]): + """Camera from a COLMAP Camera tuple or dictionary. + We assume that the origin (0, 0) is the center of the top-left pixel. + This is different from COLMAP. + """ + if isinstance(camera, tuple): + camera = camera._asdict() + + model = camera["model"] + params = camera["params"] + + if model in ["OPENCV", "PINHOLE"]: + (fx, fy, cx, cy), params = np.split(params, [4]) + elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"]: + (f, cx, cy), params = np.split(params, [3]) + fx = fy = f + if model == "SIMPLE_RADIAL": + params = np.r_[params, 0.0] + else: + raise NotImplementedError(model) + + data = np.r_[ + camera["width"], camera["height"], fx, fy, cx - 0.5, cy - 0.5, params + ] + return cls(data) + + @property + def size(self) -> torch.Tensor: + """Size (width height) of the images, with shape (..., 2).""" + return self._data[..., :2] + + @property + def f(self) -> torch.Tensor: + """Focal lengths (fx, fy) with shape (..., 2).""" + return self._data[..., 2:4] + + @property + def c(self) -> torch.Tensor: + """Principal points (cx, cy) with shape (..., 2).""" + return self._data[..., 4:6] + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., {0, 2, 4}).""" + return self._data[..., 6:] + + def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]): + """Update the camera parameters after resizing an image.""" + if isinstance(scales, (int, float)): + scales = (scales, scales) + s = self._data.new_tensor(scales) + data = torch.cat( + [self.size * s, self.f * s, (self.c + 0.5) * s - 0.5, self.dist], -1 + ) + return self.__class__(data) + + def crop(self, left_top: Tuple[float], size: Tuple[int]): + """Update the camera parameters after cropping an image.""" + left_top = self._data.new_tensor(left_top) + size = self._data.new_tensor(size) + data = torch.cat([size, self.f, self.c - left_top, self.dist], -1) + return self.__class__(data) + + @autocast + def in_image(self, p2d: torch.Tensor): + """Check if 2D points are within the image boundaries.""" + assert p2d.shape[-1] == 2 + # assert p2d.shape[:-2] == self.shape # allow broadcasting + size = self.size.unsqueeze(-2) + valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1) + return valid + + @autocast + def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Project 3D points into the camera plane and check for visibility.""" + z = p3d[..., -1] + valid = z > self.eps + z = z.clamp(min=self.eps) + p2d = p3d[..., :-1] / z.unsqueeze(-1) + return p2d, valid + + def J_project(self, p3d: torch.Tensor): + x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2] + zero = torch.zeros_like(z) + J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1) + J = J.reshape(p3d.shape[:-1] + (2, 3)) + return J # N x 2 x 3 + + @autocast + def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates + and check for validity of the distortion model. + """ + assert pts.shape[-1] == 2 + # assert pts.shape[:-2] == self.shape # allow broadcasting + return undistort_points(pts, self.dist) + + @autocast + def denormalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert normalized 2D coordinates into pixel coordinates.""" + return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2) + + @autocast + def normalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert pixel coordinates into normalized 2D coordinates.""" + return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2) + + def J_denormalize(self): + return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2 + + @autocast + def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Transform 3D points into 2D pixel coordinates.""" + p2d, visible = self.project(p3d) + p2d, mask = self.undistort(p2d) + p2d = self.denormalize(p2d) + valid = visible & mask & self.in_image(p2d) + return p2d, valid + + def J_world2image(self, p3d: torch.Tensor): + p2d_dist, valid = self.project(p3d) + J = self.J_denormalize() @ self.J_undistort(p2d_dist) @ self.J_project(p3d) + return J, valid + + def __repr__(self): + return f"Camera {self.shape} {self.dtype} {self.device}"