MapLocNet / models /map_encoder.py
wangerniu
Commit message.
124ba77
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import torch.nn as nn
from .base import BaseModel
from .feature_extractor import FeatureExtractor
class MapEncoder(BaseModel):
default_conf = {
"embedding_dim": "???",
"output_dim": None,
"num_classes": "???",
"backbone": "???",
"unary_prior": False,
}
def _init(self, conf):
self.embeddings = torch.nn.ModuleDict(
{
k: torch.nn.Embedding(n + 1, conf.embedding_dim)
for k, n in conf.num_classes.items()
}
)
#num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33}
input_dim = len(conf.num_classes) * conf.embedding_dim
output_dim = conf.output_dim
if output_dim is None:
output_dim = conf.backbone.output_dim
if conf.unary_prior:
output_dim += 1
if conf.backbone is None:
self.encoder = nn.Conv2d(input_dim, output_dim, 1)
elif conf.backbone == "simple":
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, output_dim, 3, padding=1),
)
else:
self.encoder = FeatureExtractor(
{
**conf.backbone,
"input_dim": input_dim,
"output_dim": output_dim,
}
)
def _forward(self, data):
embeddings = [
self.embeddings[k](data["map"][:, i])
for i, k in enumerate(("areas", "ways", "nodes"))
]
embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
if isinstance(self.encoder, BaseModel):
features = self.encoder({"image": embeddings})["feature_maps"]
else:
features = [self.encoder(embeddings)]
pred = {}
if self.conf.unary_prior:
pred["log_prior"] = [f[:, -1] for f in features]
features = [f[:, :-1] for f in features]
pred["map_features"] = features
return pred