你çsglin commited on
Commit
9c396a5
·
1 Parent(s): 1a6cb78

Add initial model weights, utils.py, transformer_encoder_MoE.py, and README

Browse files
Files changed (10) hide show
  1. Arabidopsis.pt +3 -0
  2. CR.pt +3 -0
  3. EscherichiaColi.pt +3 -0
  4. PC.pt +3 -0
  5. README.md +153 -0
  6. TK.pt +3 -0
  7. homo_circ.pt +3 -0
  8. homo_mrna.pt +3 -0
  9. transformer_encoder_MoE.py +555 -0
  10. utils.py +844 -0
Arabidopsis.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:725088277142fa26aa1781806b107155363a903d0eba542bdcf312f5dbde48c8
3
+ size 534071434
CR.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b69bb49ff4a336cc7a4f47ffb438dba298d67b237f678403a43bee511c8c1928
3
+ size 534069804
EscherichiaColi.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6833a55fc82e708e8e5307dd63be0983504d768d32a1b1ec7a83b616bdf49671
3
+ size 534072130
PC.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe6d83b0824ced0f00c43189a58a9eaecd4acf02226dffb057b27166f08d8f1a
3
+ size 534069804
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - generation
4
+ - protein-sequence
5
+ - rna-sequence
6
+ - pytorch
7
+ ---
8
+
9
+ # Protein to RNA CDS Sequence Generation Model
10
+
11
+ This model is a custom PyTorch model designed to generate RNA CDS sequences from protein sequences. It utilizes a custom transformer-based architecture incorporating an ESM-2 encoder and a Mixture-of-Experts (MoE) layer.
12
+
13
+ ## Model Architecture
14
+
15
+ The model `ActorModel_encoder_esm2` is defined in `utils.py`.
16
+
17
+ The key parameters used for instantiation are:
18
+
19
+ - `d_model`: Dimension of the model's internal representation (768).
20
+ - `nhead`: Number of attention heads (8).
21
+ - `num_encoder_layers`: Number of transformer encoder layers (8).
22
+ - `dim_feedforward`: Dimension of the feedforward network (`d_model * 2`).
23
+ - `esm2_dim`: Dimension of the ESM-2 embeddings (1280 for esm2_t33_650M_UR50D).
24
+ - `dropout`: Dropout rate (0.3).
25
+ - `num_experts`: Number of experts in the MoE layer (6).
26
+ - `top_k_experts`: Number of top experts to use (2).
27
+ - `device`: The device to run the model on.
28
+
29
+ ## Files in this Repository
30
+
31
+ - `homo_mrna.pt`: The PyTorch state_dict of the trained model for Homo sapiens mRNA.
32
+ - `homo_circ.pt`: The PyTorch state_dict of the trained model for Homo sapiens circlar RNA.
33
+ - `Arabidopsis.pt`: The PyTorch state_dict of the trained model for Arabidopsis thaliana mRNA.
34
+ - `CR.pt`: The PyTorch state_dict of the trained model for Chlamydomonas reinhardtii mRNA.
35
+ - `EscherichiaColi.pt`: The PyTorch state_dict of the trained model for Escherichia coli mRNA.
36
+ - `PC.pt`: The PyTorch state_dict of the trained model for Penicillium chrysogenum mRNA.
37
+ - `TK.pt`: The PyTorch state_dict of the trained model for Thermococcus kodakarensis KOD1 mRNA.
38
+ - `utils.py`: Contains the definition of the `ActorModel_encoder_esm2` class and the `Tokenizer` class.
39
+ - `transformer_encoder_MoE.py`: Contains the definition of the `Encoder` class
40
+ - `README.md`: This file.
41
+
42
+ ## How to Load the Model
43
+
44
+ Since this is a custom model, you need to download the `utils.py`,`transformer_encoder_MoE.py`, and the `.pt` file and then instantiate the model class and load the state dictionary.
45
+
46
+ 1. **Download Files:**
47
+ You can download the files using the `huggingface_hub` library:
48
+ ```python
49
+ from huggingface_hub import hf_hub_download
50
+ import os
51
+
52
+ repo_id = "sglin/RNARL"
53
+ local_dir = "./my_RNARL"
54
+
55
+ # Download model weights and utils.py
56
+ hf_hub_download(repo_id=repo_id, filename="homo_mrna.pt", local_dir=local_dir)
57
+ hf_hub_download(repo_id=repo_id, filename="homo_circ.pt", local_dir=local_dir)
58
+ hf_hub_download(repo_id=repo_id, filename="Arabidopsis.pt", local_dir=local_dir)
59
+ hf_hub_download(repo_id=repo_id, filename="CR.pt", local_dir=local_dir)
60
+ hf_hub_download(repo_id=repo_id, filename="EscherichiaColi.pt", local_dir=local_dir)
61
+ hf_hub_download(repo_id=repo_id, filename="PC.pt", local_dir=local_dir)
62
+ hf_hub_download(repo_id=repo_id, filename="TK.pt", local_dir=local_dir)
63
+ hf_hub_download(repo_id=repo_id, filename="utils.py", local_dir=local_dir)
64
+ hf_hub_download(repo_id=repo_id, filename="transformer_encoder_MoE.py", local_dir=local_dir)
65
+
66
+ # Now utils.py,transformer_encoder_MoE.py and model weights are in ./my_RNARL
67
+ ```
68
+
69
+ 2. **Import Model Class:**
70
+
71
+ ```python
72
+ # Assuming you are in or have added ./my_RNARL to your path
73
+ # Example: If in local_dir
74
+ # import sys
75
+ # sys.path.append("./my_RNARL")
76
+ # from utils import Tokenizer, ActorModel_encoder_esm2
77
+
78
+ # Or if you copied utils.py to your current working directory:
79
+ from utils import Tokenizer, ActorModel_encoder_esm2
80
+ ```
81
+
82
+ 3. **Load ESM-2 (Dependency):**
83
+ The model requires the ESM-2 encoder. You'll need to load it separately, typically from Hugging Face Hub.
84
+ ```python
85
+ from transformers import AutoTokenizer, EsmModel
86
+ import torch
87
+
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+
90
+ esm2_tokenizer = AutoTokenizer.from_pretrained("esm2_t33_650M_UR50D")
91
+ esm2_model = EsmModel.from_pretrained("esm2_t33_650M_UR50D").to(device)
92
+ esm2_model.eval()
93
+ esm2_dim = esm2_model.config.hidden_size # Get the actual dimension
94
+ ```
95
+ *Note:* Your original script used a local path (`./esm2_model_t33_650M_UR50D`). Users loading from the Hub will likely prefer loading directly from the official Hugging Face repo unless you explicitly provide the ESM-2 files in your repo (which is usually not necessary as they are already on the Hub).
96
+
97
+ 4. **Instantiate Custom Model and Load Weights:**
98
+ Instantiate your `ActorModel_encoder_esm2` using the parameters from your training script and load the state dictionary.
99
+ ```python
100
+ # Define the parameters used during training
101
+ d_model = 768
102
+ nhead = 8
103
+ num_encoder_layers = 8
104
+ dim_feedforward = d_model * 2 # or the exact value you used
105
+ dropout = 0.3
106
+ num_experts = 6
107
+ top_k_experts = 2
108
+ # vocab_size needs to match your Tokenizer
109
+ tokenizer = Tokenizer() # Instantiate your custom tokenizer
110
+ vocab_size = len(tokenizer.tokens) # Get vocab size from your tokenizer
111
+
112
+ # Instantiate the model
113
+ model = ActorModel_encoder_esm2(
114
+ vocab_size=vocab_size,
115
+ d_model=d_model,
116
+ nhead=nhead,
117
+ num_encoder_layers=num_encoder_layers,
118
+ dim_feedforward=dim_feedforward,
119
+ esm2_dim=esm2_dim, # Use the esm2_model's dimension
120
+ dropout=dropout,
121
+ num_experts=num_experts,
122
+ top_k_experts=top_k_experts,
123
+ device=device
124
+ )
125
+
126
+ # Load the state dictionary
127
+ model_weights_path = os.path.join(local_dir, "homo_mrna.pt")
128
+ model.load_state_dict(torch.load(model_weights_path, map_location=device))
129
+ model.to(device)
130
+ model.eval()
131
+
132
+ print("Model loaded successfully!")
133
+
134
+ # Now you can use the 'model' object for inference
135
+ # Remember you also need your Tokenizer and the ESM-2 tokenizer/model
136
+ ```
137
+
138
+ ## Dependencies
139
+
140
+ - `torch`
141
+ - `transformers`
142
+ - `huggingface_hub`
143
+ - `pandas`
144
+ - `numpy`
145
+ - The specific ESM-2 model used (`esm2_t33_650M_UR50D` or the one you used).
146
+
147
+ ## License
148
+
149
+ [Specify your license here, e.g., MIT, Apache 2.0]
150
+
151
+ ## Contact
152
+
153
+ [Optional: Your email or other contact info]
TK.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3737e4d94a4e266e0c89105a98ddcf8a6e7f444af5ef88810db669cd4ecf084d
3
+ size 534069804
homo_circ.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5249399ad29c6939703e79ef65532f23867ba5ce0ae6b3e4c2e4e6075ae18466
3
+ size 534071260
homo_mrna.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3c825975f8deade11bf02d4783d0703d276f1b5ed7b70a406b01e8843255d66
3
+ size 534069218
transformer_encoder_MoE.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torch.nn.parallel import parallel_apply
6
+ from typing import Tuple, List, Optional, Union
7
+ import torch.utils.checkpoint as checkpoint
8
+
9
+
10
+ class MultiHeadAttention(nn.Module):
11
+ """高效实现的多头注意力机制"""
12
+
13
+ def __init__(self, model_dim: int, n_heads: int):
14
+ super().__init__()
15
+ assert model_dim % n_heads == 0, "model_dim must be divisible by n_heads"
16
+
17
+ self.model_dim = model_dim
18
+ self.d_k = model_dim // n_heads
19
+ self.n_heads = n_heads
20
+
21
+ # 使用单个线性层同时计算Q, K, V投影,减少计算开销
22
+ self.qkv_linear = nn.Linear(model_dim, 3 * model_dim, bias=False)
23
+ self.out_linear = nn.Linear(model_dim, model_dim, bias=False)
24
+
25
+ # 初始化参数,提高训练稳定性
26
+ nn.init.xavier_uniform_(self.qkv_linear.weight)
27
+ nn.init.xavier_uniform_(self.out_linear.weight)
28
+
29
+ self.scale = 1.0 / math.sqrt(self.d_k)
30
+
31
+ def forward(self,
32
+ q: torch.Tensor,
33
+ k: torch.Tensor,
34
+ v: torch.Tensor,
35
+ mask: Optional[torch.Tensor] = None,
36
+ key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
37
+ batch_size = q.size(0)
38
+
39
+ # 如果输入相同,使用更高效的自注意力计算
40
+ is_self_attention = q.data_ptr() == k.data_ptr() == v.data_ptr()
41
+ if is_self_attention:
42
+ # [batch, seq, 3*dim] -> 3 x [batch, seq, dim]
43
+ qkv = self.qkv_linear(q).chunk(3, dim=-1)
44
+ q, k, v = map(
45
+ lambda x: x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2),
46
+ qkv
47
+ )
48
+ else:
49
+ # 使用单独的线性变换进行异源注意力计算
50
+ q = self.qkv_linear(q)[:, :, :self.model_dim]
51
+ k = self.qkv_linear(k)[:, :, self.model_dim:2*self.model_dim]
52
+ v = self.qkv_linear(v)[:, :, 2*self.model_dim:]
53
+
54
+ q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
55
+ k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
56
+ v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
57
+
58
+ # 缩放点积注意力计算
59
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
60
+
61
+ # 掩码处理 (提高数值稳定性)
62
+ if mask is not None:
63
+ scores = scores.masked_fill(mask == 0, -6.0e4)
64
+ if key_padding_mask is not None:
65
+ scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -6.0e4)
66
+
67
+ attn_weights = F.softmax(scores, dim=-1)
68
+
69
+ # 使用注意力权重聚合值
70
+ context = torch.matmul(attn_weights, v)
71
+
72
+ # 重新组织维度并线性投影输出
73
+ context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.model_dim)
74
+ output = self.out_linear(context)
75
+
76
+ return output
77
+
78
+
79
+ class MoE(nn.Module):
80
+ """优化的混合专家模块,支持并行计算和更高效的专家选择
81
+
82
+ Args:
83
+ d_model (int): 模型隐藏层维度
84
+ num_experts (int): 专家数量
85
+ d_ff (int): 前馈层维度
86
+ dropout (float): Dropout概率
87
+ top_k (int): 每个token选择的专家数量
88
+ """
89
+
90
+ def __init__(self, d_model: int, num_experts: int, d_ff: int, dropout: float, top_k: int):
91
+ super().__init__()
92
+ # 参数初始化
93
+ self.num_experts = num_experts
94
+ self.top_k = min(top_k, num_experts) # 确保top_k不超过专家数量,形状无变化
95
+ self.d_model = d_model
96
+
97
+ # 门控网络:将输入映射到专家分数 [d_model -> num_experts]
98
+ self.gate = nn.Linear(d_model, num_experts, bias=False)
99
+
100
+ # 专家网络:并行专家模块列表
101
+ self.experts = nn.ModuleList([
102
+ nn.Sequential( # 每个专家结构:
103
+ nn.Linear(d_model, d_ff, bias=False), # [d_model -> d_ff]
104
+ nn.GELU(), # 激活函数无形状变化
105
+ nn.Dropout(dropout), # 无形状变化
106
+ nn.Linear(d_ff, d_model, bias=False) # [d_ff -> d_model]
107
+ ) for _ in range(num_experts)
108
+ ])
109
+
110
+ # 参数初始化
111
+ for expert in self.experts:
112
+ nn.init.kaiming_uniform_(expert[0].weight, a=math.sqrt(5)) # 第一层线性权重初始化
113
+ nn.init.zeros_(expert[3].weight) # 输出层零初始化,形状保持 [d_ff, d_model]
114
+
115
+ nn.init.zeros_(self.gate.weight) # 门控网络零初始化,形状 [d_model, num_experts]
116
+
117
+ def orthogonal_loss(self) -> torch.Tensor:
118
+ """计算专家网络之间的正交损失,提高专家多样性
119
+
120
+ Returns:
121
+ torch.Tensor: ���交损失标量值
122
+ """
123
+ total_loss = 0.0
124
+ num_pairs = 0
125
+
126
+ # 获取所有专家的第一层和最后一层权重
127
+ # expert_weights_1形状: [num_experts, d_ff, d_model]
128
+ expert_weights_1 = torch.stack([expert[0].weight for expert in self.experts])
129
+ # expert_weights_2形状: [num_experts, d_model, d_ff]
130
+ expert_weights_2 = torch.stack([expert[3].weight for expert in self.experts])
131
+
132
+ # 计算所有专家对之间的正交损失
133
+ for i in range(self.num_experts):
134
+ w1_i = expert_weights_1[i] # [d_ff, d_model]
135
+ w2_i = expert_weights_2[i] # [d_model, d_ff]
136
+
137
+ for j in range(i+1, self.num_experts):
138
+ w1_j = expert_weights_1[j] # [d_ff, d_model]
139
+ w2_j = expert_weights_2[j] # [d_model, d_ff]
140
+
141
+ # 计算第一层权重的相似度
142
+ w1_sim = torch.sum((w1_i @ w1_j.T)**2) / (w1_i.size(0) * w1_j.size(0)) # 标量
143
+ # 计算第二层权重的相似度
144
+ w2_sim = torch.sum((w2_i.T @ w2_j)**2) / (w2_i.size(1) * w2_j.size(1)) # 标量
145
+
146
+ total_loss += (w1_sim + w2_sim) / 2 # 平均相似度
147
+ num_pairs += 1
148
+
149
+ return total_loss / max(num_pairs, 1) # 平均正交损失
150
+
151
+ def entropy_regularization_loss(self, routing_probs: torch.Tensor) -> torch.Tensor:
152
+ """计算熵正则化损失,鼓励更均匀的路由分布
153
+
154
+ Args:
155
+ routing_probs (torch.Tensor): 路由概率分布,形状 [batch*seq_len, num_experts]
156
+
157
+ Returns:
158
+ torch.Tensor: 熵损失标量值
159
+ """
160
+ # 使用数值稳定的log计算
161
+ log_probs = torch.log(torch.clamp(routing_probs, min=1e-6)) # 保持形状 [batch*seq, num_experts]
162
+ # 逐元素计算熵,保持维度
163
+ entropy = -torch.sum(routing_probs * log_probs, dim=-1) # 形状 [batch*seq]
164
+ return entropy.mean() # 标量
165
+
166
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167
+ """MoE前向传播,高效实现专家选择和组合
168
+
169
+ Args:
170
+ hidden_states (torch.Tensor): 输入张量,形状 [batch_size, seq_len, d_model]
171
+
172
+ Returns:
173
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174
+ - 输出张量 [batch_size, seq_len, d_model]
175
+ - 路由逻辑分数 [batch_size*seq_len, num_experts]
176
+ - 熵正则化损失标量值
177
+ """
178
+ batch_size, seq_len, d_model = hidden_states.shape
179
+ combined_batch_size = batch_size * seq_len
180
+ # 展平输入用于并行处理
181
+ flat_hidden = hidden_states.reshape(combined_batch_size, d_model) # [batch*seq, d_model]
182
+
183
+ # 路由计算
184
+ router_logits = self.gate(flat_hidden) # [batch*seq, num_experts]
185
+ routing_probs = F.softmax(router_logits, dim=-1) # [batch*seq, num_experts]
186
+
187
+ # 选择top-k专家
188
+ routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) # 均为 [batch*seq, top_k]
189
+ # 归一化权重
190
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # [batch*seq, top_k]
191
+
192
+
193
+ # 并行计算所有专家输出
194
+ flat_expert_inputs = [flat_hidden] * self.num_experts # 列表包含num_experts个[batch*seq, d_model]
195
+ expert_outputs = parallel_apply(self.experts, flat_expert_inputs) # 列表包含num_experts个[batch*seq, d_model]
196
+ expert_outputs = torch.stack(expert_outputs, dim=1) # [batch*seq, num_experts, d_model]
197
+
198
+ # 构建专家权重矩阵
199
+ expert_weights_matrix = torch.zeros(
200
+ combined_batch_size, self.num_experts, device=hidden_states.device
201
+ ) # [batch*seq, num_experts]
202
+
203
+ # 使用scatter_add高效聚合权重
204
+ for k in range(self.top_k):
205
+ k_indices = selected_experts[:, k] # [batch*seq]
206
+ k_weights = routing_weights[:, k].unsqueeze(1) # [batch*seq, 1]
207
+ # 将权重累加到对应位置
208
+ expert_weights_matrix.scatter_add_(
209
+ 1,
210
+ k_indices.unsqueeze(1), # [batch*seq, 1]
211
+ k_weights # [batch*seq, 1]
212
+ ) # 更新expert_weights_matrix
213
+
214
+ # 矩阵乘法组合专家输出
215
+ combined_output = torch.bmm(
216
+ expert_weights_matrix.unsqueeze(1), # [batch*seq, 1, num_experts]
217
+ expert_outputs # [batch*seq, num_experts, d_model]
218
+ ).squeeze(1) # [batch*seq, d_model]
219
+
220
+ # 恢复原始形状
221
+ output = combined_output.reshape(batch_size, seq_len, d_model) # [batch_size, seq_len, d_model]
222
+
223
+ # 计算熵正则化损失
224
+ entropy_loss = self.entropy_regularization_loss(routing_probs)
225
+
226
+ return output, router_logits, entropy_loss
227
+
228
+
229
+
230
+ class EncoderLayer(nn.Module):
231
+ """优化的编码器层,支持梯度检查点和残差连接预归一化"""
232
+
233
+ def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
234
+ dropout: float, num_experts: int, top_k: int):
235
+ super().__init__()
236
+ self.model_dim = model_dim
237
+
238
+ # 使用预归一化(Pre-LN)结构,提高训练稳定性
239
+ self.norm1 = nn.LayerNorm(model_dim)
240
+ self.norm2 = nn.LayerNorm(model_dim)
241
+
242
+ self.self_attn = MultiHeadAttention(model_dim, n_heads)
243
+ self.moe = MoE(model_dim, num_experts, ff_hidden_dim, dropout, top_k)
244
+
245
+ self.dropout = nn.Dropout(dropout)
246
+ self.dropout1 = nn.Dropout(dropout)
247
+ self.dropout2 = nn.Dropout(dropout)
248
+
249
+ # 可选的投影层,处理残差连接尺寸不匹配的情况
250
+ self.use_projection = False
251
+ if not self.use_projection:
252
+ self.residual_scale = nn.Parameter(torch.ones(1))
253
+
254
+ def _sa_block(self, x: torch.Tensor,
255
+ mask: Optional[torch.Tensor] = None,
256
+ key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
257
+ """封装自注意力计算,便于梯度检查点使用"""
258
+ x = self.self_attn(x, x, x, mask=mask, key_padding_mask=key_padding_mask)
259
+ return self.dropout1(x)
260
+
261
+ def _moe_block(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """封装MoE计算,便于梯度检查点使用"""
263
+ return self.moe(x)
264
+
265
+ def forward(self,
266
+ x: torch.Tensor,
267
+ src_mask: Optional[torch.Tensor] = None,
268
+ src_key_padding_mask: Optional[torch.Tensor] = None,
269
+ use_checkpoint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
270
+ """
271
+ 编码器层前向传播
272
+
273
+ Args:
274
+ x: 输入张量 [batch_size, seq_len, model_dim]
275
+ src_mask: 源序列掩码
276
+ src_key_padding_mask: 填充掩码
277
+ use_checkpoint: 是否使用梯度检查点以节省内存
278
+ """
279
+ # 预归一化结构 (Pre-LN)
280
+ normalized_x = self.norm1(x)
281
+
282
+ # 自注意力块 (可选梯度检查点)
283
+ if use_checkpoint and self.training:
284
+ attn_output = checkpoint.checkpoint(
285
+ self._sa_block, normalized_x, src_mask, src_key_padding_mask
286
+ )
287
+ else:
288
+ attn_output = self._sa_block(normalized_x, src_mask, src_key_padding_mask)
289
+
290
+ # 第一个残差连接
291
+ x = x + attn_output * self.residual_scale
292
+
293
+ # 预归一化
294
+ normalized_x = self.norm2(x)
295
+
296
+ # MoE块 (可选梯度检查点)
297
+ if use_checkpoint and self.training:
298
+ moe_output, router_logits, entropy_loss = checkpoint.checkpoint(
299
+ self._moe_block, normalized_x
300
+ )
301
+ else:
302
+ moe_output, router_logits, entropy_loss = self._moe_block(normalized_x)
303
+
304
+ # 第二个残差连接
305
+ x = x + self.dropout2(moe_output) * self.residual_scale
306
+
307
+ return x, router_logits, entropy_loss
308
+
309
+
310
+ class PositionwiseFeedForward(nn.Module):
311
+ def __init__(self, d_model, d_ff, dropout=0.1):
312
+ super(PositionwiseFeedForward, self).__init__()
313
+ self.linear1 = nn.Linear(d_model, d_ff)
314
+ self.linear2 = nn.Linear(d_ff, d_model)
315
+ self.dropout = nn.Dropout(dropout)
316
+ self.relu = nn.ReLU()
317
+
318
+ def forward(self, x):
319
+ return self.linear2(self.dropout(self.relu(self.linear1(x))))
320
+
321
+ class EncoderLayer_nomoe(nn.Module):
322
+
323
+ def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
324
+ dropout: float):
325
+ super().__init__()
326
+
327
+ # 使用预归一化(Pre-LN)结构,提高训练稳定性
328
+ self.norm1 = nn.LayerNorm(model_dim)
329
+ self.norm2 = nn.LayerNorm(model_dim)
330
+
331
+ self.self_attn = MultiHeadAttention(model_dim, n_heads)
332
+ self.feed_forward = PositionwiseFeedForward(model_dim, ff_hidden_dim, dropout)
333
+
334
+
335
+ self.dropout1 = nn.Dropout(dropout)
336
+ self.dropout2 = nn.Dropout(dropout)
337
+
338
+ def forward(self,
339
+ x: torch.Tensor,
340
+ src_mask: Optional[torch.Tensor] = None,
341
+ src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
342
+
343
+ # 预归一化结构 (Pre-LN)
344
+ normalized_x = self.norm1(x)
345
+
346
+ attn_output = self.self_attn(normalized_x, normalized_x, normalized_x, src_mask,src_key_padding_mask)
347
+
348
+ # 第一个残差连接
349
+ x = x + self.dropout1(attn_output)
350
+
351
+ # 预归一化
352
+ normalized_x = self.norm2(x)
353
+
354
+ ff_output = self.feed_forward(normalized_x)
355
+
356
+ # 第二个残差连接
357
+ x = x + self.dropout2(ff_output)
358
+
359
+ return x
360
+
361
+ class PositionalEncoding(nn.Module):
362
+ """高效实现的位置编码"""
363
+
364
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
365
+ super().__init__()
366
+ self.dropout = nn.Dropout(p=dropout)
367
+
368
+ # 一次性计算并缓存位置编码
369
+ pe = torch.zeros(1, max_len, d_model)
370
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
371
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
372
+
373
+ # 更高效的位置编码计算
374
+ pe[0, :, 0::2] = torch.sin(position * div_term)
375
+ pe[0, :, 1::2] = torch.cos(position * div_term)
376
+
377
+ # 注册缓冲区而不是参数,节省内存
378
+ self.register_buffer('pe', pe)
379
+
380
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
381
+ """
382
+ 添加位置编码到输入
383
+
384
+ Args:
385
+ x: 输入张量 [batch_size, seq_len, model_dim]
386
+ """
387
+ pos_encoding = self.pe[:, :x.size(1)]
388
+ x = x + pos_encoding
389
+ return self.dropout(x)
390
+
391
+
392
+ class Encoder(nn.Module):
393
+ """优化的Encoder架构"""
394
+
395
+ def __init__(self,
396
+ input_dim: int,
397
+ model_dim: int,
398
+ n_heads: int,
399
+ num_layers: int,
400
+ ff_hidden_dim: int,
401
+ dropout: float,
402
+ num_experts: int,
403
+ top_k: int,
404
+ if_embedding: bool = True,
405
+ if_pos_encoding: bool = True,
406
+ use_checkpointing: bool = False):
407
+ super().__init__()
408
+
409
+ self.model_dim = model_dim
410
+ self.num_layers = num_layers
411
+ self.if_embedding = if_embedding
412
+ self.if_pos_encoding = if_pos_encoding
413
+ self.use_checkpointing = use_checkpointing
414
+
415
+ # 嵌入层
416
+ if if_embedding:
417
+ self.embedding = nn.Embedding(input_dim, model_dim)
418
+ # 改善嵌入初始化
419
+ nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
420
+
421
+ # 位置编码
422
+ if if_pos_encoding:
423
+ self.pos_encoding = PositionalEncoding(model_dim, dropout)
424
+
425
+ # 编码器层
426
+ self.layers = nn.ModuleList([
427
+ EncoderLayer(
428
+ model_dim, n_heads, ff_hidden_dim, dropout, num_experts, top_k
429
+ ) for _ in range(num_layers)
430
+ ])
431
+
432
+ # 输出归一化
433
+ self.final_norm = nn.LayerNorm(model_dim)
434
+
435
+ def forward(self,
436
+ src: torch.Tensor,
437
+ src_mask: Optional[torch.Tensor] = None,
438
+ src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
439
+ """
440
+ 编码器前向传播
441
+
442
+ Args:
443
+ src: 输入序列 [batch_size, seq_len] 或 [batch_size, seq_len, model_dim]
444
+ src_mask: 源序列掩码
445
+ src_key_padding_mask: 填充掩码
446
+
447
+ Returns:
448
+ tuple: (输出张量, 路由逻辑列表, 熵损失)
449
+ """
450
+ # 嵌入处理
451
+ if self.if_embedding:
452
+ x = self.embedding(src) * math.sqrt(self.model_dim)
453
+ else:
454
+ x = src
455
+
456
+ # 位置编码
457
+ if self.if_pos_encoding:
458
+ x = self.pos_encoding(x)
459
+
460
+ # 跟踪熵损失和路由逻辑
461
+ total_entropy_loss = 0.0
462
+ router_logits_list = []
463
+
464
+ # 通过编码器层
465
+ for layer in self.layers:
466
+ x, router_logits, entropy_loss = layer(
467
+ x,
468
+ src_mask=src_mask,
469
+ src_key_padding_mask=src_key_padding_mask,
470
+ use_checkpoint=self.use_checkpointing
471
+ )
472
+ total_entropy_loss += entropy_loss
473
+
474
+ # 只保存CPU版本的路由逻辑,降低内存使用
475
+ if not self.training: # 仅在推理时保存路由逻辑
476
+ router_logits_list.append(router_logits.detach().cpu().tolist())
477
+
478
+ # 应用最终层归一化
479
+ x = self.final_norm(x)
480
+
481
+ # 计算平均熵损失
482
+ avg_entropy_loss = total_entropy_loss / self.num_layers
483
+
484
+ return x, router_logits_list, avg_entropy_loss
485
+
486
+
487
+
488
+ class Encoder_nomoe(nn.Module):
489
+ """优化的Encoder架构"""
490
+
491
+ def __init__(self,
492
+ input_dim: int,
493
+ model_dim: int,
494
+ n_heads: int,
495
+ num_layers: int,
496
+ ff_hidden_dim: int,
497
+ dropout: float,
498
+ if_embedding: bool = True,
499
+ if_pos_encoding: bool = True):
500
+ super().__init__()
501
+
502
+ self.model_dim = model_dim
503
+ self.num_layers = num_layers
504
+ self.if_embedding = if_embedding
505
+ self.if_pos_encoding = if_pos_encoding
506
+
507
+ # 嵌入层
508
+ if if_embedding:
509
+ self.embedding = nn.Embedding(input_dim, model_dim)
510
+ # 改善嵌入初始化
511
+ nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
512
+
513
+ # 位置编码
514
+ if if_pos_encoding:
515
+ self.pos_encoding = PositionalEncoding(model_dim, dropout)
516
+
517
+ # 编码器层
518
+ self.layers = nn.ModuleList([
519
+ EncoderLayer_nomoe(
520
+ model_dim, n_heads, ff_hidden_dim, dropout
521
+ ) for _ in range(num_layers)
522
+ ])
523
+
524
+ # 输出归一化
525
+ self.final_norm = nn.LayerNorm(model_dim)
526
+
527
+ def forward(self,
528
+ src: torch.Tensor,
529
+ src_mask: Optional[torch.Tensor] = None,
530
+ src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
531
+
532
+ # 嵌入处理
533
+ if self.if_embedding:
534
+ x = self.embedding(src) * math.sqrt(self.model_dim)
535
+ else:
536
+ x = src
537
+
538
+ # 位置编码
539
+ if self.if_pos_encoding:
540
+ x = self.pos_encoding(x)
541
+
542
+
543
+
544
+ # 通过编码器层
545
+ for layer in self.layers:
546
+ x = layer(
547
+ x,
548
+ src_mask=src_mask,
549
+ src_key_padding_mask=src_key_padding_mask
550
+ )
551
+
552
+ # 应用最终层归一化
553
+ x = self.final_norm(x)
554
+
555
+ return x
utils.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ import torch.optim.lr_scheduler as lr_scheduler
8
+ from transformer_encoder_MoE import Encoder,Encoder_nomoe
9
+ from itertools import chain
10
+ from torch.nn.parallel import parallel_apply
11
+ from typing import List, Dict, Tuple, Optional, Union
12
+ from torchcrf import CRF
13
+
14
+
15
+ class Tokenizer:
16
+ """处理序列编码和解码的分词器,支持蛋白质序列和mRNA序列。"""
17
+
18
+ def __init__(self):
19
+ # 定义特殊标记和生物序列标记
20
+ self.special_tokens = ['[START]', '[END]', '[PAD]', '[UNK]', '[SEG]']
21
+ self.amino_acids = ['A', 'R', 'S', 'I', 'L', 'G', 'V', 'T', 'P', 'N',
22
+ 'D', 'C', 'Q', 'E', 'H', 'K', 'F', 'Y', 'M', 'W', '*']
23
+ self.protein_alphabet = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I',
24
+ 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
25
+ # 生成所有可能的密码子组合
26
+ self.codons = [''.join([n1, n2, n3]) for n1 in 'UCAG' for n2 in 'UCAG' for n3 in 'UCAG']
27
+
28
+
29
+
30
+ # 合并所有标记并创建映射
31
+ self.tokens = self.special_tokens + self.amino_acids + self.codons
32
+ self.token_to_id = {token: idx for idx, token in enumerate(self.tokens)}
33
+ self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}
34
+
35
+ # 缓存常用的特殊标记索引以提高性能
36
+ self.padding_idx = self.token_to_id['[PAD]']
37
+ self.start_idx = self.token_to_id['[START]']
38
+ self.end_idx = self.token_to_id['[END]']
39
+ self.unk_idx = self.token_to_id['[UNK]']
40
+ self.seg_idx = self.token_to_id['[SEG]']
41
+
42
+ def encode_pro(self, sequence: str, max_length: int) -> List[int]:
43
+ """编码蛋白质序列。
44
+
45
+ Args:
46
+ sequence: 输入的蛋白质序列
47
+ max_length: 编码后序列的最大长度
48
+
49
+ Returns:
50
+ 编码后的ID列表
51
+ """
52
+ # 添加开始标记,并为每个字符获取ID
53
+ ids = [self.start_idx] + [self.token_to_id.get(token, self.unk_idx) for token in sequence]
54
+
55
+ # 处理序列长度并添加结束标记
56
+ if len(ids) < max_length - 1:
57
+ ids.append(self.end_idx)
58
+ else:
59
+ ids = ids[:max_length-1] + [self.end_idx]
60
+
61
+ return ids
62
+
63
+ def encode_mrna(self, sequence: str, max_length: int) -> List[int]:
64
+ """编码mRNA序列,每三个核苷酸作为一个密码子。
65
+
66
+ Args:
67
+ sequence: 输入的mRNA序列
68
+ max_length: 编码后序列的最大长度
69
+
70
+ Returns:
71
+ 编码后的ID列表
72
+ """
73
+ ids = [self.start_idx]
74
+
75
+ # 每三个字符(一个密码子)作为一个单位处理
76
+ for i in range(0, len(sequence), 3):
77
+ codon = sequence[i:i+3]
78
+ if len(codon) == 3 and codon in self.token_to_id:
79
+ ids.append(self.token_to_id[codon])
80
+ else:
81
+ ids.append(self.unk_idx)
82
+
83
+ # 处理序列长度并添加结束标记
84
+ if len(ids) < max_length - 1:
85
+ ids.append(self.end_idx)
86
+ else:
87
+ ids = ids[:max_length-1] + [self.end_idx]
88
+
89
+ return ids
90
+
91
+ def decode(self, ids: List[int]) -> str:
92
+ """将ID序列解码为文本。
93
+
94
+ Args:
95
+ ids: 编码后的ID列表
96
+
97
+ Returns:
98
+ 解码后的文本
99
+ """
100
+ return ''.join([self.id_to_token.get(id, '[UNK]') for id in ids])
101
+
102
+ def pad(self, ids: List[int], max_length: int) -> List[int]:
103
+ """对序列进行填充至指定长度。
104
+
105
+ Args:
106
+ ids: 编码后的ID列表
107
+ max_length: 目标长度
108
+
109
+ Returns:
110
+ 填充后的ID列表
111
+ """
112
+ padding_length = max_length - len(ids)
113
+ if padding_length > 0:
114
+ return ids + [self.padding_idx] * padding_length
115
+ return ids
116
+
117
+
118
+ # 生成密码子表和相关映射
119
+ class BiologicalMappings:
120
+ """生物序列编码的映射工具类。"""
121
+
122
+ @staticmethod
123
+ def get_codon_table() -> Dict[str, str]:
124
+ """返回密码子到氨基酸的映射表。"""
125
+ return {
126
+ 'GCU':'A', 'GCC':'A', 'GCA':'A', 'GCG':'A', 'CGU':'R', 'CGC':'R',
127
+ 'CGA':'R', 'CGG':'R', 'AGA':'R', 'AGG':'R', 'UCU':'S', 'UCC':'S',
128
+ 'UCA':'S', 'UCG':'S', 'AGU':'S', 'AGC':'S', 'AUU':'I', 'AUC':'I',
129
+ 'AUA':'I', 'UUA':'L', 'UUG':'L', 'CUU':'L', 'CUC':'L', 'CUA':'L',
130
+ 'CUG':'L', 'GGU':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G', 'GUU':'V',
131
+ 'GUC':'V', 'GUA':'V', 'GUG':'V', 'ACU':'T', 'ACC':'T', 'ACA':'T',
132
+ 'ACG':'T', 'CCU':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P', 'AAU':'N',
133
+ 'AAC':'N', 'GAU':'D', 'GAC':'D', 'UGU':'C', 'UGC':'C', 'CAA':'Q',
134
+ 'CAG':'Q', 'GAA':'E', 'GAG':'E', 'CAU':'H', 'CAC':'H', 'AAA':'K',
135
+ 'AAG':'K', 'UUU':'F', 'UUC':'F', 'UAU':'Y', 'UAC':'Y', 'AUG':'M',
136
+ 'UGG':'W','UAG':'*', 'UGA':'*', 'UAA':'*'}
137
+
138
+ @staticmethod
139
+ def get_amino_acid_to_codon() -> Dict[str, List[str]]:
140
+ """返回氨基酸到密码子的映射表。"""
141
+ return {
142
+ 'A':['GCU','GCC','GCA','GCG'], 'R':['CGU','CGC','CGA','CGG','AGA','AGG'],
143
+ 'S':['UCU','UCC','UCA','UCG','AGU','AGC'],'I':['AUU','AUC','AUA'],
144
+ 'L':['UUA','UUG','CUU','CUC','CUA','CUG'],'G':['GGU','GGC','GGA','GGG'],
145
+ 'V':['GUU','GUC','GUA','GUG'],'T':['ACU','ACC','ACA','ACG'],
146
+ 'P':['CCU','CCC','CCA','CCG'],'N':['AAU','AAC'],'D':['GAU','GAC'],
147
+ 'C':['UGU','UGC'],'Q':['CAA','CAG'],'E':['GAA','GAG'],'H':['CAU','CAC'],
148
+ 'K':['AAA','AAG'],'F':['UUU','UUC'],'Y':['UAU','UAC'],'M':['AUG'],'W':['UGG'],
149
+ '*':['UAG','UGA','UAA']
150
+ }
151
+
152
+ @staticmethod
153
+ def create_token_mapping(tokenizer: Tokenizer) -> torch.Tensor:
154
+ """创建从密码子令牌到氨基酸令牌的映射张量。
155
+
156
+ Args:
157
+ tokenizer: 用于获取令牌到ID映射的分词器
158
+
159
+ Returns:
160
+ 映射张量,索引为密码子ID,值为对应的氨基酸ID
161
+ """
162
+ codon_table = BiologicalMappings.get_codon_table()
163
+ token_codon_to_amino_acid = torch.full((len(tokenizer.tokens),),
164
+ tokenizer.unk_idx,
165
+ dtype=torch.long)
166
+
167
+ for codon, amino_acid in codon_table.items():
168
+ codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
169
+ amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
170
+ token_codon_to_amino_acid[codon_id] = amino_acid_id
171
+
172
+ return token_codon_to_amino_acid
173
+
174
+
175
+ class ActorModel_encoder_noesm2(nn.Module):
176
+ """基于编码器的Actor模型,用于序列生成任务。"""
177
+
178
+ def __init__(self, vocab_size: int, d_model: int, nhead: int,
179
+ num_encoder_layers: int, dim_feedforward: int, dropout: float,
180
+ num_experts: int, top_k_experts: int, device: torch.device):
181
+ """初始化模型。
182
+
183
+ Args:
184
+ vocab_size: 词汇表大小
185
+ d_model: 模型维度
186
+ nhead: 注意力头数
187
+ num_encoder_layers: 编码器层数
188
+ dim_feedforward: 前馈网络维度
189
+ dropout: Dropout率
190
+ num_experts: 专家数量
191
+ top_k_experts: 使用的顶部专家数量
192
+ device: 计算设备
193
+ """
194
+ super(ActorModel_encoder_noesm2, self).__init__()
195
+ self.device = device
196
+
197
+ # 获取生物映射并预计算掩码
198
+ self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
199
+ self.precomputed_masks = self._precompute_masks()
200
+
201
+ # 创建编码器和输出层
202
+ self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
203
+ dim_feedforward, dropout, num_experts, top_k_experts)
204
+
205
+ # 使用序列化的输出层以提高性能
206
+ self.mrna_output_layer = nn.Sequential(
207
+ nn.Linear(d_model, d_model//2),
208
+ nn.LayerNorm(d_model//2),
209
+ nn.ReLU(),
210
+ nn.Dropout(dropout),
211
+ nn.Linear(d_model//2, vocab_size)
212
+ )
213
+
214
+
215
+
216
+ def _precompute_masks(self) -> Dict[int, torch.Tensor]:
217
+ """预计算每个氨基酸对应的密码子掩码,以提高性能。"""
218
+ tokenizer = Tokenizer() # 创建分词器实例
219
+ masks = {}
220
+
221
+ for amino_acid, codons in self.amino_acid_to_codon.items():
222
+ amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
223
+ mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
224
+
225
+ for codon in codons:
226
+ codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
227
+ if codon_id != tokenizer.unk_idx:
228
+ mask[codon_id] = True
229
+
230
+ masks[amino_acid_id] = mask
231
+
232
+ return masks
233
+
234
+ def forward(self, tokenizer_encoded_proteins: torch.Tensor) -> Tuple[torch.Tensor, list, torch.Tensor]:
235
+ """模型前向传播。
236
+
237
+ Args:
238
+ tokenizer_encoded_proteins: 编码后的蛋白质序列,形状为(batch_size, seq_len)
239
+
240
+ Returns:
241
+ logits: 输出逻辑值,表示模型预测
242
+ router_logits_list: 路由器逻辑值列表
243
+ entropy_loss: 熵损失
244
+ """
245
+ # 创建源序列的填充掩码
246
+ tokenizer = Tokenizer() # 创建分词器实例
247
+ src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
248
+
249
+ # 通过编码器处理
250
+ x, router_logits_list, entropy_loss = self.encoder(
251
+ tokenizer_encoded_proteins,
252
+ src_key_padding_mask=src_padding_mask
253
+ )
254
+
255
+ # 为批次中的每个项目和序列位置生成掩码
256
+ batch_size, seq_len = tokenizer_encoded_proteins.shape
257
+
258
+ # 使用索引查询预计算的掩码,通过广播优化性能
259
+ amino_acid_to_codon_mask = torch.stack([
260
+ self.precomputed_masks.get(
261
+ tok.item(),
262
+ torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
263
+ )
264
+ for tok in tokenizer_encoded_proteins.reshape(-1)
265
+ ]).view(batch_size, seq_len, -1)
266
+
267
+ # 计算输出逻辑值并应用掩码
268
+ mrna_logits = self.mrna_output_layer(x)
269
+
270
+
271
+ # 使用masking而不是scatter来提高性能
272
+ mrna_logits = mrna_logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
273
+
274
+ return mrna_logits, router_logits_list, entropy_loss
275
+
276
+ class ActorModel_encoder_esm2(nn.Module):
277
+ """基于编码器的Actor模型,用于序列生成任务。"""
278
+
279
+ def __init__(self, vocab_size: int, d_model: int, nhead: int,
280
+ num_encoder_layers: int, dim_feedforward: int, esm2_dim: int,dropout: float,
281
+ num_experts: int, top_k_experts: int, device: torch.device):
282
+
283
+ super(ActorModel_encoder_esm2, self).__init__()
284
+ self.device = device
285
+
286
+ # 获取生物映射并预计算掩码
287
+ self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
288
+ self.precomputed_masks = self._precompute_masks()
289
+
290
+ self.dim_trans=nn.Linear(esm2_dim, d_model)
291
+ # 创建编码器和输出层
292
+ self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
293
+ dim_feedforward, dropout, num_experts, top_k_experts,if_embedding=False,if_pos_encoding=False)
294
+
295
+ # 使用序列化的输出层以提高性能
296
+ self.mrna_output_layer = nn.Sequential(
297
+ nn.Linear(d_model, d_model//2),
298
+ nn.LayerNorm(d_model//2),
299
+ nn.ReLU(),
300
+ nn.Dropout(dropout),
301
+ nn.Linear(d_model//2, vocab_size)
302
+ )
303
+
304
+
305
+
306
+
307
+ def _precompute_masks(self) -> Dict[int, torch.Tensor]:
308
+ """预计算每个氨基酸对应的密码子掩码,以提高性能。"""
309
+ tokenizer = Tokenizer() # 创建分词器实例
310
+ masks = {}
311
+
312
+ for amino_acid, codons in self.amino_acid_to_codon.items():
313
+ amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
314
+ mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
315
+
316
+ for codon in codons:
317
+ codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
318
+ if codon_id != tokenizer.unk_idx:
319
+ mask[codon_id] = True
320
+
321
+ masks[amino_acid_id] = mask
322
+
323
+ return masks
324
+
325
+ def forward(self, tokenizer_encoded_proteins,esm2_encoded_proteins) -> Tuple[torch.Tensor, list, torch.Tensor]:
326
+
327
+ # 创建源序列的填充掩码
328
+ tokenizer = Tokenizer() # 创建分词器实例
329
+ src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
330
+
331
+ # 通过编码器处理
332
+ x=self.dim_trans(esm2_encoded_proteins)
333
+
334
+ x, router_logits_list, entropy_loss = self.encoder(
335
+ x,
336
+ src_key_padding_mask=src_padding_mask
337
+ )
338
+
339
+ # 为批次中的每个项目和序列位置生成掩码
340
+ batch_size, seq_len = tokenizer_encoded_proteins.shape
341
+
342
+ # 使用索引查询预计算的掩码,通过广播优化性能
343
+ amino_acid_to_codon_mask = torch.stack([
344
+ self.precomputed_masks.get(
345
+ tok.item(),
346
+ torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
347
+ )
348
+ for tok in tokenizer_encoded_proteins.reshape(-1)
349
+ ]).view(batch_size, seq_len, -1)
350
+
351
+ # 计算输出逻辑值并应用掩码
352
+ mrna_logits = self.mrna_output_layer(x)
353
+
354
+
355
+ # 使用masking而不是scatter来提高性能
356
+ mrna_logits = mrna_logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
357
+
358
+ return mrna_logits, router_logits_list, entropy_loss
359
+
360
+ def get_embedding(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
361
+
362
+ # 创建源序列的填充掩码
363
+ tokenizer = Tokenizer() # 创建分词器实例
364
+ src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
365
+
366
+ # 通过编码器处理
367
+ x=self.dim_trans(esm2_encoded_proteins)
368
+
369
+ x, router_logits_list, entropy_loss = self.encoder(
370
+ x,
371
+ src_key_padding_mask=src_padding_mask
372
+ )
373
+ return x
374
+ def get_router_logits(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
375
+
376
+ # 创建源序列的填充掩码
377
+ tokenizer = Tokenizer() # 创建分词器实例
378
+ src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
379
+
380
+ # 通过编码器处理
381
+ x=self.dim_trans(esm2_encoded_proteins)
382
+
383
+ x, router_logits_list, entropy_loss = self.encoder(
384
+ x,
385
+ src_key_padding_mask=src_padding_mask
386
+ )
387
+ return router_logits_list
388
+
389
+ class ActorModel_encoder_nomoe(nn.Module):
390
+ """基于编码器的Actor模型,用于序列生成任务。"""
391
+
392
+ def __init__(self, vocab_size: int, d_model: int, nhead: int,
393
+ num_encoder_layers: int, dim_feedforward: int, esm2_dim: int,dropout: float, device: torch.device):
394
+ super(ActorModel_encoder_nomoe, self).__init__()
395
+ self.device = device
396
+
397
+ # 获取生物映射并预计算掩码
398
+ self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
399
+ self.precomputed_masks = self._precompute_masks()
400
+
401
+ self.dim_trans=nn.Linear(esm2_dim, d_model)
402
+ # 创建编码器和输出层
403
+ self.encoder = Encoder_nomoe(vocab_size, d_model, nhead, num_encoder_layers,
404
+ dim_feedforward, dropout,if_embedding=False,if_pos_encoding=False)
405
+
406
+ # 使用序列化的输出层以提高性能
407
+ self.output_layer = nn.Sequential(
408
+ nn.Linear(d_model, d_model//2),
409
+ nn.LayerNorm(d_model//2),
410
+ nn.ReLU(),
411
+ nn.Dropout(dropout),
412
+ nn.Linear(d_model//2, vocab_size)
413
+ )
414
+
415
+ def _precompute_masks(self) -> Dict[int, torch.Tensor]:
416
+ """预计算每个氨基酸对应的密码子掩码,以提高性能。"""
417
+ tokenizer = Tokenizer() # 创建分词器实例
418
+ masks = {}
419
+
420
+ for amino_acid, codons in self.amino_acid_to_codon.items():
421
+ amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
422
+ mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
423
+
424
+ for codon in codons:
425
+ codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
426
+ if codon_id != tokenizer.unk_idx:
427
+ mask[codon_id] = True
428
+
429
+ masks[amino_acid_id] = mask
430
+
431
+ return masks
432
+
433
+ def forward(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
434
+ """模型前向传播。
435
+
436
+ Args:
437
+ tokenizer_encoded_proteins: 编码后的蛋白质序列,形状为(batch_size, seq_len)
438
+
439
+ Returns:
440
+ logits: 输出逻辑值,表示模型预测
441
+ router_logits_list: 路由器逻辑值列表
442
+ entropy_loss: 熵损失
443
+ """
444
+ # 创建源序列的填充掩码
445
+ tokenizer = Tokenizer() # 创建分词器实例
446
+ src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
447
+
448
+ x=self.dim_trans(esm2_encoded_proteins)
449
+
450
+ # 通过编码器处理
451
+ x= self.encoder(
452
+ x,
453
+ src_key_padding_mask=src_padding_mask
454
+ )
455
+
456
+ # 为批次中的每个项目和序列位置生成掩码
457
+ batch_size, seq_len = tokenizer_encoded_proteins.shape
458
+
459
+ # 使用索引查询预计算的掩码,通过广播优化性能
460
+ amino_acid_to_codon_mask = torch.stack([
461
+ self.precomputed_masks.get(
462
+ tok.item(),
463
+ torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
464
+ )
465
+ for tok in tokenizer_encoded_proteins.reshape(-1)
466
+ ]).view(batch_size, seq_len, -1)
467
+
468
+ # 计算输出逻辑值并应用掩码
469
+ logits = self.output_layer(x)
470
+
471
+ # 使用masking而不是scatter来提高性能
472
+ logits = logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
473
+
474
+ return logits
475
+
476
+ class RewardModel_encoder(nn.Module):
477
+ def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward,dropout,num_experts,top_k_experts,device):
478
+ super(RewardModel_encoder, self).__init__()
479
+ self.tokenizer=Tokenizer()
480
+ self.device=device
481
+
482
+ self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
483
+ dim_feedforward, dropout, num_experts, top_k_experts)
484
+ self.reward_output_layer = nn.Sequential(
485
+ nn.Linear(d_model, d_model//2),
486
+ nn.LayerNorm(d_model//2), # 对线性层的输出进行归一化
487
+ nn.ReLU(),
488
+ nn.Dropout(dropout),
489
+ nn.Linear(d_model//2, 1)
490
+ )
491
+
492
+
493
+ def forward(self, tokenizer_encoded_mrnas):
494
+
495
+ src_padding_mask = (tokenizer_encoded_mrnas==self.tokenizer.padding_idx)
496
+
497
+ x,router_logits_list,entropy_loss = self.encoder(tokenizer_encoded_mrnas, src_key_padding_mask=src_padding_mask)
498
+
499
+
500
+ reward=self.reward_output_layer(x)
501
+ reward=reward[:,0,:].squeeze()
502
+
503
+ return reward,router_logits_list,entropy_loss
504
+
505
+
506
+
507
+ class LengthAwareDistributedSampler_human(DistributedSampler):
508
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
509
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
510
+
511
+ self.lengths = lengths # 每个样本的长度列表
512
+ self.weights = self.calculate_weights() # 根据长度初始化权重
513
+ self.data_num_rat=data_num_rat
514
+ self.total_size = int(len(dataset) * data_num_rat)
515
+
516
+ def calculate_weights(self):
517
+ # 分段式加权策略
518
+ weights = np.ones(len(self.lengths))
519
+ weights[np.array(self.lengths) >= 1300] = 85.64*200
520
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 5.02*200
521
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 4.36*100
522
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 3.63*100
523
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 3.15
524
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 2.20
525
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.64
526
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.36
527
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
528
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.75
529
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.63
530
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.60
531
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.71
532
+ weights[np.array(self.lengths) < 100] = 3.68*100
533
+
534
+ return weights / np.sum(weights) # 将权重归一化
535
+
536
+ def __iter__(self):
537
+ # 根据加权采样进行索引选择
538
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
539
+
540
+ # 边界处理:截断到可以整除 num_replicas 的长度
541
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
542
+ indices = indices[:total_size_local] # 截断多余的样本
543
+
544
+ # 将样本分配给不同进程
545
+ indices = indices[self.rank:total_size_local:self.num_replicas]
546
+
547
+ if self.shuffle:
548
+ np.random.shuffle(indices)
549
+
550
+ return iter(indices.tolist())
551
+
552
+ def set_epoch(self, epoch):
553
+ super().set_epoch(epoch)
554
+
555
+ class LengthAwareDistributedSampler_Arabidopsis(DistributedSampler):
556
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
557
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
558
+
559
+ self.lengths = lengths # 每个样本的长度列表
560
+ self.weights = self.calculate_weights() # 根据长度初始化权重
561
+ self.data_num_rat=data_num_rat
562
+ self.total_size = int(len(dataset) * data_num_rat)
563
+
564
+ def calculate_weights(self):
565
+ # 分段式加权策略
566
+ weights = np.ones(len(self.lengths))
567
+ weights[np.array(self.lengths) >= 1300] = 630.75*20
568
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 17.05*20
569
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 11.52*20
570
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 7.17*10
571
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 5.56*10
572
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.54
573
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.51
574
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.62
575
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
576
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.68
577
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.49
578
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.49
579
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.49
580
+ weights[np.array(self.lengths) < 100] = 1.23*10
581
+
582
+ return weights / np.sum(weights) # 将权重归一化
583
+
584
+ def __iter__(self):
585
+ # 根据加权采样进行索引选择
586
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
587
+
588
+ # 边界处理:截断到可以整除 num_replicas 的长度
589
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
590
+ indices = indices[:total_size_local] # 截断多余的样本
591
+
592
+ # 将样本分配给不同进程
593
+ indices = indices[self.rank:total_size_local:self.num_replicas]
594
+
595
+ if self.shuffle:
596
+ np.random.shuffle(indices)
597
+
598
+ return iter(indices.tolist())
599
+
600
+ def set_epoch(self, epoch):
601
+ super().set_epoch(epoch)
602
+
603
+
604
+ class LengthAwareDistributedSampler_CR(DistributedSampler):
605
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
606
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
607
+
608
+ self.lengths = lengths # 每个样本的长度列表
609
+ self.weights = self.calculate_weights() # 根据长度初始化权重
610
+ self.data_num_rat=data_num_rat
611
+ self.total_size = int(len(dataset) * data_num_rat)
612
+
613
+ def calculate_weights(self):
614
+ # 分段式加权策略
615
+ weights = np.ones(len(self.lengths))
616
+ weights[np.array(self.lengths) >= 1300] = 61.55*20
617
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 3.66*20
618
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 2.96*10
619
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 2.54*10
620
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 2.11*10
621
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 1.79
622
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.39
623
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.11
624
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
625
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.82
626
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.73
627
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.67
628
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.66
629
+ weights[np.array(self.lengths) < 100] = 1.18*10
630
+
631
+ return weights / np.sum(weights) # 将权重归一化
632
+
633
+ def __iter__(self):
634
+ # 根据加权采样进行索引选择
635
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
636
+
637
+ # 边界处理:截断到可以整除 num_replicas 的长度
638
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
639
+ indices = indices[:total_size_local] # 截断多余的样本
640
+
641
+ # 将样本分配给不同进程
642
+ indices = indices[self.rank:total_size_local:self.num_replicas]
643
+
644
+ if self.shuffle:
645
+ np.random.shuffle(indices)
646
+
647
+ return iter(indices.tolist())
648
+
649
+ def set_epoch(self, epoch):
650
+ super().set_epoch(epoch)
651
+
652
+ class LengthAwareDistributedSampler_PC(DistributedSampler):
653
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
654
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
655
+
656
+ self.lengths = lengths # 每个样本的长度列表
657
+ self.weights = self.calculate_weights() # 根据长度初始化权重
658
+ self.data_num_rat=data_num_rat
659
+ self.total_size = int(len(dataset) * data_num_rat)
660
+
661
+ def calculate_weights(self):
662
+ # 分段式加权策略
663
+ weights = np.ones(len(self.lengths))
664
+ weights[np.array(self.lengths) >= 1300] = 318.0*200
665
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 13.98*200
666
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 10.26*100
667
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 7.62*100
668
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 6.14*100
669
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.80
670
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.67
671
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.88
672
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
673
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.88
674
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.75
675
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.76
676
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.83
677
+ weights[np.array(self.lengths) < 100] = 1.87*100
678
+
679
+ return weights / np.sum(weights) # 将权重归一化
680
+
681
+ def __iter__(self):
682
+ # 根据加权采样进行索引选择
683
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
684
+
685
+ # 边界处理:截断到可以整除 num_replicas 的长度
686
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
687
+ indices = indices[:total_size_local] # 截断多余的样本
688
+
689
+ # 将样本分配给不同进程
690
+ indices = indices[self.rank:total_size_local:self.num_replicas]
691
+
692
+ if self.shuffle:
693
+ np.random.shuffle(indices)
694
+
695
+ return iter(indices.tolist())
696
+
697
+ def set_epoch(self, epoch):
698
+ super().set_epoch(epoch)
699
+
700
+ class LengthAwareDistributedSampler_EscherichiaColi(DistributedSampler):
701
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
702
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
703
+
704
+ self.lengths = lengths # 每个样本的长度列表
705
+ self.weights = self.calculate_weights() # 根据长度初始化权重
706
+ self.data_num_rat=data_num_rat
707
+ self.total_size = int(len(dataset) * data_num_rat)
708
+
709
+ def calculate_weights(self):
710
+ # 分段式加权策略
711
+ weights = np.ones(len(self.lengths))
712
+ weights[np.array(self.lengths) >= 1300] = 211.0*200
713
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 26.38*200
714
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 15.07*100
715
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 11.72*100
716
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 11.11*100
717
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 4.06
718
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.81
719
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 2.07
720
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
721
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.46
722
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.30
723
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.25
724
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.25
725
+ weights[np.array(self.lengths) < 100] = 0.47
726
+
727
+ return weights / np.sum(weights) # 将权重归一化
728
+
729
+ def __iter__(self):
730
+ # 根据加权采样进行索引选择
731
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
732
+
733
+ # 边界处理:截断到可以整除 num_replicas 的长度
734
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
735
+ indices = indices[:total_size_local] # 截断多余的样本
736
+
737
+ # 将样本分配给不同进程
738
+ indices = indices[self.rank:total_size_local:self.num_replicas]
739
+
740
+ if self.shuffle:
741
+ np.random.shuffle(indices)
742
+
743
+ return iter(indices.tolist())
744
+
745
+ def set_epoch(self, epoch):
746
+ super().set_epoch(epoch)
747
+
748
+ class LengthAwareDistributedSampler_TK(DistributedSampler):
749
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
750
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
751
+
752
+ self.lengths = lengths # 每个样本的长度列表
753
+ self.weights = self.calculate_weights() # 根据长度初始化权重
754
+ self.data_num_rat=data_num_rat
755
+ self.total_size = int(len(dataset) * data_num_rat)
756
+
757
+ def calculate_weights(self):
758
+ # 分段式加权策略
759
+ weights = np.ones(len(self.lengths))
760
+
761
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 12.25*10
762
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 8.17*10
763
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 24.5*10
764
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 8.17*10
765
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.27
766
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.33
767
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.09
768
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
769
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.25
770
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.17
771
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.13
772
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.10
773
+ weights[np.array(self.lengths) < 100] = 0.22
774
+
775
+ return weights / np.sum(weights) # 将权重归一化
776
+
777
+ def __iter__(self):
778
+ # 根据加权采样进行索引选择
779
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
780
+
781
+ # 边界处理:截断到可以整除 num_replicas 的长度
782
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
783
+ indices = indices[:total_size_local] # 截断多余的样本
784
+
785
+ # 将样本分配给不同进程
786
+ indices = indices[self.rank:total_size_local:self.num_replicas]
787
+
788
+ if self.shuffle:
789
+ np.random.shuffle(indices)
790
+
791
+ return iter(indices.tolist())
792
+
793
+ def set_epoch(self, epoch):
794
+ super().set_epoch(epoch)
795
+
796
+
797
+
798
+ class LengthAwareDistributedSampler_human_circ(DistributedSampler):
799
+ def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
800
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
801
+
802
+ self.lengths = lengths # 每个样本的长度列表
803
+ self.weights = self.calculate_weights() # 根据长度初始化权重
804
+ self.data_num_rat=data_num_rat
805
+ self.total_size = int(len(dataset) * data_num_rat)
806
+
807
+ def calculate_weights(self):
808
+ # 分段式加权策略
809
+ weights = np.ones(len(self.lengths))
810
+ weights[np.array(self.lengths) >= 1300] = 89.62*20
811
+ weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 5.24*20
812
+ weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 4.58*10
813
+ weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 3.82*10
814
+ weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 3.30
815
+ weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 2.34
816
+ weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.74
817
+ weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.36
818
+ weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
819
+ weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.74
820
+ weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.57
821
+ weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.46
822
+ weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.38
823
+ weights[np.array(self.lengths) < 100] = 0.48
824
+
825
+ return weights / np.sum(weights) # 将权重归一化
826
+
827
+ def __iter__(self):
828
+ # 根据加权采样进行索引选择
829
+ indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
830
+
831
+ # 边界处理:截断到可以整除 num_replicas 的长度
832
+ total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
833
+ indices = indices[:total_size_local] # 截断多余的样本
834
+
835
+ # 将样本分配给不同进程
836
+ indices = indices[self.rank:total_size_local:self.num_replicas]
837
+
838
+ if self.shuffle:
839
+ np.random.shuffle(indices)
840
+
841
+ return iter(indices.tolist())
842
+
843
+ def set_epoch(self, epoch):
844
+ super().set_epoch(epoch)