# Copyright (c) Meta Platforms, Inc. and affiliates. import torch import torch.nn as nn from .base import BaseModel from .feature_extractor import FeatureExtractor import numpy as np # from .embeddings import AttentionWeightedEmbedding, import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.functional as F class ImprovedAttentionEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, weight_dim=1, dropout=0.1, weight_init='normal'): super(ImprovedAttentionEmbedding, self).__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.weight_dim = weight_dim # 可学习的权重矩阵 [num_embeddings, weight_dim] if weight_init == 'normal': self.weights = nn.Parameter(torch.randn(num_embeddings, weight_dim)) elif weight_init == 'uniform': self.weights = nn.Parameter(torch.rand(num_embeddings, weight_dim)) else: self.weights = nn.Parameter(torch.ones(num_embeddings, weight_dim)) self.weight_norm = nn.LayerNorm(weight_dim) self.dropout = nn.Dropout(dropout) # L2正则化 self.l2_reg = 1e-5 def forward(self, input): embedded = self.embedding(input) # [batch, 256, 256, embedding_dim] # 获取权重,并进行归一化 weight = self.weights[input] # [batch, 256, 256, weight_dim] weight = self.weight_norm(weight) weight = F.softmax(weight, dim=-1) # 对嵌入向量进行加权 if self.weight_dim == 1: weighted_embedded = embedded * weight # [batch, 256, 256, embedding_dim] else: weighted_embedded = embedded * weight.unsqueeze(-1) weighted_embedded = self.dropout(weighted_embedded) return weighted_embedded def get_l2_reg(self): return self.l2_reg * (self.weights ** 2).sum() class AttentionWeightedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(AttentionWeightedEmbedding, self).__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.query = nn.Parameter(torch.randn(embedding_dim)) # 可训练的查询向量 self.softmax = nn.Softmax(dim=-1) def forward(self, input): # 获取嵌入向量 embedded = self.embedding(input) # Shape: [batch_size, sequence_length, embedding_dim] # 计算注意力得分 attn_scores = torch.matmul(embedded, self.query) # Shape: [batch_size, sequence_length] # 归一化注意力得分以得到权重 attn_weights = self.softmax(attn_scores).unsqueeze(-1) # Shape: [batch_size, sequence_length, 1] # 对嵌入向量应用权重 weighted_embedded = embedded * attn_weights # Shape: [batch_size, sequence_length, embedding_dim] return weighted_embedded class WeightedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(WeightedEmbedding, self).__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim) # 可学习的权重矩阵 [num_embeddings, 1] self.weights = nn.Parameter(torch.ones(num_embeddings, 1)) def forward(self, input): embedded = self.embedding(input) # [batch, 256, 256, embedding_dim] # 获取权重,并扩展维度以便进行广播运算 weight = self.weights[input] # [batch, 256, 256, 1] # 对嵌入向量进行按元素乘法 weighted_embedded = embedded * weight # [batch, 256, 256, embedding_dim] return weighted_embedded class MapEncoderSingle(BaseModel): default_conf = { "embedding_dim": "???", "output_dim": None, "num_classes": "???", "backbone": "???", "unary_prior": False, "weighted_embedding": False } def _init(self, conf): if conf.weighted_embedding==False: self.embeddings = torch.nn.ModuleDict( { k: torch.nn.Embedding(n + 1, conf.embedding_dim) for k, n in conf.num_classes.items() } ) else: if conf.weighted_embedding=="AttentionWeightedEmbedding": self.embeddings = torch.nn.ModuleDict( { k: AttentionWeightedEmbedding(n + 1, conf.embedding_dim) for k, n in conf.num_classes.items() } ) elif conf.weighted_embedding=="WeightedEmbedding": self.embeddings = torch.nn.ModuleDict( { k: WeightedEmbedding(n + 1, conf.embedding_dim) for k, n in conf.num_classes.items() } ) elif conf.weighted_embedding=="ImprovedAttentionEmbedding": self.embeddings = torch.nn.ModuleDict( { k: ImprovedAttentionEmbedding(n + 1, conf.embedding_dim) for k, n in conf.num_classes.items() } ) else: pass #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33} input_dim = len(conf.num_classes) * conf.embedding_dim output_dim = conf.output_dim if output_dim is None: output_dim = conf.backbone.output_dim if conf.unary_prior: output_dim += 1 if conf.backbone is None: self.encoder = nn.Conv2d(input_dim, output_dim, 1) elif conf.backbone == "simple": self.encoder = nn.Sequential( nn.Conv2d(input_dim, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, output_dim, 3, padding=1), ) else: self.encoder = FeatureExtractor( { **conf.backbone, "input_dim": input_dim, "output_dim": output_dim, } ) def batch_process(self,input_tensor): # 获取输入张量的维度 batch_size, dim1, dim2, dim3 = input_tensor.shape # 首先,我们需要对第一个索引为0的二维数组中的非零元素增加43 input_tensor[:, 0, :, :] += torch.where(input_tensor[:, 0, :, :] != 0, 43, 0) # 接着,对第一个索引为1的二维数组中的非零元素增加33 input_tensor[:, 1, :, :] += torch.where(input_tensor[:, 1, :, :] != 0, 33, 0) # 创建一个全零的输出张量 output_tensor = torch.zeros((batch_size, dim2, dim3), dtype=input_tensor.dtype, device=input_tensor.device) # 找到输入张量中至少有一个非零值的位置 nonzero_mask = torch.any(input_tensor != 0, dim=1) # 根据优先级赋值 output_tensor[nonzero_mask] = input_tensor[:, 2, :, :][nonzero_mask] output_tensor[nonzero_mask] = torch.where(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask], output_tensor[nonzero_mask]) output_tensor[nonzero_mask] = torch.where(torch.logical_and(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask] == 0), input_tensor[:, 0, :, :][nonzero_mask], output_tensor[nonzero_mask]) return output_tensor def _forward(self, data): temp=data["map"] temp=self.batch_process(temp) # a=self.embeddings["all"] # print("temp",temp.shape,data["map"].shape) # 找到tensor中的最大值 # max_value = temp.max() # print("最大值是:", max_value.item()) embeddings = self.embeddings["all"](temp)#shape:[batch,256,256,48] # print("embeddings.shape A",embeddings.shape) embeddings =embeddings.permute(0, 3, 1, 2) # print("embeddings.shape B",embeddings.shape) # print("Single",embeddings.shape) pass if isinstance(self.encoder, BaseModel): # print("encoder is BaseModel") features = self.encoder({"image": embeddings})["feature_maps"] else: # print("encoder is not BaseModel") features = [self.encoder(embeddings)] pred = {} if self.conf.unary_prior: pred["log_prior"] = [f[:, -1] for f in features] features = [f[:, :-1] for f in features] pred["map_features"] = features#6,8,256,256 list of tensor ,shape:[6,8, 256, 256] return pred