Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import torch | |
import torch.nn as nn | |
from .base import BaseModel | |
from .feature_extractor import FeatureExtractor | |
import numpy as np | |
# from .embeddings import AttentionWeightedEmbedding, | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.functional as F | |
class ImprovedAttentionEmbedding(nn.Module): | |
def __init__(self, num_embeddings, embedding_dim, weight_dim=1, dropout=0.1, weight_init='normal'): | |
super(ImprovedAttentionEmbedding, self).__init__() | |
self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
self.weight_dim = weight_dim | |
# 可学习的权重矩阵 [num_embeddings, weight_dim] | |
if weight_init == 'normal': | |
self.weights = nn.Parameter(torch.randn(num_embeddings, weight_dim)) | |
elif weight_init == 'uniform': | |
self.weights = nn.Parameter(torch.rand(num_embeddings, weight_dim)) | |
else: | |
self.weights = nn.Parameter(torch.ones(num_embeddings, weight_dim)) | |
self.weight_norm = nn.LayerNorm(weight_dim) | |
self.dropout = nn.Dropout(dropout) | |
# L2正则化 | |
self.l2_reg = 1e-5 | |
def forward(self, input): | |
embedded = self.embedding(input) # [batch, 256, 256, embedding_dim] | |
# 获取权重,并进行归一化 | |
weight = self.weights[input] # [batch, 256, 256, weight_dim] | |
weight = self.weight_norm(weight) | |
weight = F.softmax(weight, dim=-1) | |
# 对嵌入向量进行加权 | |
if self.weight_dim == 1: | |
weighted_embedded = embedded * weight # [batch, 256, 256, embedding_dim] | |
else: | |
weighted_embedded = embedded * weight.unsqueeze(-1) | |
weighted_embedded = self.dropout(weighted_embedded) | |
return weighted_embedded | |
def get_l2_reg(self): | |
return self.l2_reg * (self.weights ** 2).sum() | |
class AttentionWeightedEmbedding(nn.Module): | |
def __init__(self, num_embeddings, embedding_dim): | |
super(AttentionWeightedEmbedding, self).__init__() | |
self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
self.query = nn.Parameter(torch.randn(embedding_dim)) # 可训练的查询向量 | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, input): | |
# 获取嵌入向量 | |
embedded = self.embedding(input) # Shape: [batch_size, sequence_length, embedding_dim] | |
# 计算注意力得分 | |
attn_scores = torch.matmul(embedded, self.query) # Shape: [batch_size, sequence_length] | |
# 归一化注意力得分以得到权重 | |
attn_weights = self.softmax(attn_scores).unsqueeze(-1) # Shape: [batch_size, sequence_length, 1] | |
# 对嵌入向量应用权重 | |
weighted_embedded = embedded * attn_weights # Shape: [batch_size, sequence_length, embedding_dim] | |
return weighted_embedded | |
class WeightedEmbedding(nn.Module): | |
def __init__(self, num_embeddings, embedding_dim): | |
super(WeightedEmbedding, self).__init__() | |
self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
# 可学习的权重矩阵 [num_embeddings, 1] | |
self.weights = nn.Parameter(torch.ones(num_embeddings, 1)) | |
def forward(self, input): | |
embedded = self.embedding(input) # [batch, 256, 256, embedding_dim] | |
# 获取权重,并扩展维度以便进行广播运算 | |
weight = self.weights[input] # [batch, 256, 256, 1] | |
# 对嵌入向量进行按元素乘法 | |
weighted_embedded = embedded * weight # [batch, 256, 256, embedding_dim] | |
return weighted_embedded | |
class MapEncoderSingle(BaseModel): | |
default_conf = { | |
"embedding_dim": "???", | |
"output_dim": None, | |
"num_classes": "???", | |
"backbone": "???", | |
"unary_prior": False, | |
"weighted_embedding": False | |
} | |
def _init(self, conf): | |
if conf.weighted_embedding==False: | |
self.embeddings = torch.nn.ModuleDict( | |
{ | |
k: torch.nn.Embedding(n + 1, conf.embedding_dim) | |
for k, n in conf.num_classes.items() | |
} | |
) | |
else: | |
if conf.weighted_embedding=="AttentionWeightedEmbedding": | |
self.embeddings = torch.nn.ModuleDict( | |
{ | |
k: AttentionWeightedEmbedding(n + 1, conf.embedding_dim) | |
for k, n in conf.num_classes.items() | |
} | |
) | |
elif conf.weighted_embedding=="WeightedEmbedding": | |
self.embeddings = torch.nn.ModuleDict( | |
{ | |
k: WeightedEmbedding(n + 1, conf.embedding_dim) | |
for k, n in conf.num_classes.items() | |
} | |
) | |
elif conf.weighted_embedding=="ImprovedAttentionEmbedding": | |
self.embeddings = torch.nn.ModuleDict( | |
{ | |
k: ImprovedAttentionEmbedding(n + 1, conf.embedding_dim) | |
for k, n in conf.num_classes.items() | |
} | |
) | |
else: | |
pass | |
#num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33} | |
input_dim = len(conf.num_classes) * conf.embedding_dim | |
output_dim = conf.output_dim | |
if output_dim is None: | |
output_dim = conf.backbone.output_dim | |
if conf.unary_prior: | |
output_dim += 1 | |
if conf.backbone is None: | |
self.encoder = nn.Conv2d(input_dim, output_dim, 1) | |
elif conf.backbone == "simple": | |
self.encoder = nn.Sequential( | |
nn.Conv2d(input_dim, 128, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 128, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, output_dim, 3, padding=1), | |
) | |
else: | |
self.encoder = FeatureExtractor( | |
{ | |
**conf.backbone, | |
"input_dim": input_dim, | |
"output_dim": output_dim, | |
} | |
) | |
def batch_process(self,input_tensor): | |
# 获取输入张量的维度 | |
batch_size, dim1, dim2, dim3 = input_tensor.shape | |
# 首先,我们需要对第一个索引为0的二维数组中的非零元素增加43 | |
input_tensor[:, 0, :, :] += torch.where(input_tensor[:, 0, :, :] != 0, 43, 0) | |
# 接着,对第一个索引为1的二维数组中的非零元素增加33 | |
input_tensor[:, 1, :, :] += torch.where(input_tensor[:, 1, :, :] != 0, 33, 0) | |
# 创建一个全零的输出张量 | |
output_tensor = torch.zeros((batch_size, dim2, dim3), dtype=input_tensor.dtype, device=input_tensor.device) | |
# 找到输入张量中至少有一个非零值的位置 | |
nonzero_mask = torch.any(input_tensor != 0, dim=1) | |
# 根据优先级赋值 | |
output_tensor[nonzero_mask] = input_tensor[:, 2, :, :][nonzero_mask] | |
output_tensor[nonzero_mask] = torch.where(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask], output_tensor[nonzero_mask]) | |
output_tensor[nonzero_mask] = torch.where(torch.logical_and(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask] == 0), input_tensor[:, 0, :, :][nonzero_mask], output_tensor[nonzero_mask]) | |
return output_tensor | |
def _forward(self, data): | |
temp=data["map"] | |
temp=self.batch_process(temp) | |
# a=self.embeddings["all"] | |
# print("temp",temp.shape,data["map"].shape) | |
# 找到tensor中的最大值 | |
# max_value = temp.max() | |
# print("最大值是:", max_value.item()) | |
embeddings = self.embeddings["all"](temp)#shape:[batch,256,256,48] | |
# print("embeddings.shape A",embeddings.shape) | |
embeddings =embeddings.permute(0, 3, 1, 2) | |
# print("embeddings.shape B",embeddings.shape) | |
# print("Single",embeddings.shape) | |
pass | |
if isinstance(self.encoder, BaseModel): | |
# print("encoder is BaseModel") | |
features = self.encoder({"image": embeddings})["feature_maps"] | |
else: | |
# print("encoder is not BaseModel") | |
features = [self.encoder(embeddings)] | |
pred = {} | |
if self.conf.unary_prior: | |
pred["log_prior"] = [f[:, -1] for f in features] | |
features = [f[:, :-1] for f in features] | |
pred["map_features"] = features#6,8,256,256 list of tensor ,shape:[6,8, 256, 256] | |
return pred | |