wangerniu
maplocnet
629144d
raw
history blame
2.99 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
from pathlib import Path
from typing import Optional, Tuple
from omegaconf import OmegaConf, DictConfig
from .. import logger
from ..conf import data as conf_data_dir
from ..data import MapillaryDataModule
from .run import evaluate
split_overrides = {
"val": {
"scenes": [
"sanfrancisco_soma",
"sanfrancisco_hayes",
"amsterdam",
"berlin",
"lemans",
"montrouge",
"toulouse",
"nantes",
"vilnius",
"avignon",
"helsinki",
"milan",
"paris",
],
},
}
data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
data_cfg = OmegaConf.merge(
data_cfg_train,
{
"return_gps": True,
"add_map_mask": True,
"max_init_error": 32,
"loading": {"val": {"batch_size": 1, "num_workers": 0}},
},
)
default_cfg_single = OmegaConf.create({"data": data_cfg})
default_cfg_sequential = OmegaConf.create(
{
**default_cfg_single,
"chunking": {
"max_length": 10,
},
}
)
def run(
split: str,
experiment: str,
cfg: Optional[DictConfig] = None,
sequential: bool = False,
thresholds: Tuple[int] = (1, 3, 5),
**kwargs,
):
cfg = cfg or {}
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
default = default_cfg_sequential if sequential else default_cfg_single
default = OmegaConf.merge(default, split_overrides[split])
cfg = OmegaConf.merge(default, cfg)
dataset = MapillaryDataModule(cfg.get("data", {}))
metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
keys = [
"xy_max_error",
"xy_gps_error",
"yaw_max_error",
]
if sequential:
keys += [
"xy_seq_error",
"xy_gps_seq_error",
"yaw_seq_error",
"yaw_gps_seq_error",
]
for k in keys:
if k not in metrics:
logger.warning("Key %s not in metrics.", k)
continue
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--experiment", type=str, required=True)
parser.add_argument("--split", type=str, default="val", choices=["val"])
parser.add_argument("--sequential", action="store_true")
parser.add_argument("--output_dir", type=Path)
parser.add_argument("--num", type=int)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_args()
cfg = OmegaConf.from_cli(args.dotlist)
run(
args.split,
args.experiment,
cfg,
args.sequential,
output_dir=args.output_dir,
num=args.num,
)