|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .base import BaseModel |
|
from .feature_extractor import FeatureExtractor |
|
|
|
|
|
class MapEncoder(BaseModel): |
|
default_conf = { |
|
"embedding_dim": "???", |
|
"output_dim": None, |
|
"num_classes": "???", |
|
"backbone": "???", |
|
"unary_prior": False, |
|
} |
|
|
|
def _init(self, conf): |
|
self.embeddings = torch.nn.ModuleDict( |
|
{ |
|
k: torch.nn.Embedding(n + 1, conf.embedding_dim) |
|
for k, n in conf.num_classes.items() |
|
} |
|
) |
|
|
|
input_dim = len(conf.num_classes) * conf.embedding_dim |
|
output_dim = conf.output_dim |
|
if output_dim is None: |
|
output_dim = conf.backbone.output_dim |
|
if conf.unary_prior: |
|
output_dim += 1 |
|
if conf.backbone is None: |
|
self.encoder = nn.Conv2d(input_dim, output_dim, 1) |
|
elif conf.backbone == "simple": |
|
self.encoder = nn.Sequential( |
|
nn.Conv2d(input_dim, 128, 3, padding=1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(128, 128, 3, padding=1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(128, output_dim, 3, padding=1), |
|
) |
|
else: |
|
self.encoder = FeatureExtractor( |
|
{ |
|
**conf.backbone, |
|
"input_dim": input_dim, |
|
"output_dim": output_dim, |
|
} |
|
) |
|
|
|
def _forward(self, data): |
|
embeddings = [ |
|
self.embeddings[k](data["map"][:, i]) |
|
for i, k in enumerate(("areas", "ways", "nodes")) |
|
] |
|
embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2) |
|
if isinstance(self.encoder, BaseModel): |
|
features = self.encoder({"image": embeddings})["feature_maps"] |
|
else: |
|
features = [self.encoder(embeddings)] |
|
pred = {} |
|
if self.conf.unary_prior: |
|
pred["log_prior"] = [f[:, -1] for f in features] |
|
features = [f[:, :-1] for f in features] |
|
pred["map_features"] = features |
|
return pred |
|
|