# Copyright (c) Meta Platforms, Inc. and affiliates. import torch import torch.nn as nn from .base import BaseModel from .feature_extractor import FeatureExtractor class MapEncoder(BaseModel): default_conf = { "embedding_dim": "???", "output_dim": None, "num_classes": "???", "backbone": "???", "unary_prior": False, } def _init(self, conf): self.embeddings = torch.nn.ModuleDict( { k: torch.nn.Embedding(n + 1, conf.embedding_dim) for k, n in conf.num_classes.items() } ) #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33} input_dim = len(conf.num_classes) * conf.embedding_dim output_dim = conf.output_dim if output_dim is None: output_dim = conf.backbone.output_dim if conf.unary_prior: output_dim += 1 if conf.backbone is None: self.encoder = nn.Conv2d(input_dim, output_dim, 1) elif conf.backbone == "simple": self.encoder = nn.Sequential( nn.Conv2d(input_dim, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, output_dim, 3, padding=1), ) else: self.encoder = FeatureExtractor( { **conf.backbone, "input_dim": input_dim, "output_dim": output_dim, } ) def _forward(self, data): embeddings = [ self.embeddings[k](data["map"][:, i]) for i, k in enumerate(("areas", "ways", "nodes")) ] embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2) if isinstance(self.encoder, BaseModel): features = self.encoder({"image": embeddings})["feature_maps"] else: features = [self.encoder(embeddings)] pred = {} if self.conf.unary_prior: pred["log_prior"] = [f[:, -1] for f in features] features = [f[:, :-1] for f in features] pred["map_features"] = features return pred