fmoorhof commited on
Commit
64ebfda
·
1 Parent(s): 5b12ea5

feat: pLM embeddings of protein sequence input

Browse files
Files changed (3) hide show
  1. app.py +5 -3
  2. embed.py +126 -0
  3. requirements.txt +5 -1
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from embed import gen_embedding
3
 
4
+ def generate_embeddings(sequences):
5
+ embeddings = gen_embedding(sequences, plm_model="esm1b")
6
+ return embeddings.tolist()
7
 
8
+ demo = gr.Interface(fn=generate_embeddings, inputs="text", outputs="text")
9
  demo.launch()
embed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+
13
+ def gen_embedding(
14
+ sequences: list[str], plm_model: str = "esm1b", no_pad: bool = False
15
+ ) -> np.ndarray:
16
+ """
17
+ Generate embeddings for a list of sequences using a specified pre-trained language model (PLM).
18
+
19
+ Args:
20
+ sequences (list[str]): List of amino acid sequences.
21
+ plm_model (str, optional): Pre-trained model name. Options: 'esm1b', 'esm2', 'prott5', 'prostt5'.
22
+ no_pad (bool, optional): If True, removes padding tokens when calculating mean embedding.
23
+
24
+ Returns:
25
+ np.ndarray: Array of embeddings.
26
+ """
27
+ tokenizer, model = _load_model_and_tokenizer(plm_model)
28
+ logging.info(f"Generating embeddings with {plm_model} on device: {device}")
29
+
30
+ formatted_sequences = _format_sequences(sequences, plm_model)
31
+
32
+ embeddings = [_generate_sequence_embedding(seq, tokenizer, model, plm_model, no_pad) for seq in tqdm(formatted_sequences)]
33
+
34
+ torch.cuda.empty_cache()
35
+ return np.array(embeddings)
36
+
37
+
38
+ def _load_model_and_tokenizer(plm_model: str) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
39
+ """Load the tokenizer and model for a given PLM."""
40
+ if plm_model == "esm1b":
41
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
42
+ model = AutoModel.from_pretrained("facebook/esm1b_t33_650M_UR50S").to(device)
43
+
44
+ elif plm_model == "esm2":
45
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
46
+ model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
47
+
48
+ elif plm_model == "prott5":
49
+ from transformers import T5EncoderModel, T5Tokenizer
50
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
51
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)
52
+
53
+ elif plm_model == "prostt5":
54
+ from transformers import T5EncoderModel, T5Tokenizer
55
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5")
56
+ model = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(device)
57
+
58
+ else:
59
+ raise ValueError(
60
+ f"Unsupported model '{plm_model}'. Choose from 'esm1b', 'esm2', 'prott5', 'prostt5'."
61
+ )
62
+
63
+ return tokenizer, model
64
+
65
+
66
+ def _format_sequences(sequences: list[str], plm_model: str) -> list[str]:
67
+ """Format sequences if necessary (e.g., insert spaces for T5 models)."""
68
+ if plm_model in {"prott5", "prostt5"}:
69
+ return [" ".join(list(seq)) for seq in sequences]
70
+ return sequences
71
+
72
+
73
+ def _generate_sequence_embedding(
74
+ sequence: str,
75
+ tokenizer: PreTrainedTokenizer,
76
+ model: PreTrainedModel,
77
+ plm_model: str,
78
+ no_pad: bool,
79
+ ) -> np.ndarray:
80
+ """Generate embedding for a single sequence."""
81
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True).to(device)
82
+
83
+ with torch.no_grad():
84
+ outputs = model(**inputs)
85
+
86
+ if no_pad:
87
+ return _extract_no_pad_embedding(outputs, sequence, plm_model)
88
+ else:
89
+ return _extract_mean_embedding(outputs, sequence, plm_model)
90
+
91
+
92
+ def _extract_mean_embedding(
93
+ outputs: torch.nn.Module,
94
+ sequence: str,
95
+ plm_model: str,
96
+ ) -> np.ndarray:
97
+ """Extract mean embedding including padding."""
98
+ try:
99
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
100
+ except RuntimeError as e:
101
+ if plm_model == "esm1b":
102
+ raise RuntimeError(
103
+ f"ESM-1b model cannot handle sequences longer than 1024 amino acids.\n"
104
+ f"Problematic sequence: {sequence}\n"
105
+ "Please filter or truncate long sequences or use 'prott5' instead."
106
+ ) from e
107
+ raise
108
+ return embedding
109
+
110
+
111
+ def _extract_no_pad_embedding(
112
+ outputs: torch.nn.Module,
113
+ sequence: str,
114
+ plm_model: str,
115
+ ) -> np.ndarray:
116
+ """Extract mean embedding after removing padding."""
117
+ seq_len = len(sequence) if plm_model not in {"prott5", "prostt5"} else int(len(sequence) / 2 + 1)
118
+ return outputs.last_hidden_state[0, :seq_len, :].mean(dim=0).cpu().numpy()
119
+
120
+
121
+
122
+ if __name__ == "__main__":
123
+ seqs = ["PRTNN", "PRTN"]
124
+ embeddings = gen_embedding(seqs, plm_model="prott5") # , no_pad=True)
125
+ print(embeddings.shape)
126
+ print(embeddings)
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- gradio
 
 
 
 
 
1
+ numpy
2
+ gradio
3
+ torch
4
+ tqdm
5
+ transformers