你çsglin
commited on
Commit
·
9c396a5
1
Parent(s):
1a6cb78
Add initial model weights, utils.py, transformer_encoder_MoE.py, and README
Browse files- Arabidopsis.pt +3 -0
- CR.pt +3 -0
- EscherichiaColi.pt +3 -0
- PC.pt +3 -0
- README.md +153 -0
- TK.pt +3 -0
- homo_circ.pt +3 -0
- homo_mrna.pt +3 -0
- transformer_encoder_MoE.py +555 -0
- utils.py +844 -0
Arabidopsis.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:725088277142fa26aa1781806b107155363a903d0eba542bdcf312f5dbde48c8
|
3 |
+
size 534071434
|
CR.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b69bb49ff4a336cc7a4f47ffb438dba298d67b237f678403a43bee511c8c1928
|
3 |
+
size 534069804
|
EscherichiaColi.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6833a55fc82e708e8e5307dd63be0983504d768d32a1b1ec7a83b616bdf49671
|
3 |
+
size 534072130
|
PC.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe6d83b0824ced0f00c43189a58a9eaecd4acf02226dffb057b27166f08d8f1a
|
3 |
+
size 534069804
|
README.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- generation
|
4 |
+
- protein-sequence
|
5 |
+
- rna-sequence
|
6 |
+
- pytorch
|
7 |
+
---
|
8 |
+
|
9 |
+
# Protein to RNA CDS Sequence Generation Model
|
10 |
+
|
11 |
+
This model is a custom PyTorch model designed to generate RNA CDS sequences from protein sequences. It utilizes a custom transformer-based architecture incorporating an ESM-2 encoder and a Mixture-of-Experts (MoE) layer.
|
12 |
+
|
13 |
+
## Model Architecture
|
14 |
+
|
15 |
+
The model `ActorModel_encoder_esm2` is defined in `utils.py`.
|
16 |
+
|
17 |
+
The key parameters used for instantiation are:
|
18 |
+
|
19 |
+
- `d_model`: Dimension of the model's internal representation (768).
|
20 |
+
- `nhead`: Number of attention heads (8).
|
21 |
+
- `num_encoder_layers`: Number of transformer encoder layers (8).
|
22 |
+
- `dim_feedforward`: Dimension of the feedforward network (`d_model * 2`).
|
23 |
+
- `esm2_dim`: Dimension of the ESM-2 embeddings (1280 for esm2_t33_650M_UR50D).
|
24 |
+
- `dropout`: Dropout rate (0.3).
|
25 |
+
- `num_experts`: Number of experts in the MoE layer (6).
|
26 |
+
- `top_k_experts`: Number of top experts to use (2).
|
27 |
+
- `device`: The device to run the model on.
|
28 |
+
|
29 |
+
## Files in this Repository
|
30 |
+
|
31 |
+
- `homo_mrna.pt`: The PyTorch state_dict of the trained model for Homo sapiens mRNA.
|
32 |
+
- `homo_circ.pt`: The PyTorch state_dict of the trained model for Homo sapiens circlar RNA.
|
33 |
+
- `Arabidopsis.pt`: The PyTorch state_dict of the trained model for Arabidopsis thaliana mRNA.
|
34 |
+
- `CR.pt`: The PyTorch state_dict of the trained model for Chlamydomonas reinhardtii mRNA.
|
35 |
+
- `EscherichiaColi.pt`: The PyTorch state_dict of the trained model for Escherichia coli mRNA.
|
36 |
+
- `PC.pt`: The PyTorch state_dict of the trained model for Penicillium chrysogenum mRNA.
|
37 |
+
- `TK.pt`: The PyTorch state_dict of the trained model for Thermococcus kodakarensis KOD1 mRNA.
|
38 |
+
- `utils.py`: Contains the definition of the `ActorModel_encoder_esm2` class and the `Tokenizer` class.
|
39 |
+
- `transformer_encoder_MoE.py`: Contains the definition of the `Encoder` class
|
40 |
+
- `README.md`: This file.
|
41 |
+
|
42 |
+
## How to Load the Model
|
43 |
+
|
44 |
+
Since this is a custom model, you need to download the `utils.py`,`transformer_encoder_MoE.py`, and the `.pt` file and then instantiate the model class and load the state dictionary.
|
45 |
+
|
46 |
+
1. **Download Files:**
|
47 |
+
You can download the files using the `huggingface_hub` library:
|
48 |
+
```python
|
49 |
+
from huggingface_hub import hf_hub_download
|
50 |
+
import os
|
51 |
+
|
52 |
+
repo_id = "sglin/RNARL"
|
53 |
+
local_dir = "./my_RNARL"
|
54 |
+
|
55 |
+
# Download model weights and utils.py
|
56 |
+
hf_hub_download(repo_id=repo_id, filename="homo_mrna.pt", local_dir=local_dir)
|
57 |
+
hf_hub_download(repo_id=repo_id, filename="homo_circ.pt", local_dir=local_dir)
|
58 |
+
hf_hub_download(repo_id=repo_id, filename="Arabidopsis.pt", local_dir=local_dir)
|
59 |
+
hf_hub_download(repo_id=repo_id, filename="CR.pt", local_dir=local_dir)
|
60 |
+
hf_hub_download(repo_id=repo_id, filename="EscherichiaColi.pt", local_dir=local_dir)
|
61 |
+
hf_hub_download(repo_id=repo_id, filename="PC.pt", local_dir=local_dir)
|
62 |
+
hf_hub_download(repo_id=repo_id, filename="TK.pt", local_dir=local_dir)
|
63 |
+
hf_hub_download(repo_id=repo_id, filename="utils.py", local_dir=local_dir)
|
64 |
+
hf_hub_download(repo_id=repo_id, filename="transformer_encoder_MoE.py", local_dir=local_dir)
|
65 |
+
|
66 |
+
# Now utils.py,transformer_encoder_MoE.py and model weights are in ./my_RNARL
|
67 |
+
```
|
68 |
+
|
69 |
+
2. **Import Model Class:**
|
70 |
+
|
71 |
+
```python
|
72 |
+
# Assuming you are in or have added ./my_RNARL to your path
|
73 |
+
# Example: If in local_dir
|
74 |
+
# import sys
|
75 |
+
# sys.path.append("./my_RNARL")
|
76 |
+
# from utils import Tokenizer, ActorModel_encoder_esm2
|
77 |
+
|
78 |
+
# Or if you copied utils.py to your current working directory:
|
79 |
+
from utils import Tokenizer, ActorModel_encoder_esm2
|
80 |
+
```
|
81 |
+
|
82 |
+
3. **Load ESM-2 (Dependency):**
|
83 |
+
The model requires the ESM-2 encoder. You'll need to load it separately, typically from Hugging Face Hub.
|
84 |
+
```python
|
85 |
+
from transformers import AutoTokenizer, EsmModel
|
86 |
+
import torch
|
87 |
+
|
88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
+
|
90 |
+
esm2_tokenizer = AutoTokenizer.from_pretrained("esm2_t33_650M_UR50D")
|
91 |
+
esm2_model = EsmModel.from_pretrained("esm2_t33_650M_UR50D").to(device)
|
92 |
+
esm2_model.eval()
|
93 |
+
esm2_dim = esm2_model.config.hidden_size # Get the actual dimension
|
94 |
+
```
|
95 |
+
*Note:* Your original script used a local path (`./esm2_model_t33_650M_UR50D`). Users loading from the Hub will likely prefer loading directly from the official Hugging Face repo unless you explicitly provide the ESM-2 files in your repo (which is usually not necessary as they are already on the Hub).
|
96 |
+
|
97 |
+
4. **Instantiate Custom Model and Load Weights:**
|
98 |
+
Instantiate your `ActorModel_encoder_esm2` using the parameters from your training script and load the state dictionary.
|
99 |
+
```python
|
100 |
+
# Define the parameters used during training
|
101 |
+
d_model = 768
|
102 |
+
nhead = 8
|
103 |
+
num_encoder_layers = 8
|
104 |
+
dim_feedforward = d_model * 2 # or the exact value you used
|
105 |
+
dropout = 0.3
|
106 |
+
num_experts = 6
|
107 |
+
top_k_experts = 2
|
108 |
+
# vocab_size needs to match your Tokenizer
|
109 |
+
tokenizer = Tokenizer() # Instantiate your custom tokenizer
|
110 |
+
vocab_size = len(tokenizer.tokens) # Get vocab size from your tokenizer
|
111 |
+
|
112 |
+
# Instantiate the model
|
113 |
+
model = ActorModel_encoder_esm2(
|
114 |
+
vocab_size=vocab_size,
|
115 |
+
d_model=d_model,
|
116 |
+
nhead=nhead,
|
117 |
+
num_encoder_layers=num_encoder_layers,
|
118 |
+
dim_feedforward=dim_feedforward,
|
119 |
+
esm2_dim=esm2_dim, # Use the esm2_model's dimension
|
120 |
+
dropout=dropout,
|
121 |
+
num_experts=num_experts,
|
122 |
+
top_k_experts=top_k_experts,
|
123 |
+
device=device
|
124 |
+
)
|
125 |
+
|
126 |
+
# Load the state dictionary
|
127 |
+
model_weights_path = os.path.join(local_dir, "homo_mrna.pt")
|
128 |
+
model.load_state_dict(torch.load(model_weights_path, map_location=device))
|
129 |
+
model.to(device)
|
130 |
+
model.eval()
|
131 |
+
|
132 |
+
print("Model loaded successfully!")
|
133 |
+
|
134 |
+
# Now you can use the 'model' object for inference
|
135 |
+
# Remember you also need your Tokenizer and the ESM-2 tokenizer/model
|
136 |
+
```
|
137 |
+
|
138 |
+
## Dependencies
|
139 |
+
|
140 |
+
- `torch`
|
141 |
+
- `transformers`
|
142 |
+
- `huggingface_hub`
|
143 |
+
- `pandas`
|
144 |
+
- `numpy`
|
145 |
+
- The specific ESM-2 model used (`esm2_t33_650M_UR50D` or the one you used).
|
146 |
+
|
147 |
+
## License
|
148 |
+
|
149 |
+
[Specify your license here, e.g., MIT, Apache 2.0]
|
150 |
+
|
151 |
+
## Contact
|
152 |
+
|
153 |
+
[Optional: Your email or other contact info]
|
TK.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3737e4d94a4e266e0c89105a98ddcf8a6e7f444af5ef88810db669cd4ecf084d
|
3 |
+
size 534069804
|
homo_circ.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5249399ad29c6939703e79ef65532f23867ba5ce0ae6b3e4c2e4e6075ae18466
|
3 |
+
size 534071260
|
homo_mrna.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3c825975f8deade11bf02d4783d0703d276f1b5ed7b70a406b01e8843255d66
|
3 |
+
size 534069218
|
transformer_encoder_MoE.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from torch.nn.parallel import parallel_apply
|
6 |
+
from typing import Tuple, List, Optional, Union
|
7 |
+
import torch.utils.checkpoint as checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
class MultiHeadAttention(nn.Module):
|
11 |
+
"""高效实现的多头注意力机制"""
|
12 |
+
|
13 |
+
def __init__(self, model_dim: int, n_heads: int):
|
14 |
+
super().__init__()
|
15 |
+
assert model_dim % n_heads == 0, "model_dim must be divisible by n_heads"
|
16 |
+
|
17 |
+
self.model_dim = model_dim
|
18 |
+
self.d_k = model_dim // n_heads
|
19 |
+
self.n_heads = n_heads
|
20 |
+
|
21 |
+
# 使用单个线性层同时计算Q, K, V投影,减少计算开销
|
22 |
+
self.qkv_linear = nn.Linear(model_dim, 3 * model_dim, bias=False)
|
23 |
+
self.out_linear = nn.Linear(model_dim, model_dim, bias=False)
|
24 |
+
|
25 |
+
# 初始化参数,提高训练稳定性
|
26 |
+
nn.init.xavier_uniform_(self.qkv_linear.weight)
|
27 |
+
nn.init.xavier_uniform_(self.out_linear.weight)
|
28 |
+
|
29 |
+
self.scale = 1.0 / math.sqrt(self.d_k)
|
30 |
+
|
31 |
+
def forward(self,
|
32 |
+
q: torch.Tensor,
|
33 |
+
k: torch.Tensor,
|
34 |
+
v: torch.Tensor,
|
35 |
+
mask: Optional[torch.Tensor] = None,
|
36 |
+
key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
37 |
+
batch_size = q.size(0)
|
38 |
+
|
39 |
+
# 如果输入相同,使用更高效的自注意力计算
|
40 |
+
is_self_attention = q.data_ptr() == k.data_ptr() == v.data_ptr()
|
41 |
+
if is_self_attention:
|
42 |
+
# [batch, seq, 3*dim] -> 3 x [batch, seq, dim]
|
43 |
+
qkv = self.qkv_linear(q).chunk(3, dim=-1)
|
44 |
+
q, k, v = map(
|
45 |
+
lambda x: x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2),
|
46 |
+
qkv
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
# 使用单独的线性变换进行异源注意力计算
|
50 |
+
q = self.qkv_linear(q)[:, :, :self.model_dim]
|
51 |
+
k = self.qkv_linear(k)[:, :, self.model_dim:2*self.model_dim]
|
52 |
+
v = self.qkv_linear(v)[:, :, 2*self.model_dim:]
|
53 |
+
|
54 |
+
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
55 |
+
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
56 |
+
v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
57 |
+
|
58 |
+
# 缩放点积注意力计算
|
59 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
60 |
+
|
61 |
+
# 掩码处理 (提高数值稳定性)
|
62 |
+
if mask is not None:
|
63 |
+
scores = scores.masked_fill(mask == 0, -6.0e4)
|
64 |
+
if key_padding_mask is not None:
|
65 |
+
scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -6.0e4)
|
66 |
+
|
67 |
+
attn_weights = F.softmax(scores, dim=-1)
|
68 |
+
|
69 |
+
# 使用注意力权重聚合值
|
70 |
+
context = torch.matmul(attn_weights, v)
|
71 |
+
|
72 |
+
# 重新组织维度并线性投影输出
|
73 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.model_dim)
|
74 |
+
output = self.out_linear(context)
|
75 |
+
|
76 |
+
return output
|
77 |
+
|
78 |
+
|
79 |
+
class MoE(nn.Module):
|
80 |
+
"""优化的混合专家模块,支持并行计算和更高效的专家选择
|
81 |
+
|
82 |
+
Args:
|
83 |
+
d_model (int): 模型隐藏层维度
|
84 |
+
num_experts (int): 专家数量
|
85 |
+
d_ff (int): 前馈层维度
|
86 |
+
dropout (float): Dropout概率
|
87 |
+
top_k (int): 每个token选择的专家数量
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, d_model: int, num_experts: int, d_ff: int, dropout: float, top_k: int):
|
91 |
+
super().__init__()
|
92 |
+
# 参数初始化
|
93 |
+
self.num_experts = num_experts
|
94 |
+
self.top_k = min(top_k, num_experts) # 确保top_k不超过专家数量,形状无变化
|
95 |
+
self.d_model = d_model
|
96 |
+
|
97 |
+
# 门控网络:将输入映射到专家分数 [d_model -> num_experts]
|
98 |
+
self.gate = nn.Linear(d_model, num_experts, bias=False)
|
99 |
+
|
100 |
+
# 专家网络:并行专家模块列表
|
101 |
+
self.experts = nn.ModuleList([
|
102 |
+
nn.Sequential( # 每个专家结构:
|
103 |
+
nn.Linear(d_model, d_ff, bias=False), # [d_model -> d_ff]
|
104 |
+
nn.GELU(), # 激活函数无形状变化
|
105 |
+
nn.Dropout(dropout), # 无形状变化
|
106 |
+
nn.Linear(d_ff, d_model, bias=False) # [d_ff -> d_model]
|
107 |
+
) for _ in range(num_experts)
|
108 |
+
])
|
109 |
+
|
110 |
+
# 参数初始化
|
111 |
+
for expert in self.experts:
|
112 |
+
nn.init.kaiming_uniform_(expert[0].weight, a=math.sqrt(5)) # 第一层线性权重初始化
|
113 |
+
nn.init.zeros_(expert[3].weight) # 输出层零初始化,形状保持 [d_ff, d_model]
|
114 |
+
|
115 |
+
nn.init.zeros_(self.gate.weight) # 门控网络零初始化,形状 [d_model, num_experts]
|
116 |
+
|
117 |
+
def orthogonal_loss(self) -> torch.Tensor:
|
118 |
+
"""计算专家网络之间的正交损失,提高专家多样性
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
torch.Tensor: ���交损失标量值
|
122 |
+
"""
|
123 |
+
total_loss = 0.0
|
124 |
+
num_pairs = 0
|
125 |
+
|
126 |
+
# 获取所有专家的第一层和最后一层权重
|
127 |
+
# expert_weights_1形状: [num_experts, d_ff, d_model]
|
128 |
+
expert_weights_1 = torch.stack([expert[0].weight for expert in self.experts])
|
129 |
+
# expert_weights_2形状: [num_experts, d_model, d_ff]
|
130 |
+
expert_weights_2 = torch.stack([expert[3].weight for expert in self.experts])
|
131 |
+
|
132 |
+
# 计算所有专家对之间的正交损失
|
133 |
+
for i in range(self.num_experts):
|
134 |
+
w1_i = expert_weights_1[i] # [d_ff, d_model]
|
135 |
+
w2_i = expert_weights_2[i] # [d_model, d_ff]
|
136 |
+
|
137 |
+
for j in range(i+1, self.num_experts):
|
138 |
+
w1_j = expert_weights_1[j] # [d_ff, d_model]
|
139 |
+
w2_j = expert_weights_2[j] # [d_model, d_ff]
|
140 |
+
|
141 |
+
# 计算第一层权重的相似度
|
142 |
+
w1_sim = torch.sum((w1_i @ w1_j.T)**2) / (w1_i.size(0) * w1_j.size(0)) # 标量
|
143 |
+
# 计算第二层权重的相似度
|
144 |
+
w2_sim = torch.sum((w2_i.T @ w2_j)**2) / (w2_i.size(1) * w2_j.size(1)) # 标量
|
145 |
+
|
146 |
+
total_loss += (w1_sim + w2_sim) / 2 # 平均相似度
|
147 |
+
num_pairs += 1
|
148 |
+
|
149 |
+
return total_loss / max(num_pairs, 1) # 平均正交损失
|
150 |
+
|
151 |
+
def entropy_regularization_loss(self, routing_probs: torch.Tensor) -> torch.Tensor:
|
152 |
+
"""计算熵正则化损失,鼓励更均匀的路由分布
|
153 |
+
|
154 |
+
Args:
|
155 |
+
routing_probs (torch.Tensor): 路由概率分布,形状 [batch*seq_len, num_experts]
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: 熵损失标量值
|
159 |
+
"""
|
160 |
+
# 使用数值稳定的log计算
|
161 |
+
log_probs = torch.log(torch.clamp(routing_probs, min=1e-6)) # 保持形状 [batch*seq, num_experts]
|
162 |
+
# 逐元素计算熵,保持维度
|
163 |
+
entropy = -torch.sum(routing_probs * log_probs, dim=-1) # 形状 [batch*seq]
|
164 |
+
return entropy.mean() # 标量
|
165 |
+
|
166 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
167 |
+
"""MoE前向传播,高效实现专家选择和组合
|
168 |
+
|
169 |
+
Args:
|
170 |
+
hidden_states (torch.Tensor): 输入张量,形状 [batch_size, seq_len, d_model]
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
174 |
+
- 输出张量 [batch_size, seq_len, d_model]
|
175 |
+
- 路由逻辑分数 [batch_size*seq_len, num_experts]
|
176 |
+
- 熵正则化损失标量值
|
177 |
+
"""
|
178 |
+
batch_size, seq_len, d_model = hidden_states.shape
|
179 |
+
combined_batch_size = batch_size * seq_len
|
180 |
+
# 展平输入用于并行处理
|
181 |
+
flat_hidden = hidden_states.reshape(combined_batch_size, d_model) # [batch*seq, d_model]
|
182 |
+
|
183 |
+
# 路由计算
|
184 |
+
router_logits = self.gate(flat_hidden) # [batch*seq, num_experts]
|
185 |
+
routing_probs = F.softmax(router_logits, dim=-1) # [batch*seq, num_experts]
|
186 |
+
|
187 |
+
# 选择top-k专家
|
188 |
+
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) # 均为 [batch*seq, top_k]
|
189 |
+
# 归一化权重
|
190 |
+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # [batch*seq, top_k]
|
191 |
+
|
192 |
+
|
193 |
+
# 并行计算所有专家输出
|
194 |
+
flat_expert_inputs = [flat_hidden] * self.num_experts # 列表包含num_experts个[batch*seq, d_model]
|
195 |
+
expert_outputs = parallel_apply(self.experts, flat_expert_inputs) # 列表包含num_experts个[batch*seq, d_model]
|
196 |
+
expert_outputs = torch.stack(expert_outputs, dim=1) # [batch*seq, num_experts, d_model]
|
197 |
+
|
198 |
+
# 构建专家权重矩阵
|
199 |
+
expert_weights_matrix = torch.zeros(
|
200 |
+
combined_batch_size, self.num_experts, device=hidden_states.device
|
201 |
+
) # [batch*seq, num_experts]
|
202 |
+
|
203 |
+
# 使用scatter_add高效聚合权重
|
204 |
+
for k in range(self.top_k):
|
205 |
+
k_indices = selected_experts[:, k] # [batch*seq]
|
206 |
+
k_weights = routing_weights[:, k].unsqueeze(1) # [batch*seq, 1]
|
207 |
+
# 将权重累加到对应位置
|
208 |
+
expert_weights_matrix.scatter_add_(
|
209 |
+
1,
|
210 |
+
k_indices.unsqueeze(1), # [batch*seq, 1]
|
211 |
+
k_weights # [batch*seq, 1]
|
212 |
+
) # 更新expert_weights_matrix
|
213 |
+
|
214 |
+
# 矩阵乘法组合专家输出
|
215 |
+
combined_output = torch.bmm(
|
216 |
+
expert_weights_matrix.unsqueeze(1), # [batch*seq, 1, num_experts]
|
217 |
+
expert_outputs # [batch*seq, num_experts, d_model]
|
218 |
+
).squeeze(1) # [batch*seq, d_model]
|
219 |
+
|
220 |
+
# 恢复原始形状
|
221 |
+
output = combined_output.reshape(batch_size, seq_len, d_model) # [batch_size, seq_len, d_model]
|
222 |
+
|
223 |
+
# 计算熵正则化损失
|
224 |
+
entropy_loss = self.entropy_regularization_loss(routing_probs)
|
225 |
+
|
226 |
+
return output, router_logits, entropy_loss
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
class EncoderLayer(nn.Module):
|
231 |
+
"""优化的编码器层,支持梯度检查点和残差连接预归一化"""
|
232 |
+
|
233 |
+
def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
|
234 |
+
dropout: float, num_experts: int, top_k: int):
|
235 |
+
super().__init__()
|
236 |
+
self.model_dim = model_dim
|
237 |
+
|
238 |
+
# 使用预归一化(Pre-LN)结构,提高训练稳定性
|
239 |
+
self.norm1 = nn.LayerNorm(model_dim)
|
240 |
+
self.norm2 = nn.LayerNorm(model_dim)
|
241 |
+
|
242 |
+
self.self_attn = MultiHeadAttention(model_dim, n_heads)
|
243 |
+
self.moe = MoE(model_dim, num_experts, ff_hidden_dim, dropout, top_k)
|
244 |
+
|
245 |
+
self.dropout = nn.Dropout(dropout)
|
246 |
+
self.dropout1 = nn.Dropout(dropout)
|
247 |
+
self.dropout2 = nn.Dropout(dropout)
|
248 |
+
|
249 |
+
# 可选的投影层,处理残差连接尺寸不匹配的情况
|
250 |
+
self.use_projection = False
|
251 |
+
if not self.use_projection:
|
252 |
+
self.residual_scale = nn.Parameter(torch.ones(1))
|
253 |
+
|
254 |
+
def _sa_block(self, x: torch.Tensor,
|
255 |
+
mask: Optional[torch.Tensor] = None,
|
256 |
+
key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
257 |
+
"""封装自注意力计算,便于梯度检查点使用"""
|
258 |
+
x = self.self_attn(x, x, x, mask=mask, key_padding_mask=key_padding_mask)
|
259 |
+
return self.dropout1(x)
|
260 |
+
|
261 |
+
def _moe_block(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
262 |
+
"""封装MoE计算,便于梯度检查点使用"""
|
263 |
+
return self.moe(x)
|
264 |
+
|
265 |
+
def forward(self,
|
266 |
+
x: torch.Tensor,
|
267 |
+
src_mask: Optional[torch.Tensor] = None,
|
268 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
269 |
+
use_checkpoint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
270 |
+
"""
|
271 |
+
编码器层前向传播
|
272 |
+
|
273 |
+
Args:
|
274 |
+
x: 输入张量 [batch_size, seq_len, model_dim]
|
275 |
+
src_mask: 源序列掩码
|
276 |
+
src_key_padding_mask: 填充掩码
|
277 |
+
use_checkpoint: 是否使用梯度检查点以节省内存
|
278 |
+
"""
|
279 |
+
# 预归一化结构 (Pre-LN)
|
280 |
+
normalized_x = self.norm1(x)
|
281 |
+
|
282 |
+
# 自注意力块 (可选梯度检查点)
|
283 |
+
if use_checkpoint and self.training:
|
284 |
+
attn_output = checkpoint.checkpoint(
|
285 |
+
self._sa_block, normalized_x, src_mask, src_key_padding_mask
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
attn_output = self._sa_block(normalized_x, src_mask, src_key_padding_mask)
|
289 |
+
|
290 |
+
# 第一个残差连接
|
291 |
+
x = x + attn_output * self.residual_scale
|
292 |
+
|
293 |
+
# 预归一化
|
294 |
+
normalized_x = self.norm2(x)
|
295 |
+
|
296 |
+
# MoE块 (可选梯度检查点)
|
297 |
+
if use_checkpoint and self.training:
|
298 |
+
moe_output, router_logits, entropy_loss = checkpoint.checkpoint(
|
299 |
+
self._moe_block, normalized_x
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
moe_output, router_logits, entropy_loss = self._moe_block(normalized_x)
|
303 |
+
|
304 |
+
# 第二个残差连接
|
305 |
+
x = x + self.dropout2(moe_output) * self.residual_scale
|
306 |
+
|
307 |
+
return x, router_logits, entropy_loss
|
308 |
+
|
309 |
+
|
310 |
+
class PositionwiseFeedForward(nn.Module):
|
311 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
312 |
+
super(PositionwiseFeedForward, self).__init__()
|
313 |
+
self.linear1 = nn.Linear(d_model, d_ff)
|
314 |
+
self.linear2 = nn.Linear(d_ff, d_model)
|
315 |
+
self.dropout = nn.Dropout(dropout)
|
316 |
+
self.relu = nn.ReLU()
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
return self.linear2(self.dropout(self.relu(self.linear1(x))))
|
320 |
+
|
321 |
+
class EncoderLayer_nomoe(nn.Module):
|
322 |
+
|
323 |
+
def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
|
324 |
+
dropout: float):
|
325 |
+
super().__init__()
|
326 |
+
|
327 |
+
# 使用预归一化(Pre-LN)结构,提高训练稳定性
|
328 |
+
self.norm1 = nn.LayerNorm(model_dim)
|
329 |
+
self.norm2 = nn.LayerNorm(model_dim)
|
330 |
+
|
331 |
+
self.self_attn = MultiHeadAttention(model_dim, n_heads)
|
332 |
+
self.feed_forward = PositionwiseFeedForward(model_dim, ff_hidden_dim, dropout)
|
333 |
+
|
334 |
+
|
335 |
+
self.dropout1 = nn.Dropout(dropout)
|
336 |
+
self.dropout2 = nn.Dropout(dropout)
|
337 |
+
|
338 |
+
def forward(self,
|
339 |
+
x: torch.Tensor,
|
340 |
+
src_mask: Optional[torch.Tensor] = None,
|
341 |
+
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
342 |
+
|
343 |
+
# 预归一化结构 (Pre-LN)
|
344 |
+
normalized_x = self.norm1(x)
|
345 |
+
|
346 |
+
attn_output = self.self_attn(normalized_x, normalized_x, normalized_x, src_mask,src_key_padding_mask)
|
347 |
+
|
348 |
+
# 第一个残差连接
|
349 |
+
x = x + self.dropout1(attn_output)
|
350 |
+
|
351 |
+
# 预归一化
|
352 |
+
normalized_x = self.norm2(x)
|
353 |
+
|
354 |
+
ff_output = self.feed_forward(normalized_x)
|
355 |
+
|
356 |
+
# 第二个残差连接
|
357 |
+
x = x + self.dropout2(ff_output)
|
358 |
+
|
359 |
+
return x
|
360 |
+
|
361 |
+
class PositionalEncoding(nn.Module):
|
362 |
+
"""高效实现的位置编码"""
|
363 |
+
|
364 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
365 |
+
super().__init__()
|
366 |
+
self.dropout = nn.Dropout(p=dropout)
|
367 |
+
|
368 |
+
# 一次性计算并缓存位置编码
|
369 |
+
pe = torch.zeros(1, max_len, d_model)
|
370 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
371 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
372 |
+
|
373 |
+
# 更高效的位置编码计算
|
374 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
375 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
376 |
+
|
377 |
+
# 注册缓冲区而不是参数,节省内存
|
378 |
+
self.register_buffer('pe', pe)
|
379 |
+
|
380 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
381 |
+
"""
|
382 |
+
添加位置编码到输入
|
383 |
+
|
384 |
+
Args:
|
385 |
+
x: 输入张量 [batch_size, seq_len, model_dim]
|
386 |
+
"""
|
387 |
+
pos_encoding = self.pe[:, :x.size(1)]
|
388 |
+
x = x + pos_encoding
|
389 |
+
return self.dropout(x)
|
390 |
+
|
391 |
+
|
392 |
+
class Encoder(nn.Module):
|
393 |
+
"""优化的Encoder架构"""
|
394 |
+
|
395 |
+
def __init__(self,
|
396 |
+
input_dim: int,
|
397 |
+
model_dim: int,
|
398 |
+
n_heads: int,
|
399 |
+
num_layers: int,
|
400 |
+
ff_hidden_dim: int,
|
401 |
+
dropout: float,
|
402 |
+
num_experts: int,
|
403 |
+
top_k: int,
|
404 |
+
if_embedding: bool = True,
|
405 |
+
if_pos_encoding: bool = True,
|
406 |
+
use_checkpointing: bool = False):
|
407 |
+
super().__init__()
|
408 |
+
|
409 |
+
self.model_dim = model_dim
|
410 |
+
self.num_layers = num_layers
|
411 |
+
self.if_embedding = if_embedding
|
412 |
+
self.if_pos_encoding = if_pos_encoding
|
413 |
+
self.use_checkpointing = use_checkpointing
|
414 |
+
|
415 |
+
# 嵌入层
|
416 |
+
if if_embedding:
|
417 |
+
self.embedding = nn.Embedding(input_dim, model_dim)
|
418 |
+
# 改善嵌入初始化
|
419 |
+
nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
|
420 |
+
|
421 |
+
# 位置编码
|
422 |
+
if if_pos_encoding:
|
423 |
+
self.pos_encoding = PositionalEncoding(model_dim, dropout)
|
424 |
+
|
425 |
+
# 编码器层
|
426 |
+
self.layers = nn.ModuleList([
|
427 |
+
EncoderLayer(
|
428 |
+
model_dim, n_heads, ff_hidden_dim, dropout, num_experts, top_k
|
429 |
+
) for _ in range(num_layers)
|
430 |
+
])
|
431 |
+
|
432 |
+
# 输出归一化
|
433 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
434 |
+
|
435 |
+
def forward(self,
|
436 |
+
src: torch.Tensor,
|
437 |
+
src_mask: Optional[torch.Tensor] = None,
|
438 |
+
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
|
439 |
+
"""
|
440 |
+
编码器前向传播
|
441 |
+
|
442 |
+
Args:
|
443 |
+
src: 输入序列 [batch_size, seq_len] 或 [batch_size, seq_len, model_dim]
|
444 |
+
src_mask: 源序列掩码
|
445 |
+
src_key_padding_mask: 填充掩码
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
tuple: (输出张量, 路由逻辑列表, 熵损失)
|
449 |
+
"""
|
450 |
+
# 嵌入处理
|
451 |
+
if self.if_embedding:
|
452 |
+
x = self.embedding(src) * math.sqrt(self.model_dim)
|
453 |
+
else:
|
454 |
+
x = src
|
455 |
+
|
456 |
+
# 位置编码
|
457 |
+
if self.if_pos_encoding:
|
458 |
+
x = self.pos_encoding(x)
|
459 |
+
|
460 |
+
# 跟踪熵损失和路由逻辑
|
461 |
+
total_entropy_loss = 0.0
|
462 |
+
router_logits_list = []
|
463 |
+
|
464 |
+
# 通过编码器层
|
465 |
+
for layer in self.layers:
|
466 |
+
x, router_logits, entropy_loss = layer(
|
467 |
+
x,
|
468 |
+
src_mask=src_mask,
|
469 |
+
src_key_padding_mask=src_key_padding_mask,
|
470 |
+
use_checkpoint=self.use_checkpointing
|
471 |
+
)
|
472 |
+
total_entropy_loss += entropy_loss
|
473 |
+
|
474 |
+
# 只保存CPU版本的路由逻辑,降低内存使用
|
475 |
+
if not self.training: # 仅在推理时保存路由逻辑
|
476 |
+
router_logits_list.append(router_logits.detach().cpu().tolist())
|
477 |
+
|
478 |
+
# 应用最终层归一化
|
479 |
+
x = self.final_norm(x)
|
480 |
+
|
481 |
+
# 计算平均熵损失
|
482 |
+
avg_entropy_loss = total_entropy_loss / self.num_layers
|
483 |
+
|
484 |
+
return x, router_logits_list, avg_entropy_loss
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
class Encoder_nomoe(nn.Module):
|
489 |
+
"""优化的Encoder架构"""
|
490 |
+
|
491 |
+
def __init__(self,
|
492 |
+
input_dim: int,
|
493 |
+
model_dim: int,
|
494 |
+
n_heads: int,
|
495 |
+
num_layers: int,
|
496 |
+
ff_hidden_dim: int,
|
497 |
+
dropout: float,
|
498 |
+
if_embedding: bool = True,
|
499 |
+
if_pos_encoding: bool = True):
|
500 |
+
super().__init__()
|
501 |
+
|
502 |
+
self.model_dim = model_dim
|
503 |
+
self.num_layers = num_layers
|
504 |
+
self.if_embedding = if_embedding
|
505 |
+
self.if_pos_encoding = if_pos_encoding
|
506 |
+
|
507 |
+
# 嵌入层
|
508 |
+
if if_embedding:
|
509 |
+
self.embedding = nn.Embedding(input_dim, model_dim)
|
510 |
+
# 改善嵌入初始化
|
511 |
+
nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
|
512 |
+
|
513 |
+
# 位置编码
|
514 |
+
if if_pos_encoding:
|
515 |
+
self.pos_encoding = PositionalEncoding(model_dim, dropout)
|
516 |
+
|
517 |
+
# 编码器层
|
518 |
+
self.layers = nn.ModuleList([
|
519 |
+
EncoderLayer_nomoe(
|
520 |
+
model_dim, n_heads, ff_hidden_dim, dropout
|
521 |
+
) for _ in range(num_layers)
|
522 |
+
])
|
523 |
+
|
524 |
+
# 输出归一化
|
525 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
526 |
+
|
527 |
+
def forward(self,
|
528 |
+
src: torch.Tensor,
|
529 |
+
src_mask: Optional[torch.Tensor] = None,
|
530 |
+
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
|
531 |
+
|
532 |
+
# 嵌入处理
|
533 |
+
if self.if_embedding:
|
534 |
+
x = self.embedding(src) * math.sqrt(self.model_dim)
|
535 |
+
else:
|
536 |
+
x = src
|
537 |
+
|
538 |
+
# 位置编码
|
539 |
+
if self.if_pos_encoding:
|
540 |
+
x = self.pos_encoding(x)
|
541 |
+
|
542 |
+
|
543 |
+
|
544 |
+
# 通过编码器层
|
545 |
+
for layer in self.layers:
|
546 |
+
x = layer(
|
547 |
+
x,
|
548 |
+
src_mask=src_mask,
|
549 |
+
src_key_padding_mask=src_key_padding_mask
|
550 |
+
)
|
551 |
+
|
552 |
+
# 应用最终层归一化
|
553 |
+
x = self.final_norm(x)
|
554 |
+
|
555 |
+
return x
|
utils.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from torch.utils.data.distributed import DistributedSampler
|
7 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
8 |
+
from transformer_encoder_MoE import Encoder,Encoder_nomoe
|
9 |
+
from itertools import chain
|
10 |
+
from torch.nn.parallel import parallel_apply
|
11 |
+
from typing import List, Dict, Tuple, Optional, Union
|
12 |
+
from torchcrf import CRF
|
13 |
+
|
14 |
+
|
15 |
+
class Tokenizer:
|
16 |
+
"""处理序列编码和解码的分词器,支持蛋白质序列和mRNA序列。"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
# 定义特殊标记和生物序列标记
|
20 |
+
self.special_tokens = ['[START]', '[END]', '[PAD]', '[UNK]', '[SEG]']
|
21 |
+
self.amino_acids = ['A', 'R', 'S', 'I', 'L', 'G', 'V', 'T', 'P', 'N',
|
22 |
+
'D', 'C', 'Q', 'E', 'H', 'K', 'F', 'Y', 'M', 'W', '*']
|
23 |
+
self.protein_alphabet = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I',
|
24 |
+
'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
|
25 |
+
# 生成所有可能的密码子组合
|
26 |
+
self.codons = [''.join([n1, n2, n3]) for n1 in 'UCAG' for n2 in 'UCAG' for n3 in 'UCAG']
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# 合并所有标记并创建映射
|
31 |
+
self.tokens = self.special_tokens + self.amino_acids + self.codons
|
32 |
+
self.token_to_id = {token: idx for idx, token in enumerate(self.tokens)}
|
33 |
+
self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}
|
34 |
+
|
35 |
+
# 缓存常用的特殊标记索引以提高性能
|
36 |
+
self.padding_idx = self.token_to_id['[PAD]']
|
37 |
+
self.start_idx = self.token_to_id['[START]']
|
38 |
+
self.end_idx = self.token_to_id['[END]']
|
39 |
+
self.unk_idx = self.token_to_id['[UNK]']
|
40 |
+
self.seg_idx = self.token_to_id['[SEG]']
|
41 |
+
|
42 |
+
def encode_pro(self, sequence: str, max_length: int) -> List[int]:
|
43 |
+
"""编码蛋白质序列。
|
44 |
+
|
45 |
+
Args:
|
46 |
+
sequence: 输入的蛋白质序列
|
47 |
+
max_length: 编码后序列的最大长度
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
编码后的ID列表
|
51 |
+
"""
|
52 |
+
# 添加开始标记,并为每个字符获取ID
|
53 |
+
ids = [self.start_idx] + [self.token_to_id.get(token, self.unk_idx) for token in sequence]
|
54 |
+
|
55 |
+
# 处理序列长度并添加结束标记
|
56 |
+
if len(ids) < max_length - 1:
|
57 |
+
ids.append(self.end_idx)
|
58 |
+
else:
|
59 |
+
ids = ids[:max_length-1] + [self.end_idx]
|
60 |
+
|
61 |
+
return ids
|
62 |
+
|
63 |
+
def encode_mrna(self, sequence: str, max_length: int) -> List[int]:
|
64 |
+
"""编码mRNA序列,每三个核苷酸作为一个密码子。
|
65 |
+
|
66 |
+
Args:
|
67 |
+
sequence: 输入的mRNA序列
|
68 |
+
max_length: 编码后序列的最大长度
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
编码后的ID列表
|
72 |
+
"""
|
73 |
+
ids = [self.start_idx]
|
74 |
+
|
75 |
+
# 每三个字符(一个密码子)作为一个单位处理
|
76 |
+
for i in range(0, len(sequence), 3):
|
77 |
+
codon = sequence[i:i+3]
|
78 |
+
if len(codon) == 3 and codon in self.token_to_id:
|
79 |
+
ids.append(self.token_to_id[codon])
|
80 |
+
else:
|
81 |
+
ids.append(self.unk_idx)
|
82 |
+
|
83 |
+
# 处理序列长度并添加结束标记
|
84 |
+
if len(ids) < max_length - 1:
|
85 |
+
ids.append(self.end_idx)
|
86 |
+
else:
|
87 |
+
ids = ids[:max_length-1] + [self.end_idx]
|
88 |
+
|
89 |
+
return ids
|
90 |
+
|
91 |
+
def decode(self, ids: List[int]) -> str:
|
92 |
+
"""将ID序列解码为文本。
|
93 |
+
|
94 |
+
Args:
|
95 |
+
ids: 编码后的ID列表
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
解码后的文本
|
99 |
+
"""
|
100 |
+
return ''.join([self.id_to_token.get(id, '[UNK]') for id in ids])
|
101 |
+
|
102 |
+
def pad(self, ids: List[int], max_length: int) -> List[int]:
|
103 |
+
"""对序列进行填充至指定长度。
|
104 |
+
|
105 |
+
Args:
|
106 |
+
ids: 编码后的ID列表
|
107 |
+
max_length: 目标长度
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
填充后的ID列表
|
111 |
+
"""
|
112 |
+
padding_length = max_length - len(ids)
|
113 |
+
if padding_length > 0:
|
114 |
+
return ids + [self.padding_idx] * padding_length
|
115 |
+
return ids
|
116 |
+
|
117 |
+
|
118 |
+
# 生成密码子表和相关映射
|
119 |
+
class BiologicalMappings:
|
120 |
+
"""生物序列编码的映射工具类。"""
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def get_codon_table() -> Dict[str, str]:
|
124 |
+
"""返回密码子到氨基酸的映射表。"""
|
125 |
+
return {
|
126 |
+
'GCU':'A', 'GCC':'A', 'GCA':'A', 'GCG':'A', 'CGU':'R', 'CGC':'R',
|
127 |
+
'CGA':'R', 'CGG':'R', 'AGA':'R', 'AGG':'R', 'UCU':'S', 'UCC':'S',
|
128 |
+
'UCA':'S', 'UCG':'S', 'AGU':'S', 'AGC':'S', 'AUU':'I', 'AUC':'I',
|
129 |
+
'AUA':'I', 'UUA':'L', 'UUG':'L', 'CUU':'L', 'CUC':'L', 'CUA':'L',
|
130 |
+
'CUG':'L', 'GGU':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G', 'GUU':'V',
|
131 |
+
'GUC':'V', 'GUA':'V', 'GUG':'V', 'ACU':'T', 'ACC':'T', 'ACA':'T',
|
132 |
+
'ACG':'T', 'CCU':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P', 'AAU':'N',
|
133 |
+
'AAC':'N', 'GAU':'D', 'GAC':'D', 'UGU':'C', 'UGC':'C', 'CAA':'Q',
|
134 |
+
'CAG':'Q', 'GAA':'E', 'GAG':'E', 'CAU':'H', 'CAC':'H', 'AAA':'K',
|
135 |
+
'AAG':'K', 'UUU':'F', 'UUC':'F', 'UAU':'Y', 'UAC':'Y', 'AUG':'M',
|
136 |
+
'UGG':'W','UAG':'*', 'UGA':'*', 'UAA':'*'}
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def get_amino_acid_to_codon() -> Dict[str, List[str]]:
|
140 |
+
"""返回氨基酸到密码子的映射表。"""
|
141 |
+
return {
|
142 |
+
'A':['GCU','GCC','GCA','GCG'], 'R':['CGU','CGC','CGA','CGG','AGA','AGG'],
|
143 |
+
'S':['UCU','UCC','UCA','UCG','AGU','AGC'],'I':['AUU','AUC','AUA'],
|
144 |
+
'L':['UUA','UUG','CUU','CUC','CUA','CUG'],'G':['GGU','GGC','GGA','GGG'],
|
145 |
+
'V':['GUU','GUC','GUA','GUG'],'T':['ACU','ACC','ACA','ACG'],
|
146 |
+
'P':['CCU','CCC','CCA','CCG'],'N':['AAU','AAC'],'D':['GAU','GAC'],
|
147 |
+
'C':['UGU','UGC'],'Q':['CAA','CAG'],'E':['GAA','GAG'],'H':['CAU','CAC'],
|
148 |
+
'K':['AAA','AAG'],'F':['UUU','UUC'],'Y':['UAU','UAC'],'M':['AUG'],'W':['UGG'],
|
149 |
+
'*':['UAG','UGA','UAA']
|
150 |
+
}
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def create_token_mapping(tokenizer: Tokenizer) -> torch.Tensor:
|
154 |
+
"""创建从密码子令牌到氨基酸令牌的映射张量。
|
155 |
+
|
156 |
+
Args:
|
157 |
+
tokenizer: 用于获取令牌到ID映射的分词器
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
映射张量,索引为密码子ID,值为对应的氨基酸ID
|
161 |
+
"""
|
162 |
+
codon_table = BiologicalMappings.get_codon_table()
|
163 |
+
token_codon_to_amino_acid = torch.full((len(tokenizer.tokens),),
|
164 |
+
tokenizer.unk_idx,
|
165 |
+
dtype=torch.long)
|
166 |
+
|
167 |
+
for codon, amino_acid in codon_table.items():
|
168 |
+
codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
|
169 |
+
amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
|
170 |
+
token_codon_to_amino_acid[codon_id] = amino_acid_id
|
171 |
+
|
172 |
+
return token_codon_to_amino_acid
|
173 |
+
|
174 |
+
|
175 |
+
class ActorModel_encoder_noesm2(nn.Module):
|
176 |
+
"""基于编码器的Actor模型,用于序列生成任务。"""
|
177 |
+
|
178 |
+
def __init__(self, vocab_size: int, d_model: int, nhead: int,
|
179 |
+
num_encoder_layers: int, dim_feedforward: int, dropout: float,
|
180 |
+
num_experts: int, top_k_experts: int, device: torch.device):
|
181 |
+
"""初始化模型。
|
182 |
+
|
183 |
+
Args:
|
184 |
+
vocab_size: 词汇表大小
|
185 |
+
d_model: 模型维度
|
186 |
+
nhead: 注意力头数
|
187 |
+
num_encoder_layers: 编码器层数
|
188 |
+
dim_feedforward: 前馈网络维度
|
189 |
+
dropout: Dropout率
|
190 |
+
num_experts: 专家数量
|
191 |
+
top_k_experts: 使用的顶部专家数量
|
192 |
+
device: 计算设备
|
193 |
+
"""
|
194 |
+
super(ActorModel_encoder_noesm2, self).__init__()
|
195 |
+
self.device = device
|
196 |
+
|
197 |
+
# 获取生物映射并预计算掩码
|
198 |
+
self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
|
199 |
+
self.precomputed_masks = self._precompute_masks()
|
200 |
+
|
201 |
+
# 创建编码器和输出层
|
202 |
+
self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
|
203 |
+
dim_feedforward, dropout, num_experts, top_k_experts)
|
204 |
+
|
205 |
+
# 使用序列化的输出层以提高性能
|
206 |
+
self.mrna_output_layer = nn.Sequential(
|
207 |
+
nn.Linear(d_model, d_model//2),
|
208 |
+
nn.LayerNorm(d_model//2),
|
209 |
+
nn.ReLU(),
|
210 |
+
nn.Dropout(dropout),
|
211 |
+
nn.Linear(d_model//2, vocab_size)
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
def _precompute_masks(self) -> Dict[int, torch.Tensor]:
|
217 |
+
"""预计算每个氨基酸对应的密码子掩码,以提高性能。"""
|
218 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
219 |
+
masks = {}
|
220 |
+
|
221 |
+
for amino_acid, codons in self.amino_acid_to_codon.items():
|
222 |
+
amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
|
223 |
+
mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
224 |
+
|
225 |
+
for codon in codons:
|
226 |
+
codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
|
227 |
+
if codon_id != tokenizer.unk_idx:
|
228 |
+
mask[codon_id] = True
|
229 |
+
|
230 |
+
masks[amino_acid_id] = mask
|
231 |
+
|
232 |
+
return masks
|
233 |
+
|
234 |
+
def forward(self, tokenizer_encoded_proteins: torch.Tensor) -> Tuple[torch.Tensor, list, torch.Tensor]:
|
235 |
+
"""模型前向传播。
|
236 |
+
|
237 |
+
Args:
|
238 |
+
tokenizer_encoded_proteins: 编码后的蛋白质序列,形状为(batch_size, seq_len)
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
logits: 输出逻辑值,表示模型预测
|
242 |
+
router_logits_list: 路由器逻辑值列表
|
243 |
+
entropy_loss: 熵损失
|
244 |
+
"""
|
245 |
+
# 创建源序列的填充掩码
|
246 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
247 |
+
src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
|
248 |
+
|
249 |
+
# 通过编码器处理
|
250 |
+
x, router_logits_list, entropy_loss = self.encoder(
|
251 |
+
tokenizer_encoded_proteins,
|
252 |
+
src_key_padding_mask=src_padding_mask
|
253 |
+
)
|
254 |
+
|
255 |
+
# 为批次中的每个项目和序列位置生成掩码
|
256 |
+
batch_size, seq_len = tokenizer_encoded_proteins.shape
|
257 |
+
|
258 |
+
# 使用索引查询预计算的掩码,通过广播优化性能
|
259 |
+
amino_acid_to_codon_mask = torch.stack([
|
260 |
+
self.precomputed_masks.get(
|
261 |
+
tok.item(),
|
262 |
+
torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
263 |
+
)
|
264 |
+
for tok in tokenizer_encoded_proteins.reshape(-1)
|
265 |
+
]).view(batch_size, seq_len, -1)
|
266 |
+
|
267 |
+
# 计算输出逻辑值并应用掩码
|
268 |
+
mrna_logits = self.mrna_output_layer(x)
|
269 |
+
|
270 |
+
|
271 |
+
# 使用masking而不是scatter来提高性能
|
272 |
+
mrna_logits = mrna_logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
|
273 |
+
|
274 |
+
return mrna_logits, router_logits_list, entropy_loss
|
275 |
+
|
276 |
+
class ActorModel_encoder_esm2(nn.Module):
|
277 |
+
"""基于编码器的Actor模型,用于序列生成任务。"""
|
278 |
+
|
279 |
+
def __init__(self, vocab_size: int, d_model: int, nhead: int,
|
280 |
+
num_encoder_layers: int, dim_feedforward: int, esm2_dim: int,dropout: float,
|
281 |
+
num_experts: int, top_k_experts: int, device: torch.device):
|
282 |
+
|
283 |
+
super(ActorModel_encoder_esm2, self).__init__()
|
284 |
+
self.device = device
|
285 |
+
|
286 |
+
# 获取生物映射并预计算掩码
|
287 |
+
self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
|
288 |
+
self.precomputed_masks = self._precompute_masks()
|
289 |
+
|
290 |
+
self.dim_trans=nn.Linear(esm2_dim, d_model)
|
291 |
+
# 创建编码器和输出层
|
292 |
+
self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
|
293 |
+
dim_feedforward, dropout, num_experts, top_k_experts,if_embedding=False,if_pos_encoding=False)
|
294 |
+
|
295 |
+
# 使用序列化的输出层以提高性能
|
296 |
+
self.mrna_output_layer = nn.Sequential(
|
297 |
+
nn.Linear(d_model, d_model//2),
|
298 |
+
nn.LayerNorm(d_model//2),
|
299 |
+
nn.ReLU(),
|
300 |
+
nn.Dropout(dropout),
|
301 |
+
nn.Linear(d_model//2, vocab_size)
|
302 |
+
)
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
def _precompute_masks(self) -> Dict[int, torch.Tensor]:
|
308 |
+
"""预计算每个氨基酸对应的密码子掩码,以提高性能。"""
|
309 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
310 |
+
masks = {}
|
311 |
+
|
312 |
+
for amino_acid, codons in self.amino_acid_to_codon.items():
|
313 |
+
amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
|
314 |
+
mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
315 |
+
|
316 |
+
for codon in codons:
|
317 |
+
codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
|
318 |
+
if codon_id != tokenizer.unk_idx:
|
319 |
+
mask[codon_id] = True
|
320 |
+
|
321 |
+
masks[amino_acid_id] = mask
|
322 |
+
|
323 |
+
return masks
|
324 |
+
|
325 |
+
def forward(self, tokenizer_encoded_proteins,esm2_encoded_proteins) -> Tuple[torch.Tensor, list, torch.Tensor]:
|
326 |
+
|
327 |
+
# 创建源序列的填充掩码
|
328 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
329 |
+
src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
|
330 |
+
|
331 |
+
# 通过编码器处理
|
332 |
+
x=self.dim_trans(esm2_encoded_proteins)
|
333 |
+
|
334 |
+
x, router_logits_list, entropy_loss = self.encoder(
|
335 |
+
x,
|
336 |
+
src_key_padding_mask=src_padding_mask
|
337 |
+
)
|
338 |
+
|
339 |
+
# 为批次中的每个项目和序列位置生成掩码
|
340 |
+
batch_size, seq_len = tokenizer_encoded_proteins.shape
|
341 |
+
|
342 |
+
# 使用索引查询预计算的掩码,通过广播优化性能
|
343 |
+
amino_acid_to_codon_mask = torch.stack([
|
344 |
+
self.precomputed_masks.get(
|
345 |
+
tok.item(),
|
346 |
+
torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
347 |
+
)
|
348 |
+
for tok in tokenizer_encoded_proteins.reshape(-1)
|
349 |
+
]).view(batch_size, seq_len, -1)
|
350 |
+
|
351 |
+
# 计算输出逻辑值并应用掩码
|
352 |
+
mrna_logits = self.mrna_output_layer(x)
|
353 |
+
|
354 |
+
|
355 |
+
# 使用masking而不是scatter来提高性能
|
356 |
+
mrna_logits = mrna_logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
|
357 |
+
|
358 |
+
return mrna_logits, router_logits_list, entropy_loss
|
359 |
+
|
360 |
+
def get_embedding(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
|
361 |
+
|
362 |
+
# 创建源序列的填充掩码
|
363 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
364 |
+
src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
|
365 |
+
|
366 |
+
# 通过编码器处理
|
367 |
+
x=self.dim_trans(esm2_encoded_proteins)
|
368 |
+
|
369 |
+
x, router_logits_list, entropy_loss = self.encoder(
|
370 |
+
x,
|
371 |
+
src_key_padding_mask=src_padding_mask
|
372 |
+
)
|
373 |
+
return x
|
374 |
+
def get_router_logits(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
|
375 |
+
|
376 |
+
# 创建源序列的填充掩码
|
377 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
378 |
+
src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
|
379 |
+
|
380 |
+
# 通过编码器处理
|
381 |
+
x=self.dim_trans(esm2_encoded_proteins)
|
382 |
+
|
383 |
+
x, router_logits_list, entropy_loss = self.encoder(
|
384 |
+
x,
|
385 |
+
src_key_padding_mask=src_padding_mask
|
386 |
+
)
|
387 |
+
return router_logits_list
|
388 |
+
|
389 |
+
class ActorModel_encoder_nomoe(nn.Module):
|
390 |
+
"""基于编码器的Actor模型,用于序列生成任务。"""
|
391 |
+
|
392 |
+
def __init__(self, vocab_size: int, d_model: int, nhead: int,
|
393 |
+
num_encoder_layers: int, dim_feedforward: int, esm2_dim: int,dropout: float, device: torch.device):
|
394 |
+
super(ActorModel_encoder_nomoe, self).__init__()
|
395 |
+
self.device = device
|
396 |
+
|
397 |
+
# 获取生物映射并预计算掩码
|
398 |
+
self.amino_acid_to_codon = BiologicalMappings.get_amino_acid_to_codon()
|
399 |
+
self.precomputed_masks = self._precompute_masks()
|
400 |
+
|
401 |
+
self.dim_trans=nn.Linear(esm2_dim, d_model)
|
402 |
+
# 创建编码器和输出层
|
403 |
+
self.encoder = Encoder_nomoe(vocab_size, d_model, nhead, num_encoder_layers,
|
404 |
+
dim_feedforward, dropout,if_embedding=False,if_pos_encoding=False)
|
405 |
+
|
406 |
+
# 使用序列化的输出层以提高性能
|
407 |
+
self.output_layer = nn.Sequential(
|
408 |
+
nn.Linear(d_model, d_model//2),
|
409 |
+
nn.LayerNorm(d_model//2),
|
410 |
+
nn.ReLU(),
|
411 |
+
nn.Dropout(dropout),
|
412 |
+
nn.Linear(d_model//2, vocab_size)
|
413 |
+
)
|
414 |
+
|
415 |
+
def _precompute_masks(self) -> Dict[int, torch.Tensor]:
|
416 |
+
"""预计算每个氨基酸对应的密码子掩码,以提高性能。"""
|
417 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
418 |
+
masks = {}
|
419 |
+
|
420 |
+
for amino_acid, codons in self.amino_acid_to_codon.items():
|
421 |
+
amino_acid_id = tokenizer.token_to_id.get(amino_acid, tokenizer.unk_idx)
|
422 |
+
mask = torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
423 |
+
|
424 |
+
for codon in codons:
|
425 |
+
codon_id = tokenizer.token_to_id.get(codon, tokenizer.unk_idx)
|
426 |
+
if codon_id != tokenizer.unk_idx:
|
427 |
+
mask[codon_id] = True
|
428 |
+
|
429 |
+
masks[amino_acid_id] = mask
|
430 |
+
|
431 |
+
return masks
|
432 |
+
|
433 |
+
def forward(self, tokenizer_encoded_proteins,esm2_encoded_proteins):
|
434 |
+
"""模型前向传播。
|
435 |
+
|
436 |
+
Args:
|
437 |
+
tokenizer_encoded_proteins: 编码后的蛋白质序列,形状为(batch_size, seq_len)
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
logits: 输出逻辑值,表示模型预测
|
441 |
+
router_logits_list: 路由器逻辑值列表
|
442 |
+
entropy_loss: 熵损失
|
443 |
+
"""
|
444 |
+
# 创建源序列的填充掩码
|
445 |
+
tokenizer = Tokenizer() # 创建分词器实例
|
446 |
+
src_padding_mask = (tokenizer_encoded_proteins == tokenizer.padding_idx)
|
447 |
+
|
448 |
+
x=self.dim_trans(esm2_encoded_proteins)
|
449 |
+
|
450 |
+
# 通过编码器处理
|
451 |
+
x= self.encoder(
|
452 |
+
x,
|
453 |
+
src_key_padding_mask=src_padding_mask
|
454 |
+
)
|
455 |
+
|
456 |
+
# 为批次中的每个项目和序列位置生成掩码
|
457 |
+
batch_size, seq_len = tokenizer_encoded_proteins.shape
|
458 |
+
|
459 |
+
# 使用索引查询预计算的掩码,通过广播优化性能
|
460 |
+
amino_acid_to_codon_mask = torch.stack([
|
461 |
+
self.precomputed_masks.get(
|
462 |
+
tok.item(),
|
463 |
+
torch.zeros(len(tokenizer.tokens), dtype=torch.bool, device=self.device)
|
464 |
+
)
|
465 |
+
for tok in tokenizer_encoded_proteins.reshape(-1)
|
466 |
+
]).view(batch_size, seq_len, -1)
|
467 |
+
|
468 |
+
# 计算输出逻辑值并应用掩码
|
469 |
+
logits = self.output_layer(x)
|
470 |
+
|
471 |
+
# 使用masking而不是scatter来提高性能
|
472 |
+
logits = logits.masked_fill(~amino_acid_to_codon_mask, -6.0e4)
|
473 |
+
|
474 |
+
return logits
|
475 |
+
|
476 |
+
class RewardModel_encoder(nn.Module):
|
477 |
+
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward,dropout,num_experts,top_k_experts,device):
|
478 |
+
super(RewardModel_encoder, self).__init__()
|
479 |
+
self.tokenizer=Tokenizer()
|
480 |
+
self.device=device
|
481 |
+
|
482 |
+
self.encoder = Encoder(vocab_size, d_model, nhead, num_encoder_layers,
|
483 |
+
dim_feedforward, dropout, num_experts, top_k_experts)
|
484 |
+
self.reward_output_layer = nn.Sequential(
|
485 |
+
nn.Linear(d_model, d_model//2),
|
486 |
+
nn.LayerNorm(d_model//2), # 对线性层的输出进行归一化
|
487 |
+
nn.ReLU(),
|
488 |
+
nn.Dropout(dropout),
|
489 |
+
nn.Linear(d_model//2, 1)
|
490 |
+
)
|
491 |
+
|
492 |
+
|
493 |
+
def forward(self, tokenizer_encoded_mrnas):
|
494 |
+
|
495 |
+
src_padding_mask = (tokenizer_encoded_mrnas==self.tokenizer.padding_idx)
|
496 |
+
|
497 |
+
x,router_logits_list,entropy_loss = self.encoder(tokenizer_encoded_mrnas, src_key_padding_mask=src_padding_mask)
|
498 |
+
|
499 |
+
|
500 |
+
reward=self.reward_output_layer(x)
|
501 |
+
reward=reward[:,0,:].squeeze()
|
502 |
+
|
503 |
+
return reward,router_logits_list,entropy_loss
|
504 |
+
|
505 |
+
|
506 |
+
|
507 |
+
class LengthAwareDistributedSampler_human(DistributedSampler):
|
508 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
509 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
510 |
+
|
511 |
+
self.lengths = lengths # 每个样本的长度列表
|
512 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
513 |
+
self.data_num_rat=data_num_rat
|
514 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
515 |
+
|
516 |
+
def calculate_weights(self):
|
517 |
+
# 分段式加权策略
|
518 |
+
weights = np.ones(len(self.lengths))
|
519 |
+
weights[np.array(self.lengths) >= 1300] = 85.64*200
|
520 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 5.02*200
|
521 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 4.36*100
|
522 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 3.63*100
|
523 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 3.15
|
524 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 2.20
|
525 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.64
|
526 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.36
|
527 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
528 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.75
|
529 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.63
|
530 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.60
|
531 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.71
|
532 |
+
weights[np.array(self.lengths) < 100] = 3.68*100
|
533 |
+
|
534 |
+
return weights / np.sum(weights) # 将权重归一化
|
535 |
+
|
536 |
+
def __iter__(self):
|
537 |
+
# 根据加权采样进行索引选择
|
538 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
539 |
+
|
540 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
541 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
542 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
543 |
+
|
544 |
+
# 将样本分配给不同进程
|
545 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
546 |
+
|
547 |
+
if self.shuffle:
|
548 |
+
np.random.shuffle(indices)
|
549 |
+
|
550 |
+
return iter(indices.tolist())
|
551 |
+
|
552 |
+
def set_epoch(self, epoch):
|
553 |
+
super().set_epoch(epoch)
|
554 |
+
|
555 |
+
class LengthAwareDistributedSampler_Arabidopsis(DistributedSampler):
|
556 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
557 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
558 |
+
|
559 |
+
self.lengths = lengths # 每个样本的长度列表
|
560 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
561 |
+
self.data_num_rat=data_num_rat
|
562 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
563 |
+
|
564 |
+
def calculate_weights(self):
|
565 |
+
# 分段式加权策略
|
566 |
+
weights = np.ones(len(self.lengths))
|
567 |
+
weights[np.array(self.lengths) >= 1300] = 630.75*20
|
568 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 17.05*20
|
569 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 11.52*20
|
570 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 7.17*10
|
571 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 5.56*10
|
572 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.54
|
573 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.51
|
574 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.62
|
575 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
576 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.68
|
577 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.49
|
578 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.49
|
579 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.49
|
580 |
+
weights[np.array(self.lengths) < 100] = 1.23*10
|
581 |
+
|
582 |
+
return weights / np.sum(weights) # 将权重归一化
|
583 |
+
|
584 |
+
def __iter__(self):
|
585 |
+
# 根据加权采样进行索引选择
|
586 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
587 |
+
|
588 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
589 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
590 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
591 |
+
|
592 |
+
# 将样本分配给不同进程
|
593 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
594 |
+
|
595 |
+
if self.shuffle:
|
596 |
+
np.random.shuffle(indices)
|
597 |
+
|
598 |
+
return iter(indices.tolist())
|
599 |
+
|
600 |
+
def set_epoch(self, epoch):
|
601 |
+
super().set_epoch(epoch)
|
602 |
+
|
603 |
+
|
604 |
+
class LengthAwareDistributedSampler_CR(DistributedSampler):
|
605 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
606 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
607 |
+
|
608 |
+
self.lengths = lengths # 每个样本的长度列表
|
609 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
610 |
+
self.data_num_rat=data_num_rat
|
611 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
612 |
+
|
613 |
+
def calculate_weights(self):
|
614 |
+
# 分段式加权策略
|
615 |
+
weights = np.ones(len(self.lengths))
|
616 |
+
weights[np.array(self.lengths) >= 1300] = 61.55*20
|
617 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 3.66*20
|
618 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 2.96*10
|
619 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 2.54*10
|
620 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 2.11*10
|
621 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 1.79
|
622 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.39
|
623 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.11
|
624 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
625 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.82
|
626 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.73
|
627 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.67
|
628 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.66
|
629 |
+
weights[np.array(self.lengths) < 100] = 1.18*10
|
630 |
+
|
631 |
+
return weights / np.sum(weights) # 将权重归一化
|
632 |
+
|
633 |
+
def __iter__(self):
|
634 |
+
# 根据加权采样进行索引选择
|
635 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
636 |
+
|
637 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
638 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
639 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
640 |
+
|
641 |
+
# 将样本分配给不同进程
|
642 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
643 |
+
|
644 |
+
if self.shuffle:
|
645 |
+
np.random.shuffle(indices)
|
646 |
+
|
647 |
+
return iter(indices.tolist())
|
648 |
+
|
649 |
+
def set_epoch(self, epoch):
|
650 |
+
super().set_epoch(epoch)
|
651 |
+
|
652 |
+
class LengthAwareDistributedSampler_PC(DistributedSampler):
|
653 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
654 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
655 |
+
|
656 |
+
self.lengths = lengths # 每个样本的长度列表
|
657 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
658 |
+
self.data_num_rat=data_num_rat
|
659 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
660 |
+
|
661 |
+
def calculate_weights(self):
|
662 |
+
# 分段式加权策略
|
663 |
+
weights = np.ones(len(self.lengths))
|
664 |
+
weights[np.array(self.lengths) >= 1300] = 318.0*200
|
665 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 13.98*200
|
666 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 10.26*100
|
667 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 7.62*100
|
668 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 6.14*100
|
669 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.80
|
670 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.67
|
671 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.88
|
672 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
673 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.88
|
674 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.75
|
675 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.76
|
676 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.83
|
677 |
+
weights[np.array(self.lengths) < 100] = 1.87*100
|
678 |
+
|
679 |
+
return weights / np.sum(weights) # 将权重归一化
|
680 |
+
|
681 |
+
def __iter__(self):
|
682 |
+
# 根据加权采样进行索引选择
|
683 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
684 |
+
|
685 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
686 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
687 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
688 |
+
|
689 |
+
# 将样本分配给不同进程
|
690 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
691 |
+
|
692 |
+
if self.shuffle:
|
693 |
+
np.random.shuffle(indices)
|
694 |
+
|
695 |
+
return iter(indices.tolist())
|
696 |
+
|
697 |
+
def set_epoch(self, epoch):
|
698 |
+
super().set_epoch(epoch)
|
699 |
+
|
700 |
+
class LengthAwareDistributedSampler_EscherichiaColi(DistributedSampler):
|
701 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
702 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
703 |
+
|
704 |
+
self.lengths = lengths # 每个样本的长度列表
|
705 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
706 |
+
self.data_num_rat=data_num_rat
|
707 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
708 |
+
|
709 |
+
def calculate_weights(self):
|
710 |
+
# 分段式加权策略
|
711 |
+
weights = np.ones(len(self.lengths))
|
712 |
+
weights[np.array(self.lengths) >= 1300] = 211.0*200
|
713 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 26.38*200
|
714 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 15.07*100
|
715 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 11.72*100
|
716 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 11.11*100
|
717 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 4.06
|
718 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.81
|
719 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 2.07
|
720 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
721 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.46
|
722 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.30
|
723 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.25
|
724 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.25
|
725 |
+
weights[np.array(self.lengths) < 100] = 0.47
|
726 |
+
|
727 |
+
return weights / np.sum(weights) # 将权重归一化
|
728 |
+
|
729 |
+
def __iter__(self):
|
730 |
+
# 根据加权采样进行索引选择
|
731 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
732 |
+
|
733 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
734 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
735 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
736 |
+
|
737 |
+
# 将样本分配给不同进程
|
738 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
739 |
+
|
740 |
+
if self.shuffle:
|
741 |
+
np.random.shuffle(indices)
|
742 |
+
|
743 |
+
return iter(indices.tolist())
|
744 |
+
|
745 |
+
def set_epoch(self, epoch):
|
746 |
+
super().set_epoch(epoch)
|
747 |
+
|
748 |
+
class LengthAwareDistributedSampler_TK(DistributedSampler):
|
749 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
750 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
751 |
+
|
752 |
+
self.lengths = lengths # 每个样本的长度列表
|
753 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
754 |
+
self.data_num_rat=data_num_rat
|
755 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
756 |
+
|
757 |
+
def calculate_weights(self):
|
758 |
+
# 分段式加权策略
|
759 |
+
weights = np.ones(len(self.lengths))
|
760 |
+
|
761 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 12.25*10
|
762 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 8.17*10
|
763 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 24.5*10
|
764 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 8.17*10
|
765 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 3.27
|
766 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 2.33
|
767 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.09
|
768 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
769 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.25
|
770 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.17
|
771 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.13
|
772 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.10
|
773 |
+
weights[np.array(self.lengths) < 100] = 0.22
|
774 |
+
|
775 |
+
return weights / np.sum(weights) # 将权重归一化
|
776 |
+
|
777 |
+
def __iter__(self):
|
778 |
+
# 根据加权采样进行索引选择
|
779 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
780 |
+
|
781 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
782 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
783 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
784 |
+
|
785 |
+
# 将样本分配给不同进程
|
786 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
787 |
+
|
788 |
+
if self.shuffle:
|
789 |
+
np.random.shuffle(indices)
|
790 |
+
|
791 |
+
return iter(indices.tolist())
|
792 |
+
|
793 |
+
def set_epoch(self, epoch):
|
794 |
+
super().set_epoch(epoch)
|
795 |
+
|
796 |
+
|
797 |
+
|
798 |
+
class LengthAwareDistributedSampler_human_circ(DistributedSampler):
|
799 |
+
def __init__(self, dataset, lengths, data_num_rat=None,num_replicas=None, rank=None, shuffle=True):
|
800 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
801 |
+
|
802 |
+
self.lengths = lengths # 每个样本的长度列表
|
803 |
+
self.weights = self.calculate_weights() # 根据长度初始化权重
|
804 |
+
self.data_num_rat=data_num_rat
|
805 |
+
self.total_size = int(len(dataset) * data_num_rat)
|
806 |
+
|
807 |
+
def calculate_weights(self):
|
808 |
+
# 分段式加权策略
|
809 |
+
weights = np.ones(len(self.lengths))
|
810 |
+
weights[np.array(self.lengths) >= 1300] = 89.62*20
|
811 |
+
weights[(np.array(self.lengths) >= 1200) & (np.array(self.lengths) < 1300)] = 5.24*20
|
812 |
+
weights[(np.array(self.lengths) >= 1100) & (np.array(self.lengths) < 1200)] = 4.58*10
|
813 |
+
weights[(np.array(self.lengths) >= 1000) & (np.array(self.lengths) < 1100)] = 3.82*10
|
814 |
+
weights[(np.array(self.lengths) >= 900) & (np.array(self.lengths) < 1000)] = 3.30
|
815 |
+
weights[(np.array(self.lengths) >= 800) & (np.array(self.lengths) < 900)] = 2.34
|
816 |
+
weights[(np.array(self.lengths) >= 700) & (np.array(self.lengths) < 800)] = 1.74
|
817 |
+
weights[(np.array(self.lengths) >= 600) & (np.array(self.lengths) < 700)] = 1.36
|
818 |
+
weights[(np.array(self.lengths) >= 500) & (np.array(self.lengths) < 600)] = 1.0
|
819 |
+
weights[(np.array(self.lengths) >= 400) & (np.array(self.lengths) < 500)] = 0.74
|
820 |
+
weights[(np.array(self.lengths) >= 300) & (np.array(self.lengths) < 400)] = 0.57
|
821 |
+
weights[(np.array(self.lengths) >= 200) & (np.array(self.lengths) < 300)] = 0.46
|
822 |
+
weights[(np.array(self.lengths) >= 100) & (np.array(self.lengths) < 200)] = 0.38
|
823 |
+
weights[np.array(self.lengths) < 100] = 0.48
|
824 |
+
|
825 |
+
return weights / np.sum(weights) # 将权重归一化
|
826 |
+
|
827 |
+
def __iter__(self):
|
828 |
+
# 根据加权采样进行索引选择
|
829 |
+
indices = np.random.choice(len(self.dataset), self.total_size, replace=True, p=self.weights)
|
830 |
+
|
831 |
+
# 边界处理:截断到可以整除 num_replicas 的长度
|
832 |
+
total_size_local = (len(indices) // self.num_replicas) * self.num_replicas
|
833 |
+
indices = indices[:total_size_local] # 截断多余的样本
|
834 |
+
|
835 |
+
# 将样本分配给不同进程
|
836 |
+
indices = indices[self.rank:total_size_local:self.num_replicas]
|
837 |
+
|
838 |
+
if self.shuffle:
|
839 |
+
np.random.shuffle(indices)
|
840 |
+
|
841 |
+
return iter(indices.tolist())
|
842 |
+
|
843 |
+
def set_epoch(self, epoch):
|
844 |
+
super().set_epoch(epoch)
|