MilaDeepGraph
commited on
Commit
•
314a644
1
Parent(s):
6821c3c
init from Jiqing's repo
Browse files- README.md +108 -3
- config.json +63 -0
- configuration_protst.py +53 -0
- modeling_protst.py +285 -0
- pytorch_model.bin +3 -0
README.md
CHANGED
@@ -1,3 +1,108 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Abstract
|
3 |
+
Current protein language models (PLMs) learn protein representations mainly based on their sequences, thereby well capturing co-evolutionary information, but they are unable to explicitly acquire protein functions, which is the end goal of protein representation learning. Fortunately, for many proteins, their textual property descriptions are available, where their various functions are also described. Motivated by this fact, we first build the ProtDescribe dataset to augment protein sequences with text descriptions of their functions and other important properties. Based on this dataset, we propose the [ProtST framework](https://arxiv.org/abs/2301.12040) to enhance Protein Sequence pre-training and understanding by biomedical Texts. During pre-training, we design three types of tasks, i.e., unimodal mask prediction, multimodal representation alignment and multimodal mask prediction, to enhance a PLM with protein property information with different granularities and, at the same time, preserve the PLM’s original representation power. On downstream tasks, ProtST enables both supervised learning and zeroshot prediction. We verify the superiority of ProtST-induced PLMs over previous ones on diverse representation learning benchmarks. Under the zero-shot setting, we show the effectiveness of ProtST on zero-shot protein classification, and ProtST also enables functional protein retrieval from a large-scale database without any function annotation. Source code and model weights are available at [https://github.com/DeepGraphLearning/ProtST](https://github.com/DeepGraphLearning/ProtST).
|
4 |
+
|
5 |
+
![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f0a673f0d40f6aae296b4a/o4F5-Cm-gGdHPpX5rPVKx.png)
|
6 |
+
|
7 |
+
## Example
|
8 |
+
This example shows how to use ProtST on zero-shot classification task.
|
9 |
+
```python
|
10 |
+
import logging
|
11 |
+
import functools
|
12 |
+
from tqdm import tqdm
|
13 |
+
import torch
|
14 |
+
from datasets import load_dataset
|
15 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def tokenize_protein(example, protein_tokenizer=None, padding=None):
|
21 |
+
protein_seqs = example["prot_seq"]
|
22 |
+
|
23 |
+
protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
|
24 |
+
example["protein_input_ids"] = protein_inputs.input_ids
|
25 |
+
example["protein_attention_mask"] = protein_inputs.attention_mask
|
26 |
+
|
27 |
+
return example
|
28 |
+
|
29 |
+
|
30 |
+
def label_embedding(labels, text_tokenizer, text_model, device):
|
31 |
+
# embed label descriptions
|
32 |
+
label_feature = []
|
33 |
+
with torch.inference_mode():
|
34 |
+
for label in labels:
|
35 |
+
label_input_ids = text_tokenizer.encode(label, max_length=128,
|
36 |
+
truncation=True, add_special_tokens=False)
|
37 |
+
label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
|
38 |
+
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
39 |
+
attention_mask = label_input_ids != text_tokenizer.pad_token_id
|
40 |
+
attention_mask = attention_mask.to(device)
|
41 |
+
|
42 |
+
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
|
43 |
+
|
44 |
+
label_feature.append(text_outputs["text_feature"])
|
45 |
+
label_feature = torch.cat(label_feature, dim=0)
|
46 |
+
label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
|
47 |
+
|
48 |
+
return label_feature
|
49 |
+
|
50 |
+
def zero_shot_eval(logger, device,
|
51 |
+
test_dataset, target_field, protein_model, logit_scale, label_feature):
|
52 |
+
|
53 |
+
# get prediction and target
|
54 |
+
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
|
55 |
+
preds, targets = [], []
|
56 |
+
with torch.inference_mode():
|
57 |
+
for data in tqdm(test_dataloader):
|
58 |
+
target = data[target_field]
|
59 |
+
targets.append(target)
|
60 |
+
|
61 |
+
protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
|
62 |
+
attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
|
63 |
+
|
64 |
+
protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
|
65 |
+
|
66 |
+
protein_feature = protein_outputs["protein_feature"]
|
67 |
+
protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
|
68 |
+
pred = logit_scale * protein_feature @ label_feature.t()
|
69 |
+
preds.append(pred)
|
70 |
+
|
71 |
+
preds = torch.cat(preds, dim=0)
|
72 |
+
targets = torch.tensor(targets, dtype=torch.long, device=device)
|
73 |
+
accuracy = (preds.argmax(dim=-1) == targets).float().mean().item()
|
74 |
+
logger.warning("Zero-shot accuracy: %.6f" % accuracy)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
# get datasets
|
79 |
+
raw_datasets = load_dataset("Jiqing/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
|
80 |
+
|
81 |
+
#device = torch.device("cuda:0")
|
82 |
+
device = torch.device("cpu")
|
83 |
+
|
84 |
+
protst_model = AutoModel.from_pretrained("Jiqing/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
85 |
+
protein_model = protst_model.protein_model
|
86 |
+
text_model = protst_model.text_model
|
87 |
+
logit_scale = protst_model.logit_scale
|
88 |
+
logit_scale.requires_grad = False
|
89 |
+
logit_scale = logit_scale.to(device)
|
90 |
+
logit_scale = logit_scale.exp()
|
91 |
+
|
92 |
+
protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
|
93 |
+
text_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
|
94 |
+
|
95 |
+
func_tokenize_protein = functools.partial(tokenize_protein, protein_tokenizer=protein_tokenizer, padding=False)
|
96 |
+
test_dataset = raw_datasets.map(
|
97 |
+
func_tokenize_protein, batched=False,
|
98 |
+
remove_columns=["prot_seq"],
|
99 |
+
desc="Running tokenize_proteins on dataset",
|
100 |
+
)
|
101 |
+
|
102 |
+
labels = load_dataset("Jiqing/subloc_template", cache_dir="~/.cache/huggingface/datasets")["train"]["name"]
|
103 |
+
|
104 |
+
text_tokenizer.encode(labels[0], max_length=128, truncation=True, add_special_tokens=False)
|
105 |
+
label_feature = label_embedding(labels, text_tokenizer, text_model, device)
|
106 |
+
zero_shot_eval(logger, device, test_dataset, "localization",
|
107 |
+
protein_model, logit_scale, label_feature)
|
108 |
+
```
|
config.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"ProtSTModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoModel": "modeling_protst.ProtSTModel",
|
7 |
+
"AutoConfig": "configuration_protst.ProtSTConfig"
|
8 |
+
},
|
9 |
+
"model_type": "protst",
|
10 |
+
"protein_config": {
|
11 |
+
"_name_or_path": "/tmp/facebook/esm1b_t33_650M_UR50S",
|
12 |
+
"architectures": [
|
13 |
+
"EsmForMaskedLM"
|
14 |
+
],
|
15 |
+
"attention_probs_dropout_prob": 0.0,
|
16 |
+
"classifier_dropout": null,
|
17 |
+
"emb_layer_norm_before": true,
|
18 |
+
"esmfold_config": null,
|
19 |
+
"hidden_act": "gelu",
|
20 |
+
"hidden_dropout_prob": 0.0,
|
21 |
+
"hidden_size": 1280,
|
22 |
+
"initializer_range": 0.02,
|
23 |
+
"intermediate_size": 5120,
|
24 |
+
"is_folding_model": false,
|
25 |
+
"layer_norm_eps": 1e-05,
|
26 |
+
"mask_token_id": 32,
|
27 |
+
"max_position_embeddings": 1026,
|
28 |
+
"model_type": "esm",
|
29 |
+
"num_attention_heads": 20,
|
30 |
+
"num_hidden_layers": 33,
|
31 |
+
"cls_token_id": 0,
|
32 |
+
"pad_token_id": 1,
|
33 |
+
"eos_token_id": 2,
|
34 |
+
"position_embedding_type": "absolute",
|
35 |
+
"token_dropout": true,
|
36 |
+
"torch_dtype": "float32",
|
37 |
+
"use_cache": true,
|
38 |
+
"vocab_list": null,
|
39 |
+
"vocab_size": 33
|
40 |
+
},
|
41 |
+
"text_config": {
|
42 |
+
"architectures": [
|
43 |
+
"BertForMaskedLM"
|
44 |
+
],
|
45 |
+
"model_type": "bert",
|
46 |
+
"attention_probs_dropout_prob": 0.1,
|
47 |
+
"hidden_act": "gelu",
|
48 |
+
"pad_token_id": 0,
|
49 |
+
"cls_token_id": 2,
|
50 |
+
"sep_token_id": 3,
|
51 |
+
"hidden_dropout_prob": 0.1,
|
52 |
+
"hidden_size": 768,
|
53 |
+
"initializer_range": 0.02,
|
54 |
+
"intermediate_size": 3072,
|
55 |
+
"max_position_embeddings": 512,
|
56 |
+
"num_attention_heads": 12,
|
57 |
+
"num_hidden_layers": 12,
|
58 |
+
"type_vocab_size": 2,
|
59 |
+
"vocab_size": 30522
|
60 |
+
},
|
61 |
+
"torch_dtype": "float32",
|
62 |
+
"transformers_version": "4.37.0.dev0"
|
63 |
+
}
|
configuration_protst.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from transformers.utils import logging
|
3 |
+
from transformers.models.esm import EsmConfig
|
4 |
+
from transformers.models.bert import BertConfig
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class ProtSTConfig(PretrainedConfig):
|
10 |
+
r"""
|
11 |
+
This is the configuration class to store the configuration of a [`ProtSTModel`].
|
12 |
+
|
13 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
14 |
+
documentation from [`PretrainedConfig`] for more information.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
protein_config (`dict`, *optional*):
|
18 |
+
Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
|
19 |
+
text_config (`dict`, *optional*):
|
20 |
+
Dictionary of configuration options used to initialize [`BertForPubMed`].
|
21 |
+
```"""
|
22 |
+
|
23 |
+
model_type = "protst"
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
protein_config=None,
|
28 |
+
text_config=None,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().__init__(**kwargs)
|
32 |
+
|
33 |
+
if protein_config is None:
|
34 |
+
protein_config = {}
|
35 |
+
logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")
|
36 |
+
|
37 |
+
if text_config is None:
|
38 |
+
text_config = {}
|
39 |
+
logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")
|
40 |
+
|
41 |
+
self.protein_config = EsmConfig(**protein_config)
|
42 |
+
self.text_config = BertConfig(**text_config)
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_protein_text_configs(
|
46 |
+
cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
|
47 |
+
):
|
48 |
+
r"""
|
49 |
+
Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
|
50 |
+
[`ProtSTConfig`]: An instance of a configuration object
|
51 |
+
"""
|
52 |
+
|
53 |
+
return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
|
modeling_protst.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import Optional, Tuple, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from transformers import PreTrainedModel
|
7 |
+
from transformers.modeling_outputs import ModelOutput
|
8 |
+
from transformers.models.esm import EsmPreTrainedModel, EsmModel
|
9 |
+
from transformers.models.bert import BertPreTrainedModel, BertModel
|
10 |
+
from .configuration_protst import ProtSTConfig
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class EsmProteinRepresentationOutput(ModelOutput):
|
15 |
+
|
16 |
+
protein_feature: torch.FloatTensor = None
|
17 |
+
residue_feature: torch.FloatTensor = None
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class BertTextRepresentationOutput(ModelOutput):
|
22 |
+
|
23 |
+
text_feature: torch.FloatTensor = None
|
24 |
+
word_feature: torch.FloatTensor = None
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class ProtSTClassificationOutput(ModelOutput):
|
29 |
+
|
30 |
+
loss: Optional[torch.FloatTensor] = None
|
31 |
+
logits: torch.FloatTensor = None
|
32 |
+
|
33 |
+
class ProtSTHead(nn.Module):
|
34 |
+
def __init__(self, config, out_dim=512):
|
35 |
+
super().__init__()
|
36 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
37 |
+
self.out_proj = nn.Linear(config.hidden_size, out_dim)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.dense(x)
|
41 |
+
x = nn.functional.relu(x)
|
42 |
+
x = self.out_proj(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class BertForPubMed(BertPreTrainedModel):
|
47 |
+
def __init__(self, config):
|
48 |
+
super().__init__(config)
|
49 |
+
|
50 |
+
self.pad_token_id = config.pad_token_id
|
51 |
+
self.cls_token_id = config.cls_token_id
|
52 |
+
self.sep_token_id = config.sep_token_id
|
53 |
+
|
54 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
55 |
+
self.text_mlp = ProtSTHead(config)
|
56 |
+
self.word_mlp = ProtSTHead(config)
|
57 |
+
|
58 |
+
self.post_init() # NOTE
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids: Optional[torch.Tensor] = None,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
65 |
+
position_ids: Optional[torch.Tensor] = None,
|
66 |
+
head_mask: Optional[torch.Tensor] = None,
|
67 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
68 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
69 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
70 |
+
output_attentions: Optional[bool] = None,
|
71 |
+
output_hidden_states: Optional[bool] = None,
|
72 |
+
return_dict: Optional[bool] = None,
|
73 |
+
) -> Union[Tuple[torch.Tensor], ModelOutput]:
|
74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
75 |
+
|
76 |
+
outputs = self.bert(
|
77 |
+
input_ids,
|
78 |
+
attention_mask=attention_mask,
|
79 |
+
token_type_ids=token_type_ids,
|
80 |
+
position_ids=position_ids,
|
81 |
+
head_mask=head_mask,
|
82 |
+
inputs_embeds=inputs_embeds,
|
83 |
+
encoder_hidden_states=encoder_hidden_states,
|
84 |
+
encoder_attention_mask=encoder_attention_mask,
|
85 |
+
output_attentions=output_attentions,
|
86 |
+
output_hidden_states=output_hidden_states,
|
87 |
+
return_dict=return_dict,
|
88 |
+
)
|
89 |
+
word_feature = outputs.last_hidden_state
|
90 |
+
is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id)
|
91 |
+
special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
|
92 |
+
pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype)
|
93 |
+
pooled_feature = self.text_mlp(pooled_feature)
|
94 |
+
word_feature = self.word_mlp(word_feature)
|
95 |
+
|
96 |
+
if not return_dict:
|
97 |
+
return (pooled_feature, word_feature)
|
98 |
+
|
99 |
+
return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
class EsmForProteinRepresentation(EsmPreTrainedModel):
|
105 |
+
def __init__(self, config):
|
106 |
+
super().__init__(config)
|
107 |
+
|
108 |
+
self.cls_token_id = config.cls_token_id
|
109 |
+
self.pad_token_id = config.pad_token_id
|
110 |
+
self.eos_token_id = config.eos_token_id
|
111 |
+
|
112 |
+
self.esm = EsmModel(config, add_pooling_layer=False)
|
113 |
+
self.protein_mlp = ProtSTHead(config)
|
114 |
+
self.residue_mlp = ProtSTHead(config)
|
115 |
+
|
116 |
+
self.post_init() # NOTE
|
117 |
+
|
118 |
+
def forward(
|
119 |
+
self,
|
120 |
+
input_ids: Optional[torch.LongTensor] = None,
|
121 |
+
attention_mask: Optional[torch.Tensor] = None,
|
122 |
+
position_ids: Optional[torch.LongTensor] = None,
|
123 |
+
head_mask: Optional[torch.Tensor] = None,
|
124 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
125 |
+
output_attentions: Optional[bool] = None,
|
126 |
+
output_hidden_states: Optional[bool] = None,
|
127 |
+
return_dict: Optional[bool] = None,
|
128 |
+
) -> Union[Tuple, EsmProteinRepresentationOutput]:
|
129 |
+
|
130 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
131 |
+
|
132 |
+
outputs = self.esm(
|
133 |
+
input_ids,
|
134 |
+
attention_mask=attention_mask,
|
135 |
+
position_ids=position_ids,
|
136 |
+
head_mask=head_mask,
|
137 |
+
inputs_embeds=inputs_embeds,
|
138 |
+
output_attentions=output_attentions,
|
139 |
+
output_hidden_states=output_hidden_states,
|
140 |
+
return_dict=return_dict,
|
141 |
+
)
|
142 |
+
|
143 |
+
residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
|
144 |
+
|
145 |
+
# mean readout
|
146 |
+
is_special = (
|
147 |
+
(input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id)
|
148 |
+
)
|
149 |
+
special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
|
150 |
+
protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype)
|
151 |
+
|
152 |
+
# For ProtST pretrain and zero-shot
|
153 |
+
protein_feature = self.protein_mlp(protein_feature)
|
154 |
+
residue_feature = self.residue_mlp(residue_feature)
|
155 |
+
|
156 |
+
|
157 |
+
return EsmProteinRepresentationOutput(
|
158 |
+
protein_feature=protein_feature, residue_feature=residue_feature
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
class ProtSTPreTrainedModel(PreTrainedModel):
|
163 |
+
config_class = ProtSTConfig
|
164 |
+
|
165 |
+
def _compute_protein_feature(self,
|
166 |
+
protein_input_ids, protein_attention_mask, protein_position_ids,
|
167 |
+
output_attentions, output_hidden_states
|
168 |
+
):
|
169 |
+
|
170 |
+
protein_outputs = self.protein_model(
|
171 |
+
protein_input_ids,
|
172 |
+
attention_mask=protein_attention_mask,
|
173 |
+
position_ids=protein_position_ids,
|
174 |
+
head_mask=None,
|
175 |
+
inputs_embeds=None,
|
176 |
+
encoder_hidden_states=None,
|
177 |
+
encoder_attention_mask=None,
|
178 |
+
output_attentions=output_attentions,
|
179 |
+
output_hidden_states=output_hidden_states,
|
180 |
+
return_dict=None,
|
181 |
+
)
|
182 |
+
|
183 |
+
return protein_outputs
|
184 |
+
|
185 |
+
def _compute_text_feature(self,
|
186 |
+
text_input_ids, text_attention_mask, text_position_ids,
|
187 |
+
output_attentions, output_hidden_states
|
188 |
+
):
|
189 |
+
text_outputs = self.text_model(
|
190 |
+
text_input_ids,
|
191 |
+
attention_mask=text_attention_mask,
|
192 |
+
position_ids=text_position_ids,
|
193 |
+
head_mask=None,
|
194 |
+
inputs_embeds=None,
|
195 |
+
encoder_hidden_states=None,
|
196 |
+
encoder_attention_mask=None,
|
197 |
+
output_attentions=output_attentions,
|
198 |
+
output_hidden_states=output_hidden_states,
|
199 |
+
return_dict=None,
|
200 |
+
)
|
201 |
+
|
202 |
+
return text_outputs
|
203 |
+
|
204 |
+
|
205 |
+
class ProtSTModel(ProtSTPreTrainedModel):
|
206 |
+
def __init__(self, config):
|
207 |
+
super().__init__(config)
|
208 |
+
|
209 |
+
self.config = config
|
210 |
+
self.protein_model = EsmForProteinRepresentation(config.protein_config)
|
211 |
+
self.text_model = BertForPubMed(config.text_config)
|
212 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
|
213 |
+
|
214 |
+
self.post_init() # NOTE
|
215 |
+
|
216 |
+
def forward(self,
|
217 |
+
protein_input_ids: Optional[torch.LongTensor] = None,
|
218 |
+
text_input_ids: Optional[torch.LongTensor] = None,
|
219 |
+
protein_attention_mask: Optional[torch.Tensor] = None,
|
220 |
+
text_attention_mask: Optional[torch.Tensor] = None,
|
221 |
+
protein_position_ids: Optional[torch.LongTensor] = None,
|
222 |
+
text_position_ids: Optional[torch.LongTensor] = None,
|
223 |
+
output_attentions: Optional[bool] = None,
|
224 |
+
output_hidden_states: Optional[bool] = None,
|
225 |
+
):
|
226 |
+
# Not implement yet
|
227 |
+
return None
|
228 |
+
|
229 |
+
|
230 |
+
class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel):
|
231 |
+
def __init__(self, config):
|
232 |
+
super().__init__(config)
|
233 |
+
|
234 |
+
self.config = config
|
235 |
+
self.protein_model = EsmForProteinRepresentation(config.protein_config)
|
236 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
|
237 |
+
self.classifier = ProtSTHead(config.protein_config, out_dim=config.num_labels)
|
238 |
+
|
239 |
+
self.post_init() # NOTE
|
240 |
+
|
241 |
+
def forward(
|
242 |
+
self,
|
243 |
+
input_ids: Optional[torch.LongTensor] = None,
|
244 |
+
attention_mask: Optional[torch.Tensor] = None,
|
245 |
+
position_ids: Optional[torch.LongTensor] = None,
|
246 |
+
head_mask: Optional[torch.Tensor] = None,
|
247 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
248 |
+
labels: Optional[torch.LongTensor] = None,
|
249 |
+
output_attentions: Optional[bool] = None,
|
250 |
+
output_hidden_states: Optional[bool] = None,
|
251 |
+
return_dict: Optional[bool] = None,
|
252 |
+
) -> Union[Tuple, ProtSTClassificationOutput]:
|
253 |
+
r"""
|
254 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
255 |
+
Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
256 |
+
Returns:
|
257 |
+
Examples:
|
258 |
+
"""
|
259 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
260 |
+
|
261 |
+
outputs = self.protein_model(
|
262 |
+
input_ids,
|
263 |
+
attention_mask=attention_mask,
|
264 |
+
position_ids=position_ids,
|
265 |
+
head_mask=head_mask,
|
266 |
+
inputs_embeds=inputs_embeds,
|
267 |
+
output_attentions=output_attentions,
|
268 |
+
output_hidden_states=output_hidden_states,
|
269 |
+
return_dict=return_dict,
|
270 |
+
)
|
271 |
+
|
272 |
+
logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
|
273 |
+
|
274 |
+
loss = None
|
275 |
+
if labels is not None:
|
276 |
+
loss_fct = nn.CrossEntropyLoss()
|
277 |
+
|
278 |
+
labels = labels.to(logits.device)
|
279 |
+
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
280 |
+
|
281 |
+
if not return_dict:
|
282 |
+
output = (logits,)
|
283 |
+
return ((loss,) + output) if loss is not None else output
|
284 |
+
|
285 |
+
return ProtSTClassificationOutput(loss=loss, logits=logits)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c59f77e12992626701f6bdfb732b5b9171f753fda86df7f68aa2135ebd421868
|
3 |
+
size 135
|