Commit
·
b648b64
1
Parent(s):
1cdb04b
Update README.md
Browse files
README.md
CHANGED
@@ -19,8 +19,23 @@ tags:
|
|
19 |
- protein language model
|
20 |
- binding sites
|
21 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
## Training procedure
|
23 |
|
|
|
|
|
|
|
|
|
24 |
```python
|
25 |
Epoch: 3
|
26 |
Training Loss: 0.029100
|
@@ -35,5 +50,54 @@ Mcc: 0.560612
|
|
35 |
|
36 |
### Framework versions
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
19 |
- protein language model
|
20 |
- binding sites
|
21 |
---
|
22 |
+
# ESM-2 for Binding Site Prediction
|
23 |
+
|
24 |
+
This model is a finetuned version of the 35M parameter `esm2_t12_35M_UR50D` ([see here](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
|
25 |
+
and [here](https://huggingface.co/docs/transformers/model_doc/esm) for more details). The model was finetuned with LoRA for
|
26 |
+
the binay token classification task of predicting binding sites (and active sites) of protein sequences based on sequence alone.
|
27 |
+
The model may be underfit and undertrained, however it still achieved better performance on the test set in terms of loss, accuracy,
|
28 |
+
precision, recall, F1 score, ROC_AUC, and Matthews Correlation Coefficient (MCC) compared to the models trained on the smaller
|
29 |
+
dataset [found here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) of ~209K protein sequences. Note,
|
30 |
+
this model has a high recall, meaning it is likely to detect binding sites, but it has a low precision, meaning the model will likely return
|
31 |
+
false positives as well.
|
32 |
+
|
33 |
## Training procedure
|
34 |
|
35 |
+
This model was finetuned on ~549K protein sequences from the UniProt database. The dataset can be found
|
36 |
+
[here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). The model obtains
|
37 |
+
the following test metrics:
|
38 |
+
|
39 |
```python
|
40 |
Epoch: 3
|
41 |
Training Loss: 0.029100
|
|
|
50 |
|
51 |
### Framework versions
|
52 |
|
53 |
+
- PEFT 0.5.0
|
54 |
+
|
55 |
+
## Using the model
|
56 |
+
|
57 |
+
To use the model on one of your protein sequences try running the following:
|
58 |
+
|
59 |
+
```python
|
60 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
61 |
+
from peft import PeftModel
|
62 |
+
import torch
|
63 |
+
|
64 |
+
# Path to the saved LoRA model
|
65 |
+
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
66 |
+
# ESM2 base model
|
67 |
+
base_model_path = "facebook/esm2_t12_35M_UR50D"
|
68 |
+
|
69 |
+
# Load the model
|
70 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
71 |
+
loaded_model = PeftModel.from_pretrained(base_model, model_path)
|
72 |
+
|
73 |
+
# Ensure the model is in evaluation mode
|
74 |
+
loaded_model.eval()
|
75 |
+
|
76 |
+
# Load the tokenizer
|
77 |
+
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
78 |
+
|
79 |
+
# Protein sequence for inference
|
80 |
+
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
|
81 |
+
|
82 |
+
# Tokenize the sequence
|
83 |
+
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
|
84 |
+
|
85 |
+
# Run the model
|
86 |
+
with torch.no_grad():
|
87 |
+
logits = loaded_model(**inputs).logits
|
88 |
+
|
89 |
+
# Get predictions
|
90 |
+
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
91 |
+
predictions = torch.argmax(logits, dim=2)
|
92 |
+
|
93 |
+
# Define labels
|
94 |
+
id2label = {
|
95 |
+
0: "No binding site",
|
96 |
+
1: "Binding site"
|
97 |
+
}
|
98 |
|
99 |
+
# Print the predicted labels for each token
|
100 |
+
for token, prediction in zip(tokens, predictions[0].numpy()):
|
101 |
+
if token not in ['<pad>', '<cls>', '<eos>']:
|
102 |
+
print((token, id2label[prediction]))
|
103 |
+
```
|