Spaces:
Sleeping
Sleeping
Simon Duerr
commited on
Commit
·
8c639ec
0
Parent(s):
webapp
Browse files- CODEOWNERS +2 -0
- LICENSE +29 -0
- ProteinMPNN +1 -0
- app.py +498 -0
- checkpoints/allatom.yml +69 -0
- checkpoints/allatom_state_dict.pth +3 -0
- checkpoints/backbone.yml +69 -0
- checkpoints/backbone_state_dict.pth +3 -0
- checkpoints/minimpnn_state_dict.pth +3 -0
- configs/allatom.yml +69 -0
- configs/backbone.yml +69 -0
- configs/seqdes.yml +74 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-38.pyc +0 -0
- core/__pycache__/__init__.cpython-39.pyc +0 -0
- core/__pycache__/data.cpython-38.pyc +0 -0
- core/__pycache__/data.cpython-39.pyc +0 -0
- core/__pycache__/protein.cpython-38.pyc +0 -0
- core/__pycache__/protein.cpython-39.pyc +0 -0
- core/__pycache__/protein_mpnn.cpython-38.pyc +0 -0
- core/__pycache__/protein_mpnn.cpython-39.pyc +0 -0
- core/__pycache__/residue_constants.cpython-38.pyc +0 -0
- core/__pycache__/residue_constants.cpython-39.pyc +0 -0
- core/__pycache__/utils.cpython-38.pyc +0 -0
- core/__pycache__/utils.cpython-39.pyc +0 -0
- core/data.py +271 -0
- core/protein.py +341 -0
- core/protein_mpnn.py +1886 -0
- core/residue_constants.py +1104 -0
- core/stereo_chemical_props.txt +345 -0
- core/utils.py +1062 -0
- diffusion.py +66 -0
- draw_samples.py +353 -0
- evaluation.py +406 -0
- models.py +778 -0
- modules.py +696 -0
- output_helpers.py +0 -0
- package.txt +1 -0
- protpardelle_pymol.py +159 -0
- requirements.txt +14 -0
- sampling.py +213 -0
CODEOWNERS
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Global owner
|
2 |
+
* @alexechu
|
LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022, Alex Chu
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
3. Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
ProteinMPNN
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 8907e6671bfbfc92303b5f79c4b5e6ce47cdef57
|
app.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import re
|
4 |
+
import urllib
|
5 |
+
|
6 |
+
import tempfile
|
7 |
+
|
8 |
+
from output_helpers import viewer_html, output_html, load_js, get_js
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import shlex
|
15 |
+
import subprocess
|
16 |
+
from datetime import datetime
|
17 |
+
|
18 |
+
from einops import repeat
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from core import data
|
22 |
+
from core import utils
|
23 |
+
import models
|
24 |
+
import sampling
|
25 |
+
|
26 |
+
# from draw_samples import draw_and_save_samples, parse_resample_idx_string
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def draw_and_save_samples(
|
31 |
+
model,
|
32 |
+
samples_per_len=8,
|
33 |
+
lengths=range(50, 512),
|
34 |
+
save_dir="./",
|
35 |
+
mode="backbone",
|
36 |
+
**sampling_kwargs,
|
37 |
+
):
|
38 |
+
device = model.device
|
39 |
+
sample_files = []
|
40 |
+
if mode == "backbone":
|
41 |
+
total_sampling_time = 0
|
42 |
+
for l in lengths:
|
43 |
+
prot_lens = torch.ones(samples_per_len).long() * l
|
44 |
+
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
|
45 |
+
aux = sampling.draw_backbone_samples(
|
46 |
+
model,
|
47 |
+
seq_mask=seq_mask,
|
48 |
+
pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp",
|
49 |
+
return_aux=True,
|
50 |
+
return_sampling_runtime=True,
|
51 |
+
**sampling_kwargs,
|
52 |
+
)
|
53 |
+
total_sampling_time += aux["runtime"]
|
54 |
+
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
|
55 |
+
return sample_files
|
56 |
+
elif mode == "allatom":
|
57 |
+
total_sampling_time = 0
|
58 |
+
for l in lengths:
|
59 |
+
prot_lens = torch.ones(samples_per_len).long() * l
|
60 |
+
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
|
61 |
+
aux = sampling.draw_allatom_samples(
|
62 |
+
model,
|
63 |
+
seq_mask=seq_mask,
|
64 |
+
pdb_save_path=f"{save_dir}/len{format(l, '03d')}",
|
65 |
+
return_aux=True,
|
66 |
+
**sampling_kwargs,
|
67 |
+
)
|
68 |
+
total_sampling_time += aux["runtime"]
|
69 |
+
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
|
70 |
+
return sample_files
|
71 |
+
|
72 |
+
|
73 |
+
def parse_idx_string(idx_str):
|
74 |
+
spans = idx_str.split(",")
|
75 |
+
idxs = []
|
76 |
+
for s in spans:
|
77 |
+
if "-" in s:
|
78 |
+
start, stop = s.split("-")
|
79 |
+
idxs.extend(list(range(int(start), int(stop))))
|
80 |
+
else:
|
81 |
+
idxs.append(int(s))
|
82 |
+
return idxs
|
83 |
+
|
84 |
+
def changemode(m):
|
85 |
+
if (m == "unconditional"):
|
86 |
+
return gr.update(visible=True), gr.update(visible=False),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
|
87 |
+
else:
|
88 |
+
return gr.update(visible=False), gr.update(visible=True),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
|
89 |
+
|
90 |
+
def fileselection(val):
|
91 |
+
if (val == "upload"):
|
92 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
|
93 |
+
else:
|
94 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
|
95 |
+
|
96 |
+
def update_structuresel(pdb, radio_val):
|
97 |
+
pdb_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
|
98 |
+
|
99 |
+
|
100 |
+
representations = [{
|
101 |
+
"model": 0,
|
102 |
+
"chain": "",
|
103 |
+
"resname": "",
|
104 |
+
"style": "cartoon",
|
105 |
+
"color": "whiteCarbon",
|
106 |
+
"residue_range": "",
|
107 |
+
"around": 0,
|
108 |
+
"byres": False,
|
109 |
+
"visible": False,
|
110 |
+
}]
|
111 |
+
|
112 |
+
|
113 |
+
if (radio_val == "PDB"):
|
114 |
+
if (len(pdb) != 4):
|
115 |
+
return gr.update(open=True),gr.update(), gr.update(value="",visible=False)
|
116 |
+
else:
|
117 |
+
urllib.request.urlretrieve(
|
118 |
+
f"http://files.rcsb.org/download/{pdb.lower()}.pdb1",
|
119 |
+
pdb_file.name,
|
120 |
+
)
|
121 |
+
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
|
122 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
123 |
+
allow-scripts allow-same-origin allow-popups
|
124 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
125 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
|
126 |
+
elif (radio_val == "AFDB2"):
|
127 |
+
if (re.match("[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",pdb) != None):
|
128 |
+
urllib.request.urlretrieve(
|
129 |
+
f"https://alphafold.ebi.ac.uk/files/AF-{pdb}-F1-model_v2.pdb",
|
130 |
+
pdb_file.name
|
131 |
+
)
|
132 |
+
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
|
133 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
134 |
+
allow-scripts allow-same-origin allow-popups
|
135 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
136 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
|
137 |
+
else:
|
138 |
+
return gr.update(open=True), gr.update(value="regex not matched",visible=True)
|
139 |
+
else:
|
140 |
+
return gr.update(open=False),gr.update(value=f"{pdb.name}"), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
|
141 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
142 |
+
allow-scripts allow-same-origin allow-popups
|
143 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
144 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb.name, representations=representations)}'></iframe>""",visible=True)
|
145 |
+
|
146 |
+
from Bio.PDB import PDBParser, cealign
|
147 |
+
from Bio.PDB.PDBIO import PDBIO
|
148 |
+
|
149 |
+
class dotdict(dict):
|
150 |
+
"""dot.notation access to dictionary attributes"""
|
151 |
+
__getattr__ = dict.get
|
152 |
+
__setattr__ = dict.__setitem__
|
153 |
+
__delattr__ = dict.__delitem__
|
154 |
+
|
155 |
+
def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
|
156 |
+
# Set up params, arguments, sampling config
|
157 |
+
####################
|
158 |
+
|
159 |
+
args = {}
|
160 |
+
args["model_checkpoint"] = "checkpoints" #Path to denoiser model weights and config",
|
161 |
+
|
162 |
+
args["mpnnpath"] = "checkpoints/minimpnn_state_dict.pth" #"Path to minimpnn model weights",
|
163 |
+
|
164 |
+
args["modeldir"] = None #"Model base directory, ex 'training_logs/other/lemon-shape-51'",
|
165 |
+
|
166 |
+
args["modelepoch"] = None #"Model epoch, ex 1000")
|
167 |
+
|
168 |
+
|
169 |
+
args["type"]=modeltype # "Type of model"
|
170 |
+
if m == "conditional":
|
171 |
+
args["param"] = None #"Which sampling param to vary"
|
172 |
+
args["paramval"]=None #"Which param val to use"
|
173 |
+
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
|
174 |
+
args["perlen"] = int(perlen) #How many samples per sequence length"
|
175 |
+
args["minlen"] = None #"Minimum sequence length"
|
176 |
+
args["maxlen"] = None #Maximum sequence length, not inclusive",
|
177 |
+
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
|
178 |
+
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
|
179 |
+
args["targetdir"] = "." #"Directory to save results"
|
180 |
+
args["input_pdb"] = path_to_file # "PDB file to condition on"
|
181 |
+
args["resample_idxs"] = resample_idx[1:-1] # "Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7"
|
182 |
+
else:
|
183 |
+
args["param"] = "n_steps" #"Which sampling param to vary"
|
184 |
+
args["paramval"]="100" #"Which param val to use"
|
185 |
+
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
|
186 |
+
args["perlen"] = int(perlen) #How many samples per sequence length"
|
187 |
+
args["minlen"] = int(minlen) #"Minimum sequence length"
|
188 |
+
args["maxlen"] = int(maxlen)+1 #Maximum sequence length
|
189 |
+
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
|
190 |
+
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
|
191 |
+
args["targetdir"] = "." #"Directory to save results"
|
192 |
+
args["resample_idxs"] = None
|
193 |
+
|
194 |
+
args = dotdict(args)
|
195 |
+
is_test_run = False
|
196 |
+
seed = 0
|
197 |
+
samples_per_len = args.perlen
|
198 |
+
min_len = args.minlen
|
199 |
+
max_len = args.maxlen
|
200 |
+
len_step_size = args.steplen
|
201 |
+
device = "cuda:0"
|
202 |
+
|
203 |
+
# setting default sampling config
|
204 |
+
if args.type == "backbone":
|
205 |
+
sampling_config = sampling.default_backbone_sampling_config()
|
206 |
+
elif args.type == "allatom":
|
207 |
+
sampling_config = sampling.default_allatom_sampling_config()
|
208 |
+
|
209 |
+
sampling_kwargs = vars(sampling_config)
|
210 |
+
|
211 |
+
# Parse conditioning inputs
|
212 |
+
input_pdb_len = None
|
213 |
+
if args.input_pdb:
|
214 |
+
input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True)
|
215 |
+
input_pdb_len = input_feats["aatype"].shape[0]
|
216 |
+
if args.resample_idxs:
|
217 |
+
print(
|
218 |
+
f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths."
|
219 |
+
)
|
220 |
+
resample_idxs = parse_idx_string(args.resample_idxs)
|
221 |
+
else:
|
222 |
+
resample_idxs = list(range(input_pdb_len))
|
223 |
+
cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs]
|
224 |
+
to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to(
|
225 |
+
device
|
226 |
+
)
|
227 |
+
|
228 |
+
# For unconditional model, center coords on whole structure
|
229 |
+
centered_coords = data.apply_random_se3(
|
230 |
+
input_feats["atom_positions"],
|
231 |
+
atom_mask=input_feats["atom_mask"],
|
232 |
+
translation_scale=0.0,
|
233 |
+
)
|
234 |
+
cond_kwargs = {}
|
235 |
+
cond_kwargs["gt_coords"] = to_batch_size(centered_coords)
|
236 |
+
cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"])
|
237 |
+
cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0
|
238 |
+
cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"])
|
239 |
+
cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"])
|
240 |
+
cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1
|
241 |
+
sampling_kwargs.update(cond_kwargs)
|
242 |
+
|
243 |
+
print("input_pdb_len", input_pdb_len)
|
244 |
+
|
245 |
+
# Determine lengths to sample at
|
246 |
+
if min_len is not None and max_len is not None:
|
247 |
+
if len_step_size is not None:
|
248 |
+
sampling_lengths = range(min_len, max_len, len_step_size)
|
249 |
+
else:
|
250 |
+
sampling_lengths = list(
|
251 |
+
torch.randint(min_len, max_len, size=(args.num_lens,))
|
252 |
+
)
|
253 |
+
elif input_pdb_len is not None:
|
254 |
+
sampling_lengths = [input_pdb_len]
|
255 |
+
else:
|
256 |
+
raise Exception("Need to provide a set of protein lengths or an input pdb.")
|
257 |
+
|
258 |
+
total_num_samples = len(list(sampling_lengths)) * samples_per_len
|
259 |
+
|
260 |
+
model_directory = args.modeldir
|
261 |
+
epoch = args.modelepoch
|
262 |
+
base_dir = args.targetdir
|
263 |
+
|
264 |
+
date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
|
265 |
+
if is_test_run:
|
266 |
+
date_string = f"test-{date_string}"
|
267 |
+
|
268 |
+
# Update sampling config with arguments
|
269 |
+
if args.param:
|
270 |
+
var_param = args.param
|
271 |
+
var_value = args.paramval
|
272 |
+
sampling_kwargs[var_param] = (
|
273 |
+
None
|
274 |
+
if var_value == "None"
|
275 |
+
else int(var_value)
|
276 |
+
if var_param == "n_steps"
|
277 |
+
else float(var_value)
|
278 |
+
)
|
279 |
+
elif args.parampath:
|
280 |
+
with open(args.parampath) as f:
|
281 |
+
var_params = json.loads(f.read())
|
282 |
+
sampling_kwargs.update(var_params)
|
283 |
+
|
284 |
+
# this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule
|
285 |
+
sampling_kwargs_readme = list(sampling_kwargs.items())
|
286 |
+
|
287 |
+
print("Base directory:", base_dir)
|
288 |
+
save_dir = f"{base_dir}/samples/{date_string}"
|
289 |
+
save_init_dir = f"{base_dir}/samples_inits/{date_string}"
|
290 |
+
|
291 |
+
# make dirs if do not exist
|
292 |
+
if not os.path.exists(save_dir):
|
293 |
+
subprocess.run(shlex.split(f"mkdir -p {save_dir}"))
|
294 |
+
|
295 |
+
if not os.path.exists(save_init_dir):
|
296 |
+
subprocess.run(shlex.split(f"mkdir -p {save_init_dir}"))
|
297 |
+
|
298 |
+
print("Samples saved to:", save_dir)
|
299 |
+
torch.manual_seed(seed)
|
300 |
+
|
301 |
+
# Load model
|
302 |
+
if args.type == "backbone":
|
303 |
+
if args.model_checkpoint:
|
304 |
+
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
305 |
+
cfg_path = f"{args.model_checkpoint}/backbone.yml"
|
306 |
+
else:
|
307 |
+
checkpoint = (
|
308 |
+
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
309 |
+
)
|
310 |
+
cfg_path = f"{model_directory}/configs/backbone.yml"
|
311 |
+
cfg = utils.load_config(cfg_path)
|
312 |
+
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
313 |
+
model = models.Protpardelle(cfg, device=device)
|
314 |
+
model.load_state_dict(weights)
|
315 |
+
model.to(device)
|
316 |
+
model.eval()
|
317 |
+
model.device = device
|
318 |
+
elif args.type == "allatom":
|
319 |
+
if args.model_checkpoint:
|
320 |
+
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
321 |
+
cfg_path = f"{args.model_checkpoint}/allatom.yml"
|
322 |
+
else:
|
323 |
+
checkpoint = (
|
324 |
+
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
325 |
+
)
|
326 |
+
cfg_path = f"{model_directory}/configs/allatom.yml"
|
327 |
+
config = utils.load_config(cfg_path)
|
328 |
+
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
329 |
+
model = models.Protpardelle(config, device=device)
|
330 |
+
model.load_state_dict(weights)
|
331 |
+
model.load_minimpnn(args.mpnnpath)
|
332 |
+
model.to(device)
|
333 |
+
model.eval()
|
334 |
+
model.device = device
|
335 |
+
|
336 |
+
with open(save_dir + "/run_parameters.txt", "w") as f:
|
337 |
+
f.write(f"Sampling run for {date_string}\n")
|
338 |
+
f.write(f"Random seed {seed}\n")
|
339 |
+
f.write(f"Model checkpoint: {checkpoint}\n")
|
340 |
+
f.write(
|
341 |
+
f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n"
|
342 |
+
)
|
343 |
+
f.write("Sampling params:\n")
|
344 |
+
for k, v in sampling_kwargs_readme:
|
345 |
+
f.write(f"{k}\t{v}\n")
|
346 |
+
|
347 |
+
# Draw samples
|
348 |
+
output_files = draw_and_save_samples(
|
349 |
+
model,
|
350 |
+
samples_per_len=samples_per_len,
|
351 |
+
lengths=sampling_lengths,
|
352 |
+
save_dir=save_dir,
|
353 |
+
mode=args.type,
|
354 |
+
**sampling_kwargs,
|
355 |
+
)
|
356 |
+
|
357 |
+
return output_files
|
358 |
+
|
359 |
+
|
360 |
+
def api_predict(pdb_content,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
|
361 |
+
|
362 |
+
if (m == "conditional"):
|
363 |
+
tempPDB = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
|
364 |
+
tempPDB.write(pdb_content.encode())
|
365 |
+
tempPDB.close()
|
366 |
+
|
367 |
+
path_to_file = tempPDB.name
|
368 |
+
else:
|
369 |
+
path_to_file = None
|
370 |
+
|
371 |
+
try:
|
372 |
+
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
|
373 |
+
except Exception as e:
|
374 |
+
print(e)
|
375 |
+
|
376 |
+
raise gr.Error(e)
|
377 |
+
|
378 |
+
# load each design as string
|
379 |
+
design_str = []
|
380 |
+
for d in designs:
|
381 |
+
with open(d, "r") as f:
|
382 |
+
design_str.append(f.read())
|
383 |
+
|
384 |
+
results = list(zip(designs, design_str))
|
385 |
+
return json.dumps(results)
|
386 |
+
|
387 |
+
def predict(pdb_radio, path_to_file,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
|
388 |
+
print("running predict")
|
389 |
+
try:
|
390 |
+
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
|
391 |
+
except Exception as e:
|
392 |
+
print(e)
|
393 |
+
|
394 |
+
raise gr.Error(e)
|
395 |
+
|
396 |
+
return gr.update(open=True), gr.update(value="something went wrong")
|
397 |
+
|
398 |
+
parser = PDBParser()
|
399 |
+
aligner = cealign.CEAligner()
|
400 |
+
io=PDBIO()
|
401 |
+
aligned_designs = []
|
402 |
+
metrics = []
|
403 |
+
if (m == "conditional"):
|
404 |
+
ref = parser.get_structure("ref", path_to_file)
|
405 |
+
aligner.set_reference(ref)
|
406 |
+
|
407 |
+
for d in designs:
|
408 |
+
design = parser.get_structure("design", d)
|
409 |
+
aligner.align(design)
|
410 |
+
metrics.append({"rms": f"{aligner.rms:.1f}", "len": len(list(design[0].get_residues()))})
|
411 |
+
io.set_structure(design)
|
412 |
+
io.save(d.replace(".pdb", f"_al.pdb"))
|
413 |
+
aligned_designs.append(d.replace(".pdb", f"_al.pdb"))
|
414 |
+
else:
|
415 |
+
for d in designs:
|
416 |
+
design = parser.get_structure("design", d)
|
417 |
+
metrics.append({"len": len(list(design[0].get_residues()))})
|
418 |
+
aligned_designs = designs
|
419 |
+
|
420 |
+
output_view = f"""<iframe style="width: 100%; height: 900px" name="result" allow="midi; geolocation; microphone; camera;
|
421 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
422 |
+
allow-scripts allow-same-origin allow-popups
|
423 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
424 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{output_html(path_to_file, aligned_designs, metrics, resample_idx=resample_idx, mode=m)}'></iframe>"""
|
425 |
+
|
426 |
+
return gr.update(open=False), gr.update(value=output_view,visible=True)
|
427 |
+
|
428 |
+
|
429 |
+
protpardelleDemo = gr.Blocks()
|
430 |
+
|
431 |
+
with protpardelleDemo:
|
432 |
+
gr.Markdown("# Protpardelle")
|
433 |
+
gr.Markdown(""" An all-atom protein generative model
|
434 |
+
Alexander E. Chu, Lucy Cheng, Gina El Nesr, Minkai Xu, Po-Ssu Huang
|
435 |
+
doi: https://doi.org/10.1101/2023.05.24.542194""")
|
436 |
+
|
437 |
+
with gr.Accordion(label="Input options", open=True) as input_accordion:
|
438 |
+
model = gr.Dropdown(["backbone", "allatom"], value="allatom", label="What to sample?")
|
439 |
+
|
440 |
+
m = gr.Radio(['unconditional','conditional'],value="unconditional", label="Choose a Mode")
|
441 |
+
|
442 |
+
|
443 |
+
#unconditional
|
444 |
+
with gr.Group(visible=True) as uncond:
|
445 |
+
gr.Markdown("Unconditional Sampling")
|
446 |
+
# length = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="length")
|
447 |
+
# param = gr.Dropdown(["length", "param"], value="length", label="Which sampling param to vary?")
|
448 |
+
# paramval = gr.Dropdown(["nsteps"], label="paramval", info="Which param val to use?")
|
449 |
+
|
450 |
+
#conditional
|
451 |
+
with gr.Group(visible=False) as cond:
|
452 |
+
with gr.Accordion(label="Structure to condition on", open=True) as input_accordion:
|
453 |
+
pdb_radio = gr.Radio(['PDB','AF2 EBI DB', 'upload'],value="PDB", label="source of the structure")
|
454 |
+
pdbcode = gr.Textbox(label="Uniprot code to be retrieved Alphafold2 Database", visible=True)
|
455 |
+
pdbfile = gr.File(label="PDB File", visible=False)
|
456 |
+
btn_load = gr.Button("Load PDB")
|
457 |
+
pdb_radio.change(fileselection, inputs=pdb_radio, outputs=[pdbcode, pdbfile, btn_load])
|
458 |
+
|
459 |
+
|
460 |
+
|
461 |
+
pdb_html = gr.HTML("", visible=False)
|
462 |
+
|
463 |
+
|
464 |
+
path_to_file = gr.Textbox(label="Path to file", visible=False)
|
465 |
+
resample_idxs = gr.Textbox(label="Cond Idxs", interactive=False, info="Zero indexed list of indices to condition on, select in sequence viewer above")
|
466 |
+
btn_load.click(update_structuresel, inputs=[pdbcode, pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
|
467 |
+
pdbfile.change(update_structuresel, inputs=[pdbfile,pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
|
468 |
+
|
469 |
+
with gr.Accordion(label="Sizes", open=True) as size_uncond:
|
470 |
+
with gr.Row():
|
471 |
+
minlen = gr.Slider(minimum=2, maximum=200,value=50, step=1, label="minlen", info="Minimum sequence length")
|
472 |
+
maxlen = gr.Slider(minimum=3, maximum=200,value=60, step=1, label="maxlen", info="Maximum sequence length")
|
473 |
+
steplen = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="steplen", info="How frequently to select sequence length?" )
|
474 |
+
perlen = gr.Slider(minimum=1, maximum=200, step=1, value=2, label="perlen", info="How many samples per sequence length?")
|
475 |
+
|
476 |
+
|
477 |
+
btn_conditional = gr.Button("Run conditional",visible=False)
|
478 |
+
btn_unconditional = gr.Button("Run unconditional")
|
479 |
+
m.change(changemode, inputs=m, outputs=[uncond, cond, btn_unconditional, btn_conditional, size_uncond])
|
480 |
+
out = gr.HTML("", visible=True)
|
481 |
+
|
482 |
+
btn_unconditional.click(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
|
483 |
+
|
484 |
+
btn_conditional.click(fn=None,
|
485 |
+
inputs=[resample_idxs],
|
486 |
+
outputs=[resample_idxs],
|
487 |
+
_js=get_js
|
488 |
+
) #
|
489 |
+
out_text = gr.Textbox(label="Output", visible=False)
|
490 |
+
#hidden button for named api route
|
491 |
+
pdb_content = gr.Textbox(label="PDB Content", visible=False)
|
492 |
+
btn_api = gr.Button("Run API",visible=False)
|
493 |
+
btn_api.click(api_predict, inputs=[pdb_content,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[out_text], api_name="protpardelle")
|
494 |
+
|
495 |
+
resample_idxs.change(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
|
496 |
+
protpardelleDemo.load(None, None, None, _js=load_js)
|
497 |
+
protpardelleDemo.queue()
|
498 |
+
protpardelleDemo.launch(allowed_paths=['samples'], share=True)
|
checkpoints/allatom.yml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
|
3 |
+
seed: 0
|
4 |
+
checkpoint: ['', 0]
|
5 |
+
batch_size: 32
|
6 |
+
max_epochs: 10000
|
7 |
+
eval_freq: 7200 # seconds
|
8 |
+
checkpoint_freq: 50
|
9 |
+
checkpoints: []
|
10 |
+
lr: 0.0001
|
11 |
+
warmup_steps: 1000
|
12 |
+
decay_steps: 2_000_000
|
13 |
+
clip_grad_norm: True
|
14 |
+
grad_clip_val: 1.0
|
15 |
+
weight_decay: 0.0
|
16 |
+
n_eval_samples: 8
|
17 |
+
sample_length_range: [50, 512]
|
18 |
+
sc_num_seqs: 4
|
19 |
+
eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
|
20 |
+
self_cond_train_prob: 0.9
|
21 |
+
subsample_eval_set: 0.05
|
22 |
+
crop_conditional: False
|
23 |
+
|
24 |
+
data:
|
25 |
+
pdb_path: 'datasets/ingraham_cath_dataset'
|
26 |
+
fixed_size: 512
|
27 |
+
n_aatype_tokens: 21
|
28 |
+
se3_data_augment: True
|
29 |
+
sigma_data: 10.0
|
30 |
+
|
31 |
+
diffusion:
|
32 |
+
training:
|
33 |
+
function: 'lognormal'
|
34 |
+
psigma_mean: -1.0
|
35 |
+
psigma_std: 1.5
|
36 |
+
sampling:
|
37 |
+
function: 'uniform'
|
38 |
+
s_min: 0.001
|
39 |
+
s_max: 80
|
40 |
+
|
41 |
+
model:
|
42 |
+
task: 'allatom' # 'backbone', 'allatom', 'seqdes', 'codesign'
|
43 |
+
pretrained_modules: [] # 'struct_model', 'mpnn_model'
|
44 |
+
struct_model_checkpoint: ''
|
45 |
+
mpnn_model_checkpoint: ''
|
46 |
+
crop_conditional: False
|
47 |
+
dummy_fill_masked_atoms: False
|
48 |
+
struct_model:
|
49 |
+
arch: 'uvit'
|
50 |
+
n_atoms: 37
|
51 |
+
n_channel: 256
|
52 |
+
noise_cond_mult: 4
|
53 |
+
uvit:
|
54 |
+
patch_size: 1
|
55 |
+
n_layers: 6
|
56 |
+
n_heads: 8
|
57 |
+
dim_head: 32
|
58 |
+
n_filt_per_layer: []
|
59 |
+
n_blocks_per_layer: 2
|
60 |
+
cat_pwd_to_conv: False
|
61 |
+
conv_skip_connection: False
|
62 |
+
position_embedding_type: 'rotary'
|
63 |
+
mpnn_model:
|
64 |
+
use_self_conditioning: True
|
65 |
+
label_smoothing: 0.1
|
66 |
+
n_channel: 128
|
67 |
+
n_layers: 3
|
68 |
+
n_neighbors: 32
|
69 |
+
noise_cond_mult: 4
|
checkpoints/allatom_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c854ce05b3b1b28c45f58ebf6e5cfba5a45b389ea2aa58a6ce25649d90da238f
|
3 |
+
size 87550006
|
checkpoints/backbone.yml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
|
3 |
+
seed: 0
|
4 |
+
checkpoint: ['', 0]
|
5 |
+
batch_size: 32
|
6 |
+
max_epochs: 10000
|
7 |
+
eval_freq: 7200 # seconds
|
8 |
+
checkpoint_freq: 50
|
9 |
+
checkpoints: []
|
10 |
+
lr: 0.0001
|
11 |
+
warmup_steps: 1000
|
12 |
+
decay_steps: 2_000_000
|
13 |
+
clip_grad_norm: True
|
14 |
+
grad_clip_val: 1.0
|
15 |
+
weight_decay: 0.0
|
16 |
+
n_eval_samples: 8
|
17 |
+
sample_length_range: [50, 512]
|
18 |
+
sc_num_seqs: 4
|
19 |
+
eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
|
20 |
+
self_cond_train_prob: 0.9
|
21 |
+
subsample_eval_set: 0.05
|
22 |
+
crop_conditional: False
|
23 |
+
|
24 |
+
data:
|
25 |
+
pdb_path: 'datasets/ingraham_cath_dataset'
|
26 |
+
fixed_size: 384
|
27 |
+
n_aatype_tokens: 21
|
28 |
+
se3_data_augment: True
|
29 |
+
sigma_data: 10.0
|
30 |
+
|
31 |
+
diffusion:
|
32 |
+
training:
|
33 |
+
function: 'lognormal'
|
34 |
+
psigma_mean: -1.2
|
35 |
+
psigma_std: 1.2
|
36 |
+
sampling:
|
37 |
+
function: 'uniform'
|
38 |
+
s_min: 0.001
|
39 |
+
s_max: 80
|
40 |
+
|
41 |
+
model:
|
42 |
+
task: 'backbone' # 'backbone', 'allatom', 'seqdes', 'codesign'
|
43 |
+
pretrained_modules: [] # 'struct_model', 'mpnn_model'
|
44 |
+
struct_model_checkpoint: ''
|
45 |
+
mpnn_model_checkpoint: ''
|
46 |
+
crop_conditional: False
|
47 |
+
dummy_fill_masked_atoms: False
|
48 |
+
struct_model:
|
49 |
+
arch: 'uvit'
|
50 |
+
n_atoms: 37 # keep same shapes, just zero out sidechains
|
51 |
+
n_channel: 256
|
52 |
+
noise_cond_mult: 4
|
53 |
+
uvit:
|
54 |
+
patch_size: 1
|
55 |
+
n_layers: 6
|
56 |
+
n_heads: 8
|
57 |
+
dim_head: 32
|
58 |
+
n_filt_per_layer: []
|
59 |
+
n_blocks_per_layer: 2
|
60 |
+
cat_pwd_to_conv: False
|
61 |
+
conv_skip_connection: False
|
62 |
+
position_embedding_type: 'absolute_residx'
|
63 |
+
mpnn_model:
|
64 |
+
use_self_conditioning: True
|
65 |
+
label_smoothing: 0.1
|
66 |
+
n_channel: 128
|
67 |
+
n_layers: 3
|
68 |
+
n_neighbors: 32
|
69 |
+
noise_cond_mult: 4
|
checkpoints/backbone_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2bcbdcca2419beb8f07cc1d43ee4d8c53d7e4ce21b4a144b88218af00ed3b2b9
|
3 |
+
size 87548437
|
checkpoints/minimpnn_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86be202225b3769976ef9bcec75029f4352d670d0107db560eec3d35eeacca9f
|
3 |
+
size 100570633
|
configs/allatom.yml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
|
3 |
+
seed: 0
|
4 |
+
checkpoint: ['', 0]
|
5 |
+
batch_size: 32
|
6 |
+
max_epochs: 10000
|
7 |
+
eval_freq: 7200 # seconds
|
8 |
+
checkpoint_freq: 50
|
9 |
+
checkpoints: []
|
10 |
+
lr: 0.0001
|
11 |
+
warmup_steps: 1000
|
12 |
+
decay_steps: 2_000_000
|
13 |
+
clip_grad_norm: True
|
14 |
+
grad_clip_val: 1.0
|
15 |
+
weight_decay: 0.0
|
16 |
+
n_eval_samples: 8
|
17 |
+
sample_length_range: [50, 512]
|
18 |
+
sc_num_seqs: 4
|
19 |
+
eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
|
20 |
+
self_cond_train_prob: 0.9
|
21 |
+
subsample_eval_set: 0.05
|
22 |
+
crop_conditional: False
|
23 |
+
|
24 |
+
data:
|
25 |
+
pdb_path: 'datasets/ingraham_cath_dataset'
|
26 |
+
fixed_size: 512
|
27 |
+
n_aatype_tokens: 21
|
28 |
+
se3_data_augment: True
|
29 |
+
sigma_data: 10.0
|
30 |
+
|
31 |
+
diffusion:
|
32 |
+
training:
|
33 |
+
function: 'lognormal'
|
34 |
+
psigma_mean: -1.0
|
35 |
+
psigma_std: 1.5
|
36 |
+
sampling:
|
37 |
+
function: 'uniform'
|
38 |
+
s_min: 0.001
|
39 |
+
s_max: 80
|
40 |
+
|
41 |
+
model:
|
42 |
+
task: 'allatom' # 'backbone', 'allatom', 'seqdes', 'codesign'
|
43 |
+
pretrained_modules: [] # 'struct_model', 'mpnn_model'
|
44 |
+
struct_model_checkpoint: ''
|
45 |
+
mpnn_model_checkpoint: ''
|
46 |
+
crop_conditional: False
|
47 |
+
dummy_fill_masked_atoms: False
|
48 |
+
struct_model:
|
49 |
+
arch: 'uvit'
|
50 |
+
n_atoms: 37
|
51 |
+
n_channel: 256
|
52 |
+
noise_cond_mult: 4
|
53 |
+
uvit:
|
54 |
+
patch_size: 1
|
55 |
+
n_layers: 6
|
56 |
+
n_heads: 8
|
57 |
+
dim_head: 32
|
58 |
+
n_filt_per_layer: []
|
59 |
+
n_blocks_per_layer: 2
|
60 |
+
cat_pwd_to_conv: False
|
61 |
+
conv_skip_connection: False
|
62 |
+
position_embedding_type: 'rotary'
|
63 |
+
mpnn_model:
|
64 |
+
use_self_conditioning: True
|
65 |
+
label_smoothing: 0.1
|
66 |
+
n_channel: 128
|
67 |
+
n_layers: 3
|
68 |
+
n_neighbors: 32
|
69 |
+
noise_cond_mult: 4
|
configs/backbone.yml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
home_dir: '/scratch/users/alexechu'
|
3 |
+
seed: 0
|
4 |
+
checkpoint: ['', 0]
|
5 |
+
batch_size: 32
|
6 |
+
max_epochs: 10000
|
7 |
+
eval_freq: 7200 # seconds
|
8 |
+
checkpoint_freq: 50
|
9 |
+
checkpoints: []
|
10 |
+
lr: 0.0001
|
11 |
+
warmup_steps: 1000
|
12 |
+
decay_steps: 2_000_000
|
13 |
+
clip_grad_norm: True
|
14 |
+
grad_clip_val: 1.0
|
15 |
+
weight_decay: 0.0
|
16 |
+
n_eval_samples: 8
|
17 |
+
sample_length_range: [50, 512]
|
18 |
+
sc_num_seqs: 4
|
19 |
+
eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
|
20 |
+
self_cond_train_prob: 0.9
|
21 |
+
subsample_eval_set: 0.05
|
22 |
+
crop_conditional: False
|
23 |
+
|
24 |
+
data:
|
25 |
+
pdb_path: 'datasets/ingraham_cath_dataset'
|
26 |
+
fixed_size: 384
|
27 |
+
n_aatype_tokens: 21
|
28 |
+
se3_data_augment: True
|
29 |
+
sigma_data: 10.0
|
30 |
+
|
31 |
+
diffusion:
|
32 |
+
training:
|
33 |
+
function: 'lognormal'
|
34 |
+
psigma_mean: -1.2
|
35 |
+
psigma_std: 1.2
|
36 |
+
sampling:
|
37 |
+
function: 'uniform'
|
38 |
+
s_min: 0.001
|
39 |
+
s_max: 80
|
40 |
+
|
41 |
+
model:
|
42 |
+
task: 'backbone' # 'backbone', 'allatom', 'seqdes', 'codesign'
|
43 |
+
pretrained_modules: [] # 'struct_model', 'mpnn_model'
|
44 |
+
struct_model_checkpoint: ''
|
45 |
+
mpnn_model_checkpoint: ''
|
46 |
+
crop_conditional: False
|
47 |
+
dummy_fill_masked_atoms: False
|
48 |
+
struct_model:
|
49 |
+
arch: 'uvit'
|
50 |
+
n_atoms: 37 # keep same shapes, just zero out sidechains
|
51 |
+
n_channel: 256
|
52 |
+
noise_cond_mult: 4
|
53 |
+
uvit:
|
54 |
+
patch_size: 1
|
55 |
+
n_layers: 6
|
56 |
+
n_heads: 8
|
57 |
+
dim_head: 32
|
58 |
+
n_filt_per_layer: []
|
59 |
+
n_blocks_per_layer: 2
|
60 |
+
cat_pwd_to_conv: False
|
61 |
+
conv_skip_connection: False
|
62 |
+
position_embedding_type: 'absolute_residx'
|
63 |
+
mpnn_model:
|
64 |
+
use_self_conditioning: True
|
65 |
+
label_smoothing: 0.1
|
66 |
+
n_channel: 128
|
67 |
+
n_layers: 3
|
68 |
+
n_neighbors: 32
|
69 |
+
noise_cond_mult: 4
|
configs/seqdes.yml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train:
|
2 |
+
home_dir: '/scratch/users/alexechu'
|
3 |
+
seed: 0
|
4 |
+
checkpoint: ['', 0]
|
5 |
+
batch_size: 32
|
6 |
+
max_epochs: 10000
|
7 |
+
eval_freq: 3600 # seconds
|
8 |
+
checkpoint_freq: 20
|
9 |
+
checkpoints: []
|
10 |
+
lr: 0.0001
|
11 |
+
warmup_steps: 1000
|
12 |
+
decay_steps: 400_000
|
13 |
+
clip_grad_norm: True
|
14 |
+
grad_clip_val: 1.0
|
15 |
+
weight_decay: 0.0
|
16 |
+
n_eval_samples: 8
|
17 |
+
sample_length_range: [50, 512]
|
18 |
+
sc_num_seqs: 4
|
19 |
+
eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
|
20 |
+
self_cond_train_prob: 0.9
|
21 |
+
dgram_loss_weight: False
|
22 |
+
subsample_eval_set: 0.1
|
23 |
+
crop_conditional: False
|
24 |
+
|
25 |
+
data:
|
26 |
+
pdb_path: 'datasets/ingraham_cath_dataset'
|
27 |
+
fixed_size: 512
|
28 |
+
n_aatype_tokens: 21
|
29 |
+
se3_data_augment: True
|
30 |
+
sigma_data: 10.0
|
31 |
+
|
32 |
+
diffusion:
|
33 |
+
training:
|
34 |
+
function: 'mpnn'
|
35 |
+
psigma_mean: -1.2
|
36 |
+
psigma_std: 1.2
|
37 |
+
time_power: 30.0
|
38 |
+
constant_val: 0.02
|
39 |
+
sampling:
|
40 |
+
function: 'uniform'
|
41 |
+
s_min: 0.001
|
42 |
+
s_max: 60
|
43 |
+
|
44 |
+
model:
|
45 |
+
task: 'seqdes' # 'backbone', 'allatom', 'seqdes', 'codesign'
|
46 |
+
pretrained_modules: ['struct_model'] # 'struct_model', 'mpnn_model'
|
47 |
+
struct_model_checkpoint: 'protpardelle/checkpoints/allatom_state_dict.pth'
|
48 |
+
mpnn_model_checkpoint: ''
|
49 |
+
crop_conditional: False
|
50 |
+
dummy_fill_masked_atoms: False
|
51 |
+
debug_mpnn: True
|
52 |
+
struct_model:
|
53 |
+
arch: 'uvit'
|
54 |
+
n_channel: 256
|
55 |
+
n_atoms: 37
|
56 |
+
noise_cond_mult: 4
|
57 |
+
uvit:
|
58 |
+
patch_size: 1
|
59 |
+
n_layers: 6
|
60 |
+
n_heads: 8
|
61 |
+
dim_head: 32
|
62 |
+
n_filt_per_layer: [] # None or [] for vanilla trf
|
63 |
+
n_blocks_per_layer: 2
|
64 |
+
cat_pwd_to_conv: False
|
65 |
+
conv_skip_connection: False # n layers must == 1
|
66 |
+
position_embedding_type: 'rotary'
|
67 |
+
mpnn_model:
|
68 |
+
use_self_conditioning: False
|
69 |
+
label_smoothing: 0.0
|
70 |
+
n_channel: 128
|
71 |
+
n_layers: 3
|
72 |
+
n_neighbors: 32
|
73 |
+
noise_cond_mult: 4
|
74 |
+
|
core/__init__.py
ADDED
File without changes
|
core/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (169 Bytes). View file
|
|
core/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (169 Bytes). View file
|
|
core/__pycache__/data.cpython-38.pyc
ADDED
Binary file (6.74 kB). View file
|
|
core/__pycache__/data.cpython-39.pyc
ADDED
Binary file (6.66 kB). View file
|
|
core/__pycache__/protein.cpython-38.pyc
ADDED
Binary file (7.97 kB). View file
|
|
core/__pycache__/protein.cpython-39.pyc
ADDED
Binary file (7.94 kB). View file
|
|
core/__pycache__/protein_mpnn.cpython-38.pyc
ADDED
Binary file (53.5 kB). View file
|
|
core/__pycache__/protein_mpnn.cpython-39.pyc
ADDED
Binary file (53.3 kB). View file
|
|
core/__pycache__/residue_constants.cpython-38.pyc
ADDED
Binary file (21.2 kB). View file
|
|
core/__pycache__/residue_constants.cpython-39.pyc
ADDED
Binary file (24 kB). View file
|
|
core/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (30.3 kB). View file
|
|
core/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (30.1 kB). View file
|
|
core/data.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Dataloader from PDB files.
|
7 |
+
"""
|
8 |
+
import copy
|
9 |
+
import pickle
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch.utils import data
|
14 |
+
|
15 |
+
from core import utils
|
16 |
+
from core import protein
|
17 |
+
from core import residue_constants
|
18 |
+
|
19 |
+
|
20 |
+
FEATURES_1D = (
|
21 |
+
"coords_in",
|
22 |
+
"torsions_in",
|
23 |
+
"b_factors",
|
24 |
+
"atom_positions",
|
25 |
+
"aatype",
|
26 |
+
"atom_mask",
|
27 |
+
"residue_index",
|
28 |
+
"chain_index",
|
29 |
+
)
|
30 |
+
FEATURES_FLOAT = (
|
31 |
+
"coords_in",
|
32 |
+
"torsions_in",
|
33 |
+
"b_factors",
|
34 |
+
"atom_positions",
|
35 |
+
"atom_mask",
|
36 |
+
"seq_mask",
|
37 |
+
)
|
38 |
+
FEATURES_LONG = ("aatype", "residue_index", "chain_index", "orig_size")
|
39 |
+
|
40 |
+
|
41 |
+
def make_fixed_size_1d(data, fixed_size=128):
|
42 |
+
data_len = data.shape[0]
|
43 |
+
if data_len >= fixed_size:
|
44 |
+
extra_len = data_len - fixed_size
|
45 |
+
start_idx = np.random.choice(np.arange(extra_len + 1))
|
46 |
+
new_data = data[start_idx : (start_idx + fixed_size)]
|
47 |
+
mask = torch.ones(fixed_size)
|
48 |
+
if data_len < fixed_size:
|
49 |
+
pad_size = fixed_size - data_len
|
50 |
+
extra_shape = data.shape[1:]
|
51 |
+
new_data = torch.cat([data, torch.zeros(pad_size, *extra_shape)], 0)
|
52 |
+
mask = torch.cat([torch.ones(data_len), torch.zeros(pad_size)], 0)
|
53 |
+
return new_data, mask
|
54 |
+
|
55 |
+
|
56 |
+
def apply_random_se3(coords_in, atom_mask=None, translation_scale=1.0):
|
57 |
+
# unbatched. center on the mean of CA coords
|
58 |
+
coords_mean = coords_in[:, 1:2].mean(-3, keepdim=True)
|
59 |
+
coords_in -= coords_mean
|
60 |
+
random_rot, _ = torch.linalg.qr(torch.randn(3, 3))
|
61 |
+
coords_in = coords_in @ random_rot
|
62 |
+
random_trans = torch.randn_like(coords_mean) * translation_scale
|
63 |
+
coords_in += random_trans
|
64 |
+
if atom_mask is not None:
|
65 |
+
coords_in = coords_in * atom_mask[..., None]
|
66 |
+
return coords_in
|
67 |
+
|
68 |
+
|
69 |
+
def get_masked_coords_array(coords, atom_mask):
|
70 |
+
ma_mask = repeat(1 - atom_mask[..., None].cpu().numpy(), "... 1 -> ... 3")
|
71 |
+
return np.ma.array(coords.cpu().numpy(), mask=ma_mask)
|
72 |
+
|
73 |
+
|
74 |
+
def make_crop_cond_mask_and_recenter_coords(
|
75 |
+
atom_mask,
|
76 |
+
atom_coords,
|
77 |
+
contiguous_prob=0.05,
|
78 |
+
discontiguous_prob=0.9,
|
79 |
+
sidechain_only_prob=0.8,
|
80 |
+
max_span_len=10,
|
81 |
+
max_discontiguous_res=8,
|
82 |
+
dist_threshold=8.0,
|
83 |
+
recenter_coords=True,
|
84 |
+
):
|
85 |
+
b, n, a = atom_mask.shape
|
86 |
+
device = atom_mask.device
|
87 |
+
seq_mask = atom_mask[..., 1]
|
88 |
+
n_res = seq_mask.sum(-1)
|
89 |
+
masks = []
|
90 |
+
|
91 |
+
for i, nr in enumerate(n_res):
|
92 |
+
nr = nr.int().item()
|
93 |
+
mask = torch.zeros((n, a), device=device)
|
94 |
+
conditioning_type = torch.distributions.Categorical(
|
95 |
+
torch.tensor(
|
96 |
+
[
|
97 |
+
contiguous_prob,
|
98 |
+
discontiguous_prob,
|
99 |
+
1.0 - contiguous_prob - discontiguous_prob,
|
100 |
+
]
|
101 |
+
)
|
102 |
+
).sample()
|
103 |
+
conditioning_type = ["contiguous", "discontiguous", "none"][conditioning_type]
|
104 |
+
|
105 |
+
if conditioning_type == "contiguous":
|
106 |
+
span_len = torch.randint(
|
107 |
+
1, min(max_span_len, nr), (1,), device=device
|
108 |
+
).item()
|
109 |
+
span_start = torch.randint(0, nr - span_len, (1,), device=device)
|
110 |
+
mask[span_start : span_start + span_len, :] = 1
|
111 |
+
elif conditioning_type == "discontiguous":
|
112 |
+
# Extract CB atoms coordinates for the i-th example
|
113 |
+
cb_atoms = atom_coords[i, :, 3]
|
114 |
+
# Pairwise distances between CB atoms
|
115 |
+
cb_distances = torch.cdist(cb_atoms, cb_atoms)
|
116 |
+
close_mask = (
|
117 |
+
cb_distances <= dist_threshold
|
118 |
+
) # Mask for selecting close CB atoms
|
119 |
+
|
120 |
+
random_residue = torch.randint(0, nr, (1,), device=device).squeeze()
|
121 |
+
cb_dist_i = cb_distances[random_residue] + 1e3 * (1 - seq_mask[i])
|
122 |
+
close_mask = cb_dist_i <= dist_threshold
|
123 |
+
n_neighbors = close_mask.sum().int()
|
124 |
+
|
125 |
+
# pick how many neighbors (up to 10)
|
126 |
+
n_sele = torch.randint(
|
127 |
+
2,
|
128 |
+
n_neighbors.clamp(min=3, max=max_discontiguous_res + 1),
|
129 |
+
(1,),
|
130 |
+
device=device,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Select the indices of CB atoms that are close together
|
134 |
+
idxs = torch.arange(n, device=device)[close_mask.bool()]
|
135 |
+
idxs = idxs[torch.randperm(len(idxs))[:n_sele]]
|
136 |
+
|
137 |
+
if len(idxs) > 0:
|
138 |
+
mask[idxs] = 1
|
139 |
+
|
140 |
+
if np.random.uniform() < sidechain_only_prob:
|
141 |
+
mask[:, :5] = 0
|
142 |
+
|
143 |
+
masks.append(mask)
|
144 |
+
|
145 |
+
crop_cond_mask = torch.stack(masks)
|
146 |
+
crop_cond_mask = crop_cond_mask * atom_mask
|
147 |
+
if recenter_coords:
|
148 |
+
motif_masked_array = get_masked_coords_array(atom_coords, crop_cond_mask)
|
149 |
+
cond_coords_center = motif_masked_array.mean((1, 2))
|
150 |
+
motif_mask = torch.Tensor(1 - cond_coords_center.mask).to(crop_cond_mask)
|
151 |
+
means = torch.Tensor(cond_coords_center.data).to(atom_coords) * motif_mask
|
152 |
+
coords_out = atom_coords - rearrange(means, "b c -> b 1 1 c")
|
153 |
+
else:
|
154 |
+
coords_out = atom_coords
|
155 |
+
return coords_out, crop_cond_mask
|
156 |
+
|
157 |
+
|
158 |
+
class Dataset(data.Dataset):
|
159 |
+
"""Loads and processes PDBs into tensors."""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
pdb_path,
|
164 |
+
fixed_size,
|
165 |
+
mode="train",
|
166 |
+
overfit=-1,
|
167 |
+
short_epoch=False,
|
168 |
+
se3_data_augment=True,
|
169 |
+
):
|
170 |
+
self.pdb_path = pdb_path
|
171 |
+
self.fixed_size = fixed_size
|
172 |
+
self.mode = mode
|
173 |
+
self.overfit = overfit
|
174 |
+
self.short_epoch = short_epoch
|
175 |
+
self.se3_data_augment = se3_data_augment
|
176 |
+
|
177 |
+
with open(f"{self.pdb_path}/{mode}_pdb_keys.list") as f:
|
178 |
+
self.pdb_keys = np.array(f.read().split("\n")[:-1])
|
179 |
+
|
180 |
+
if overfit > 0:
|
181 |
+
n_data = len(self.pdb_keys)
|
182 |
+
self.pdb_keys = np.random.choice(
|
183 |
+
self.pdb_keys, min(n_data, overfit), replace=False
|
184 |
+
).repeat(n_data // overfit)
|
185 |
+
|
186 |
+
def __len__(self):
|
187 |
+
if self.short_epoch:
|
188 |
+
return min(len(self.pdb_keys), 256)
|
189 |
+
else:
|
190 |
+
return len(self.pdb_keys)
|
191 |
+
|
192 |
+
def __getitem__(self, idx):
|
193 |
+
pdb_key = self.pdb_keys[idx]
|
194 |
+
data = self.get_item(pdb_key)
|
195 |
+
# For now, replace dataloading errors with a random pdb. 10 tries
|
196 |
+
for _ in range(10):
|
197 |
+
if data is not None:
|
198 |
+
return data
|
199 |
+
pdb_key = self.pdb_keys[np.random.randint(len(self.pdb_keys))]
|
200 |
+
data = self.get_item(pdb_key)
|
201 |
+
raise Exception("Failed to load data example after 10 tries.")
|
202 |
+
|
203 |
+
def get_item(self, pdb_key):
|
204 |
+
example = {}
|
205 |
+
|
206 |
+
if self.pdb_path.endswith("cath_s40_dataset"): # CATH pdbs
|
207 |
+
data_file = f"{self.pdb_path}/dompdb/{pdb_key}"
|
208 |
+
elif self.pdb_path.endswith("ingraham_cath_dataset"): # ingraham splits
|
209 |
+
data_file = f"{self.pdb_path}/pdb_store/{pdb_key}"
|
210 |
+
else:
|
211 |
+
raise Exception("Invalid pdb path.")
|
212 |
+
|
213 |
+
try:
|
214 |
+
example = utils.load_feats_from_pdb(data_file)
|
215 |
+
coords_in = example["atom_positions"]
|
216 |
+
except FileNotFoundError:
|
217 |
+
raise Exception(f"File {pdb_key} not found. Check if dataset is corrupted?")
|
218 |
+
except RuntimeError:
|
219 |
+
return None
|
220 |
+
|
221 |
+
# Apply data augmentation
|
222 |
+
if self.se3_data_augment:
|
223 |
+
coords_in = apply_random_se3(coords_in, atom_mask=example["atom_mask"])
|
224 |
+
|
225 |
+
orig_size = coords_in.shape[0]
|
226 |
+
example["coords_in"] = coords_in
|
227 |
+
example["orig_size"] = torch.ones(1) * orig_size
|
228 |
+
|
229 |
+
fixed_size_example = {}
|
230 |
+
seq_mask = None
|
231 |
+
for k, v in example.items():
|
232 |
+
if k in FEATURES_1D:
|
233 |
+
fixed_size_example[k], seq_mask = make_fixed_size_1d(
|
234 |
+
v, fixed_size=self.fixed_size
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
fixed_size_example[k] = v
|
238 |
+
if seq_mask is not None:
|
239 |
+
fixed_size_example["seq_mask"] = seq_mask
|
240 |
+
|
241 |
+
example_out = {}
|
242 |
+
for k, v in fixed_size_example.items():
|
243 |
+
if k in FEATURES_FLOAT:
|
244 |
+
example_out[k] = v.float()
|
245 |
+
elif k in FEATURES_LONG:
|
246 |
+
example_out[k] = v.long()
|
247 |
+
|
248 |
+
return example_out
|
249 |
+
|
250 |
+
def collate(self, example_list):
|
251 |
+
out = {}
|
252 |
+
for ex in example_list:
|
253 |
+
for k, v in ex.items():
|
254 |
+
out.setdefault(k, []).append(v)
|
255 |
+
return {k: torch.stack(v) for k, v in out.items()}
|
256 |
+
|
257 |
+
def sample(self, n=1, return_data=True, return_keys=False):
|
258 |
+
keys = self.pdb_keys[torch.randperm(self.__len__())[:n].long()]
|
259 |
+
|
260 |
+
if return_keys and not return_data:
|
261 |
+
return keys
|
262 |
+
|
263 |
+
if n == 1:
|
264 |
+
data = self.collate([self.get_item(keys)])
|
265 |
+
else:
|
266 |
+
data = self.collate([self.get_item(key) for key in keys])
|
267 |
+
|
268 |
+
if return_data and return_keys:
|
269 |
+
return data, keys
|
270 |
+
if return_data and not return_keys:
|
271 |
+
return data
|
core/protein.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Protein data type.
|
16 |
+
Adapted from original code by alexechu.
|
17 |
+
"""
|
18 |
+
import dataclasses
|
19 |
+
import io
|
20 |
+
from typing import Any, Mapping, Optional
|
21 |
+
from core import residue_constants
|
22 |
+
from Bio.PDB import PDBParser
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
FeatureDict = Mapping[str, np.ndarray]
|
26 |
+
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
27 |
+
|
28 |
+
# Complete sequence of chain IDs supported by the PDB format.
|
29 |
+
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
30 |
+
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
|
31 |
+
|
32 |
+
|
33 |
+
@dataclasses.dataclass(frozen=True)
|
34 |
+
class Protein:
|
35 |
+
"""Protein structure representation."""
|
36 |
+
|
37 |
+
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
38 |
+
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
39 |
+
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
40 |
+
|
41 |
+
# Amino-acid type for each residue represented as an integer between 0 and
|
42 |
+
# 20, where 20 is 'X'.
|
43 |
+
aatype: np.ndarray # [num_res]
|
44 |
+
|
45 |
+
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
46 |
+
# is present and 0.0 if not. This should be used for loss masking.
|
47 |
+
atom_mask: np.ndarray # [num_res, num_atom_type]
|
48 |
+
|
49 |
+
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
50 |
+
residue_index: np.ndarray # [num_res]
|
51 |
+
|
52 |
+
# 0-indexed number corresponding to the chain in the protein that this residue
|
53 |
+
# belongs to.
|
54 |
+
chain_index: np.ndarray # [num_res]
|
55 |
+
|
56 |
+
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
57 |
+
# representing the displacement of the residue from its ground truth mean
|
58 |
+
# value.
|
59 |
+
b_factors: np.ndarray # [num_res, num_atom_type]
|
60 |
+
|
61 |
+
def __post_init__(self):
|
62 |
+
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
|
63 |
+
raise ValueError(
|
64 |
+
f"Cannot build an instance with more than {PDB_MAX_CHAINS} chains "
|
65 |
+
"because these cannot be written to PDB format."
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def from_pdb_string(
|
70 |
+
pdb_str: str, chain_id: Optional[str] = None, protein_only: bool = False
|
71 |
+
) -> Protein:
|
72 |
+
"""Takes a PDB string and constructs a Protein object.
|
73 |
+
|
74 |
+
WARNING: All non-standard residue types will be converted into UNK. All
|
75 |
+
non-standard atoms will be ignored.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
pdb_str: The contents of the pdb file
|
79 |
+
chain_id: If chain_id is specified (e.g. A), then only that chain
|
80 |
+
is parsed. Otherwise all chains are parsed.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
A new `Protein` parsed from the pdb contents.
|
84 |
+
"""
|
85 |
+
pdb_fh = io.StringIO(pdb_str)
|
86 |
+
parser = PDBParser(QUIET=True)
|
87 |
+
structure = parser.get_structure("none", pdb_fh)
|
88 |
+
models = list(structure.get_models())
|
89 |
+
if len(models) != 1:
|
90 |
+
raise ValueError(
|
91 |
+
f"Only single model PDBs are supported. Found {len(models)} models."
|
92 |
+
)
|
93 |
+
model = models[0]
|
94 |
+
|
95 |
+
atom_positions = []
|
96 |
+
aatype = []
|
97 |
+
atom_mask = []
|
98 |
+
residue_index = []
|
99 |
+
chain_ids = []
|
100 |
+
b_factors = []
|
101 |
+
|
102 |
+
for chain in model:
|
103 |
+
if chain_id is not None and chain.id != chain_id:
|
104 |
+
continue
|
105 |
+
for res in chain:
|
106 |
+
if protein_only and res.id[0] != " ":
|
107 |
+
continue
|
108 |
+
if res.id[2] != " ":
|
109 |
+
pass
|
110 |
+
# raise ValueError(
|
111 |
+
# f"PDB contains an insertion code at chain {chain.id} and residue "
|
112 |
+
# f"index {res.id[1]}. These are not supported."
|
113 |
+
# )
|
114 |
+
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
|
115 |
+
restype_idx = residue_constants.restype_order.get(
|
116 |
+
res_shortname, residue_constants.restype_num
|
117 |
+
)
|
118 |
+
pos = np.zeros((residue_constants.atom_type_num, 3))
|
119 |
+
mask = np.zeros((residue_constants.atom_type_num,))
|
120 |
+
res_b_factors = np.zeros((residue_constants.atom_type_num,))
|
121 |
+
for atom in res:
|
122 |
+
if atom.name not in residue_constants.atom_types:
|
123 |
+
continue
|
124 |
+
pos[residue_constants.atom_order[atom.name]] = atom.coord
|
125 |
+
mask[residue_constants.atom_order[atom.name]] = 1.0
|
126 |
+
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
|
127 |
+
if np.sum(mask) < 0.5:
|
128 |
+
# If no known atom positions are reported for the residue then skip it.
|
129 |
+
continue
|
130 |
+
aatype.append(restype_idx)
|
131 |
+
atom_positions.append(pos)
|
132 |
+
atom_mask.append(mask)
|
133 |
+
residue_index.append(res.id[1])
|
134 |
+
chain_ids.append(chain.id)
|
135 |
+
b_factors.append(res_b_factors)
|
136 |
+
|
137 |
+
# Chain IDs are usually characters so map these to ints.
|
138 |
+
unique_chain_ids = np.unique(chain_ids)
|
139 |
+
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
|
140 |
+
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
|
141 |
+
|
142 |
+
return Protein(
|
143 |
+
atom_positions=np.array(atom_positions),
|
144 |
+
atom_mask=np.array(atom_mask),
|
145 |
+
aatype=np.array(aatype),
|
146 |
+
residue_index=np.array(residue_index),
|
147 |
+
chain_index=chain_index,
|
148 |
+
b_factors=np.array(b_factors),
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
|
153 |
+
chain_end = "TER"
|
154 |
+
return (
|
155 |
+
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
|
156 |
+
f"{chain_name:>1}{residue_index:>4}"
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
def are_atoms_bonded(res3name, atom1_name, atom2_name):
|
161 |
+
lookup_table = residue_constants.standard_residue_bonds
|
162 |
+
for bond in lookup_table[res3name]:
|
163 |
+
if bond.atom1_name == atom1_name and bond.atom2_name == atom2_name:
|
164 |
+
return True
|
165 |
+
elif bond.atom1_name == atom2_name and bond.atom2_name == atom1_name:
|
166 |
+
return True
|
167 |
+
return False
|
168 |
+
|
169 |
+
|
170 |
+
def to_pdb(prot: Protein, conect=False) -> str:
|
171 |
+
"""Converts a `Protein` instance to a PDB string.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
prot: The protein to convert to PDB.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
PDB string.
|
178 |
+
"""
|
179 |
+
restypes = residue_constants.restypes + ["X"]
|
180 |
+
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
|
181 |
+
atom_types = residue_constants.atom_types
|
182 |
+
|
183 |
+
pdb_lines = []
|
184 |
+
|
185 |
+
atom_mask = prot.atom_mask
|
186 |
+
aatype = prot.aatype
|
187 |
+
atom_positions = prot.atom_positions
|
188 |
+
residue_index = prot.residue_index.astype(np.int32)
|
189 |
+
chain_index = prot.chain_index.astype(np.int32)
|
190 |
+
b_factors = prot.b_factors
|
191 |
+
|
192 |
+
if np.any(aatype > residue_constants.restype_num):
|
193 |
+
raise ValueError("Invalid aatypes.")
|
194 |
+
|
195 |
+
# Construct a mapping from chain integer indices to chain ID strings.
|
196 |
+
chain_ids = {}
|
197 |
+
for i in np.unique(chain_index): # np.unique gives sorted output.
|
198 |
+
if i >= PDB_MAX_CHAINS:
|
199 |
+
raise ValueError(
|
200 |
+
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
|
201 |
+
)
|
202 |
+
chain_ids[i] = PDB_CHAIN_IDS[i]
|
203 |
+
|
204 |
+
pdb_lines.append("MODEL 1")
|
205 |
+
atom_index = 1
|
206 |
+
last_chain_index = chain_index[0]
|
207 |
+
conect_lines = []
|
208 |
+
# Add all atom sites.
|
209 |
+
for i in range(aatype.shape[0]):
|
210 |
+
# Close the previous chain if in a multichain PDB.
|
211 |
+
if last_chain_index != chain_index[i]:
|
212 |
+
pdb_lines.append(
|
213 |
+
_chain_end(
|
214 |
+
atom_index,
|
215 |
+
res_1to3(aatype[i - 1]),
|
216 |
+
chain_ids[chain_index[i - 1]],
|
217 |
+
residue_index[i - 1],
|
218 |
+
)
|
219 |
+
)
|
220 |
+
last_chain_index = chain_index[i]
|
221 |
+
atom_index += 1 # Atom index increases at the TER symbol.
|
222 |
+
|
223 |
+
res_name_3 = res_1to3(aatype[i])
|
224 |
+
atoms_appended_for_res = []
|
225 |
+
for atom_name, pos, mask, b_factor in zip(
|
226 |
+
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
227 |
+
):
|
228 |
+
if mask < 0.5:
|
229 |
+
continue
|
230 |
+
|
231 |
+
record_type = "ATOM"
|
232 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
233 |
+
alt_loc = ""
|
234 |
+
insertion_code = ""
|
235 |
+
occupancy = 1.00
|
236 |
+
element = atom_name[0] # Protein supports only C, N, O, S, this works.
|
237 |
+
charge = ""
|
238 |
+
# PDB is a columnar format, every space matters here!
|
239 |
+
atom_line = (
|
240 |
+
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
241 |
+
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
|
242 |
+
f"{residue_index[i]:>4}{insertion_code:>1} "
|
243 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
244 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
245 |
+
f"{element:>2}{charge:>2}"
|
246 |
+
)
|
247 |
+
pdb_lines.append(atom_line)
|
248 |
+
|
249 |
+
for prev_atom_idx, prev_atom in atoms_appended_for_res:
|
250 |
+
if are_atoms_bonded(res_name_3, atom_name, prev_atom):
|
251 |
+
conect_line = f"CONECT{prev_atom_idx:5d}{atom_index:5d}\n"
|
252 |
+
conect_lines.append(conect_line)
|
253 |
+
atoms_appended_for_res.append((atom_index, atom_name))
|
254 |
+
if atom_name == "N":
|
255 |
+
n_atom_idx = atom_index
|
256 |
+
if atom_name == "C":
|
257 |
+
c_atom_idx = atom_index
|
258 |
+
|
259 |
+
atom_index += 1
|
260 |
+
|
261 |
+
if i > 0:
|
262 |
+
conect_line = f"CONECT{prev_c_atom_idx:5d}{n_atom_idx:5d}\n"
|
263 |
+
conect_lines.append(conect_line)
|
264 |
+
prev_c_atom_idx = c_atom_idx
|
265 |
+
|
266 |
+
# Close the final chain.
|
267 |
+
pdb_lines.append(
|
268 |
+
_chain_end(
|
269 |
+
atom_index,
|
270 |
+
res_1to3(aatype[-1]),
|
271 |
+
chain_ids[chain_index[-1]],
|
272 |
+
residue_index[-1],
|
273 |
+
)
|
274 |
+
)
|
275 |
+
pdb_lines.append("ENDMDL")
|
276 |
+
pdb_lines.append("END")
|
277 |
+
|
278 |
+
# Pad all lines to 80 characters.
|
279 |
+
pdb_lines = [line.ljust(80) for line in pdb_lines]
|
280 |
+
pdb_str = "\n".join(pdb_lines) + "\n" # Add terminating newline.
|
281 |
+
if conect:
|
282 |
+
conect_str = "".join(conect_lines) + "\n"
|
283 |
+
return pdb_str, conect_str
|
284 |
+
return pdb_str
|
285 |
+
|
286 |
+
|
287 |
+
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
288 |
+
"""Computes an ideal atom mask.
|
289 |
+
|
290 |
+
`Protein.atom_mask` typically is defined according to the atoms that are
|
291 |
+
reported in the PDB. This function computes a mask according to heavy atoms
|
292 |
+
that should be present in the given sequence of amino acids.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
An ideal atom mask.
|
299 |
+
"""
|
300 |
+
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
301 |
+
|
302 |
+
|
303 |
+
def from_prediction(
|
304 |
+
features: FeatureDict,
|
305 |
+
result: ModelOutput,
|
306 |
+
b_factors: Optional[np.ndarray] = None,
|
307 |
+
remove_leading_feature_dimension: bool = True,
|
308 |
+
) -> Protein:
|
309 |
+
"""Assembles a protein from a prediction.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
features: Dictionary holding model inputs.
|
313 |
+
result: Dictionary holding model outputs.
|
314 |
+
b_factors: (Optional) B-factors to use for the protein.
|
315 |
+
remove_leading_feature_dimension: Whether to remove the leading dimension
|
316 |
+
of the `features` values.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
A protein instance.
|
320 |
+
"""
|
321 |
+
fold_output = result["structure_module"]
|
322 |
+
|
323 |
+
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
|
324 |
+
return arr[0] if remove_leading_feature_dimension else arr
|
325 |
+
|
326 |
+
if "asym_id" in features:
|
327 |
+
chain_index = _maybe_remove_leading_dim(features["asym_id"])
|
328 |
+
else:
|
329 |
+
chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"]))
|
330 |
+
|
331 |
+
if b_factors is None:
|
332 |
+
b_factors = np.zeros_like(fold_output["final_atom_mask"])
|
333 |
+
|
334 |
+
return Protein(
|
335 |
+
aatype=_maybe_remove_leading_dim(features["aatype"]),
|
336 |
+
atom_positions=fold_output["final_atom_positions"],
|
337 |
+
atom_mask=fold_output["final_atom_mask"],
|
338 |
+
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
|
339 |
+
chain_index=chain_index,
|
340 |
+
b_factors=b_factors,
|
341 |
+
)
|
core/protein_mpnn.py
ADDED
@@ -0,0 +1,1886 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Justas Dauparas
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
'''
|
24 |
+
Adapted from original code by alexechu.
|
25 |
+
'''
|
26 |
+
import json, time, os, sys, glob
|
27 |
+
import shutil
|
28 |
+
import warnings
|
29 |
+
import copy
|
30 |
+
import random
|
31 |
+
import os.path
|
32 |
+
import subprocess
|
33 |
+
import itertools
|
34 |
+
|
35 |
+
from einops.layers.torch import Rearrange
|
36 |
+
import numpy as np
|
37 |
+
import torch
|
38 |
+
from torch import optim
|
39 |
+
from torch.utils.data import DataLoader
|
40 |
+
from torch.utils.data.dataset import random_split, Subset
|
41 |
+
import torch.nn as nn
|
42 |
+
import torch.nn.functional as F
|
43 |
+
|
44 |
+
|
45 |
+
def get_mpnn_model(model_name='v_48_020', path_to_model_weights='', ca_only=False, backbone_noise=0.0, verbose=False, device=None):
|
46 |
+
hidden_dim = 128
|
47 |
+
num_layers = 3
|
48 |
+
if device is None:
|
49 |
+
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
|
50 |
+
|
51 |
+
if path_to_model_weights:
|
52 |
+
model_folder_path = path_to_model_weights
|
53 |
+
if model_folder_path[-1] != '/':
|
54 |
+
model_folder_path = model_folder_path + '/'
|
55 |
+
else:
|
56 |
+
file_path = os.path.realpath(__file__)
|
57 |
+
k = file_path.rfind("/")
|
58 |
+
if ca_only:
|
59 |
+
model_folder_path = file_path[:k] + '/ca_model_weights/'
|
60 |
+
else:
|
61 |
+
model_folder_path = file_path[:k] + '/vanilla_model_weights/'
|
62 |
+
|
63 |
+
checkpoint_path = model_folder_path + f'{model_name}.pt'
|
64 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
65 |
+
noise_level_print = checkpoint['noise_level']
|
66 |
+
model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim,
|
67 |
+
num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
|
68 |
+
model.to(device)
|
69 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
if verbose:
|
73 |
+
print(40*'-')
|
74 |
+
print('Model loaded...')
|
75 |
+
print('Number of edges:', checkpoint['num_edges'])
|
76 |
+
print(f'Training noise level: {noise_level_print}A')
|
77 |
+
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def run_proteinmpnn(model=None, pdb_path='', pdb_path_chains='', path_to_model_weights='', model_name='v_48_020', seed=0, ca_only=False, out_folder='', num_seq_per_target=1, batch_size=1, sampling_temps=[0.1], backbone_noise=0.0, max_length=200000, omit_AAs=[], print_all=False,
|
82 |
+
chain_id_jsonl='', fixed_positions_jsonl='', pssm_jsonl='', omit_AA_jsonl='', bias_AA_jsonl='', tied_positions_jsonl='', bias_by_res_jsonl='', jsonl_path='',
|
83 |
+
pssm_threshold=0.0, pssm_multi=0.0, pssm_log_odds_flag=False, pssm_bias_flag=False, write_output_files=False):
|
84 |
+
|
85 |
+
if model is None:
|
86 |
+
model = get_mpnn_model(model_name=model_name, path_to_model_weights=path_to_model_weights, ca_only=ca_only, backbone_noise=backbone_noise, verbose=print_all)
|
87 |
+
|
88 |
+
if seed:
|
89 |
+
seed=seed
|
90 |
+
else:
|
91 |
+
seed=int(np.random.randint(0, high=999, size=1, dtype=int)[0])
|
92 |
+
|
93 |
+
torch.manual_seed(seed)
|
94 |
+
random.seed(seed)
|
95 |
+
np.random.seed(seed)
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
NUM_BATCHES = num_seq_per_target//batch_size
|
100 |
+
BATCH_COPIES = batch_size
|
101 |
+
temperatures = sampling_temps
|
102 |
+
omit_AAs_list = omit_AAs
|
103 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
104 |
+
alphabet_dict = dict(zip(alphabet, range(21)))
|
105 |
+
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
|
106 |
+
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
|
107 |
+
if os.path.isfile(chain_id_jsonl):
|
108 |
+
with open(chain_id_jsonl, 'r') as json_file:
|
109 |
+
json_list = list(json_file)
|
110 |
+
for json_str in json_list:
|
111 |
+
chain_id_dict = json.loads(json_str)
|
112 |
+
else:
|
113 |
+
chain_id_dict = None
|
114 |
+
if print_all:
|
115 |
+
print(40*'-')
|
116 |
+
print('chain_id_jsonl is NOT loaded')
|
117 |
+
|
118 |
+
if os.path.isfile(fixed_positions_jsonl):
|
119 |
+
with open(fixed_positions_jsonl, 'r') as json_file:
|
120 |
+
json_list = list(json_file)
|
121 |
+
for json_str in json_list:
|
122 |
+
fixed_positions_dict = json.loads(json_str)
|
123 |
+
else:
|
124 |
+
if print_all:
|
125 |
+
print(40*'-')
|
126 |
+
print('fixed_positions_jsonl is NOT loaded')
|
127 |
+
fixed_positions_dict = None
|
128 |
+
|
129 |
+
|
130 |
+
if os.path.isfile(pssm_jsonl):
|
131 |
+
with open(pssm_jsonl, 'r') as json_file:
|
132 |
+
json_list = list(json_file)
|
133 |
+
pssm_dict = {}
|
134 |
+
for json_str in json_list:
|
135 |
+
pssm_dict.update(json.loads(json_str))
|
136 |
+
else:
|
137 |
+
if print_all:
|
138 |
+
print(40*'-')
|
139 |
+
print('pssm_jsonl is NOT loaded')
|
140 |
+
pssm_dict = None
|
141 |
+
|
142 |
+
|
143 |
+
if os.path.isfile(omit_AA_jsonl):
|
144 |
+
with open(omit_AA_jsonl, 'r') as json_file:
|
145 |
+
json_list = list(json_file)
|
146 |
+
for json_str in json_list:
|
147 |
+
omit_AA_dict = json.loads(json_str)
|
148 |
+
else:
|
149 |
+
if print_all:
|
150 |
+
print(40*'-')
|
151 |
+
print('omit_AA_jsonl is NOT loaded')
|
152 |
+
omit_AA_dict = None
|
153 |
+
|
154 |
+
|
155 |
+
if os.path.isfile(bias_AA_jsonl):
|
156 |
+
with open(bias_AA_jsonl, 'r') as json_file:
|
157 |
+
json_list = list(json_file)
|
158 |
+
for json_str in json_list:
|
159 |
+
bias_AA_dict = json.loads(json_str)
|
160 |
+
else:
|
161 |
+
if print_all:
|
162 |
+
print(40*'-')
|
163 |
+
print('bias_AA_jsonl is NOT loaded')
|
164 |
+
bias_AA_dict = None
|
165 |
+
|
166 |
+
|
167 |
+
if os.path.isfile(tied_positions_jsonl):
|
168 |
+
with open(tied_positions_jsonl, 'r') as json_file:
|
169 |
+
json_list = list(json_file)
|
170 |
+
for json_str in json_list:
|
171 |
+
tied_positions_dict = json.loads(json_str)
|
172 |
+
else:
|
173 |
+
if print_all:
|
174 |
+
print(40*'-')
|
175 |
+
print('tied_positions_jsonl is NOT loaded')
|
176 |
+
tied_positions_dict = None
|
177 |
+
|
178 |
+
|
179 |
+
if os.path.isfile(bias_by_res_jsonl):
|
180 |
+
with open(bias_by_res_jsonl, 'r') as json_file:
|
181 |
+
json_list = list(json_file)
|
182 |
+
|
183 |
+
for json_str in json_list:
|
184 |
+
bias_by_res_dict = json.loads(json_str)
|
185 |
+
if print_all:
|
186 |
+
print('bias by residue dictionary is loaded')
|
187 |
+
else:
|
188 |
+
if print_all:
|
189 |
+
print(40*'-')
|
190 |
+
print('bias by residue dictionary is not loaded, or not provided')
|
191 |
+
bias_by_res_dict = None
|
192 |
+
|
193 |
+
|
194 |
+
if print_all:
|
195 |
+
print(40*'-')
|
196 |
+
bias_AAs_np = np.zeros(len(alphabet))
|
197 |
+
if bias_AA_dict:
|
198 |
+
for n, AA in enumerate(alphabet):
|
199 |
+
if AA in list(bias_AA_dict.keys()):
|
200 |
+
bias_AAs_np[n] = bias_AA_dict[AA]
|
201 |
+
|
202 |
+
if pdb_path:
|
203 |
+
pdb_dict_list = parse_PDB(pdb_path, ca_only=ca_only)
|
204 |
+
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
|
205 |
+
all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9]=='seq_chain'] #['A','B', 'C',...]
|
206 |
+
if pdb_path_chains:
|
207 |
+
designed_chain_list = [str(item) for item in pdb_path_chains.split()]
|
208 |
+
else:
|
209 |
+
designed_chain_list = all_chain_list
|
210 |
+
fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
|
211 |
+
chain_id_dict = {}
|
212 |
+
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)
|
213 |
+
else:
|
214 |
+
dataset_valid = StructureDataset(jsonl_path, truncate=None, max_length=max_length, verbose=print_all)
|
215 |
+
|
216 |
+
# Build paths for experiment
|
217 |
+
if write_output_files:
|
218 |
+
folder_for_outputs = out_folder
|
219 |
+
base_folder = folder_for_outputs
|
220 |
+
if base_folder[-1] != '/':
|
221 |
+
base_folder = base_folder + '/'
|
222 |
+
if not os.path.exists(base_folder):
|
223 |
+
os.makedirs(base_folder)
|
224 |
+
if not os.path.exists(base_folder + 'seqs'):
|
225 |
+
os.makedirs(base_folder + 'seqs')
|
226 |
+
|
227 |
+
# if args.save_score:
|
228 |
+
# if not os.path.exists(base_folder + 'scores'):
|
229 |
+
# os.makedirs(base_folder + 'scores')
|
230 |
+
|
231 |
+
# if args.score_only:
|
232 |
+
# if not os.path.exists(base_folder + 'score_only'):
|
233 |
+
# os.makedirs(base_folder + 'score_only')
|
234 |
+
|
235 |
+
|
236 |
+
# if args.conditional_probs_only:
|
237 |
+
# if not os.path.exists(base_folder + 'conditional_probs_only'):
|
238 |
+
# os.makedirs(base_folder + 'conditional_probs_only')
|
239 |
+
|
240 |
+
# if args.unconditional_probs_only:
|
241 |
+
# if not os.path.exists(base_folder + 'unconditional_probs_only'):
|
242 |
+
# os.makedirs(base_folder + 'unconditional_probs_only')
|
243 |
+
|
244 |
+
# if args.save_probs:
|
245 |
+
# if not os.path.exists(base_folder + 'probs'):
|
246 |
+
# os.makedirs(base_folder + 'probs')
|
247 |
+
|
248 |
+
# Timing
|
249 |
+
start_time = time.time()
|
250 |
+
total_residues = 0
|
251 |
+
protein_list = []
|
252 |
+
total_step = 0
|
253 |
+
# Validation epoch
|
254 |
+
new_mpnn_seqs = []
|
255 |
+
with torch.no_grad():
|
256 |
+
test_sum, test_weights = 0., 0.
|
257 |
+
for ix, protein in enumerate(dataset_valid):
|
258 |
+
score_list = []
|
259 |
+
global_score_list = []
|
260 |
+
all_probs_list = []
|
261 |
+
all_log_probs_list = []
|
262 |
+
S_sample_list = []
|
263 |
+
batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
|
264 |
+
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
|
265 |
+
pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
|
266 |
+
name_ = batch_clones[0]['name']
|
267 |
+
if False:
|
268 |
+
pass
|
269 |
+
# if args.score_only:
|
270 |
+
# loop_c = 0
|
271 |
+
# if args.path_to_fasta:
|
272 |
+
# fasta_names, fasta_seqs = parse_fasta(args.path_to_fasta, omit=["/"])
|
273 |
+
# loop_c = len(fasta_seqs)
|
274 |
+
# for fc in range(1+loop_c):
|
275 |
+
# if fc == 0:
|
276 |
+
# structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_pdb'
|
277 |
+
# else:
|
278 |
+
# structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_fasta_{fc}'
|
279 |
+
# native_score_list = []
|
280 |
+
# global_native_score_list = []
|
281 |
+
# if fc > 0:
|
282 |
+
# input_seq_length = len(fasta_seqs[fc-1])
|
283 |
+
# S_input = torch.tensor([alphabet_dict[AA] for AA in fasta_seqs[fc-1]], device=device)[None,:].repeat(X.shape[0], 1)
|
284 |
+
# S[:,:input_seq_length] = S_input #assumes that S and S_input are alphabetically sorted for masked_chains
|
285 |
+
# for j in range(NUM_BATCHES):
|
286 |
+
# randn_1 = torch.randn(chain_M.shape, device=X.device)
|
287 |
+
# log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
|
288 |
+
# mask_for_loss = mask*chain_M*chain_M_pos
|
289 |
+
# scores = _scores(S, log_probs, mask_for_loss)
|
290 |
+
# native_score = scores.cpu().data.numpy()
|
291 |
+
# native_score_list.append(native_score)
|
292 |
+
# global_scores = _scores(S, log_probs, mask)
|
293 |
+
# global_native_score = global_scores.cpu().data.numpy()
|
294 |
+
# global_native_score_list.append(global_native_score)
|
295 |
+
# native_score = np.concatenate(native_score_list, 0)
|
296 |
+
# global_native_score = np.concatenate(global_native_score_list, 0)
|
297 |
+
# ns_mean = native_score.mean()
|
298 |
+
# ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
|
299 |
+
# ns_std = native_score.std()
|
300 |
+
# ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)
|
301 |
+
|
302 |
+
# global_ns_mean = global_native_score.mean()
|
303 |
+
# global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
|
304 |
+
# global_ns_std = global_native_score.std()
|
305 |
+
# global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
|
306 |
+
|
307 |
+
# ns_sample_size = native_score.shape[0]
|
308 |
+
# seq_str = _S_to_seq(S[0,], chain_M[0,])
|
309 |
+
# np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score, S=S[0,].cpu().numpy(), seq_str=seq_str)
|
310 |
+
# if print_all:
|
311 |
+
# if fc == 0:
|
312 |
+
# print(f'Score for {name_} from PDB, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, global score, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
|
313 |
+
# else:
|
314 |
+
# print(f'Score for {name_}_{fc} from FASTA, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, global score, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
|
315 |
+
# elif args.conditional_probs_only:
|
316 |
+
# if print_all:
|
317 |
+
# print(f'Calculating conditional probabilities for {name_}')
|
318 |
+
# conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
|
319 |
+
# log_conditional_probs_list = []
|
320 |
+
# for j in range(NUM_BATCHES):
|
321 |
+
# randn_1 = torch.randn(chain_M.shape, device=X.device)
|
322 |
+
# log_conditional_probs = model.conditional_probs(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, args.conditional_probs_only_backbone)
|
323 |
+
# log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
|
324 |
+
# concat_log_p = np.concatenate(log_conditional_probs_list, 0) #[B, L, 21]
|
325 |
+
# mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
|
326 |
+
# np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
|
327 |
+
# elif args.unconditional_probs_only:
|
328 |
+
# if print_all:
|
329 |
+
# print(f'Calculating sequence unconditional probabilities for {name_}')
|
330 |
+
# unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
|
331 |
+
# log_unconditional_probs_list = []
|
332 |
+
# for j in range(NUM_BATCHES):
|
333 |
+
# log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
|
334 |
+
# log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
|
335 |
+
# concat_log_p = np.concatenate(log_unconditional_probs_list, 0) #[B, L, 21]
|
336 |
+
# mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
|
337 |
+
# np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
|
338 |
+
else:
|
339 |
+
randn_1 = torch.randn(chain_M.shape, device=X.device)
|
340 |
+
log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
|
341 |
+
mask_for_loss = mask*chain_M*chain_M_pos
|
342 |
+
scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
|
343 |
+
native_score = scores.cpu().data.numpy()
|
344 |
+
global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
|
345 |
+
global_native_score = global_scores.cpu().data.numpy()
|
346 |
+
# Generate some sequences
|
347 |
+
if write_output_files:
|
348 |
+
ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
|
349 |
+
score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
|
350 |
+
probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
|
351 |
+
f = open(ali_file, 'w')
|
352 |
+
if print_all:
|
353 |
+
print(f'Generating sequences for: {name_}')
|
354 |
+
t0 = time.time()
|
355 |
+
for temp in temperatures:
|
356 |
+
for j in range(NUM_BATCHES):
|
357 |
+
randn_2 = torch.randn(chain_M.shape, device=X.device)
|
358 |
+
if tied_positions_dict == None:
|
359 |
+
sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
|
360 |
+
S_sample = sample_dict["S"]
|
361 |
+
else:
|
362 |
+
sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
|
363 |
+
# Compute scores
|
364 |
+
S_sample = sample_dict["S"]
|
365 |
+
log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
|
366 |
+
mask_for_loss = mask*chain_M*chain_M_pos
|
367 |
+
scores = _scores(S_sample, log_probs, mask_for_loss)
|
368 |
+
scores = scores.cpu().data.numpy()
|
369 |
+
|
370 |
+
global_scores = _scores(S_sample, log_probs, mask) #score the whole structure-sequence
|
371 |
+
global_scores = global_scores.cpu().data.numpy()
|
372 |
+
|
373 |
+
all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
|
374 |
+
all_log_probs_list.append(log_probs.cpu().data.numpy())
|
375 |
+
S_sample_list.append(S_sample.cpu().data.numpy())
|
376 |
+
for b_ix in range(BATCH_COPIES):
|
377 |
+
masked_chain_length_list = masked_chain_length_list_list[b_ix]
|
378 |
+
masked_list = masked_list_list[b_ix]
|
379 |
+
seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
|
380 |
+
seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
|
381 |
+
new_mpnn_seqs.append(seq)
|
382 |
+
score = scores[b_ix]
|
383 |
+
score_list.append(score)
|
384 |
+
global_score = global_scores[b_ix]
|
385 |
+
global_score_list.append(global_score)
|
386 |
+
native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
|
387 |
+
if b_ix == 0 and j==0 and temp==temperatures[0]:
|
388 |
+
start = 0
|
389 |
+
end = 0
|
390 |
+
list_of_AAs = []
|
391 |
+
for mask_l in masked_chain_length_list:
|
392 |
+
end += mask_l
|
393 |
+
list_of_AAs.append(native_seq[start:end])
|
394 |
+
start = end
|
395 |
+
native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
|
396 |
+
l0 = 0
|
397 |
+
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
|
398 |
+
l0 += mc_length
|
399 |
+
native_seq = native_seq[:l0] + '/' + native_seq[l0:]
|
400 |
+
l0 += 1
|
401 |
+
sorted_masked_chain_letters = np.argsort(masked_list_list[0])
|
402 |
+
print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
|
403 |
+
sorted_visible_chain_letters = np.argsort(visible_list_list[0])
|
404 |
+
print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
|
405 |
+
native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
|
406 |
+
global_native_score_print = np.format_float_positional(np.float32(global_native_score.mean()), unique=False, precision=4)
|
407 |
+
script_dir = os.path.dirname(os.path.realpath(__file__))
|
408 |
+
try:
|
409 |
+
commit_str = subprocess.check_output(f'git --git-dir {script_dir}/.git rev-parse HEAD', shell=True, stderr=subprocess.DEVNULL).decode().strip()
|
410 |
+
except subprocess.CalledProcessError:
|
411 |
+
commit_str = 'unknown'
|
412 |
+
if ca_only:
|
413 |
+
print_model_name = 'CA_model_name'
|
414 |
+
else:
|
415 |
+
print_model_name = 'model_name'
|
416 |
+
if write_output_files:
|
417 |
+
f.write('>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, git_hash={}, seed={}\n{}\n'.format(name_, native_score_print, global_native_score_print, print_visible_chains, print_masked_chains, print_model_name, model_name, commit_str, seed, native_seq)) #write the native sequence
|
418 |
+
start = 0
|
419 |
+
end = 0
|
420 |
+
list_of_AAs = []
|
421 |
+
for mask_l in masked_chain_length_list:
|
422 |
+
end += mask_l
|
423 |
+
list_of_AAs.append(seq[start:end])
|
424 |
+
start = end
|
425 |
+
|
426 |
+
seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
|
427 |
+
l0 = 0
|
428 |
+
for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
|
429 |
+
l0 += mc_length
|
430 |
+
seq = seq[:l0] + '/' + seq[l0:]
|
431 |
+
l0 += 1
|
432 |
+
score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
|
433 |
+
global_score_print = np.format_float_positional(np.float32(global_score), unique=False, precision=4)
|
434 |
+
seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
|
435 |
+
sample_number = j*BATCH_COPIES+b_ix+1
|
436 |
+
if write_output_files:
|
437 |
+
f.write('>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,sample_number,score_print,global_score_print,seq_rec_print,seq)) #write generated sequence
|
438 |
+
# if args.save_score:
|
439 |
+
# np.savez(score_file, score=np.array(score_list, np.float32), global_score=np.array(global_score_list, np.float32))
|
440 |
+
# if args.save_probs:
|
441 |
+
# all_probs_concat = np.concatenate(all_probs_list)
|
442 |
+
# all_log_probs_concat = np.concatenate(all_log_probs_list)
|
443 |
+
# S_sample_concat = np.concatenate(S_sample_list)
|
444 |
+
# np.savez(probs_file, probs=np.array(all_probs_concat, np.float32), log_probs=np.array(all_log_probs_concat, np.float32), S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(), chain_order=chain_list_list)
|
445 |
+
t1 = time.time()
|
446 |
+
dt = round(float(t1-t0), 4)
|
447 |
+
num_seqs = len(temperatures)*NUM_BATCHES*BATCH_COPIES
|
448 |
+
total_length = X.shape[1]
|
449 |
+
if print_all:
|
450 |
+
print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
|
451 |
+
if write_output_files:
|
452 |
+
f.close()
|
453 |
+
|
454 |
+
return new_mpnn_seqs
|
455 |
+
|
456 |
+
|
457 |
+
def parse_fasta(filename,limit=-1, omit=[]):
|
458 |
+
header = []
|
459 |
+
sequence = []
|
460 |
+
lines = open(filename, "r")
|
461 |
+
for line in lines:
|
462 |
+
line = line.rstrip()
|
463 |
+
if line[0] == ">":
|
464 |
+
if len(header) == limit:
|
465 |
+
break
|
466 |
+
header.append(line[1:])
|
467 |
+
sequence.append([])
|
468 |
+
else:
|
469 |
+
if omit:
|
470 |
+
line = [item for item in line if item not in omit]
|
471 |
+
line = ''.join(line)
|
472 |
+
line = ''.join(line)
|
473 |
+
sequence[-1].append(line)
|
474 |
+
lines.close()
|
475 |
+
sequence = [''.join(seq) for seq in sequence]
|
476 |
+
return np.array(header), np.array(sequence)
|
477 |
+
|
478 |
+
def _scores(S, log_probs, mask):
|
479 |
+
""" Negative log probabilities """
|
480 |
+
criterion = torch.nn.NLLLoss(reduction='none')
|
481 |
+
loss = criterion(
|
482 |
+
log_probs.contiguous().view(-1,log_probs.size(-1)),
|
483 |
+
S.contiguous().view(-1)
|
484 |
+
).view(S.size())
|
485 |
+
scores = torch.sum(loss * mask, dim=-1) / torch.sum(mask, dim=-1)
|
486 |
+
return scores
|
487 |
+
|
488 |
+
def _S_to_seq(S, mask):
|
489 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
490 |
+
seq = ''.join([alphabet[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0])
|
491 |
+
return seq
|
492 |
+
|
493 |
+
def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None):
|
494 |
+
'''
|
495 |
+
input: x = PDB filename
|
496 |
+
atoms = atoms to extract (optional)
|
497 |
+
output: (length, atoms, coords=(x,y,z)), sequence
|
498 |
+
'''
|
499 |
+
|
500 |
+
alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
|
501 |
+
states = len(alpha_1)
|
502 |
+
alpha_3 = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE',
|
503 |
+
'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL','GAP']
|
504 |
+
|
505 |
+
aa_1_N = {a:n for n,a in enumerate(alpha_1)}
|
506 |
+
aa_3_N = {a:n for n,a in enumerate(alpha_3)}
|
507 |
+
aa_N_1 = {n:a for n,a in enumerate(alpha_1)}
|
508 |
+
aa_1_3 = {a:b for a,b in zip(alpha_1,alpha_3)}
|
509 |
+
aa_3_1 = {b:a for a,b in zip(alpha_1,alpha_3)}
|
510 |
+
|
511 |
+
def AA_to_N(x):
|
512 |
+
# ["ARND"] -> [[0,1,2,3]]
|
513 |
+
x = np.array(x);
|
514 |
+
if x.ndim == 0: x = x[None]
|
515 |
+
return [[aa_1_N.get(a, states-1) for a in y] for y in x]
|
516 |
+
|
517 |
+
def N_to_AA(x):
|
518 |
+
# [[0,1,2,3]] -> ["ARND"]
|
519 |
+
x = np.array(x);
|
520 |
+
if x.ndim == 1: x = x[None]
|
521 |
+
return ["".join([aa_N_1.get(a,"-") for a in y]) for y in x]
|
522 |
+
|
523 |
+
xyz,seq,min_resn,max_resn = {},{},1e6,-1e6
|
524 |
+
for line in open(x,"rb"):
|
525 |
+
line = line.decode("utf-8","ignore").rstrip()
|
526 |
+
|
527 |
+
if line[:6] == "HETATM" and line[17:17+3] == "MSE":
|
528 |
+
line = line.replace("HETATM","ATOM ")
|
529 |
+
line = line.replace("MSE","MET")
|
530 |
+
|
531 |
+
if line[:4] == "ATOM":
|
532 |
+
ch = line[21:22]
|
533 |
+
if ch == chain or chain is None:
|
534 |
+
atom = line[12:12+4].strip()
|
535 |
+
resi = line[17:17+3]
|
536 |
+
resn = line[22:22+5].strip()
|
537 |
+
x,y,z = [float(line[i:(i+8)]) for i in [30,38,46]]
|
538 |
+
|
539 |
+
if resn[-1].isalpha():
|
540 |
+
resa,resn = resn[-1],int(resn[:-1])-1
|
541 |
+
else:
|
542 |
+
resa,resn = "",int(resn)-1
|
543 |
+
# resn = int(resn)
|
544 |
+
if resn < min_resn:
|
545 |
+
min_resn = resn
|
546 |
+
if resn > max_resn:
|
547 |
+
max_resn = resn
|
548 |
+
if resn not in xyz:
|
549 |
+
xyz[resn] = {}
|
550 |
+
if resa not in xyz[resn]:
|
551 |
+
xyz[resn][resa] = {}
|
552 |
+
if resn not in seq:
|
553 |
+
seq[resn] = {}
|
554 |
+
if resa not in seq[resn]:
|
555 |
+
seq[resn][resa] = resi
|
556 |
+
|
557 |
+
if atom not in xyz[resn][resa]:
|
558 |
+
xyz[resn][resa][atom] = np.array([x,y,z])
|
559 |
+
|
560 |
+
# convert to numpy arrays, fill in missing values
|
561 |
+
seq_,xyz_ = [],[]
|
562 |
+
try:
|
563 |
+
for resn in range(min_resn,max_resn+1):
|
564 |
+
if resn in seq:
|
565 |
+
for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k],20))
|
566 |
+
else: seq_.append(20)
|
567 |
+
if resn in xyz:
|
568 |
+
for k in sorted(xyz[resn]):
|
569 |
+
for atom in atoms:
|
570 |
+
if atom in xyz[resn][k]: xyz_.append(xyz[resn][k][atom])
|
571 |
+
else: xyz_.append(np.full(3,np.nan))
|
572 |
+
else:
|
573 |
+
for atom in atoms: xyz_.append(np.full(3,np.nan))
|
574 |
+
return np.array(xyz_).reshape(-1,len(atoms),3), N_to_AA(np.array(seq_))
|
575 |
+
except TypeError:
|
576 |
+
return 'no_chain', 'no_chain'
|
577 |
+
|
578 |
+
def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False):
|
579 |
+
c=0
|
580 |
+
pdb_dict_list = []
|
581 |
+
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
|
582 |
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
583 |
+
chain_alphabet = init_alphabet + extra_alphabet
|
584 |
+
|
585 |
+
if input_chain_list:
|
586 |
+
chain_alphabet = input_chain_list
|
587 |
+
|
588 |
+
|
589 |
+
biounit_names = [path_to_pdb]
|
590 |
+
for biounit in biounit_names:
|
591 |
+
my_dict = {}
|
592 |
+
s = 0
|
593 |
+
concat_seq = ''
|
594 |
+
concat_N = []
|
595 |
+
concat_CA = []
|
596 |
+
concat_C = []
|
597 |
+
concat_O = []
|
598 |
+
concat_mask = []
|
599 |
+
coords_dict = {}
|
600 |
+
for letter in chain_alphabet:
|
601 |
+
if ca_only:
|
602 |
+
sidechain_atoms = ['CA']
|
603 |
+
else:
|
604 |
+
sidechain_atoms = ['N', 'CA', 'C', 'O']
|
605 |
+
xyz, seq = parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
|
606 |
+
if type(xyz) != str:
|
607 |
+
concat_seq += seq[0]
|
608 |
+
my_dict['seq_chain_'+letter]=seq[0]
|
609 |
+
coords_dict_chain = {}
|
610 |
+
if ca_only:
|
611 |
+
coords_dict_chain['CA_chain_'+letter]=xyz.tolist()
|
612 |
+
else:
|
613 |
+
coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
|
614 |
+
coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
|
615 |
+
coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
|
616 |
+
coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
|
617 |
+
my_dict['coords_chain_'+letter]=coords_dict_chain
|
618 |
+
s += 1
|
619 |
+
fi = biounit.rfind("/")
|
620 |
+
my_dict['name']=biounit[(fi+1):-4]
|
621 |
+
my_dict['num_of_chains'] = s
|
622 |
+
my_dict['seq'] = concat_seq
|
623 |
+
if s <= len(chain_alphabet):
|
624 |
+
pdb_dict_list.append(my_dict)
|
625 |
+
c+=1
|
626 |
+
return pdb_dict_list
|
627 |
+
|
628 |
+
|
629 |
+
|
630 |
+
def tied_featurize(batch, device, chain_dict, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None, ca_only=False):
|
631 |
+
""" Pack and pad batch into torch tensors """
|
632 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
633 |
+
B = len(batch)
|
634 |
+
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
|
635 |
+
L_max = max([len(b['seq']) for b in batch])
|
636 |
+
if ca_only:
|
637 |
+
X = np.zeros([B, L_max, 1, 3])
|
638 |
+
else:
|
639 |
+
X = np.zeros([B, L_max, 4, 3])
|
640 |
+
residue_idx = -100*np.ones([B, L_max], dtype=np.int32)
|
641 |
+
chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
642 |
+
pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted
|
643 |
+
pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
|
644 |
+
pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
|
645 |
+
chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
646 |
+
bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32)
|
647 |
+
chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
648 |
+
S = np.zeros([B, L_max], dtype=np.int32)
|
649 |
+
omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32)
|
650 |
+
# Build the batch
|
651 |
+
letter_list_list = []
|
652 |
+
visible_list_list = []
|
653 |
+
masked_list_list = []
|
654 |
+
masked_chain_length_list_list = []
|
655 |
+
tied_pos_list_of_lists_list = []
|
656 |
+
for i, b in enumerate(batch):
|
657 |
+
if chain_dict != None:
|
658 |
+
masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F]
|
659 |
+
else:
|
660 |
+
masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
|
661 |
+
visible_chains = []
|
662 |
+
masked_chains.sort() #sort masked_chains
|
663 |
+
visible_chains.sort() #sort visible_chains
|
664 |
+
all_chains = masked_chains + visible_chains
|
665 |
+
for i, b in enumerate(batch):
|
666 |
+
mask_dict = {}
|
667 |
+
a = 0
|
668 |
+
x_chain_list = []
|
669 |
+
chain_mask_list = []
|
670 |
+
chain_seq_list = []
|
671 |
+
chain_encoding_list = []
|
672 |
+
c = 1
|
673 |
+
letter_list = []
|
674 |
+
global_idx_start_list = [0]
|
675 |
+
visible_list = []
|
676 |
+
masked_list = []
|
677 |
+
masked_chain_length_list = []
|
678 |
+
fixed_position_mask_list = []
|
679 |
+
omit_AA_mask_list = []
|
680 |
+
pssm_coef_list = []
|
681 |
+
pssm_bias_list = []
|
682 |
+
pssm_log_odds_list = []
|
683 |
+
bias_by_res_list = []
|
684 |
+
l0 = 0
|
685 |
+
l1 = 0
|
686 |
+
for step, letter in enumerate(all_chains):
|
687 |
+
if letter in visible_chains:
|
688 |
+
letter_list.append(letter)
|
689 |
+
visible_list.append(letter)
|
690 |
+
chain_seq = b[f'seq_chain_{letter}']
|
691 |
+
chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
|
692 |
+
chain_length = len(chain_seq)
|
693 |
+
global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
|
694 |
+
chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
|
695 |
+
chain_mask = np.zeros(chain_length) #0.0 for visible chains
|
696 |
+
if ca_only:
|
697 |
+
x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
|
698 |
+
if len(x_chain.shape) == 2:
|
699 |
+
x_chain = x_chain[:,None,:]
|
700 |
+
else:
|
701 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
702 |
+
x_chain_list.append(x_chain)
|
703 |
+
chain_mask_list.append(chain_mask)
|
704 |
+
chain_seq_list.append(chain_seq)
|
705 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
706 |
+
l1 += chain_length
|
707 |
+
residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
|
708 |
+
l0 += chain_length
|
709 |
+
c+=1
|
710 |
+
fixed_position_mask = np.ones(chain_length)
|
711 |
+
fixed_position_mask_list.append(fixed_position_mask)
|
712 |
+
omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
|
713 |
+
omit_AA_mask_list.append(omit_AA_mask_temp)
|
714 |
+
pssm_coef = np.zeros(chain_length)
|
715 |
+
pssm_bias = np.zeros([chain_length, 21])
|
716 |
+
pssm_log_odds = 10000.0*np.ones([chain_length, 21])
|
717 |
+
pssm_coef_list.append(pssm_coef)
|
718 |
+
pssm_bias_list.append(pssm_bias)
|
719 |
+
pssm_log_odds_list.append(pssm_log_odds)
|
720 |
+
bias_by_res_list.append(np.zeros([chain_length, 21]))
|
721 |
+
if letter in masked_chains:
|
722 |
+
masked_list.append(letter)
|
723 |
+
letter_list.append(letter)
|
724 |
+
chain_seq = b[f'seq_chain_{letter}']
|
725 |
+
chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
|
726 |
+
chain_length = len(chain_seq)
|
727 |
+
global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
|
728 |
+
masked_chain_length_list.append(chain_length)
|
729 |
+
chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
|
730 |
+
chain_mask = np.ones(chain_length) #1.0 for masked
|
731 |
+
if ca_only:
|
732 |
+
x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
|
733 |
+
if len(x_chain.shape) == 2:
|
734 |
+
x_chain = x_chain[:,None,:]
|
735 |
+
else:
|
736 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
737 |
+
x_chain_list.append(x_chain)
|
738 |
+
chain_mask_list.append(chain_mask)
|
739 |
+
chain_seq_list.append(chain_seq)
|
740 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
741 |
+
l1 += chain_length
|
742 |
+
residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
|
743 |
+
l0 += chain_length
|
744 |
+
c+=1
|
745 |
+
fixed_position_mask = np.ones(chain_length)
|
746 |
+
if fixed_position_dict!=None:
|
747 |
+
fixed_pos_list = fixed_position_dict[b['name']][letter]
|
748 |
+
if fixed_pos_list:
|
749 |
+
fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0
|
750 |
+
fixed_position_mask_list.append(fixed_position_mask)
|
751 |
+
omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
|
752 |
+
if omit_AA_dict!=None:
|
753 |
+
for item in omit_AA_dict[b['name']][letter]:
|
754 |
+
idx_AA = np.array(item[0])-1
|
755 |
+
AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0])
|
756 |
+
idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx])
|
757 |
+
omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1
|
758 |
+
omit_AA_mask_list.append(omit_AA_mask_temp)
|
759 |
+
pssm_coef = np.zeros(chain_length)
|
760 |
+
pssm_bias = np.zeros([chain_length, 21])
|
761 |
+
pssm_log_odds = 10000.0*np.ones([chain_length, 21])
|
762 |
+
if pssm_dict:
|
763 |
+
if pssm_dict[b['name']][letter]:
|
764 |
+
pssm_coef = pssm_dict[b['name']][letter]['pssm_coef']
|
765 |
+
pssm_bias = pssm_dict[b['name']][letter]['pssm_bias']
|
766 |
+
pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds']
|
767 |
+
pssm_coef_list.append(pssm_coef)
|
768 |
+
pssm_bias_list.append(pssm_bias)
|
769 |
+
pssm_log_odds_list.append(pssm_log_odds)
|
770 |
+
if bias_by_res_dict:
|
771 |
+
bias_by_res_list.append(bias_by_res_dict[b['name']][letter])
|
772 |
+
else:
|
773 |
+
bias_by_res_list.append(np.zeros([chain_length, 21]))
|
774 |
+
|
775 |
+
|
776 |
+
letter_list_np = np.array(letter_list)
|
777 |
+
tied_pos_list_of_lists = []
|
778 |
+
tied_beta = np.ones(L_max)
|
779 |
+
if tied_positions_dict!=None:
|
780 |
+
tied_pos_list = tied_positions_dict[b['name']]
|
781 |
+
if tied_pos_list:
|
782 |
+
set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list])))
|
783 |
+
for tied_item in tied_pos_list:
|
784 |
+
one_list = []
|
785 |
+
for k, v in tied_item.items():
|
786 |
+
start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]]
|
787 |
+
if isinstance(v[0], list):
|
788 |
+
for v_count in range(len(v[0])):
|
789 |
+
one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first
|
790 |
+
tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count]
|
791 |
+
else:
|
792 |
+
for v_ in v:
|
793 |
+
one_list.append(start_idx+v_-1)#make 0 to be the first
|
794 |
+
tied_pos_list_of_lists.append(one_list)
|
795 |
+
tied_pos_list_of_lists_list.append(tied_pos_list_of_lists)
|
796 |
+
|
797 |
+
|
798 |
+
|
799 |
+
x = np.concatenate(x_chain_list,0) #[L, 4, 3]
|
800 |
+
all_sequence = "".join(chain_seq_list)
|
801 |
+
m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
|
802 |
+
chain_encoding = np.concatenate(chain_encoding_list,0)
|
803 |
+
m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted
|
804 |
+
|
805 |
+
pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted
|
806 |
+
pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted
|
807 |
+
pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted
|
808 |
+
|
809 |
+
bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked
|
810 |
+
|
811 |
+
l = len(all_sequence)
|
812 |
+
x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
|
813 |
+
X[i,:,:,:] = x_pad
|
814 |
+
|
815 |
+
m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
816 |
+
m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
817 |
+
omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
818 |
+
chain_M[i,:] = m_pad
|
819 |
+
chain_M_pos[i,:] = m_pos_pad
|
820 |
+
omit_AA_mask[i,] = omit_AA_mask_pad
|
821 |
+
|
822 |
+
chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
823 |
+
chain_encoding_all[i,:] = chain_encoding_pad
|
824 |
+
|
825 |
+
pssm_coef_pad = np.pad(pssm_coef_, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
826 |
+
pssm_bias_pad = np.pad(pssm_bias_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
827 |
+
pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
828 |
+
|
829 |
+
pssm_coef_all[i,:] = pssm_coef_pad
|
830 |
+
pssm_bias_all[i,:] = pssm_bias_pad
|
831 |
+
pssm_log_odds_all[i,:] = pssm_log_odds_pad
|
832 |
+
|
833 |
+
bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
834 |
+
bias_by_res_all[i,:] = bias_by_res_pad
|
835 |
+
|
836 |
+
# Convert to labels
|
837 |
+
indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
|
838 |
+
S[i, :l] = indices
|
839 |
+
letter_list_list.append(letter_list)
|
840 |
+
visible_list_list.append(visible_list)
|
841 |
+
masked_list_list.append(masked_list)
|
842 |
+
masked_chain_length_list_list.append(masked_chain_length_list)
|
843 |
+
|
844 |
+
|
845 |
+
isnan = np.isnan(X)
|
846 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
|
847 |
+
X[isnan] = 0.
|
848 |
+
|
849 |
+
# Conversion
|
850 |
+
pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32, device=device)
|
851 |
+
pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32, device=device)
|
852 |
+
pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32, device=device)
|
853 |
+
|
854 |
+
tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32, device=device)
|
855 |
+
|
856 |
+
jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32)
|
857 |
+
bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32, device=device)
|
858 |
+
phi_mask = np.pad(jumps, [[0,0],[1,0]])
|
859 |
+
psi_mask = np.pad(jumps, [[0,0],[0,1]])
|
860 |
+
omega_mask = np.pad(jumps, [[0,0],[0,1]])
|
861 |
+
dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3]
|
862 |
+
dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32, device=device)
|
863 |
+
residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
|
864 |
+
S = torch.from_numpy(S).to(dtype=torch.long,device=device)
|
865 |
+
X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
|
866 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
|
867 |
+
chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
|
868 |
+
chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32, device=device)
|
869 |
+
omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32, device=device)
|
870 |
+
chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
|
871 |
+
if ca_only:
|
872 |
+
X_out = X[:,:,0]
|
873 |
+
else:
|
874 |
+
X_out = X
|
875 |
+
return X_out, S, mask, lengths, chain_M, chain_encoding_all, letter_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef_all, pssm_bias_all, pssm_log_odds_all, bias_by_res_all, tied_beta
|
876 |
+
|
877 |
+
|
878 |
+
|
879 |
+
def loss_nll(S, log_probs, mask):
|
880 |
+
""" Negative log probabilities """
|
881 |
+
criterion = torch.nn.NLLLoss(reduction='none')
|
882 |
+
loss = criterion(
|
883 |
+
log_probs.contiguous().view(-1, log_probs.size(-1)), S.contiguous().view(-1)
|
884 |
+
).view(S.size())
|
885 |
+
loss_av = torch.sum(loss * mask) / torch.sum(mask)
|
886 |
+
return loss, loss_av
|
887 |
+
|
888 |
+
|
889 |
+
def loss_smoothed(S, log_probs, mask, weight=0.1):
|
890 |
+
""" Negative log probabilities """
|
891 |
+
S_onehot = torch.nn.functional.one_hot(S, 21).float()
|
892 |
+
|
893 |
+
# Label smoothing
|
894 |
+
S_onehot = S_onehot + weight / float(S_onehot.size(-1))
|
895 |
+
S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)
|
896 |
+
|
897 |
+
loss = -(S_onehot * log_probs).sum(-1)
|
898 |
+
loss_av = torch.sum(loss * mask) / torch.sum(mask)
|
899 |
+
return loss, loss_av
|
900 |
+
|
901 |
+
class StructureDataset():
|
902 |
+
def __init__(self, jsonl_file, verbose=True, truncate=None, max_length=100,
|
903 |
+
alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
|
904 |
+
alphabet_set = set([a for a in alphabet])
|
905 |
+
discard_count = {
|
906 |
+
'bad_chars': 0,
|
907 |
+
'too_long': 0,
|
908 |
+
'bad_seq_length': 0
|
909 |
+
}
|
910 |
+
|
911 |
+
with open(jsonl_file) as f:
|
912 |
+
self.data = []
|
913 |
+
|
914 |
+
lines = f.readlines()
|
915 |
+
start = time.time()
|
916 |
+
for i, line in enumerate(lines):
|
917 |
+
entry = json.loads(line)
|
918 |
+
seq = entry['seq']
|
919 |
+
name = entry['name']
|
920 |
+
|
921 |
+
# Convert raw coords to np arrays
|
922 |
+
#for key, val in entry['coords'].items():
|
923 |
+
# entry['coords'][key] = np.asarray(val)
|
924 |
+
|
925 |
+
# Check if in alphabet
|
926 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
927 |
+
if len(bad_chars) == 0:
|
928 |
+
if len(entry['seq']) <= max_length:
|
929 |
+
if True:
|
930 |
+
self.data.append(entry)
|
931 |
+
else:
|
932 |
+
discard_count['bad_seq_length'] += 1
|
933 |
+
else:
|
934 |
+
discard_count['too_long'] += 1
|
935 |
+
else:
|
936 |
+
if verbose:
|
937 |
+
print(name, bad_chars, entry['seq'])
|
938 |
+
discard_count['bad_chars'] += 1
|
939 |
+
|
940 |
+
# Truncate early
|
941 |
+
if truncate is not None and len(self.data) == truncate:
|
942 |
+
return
|
943 |
+
|
944 |
+
if verbose and (i + 1) % 1000 == 0:
|
945 |
+
elapsed = time.time() - start
|
946 |
+
print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed))
|
947 |
+
if verbose:
|
948 |
+
print('discarded', discard_count)
|
949 |
+
def __len__(self):
|
950 |
+
return len(self.data)
|
951 |
+
|
952 |
+
def __getitem__(self, idx):
|
953 |
+
return self.data[idx]
|
954 |
+
|
955 |
+
|
956 |
+
class StructureDatasetPDB():
|
957 |
+
def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100,
|
958 |
+
alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
|
959 |
+
alphabet_set = set([a for a in alphabet])
|
960 |
+
discard_count = {
|
961 |
+
'bad_chars': 0,
|
962 |
+
'too_long': 0,
|
963 |
+
'bad_seq_length': 0
|
964 |
+
}
|
965 |
+
|
966 |
+
self.data = []
|
967 |
+
|
968 |
+
start = time.time()
|
969 |
+
for i, entry in enumerate(pdb_dict_list):
|
970 |
+
seq = entry['seq']
|
971 |
+
name = entry['name']
|
972 |
+
|
973 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
974 |
+
if len(bad_chars) == 0:
|
975 |
+
if len(entry['seq']) <= max_length:
|
976 |
+
self.data.append(entry)
|
977 |
+
else:
|
978 |
+
discard_count['too_long'] += 1
|
979 |
+
else:
|
980 |
+
discard_count['bad_chars'] += 1
|
981 |
+
|
982 |
+
# Truncate early
|
983 |
+
if truncate is not None and len(self.data) == truncate:
|
984 |
+
return
|
985 |
+
|
986 |
+
if verbose and (i + 1) % 1000 == 0:
|
987 |
+
elapsed = time.time() - start
|
988 |
+
|
989 |
+
#print('Discarded', discard_count)
|
990 |
+
def __len__(self):
|
991 |
+
return len(self.data)
|
992 |
+
|
993 |
+
def __getitem__(self, idx):
|
994 |
+
return self.data[idx]
|
995 |
+
|
996 |
+
|
997 |
+
|
998 |
+
class StructureLoader():
|
999 |
+
def __init__(self, dataset, batch_size=100, shuffle=True,
|
1000 |
+
collate_fn=lambda x:x, drop_last=False):
|
1001 |
+
self.dataset = dataset
|
1002 |
+
self.size = len(dataset)
|
1003 |
+
self.lengths = [len(dataset[i]['seq']) for i in range(self.size)]
|
1004 |
+
self.batch_size = batch_size
|
1005 |
+
sorted_ix = np.argsort(self.lengths)
|
1006 |
+
|
1007 |
+
# Cluster into batches of similar sizes
|
1008 |
+
clusters, batch = [], []
|
1009 |
+
batch_max = 0
|
1010 |
+
for ix in sorted_ix:
|
1011 |
+
size = self.lengths[ix]
|
1012 |
+
if size * (len(batch) + 1) <= self.batch_size:
|
1013 |
+
batch.append(ix)
|
1014 |
+
batch_max = size
|
1015 |
+
else:
|
1016 |
+
clusters.append(batch)
|
1017 |
+
batch, batch_max = [], 0
|
1018 |
+
if len(batch) > 0:
|
1019 |
+
clusters.append(batch)
|
1020 |
+
self.clusters = clusters
|
1021 |
+
|
1022 |
+
def __len__(self):
|
1023 |
+
return len(self.clusters)
|
1024 |
+
|
1025 |
+
def __iter__(self):
|
1026 |
+
np.random.shuffle(self.clusters)
|
1027 |
+
for b_idx in self.clusters:
|
1028 |
+
batch = [self.dataset[i] for i in b_idx]
|
1029 |
+
yield batch
|
1030 |
+
|
1031 |
+
|
1032 |
+
|
1033 |
+
# The following gather functions
|
1034 |
+
def gather_edges(edges, neighbor_idx):
|
1035 |
+
# Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
|
1036 |
+
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
|
1037 |
+
edge_features = torch.gather(edges, 2, neighbors)
|
1038 |
+
return edge_features
|
1039 |
+
|
1040 |
+
def gather_nodes(nodes, neighbor_idx):
|
1041 |
+
# Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
|
1042 |
+
# Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
|
1043 |
+
neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
|
1044 |
+
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
|
1045 |
+
# Gather and re-pack
|
1046 |
+
neighbor_features = torch.gather(nodes, 1, neighbors_flat)
|
1047 |
+
neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
|
1048 |
+
return neighbor_features
|
1049 |
+
|
1050 |
+
def gather_nodes_t(nodes, neighbor_idx):
|
1051 |
+
# Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
|
1052 |
+
idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
|
1053 |
+
neighbor_features = torch.gather(nodes, 1, idx_flat)
|
1054 |
+
return neighbor_features
|
1055 |
+
|
1056 |
+
def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
|
1057 |
+
h_nodes = gather_nodes(h_nodes, E_idx)
|
1058 |
+
h_nn = torch.cat([h_neighbors, h_nodes], -1)
|
1059 |
+
return h_nn
|
1060 |
+
|
1061 |
+
|
1062 |
+
class EncLayer(nn.Module):
|
1063 |
+
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, time_cond_dim=None):
|
1064 |
+
super(EncLayer, self).__init__()
|
1065 |
+
self.num_hidden = num_hidden
|
1066 |
+
self.num_in = num_in
|
1067 |
+
self.scale = scale
|
1068 |
+
self.dropout1 = nn.Dropout(dropout)
|
1069 |
+
self.dropout2 = nn.Dropout(dropout)
|
1070 |
+
self.dropout3 = nn.Dropout(dropout)
|
1071 |
+
self.norm1 = nn.LayerNorm(num_hidden)
|
1072 |
+
self.norm2 = nn.LayerNorm(num_hidden)
|
1073 |
+
self.norm3 = nn.LayerNorm(num_hidden)
|
1074 |
+
|
1075 |
+
if time_cond_dim is not None:
|
1076 |
+
self.time_block1 = nn.Sequential(
|
1077 |
+
Rearrange('b 1 d -> b 1 1 d'),
|
1078 |
+
nn.SiLU(),
|
1079 |
+
nn.Linear(time_cond_dim, num_hidden * 2))
|
1080 |
+
self.time_block2 = nn.Sequential(
|
1081 |
+
Rearrange('b 1 d -> b 1 1 d'),
|
1082 |
+
nn.SiLU(),
|
1083 |
+
nn.Linear(time_cond_dim, num_hidden * 2))
|
1084 |
+
|
1085 |
+
self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
1086 |
+
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1087 |
+
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1088 |
+
self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
1089 |
+
self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1090 |
+
self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1091 |
+
self.act = torch.nn.GELU()
|
1092 |
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
1093 |
+
|
1094 |
+
def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None, time_cond=None):
|
1095 |
+
""" Parallel computation of full transformer layer """
|
1096 |
+
|
1097 |
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
1098 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
|
1099 |
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
1100 |
+
|
1101 |
+
h_message = self.act(self.W2(self.act(self.W1(h_EV))))
|
1102 |
+
if time_cond is not None:
|
1103 |
+
scale, shift = self.time_block1(time_cond).chunk(2, dim=-1)
|
1104 |
+
h_message = h_message * (scale + 1) + shift
|
1105 |
+
h_message = self.W3(h_message)
|
1106 |
+
|
1107 |
+
if mask_attend is not None:
|
1108 |
+
h_message = mask_attend.unsqueeze(-1) * h_message
|
1109 |
+
dh = torch.sum(h_message, -2) / self.scale
|
1110 |
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
1111 |
+
|
1112 |
+
dh = self.dense(h_V)
|
1113 |
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
1114 |
+
if mask_V is not None:
|
1115 |
+
mask_V = mask_V.unsqueeze(-1)
|
1116 |
+
h_V = mask_V * h_V
|
1117 |
+
|
1118 |
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
1119 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
|
1120 |
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
1121 |
+
|
1122 |
+
h_message = self.act(self.W12(self.act(self.W11(h_EV))))
|
1123 |
+
if time_cond is not None:
|
1124 |
+
scale, shift = self.time_block2(time_cond).chunk(2, dim=-1)
|
1125 |
+
h_message = h_message * (scale + 1) + shift
|
1126 |
+
h_message = self.W13(h_message)
|
1127 |
+
|
1128 |
+
h_E = self.norm3(h_E + self.dropout3(h_message))
|
1129 |
+
return h_V, h_E
|
1130 |
+
|
1131 |
+
|
1132 |
+
class DecLayer(nn.Module):
|
1133 |
+
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, time_cond_dim=None):
|
1134 |
+
super(DecLayer, self).__init__()
|
1135 |
+
self.num_hidden = num_hidden
|
1136 |
+
self.num_in = num_in
|
1137 |
+
self.scale = scale
|
1138 |
+
self.dropout1 = nn.Dropout(dropout)
|
1139 |
+
self.dropout2 = nn.Dropout(dropout)
|
1140 |
+
self.norm1 = nn.LayerNorm(num_hidden)
|
1141 |
+
self.norm2 = nn.LayerNorm(num_hidden)
|
1142 |
+
|
1143 |
+
if time_cond_dim is not None:
|
1144 |
+
self.time_block = nn.Sequential(
|
1145 |
+
Rearrange('b 1 d -> b 1 1 d'),
|
1146 |
+
nn.SiLU(),
|
1147 |
+
nn.Linear(time_cond_dim, num_hidden * 2))
|
1148 |
+
|
1149 |
+
self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
1150 |
+
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1151 |
+
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
|
1152 |
+
self.act = torch.nn.GELU()
|
1153 |
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
1154 |
+
|
1155 |
+
def forward(self, h_V, h_E, mask_V=None, mask_attend=None, time_cond=None):
|
1156 |
+
""" Parallel computation of full transformer layer """
|
1157 |
+
|
1158 |
+
# Concatenate h_V_i to h_E_ij
|
1159 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1)
|
1160 |
+
h_EV = torch.cat([h_V_expand, h_E], -1)
|
1161 |
+
|
1162 |
+
h_message = self.act(self.W2(self.act(self.W1(h_EV))))
|
1163 |
+
if time_cond is not None:
|
1164 |
+
scale, shift = self.time_block(time_cond).chunk(2, dim=-1)
|
1165 |
+
h_message = h_message * (scale + 1) + shift
|
1166 |
+
h_message = self.W3(h_message)
|
1167 |
+
|
1168 |
+
if mask_attend is not None:
|
1169 |
+
h_message = mask_attend.unsqueeze(-1) * h_message
|
1170 |
+
dh = torch.sum(h_message, -2) / self.scale
|
1171 |
+
|
1172 |
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
1173 |
+
|
1174 |
+
# Position-wise feedforward
|
1175 |
+
dh = self.dense(h_V)
|
1176 |
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
1177 |
+
|
1178 |
+
if mask_V is not None:
|
1179 |
+
mask_V = mask_V.unsqueeze(-1)
|
1180 |
+
h_V = mask_V * h_V
|
1181 |
+
return h_V
|
1182 |
+
|
1183 |
+
|
1184 |
+
|
1185 |
+
class PositionWiseFeedForward(nn.Module):
|
1186 |
+
def __init__(self, num_hidden, num_ff):
|
1187 |
+
super(PositionWiseFeedForward, self).__init__()
|
1188 |
+
self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
|
1189 |
+
self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
|
1190 |
+
self.act = torch.nn.GELU()
|
1191 |
+
def forward(self, h_V):
|
1192 |
+
h = self.act(self.W_in(h_V))
|
1193 |
+
h = self.W_out(h)
|
1194 |
+
return h
|
1195 |
+
|
1196 |
+
class PositionalEncodings(nn.Module):
|
1197 |
+
def __init__(self, num_embeddings, max_relative_feature=32):
|
1198 |
+
super(PositionalEncodings, self).__init__()
|
1199 |
+
self.num_embeddings = num_embeddings
|
1200 |
+
self.max_relative_feature = max_relative_feature
|
1201 |
+
self.linear = nn.Linear(2*max_relative_feature+1+1, num_embeddings)
|
1202 |
+
|
1203 |
+
def forward(self, offset, mask):
|
1204 |
+
d = torch.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature)*mask + (1-mask)*(2*self.max_relative_feature+1)
|
1205 |
+
d_onehot = torch.nn.functional.one_hot(d, 2*self.max_relative_feature+1+1)
|
1206 |
+
E = self.linear(d_onehot.float())
|
1207 |
+
return E
|
1208 |
+
|
1209 |
+
|
1210 |
+
|
1211 |
+
class CA_ProteinFeatures(nn.Module):
|
1212 |
+
def __init__(self, edge_features, node_features, num_positional_embeddings=16,
|
1213 |
+
num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
|
1214 |
+
""" Extract protein features """
|
1215 |
+
super(CA_ProteinFeatures, self).__init__()
|
1216 |
+
self.edge_features = edge_features
|
1217 |
+
self.node_features = node_features
|
1218 |
+
self.top_k = top_k
|
1219 |
+
self.augment_eps = augment_eps
|
1220 |
+
self.num_rbf = num_rbf
|
1221 |
+
self.num_positional_embeddings = num_positional_embeddings
|
1222 |
+
|
1223 |
+
# Positional encoding
|
1224 |
+
self.embeddings = PositionalEncodings(num_positional_embeddings)
|
1225 |
+
# Normalization and embedding
|
1226 |
+
node_in, edge_in = 3, num_positional_embeddings + num_rbf*9 + 7
|
1227 |
+
self.node_embedding = nn.Linear(node_in, node_features, bias=False) #NOT USED
|
1228 |
+
self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
|
1229 |
+
self.norm_nodes = nn.LayerNorm(node_features)
|
1230 |
+
self.norm_edges = nn.LayerNorm(edge_features)
|
1231 |
+
|
1232 |
+
|
1233 |
+
def _quaternions(self, R):
|
1234 |
+
""" Convert a batch of 3D rotations [R] to quaternions [Q]
|
1235 |
+
R [...,3,3]
|
1236 |
+
Q [...,4]
|
1237 |
+
"""
|
1238 |
+
# Simple Wikipedia version
|
1239 |
+
# en.wikipedia.org/wiki/Rotation_matrix#Quaternion
|
1240 |
+
# For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
|
1241 |
+
diag = torch.diagonal(R, dim1=-2, dim2=-1)
|
1242 |
+
Rxx, Ryy, Rzz = diag.unbind(-1)
|
1243 |
+
magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
|
1244 |
+
Rxx - Ryy - Rzz,
|
1245 |
+
- Rxx + Ryy - Rzz,
|
1246 |
+
- Rxx - Ryy + Rzz
|
1247 |
+
], -1)))
|
1248 |
+
_R = lambda i,j: R[:,:,:,i,j]
|
1249 |
+
signs = torch.sign(torch.stack([
|
1250 |
+
_R(2,1) - _R(1,2),
|
1251 |
+
_R(0,2) - _R(2,0),
|
1252 |
+
_R(1,0) - _R(0,1)
|
1253 |
+
], -1))
|
1254 |
+
xyz = signs * magnitudes
|
1255 |
+
# The relu enforces a non-negative trace
|
1256 |
+
w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
|
1257 |
+
Q = torch.cat((xyz, w), -1)
|
1258 |
+
Q = F.normalize(Q, dim=-1)
|
1259 |
+
return Q
|
1260 |
+
|
1261 |
+
def _orientations_coarse(self, X, E_idx, eps=1e-6):
|
1262 |
+
dX = X[:,1:,:] - X[:,:-1,:]
|
1263 |
+
dX_norm = torch.norm(dX,dim=-1)
|
1264 |
+
dX_mask = (3.6<dX_norm) & (dX_norm<4.0) #exclude CA-CA jumps
|
1265 |
+
dX = dX*dX_mask[:,:,None]
|
1266 |
+
U = F.normalize(dX, dim=-1)
|
1267 |
+
u_2 = U[:,:-2,:]
|
1268 |
+
u_1 = U[:,1:-1,:]
|
1269 |
+
u_0 = U[:,2:,:]
|
1270 |
+
# Backbone normals
|
1271 |
+
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
|
1272 |
+
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
|
1273 |
+
|
1274 |
+
# Bond angle calculation
|
1275 |
+
cosA = -(u_1 * u_0).sum(-1)
|
1276 |
+
cosA = torch.clamp(cosA, -1+eps, 1-eps)
|
1277 |
+
A = torch.acos(cosA)
|
1278 |
+
# Angle between normals
|
1279 |
+
cosD = (n_2 * n_1).sum(-1)
|
1280 |
+
cosD = torch.clamp(cosD, -1+eps, 1-eps)
|
1281 |
+
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
|
1282 |
+
# Backbone features
|
1283 |
+
AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
|
1284 |
+
AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0)
|
1285 |
+
|
1286 |
+
# Build relative orientations
|
1287 |
+
o_1 = F.normalize(u_2 - u_1, dim=-1)
|
1288 |
+
O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
|
1289 |
+
O = O.view(list(O.shape[:2]) + [9])
|
1290 |
+
O = F.pad(O, (0,0,1,2), 'constant', 0)
|
1291 |
+
O_neighbors = gather_nodes(O, E_idx)
|
1292 |
+
X_neighbors = gather_nodes(X, E_idx)
|
1293 |
+
|
1294 |
+
# Re-view as rotation matrices
|
1295 |
+
O = O.view(list(O.shape[:2]) + [3,3])
|
1296 |
+
O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3])
|
1297 |
+
|
1298 |
+
# Rotate into local reference frames
|
1299 |
+
dX = X_neighbors - X.unsqueeze(-2)
|
1300 |
+
dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
|
1301 |
+
dU = F.normalize(dU, dim=-1)
|
1302 |
+
R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors)
|
1303 |
+
Q = self._quaternions(R)
|
1304 |
+
|
1305 |
+
# Orientation features
|
1306 |
+
O_features = torch.cat((dU,Q), dim=-1)
|
1307 |
+
return AD_features, O_features
|
1308 |
+
|
1309 |
+
|
1310 |
+
|
1311 |
+
def _dist(self, X, mask, eps=1E-6):
|
1312 |
+
""" Pairwise euclidean distances """
|
1313 |
+
# Convolutional network on NCHW
|
1314 |
+
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
|
1315 |
+
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
1316 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
1317 |
+
|
1318 |
+
# Identify k nearest neighbors (including self)
|
1319 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
1320 |
+
D_adjust = D + (1. - mask_2D) * D_max
|
1321 |
+
D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
|
1322 |
+
mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx)
|
1323 |
+
return D_neighbors, E_idx, mask_neighbors
|
1324 |
+
|
1325 |
+
def _rbf(self, D):
|
1326 |
+
# Distance radial basis function
|
1327 |
+
device = D.device
|
1328 |
+
D_min, D_max, D_count = 2., 22., self.num_rbf
|
1329 |
+
D_mu = torch.linspace(D_min, D_max, D_count).to(device)
|
1330 |
+
D_mu = D_mu.view([1,1,1,-1])
|
1331 |
+
D_sigma = (D_max - D_min) / D_count
|
1332 |
+
D_expand = torch.unsqueeze(D, -1)
|
1333 |
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
1334 |
+
return RBF
|
1335 |
+
|
1336 |
+
def _get_rbf(self, A, B, E_idx):
|
1337 |
+
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
|
1338 |
+
D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
|
1339 |
+
RBF_A_B = self._rbf(D_A_B_neighbors)
|
1340 |
+
return RBF_A_B
|
1341 |
+
|
1342 |
+
def forward(self, Ca, mask, residue_idx, chain_labels):
|
1343 |
+
""" Featurize coordinates as an attributed graph """
|
1344 |
+
if self.augment_eps > 0:
|
1345 |
+
Ca = Ca + self.augment_eps * torch.randn_like(Ca)
|
1346 |
+
|
1347 |
+
D_neighbors, E_idx, mask_neighbors = self._dist(Ca, mask)
|
1348 |
+
|
1349 |
+
Ca_0 = torch.zeros(Ca.shape, device=Ca.device)
|
1350 |
+
Ca_2 = torch.zeros(Ca.shape, device=Ca.device)
|
1351 |
+
Ca_0[:,1:,:] = Ca[:,:-1,:]
|
1352 |
+
Ca_1 = Ca
|
1353 |
+
Ca_2[:,:-1,:] = Ca[:,1:,:]
|
1354 |
+
|
1355 |
+
V, O_features = self._orientations_coarse(Ca, E_idx)
|
1356 |
+
|
1357 |
+
RBF_all = []
|
1358 |
+
RBF_all.append(self._rbf(D_neighbors)) #Ca_1-Ca_1
|
1359 |
+
RBF_all.append(self._get_rbf(Ca_0, Ca_0, E_idx))
|
1360 |
+
RBF_all.append(self._get_rbf(Ca_2, Ca_2, E_idx))
|
1361 |
+
|
1362 |
+
RBF_all.append(self._get_rbf(Ca_0, Ca_1, E_idx))
|
1363 |
+
RBF_all.append(self._get_rbf(Ca_0, Ca_2, E_idx))
|
1364 |
+
|
1365 |
+
RBF_all.append(self._get_rbf(Ca_1, Ca_0, E_idx))
|
1366 |
+
RBF_all.append(self._get_rbf(Ca_1, Ca_2, E_idx))
|
1367 |
+
|
1368 |
+
RBF_all.append(self._get_rbf(Ca_2, Ca_0, E_idx))
|
1369 |
+
RBF_all.append(self._get_rbf(Ca_2, Ca_1, E_idx))
|
1370 |
+
|
1371 |
+
|
1372 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
1373 |
+
|
1374 |
+
|
1375 |
+
offset = residue_idx[:,:,None]-residue_idx[:,None,:]
|
1376 |
+
offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
|
1377 |
+
|
1378 |
+
d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long()
|
1379 |
+
E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
|
1380 |
+
E_positional = self.embeddings(offset.long(), E_chains)
|
1381 |
+
E = torch.cat((E_positional, RBF_all, O_features), -1)
|
1382 |
+
|
1383 |
+
|
1384 |
+
E = self.edge_embedding(E)
|
1385 |
+
E = self.norm_edges(E)
|
1386 |
+
|
1387 |
+
return E, E_idx
|
1388 |
+
|
1389 |
+
|
1390 |
+
def get_closest_neighbors(X, mask, top_k, eps=1e-6):
|
1391 |
+
# X is ca coords (b, n, 3), mask is seq mask
|
1392 |
+
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
|
1393 |
+
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
1394 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
1395 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
1396 |
+
D_adjust = D + (1. - mask_2D) * D_max
|
1397 |
+
sampled_top_k = top_k
|
1398 |
+
D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(top_k, X.shape[1]), dim=-1, largest=False)
|
1399 |
+
return D_neighbors, E_idx
|
1400 |
+
|
1401 |
+
|
1402 |
+
class ProteinFeatures(nn.Module):
|
1403 |
+
def __init__(self, edge_features, node_features, num_positional_embeddings=16,
|
1404 |
+
num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
|
1405 |
+
""" Extract protein features """
|
1406 |
+
super(ProteinFeatures, self).__init__()
|
1407 |
+
self.edge_features = edge_features
|
1408 |
+
self.node_features = node_features
|
1409 |
+
self.top_k = top_k
|
1410 |
+
self.augment_eps = augment_eps
|
1411 |
+
self.num_rbf = num_rbf
|
1412 |
+
self.num_positional_embeddings = num_positional_embeddings
|
1413 |
+
|
1414 |
+
self.embeddings = PositionalEncodings(num_positional_embeddings)
|
1415 |
+
node_in, edge_in = 6, num_positional_embeddings + num_rbf*25
|
1416 |
+
self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
|
1417 |
+
self.norm_edges = nn.LayerNorm(edge_features)
|
1418 |
+
|
1419 |
+
def _dist(self, X, mask, eps=1E-6):
|
1420 |
+
# mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
|
1421 |
+
# dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
1422 |
+
# D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
1423 |
+
# D_max, _ = torch.max(D, -1, keepdim=True)
|
1424 |
+
# D_adjust = D + (1. - mask_2D) * D_max
|
1425 |
+
# sampled_top_k = self.top_k
|
1426 |
+
# D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
|
1427 |
+
# return D_neighbors, E_idx
|
1428 |
+
return get_closest_neighbors(X, mask, self.top_k, eps=eps)
|
1429 |
+
|
1430 |
+
def _rbf(self, D):
|
1431 |
+
device = D.device
|
1432 |
+
D_min, D_max, D_count = 2., 22., self.num_rbf
|
1433 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
1434 |
+
D_mu = D_mu.view([1,1,1,-1])
|
1435 |
+
D_sigma = (D_max - D_min) / D_count
|
1436 |
+
D_expand = torch.unsqueeze(D, -1)
|
1437 |
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
1438 |
+
return RBF
|
1439 |
+
|
1440 |
+
def _get_rbf(self, A, B, E_idx):
|
1441 |
+
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
|
1442 |
+
D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
|
1443 |
+
RBF_A_B = self._rbf(D_A_B_neighbors)
|
1444 |
+
return RBF_A_B
|
1445 |
+
|
1446 |
+
def forward(self, X, mask, residue_idx, chain_labels):
|
1447 |
+
if self.augment_eps > 0:
|
1448 |
+
X = X + self.augment_eps * torch.randn_like(X)
|
1449 |
+
|
1450 |
+
b = X[:,:,1,:] - X[:,:,0,:]
|
1451 |
+
c = X[:,:,2,:] - X[:,:,1,:]
|
1452 |
+
a = torch.cross(b, c, dim=-1)
|
1453 |
+
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:]
|
1454 |
+
Ca = X[:,:,1,:]
|
1455 |
+
N = X[:,:,0,:]
|
1456 |
+
C = X[:,:,2,:]
|
1457 |
+
O = X[:,:,3,:]
|
1458 |
+
|
1459 |
+
D_neighbors, E_idx = self._dist(Ca, mask)
|
1460 |
+
|
1461 |
+
RBF_all = []
|
1462 |
+
RBF_all.append(self._rbf(D_neighbors)) #Ca-Ca
|
1463 |
+
RBF_all.append(self._get_rbf(N, N, E_idx)) #N-N
|
1464 |
+
RBF_all.append(self._get_rbf(C, C, E_idx)) #C-C
|
1465 |
+
RBF_all.append(self._get_rbf(O, O, E_idx)) #O-O
|
1466 |
+
RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) #Cb-Cb
|
1467 |
+
RBF_all.append(self._get_rbf(Ca, N, E_idx)) #Ca-N
|
1468 |
+
RBF_all.append(self._get_rbf(Ca, C, E_idx)) #Ca-C
|
1469 |
+
RBF_all.append(self._get_rbf(Ca, O, E_idx)) #Ca-O
|
1470 |
+
RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) #Ca-Cb
|
1471 |
+
RBF_all.append(self._get_rbf(N, C, E_idx)) #N-C
|
1472 |
+
RBF_all.append(self._get_rbf(N, O, E_idx)) #N-O
|
1473 |
+
RBF_all.append(self._get_rbf(N, Cb, E_idx)) #N-Cb
|
1474 |
+
RBF_all.append(self._get_rbf(Cb, C, E_idx)) #Cb-C
|
1475 |
+
RBF_all.append(self._get_rbf(Cb, O, E_idx)) #Cb-O
|
1476 |
+
RBF_all.append(self._get_rbf(O, C, E_idx)) #O-C
|
1477 |
+
RBF_all.append(self._get_rbf(N, Ca, E_idx)) #N-Ca
|
1478 |
+
RBF_all.append(self._get_rbf(C, Ca, E_idx)) #C-Ca
|
1479 |
+
RBF_all.append(self._get_rbf(O, Ca, E_idx)) #O-Ca
|
1480 |
+
RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) #Cb-Ca
|
1481 |
+
RBF_all.append(self._get_rbf(C, N, E_idx)) #C-N
|
1482 |
+
RBF_all.append(self._get_rbf(O, N, E_idx)) #O-N
|
1483 |
+
RBF_all.append(self._get_rbf(Cb, N, E_idx)) #Cb-N
|
1484 |
+
RBF_all.append(self._get_rbf(C, Cb, E_idx)) #C-Cb
|
1485 |
+
RBF_all.append(self._get_rbf(O, Cb, E_idx)) #O-Cb
|
1486 |
+
RBF_all.append(self._get_rbf(C, O, E_idx)) #C-O
|
1487 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
1488 |
+
|
1489 |
+
offset = residue_idx[:,:,None]-residue_idx[:,None,:]
|
1490 |
+
offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
|
1491 |
+
|
1492 |
+
d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long() #find self vs non-self interaction
|
1493 |
+
E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
|
1494 |
+
E_positional = self.embeddings(offset.long(), E_chains)
|
1495 |
+
E = torch.cat((E_positional, RBF_all), -1)
|
1496 |
+
E = self.edge_embedding(E)
|
1497 |
+
E = self.norm_edges(E)
|
1498 |
+
return E, E_idx
|
1499 |
+
|
1500 |
+
|
1501 |
+
|
1502 |
+
class ProteinMPNN(nn.Module):
|
1503 |
+
def __init__(self, num_letters, node_features, edge_features,
|
1504 |
+
hidden_dim, num_encoder_layers=3, num_decoder_layers=3,
|
1505 |
+
vocab=21, k_neighbors=64, augment_eps=0.05, dropout=0.1, ca_only=False, time_cond_dim=None, input_S_is_embeddings=False):
|
1506 |
+
super(ProteinMPNN, self).__init__()
|
1507 |
+
|
1508 |
+
# Hyperparameters
|
1509 |
+
self.node_features = node_features
|
1510 |
+
self.edge_features = edge_features
|
1511 |
+
self.hidden_dim = hidden_dim
|
1512 |
+
|
1513 |
+
# Featurization layers
|
1514 |
+
if ca_only:
|
1515 |
+
self.features = CA_ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
|
1516 |
+
self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
|
1517 |
+
else:
|
1518 |
+
self.features = ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
|
1519 |
+
|
1520 |
+
self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
|
1521 |
+
self.input_S_is_embeddings = input_S_is_embeddings
|
1522 |
+
if not self.input_S_is_embeddings:
|
1523 |
+
self.W_s = nn.Embedding(vocab, hidden_dim)
|
1524 |
+
|
1525 |
+
if time_cond_dim is not None:
|
1526 |
+
self.time_block = nn.Sequential(
|
1527 |
+
nn.SiLU(),
|
1528 |
+
nn.Linear(time_cond_dim, hidden_dim)
|
1529 |
+
)
|
1530 |
+
|
1531 |
+
# Encoder layers
|
1532 |
+
self.encoder_layers = nn.ModuleList([
|
1533 |
+
EncLayer(hidden_dim, hidden_dim*2, dropout=dropout, time_cond_dim=time_cond_dim)
|
1534 |
+
for _ in range(num_encoder_layers)
|
1535 |
+
])
|
1536 |
+
|
1537 |
+
# Decoder layers
|
1538 |
+
self.decoder_layers = nn.ModuleList([
|
1539 |
+
DecLayer(hidden_dim, hidden_dim*3, dropout=dropout, time_cond_dim=time_cond_dim)
|
1540 |
+
for _ in range(num_decoder_layers)
|
1541 |
+
])
|
1542 |
+
self.W_out = nn.Linear(hidden_dim, num_letters, bias=True)
|
1543 |
+
|
1544 |
+
for p in self.parameters():
|
1545 |
+
if p.dim() > 1:
|
1546 |
+
nn.init.xavier_uniform_(p)
|
1547 |
+
|
1548 |
+
def forward(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, use_input_decoding_order=False, decoding_order=None, causal_mask=True, time_cond=None, return_node_embs=False):
|
1549 |
+
""" Graph-conditioned sequence model """
|
1550 |
+
device=X.device
|
1551 |
+
# Prepare node and edge embeddings
|
1552 |
+
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
|
1553 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
|
1554 |
+
if time_cond is not None:
|
1555 |
+
time_cond_nodes = self.time_block(time_cond)
|
1556 |
+
h_V += time_cond_nodes # time_cond is b, 1, c
|
1557 |
+
h_E = self.W_e(E)
|
1558 |
+
|
1559 |
+
# Encoder is unmasked self-attention
|
1560 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
1561 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
1562 |
+
for layer in self.encoder_layers:
|
1563 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend, time_cond=time_cond)
|
1564 |
+
|
1565 |
+
encoder_embs = h_V
|
1566 |
+
|
1567 |
+
# Concatenate sequence embeddings for autoregressive decoder
|
1568 |
+
if self.input_S_is_embeddings:
|
1569 |
+
h_S = S
|
1570 |
+
else:
|
1571 |
+
h_S = self.W_s(S)
|
1572 |
+
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
|
1573 |
+
|
1574 |
+
# Build encoder embeddings
|
1575 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
1576 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
1577 |
+
|
1578 |
+
|
1579 |
+
chain_M = chain_M*mask #update chain_M to include missing regions
|
1580 |
+
mask_size = E_idx.shape[1]
|
1581 |
+
if causal_mask:
|
1582 |
+
if not use_input_decoding_order:
|
1583 |
+
decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
1584 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
|
1585 |
+
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
|
1586 |
+
else:
|
1587 |
+
order_mask_backward = torch.ones(X.shape[0], mask_size, mask_size, device=device)
|
1588 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
1589 |
+
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
|
1590 |
+
mask_bw = mask_1D * mask_attend
|
1591 |
+
mask_fw = mask_1D * (1. - mask_attend)
|
1592 |
+
|
1593 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
1594 |
+
for layer in self.decoder_layers:
|
1595 |
+
# Masked positions attend to encoder information, unmasked see.
|
1596 |
+
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
|
1597 |
+
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
|
1598 |
+
h_V = layer(h_V, h_ESV, mask, time_cond=time_cond)
|
1599 |
+
|
1600 |
+
if return_node_embs:
|
1601 |
+
return h_V, encoder_embs
|
1602 |
+
else:
|
1603 |
+
logits = self.W_out(h_V)
|
1604 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
1605 |
+
return log_probs
|
1606 |
+
|
1607 |
+
|
1608 |
+
def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None):
|
1609 |
+
device = X.device
|
1610 |
+
# Prepare node and edge embeddings
|
1611 |
+
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
|
1612 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
|
1613 |
+
h_E = self.W_e(E)
|
1614 |
+
|
1615 |
+
# Encoder is unmasked self-attention
|
1616 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
1617 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
1618 |
+
for layer in self.encoder_layers:
|
1619 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
1620 |
+
|
1621 |
+
# Decoder uses masked self-attention
|
1622 |
+
chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
|
1623 |
+
decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
1624 |
+
mask_size = E_idx.shape[1]
|
1625 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
|
1626 |
+
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
|
1627 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
1628 |
+
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
|
1629 |
+
mask_bw = mask_1D * mask_attend
|
1630 |
+
mask_fw = mask_1D * (1. - mask_attend)
|
1631 |
+
|
1632 |
+
N_batch, N_nodes = X.size(0), X.size(1)
|
1633 |
+
log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
|
1634 |
+
all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
|
1635 |
+
h_S = torch.zeros_like(h_V, device=device)
|
1636 |
+
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
|
1637 |
+
h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
|
1638 |
+
constant = torch.tensor(omit_AAs_np, device=device)
|
1639 |
+
constant_bias = torch.tensor(bias_AAs_np, device=device)
|
1640 |
+
#chain_mask_combined = chain_mask*chain_M_pos
|
1641 |
+
omit_AA_mask_flag = omit_AA_mask != None
|
1642 |
+
|
1643 |
+
|
1644 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
1645 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
1646 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
1647 |
+
for t_ in range(N_nodes):
|
1648 |
+
t = decoding_order[:,t_] #[B]
|
1649 |
+
chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B]
|
1650 |
+
mask_gathered = torch.gather(mask, 1, t[:,None]) #[B]
|
1651 |
+
bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21]
|
1652 |
+
if (mask_gathered==0).all(): #for padded or missing regions only
|
1653 |
+
S_t = torch.gather(S_true, 1, t[:,None])
|
1654 |
+
else:
|
1655 |
+
# Hidden layers
|
1656 |
+
E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1]))
|
1657 |
+
h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1]))
|
1658 |
+
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
|
1659 |
+
h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
|
1660 |
+
mask_t = torch.gather(mask, 1, t[:,None])
|
1661 |
+
for l, layer in enumerate(self.decoder_layers):
|
1662 |
+
# Updated relational features for future states
|
1663 |
+
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
|
1664 |
+
h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1]))
|
1665 |
+
h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t
|
1666 |
+
h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t))
|
1667 |
+
# Sampling step
|
1668 |
+
h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0]
|
1669 |
+
logits = self.W_out(h_V_t) / temperature
|
1670 |
+
probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
|
1671 |
+
if pssm_bias_flag:
|
1672 |
+
pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0]
|
1673 |
+
pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0]
|
1674 |
+
probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
|
1675 |
+
if pssm_log_odds_flag:
|
1676 |
+
pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21]
|
1677 |
+
probs_masked = probs*pssm_log_odds_mask_gathered
|
1678 |
+
probs_masked += probs * 0.001
|
1679 |
+
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
|
1680 |
+
if omit_AA_mask_flag:
|
1681 |
+
omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21]
|
1682 |
+
probs_masked = probs*(1.0-omit_AA_mask_gathered)
|
1683 |
+
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
|
1684 |
+
S_t = torch.multinomial(probs, 1)
|
1685 |
+
all_probs.scatter_(1, t[:,None,None].repeat(1,1,21), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float())
|
1686 |
+
S_true_gathered = torch.gather(S_true, 1, t[:,None])
|
1687 |
+
S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long()
|
1688 |
+
temp1 = self.W_s(S_t)
|
1689 |
+
h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1)
|
1690 |
+
S.scatter_(1, t[:,None], S_t)
|
1691 |
+
output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
|
1692 |
+
return output_dict
|
1693 |
+
|
1694 |
+
|
1695 |
+
def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None):
|
1696 |
+
device = X.device
|
1697 |
+
# Prepare node and edge embeddings
|
1698 |
+
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
|
1699 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
|
1700 |
+
h_E = self.W_e(E)
|
1701 |
+
# Encoder is unmasked self-attention
|
1702 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
1703 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
1704 |
+
for layer in self.encoder_layers:
|
1705 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
1706 |
+
|
1707 |
+
# Decoder uses masked self-attention
|
1708 |
+
chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
|
1709 |
+
decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
1710 |
+
|
1711 |
+
new_decoding_order = []
|
1712 |
+
for t_dec in list(decoding_order[0,].cpu().data.numpy()):
|
1713 |
+
if t_dec not in list(itertools.chain(*new_decoding_order)):
|
1714 |
+
list_a = [item for item in tied_pos if t_dec in item]
|
1715 |
+
if list_a:
|
1716 |
+
new_decoding_order.append(list_a[0])
|
1717 |
+
else:
|
1718 |
+
new_decoding_order.append([t_dec])
|
1719 |
+
decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1)
|
1720 |
+
|
1721 |
+
mask_size = E_idx.shape[1]
|
1722 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
|
1723 |
+
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
|
1724 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
1725 |
+
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
|
1726 |
+
mask_bw = mask_1D * mask_attend
|
1727 |
+
mask_fw = mask_1D * (1. - mask_attend)
|
1728 |
+
|
1729 |
+
N_batch, N_nodes = X.size(0), X.size(1)
|
1730 |
+
log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
|
1731 |
+
all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
|
1732 |
+
h_S = torch.zeros_like(h_V, device=device)
|
1733 |
+
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
|
1734 |
+
h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
|
1735 |
+
constant = torch.tensor(omit_AAs_np, device=device)
|
1736 |
+
constant_bias = torch.tensor(bias_AAs_np, device=device)
|
1737 |
+
omit_AA_mask_flag = omit_AA_mask != None
|
1738 |
+
|
1739 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
1740 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
1741 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
1742 |
+
for t_list in new_decoding_order:
|
1743 |
+
logits = 0.0
|
1744 |
+
logit_list = []
|
1745 |
+
done_flag = False
|
1746 |
+
for t in t_list:
|
1747 |
+
if (mask[:,t]==0).all():
|
1748 |
+
S_t = S_true[:,t]
|
1749 |
+
for t in t_list:
|
1750 |
+
h_S[:,t,:] = self.W_s(S_t)
|
1751 |
+
S[:,t] = S_t
|
1752 |
+
done_flag = True
|
1753 |
+
break
|
1754 |
+
else:
|
1755 |
+
E_idx_t = E_idx[:,t:t+1,:]
|
1756 |
+
h_E_t = h_E[:,t:t+1,:,:]
|
1757 |
+
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
|
1758 |
+
h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:]
|
1759 |
+
mask_t = mask[:,t:t+1]
|
1760 |
+
for l, layer in enumerate(self.decoder_layers):
|
1761 |
+
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
|
1762 |
+
h_V_t = h_V_stack[l][:,t:t+1,:]
|
1763 |
+
h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t
|
1764 |
+
h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1)
|
1765 |
+
h_V_t = h_V_stack[-1][:,t,:]
|
1766 |
+
logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list))
|
1767 |
+
logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list)
|
1768 |
+
if done_flag:
|
1769 |
+
pass
|
1770 |
+
else:
|
1771 |
+
bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21]
|
1772 |
+
probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
|
1773 |
+
if pssm_bias_flag:
|
1774 |
+
pssm_coef_gathered = pssm_coef[:,t]
|
1775 |
+
pssm_bias_gathered = pssm_bias[:,t]
|
1776 |
+
probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
|
1777 |
+
if pssm_log_odds_flag:
|
1778 |
+
pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t]
|
1779 |
+
probs_masked = probs*pssm_log_odds_mask_gathered
|
1780 |
+
probs_masked += probs * 0.001
|
1781 |
+
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
|
1782 |
+
if omit_AA_mask_flag:
|
1783 |
+
omit_AA_mask_gathered = omit_AA_mask[:,t]
|
1784 |
+
probs_masked = probs*(1.0-omit_AA_mask_gathered)
|
1785 |
+
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
|
1786 |
+
S_t_repeat = torch.multinomial(probs, 1).squeeze(-1)
|
1787 |
+
S_t_repeat = (chain_mask[:,t]*S_t_repeat + (1-chain_mask[:,t])*S_true[:,t]).long() #hard pick fixed positions
|
1788 |
+
for t in t_list:
|
1789 |
+
h_S[:,t,:] = self.W_s(S_t_repeat)
|
1790 |
+
S[:,t] = S_t_repeat
|
1791 |
+
all_probs[:,t,:] = probs.float()
|
1792 |
+
output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
|
1793 |
+
return output_dict
|
1794 |
+
|
1795 |
+
|
1796 |
+
def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False):
|
1797 |
+
""" Graph-conditioned sequence model """
|
1798 |
+
device=X.device
|
1799 |
+
# Prepare node and edge embeddings
|
1800 |
+
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
|
1801 |
+
h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
|
1802 |
+
h_E = self.W_e(E)
|
1803 |
+
|
1804 |
+
# Encoder is unmasked self-attention
|
1805 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
1806 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
1807 |
+
for layer in self.encoder_layers:
|
1808 |
+
h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend)
|
1809 |
+
|
1810 |
+
# Concatenate sequence embeddings for autoregressive decoder
|
1811 |
+
h_S = self.W_s(S)
|
1812 |
+
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
|
1813 |
+
|
1814 |
+
# Build encoder embeddings
|
1815 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
1816 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx)
|
1817 |
+
|
1818 |
+
|
1819 |
+
chain_M = chain_M*mask #update chain_M to include missing regions
|
1820 |
+
|
1821 |
+
chain_M_np = chain_M.cpu().numpy()
|
1822 |
+
idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0]
|
1823 |
+
log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float()
|
1824 |
+
|
1825 |
+
for idx in idx_to_loop:
|
1826 |
+
h_V = torch.clone(h_V_enc)
|
1827 |
+
order_mask = torch.zeros(chain_M.shape[1], device=device).float()
|
1828 |
+
if backbone_only:
|
1829 |
+
order_mask = torch.ones(chain_M.shape[1], device=device).float()
|
1830 |
+
order_mask[idx] = 0.
|
1831 |
+
else:
|
1832 |
+
order_mask = torch.zeros(chain_M.shape[1], device=device).float()
|
1833 |
+
order_mask[idx] = 1.
|
1834 |
+
decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
1835 |
+
mask_size = E_idx.shape[1]
|
1836 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
|
1837 |
+
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
|
1838 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
1839 |
+
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
|
1840 |
+
mask_bw = mask_1D * mask_attend
|
1841 |
+
mask_fw = mask_1D * (1. - mask_attend)
|
1842 |
+
|
1843 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
1844 |
+
for layer in self.decoder_layers:
|
1845 |
+
# Masked positions attend to encoder information, unmasked see.
|
1846 |
+
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
|
1847 |
+
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
|
1848 |
+
h_V = layer(h_V, h_ESV, mask)
|
1849 |
+
|
1850 |
+
logits = self.W_out(h_V)
|
1851 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
1852 |
+
log_conditional_probs[:,idx,:] = log_probs[:,idx,:]
|
1853 |
+
return log_conditional_probs
|
1854 |
+
|
1855 |
+
|
1856 |
+
def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all):
|
1857 |
+
""" Graph-conditioned sequence model """
|
1858 |
+
device=X.device
|
1859 |
+
# Prepare node and edge embeddings
|
1860 |
+
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
|
1861 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
|
1862 |
+
h_E = self.W_e(E)
|
1863 |
+
|
1864 |
+
# Encoder is unmasked self-attention
|
1865 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
1866 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
1867 |
+
for layer in self.encoder_layers:
|
1868 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
1869 |
+
|
1870 |
+
# Build encoder embeddings
|
1871 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx)
|
1872 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
1873 |
+
|
1874 |
+
order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device)
|
1875 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
1876 |
+
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
|
1877 |
+
mask_bw = mask_1D * mask_attend
|
1878 |
+
mask_fw = mask_1D * (1. - mask_attend)
|
1879 |
+
|
1880 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
1881 |
+
for layer in self.decoder_layers:
|
1882 |
+
h_V = layer(h_V, h_EXV_encoder_fw, mask)
|
1883 |
+
|
1884 |
+
logits = self.W_out(h_V)
|
1885 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
1886 |
+
return log_probs
|
core/residue_constants.py
ADDED
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Constants used in AlphaFold.
|
16 |
+
Adapted from original code by alexechu.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import collections
|
20 |
+
import functools
|
21 |
+
import os
|
22 |
+
from typing import List, Mapping, Tuple
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import tree
|
26 |
+
|
27 |
+
# Internal import (35fd).
|
28 |
+
|
29 |
+
|
30 |
+
# Distance from one CA to next CA [trans configuration: omega = 180].
|
31 |
+
ca_ca = 3.80209737096
|
32 |
+
|
33 |
+
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
|
34 |
+
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
|
35 |
+
# chi angles so their chi angle lists are empty.
|
36 |
+
chi_angles_atoms = {
|
37 |
+
"ALA": [],
|
38 |
+
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
|
39 |
+
"ARG": [
|
40 |
+
["N", "CA", "CB", "CG"],
|
41 |
+
["CA", "CB", "CG", "CD"],
|
42 |
+
["CB", "CG", "CD", "NE"],
|
43 |
+
["CG", "CD", "NE", "CZ"],
|
44 |
+
],
|
45 |
+
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
46 |
+
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
47 |
+
"CYS": [["N", "CA", "CB", "SG"]],
|
48 |
+
"GLN": [
|
49 |
+
["N", "CA", "CB", "CG"],
|
50 |
+
["CA", "CB", "CG", "CD"],
|
51 |
+
["CB", "CG", "CD", "OE1"],
|
52 |
+
],
|
53 |
+
"GLU": [
|
54 |
+
["N", "CA", "CB", "CG"],
|
55 |
+
["CA", "CB", "CG", "CD"],
|
56 |
+
["CB", "CG", "CD", "OE1"],
|
57 |
+
],
|
58 |
+
"GLY": [],
|
59 |
+
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
|
60 |
+
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
|
61 |
+
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
62 |
+
"LYS": [
|
63 |
+
["N", "CA", "CB", "CG"],
|
64 |
+
["CA", "CB", "CG", "CD"],
|
65 |
+
["CB", "CG", "CD", "CE"],
|
66 |
+
["CG", "CD", "CE", "NZ"],
|
67 |
+
],
|
68 |
+
"MET": [
|
69 |
+
["N", "CA", "CB", "CG"],
|
70 |
+
["CA", "CB", "CG", "SD"],
|
71 |
+
["CB", "CG", "SD", "CE"],
|
72 |
+
],
|
73 |
+
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
74 |
+
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
|
75 |
+
"SER": [["N", "CA", "CB", "OG"]],
|
76 |
+
"THR": [["N", "CA", "CB", "OG1"]],
|
77 |
+
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
78 |
+
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
79 |
+
"VAL": [["N", "CA", "CB", "CG1"]],
|
80 |
+
}
|
81 |
+
|
82 |
+
# If chi angles given in fixed-length array, this matrix determines how to mask
|
83 |
+
# them for each AA type. The order is as per restype_order (see below).
|
84 |
+
chi_angles_mask = [
|
85 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
86 |
+
[1.0, 1.0, 1.0, 1.0], # ARG
|
87 |
+
[1.0, 1.0, 0.0, 0.0], # ASN
|
88 |
+
[1.0, 1.0, 0.0, 0.0], # ASP
|
89 |
+
[1.0, 0.0, 0.0, 0.0], # CYS
|
90 |
+
[1.0, 1.0, 1.0, 0.0], # GLN
|
91 |
+
[1.0, 1.0, 1.0, 0.0], # GLU
|
92 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
93 |
+
[1.0, 1.0, 0.0, 0.0], # HIS
|
94 |
+
[1.0, 1.0, 0.0, 0.0], # ILE
|
95 |
+
[1.0, 1.0, 0.0, 0.0], # LEU
|
96 |
+
[1.0, 1.0, 1.0, 1.0], # LYS
|
97 |
+
[1.0, 1.0, 1.0, 0.0], # MET
|
98 |
+
[1.0, 1.0, 0.0, 0.0], # PHE
|
99 |
+
[1.0, 1.0, 0.0, 0.0], # PRO
|
100 |
+
[1.0, 0.0, 0.0, 0.0], # SER
|
101 |
+
[1.0, 0.0, 0.0, 0.0], # THR
|
102 |
+
[1.0, 1.0, 0.0, 0.0], # TRP
|
103 |
+
[1.0, 1.0, 0.0, 0.0], # TYR
|
104 |
+
[1.0, 0.0, 0.0, 0.0], # VAL
|
105 |
+
]
|
106 |
+
|
107 |
+
# The following chi angles are pi periodic: they can be rotated by a multiple
|
108 |
+
# of pi without affecting the structure.
|
109 |
+
chi_pi_periodic = [
|
110 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
111 |
+
[0.0, 0.0, 0.0, 0.0], # ARG
|
112 |
+
[0.0, 0.0, 0.0, 0.0], # ASN
|
113 |
+
[0.0, 1.0, 0.0, 0.0], # ASP
|
114 |
+
[0.0, 0.0, 0.0, 0.0], # CYS
|
115 |
+
[0.0, 0.0, 0.0, 0.0], # GLN
|
116 |
+
[0.0, 0.0, 1.0, 0.0], # GLU
|
117 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
118 |
+
[0.0, 0.0, 0.0, 0.0], # HIS
|
119 |
+
[0.0, 0.0, 0.0, 0.0], # ILE
|
120 |
+
[0.0, 0.0, 0.0, 0.0], # LEU
|
121 |
+
[0.0, 0.0, 0.0, 0.0], # LYS
|
122 |
+
[0.0, 0.0, 0.0, 0.0], # MET
|
123 |
+
[0.0, 1.0, 0.0, 0.0], # PHE
|
124 |
+
[0.0, 0.0, 0.0, 0.0], # PRO
|
125 |
+
[0.0, 0.0, 0.0, 0.0], # SER
|
126 |
+
[0.0, 0.0, 0.0, 0.0], # THR
|
127 |
+
[0.0, 0.0, 0.0, 0.0], # TRP
|
128 |
+
[0.0, 1.0, 0.0, 0.0], # TYR
|
129 |
+
[0.0, 0.0, 0.0, 0.0], # VAL
|
130 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
131 |
+
]
|
132 |
+
|
133 |
+
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
|
134 |
+
# psi and chi angles:
|
135 |
+
# 0: 'backbone group',
|
136 |
+
# 1: 'pre-omega-group', (empty)
|
137 |
+
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
|
138 |
+
# 3: 'psi-group',
|
139 |
+
# 4,5,6,7: 'chi1,2,3,4-group'
|
140 |
+
# The atom positions are relative to the axis-end-atom of the corresponding
|
141 |
+
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
|
142 |
+
# is defined such that the dihedral-angle-defining atom (the last entry in
|
143 |
+
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
|
144 |
+
# format: [atomname, group_idx, rel_position]
|
145 |
+
rigid_group_atom_positions = {
|
146 |
+
"ALA": [
|
147 |
+
["N", 0, (-0.525, 1.363, 0.000)],
|
148 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
149 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
150 |
+
["CB", 0, (-0.529, -0.774, -1.205)],
|
151 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
152 |
+
],
|
153 |
+
"ARG": [
|
154 |
+
["N", 0, (-0.524, 1.362, -0.000)],
|
155 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
156 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
157 |
+
["CB", 0, (-0.524, -0.778, -1.209)],
|
158 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
159 |
+
["CG", 4, (0.616, 1.390, -0.000)],
|
160 |
+
["CD", 5, (0.564, 1.414, 0.000)],
|
161 |
+
["NE", 6, (0.539, 1.357, -0.000)],
|
162 |
+
["NH1", 7, (0.206, 2.301, 0.000)],
|
163 |
+
["NH2", 7, (2.078, 0.978, -0.000)],
|
164 |
+
["CZ", 7, (0.758, 1.093, -0.000)],
|
165 |
+
],
|
166 |
+
"ASN": [
|
167 |
+
["N", 0, (-0.536, 1.357, 0.000)],
|
168 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
169 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
170 |
+
["CB", 0, (-0.531, -0.787, -1.200)],
|
171 |
+
["O", 3, (0.625, 1.062, 0.000)],
|
172 |
+
["CG", 4, (0.584, 1.399, 0.000)],
|
173 |
+
["ND2", 5, (0.593, -1.188, 0.001)],
|
174 |
+
["OD1", 5, (0.633, 1.059, 0.000)],
|
175 |
+
],
|
176 |
+
"ASP": [
|
177 |
+
["N", 0, (-0.525, 1.362, -0.000)],
|
178 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
179 |
+
["C", 0, (1.527, 0.000, -0.000)],
|
180 |
+
["CB", 0, (-0.526, -0.778, -1.208)],
|
181 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
182 |
+
["CG", 4, (0.593, 1.398, -0.000)],
|
183 |
+
["OD1", 5, (0.610, 1.091, 0.000)],
|
184 |
+
["OD2", 5, (0.592, -1.101, -0.003)],
|
185 |
+
],
|
186 |
+
"CYS": [
|
187 |
+
["N", 0, (-0.522, 1.362, -0.000)],
|
188 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
189 |
+
["C", 0, (1.524, 0.000, 0.000)],
|
190 |
+
["CB", 0, (-0.519, -0.773, -1.212)],
|
191 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
192 |
+
["SG", 4, (0.728, 1.653, 0.000)],
|
193 |
+
],
|
194 |
+
"GLN": [
|
195 |
+
["N", 0, (-0.526, 1.361, -0.000)],
|
196 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
197 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
198 |
+
["CB", 0, (-0.525, -0.779, -1.207)],
|
199 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
200 |
+
["CG", 4, (0.615, 1.393, 0.000)],
|
201 |
+
["CD", 5, (0.587, 1.399, -0.000)],
|
202 |
+
["NE2", 6, (0.593, -1.189, -0.001)],
|
203 |
+
["OE1", 6, (0.634, 1.060, 0.000)],
|
204 |
+
],
|
205 |
+
"GLU": [
|
206 |
+
["N", 0, (-0.528, 1.361, 0.000)],
|
207 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
208 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
209 |
+
["CB", 0, (-0.526, -0.781, -1.207)],
|
210 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
211 |
+
["CG", 4, (0.615, 1.392, 0.000)],
|
212 |
+
["CD", 5, (0.600, 1.397, 0.000)],
|
213 |
+
["OE1", 6, (0.607, 1.095, -0.000)],
|
214 |
+
["OE2", 6, (0.589, -1.104, -0.001)],
|
215 |
+
],
|
216 |
+
"GLY": [
|
217 |
+
["N", 0, (-0.572, 1.337, 0.000)],
|
218 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
219 |
+
["C", 0, (1.517, -0.000, -0.000)],
|
220 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
221 |
+
],
|
222 |
+
"HIS": [
|
223 |
+
["N", 0, (-0.527, 1.360, 0.000)],
|
224 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
225 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
226 |
+
["CB", 0, (-0.525, -0.778, -1.208)],
|
227 |
+
["O", 3, (0.625, 1.063, 0.000)],
|
228 |
+
["CG", 4, (0.600, 1.370, -0.000)],
|
229 |
+
["CD2", 5, (0.889, -1.021, 0.003)],
|
230 |
+
["ND1", 5, (0.744, 1.160, -0.000)],
|
231 |
+
["CE1", 5, (2.030, 0.851, 0.002)],
|
232 |
+
["NE2", 5, (2.145, -0.466, 0.004)],
|
233 |
+
],
|
234 |
+
"ILE": [
|
235 |
+
["N", 0, (-0.493, 1.373, -0.000)],
|
236 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
237 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
238 |
+
["CB", 0, (-0.536, -0.793, -1.213)],
|
239 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
240 |
+
["CG1", 4, (0.534, 1.437, -0.000)],
|
241 |
+
["CG2", 4, (0.540, -0.785, -1.199)],
|
242 |
+
["CD1", 5, (0.619, 1.391, 0.000)],
|
243 |
+
],
|
244 |
+
"LEU": [
|
245 |
+
["N", 0, (-0.520, 1.363, 0.000)],
|
246 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
247 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
248 |
+
["CB", 0, (-0.522, -0.773, -1.214)],
|
249 |
+
["O", 3, (0.625, 1.063, -0.000)],
|
250 |
+
["CG", 4, (0.678, 1.371, 0.000)],
|
251 |
+
["CD1", 5, (0.530, 1.430, -0.000)],
|
252 |
+
["CD2", 5, (0.535, -0.774, 1.200)],
|
253 |
+
],
|
254 |
+
"LYS": [
|
255 |
+
["N", 0, (-0.526, 1.362, -0.000)],
|
256 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
257 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
258 |
+
["CB", 0, (-0.524, -0.778, -1.208)],
|
259 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
260 |
+
["CG", 4, (0.619, 1.390, 0.000)],
|
261 |
+
["CD", 5, (0.559, 1.417, 0.000)],
|
262 |
+
["CE", 6, (0.560, 1.416, 0.000)],
|
263 |
+
["NZ", 7, (0.554, 1.387, 0.000)],
|
264 |
+
],
|
265 |
+
"MET": [
|
266 |
+
["N", 0, (-0.521, 1.364, -0.000)],
|
267 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
268 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
269 |
+
["CB", 0, (-0.523, -0.776, -1.210)],
|
270 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
271 |
+
["CG", 4, (0.613, 1.391, -0.000)],
|
272 |
+
["SD", 5, (0.703, 1.695, 0.000)],
|
273 |
+
["CE", 6, (0.320, 1.786, -0.000)],
|
274 |
+
],
|
275 |
+
"PHE": [
|
276 |
+
["N", 0, (-0.518, 1.363, 0.000)],
|
277 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
278 |
+
["C", 0, (1.524, 0.000, -0.000)],
|
279 |
+
["CB", 0, (-0.525, -0.776, -1.212)],
|
280 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
281 |
+
["CG", 4, (0.607, 1.377, 0.000)],
|
282 |
+
["CD1", 5, (0.709, 1.195, -0.000)],
|
283 |
+
["CD2", 5, (0.706, -1.196, 0.000)],
|
284 |
+
["CE1", 5, (2.102, 1.198, -0.000)],
|
285 |
+
["CE2", 5, (2.098, -1.201, -0.000)],
|
286 |
+
["CZ", 5, (2.794, -0.003, -0.001)],
|
287 |
+
],
|
288 |
+
"PRO": [
|
289 |
+
["N", 0, (-0.566, 1.351, -0.000)],
|
290 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
291 |
+
["C", 0, (1.527, -0.000, 0.000)],
|
292 |
+
["CB", 0, (-0.546, -0.611, -1.293)],
|
293 |
+
["O", 3, (0.621, 1.066, 0.000)],
|
294 |
+
["CG", 4, (0.382, 1.445, 0.0)],
|
295 |
+
# ['CD', 5, (0.427, 1.440, 0.0)],
|
296 |
+
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
|
297 |
+
],
|
298 |
+
"SER": [
|
299 |
+
["N", 0, (-0.529, 1.360, -0.000)],
|
300 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
301 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
302 |
+
["CB", 0, (-0.518, -0.777, -1.211)],
|
303 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
304 |
+
["OG", 4, (0.503, 1.325, 0.000)],
|
305 |
+
],
|
306 |
+
"THR": [
|
307 |
+
["N", 0, (-0.517, 1.364, 0.000)],
|
308 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
309 |
+
["C", 0, (1.526, 0.000, -0.000)],
|
310 |
+
["CB", 0, (-0.516, -0.793, -1.215)],
|
311 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
312 |
+
["CG2", 4, (0.550, -0.718, -1.228)],
|
313 |
+
["OG1", 4, (0.472, 1.353, 0.000)],
|
314 |
+
],
|
315 |
+
"TRP": [
|
316 |
+
["N", 0, (-0.521, 1.363, 0.000)],
|
317 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
318 |
+
["C", 0, (1.525, -0.000, 0.000)],
|
319 |
+
["CB", 0, (-0.523, -0.776, -1.212)],
|
320 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
321 |
+
["CG", 4, (0.609, 1.370, -0.000)],
|
322 |
+
["CD1", 5, (0.824, 1.091, 0.000)],
|
323 |
+
["CD2", 5, (0.854, -1.148, -0.005)],
|
324 |
+
["CE2", 5, (2.186, -0.678, -0.007)],
|
325 |
+
["CE3", 5, (0.622, -2.530, -0.007)],
|
326 |
+
["NE1", 5, (2.140, 0.690, -0.004)],
|
327 |
+
["CH2", 5, (3.028, -2.890, -0.013)],
|
328 |
+
["CZ2", 5, (3.283, -1.543, -0.011)],
|
329 |
+
["CZ3", 5, (1.715, -3.389, -0.011)],
|
330 |
+
],
|
331 |
+
"TYR": [
|
332 |
+
["N", 0, (-0.522, 1.362, 0.000)],
|
333 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
334 |
+
["C", 0, (1.524, -0.000, -0.000)],
|
335 |
+
["CB", 0, (-0.522, -0.776, -1.213)],
|
336 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
337 |
+
["CG", 4, (0.607, 1.382, -0.000)],
|
338 |
+
["CD1", 5, (0.716, 1.195, -0.000)],
|
339 |
+
["CD2", 5, (0.713, -1.194, -0.001)],
|
340 |
+
["CE1", 5, (2.107, 1.200, -0.002)],
|
341 |
+
["CE2", 5, (2.104, -1.201, -0.003)],
|
342 |
+
["OH", 5, (4.168, -0.002, -0.005)],
|
343 |
+
["CZ", 5, (2.791, -0.001, -0.003)],
|
344 |
+
],
|
345 |
+
"VAL": [
|
346 |
+
["N", 0, (-0.494, 1.373, -0.000)],
|
347 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
348 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
349 |
+
["CB", 0, (-0.533, -0.795, -1.213)],
|
350 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
351 |
+
["CG1", 4, (0.540, 1.429, -0.000)],
|
352 |
+
["CG2", 4, (0.533, -0.776, 1.203)],
|
353 |
+
],
|
354 |
+
}
|
355 |
+
|
356 |
+
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
|
357 |
+
residue_atoms = {
|
358 |
+
"ALA": ["C", "CA", "CB", "N", "O"],
|
359 |
+
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
|
360 |
+
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
|
361 |
+
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
|
362 |
+
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
|
363 |
+
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
|
364 |
+
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
|
365 |
+
"GLY": ["C", "CA", "N", "O"],
|
366 |
+
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
|
367 |
+
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
|
368 |
+
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
|
369 |
+
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
|
370 |
+
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
|
371 |
+
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
|
372 |
+
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
|
373 |
+
"SER": ["C", "CA", "CB", "N", "O", "OG"],
|
374 |
+
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
|
375 |
+
"TRP": [
|
376 |
+
"C",
|
377 |
+
"CA",
|
378 |
+
"CB",
|
379 |
+
"CG",
|
380 |
+
"CD1",
|
381 |
+
"CD2",
|
382 |
+
"CE2",
|
383 |
+
"CE3",
|
384 |
+
"CZ2",
|
385 |
+
"CZ3",
|
386 |
+
"CH2",
|
387 |
+
"N",
|
388 |
+
"NE1",
|
389 |
+
"O",
|
390 |
+
],
|
391 |
+
"TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
|
392 |
+
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
|
393 |
+
}
|
394 |
+
|
395 |
+
# Naming swaps for ambiguous atom names.
|
396 |
+
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
|
397 |
+
# 4 of the 20 amino acids.
|
398 |
+
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
|
399 |
+
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
|
400 |
+
# the 'ambiguous' atoms and their neighbours)
|
401 |
+
residue_atom_renaming_swaps = {
|
402 |
+
"ASP": {"OD1": "OD2"},
|
403 |
+
"GLU": {"OE1": "OE2"},
|
404 |
+
"PHE": {"CD1": "CD2", "CE1": "CE2"},
|
405 |
+
"TYR": {"CD1": "CD2", "CE1": "CE2"},
|
406 |
+
}
|
407 |
+
|
408 |
+
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
|
409 |
+
van_der_waals_radius = {
|
410 |
+
"C": 1.7,
|
411 |
+
"N": 1.55,
|
412 |
+
"O": 1.52,
|
413 |
+
"S": 1.8,
|
414 |
+
}
|
415 |
+
|
416 |
+
Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
|
417 |
+
BondAngle = collections.namedtuple(
|
418 |
+
"BondAngle", ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"]
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
@functools.lru_cache(maxsize=None)
|
423 |
+
def load_stereo_chemical_props() -> (
|
424 |
+
Tuple[
|
425 |
+
Mapping[str, List[Bond]],
|
426 |
+
Mapping[str, List[Bond]],
|
427 |
+
Mapping[str, List[BondAngle]],
|
428 |
+
]
|
429 |
+
):
|
430 |
+
"""Load stereo_chemical_props.txt into a nice structure.
|
431 |
+
|
432 |
+
Load literature values for bond lengths and bond angles and translate
|
433 |
+
bond angles into the length of the opposite edge of the triangle
|
434 |
+
("residue_virtual_bonds").
|
435 |
+
|
436 |
+
Returns:
|
437 |
+
residue_bonds: Dict that maps resname -> list of Bond tuples.
|
438 |
+
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
|
439 |
+
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
|
440 |
+
"""
|
441 |
+
stereo_chemical_props_path = os.path.join(
|
442 |
+
os.path.dirname(os.path.abspath(__file__)), "stereo_chemical_props.txt"
|
443 |
+
)
|
444 |
+
with open(stereo_chemical_props_path, "rt") as f:
|
445 |
+
stereo_chemical_props = f.read()
|
446 |
+
lines_iter = iter(stereo_chemical_props.splitlines())
|
447 |
+
# Load bond lengths.
|
448 |
+
residue_bonds = {}
|
449 |
+
next(lines_iter) # Skip header line.
|
450 |
+
for line in lines_iter:
|
451 |
+
if line.strip() == "-":
|
452 |
+
break
|
453 |
+
bond, resname, length, stddev = line.split()
|
454 |
+
atom1, atom2 = bond.split("-")
|
455 |
+
if resname not in residue_bonds:
|
456 |
+
residue_bonds[resname] = []
|
457 |
+
residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
|
458 |
+
residue_bonds["UNK"] = []
|
459 |
+
|
460 |
+
# Load bond angles.
|
461 |
+
residue_bond_angles = {}
|
462 |
+
next(lines_iter) # Skip empty line.
|
463 |
+
next(lines_iter) # Skip header line.
|
464 |
+
for line in lines_iter:
|
465 |
+
if line.strip() == "-":
|
466 |
+
break
|
467 |
+
bond, resname, angle_degree, stddev_degree = line.split()
|
468 |
+
atom1, atom2, atom3 = bond.split("-")
|
469 |
+
if resname not in residue_bond_angles:
|
470 |
+
residue_bond_angles[resname] = []
|
471 |
+
residue_bond_angles[resname].append(
|
472 |
+
BondAngle(
|
473 |
+
atom1,
|
474 |
+
atom2,
|
475 |
+
atom3,
|
476 |
+
float(angle_degree) / 180.0 * np.pi,
|
477 |
+
float(stddev_degree) / 180.0 * np.pi,
|
478 |
+
)
|
479 |
+
)
|
480 |
+
residue_bond_angles["UNK"] = []
|
481 |
+
|
482 |
+
def make_bond_key(atom1_name, atom2_name):
|
483 |
+
"""Unique key to lookup bonds."""
|
484 |
+
return "-".join(sorted([atom1_name, atom2_name]))
|
485 |
+
|
486 |
+
# Translate bond angles into distances ("virtual bonds").
|
487 |
+
residue_virtual_bonds = {}
|
488 |
+
for resname, bond_angles in residue_bond_angles.items():
|
489 |
+
# Create a fast lookup dict for bond lengths.
|
490 |
+
bond_cache = {}
|
491 |
+
for b in residue_bonds[resname]:
|
492 |
+
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
|
493 |
+
residue_virtual_bonds[resname] = []
|
494 |
+
for ba in bond_angles:
|
495 |
+
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
|
496 |
+
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
|
497 |
+
|
498 |
+
# Compute distance between atom1 and atom3 using the law of cosines
|
499 |
+
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
|
500 |
+
gamma = ba.angle_rad
|
501 |
+
length = np.sqrt(
|
502 |
+
bond1.length**2
|
503 |
+
+ bond2.length**2
|
504 |
+
- 2 * bond1.length * bond2.length * np.cos(gamma)
|
505 |
+
)
|
506 |
+
|
507 |
+
# Propagation of uncertainty assuming uncorrelated errors.
|
508 |
+
dl_outer = 0.5 / length
|
509 |
+
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
|
510 |
+
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
|
511 |
+
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
|
512 |
+
stddev = np.sqrt(
|
513 |
+
(dl_dgamma * ba.stddev) ** 2
|
514 |
+
+ (dl_db1 * bond1.stddev) ** 2
|
515 |
+
+ (dl_db2 * bond2.stddev) ** 2
|
516 |
+
)
|
517 |
+
residue_virtual_bonds[resname].append(
|
518 |
+
Bond(ba.atom1_name, ba.atom3name, length, stddev)
|
519 |
+
)
|
520 |
+
|
521 |
+
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
|
522 |
+
|
523 |
+
|
524 |
+
# Between-residue bond lengths for general bonds (first element) and for Proline
|
525 |
+
# (second element).
|
526 |
+
between_res_bond_length_c_n = [1.329, 1.341]
|
527 |
+
between_res_bond_length_stddev_c_n = [0.014, 0.016]
|
528 |
+
|
529 |
+
# Between-residue cos_angles.
|
530 |
+
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
|
531 |
+
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
|
532 |
+
|
533 |
+
# This mapping is used when we need to store atom data in a format that requires
|
534 |
+
# fixed atom data size for every residue (e.g. a numpy array).
|
535 |
+
atom_types = [
|
536 |
+
"N",
|
537 |
+
"CA",
|
538 |
+
"C",
|
539 |
+
"CB",
|
540 |
+
"O",
|
541 |
+
"CG",
|
542 |
+
"CG1",
|
543 |
+
"CG2",
|
544 |
+
"OG",
|
545 |
+
"OG1",
|
546 |
+
"SG",
|
547 |
+
"CD",
|
548 |
+
"CD1",
|
549 |
+
"CD2",
|
550 |
+
"ND1",
|
551 |
+
"ND2",
|
552 |
+
"OD1",
|
553 |
+
"OD2",
|
554 |
+
"SD",
|
555 |
+
"CE",
|
556 |
+
"CE1",
|
557 |
+
"CE2",
|
558 |
+
"CE3",
|
559 |
+
"NE",
|
560 |
+
"NE1",
|
561 |
+
"NE2",
|
562 |
+
"OE1",
|
563 |
+
"OE2",
|
564 |
+
"CH2",
|
565 |
+
"NH1",
|
566 |
+
"NH2",
|
567 |
+
"OH",
|
568 |
+
"CZ",
|
569 |
+
"CZ2",
|
570 |
+
"CZ3",
|
571 |
+
"NZ",
|
572 |
+
"OXT",
|
573 |
+
]
|
574 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
575 |
+
atom_type_num = len(atom_types) # := 37.
|
576 |
+
|
577 |
+
# A compact atom encoding with 14 columns
|
578 |
+
# pylint: disable=line-too-long
|
579 |
+
# pylint: disable=bad-whitespace
|
580 |
+
restype_name_to_atom14_names = {
|
581 |
+
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
582 |
+
"ARG": [
|
583 |
+
"N",
|
584 |
+
"CA",
|
585 |
+
"C",
|
586 |
+
"O",
|
587 |
+
"CB",
|
588 |
+
"CG",
|
589 |
+
"CD",
|
590 |
+
"NE",
|
591 |
+
"CZ",
|
592 |
+
"NH1",
|
593 |
+
"NH2",
|
594 |
+
"",
|
595 |
+
"",
|
596 |
+
"",
|
597 |
+
],
|
598 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
|
599 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
|
600 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
601 |
+
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
|
602 |
+
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
|
603 |
+
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
604 |
+
"HIS": [
|
605 |
+
"N",
|
606 |
+
"CA",
|
607 |
+
"C",
|
608 |
+
"O",
|
609 |
+
"CB",
|
610 |
+
"CG",
|
611 |
+
"ND1",
|
612 |
+
"CD2",
|
613 |
+
"CE1",
|
614 |
+
"NE2",
|
615 |
+
"",
|
616 |
+
"",
|
617 |
+
"",
|
618 |
+
"",
|
619 |
+
],
|
620 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
|
621 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
|
622 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
|
623 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
|
624 |
+
"PHE": [
|
625 |
+
"N",
|
626 |
+
"CA",
|
627 |
+
"C",
|
628 |
+
"O",
|
629 |
+
"CB",
|
630 |
+
"CG",
|
631 |
+
"CD1",
|
632 |
+
"CD2",
|
633 |
+
"CE1",
|
634 |
+
"CE2",
|
635 |
+
"CZ",
|
636 |
+
"",
|
637 |
+
"",
|
638 |
+
"",
|
639 |
+
],
|
640 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
641 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
642 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
|
643 |
+
"TRP": [
|
644 |
+
"N",
|
645 |
+
"CA",
|
646 |
+
"C",
|
647 |
+
"O",
|
648 |
+
"CB",
|
649 |
+
"CG",
|
650 |
+
"CD1",
|
651 |
+
"CD2",
|
652 |
+
"NE1",
|
653 |
+
"CE2",
|
654 |
+
"CE3",
|
655 |
+
"CZ2",
|
656 |
+
"CZ3",
|
657 |
+
"CH2",
|
658 |
+
],
|
659 |
+
"TYR": [
|
660 |
+
"N",
|
661 |
+
"CA",
|
662 |
+
"C",
|
663 |
+
"O",
|
664 |
+
"CB",
|
665 |
+
"CG",
|
666 |
+
"CD1",
|
667 |
+
"CD2",
|
668 |
+
"CE1",
|
669 |
+
"CE2",
|
670 |
+
"CZ",
|
671 |
+
"OH",
|
672 |
+
"",
|
673 |
+
"",
|
674 |
+
],
|
675 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
|
676 |
+
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
|
677 |
+
}
|
678 |
+
# pylint: enable=line-too-long
|
679 |
+
# pylint: enable=bad-whitespace
|
680 |
+
|
681 |
+
|
682 |
+
# This is the standard residue order when coding AA type as a number.
|
683 |
+
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
|
684 |
+
restypes = [
|
685 |
+
"A",
|
686 |
+
"R",
|
687 |
+
"N",
|
688 |
+
"D",
|
689 |
+
"C",
|
690 |
+
"Q",
|
691 |
+
"E",
|
692 |
+
"G",
|
693 |
+
"H",
|
694 |
+
"I",
|
695 |
+
"L",
|
696 |
+
"K",
|
697 |
+
"M",
|
698 |
+
"F",
|
699 |
+
"P",
|
700 |
+
"S",
|
701 |
+
"T",
|
702 |
+
"W",
|
703 |
+
"Y",
|
704 |
+
"V",
|
705 |
+
]
|
706 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
707 |
+
restype_num = len(restypes) # := 20.
|
708 |
+
unk_restype_index = restype_num # Catch-all index for unknown restypes.
|
709 |
+
|
710 |
+
restypes_with_x = restypes + ["X"]
|
711 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
712 |
+
|
713 |
+
|
714 |
+
def sequence_to_onehot(
|
715 |
+
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
|
716 |
+
) -> np.ndarray:
|
717 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
718 |
+
|
719 |
+
Args:
|
720 |
+
sequence: An amino acid sequence.
|
721 |
+
mapping: A dictionary mapping amino acids to integers.
|
722 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
723 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
724 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
725 |
+
the mapping will throw an error.
|
726 |
+
|
727 |
+
Returns:
|
728 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
729 |
+
the sequence.
|
730 |
+
|
731 |
+
Raises:
|
732 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
733 |
+
num_unique_aas - 1 without any gaps.
|
734 |
+
"""
|
735 |
+
num_entries = max(mapping.values()) + 1
|
736 |
+
|
737 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
738 |
+
raise ValueError(
|
739 |
+
"The mapping must have values from 0 to num_unique_aas-1 "
|
740 |
+
"without any gaps. Got: %s" % sorted(mapping.values())
|
741 |
+
)
|
742 |
+
|
743 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
|
744 |
+
|
745 |
+
for aa_index, aa_type in enumerate(sequence):
|
746 |
+
if map_unknown_to_x:
|
747 |
+
if aa_type.isalpha() and aa_type.isupper():
|
748 |
+
aa_id = mapping.get(aa_type, mapping["X"])
|
749 |
+
else:
|
750 |
+
raise ValueError(f"Invalid character in the sequence: {aa_type}")
|
751 |
+
else:
|
752 |
+
aa_id = mapping[aa_type]
|
753 |
+
one_hot_arr[aa_index, aa_id] = 1
|
754 |
+
|
755 |
+
return one_hot_arr
|
756 |
+
|
757 |
+
|
758 |
+
restype_1to3 = {
|
759 |
+
"A": "ALA",
|
760 |
+
"R": "ARG",
|
761 |
+
"N": "ASN",
|
762 |
+
"D": "ASP",
|
763 |
+
"C": "CYS",
|
764 |
+
"Q": "GLN",
|
765 |
+
"E": "GLU",
|
766 |
+
"G": "GLY",
|
767 |
+
"H": "HIS",
|
768 |
+
"I": "ILE",
|
769 |
+
"L": "LEU",
|
770 |
+
"K": "LYS",
|
771 |
+
"M": "MET",
|
772 |
+
"F": "PHE",
|
773 |
+
"P": "PRO",
|
774 |
+
"S": "SER",
|
775 |
+
"T": "THR",
|
776 |
+
"W": "TRP",
|
777 |
+
"Y": "TYR",
|
778 |
+
"V": "VAL",
|
779 |
+
}
|
780 |
+
|
781 |
+
|
782 |
+
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
|
783 |
+
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
|
784 |
+
# many more, and less common, three letter names as keys and maps many of these
|
785 |
+
# to the same one letter name (including 'X' and 'U' which we don't use here).
|
786 |
+
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
787 |
+
|
788 |
+
# Define a restype name for all unknown residues.
|
789 |
+
unk_restype = "UNK"
|
790 |
+
|
791 |
+
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
|
792 |
+
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
|
793 |
+
|
794 |
+
|
795 |
+
# Define exploded all-atom representation (atom73)
|
796 |
+
atom73_names = ['N', 'CA', 'C', 'CB', 'O']
|
797 |
+
for aa1 in restypes:
|
798 |
+
aa3 = restype_1to3[aa1]
|
799 |
+
atom_list = residue_atoms[aa3]
|
800 |
+
for atom in atom_types:
|
801 |
+
if atom in atom_list and atom not in atom73_names:
|
802 |
+
atom73_names.append(f'{aa1}{atom}')
|
803 |
+
|
804 |
+
atom73_names_to_idx = {a: i for i, a in enumerate(atom73_names)}
|
805 |
+
|
806 |
+
restype_atom73_mask = np.zeros((22, 73))
|
807 |
+
for i, restype in enumerate(restypes):
|
808 |
+
for atom_name in atom_types:
|
809 |
+
atom73_name = atom_name
|
810 |
+
if atom_name not in ['N', 'CA', 'C', 'CB', 'O']:
|
811 |
+
atom73_name = restype + atom_name
|
812 |
+
if atom73_name in atom73_names_to_idx:
|
813 |
+
atom73_idx = atom73_names_to_idx[atom73_name]
|
814 |
+
restype_atom73_mask[i, atom73_idx] = 1
|
815 |
+
# Remove CB for glycine
|
816 |
+
restype_atom73_mask[restype_order["G"], 3] = 0
|
817 |
+
# Backbone atoms for unk and mask
|
818 |
+
restype_atom73_mask[-2:, [0, 1, 2, 4]] = 1
|
819 |
+
|
820 |
+
|
821 |
+
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
|
822 |
+
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
|
823 |
+
# remaining 20 amino acids are kept in alphabetical order.
|
824 |
+
# There are 2 non-amino acid codes, X (representing any amino acid) and
|
825 |
+
# "-" representing a missing amino acid in an alignment. The id for these
|
826 |
+
# codes is put at the end (20 and 21) so that they can easily be ignored if
|
827 |
+
# desired.
|
828 |
+
HHBLITS_AA_TO_ID = {
|
829 |
+
"A": 0,
|
830 |
+
"B": 2,
|
831 |
+
"C": 1,
|
832 |
+
"D": 2,
|
833 |
+
"E": 3,
|
834 |
+
"F": 4,
|
835 |
+
"G": 5,
|
836 |
+
"H": 6,
|
837 |
+
"I": 7,
|
838 |
+
"J": 20,
|
839 |
+
"K": 8,
|
840 |
+
"L": 9,
|
841 |
+
"M": 10,
|
842 |
+
"N": 11,
|
843 |
+
"O": 20,
|
844 |
+
"P": 12,
|
845 |
+
"Q": 13,
|
846 |
+
"R": 14,
|
847 |
+
"S": 15,
|
848 |
+
"T": 16,
|
849 |
+
"U": 1,
|
850 |
+
"V": 17,
|
851 |
+
"W": 18,
|
852 |
+
"X": 20,
|
853 |
+
"Y": 19,
|
854 |
+
"Z": 3,
|
855 |
+
"-": 21,
|
856 |
+
}
|
857 |
+
|
858 |
+
# Partial inversion of HHBLITS_AA_TO_ID.
|
859 |
+
ID_TO_HHBLITS_AA = {
|
860 |
+
0: "A",
|
861 |
+
1: "C", # Also U.
|
862 |
+
2: "D", # Also B.
|
863 |
+
3: "E", # Also Z.
|
864 |
+
4: "F",
|
865 |
+
5: "G",
|
866 |
+
6: "H",
|
867 |
+
7: "I",
|
868 |
+
8: "K",
|
869 |
+
9: "L",
|
870 |
+
10: "M",
|
871 |
+
11: "N",
|
872 |
+
12: "P",
|
873 |
+
13: "Q",
|
874 |
+
14: "R",
|
875 |
+
15: "S",
|
876 |
+
16: "T",
|
877 |
+
17: "V",
|
878 |
+
18: "W",
|
879 |
+
19: "Y",
|
880 |
+
20: "X", # Includes J and O.
|
881 |
+
21: "-",
|
882 |
+
}
|
883 |
+
|
884 |
+
restypes_with_x_and_gap = restypes + ["X", "-"]
|
885 |
+
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
|
886 |
+
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
|
887 |
+
for i in range(len(restypes_with_x_and_gap))
|
888 |
+
)
|
889 |
+
|
890 |
+
|
891 |
+
def _make_standard_atom_mask() -> np.ndarray:
|
892 |
+
"""Returns [num_res_types, num_atom_types] mask array."""
|
893 |
+
# +1 to account for unknown (all 0s).
|
894 |
+
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
|
895 |
+
for restype, restype_letter in enumerate(restypes):
|
896 |
+
restype_name = restype_1to3[restype_letter]
|
897 |
+
atom_names = residue_atoms[restype_name]
|
898 |
+
for atom_name in atom_names:
|
899 |
+
atom_type = atom_order[atom_name]
|
900 |
+
mask[restype, atom_type] = 1
|
901 |
+
return mask
|
902 |
+
|
903 |
+
|
904 |
+
STANDARD_ATOM_MASK = _make_standard_atom_mask()
|
905 |
+
|
906 |
+
|
907 |
+
# A one hot representation for the first and second atoms defining the axis
|
908 |
+
# of rotation for each chi-angle in each residue.
|
909 |
+
def chi_angle_atom(atom_index: int) -> np.ndarray:
|
910 |
+
"""Define chi-angle rigid groups via one-hot representations."""
|
911 |
+
chi_angles_index = {}
|
912 |
+
one_hots = []
|
913 |
+
|
914 |
+
for k, v in chi_angles_atoms.items():
|
915 |
+
indices = [atom_types.index(s[atom_index]) for s in v]
|
916 |
+
indices.extend([-1] * (4 - len(indices)))
|
917 |
+
chi_angles_index[k] = indices
|
918 |
+
|
919 |
+
for r in restypes:
|
920 |
+
res3 = restype_1to3[r]
|
921 |
+
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
|
922 |
+
one_hots.append(one_hot)
|
923 |
+
|
924 |
+
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
|
925 |
+
one_hot = np.stack(one_hots, axis=0)
|
926 |
+
one_hot = np.transpose(one_hot, [0, 2, 1])
|
927 |
+
|
928 |
+
return one_hot
|
929 |
+
|
930 |
+
|
931 |
+
chi_atom_1_one_hot = chi_angle_atom(1)
|
932 |
+
chi_atom_2_one_hot = chi_angle_atom(2)
|
933 |
+
|
934 |
+
# An array like chi_angles_atoms but using indices rather than names.
|
935 |
+
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
|
936 |
+
chi_angles_atom_indices = tree.map_structure(
|
937 |
+
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
|
938 |
+
)
|
939 |
+
chi_angles_atom_indices = np.array(
|
940 |
+
[
|
941 |
+
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
|
942 |
+
for chi_atoms in chi_angles_atom_indices
|
943 |
+
]
|
944 |
+
)
|
945 |
+
|
946 |
+
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
|
947 |
+
# and atom index within that group.
|
948 |
+
chi_groups_for_atom = collections.defaultdict(list)
|
949 |
+
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
|
950 |
+
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
|
951 |
+
for atom_i, atom in enumerate(chi_group):
|
952 |
+
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
|
953 |
+
chi_groups_for_atom = dict(chi_groups_for_atom)
|
954 |
+
|
955 |
+
|
956 |
+
def _make_rigid_transformation_4x4(ex, ey, translation):
|
957 |
+
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
|
958 |
+
# Normalize ex.
|
959 |
+
ex_normalized = ex / np.linalg.norm(ex)
|
960 |
+
|
961 |
+
# make ey perpendicular to ex
|
962 |
+
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
|
963 |
+
ey_normalized /= np.linalg.norm(ey_normalized)
|
964 |
+
|
965 |
+
# compute ez as cross product
|
966 |
+
eznorm = np.cross(ex_normalized, ey_normalized)
|
967 |
+
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
|
968 |
+
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
|
969 |
+
return m
|
970 |
+
|
971 |
+
|
972 |
+
# create an array with (restype, atomtype) --> rigid_group_idx
|
973 |
+
# and an array with (restype, atomtype, coord) for the atom positions
|
974 |
+
# and compute affine transformation matrices (4,4) from one rigid group to the
|
975 |
+
# previous group
|
976 |
+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
|
977 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
978 |
+
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
|
979 |
+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
|
980 |
+
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
|
981 |
+
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
|
982 |
+
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
|
983 |
+
|
984 |
+
|
985 |
+
def _make_rigid_group_constants():
|
986 |
+
"""Fill the arrays above."""
|
987 |
+
for restype, restype_letter in enumerate(restypes):
|
988 |
+
resname = restype_1to3[restype_letter]
|
989 |
+
for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
|
990 |
+
atomtype = atom_order[atomname]
|
991 |
+
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
|
992 |
+
restype_atom37_mask[restype, atomtype] = 1
|
993 |
+
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
|
994 |
+
|
995 |
+
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
|
996 |
+
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
|
997 |
+
restype_atom14_mask[restype, atom14idx] = 1
|
998 |
+
restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
|
999 |
+
|
1000 |
+
for restype, restype_letter in enumerate(restypes):
|
1001 |
+
resname = restype_1to3[restype_letter]
|
1002 |
+
atom_positions = {
|
1003 |
+
name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
# backbone to backbone is the identity transform
|
1007 |
+
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
|
1008 |
+
|
1009 |
+
# pre-omega-frame to backbone (currently dummy identity matrix)
|
1010 |
+
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
|
1011 |
+
|
1012 |
+
# phi-frame to backbone
|
1013 |
+
mat = _make_rigid_transformation_4x4(
|
1014 |
+
ex=atom_positions["N"] - atom_positions["CA"],
|
1015 |
+
ey=np.array([1.0, 0.0, 0.0]),
|
1016 |
+
translation=atom_positions["N"],
|
1017 |
+
)
|
1018 |
+
restype_rigid_group_default_frame[restype, 2, :, :] = mat
|
1019 |
+
|
1020 |
+
# psi-frame to backbone
|
1021 |
+
mat = _make_rigid_transformation_4x4(
|
1022 |
+
ex=atom_positions["C"] - atom_positions["CA"],
|
1023 |
+
ey=atom_positions["CA"] - atom_positions["N"],
|
1024 |
+
translation=atom_positions["C"],
|
1025 |
+
)
|
1026 |
+
restype_rigid_group_default_frame[restype, 3, :, :] = mat
|
1027 |
+
|
1028 |
+
# chi1-frame to backbone
|
1029 |
+
if chi_angles_mask[restype][0]:
|
1030 |
+
base_atom_names = chi_angles_atoms[resname][0]
|
1031 |
+
base_atom_positions = [atom_positions[name] for name in base_atom_names]
|
1032 |
+
mat = _make_rigid_transformation_4x4(
|
1033 |
+
ex=base_atom_positions[2] - base_atom_positions[1],
|
1034 |
+
ey=base_atom_positions[0] - base_atom_positions[1],
|
1035 |
+
translation=base_atom_positions[2],
|
1036 |
+
)
|
1037 |
+
restype_rigid_group_default_frame[restype, 4, :, :] = mat
|
1038 |
+
|
1039 |
+
# chi2-frame to chi1-frame
|
1040 |
+
# chi3-frame to chi2-frame
|
1041 |
+
# chi4-frame to chi3-frame
|
1042 |
+
# luckily all rotation axes for the next frame start at (0,0,0) of the
|
1043 |
+
# previous frame
|
1044 |
+
for chi_idx in range(1, 4):
|
1045 |
+
if chi_angles_mask[restype][chi_idx]:
|
1046 |
+
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
|
1047 |
+
axis_end_atom_position = atom_positions[axis_end_atom_name]
|
1048 |
+
mat = _make_rigid_transformation_4x4(
|
1049 |
+
ex=axis_end_atom_position,
|
1050 |
+
ey=np.array([-1.0, 0.0, 0.0]),
|
1051 |
+
translation=axis_end_atom_position,
|
1052 |
+
)
|
1053 |
+
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
|
1054 |
+
|
1055 |
+
|
1056 |
+
_make_rigid_group_constants()
|
1057 |
+
|
1058 |
+
|
1059 |
+
def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15):
|
1060 |
+
"""compute upper and lower bounds for bonds to assess violations."""
|
1061 |
+
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
|
1062 |
+
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
|
1063 |
+
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
|
1064 |
+
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
|
1065 |
+
for restype, restype_letter in enumerate(restypes):
|
1066 |
+
resname = restype_1to3[restype_letter]
|
1067 |
+
atom_list = restype_name_to_atom14_names[resname]
|
1068 |
+
|
1069 |
+
# create lower and upper bounds for clashes
|
1070 |
+
for atom1_idx, atom1_name in enumerate(atom_list):
|
1071 |
+
if not atom1_name:
|
1072 |
+
continue
|
1073 |
+
atom1_radius = van_der_waals_radius[atom1_name[0]]
|
1074 |
+
for atom2_idx, atom2_name in enumerate(atom_list):
|
1075 |
+
if (not atom2_name) or atom1_idx == atom2_idx:
|
1076 |
+
continue
|
1077 |
+
atom2_radius = van_der_waals_radius[atom2_name[0]]
|
1078 |
+
lower = atom1_radius + atom2_radius - overlap_tolerance
|
1079 |
+
upper = 1e10
|
1080 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
1081 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
1082 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
1083 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
1084 |
+
|
1085 |
+
# overwrite lower and upper bounds for bonds and angles
|
1086 |
+
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
|
1087 |
+
atom1_idx = atom_list.index(b.atom1_name)
|
1088 |
+
atom2_idx = atom_list.index(b.atom2_name)
|
1089 |
+
lower = b.length - bond_length_tolerance_factor * b.stddev
|
1090 |
+
upper = b.length + bond_length_tolerance_factor * b.stddev
|
1091 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
1092 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
1093 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
1094 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
1095 |
+
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
|
1096 |
+
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
|
1097 |
+
return {
|
1098 |
+
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
|
1099 |
+
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
|
1100 |
+
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
|
1101 |
+
}
|
1102 |
+
|
1103 |
+
|
1104 |
+
standard_residue_bonds, _, standard_residue_bond_angles = load_stereo_chemical_props()
|
core/stereo_chemical_props.txt
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Bond Residue Mean StdDev
|
2 |
+
CA-CB ALA 1.520 0.021
|
3 |
+
N-CA ALA 1.459 0.020
|
4 |
+
CA-C ALA 1.525 0.026
|
5 |
+
C-O ALA 1.229 0.019
|
6 |
+
CA-CB ARG 1.535 0.022
|
7 |
+
CB-CG ARG 1.521 0.027
|
8 |
+
CG-CD ARG 1.515 0.025
|
9 |
+
CD-NE ARG 1.460 0.017
|
10 |
+
NE-CZ ARG 1.326 0.013
|
11 |
+
CZ-NH1 ARG 1.326 0.013
|
12 |
+
CZ-NH2 ARG 1.326 0.013
|
13 |
+
N-CA ARG 1.459 0.020
|
14 |
+
CA-C ARG 1.525 0.026
|
15 |
+
C-O ARG 1.229 0.019
|
16 |
+
CA-CB ASN 1.527 0.026
|
17 |
+
CB-CG ASN 1.506 0.023
|
18 |
+
CG-OD1 ASN 1.235 0.022
|
19 |
+
CG-ND2 ASN 1.324 0.025
|
20 |
+
N-CA ASN 1.459 0.020
|
21 |
+
CA-C ASN 1.525 0.026
|
22 |
+
C-O ASN 1.229 0.019
|
23 |
+
CA-CB ASP 1.535 0.022
|
24 |
+
CB-CG ASP 1.513 0.021
|
25 |
+
CG-OD1 ASP 1.249 0.023
|
26 |
+
CG-OD2 ASP 1.249 0.023
|
27 |
+
N-CA ASP 1.459 0.020
|
28 |
+
CA-C ASP 1.525 0.026
|
29 |
+
C-O ASP 1.229 0.019
|
30 |
+
CA-CB CYS 1.526 0.013
|
31 |
+
CB-SG CYS 1.812 0.016
|
32 |
+
N-CA CYS 1.459 0.020
|
33 |
+
CA-C CYS 1.525 0.026
|
34 |
+
C-O CYS 1.229 0.019
|
35 |
+
CA-CB GLU 1.535 0.022
|
36 |
+
CB-CG GLU 1.517 0.019
|
37 |
+
CG-CD GLU 1.515 0.015
|
38 |
+
CD-OE1 GLU 1.252 0.011
|
39 |
+
CD-OE2 GLU 1.252 0.011
|
40 |
+
N-CA GLU 1.459 0.020
|
41 |
+
CA-C GLU 1.525 0.026
|
42 |
+
C-O GLU 1.229 0.019
|
43 |
+
CA-CB GLN 1.535 0.022
|
44 |
+
CB-CG GLN 1.521 0.027
|
45 |
+
CG-CD GLN 1.506 0.023
|
46 |
+
CD-OE1 GLN 1.235 0.022
|
47 |
+
CD-NE2 GLN 1.324 0.025
|
48 |
+
N-CA GLN 1.459 0.020
|
49 |
+
CA-C GLN 1.525 0.026
|
50 |
+
C-O GLN 1.229 0.019
|
51 |
+
N-CA GLY 1.456 0.015
|
52 |
+
CA-C GLY 1.514 0.016
|
53 |
+
C-O GLY 1.232 0.016
|
54 |
+
CA-CB HIS 1.535 0.022
|
55 |
+
CB-CG HIS 1.492 0.016
|
56 |
+
CG-ND1 HIS 1.369 0.015
|
57 |
+
CG-CD2 HIS 1.353 0.017
|
58 |
+
ND1-CE1 HIS 1.343 0.025
|
59 |
+
CD2-NE2 HIS 1.415 0.021
|
60 |
+
CE1-NE2 HIS 1.322 0.023
|
61 |
+
N-CA HIS 1.459 0.020
|
62 |
+
CA-C HIS 1.525 0.026
|
63 |
+
C-O HIS 1.229 0.019
|
64 |
+
CA-CB ILE 1.544 0.023
|
65 |
+
CB-CG1 ILE 1.536 0.028
|
66 |
+
CB-CG2 ILE 1.524 0.031
|
67 |
+
CG1-CD1 ILE 1.500 0.069
|
68 |
+
N-CA ILE 1.459 0.020
|
69 |
+
CA-C ILE 1.525 0.026
|
70 |
+
C-O ILE 1.229 0.019
|
71 |
+
CA-CB LEU 1.533 0.023
|
72 |
+
CB-CG LEU 1.521 0.029
|
73 |
+
CG-CD1 LEU 1.514 0.037
|
74 |
+
CG-CD2 LEU 1.514 0.037
|
75 |
+
N-CA LEU 1.459 0.020
|
76 |
+
CA-C LEU 1.525 0.026
|
77 |
+
C-O LEU 1.229 0.019
|
78 |
+
CA-CB LYS 1.535 0.022
|
79 |
+
CB-CG LYS 1.521 0.027
|
80 |
+
CG-CD LYS 1.520 0.034
|
81 |
+
CD-CE LYS 1.508 0.025
|
82 |
+
CE-NZ LYS 1.486 0.025
|
83 |
+
N-CA LYS 1.459 0.020
|
84 |
+
CA-C LYS 1.525 0.026
|
85 |
+
C-O LYS 1.229 0.019
|
86 |
+
CA-CB MET 1.535 0.022
|
87 |
+
CB-CG MET 1.509 0.032
|
88 |
+
CG-SD MET 1.807 0.026
|
89 |
+
SD-CE MET 1.774 0.056
|
90 |
+
N-CA MET 1.459 0.020
|
91 |
+
CA-C MET 1.525 0.026
|
92 |
+
C-O MET 1.229 0.019
|
93 |
+
CA-CB PHE 1.535 0.022
|
94 |
+
CB-CG PHE 1.509 0.017
|
95 |
+
CG-CD1 PHE 1.383 0.015
|
96 |
+
CG-CD2 PHE 1.383 0.015
|
97 |
+
CD1-CE1 PHE 1.388 0.020
|
98 |
+
CD2-CE2 PHE 1.388 0.020
|
99 |
+
CE1-CZ PHE 1.369 0.019
|
100 |
+
CE2-CZ PHE 1.369 0.019
|
101 |
+
N-CA PHE 1.459 0.020
|
102 |
+
CA-C PHE 1.525 0.026
|
103 |
+
C-O PHE 1.229 0.019
|
104 |
+
CA-CB PRO 1.531 0.020
|
105 |
+
CB-CG PRO 1.495 0.050
|
106 |
+
CG-CD PRO 1.502 0.033
|
107 |
+
CD-N PRO 1.474 0.014
|
108 |
+
N-CA PRO 1.468 0.017
|
109 |
+
CA-C PRO 1.524 0.020
|
110 |
+
C-O PRO 1.228 0.020
|
111 |
+
CA-CB SER 1.525 0.015
|
112 |
+
CB-OG SER 1.418 0.013
|
113 |
+
N-CA SER 1.459 0.020
|
114 |
+
CA-C SER 1.525 0.026
|
115 |
+
C-O SER 1.229 0.019
|
116 |
+
CA-CB THR 1.529 0.026
|
117 |
+
CB-OG1 THR 1.428 0.020
|
118 |
+
CB-CG2 THR 1.519 0.033
|
119 |
+
N-CA THR 1.459 0.020
|
120 |
+
CA-C THR 1.525 0.026
|
121 |
+
C-O THR 1.229 0.019
|
122 |
+
CA-CB TRP 1.535 0.022
|
123 |
+
CB-CG TRP 1.498 0.018
|
124 |
+
CG-CD1 TRP 1.363 0.014
|
125 |
+
CG-CD2 TRP 1.432 0.017
|
126 |
+
CD1-NE1 TRP 1.375 0.017
|
127 |
+
NE1-CE2 TRP 1.371 0.013
|
128 |
+
CD2-CE2 TRP 1.409 0.012
|
129 |
+
CD2-CE3 TRP 1.399 0.015
|
130 |
+
CE2-CZ2 TRP 1.393 0.017
|
131 |
+
CE3-CZ3 TRP 1.380 0.017
|
132 |
+
CZ2-CH2 TRP 1.369 0.019
|
133 |
+
CZ3-CH2 TRP 1.396 0.016
|
134 |
+
N-CA TRP 1.459 0.020
|
135 |
+
CA-C TRP 1.525 0.026
|
136 |
+
C-O TRP 1.229 0.019
|
137 |
+
CA-CB TYR 1.535 0.022
|
138 |
+
CB-CG TYR 1.512 0.015
|
139 |
+
CG-CD1 TYR 1.387 0.013
|
140 |
+
CG-CD2 TYR 1.387 0.013
|
141 |
+
CD1-CE1 TYR 1.389 0.015
|
142 |
+
CD2-CE2 TYR 1.389 0.015
|
143 |
+
CE1-CZ TYR 1.381 0.013
|
144 |
+
CE2-CZ TYR 1.381 0.013
|
145 |
+
CZ-OH TYR 1.374 0.017
|
146 |
+
N-CA TYR 1.459 0.020
|
147 |
+
CA-C TYR 1.525 0.026
|
148 |
+
C-O TYR 1.229 0.019
|
149 |
+
CA-CB VAL 1.543 0.021
|
150 |
+
CB-CG1 VAL 1.524 0.021
|
151 |
+
CB-CG2 VAL 1.524 0.021
|
152 |
+
N-CA VAL 1.459 0.020
|
153 |
+
CA-C VAL 1.525 0.026
|
154 |
+
C-O VAL 1.229 0.019
|
155 |
+
-
|
156 |
+
|
157 |
+
Angle Residue Mean StdDev
|
158 |
+
N-CA-CB ALA 110.1 1.4
|
159 |
+
CB-CA-C ALA 110.1 1.5
|
160 |
+
N-CA-C ALA 111.0 2.7
|
161 |
+
CA-C-O ALA 120.1 2.1
|
162 |
+
N-CA-CB ARG 110.6 1.8
|
163 |
+
CB-CA-C ARG 110.4 2.0
|
164 |
+
CA-CB-CG ARG 113.4 2.2
|
165 |
+
CB-CG-CD ARG 111.6 2.6
|
166 |
+
CG-CD-NE ARG 111.8 2.1
|
167 |
+
CD-NE-CZ ARG 123.6 1.4
|
168 |
+
NE-CZ-NH1 ARG 120.3 0.5
|
169 |
+
NE-CZ-NH2 ARG 120.3 0.5
|
170 |
+
NH1-CZ-NH2 ARG 119.4 1.1
|
171 |
+
N-CA-C ARG 111.0 2.7
|
172 |
+
CA-C-O ARG 120.1 2.1
|
173 |
+
N-CA-CB ASN 110.6 1.8
|
174 |
+
CB-CA-C ASN 110.4 2.0
|
175 |
+
CA-CB-CG ASN 113.4 2.2
|
176 |
+
CB-CG-ND2 ASN 116.7 2.4
|
177 |
+
CB-CG-OD1 ASN 121.6 2.0
|
178 |
+
ND2-CG-OD1 ASN 121.9 2.3
|
179 |
+
N-CA-C ASN 111.0 2.7
|
180 |
+
CA-C-O ASN 120.1 2.1
|
181 |
+
N-CA-CB ASP 110.6 1.8
|
182 |
+
CB-CA-C ASP 110.4 2.0
|
183 |
+
CA-CB-CG ASP 113.4 2.2
|
184 |
+
CB-CG-OD1 ASP 118.3 0.9
|
185 |
+
CB-CG-OD2 ASP 118.3 0.9
|
186 |
+
OD1-CG-OD2 ASP 123.3 1.9
|
187 |
+
N-CA-C ASP 111.0 2.7
|
188 |
+
CA-C-O ASP 120.1 2.1
|
189 |
+
N-CA-CB CYS 110.8 1.5
|
190 |
+
CB-CA-C CYS 111.5 1.2
|
191 |
+
CA-CB-SG CYS 114.2 1.1
|
192 |
+
N-CA-C CYS 111.0 2.7
|
193 |
+
CA-C-O CYS 120.1 2.1
|
194 |
+
N-CA-CB GLU 110.6 1.8
|
195 |
+
CB-CA-C GLU 110.4 2.0
|
196 |
+
CA-CB-CG GLU 113.4 2.2
|
197 |
+
CB-CG-CD GLU 114.2 2.7
|
198 |
+
CG-CD-OE1 GLU 118.3 2.0
|
199 |
+
CG-CD-OE2 GLU 118.3 2.0
|
200 |
+
OE1-CD-OE2 GLU 123.3 1.2
|
201 |
+
N-CA-C GLU 111.0 2.7
|
202 |
+
CA-C-O GLU 120.1 2.1
|
203 |
+
N-CA-CB GLN 110.6 1.8
|
204 |
+
CB-CA-C GLN 110.4 2.0
|
205 |
+
CA-CB-CG GLN 113.4 2.2
|
206 |
+
CB-CG-CD GLN 111.6 2.6
|
207 |
+
CG-CD-OE1 GLN 121.6 2.0
|
208 |
+
CG-CD-NE2 GLN 116.7 2.4
|
209 |
+
OE1-CD-NE2 GLN 121.9 2.3
|
210 |
+
N-CA-C GLN 111.0 2.7
|
211 |
+
CA-C-O GLN 120.1 2.1
|
212 |
+
N-CA-C GLY 113.1 2.5
|
213 |
+
CA-C-O GLY 120.6 1.8
|
214 |
+
N-CA-CB HIS 110.6 1.8
|
215 |
+
CB-CA-C HIS 110.4 2.0
|
216 |
+
CA-CB-CG HIS 113.6 1.7
|
217 |
+
CB-CG-ND1 HIS 123.2 2.5
|
218 |
+
CB-CG-CD2 HIS 130.8 3.1
|
219 |
+
CG-ND1-CE1 HIS 108.2 1.4
|
220 |
+
ND1-CE1-NE2 HIS 109.9 2.2
|
221 |
+
CE1-NE2-CD2 HIS 106.6 2.5
|
222 |
+
NE2-CD2-CG HIS 109.2 1.9
|
223 |
+
CD2-CG-ND1 HIS 106.0 1.4
|
224 |
+
N-CA-C HIS 111.0 2.7
|
225 |
+
CA-C-O HIS 120.1 2.1
|
226 |
+
N-CA-CB ILE 110.8 2.3
|
227 |
+
CB-CA-C ILE 111.6 2.0
|
228 |
+
CA-CB-CG1 ILE 111.0 1.9
|
229 |
+
CB-CG1-CD1 ILE 113.9 2.8
|
230 |
+
CA-CB-CG2 ILE 110.9 2.0
|
231 |
+
CG1-CB-CG2 ILE 111.4 2.2
|
232 |
+
N-CA-C ILE 111.0 2.7
|
233 |
+
CA-C-O ILE 120.1 2.1
|
234 |
+
N-CA-CB LEU 110.4 2.0
|
235 |
+
CB-CA-C LEU 110.2 1.9
|
236 |
+
CA-CB-CG LEU 115.3 2.3
|
237 |
+
CB-CG-CD1 LEU 111.0 1.7
|
238 |
+
CB-CG-CD2 LEU 111.0 1.7
|
239 |
+
CD1-CG-CD2 LEU 110.5 3.0
|
240 |
+
N-CA-C LEU 111.0 2.7
|
241 |
+
CA-C-O LEU 120.1 2.1
|
242 |
+
N-CA-CB LYS 110.6 1.8
|
243 |
+
CB-CA-C LYS 110.4 2.0
|
244 |
+
CA-CB-CG LYS 113.4 2.2
|
245 |
+
CB-CG-CD LYS 111.6 2.6
|
246 |
+
CG-CD-CE LYS 111.9 3.0
|
247 |
+
CD-CE-NZ LYS 111.7 2.3
|
248 |
+
N-CA-C LYS 111.0 2.7
|
249 |
+
CA-C-O LYS 120.1 2.1
|
250 |
+
N-CA-CB MET 110.6 1.8
|
251 |
+
CB-CA-C MET 110.4 2.0
|
252 |
+
CA-CB-CG MET 113.3 1.7
|
253 |
+
CB-CG-SD MET 112.4 3.0
|
254 |
+
CG-SD-CE MET 100.2 1.6
|
255 |
+
N-CA-C MET 111.0 2.7
|
256 |
+
CA-C-O MET 120.1 2.1
|
257 |
+
N-CA-CB PHE 110.6 1.8
|
258 |
+
CB-CA-C PHE 110.4 2.0
|
259 |
+
CA-CB-CG PHE 113.9 2.4
|
260 |
+
CB-CG-CD1 PHE 120.8 0.7
|
261 |
+
CB-CG-CD2 PHE 120.8 0.7
|
262 |
+
CD1-CG-CD2 PHE 118.3 1.3
|
263 |
+
CG-CD1-CE1 PHE 120.8 1.1
|
264 |
+
CG-CD2-CE2 PHE 120.8 1.1
|
265 |
+
CD1-CE1-CZ PHE 120.1 1.2
|
266 |
+
CD2-CE2-CZ PHE 120.1 1.2
|
267 |
+
CE1-CZ-CE2 PHE 120.0 1.8
|
268 |
+
N-CA-C PHE 111.0 2.7
|
269 |
+
CA-C-O PHE 120.1 2.1
|
270 |
+
N-CA-CB PRO 103.3 1.2
|
271 |
+
CB-CA-C PRO 111.7 2.1
|
272 |
+
CA-CB-CG PRO 104.8 1.9
|
273 |
+
CB-CG-CD PRO 106.5 3.9
|
274 |
+
CG-CD-N PRO 103.2 1.5
|
275 |
+
CA-N-CD PRO 111.7 1.4
|
276 |
+
N-CA-C PRO 112.1 2.6
|
277 |
+
CA-C-O PRO 120.2 2.4
|
278 |
+
N-CA-CB SER 110.5 1.5
|
279 |
+
CB-CA-C SER 110.1 1.9
|
280 |
+
CA-CB-OG SER 111.2 2.7
|
281 |
+
N-CA-C SER 111.0 2.7
|
282 |
+
CA-C-O SER 120.1 2.1
|
283 |
+
N-CA-CB THR 110.3 1.9
|
284 |
+
CB-CA-C THR 111.6 2.7
|
285 |
+
CA-CB-OG1 THR 109.0 2.1
|
286 |
+
CA-CB-CG2 THR 112.4 1.4
|
287 |
+
OG1-CB-CG2 THR 110.0 2.3
|
288 |
+
N-CA-C THR 111.0 2.7
|
289 |
+
CA-C-O THR 120.1 2.1
|
290 |
+
N-CA-CB TRP 110.6 1.8
|
291 |
+
CB-CA-C TRP 110.4 2.0
|
292 |
+
CA-CB-CG TRP 113.7 1.9
|
293 |
+
CB-CG-CD1 TRP 127.0 1.3
|
294 |
+
CB-CG-CD2 TRP 126.6 1.3
|
295 |
+
CD1-CG-CD2 TRP 106.3 0.8
|
296 |
+
CG-CD1-NE1 TRP 110.1 1.0
|
297 |
+
CD1-NE1-CE2 TRP 109.0 0.9
|
298 |
+
NE1-CE2-CD2 TRP 107.3 1.0
|
299 |
+
CE2-CD2-CG TRP 107.3 0.8
|
300 |
+
CG-CD2-CE3 TRP 133.9 0.9
|
301 |
+
NE1-CE2-CZ2 TRP 130.4 1.1
|
302 |
+
CE3-CD2-CE2 TRP 118.7 1.2
|
303 |
+
CD2-CE2-CZ2 TRP 122.3 1.2
|
304 |
+
CE2-CZ2-CH2 TRP 117.4 1.0
|
305 |
+
CZ2-CH2-CZ3 TRP 121.6 1.2
|
306 |
+
CH2-CZ3-CE3 TRP 121.2 1.1
|
307 |
+
CZ3-CE3-CD2 TRP 118.8 1.3
|
308 |
+
N-CA-C TRP 111.0 2.7
|
309 |
+
CA-C-O TRP 120.1 2.1
|
310 |
+
N-CA-CB TYR 110.6 1.8
|
311 |
+
CB-CA-C TYR 110.4 2.0
|
312 |
+
CA-CB-CG TYR 113.4 1.9
|
313 |
+
CB-CG-CD1 TYR 121.0 0.6
|
314 |
+
CB-CG-CD2 TYR 121.0 0.6
|
315 |
+
CD1-CG-CD2 TYR 117.9 1.1
|
316 |
+
CG-CD1-CE1 TYR 121.3 0.8
|
317 |
+
CG-CD2-CE2 TYR 121.3 0.8
|
318 |
+
CD1-CE1-CZ TYR 119.8 0.9
|
319 |
+
CD2-CE2-CZ TYR 119.8 0.9
|
320 |
+
CE1-CZ-CE2 TYR 119.8 1.6
|
321 |
+
CE1-CZ-OH TYR 120.1 2.7
|
322 |
+
CE2-CZ-OH TYR 120.1 2.7
|
323 |
+
N-CA-C TYR 111.0 2.7
|
324 |
+
CA-C-O TYR 120.1 2.1
|
325 |
+
N-CA-CB VAL 111.5 2.2
|
326 |
+
CB-CA-C VAL 111.4 1.9
|
327 |
+
CA-CB-CG1 VAL 110.9 1.5
|
328 |
+
CA-CB-CG2 VAL 110.9 1.5
|
329 |
+
CG1-CB-CG2 VAL 110.9 1.6
|
330 |
+
N-CA-C VAL 111.0 2.7
|
331 |
+
CA-C-O VAL 120.1 2.1
|
332 |
+
-
|
333 |
+
|
334 |
+
Non-bonded distance Minimum Dist Tolerance
|
335 |
+
C-C 3.4 1.5
|
336 |
+
C-N 3.25 1.5
|
337 |
+
C-S 3.5 1.5
|
338 |
+
C-O 3.22 1.5
|
339 |
+
N-N 3.1 1.5
|
340 |
+
N-S 3.35 1.5
|
341 |
+
N-O 3.07 1.5
|
342 |
+
O-S 3.32 1.5
|
343 |
+
O-O 3.04 1.5
|
344 |
+
S-S 2.03 1.0
|
345 |
+
-
|
core/utils.py
ADDED
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Various utils for handling protein data.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import shlex
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
import torch
|
14 |
+
import yaml
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import Bio
|
22 |
+
from Bio.PDB.DSSP import DSSP
|
23 |
+
|
24 |
+
from core import protein
|
25 |
+
from core import protein_mpnn
|
26 |
+
from core import residue_constants
|
27 |
+
|
28 |
+
|
29 |
+
PATH_TO_TMALIGN = "/home/alexechu/essentials_kit/ml_utils/align/TMalign/TMalign"
|
30 |
+
|
31 |
+
|
32 |
+
################ STRUCTURE/FORMAT UTILS #############################
|
33 |
+
|
34 |
+
|
35 |
+
def aatype_to_seq(aatype, seq_mask=None):
|
36 |
+
if seq_mask is None:
|
37 |
+
seq_mask = torch.ones_like(aatype)
|
38 |
+
|
39 |
+
mapping = residue_constants.restypes_with_x
|
40 |
+
mapping = mapping + ["<mask>"]
|
41 |
+
|
42 |
+
unbatched = False
|
43 |
+
if len(aatype.shape) == 1:
|
44 |
+
unbatched = True
|
45 |
+
aatype = [aatype]
|
46 |
+
seq_mask = [seq_mask]
|
47 |
+
|
48 |
+
seqs = []
|
49 |
+
for i, ai in enumerate(aatype):
|
50 |
+
seq = []
|
51 |
+
for j, aa in enumerate(ai):
|
52 |
+
if seq_mask[i][j] == 1:
|
53 |
+
try:
|
54 |
+
seq.append(mapping[aa])
|
55 |
+
except IndexError:
|
56 |
+
print(aatype[i])
|
57 |
+
raise Exception(f"Error in mapping {aa} at {i},{j}")
|
58 |
+
seqs.append("".join(seq))
|
59 |
+
|
60 |
+
if unbatched:
|
61 |
+
seqs = seqs[0]
|
62 |
+
return seqs
|
63 |
+
|
64 |
+
|
65 |
+
def seq_to_aatype(seq, num_tokens=21):
|
66 |
+
if num_tokens == 20:
|
67 |
+
mapping = residue_constants.restype_order
|
68 |
+
if num_tokens == 21:
|
69 |
+
mapping = residue_constants.restype_order_with_x
|
70 |
+
if num_tokens == 22:
|
71 |
+
mapping = residue_constants.restype_order_with_x
|
72 |
+
mapping["<mask>"] = 21
|
73 |
+
return torch.Tensor([mapping[aa] for aa in seq]).long()
|
74 |
+
|
75 |
+
|
76 |
+
def batched_seq_to_aatype_and_mask(seqs, max_len=None):
|
77 |
+
if max_len is None:
|
78 |
+
max_len = max([len(s) for s in seqs])
|
79 |
+
aatypes = []
|
80 |
+
seq_mask = []
|
81 |
+
for s in seqs:
|
82 |
+
pad_size = max_len - len(s)
|
83 |
+
aatype = seq_to_aatype(s)
|
84 |
+
aatypes.append(F.pad(aatype, (0, pad_size)))
|
85 |
+
mask = torch.ones_like(aatype).float()
|
86 |
+
seq_mask.append(F.pad(mask, (0, pad_size)))
|
87 |
+
return torch.stack(aatypes), torch.stack(seq_mask)
|
88 |
+
|
89 |
+
|
90 |
+
def atom37_mask_from_aatype(aatype, seq_mask=None):
|
91 |
+
# source_mask is (21,37) originally
|
92 |
+
source_mask = torch.Tensor(residue_constants.restype_atom37_mask).to(aatype.device)
|
93 |
+
bb_atoms = source_mask[residue_constants.restype_order["G"]][None]
|
94 |
+
# Use only the first 20 plus bb atoms for X, mask
|
95 |
+
source_mask = torch.cat([source_mask[:-1], bb_atoms, bb_atoms], 0)
|
96 |
+
atom_mask = source_mask[aatype]
|
97 |
+
if seq_mask is not None:
|
98 |
+
atom_mask *= seq_mask[..., None]
|
99 |
+
return atom_mask
|
100 |
+
|
101 |
+
|
102 |
+
def atom37_coords_from_atom14(atom14_coords, aatype, return_mask=False):
|
103 |
+
# Unbatched
|
104 |
+
device = atom14_coords.device
|
105 |
+
atom37_coords = torch.zeros((atom14_coords.shape[0], 37, 3)).to(device)
|
106 |
+
for i in range(atom14_coords.shape[0]): # per residue
|
107 |
+
aa = aatype[i]
|
108 |
+
aa_3name = residue_constants.restype_1to3[residue_constants.restypes[aa]]
|
109 |
+
atom14_atoms = residue_constants.restype_name_to_atom14_names[aa_3name]
|
110 |
+
for j in range(14):
|
111 |
+
atom_name = atom14_atoms[j]
|
112 |
+
if atom_name != "":
|
113 |
+
atom37_idx = residue_constants.atom_order[atom_name]
|
114 |
+
atom37_coords[i, atom37_idx, :] = atom14_coords[i, j, :]
|
115 |
+
|
116 |
+
if return_mask:
|
117 |
+
atom37_mask = atom37_mask_from_aatype(aatype)
|
118 |
+
return atom37_coords, atom37_mask
|
119 |
+
return atom37_coords
|
120 |
+
|
121 |
+
|
122 |
+
def atom73_mask_from_aatype(aatype, seq_mask=None):
|
123 |
+
source_mask = torch.Tensor(residue_constants.restype_atom73_mask).to(aatype.device)
|
124 |
+
atom_mask = source_mask[aatype]
|
125 |
+
if seq_mask is not None:
|
126 |
+
atom_mask *= seq_mask[..., None]
|
127 |
+
return atom_mask
|
128 |
+
|
129 |
+
|
130 |
+
def atom37_to_atom73(atom37, aatype, return_mask=False):
|
131 |
+
# Unbatched
|
132 |
+
atom73 = torch.zeros((atom37.shape[0], 73, 3)).to(atom37)
|
133 |
+
for i in range(atom37.shape[0]):
|
134 |
+
aa = aatype[i]
|
135 |
+
aa1 = residue_constants.restypes[aa]
|
136 |
+
for j, atom37_name in enumerate(residue_constants.atom_types):
|
137 |
+
atom73_name = atom37_name
|
138 |
+
if atom37_name not in ["N", "CA", "C", "O", "CB"]:
|
139 |
+
atom73_name = aa1 + atom73_name
|
140 |
+
if atom73_name in residue_constants.atom73_names_to_idx:
|
141 |
+
atom73_idx = residue_constants.atom73_names_to_idx[atom73_name]
|
142 |
+
atom73[i, atom73_idx, :] = atom37[i, j, :]
|
143 |
+
|
144 |
+
if return_mask:
|
145 |
+
atom73_mask = atom73_mask_from_aatype(aatype)
|
146 |
+
return atom73, atom73_mask
|
147 |
+
return atom73
|
148 |
+
|
149 |
+
|
150 |
+
def atom73_to_atom37(atom73, aatype, return_mask=False):
|
151 |
+
# Unbatched
|
152 |
+
atom37_coords = torch.zeros((atom73.shape[0], 37, 3)).to(atom73)
|
153 |
+
for i in range(atom73.shape[0]): # per residue
|
154 |
+
aa = aatype[i]
|
155 |
+
aa1 = residue_constants.restypes[aa]
|
156 |
+
for j, atom_type in enumerate(residue_constants.atom_types):
|
157 |
+
atom73_name = atom_type
|
158 |
+
if atom73_name not in ["N", "CA", "C", "O", "CB"]:
|
159 |
+
atom73_name = aa1 + atom73_name
|
160 |
+
if atom73_name in residue_constants.atom73_names_to_idx:
|
161 |
+
atom73_idx = residue_constants.atom73_names_to_idx[atom73_name]
|
162 |
+
atom37_coords[i, j, :] = atom73[i, atom73_idx, :]
|
163 |
+
|
164 |
+
if return_mask:
|
165 |
+
atom37_mask = atom37_mask_from_aatype(aatype)
|
166 |
+
return atom37_coords, atom37_mask
|
167 |
+
return atom37_coords
|
168 |
+
|
169 |
+
|
170 |
+
def get_dmap(pdb, atoms=["N", "CA", "C", "O"], batched=True, out="torch", device=None):
|
171 |
+
def _dmap_from_coords(coords):
|
172 |
+
coords = coords.contiguous()
|
173 |
+
dmaps = torch.cdist(coords, coords).unsqueeze(1)
|
174 |
+
if out == "numpy":
|
175 |
+
return dmaps.detach().cpu().numpy()
|
176 |
+
elif out == "torch":
|
177 |
+
if device is not None:
|
178 |
+
return dmaps.to(device)
|
179 |
+
else:
|
180 |
+
return dmaps
|
181 |
+
|
182 |
+
if isinstance(pdb, str): # input is pdb file
|
183 |
+
coords = load_coords_from_pdb(pdb, atoms=atoms).view(1, -1, 3)
|
184 |
+
return _dmap_from_coords(coords)
|
185 |
+
elif len(pdb.shape) == 2: # single set of coords
|
186 |
+
if isinstance(pdb, np.ndarray):
|
187 |
+
pdb = torch.Tensor(pdb)
|
188 |
+
return _dmap_from_coords(pdb.unsqueeze(0))
|
189 |
+
elif len(pdb.shape) == 3 and batched:
|
190 |
+
return _dmap_from_coords(pdb)
|
191 |
+
elif len(pdb.shape) == 3 and not batched:
|
192 |
+
return _dmap_from_coords(pdb.view(1, -1, 3))
|
193 |
+
elif len(pdb.shape) == 4:
|
194 |
+
return _dmap_from_coords(pdb.view(pdb.size(0), -1, 3))
|
195 |
+
|
196 |
+
|
197 |
+
def get_channeled_dmap(coords):
|
198 |
+
# coords is b, nres, natom, 3
|
199 |
+
coords = coords.permute(0, 2, 1, 3)
|
200 |
+
dvecs = coords[..., None, :] - coords[..., None, :, :] # b, natom, nres, nres, 3
|
201 |
+
dists = torch.sqrt(dvecs.pow(2).sum(-1) + 1e-8)
|
202 |
+
return dists
|
203 |
+
|
204 |
+
|
205 |
+
def fill_in_cbeta_for_atom37(coords):
|
206 |
+
b = coords[..., 1, :] - coords[..., 0, :]
|
207 |
+
c = coords[..., 2, :] - coords[..., 1, :]
|
208 |
+
a = torch.cross(b, c, dim=-1)
|
209 |
+
cbeta = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + coords[..., 1, :]
|
210 |
+
new_coords = torch.clone(coords)
|
211 |
+
new_coords[..., 3, :] = cbeta
|
212 |
+
return new_coords
|
213 |
+
|
214 |
+
|
215 |
+
def get_distogram(coords, n_bins=20, start=2, return_onehot=True, seq_mask=None):
|
216 |
+
# coords is b, nres, natom, 3
|
217 |
+
# distogram for cb atom (assume 3rd atom)
|
218 |
+
coords_with_cb = fill_in_cbeta_for_atom37(coords)
|
219 |
+
dists = get_channeled_dmap(coords_with_cb[:, :, 3:4]).squeeze(1)
|
220 |
+
bins = torch.arange(start, start + n_bins - 1).to(dists.device)
|
221 |
+
dgram = torch.bucketize(dists, bins)
|
222 |
+
dgram_oh = F.one_hot(dgram, n_bins)
|
223 |
+
if seq_mask is not None:
|
224 |
+
mask_2d = seq_mask[:, :, None] * seq_mask[:, None, :]
|
225 |
+
dgram = dgram * mask_2d
|
226 |
+
dgram_oh = dgram_oh * mask_2d[..., None]
|
227 |
+
|
228 |
+
if return_onehot:
|
229 |
+
return dgram_oh
|
230 |
+
return dgram
|
231 |
+
|
232 |
+
|
233 |
+
def get_contacts(coords=None, distogram=None, seq_mask=None):
|
234 |
+
if distogram is None:
|
235 |
+
distogram = get_distogram(coords)
|
236 |
+
contacts = (distogram.argmax(-1) < 6).float()
|
237 |
+
if seq_mask is not None:
|
238 |
+
contacts *= seq_mask[..., None] * seq_mask[..., None, :]
|
239 |
+
return contacts
|
240 |
+
|
241 |
+
|
242 |
+
def dihedral(a, b, c, d):
|
243 |
+
# inputs can be (1,3), (n,3), or (bs,n,3)
|
244 |
+
b1 = a - b
|
245 |
+
b2 = b - c
|
246 |
+
b3 = c - d
|
247 |
+
n1 = F.normalize(torch.cross(b1, b2), dim=-1)
|
248 |
+
n2 = F.normalize(torch.cross(b2, b3), dim=-1)
|
249 |
+
m1 = torch.cross(n1, b2 / b2.norm(dim=-1).unsqueeze(-1))
|
250 |
+
y = (m1 * n2).sum(dim=-1)
|
251 |
+
x = (n1 * n2).sum(dim=-1)
|
252 |
+
return torch.atan2(y, x)
|
253 |
+
|
254 |
+
|
255 |
+
def get_torsions_from_coords(
|
256 |
+
coords, atoms=["N", "CA", "C", "O"], batched=True, out="torch", device=None
|
257 |
+
):
|
258 |
+
"""
|
259 |
+
Returns a n-dim array of shape (bs, nres, ntors), where ntors is the
|
260 |
+
number of torsion angles (e.g. 2 if using phi and psi), with units of radians.
|
261 |
+
"""
|
262 |
+
if isinstance(coords, np.ndarray):
|
263 |
+
coords = torch.Tensor(coords)
|
264 |
+
if len(coords.shape) == 2:
|
265 |
+
coords = coords.unsqueeze(0)
|
266 |
+
if len(coords.shape) == 4:
|
267 |
+
coords = coords.view(coords.size(0), -1, 3)
|
268 |
+
if len(coords.shape) == 3 and not batched:
|
269 |
+
coords = coords.view(1, -1, 3)
|
270 |
+
if len(coords.shape) == 3:
|
271 |
+
bs = coords.size(0)
|
272 |
+
if "O" in atoms:
|
273 |
+
idxs = [
|
274 |
+
i for i in range(coords.size(1)) if i % 4 != 3
|
275 |
+
] # deselect O atoms for N-Ca-C-O coords
|
276 |
+
coords = coords[:, idxs, :]
|
277 |
+
a, b, c, d = (
|
278 |
+
coords[:, :-3, :],
|
279 |
+
coords[:, 1:-2, :],
|
280 |
+
coords[:, 2:-1, :],
|
281 |
+
coords[:, 3:, :],
|
282 |
+
)
|
283 |
+
torsions = dihedral(
|
284 |
+
a, b, c, d
|
285 |
+
) # output order is psi-omega-phi, reorganize to (bs, nres, 3)
|
286 |
+
torsions = torsions.view(bs, torsions.size(1) // 3, 3)
|
287 |
+
omegaphi = torch.cat(
|
288 |
+
(torch.zeros(bs, 1, 2).to(coords.device), torsions[:, :, 1:]), 1
|
289 |
+
)
|
290 |
+
psi = torch.cat((torsions[:, :, 0], torch.zeros(bs, 1).to(coords.device)), 1)
|
291 |
+
torsions = torch.cat(
|
292 |
+
(
|
293 |
+
omegaphi[:, :, 1].unsqueeze(-1),
|
294 |
+
psi.unsqueeze(-1),
|
295 |
+
omegaphi[:, :, 0].unsqueeze(-1),
|
296 |
+
),
|
297 |
+
-1,
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
raise Exception("input coords not of correct dims")
|
301 |
+
|
302 |
+
if out == "numpy":
|
303 |
+
return torsions.detach().cpu().numpy()
|
304 |
+
elif out == "torch":
|
305 |
+
if device is not None:
|
306 |
+
return torsions.to(device)
|
307 |
+
else:
|
308 |
+
return torsions
|
309 |
+
|
310 |
+
|
311 |
+
def get_trig_from_torsions(torsions, out="torch", device=None):
|
312 |
+
"""
|
313 |
+
Calculate unit circle projections from coords input.
|
314 |
+
|
315 |
+
Returns a n-dim array of shape (bs, nres, ntors, 2), where ntors is the
|
316 |
+
number of torsion angles (e.g. 2 if using phi and psi), and the last
|
317 |
+
dimension is the xy unit-circle coordinates of the corresponding angle.
|
318 |
+
"""
|
319 |
+
if isinstance(torsions, np.ndarray):
|
320 |
+
torsions = torch.Tensor(torsions)
|
321 |
+
x = torsions.cos()
|
322 |
+
y = torsions.sin()
|
323 |
+
trig = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), -1)
|
324 |
+
if out == "numpy":
|
325 |
+
return trig.detach().cpu().numpy()
|
326 |
+
elif out == "torch":
|
327 |
+
if device is not None:
|
328 |
+
return trig.to(device)
|
329 |
+
else:
|
330 |
+
return trig
|
331 |
+
|
332 |
+
|
333 |
+
def get_abego_string_from_torsions(torsions):
|
334 |
+
A_bin = (-75, 50)
|
335 |
+
G_bin = (-100, 100)
|
336 |
+
torsions = torsions * 180.0 / np.pi
|
337 |
+
phi, psi = torsions[:, :, 0], torsions[:, :, 1]
|
338 |
+
abego_vec = np.zeros((torsions.size(0), torsions.size(1))).astype(str)
|
339 |
+
A = (phi <= 0) & (psi <= A_bin[1]) & (psi > A_bin[0])
|
340 |
+
B = (phi <= 0) & ((psi > A_bin[1]) | (psi <= A_bin[0]))
|
341 |
+
G = (phi > 0) & (psi <= G_bin[1]) & (psi > G_bin[0])
|
342 |
+
E = (phi > 0) & ((psi > G_bin[1]) | (psi <= G_bin[0]))
|
343 |
+
abego_vec[A] = "A"
|
344 |
+
abego_vec[B] = "B"
|
345 |
+
abego_vec[G] = "G"
|
346 |
+
abego_vec[E] = "E"
|
347 |
+
abego_strs = ["".join(v) for v in abego_vec]
|
348 |
+
return abego_strs
|
349 |
+
|
350 |
+
|
351 |
+
def get_bond_lengths_from_coords(coords, batched=True, out="torch", device=None):
|
352 |
+
"""
|
353 |
+
Returns array of shape (bs, n_res, 4), where final dim is bond lengths
|
354 |
+
in order of N-Ca, Ca-C, C-O, C-N (none for last residue)
|
355 |
+
"""
|
356 |
+
if isinstance(coords, np.ndarray):
|
357 |
+
coords = torch.Tensor(coords)
|
358 |
+
if len(coords.shape) == 2:
|
359 |
+
coords = coords.unsqueeze(0)
|
360 |
+
if len(coords.shape) == 3 and not batched:
|
361 |
+
coords = coords.view(1, -1, 3)
|
362 |
+
if len(coords.shape) == 4:
|
363 |
+
coords = coords.view(coords.size(0), -1, 3)
|
364 |
+
N = coords[:, ::4, :]
|
365 |
+
Ca = coords[:, 1::4, :]
|
366 |
+
C = coords[:, 2::4, :]
|
367 |
+
O = coords[:, 3::4, :]
|
368 |
+
NCa = (Ca - N).norm(dim=-1).unsqueeze(-1)
|
369 |
+
CaC = (C - Ca).norm(dim=-1).unsqueeze(-1)
|
370 |
+
CO = (O - C).norm(dim=-1).unsqueeze(-1)
|
371 |
+
CN = (N[:, 1:] - C[:, :-1]).norm(dim=-1)
|
372 |
+
CN = torch.cat([CN, torch.zeros(CN.size(0), 1).to(CN.device)], 1).unsqueeze(-1)
|
373 |
+
blengths = torch.cat((NCa, CaC, CO, CN), -1)
|
374 |
+
if out == "numpy":
|
375 |
+
return blengths.detach().cpu().numpy()
|
376 |
+
elif out == "torch":
|
377 |
+
if device is not None:
|
378 |
+
return blengths.to(device)
|
379 |
+
else:
|
380 |
+
return blengths
|
381 |
+
|
382 |
+
|
383 |
+
def get_bond_angles_from_coords(coords, batched=True, out="torch", device=None):
|
384 |
+
"""
|
385 |
+
Returns array of shape (bs, n_res, 5), where final dim is bond angles
|
386 |
+
in order of N-Ca-C, Ca-C-O, Ca-C-N, O-C-N, C-N-Ca (none for last residue)
|
387 |
+
"""
|
388 |
+
|
389 |
+
def _angle(v1, v2):
|
390 |
+
cos = (v1 * v2).sum(-1) / (v1.norm(dim=-1) * v2.norm(dim=-1))
|
391 |
+
return cos.acos()
|
392 |
+
|
393 |
+
if isinstance(coords, np.ndarray):
|
394 |
+
coords = torch.Tensor(coords)
|
395 |
+
if len(coords.shape) == 2:
|
396 |
+
coords = coords.unsqueeze(0)
|
397 |
+
if len(coords.shape) == 3 and not batched:
|
398 |
+
coords = coords.view(1, -1, 3)
|
399 |
+
if len(coords.shape) == 4:
|
400 |
+
coords = coords.view(coords.size(0), -1, 3)
|
401 |
+
N = coords[:, ::4, :]
|
402 |
+
Nnext = coords[:, 4::4, :]
|
403 |
+
Ca = coords[:, 1::4, :]
|
404 |
+
Canext = coords[:, 5::4, :]
|
405 |
+
C = coords[:, 2::4, :]
|
406 |
+
O = coords[:, 3::4, :]
|
407 |
+
CaN = N - Ca
|
408 |
+
CaC = C - Ca
|
409 |
+
CCa = Ca - C
|
410 |
+
CO = O - C
|
411 |
+
CNnext = Nnext - C[:, :-1, :]
|
412 |
+
NnextC = -1 * CNnext
|
413 |
+
NnextCanext = Canext - Nnext
|
414 |
+
NCaC = _angle(CaN, CaC).unsqueeze(-1)
|
415 |
+
CaCO = _angle(CCa, CO).unsqueeze(-1)
|
416 |
+
CaCN = _angle(CCa[:, :-1], CNnext).unsqueeze(-1)
|
417 |
+
CaCN = _extend(CaCN)
|
418 |
+
OCN = _angle(CO[:, :-1], CNnext).unsqueeze(-1)
|
419 |
+
OCN = _extend(OCN)
|
420 |
+
CNCa = _angle(NnextC, NnextCanext).unsqueeze(-1)
|
421 |
+
# CNCa = torch.cat([CNCa, torch.zeros(CNCa.size(0), 1).to(CNCa.device)], 1).unsqueeze(-1)
|
422 |
+
CNCa = _extend(CNCa)
|
423 |
+
bangles = torch.cat((NCaC, CaCO, CaCN, OCN, CNCa), -1)
|
424 |
+
if out == "numpy":
|
425 |
+
return bangles.detach().cpu().numpy()
|
426 |
+
elif out == "torch":
|
427 |
+
if device is not None:
|
428 |
+
return bangles.to(device)
|
429 |
+
else:
|
430 |
+
return bangles
|
431 |
+
|
432 |
+
|
433 |
+
def get_buried_positions_mask(coords, seq_mask=None, threshold=6.0):
|
434 |
+
ca_idx = residue_constants.atom_order["CA"] # typically 1
|
435 |
+
cb_idx = residue_constants.atom_order["CB"] # typically 3
|
436 |
+
if seq_mask is None:
|
437 |
+
seq_mask = torch.ones_like(coords)[..., 0, 0]
|
438 |
+
coords = fill_in_cbeta_for_atom37(coords)
|
439 |
+
|
440 |
+
# get 8 closest neighbors by CB
|
441 |
+
neighbor_coords = coords[:, :, cb_idx]
|
442 |
+
|
443 |
+
ca_neighbor_dists, edge_index = protein_mpnn.get_closest_neighbors(
|
444 |
+
neighbor_coords, seq_mask, 9
|
445 |
+
)
|
446 |
+
edge_index = edge_index[..., 1:].contiguous()
|
447 |
+
|
448 |
+
# compute avg CB distance
|
449 |
+
cb_coords = coords[:, :, cb_idx]
|
450 |
+
neighbor_cb = protein_mpnn.gather_nodes(cb_coords, edge_index)
|
451 |
+
avg_cb_dist = (neighbor_cb - cb_coords[..., None, :]).pow(2).sum(-1).sqrt().mean(-1)
|
452 |
+
|
453 |
+
buried_positions_mask = (avg_cb_dist < threshold).float() * seq_mask
|
454 |
+
return buried_positions_mask
|
455 |
+
|
456 |
+
|
457 |
+
def get_fullatom_bond_lengths_from_coords(
|
458 |
+
coords, aatype, atom_mask=None, return_format="per_aa"
|
459 |
+
):
|
460 |
+
# Also return sidechain bond angles. All unbatched. return list of dicts
|
461 |
+
def dist(xyz1, xyz2):
|
462 |
+
return (xyz1 - xyz2).pow(2).sum().sqrt().detach().cpu().item()
|
463 |
+
|
464 |
+
assert aatype.max() <= 19
|
465 |
+
seq = aatype_to_seq(aatype)
|
466 |
+
# residue-wise list of dicts [{'N-CA': a, 'CA-C': b}, {'N-CA': a, 'CA-C': b}]
|
467 |
+
all_bond_lens_by_pos = []
|
468 |
+
# aa-wise dict of dicts of lists {'A': {'N-CA': [a, b, c], 'CA-C': [a, b, c]}}
|
469 |
+
all_bond_lens_by_aa = {aa: {} for aa in residue_constants.restypes}
|
470 |
+
for i, res in enumerate(coords):
|
471 |
+
aa3 = residue_constants.restype_1to3[seq[i]]
|
472 |
+
res_bond_lens = {}
|
473 |
+
for j, atom1 in enumerate(residue_constants.atom_types):
|
474 |
+
for k, atom2 in enumerate(residue_constants.atom_types):
|
475 |
+
if j < k and protein.are_atoms_bonded(aa3, atom1, atom2):
|
476 |
+
if atom_mask is None or (
|
477 |
+
atom_mask[i, j] > 0.5 and atom_mask[i, k] > 0.5
|
478 |
+
):
|
479 |
+
bond_name = f"{atom1}-{atom2}"
|
480 |
+
bond_len = dist(res[j], res[k])
|
481 |
+
res_bond_lens[bond_name] = bond_len
|
482 |
+
all_bond_lens_by_pos.append(res_bond_lens)
|
483 |
+
for key, val in res_bond_lens.items():
|
484 |
+
all_bond_lens_by_aa[seq[i]].setdefault(key, []).append(val)
|
485 |
+
|
486 |
+
if return_format == "per_aa":
|
487 |
+
return all_bond_lens_by_aa
|
488 |
+
elif return_format == "per_position":
|
489 |
+
return all_bond_lens_by_pos
|
490 |
+
|
491 |
+
|
492 |
+
def batched_fullatom_bond_lengths_from_coords(
|
493 |
+
coords, aatype, atom_mask=None, return_format="per_aa"
|
494 |
+
):
|
495 |
+
# Expects trimmed coords (no mask)
|
496 |
+
if return_format == "per_position":
|
497 |
+
batched_bond_lens = []
|
498 |
+
elif return_format == "per_aa":
|
499 |
+
batched_bond_lens = {aa: {} for aa in residue_constants.restypes}
|
500 |
+
for i, c in enumerate(coords):
|
501 |
+
atom_mask_i = None if atom_mask is None else atom_mask[i]
|
502 |
+
bond_lens = get_fullatom_bond_lengths_from_coords(
|
503 |
+
c, aatype[i], atom_mask=atom_mask_i, return_format=return_format
|
504 |
+
)
|
505 |
+
if return_format == "per_position":
|
506 |
+
batched_bond_lens.extend(bond_lens)
|
507 |
+
elif return_format == "per_aa":
|
508 |
+
for aa, d in bond_lens.items():
|
509 |
+
for bond, lengths in d.items():
|
510 |
+
batched_bond_lens[aa].setdefault(bond, []).extend(lengths)
|
511 |
+
return batched_bond_lens
|
512 |
+
|
513 |
+
|
514 |
+
def batched_fullatom_bond_angles_from_coords(coords, aatype, return_format="per_aa"):
|
515 |
+
# Expects trimmed coords (no mask)
|
516 |
+
if return_format == "per_position":
|
517 |
+
batched_bond_angles = []
|
518 |
+
elif return_format == "per_aa":
|
519 |
+
batched_bond_angles = {aa: {} for aa in residue_constants.restypes}
|
520 |
+
for i, c in enumerate(coords):
|
521 |
+
bond_angles = get_fullatom_bond_angles_from_coords(
|
522 |
+
c, aatype[i], return_format=return_format
|
523 |
+
)
|
524 |
+
if return_format == "per_position":
|
525 |
+
batched_bond_angles.extend(bond_angles)
|
526 |
+
elif return_format == "per_aa":
|
527 |
+
for aa, d in bond_angles.items():
|
528 |
+
for bond, lengths in d.items():
|
529 |
+
batched_bond_angles[aa].setdefault(bond, []).extend(lengths)
|
530 |
+
return batched_bond_angles
|
531 |
+
|
532 |
+
|
533 |
+
def get_chi_angles(coords, aatype, atom_mask=None, seq_mask=None):
|
534 |
+
# unbatched
|
535 |
+
# return (n, 4) chis in degrees and mask
|
536 |
+
chis = []
|
537 |
+
chi_mask = []
|
538 |
+
atom_order = residue_constants.atom_order
|
539 |
+
|
540 |
+
seq = aatype_to_seq(aatype, seq_mask=seq_mask)
|
541 |
+
|
542 |
+
for i, aa1 in enumerate(seq): # per residue
|
543 |
+
if seq_mask is not None and seq_mask[i] == 0:
|
544 |
+
chis.append([0, 0, 0, 0])
|
545 |
+
chi_mask.append([0, 0, 0, 0])
|
546 |
+
else:
|
547 |
+
chi = []
|
548 |
+
mask = []
|
549 |
+
chi_atoms = residue_constants.chi_angles_atoms[
|
550 |
+
residue_constants.restype_1to3[aa1]
|
551 |
+
]
|
552 |
+
for j in range(4): # per chi angle
|
553 |
+
if j > len(chi_atoms) - 1:
|
554 |
+
chi.append(0)
|
555 |
+
mask.append(0)
|
556 |
+
elif atom_mask is not None and any(
|
557 |
+
[atom_mask[i, atom_order[a]] < 0.5 for a in chi_atoms[j]]
|
558 |
+
):
|
559 |
+
chi.append(0)
|
560 |
+
mask.append(0)
|
561 |
+
else:
|
562 |
+
# Four atoms per dihedral
|
563 |
+
xyz4 = [coords[i, atom_order[a]] for a in chi_atoms[j]]
|
564 |
+
angle = dihedral(*xyz4) * 180 / np.pi
|
565 |
+
chi.append(angle)
|
566 |
+
mask.append(1)
|
567 |
+
chis.append(chi)
|
568 |
+
chi_mask.append(mask)
|
569 |
+
|
570 |
+
chis = torch.Tensor(chis)
|
571 |
+
chi_mask = torch.Tensor(chi_mask)
|
572 |
+
|
573 |
+
return chis, chi_mask
|
574 |
+
|
575 |
+
|
576 |
+
def fill_Os_from_NCaC_coords(
|
577 |
+
coords: torch.Tensor, out: str = "torch", device: str = None
|
578 |
+
):
|
579 |
+
"""Given NCaC coords, add O atom coordinates in.
|
580 |
+
(bs, 3n, 3) -> (bs, 4n, 3)
|
581 |
+
"""
|
582 |
+
CO_LEN = 1.231
|
583 |
+
if len(coords.shape) == 2:
|
584 |
+
coords = coords.unsqueeze(0)
|
585 |
+
Cs = coords[:, 2:-1:3, :] # all but last C
|
586 |
+
CCa_norm = F.normalize(coords[:, 1:-2:3, :] - Cs, dim=-1) # all but last Ca
|
587 |
+
CN_norm = F.normalize(coords[:, 3::3, :] - Cs, dim=-1) # all but first N
|
588 |
+
Os = F.normalize(CCa_norm + CN_norm, dim=-1) * -CO_LEN
|
589 |
+
Os += Cs
|
590 |
+
# TODO place C-term O atom properly
|
591 |
+
Os = torch.cat([Os, coords[:, -1, :].view(-1, 1, 3) + 1], 1)
|
592 |
+
coords_out = []
|
593 |
+
for i in range(Os.size(1)):
|
594 |
+
coords_out.append(coords[:, i * 3 : (i + 1) * 3, :])
|
595 |
+
coords_out.append(Os[:, i, :].view(-1, 1, 3))
|
596 |
+
coords_out = torch.cat(coords_out, 1)
|
597 |
+
if out == "numpy":
|
598 |
+
return coords_out.detach().cpu().numpy()
|
599 |
+
elif out == "torch":
|
600 |
+
if device is not None:
|
601 |
+
return coords_out.to(device)
|
602 |
+
else:
|
603 |
+
return coords_out
|
604 |
+
|
605 |
+
|
606 |
+
def _extend(x, axis=1, n=1, prepend=False):
|
607 |
+
# Add an extra zeros 'residue' to the end (or beginning, prepend=True) of a Tensor
|
608 |
+
# Used to extend torsions when there is no 'psi' for last residue
|
609 |
+
shape = list(x.shape)
|
610 |
+
shape[axis] = n
|
611 |
+
if prepend:
|
612 |
+
return torch.cat([torch.zeros(shape).to(x.device), x], axis)
|
613 |
+
else:
|
614 |
+
return torch.cat([x, torch.zeros(shape).to(x.device)], axis)
|
615 |
+
|
616 |
+
|
617 |
+
def trim_coords(coords, n_res, batched=True):
|
618 |
+
if batched: # Return list of tensors
|
619 |
+
front = (coords.shape[1] - n_res) // 2
|
620 |
+
return [
|
621 |
+
coords[i, front[i] : front[i] + n_res[i]] for i in range(coords.shape[0])
|
622 |
+
]
|
623 |
+
else:
|
624 |
+
if isinstance(n_res, torch.Tensor):
|
625 |
+
n_res = n_res.int()
|
626 |
+
front_pad = (coords.shape[0] - n_res) // 2
|
627 |
+
return coords[front_pad : front_pad + n_res]
|
628 |
+
|
629 |
+
|
630 |
+
def batch_align_on_calpha(x, y):
|
631 |
+
aligned_x = []
|
632 |
+
for i, xi in enumerate(x):
|
633 |
+
xi_calpha = xi[:, 1, :]
|
634 |
+
_, (R, t) = kabsch_align(xi_calpha, y[i, :, 1, :])
|
635 |
+
xi_ctr = xi - xi_calpha.mean(0, keepdim=True)
|
636 |
+
xi_aligned = xi_ctr @ R.t() + t
|
637 |
+
aligned_x.append(xi_aligned)
|
638 |
+
return torch.stack(aligned_x)
|
639 |
+
|
640 |
+
|
641 |
+
def kabsch_align(p, q):
|
642 |
+
if len(p.shape) > 2:
|
643 |
+
p = p.reshape(-1, 3)
|
644 |
+
if len(q.shape) > 2:
|
645 |
+
q = q.reshape(-1, 3)
|
646 |
+
p_ctr = p - p.mean(0, keepdim=True)
|
647 |
+
t = q.mean(0, keepdim=True)
|
648 |
+
q_ctr = q - t
|
649 |
+
H = p_ctr.t() @ q_ctr
|
650 |
+
U, S, V = torch.svd(H)
|
651 |
+
R = V @ U.t()
|
652 |
+
I_ = torch.eye(3).to(p)
|
653 |
+
I_[-1, -1] = R.det().sign()
|
654 |
+
R = V @ I_ @ U.t()
|
655 |
+
p_aligned = p_ctr @ R.t() + t
|
656 |
+
return p_aligned, (R, t)
|
657 |
+
|
658 |
+
|
659 |
+
def get_dssp_string(pdb):
|
660 |
+
try:
|
661 |
+
structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb[:-3], pdb)
|
662 |
+
dssp = DSSP(structure[0], pdb, dssp="mkdssp")
|
663 |
+
dssp_string = "".join([dssp[k][2] for k in dssp.keys()])
|
664 |
+
return dssp_string
|
665 |
+
except Exception as e:
|
666 |
+
print(e)
|
667 |
+
return None
|
668 |
+
|
669 |
+
|
670 |
+
def pool_dssp_symbols(dssp_string, newchar=None, chars=["-", "T", "S", "C", " "]):
|
671 |
+
"""Replaces all instances of chars with newchar. DSSP chars are helix=GHI, strand=EB, loop=- TSC"""
|
672 |
+
if newchar is None:
|
673 |
+
newchar = chars[0]
|
674 |
+
string_out = dssp_string
|
675 |
+
for c in chars:
|
676 |
+
string_out = string_out.replace(c, newchar)
|
677 |
+
return string_out
|
678 |
+
|
679 |
+
|
680 |
+
def get_3state_dssp(pdb=None, coords=None):
|
681 |
+
if coords is not None:
|
682 |
+
pdb = "temp_dssp.pdb"
|
683 |
+
write_coords_to_pdb(coords, pdb, batched=False)
|
684 |
+
dssp_string = get_dssp_string(pdb)
|
685 |
+
if dssp_string is not None:
|
686 |
+
dssp_string = pool_dssp_symbols(dssp_string, newchar="L")
|
687 |
+
dssp_string = pool_dssp_symbols(dssp_string, chars=["H", "G", "I"])
|
688 |
+
dssp_string = pool_dssp_symbols(dssp_string, chars=["E", "B"])
|
689 |
+
if coords is not None:
|
690 |
+
subprocess.run(shlex.split(f"rm {pdb}"))
|
691 |
+
return dssp_string
|
692 |
+
|
693 |
+
|
694 |
+
############## SAVE/LOAD UTILS #################################
|
695 |
+
|
696 |
+
|
697 |
+
def load_feats_from_pdb(
|
698 |
+
pdb, bb_atoms=["N", "CA", "C", "O"], load_atom73=False, **kwargs
|
699 |
+
):
|
700 |
+
feats = {}
|
701 |
+
with open(pdb, "r") as f:
|
702 |
+
pdb_str = f.read()
|
703 |
+
protein_obj = protein.from_pdb_string(pdb_str, **kwargs)
|
704 |
+
bb_idxs = [residue_constants.atom_order[a] for a in bb_atoms]
|
705 |
+
bb_coords = torch.from_numpy(protein_obj.atom_positions[:, bb_idxs])
|
706 |
+
feats["bb_coords"] = bb_coords.float()
|
707 |
+
for k, v in vars(protein_obj).items():
|
708 |
+
feats[k] = torch.Tensor(v)
|
709 |
+
feats["aatype"] = feats["aatype"].long()
|
710 |
+
if load_atom73:
|
711 |
+
feats["atom73_coords"], feats["atom73_mask"] = atom37_to_atom73(
|
712 |
+
feats["atom_positions"], feats["aatype"], return_mask=True
|
713 |
+
)
|
714 |
+
return feats
|
715 |
+
|
716 |
+
|
717 |
+
def load_coords_from_pdb(
|
718 |
+
pdb,
|
719 |
+
atoms=["N", "CA", "C", "O"],
|
720 |
+
method="raw",
|
721 |
+
also_bfactors=False,
|
722 |
+
normalize_bfactors=True,
|
723 |
+
):
|
724 |
+
"""Returns array of shape (1, n_res, len(atoms), 3)"""
|
725 |
+
coords = []
|
726 |
+
bfactors = []
|
727 |
+
if method == "raw": # Raw numpy implementation, faster than biopdb
|
728 |
+
# Indexing into PDB format, allowing XXXX.XXX
|
729 |
+
coords_in_pdb = [slice(30, 38), slice(38, 46), slice(46, 54)]
|
730 |
+
# Indexing into PDB format, allowing XXX.XX
|
731 |
+
bfactor_in_pdb = slice(60, 66)
|
732 |
+
|
733 |
+
with open(pdb, "r") as f:
|
734 |
+
resi_prev = 1
|
735 |
+
counter = 0
|
736 |
+
for l in f:
|
737 |
+
l_split = l.rstrip("\n").split()
|
738 |
+
if len(l_split) > 0 and l_split[0] == "ATOM" and l_split[2] in atoms:
|
739 |
+
resi = l_split[5]
|
740 |
+
if resi == resi_prev:
|
741 |
+
counter += 1
|
742 |
+
else:
|
743 |
+
counter = 0
|
744 |
+
if counter < len(atoms):
|
745 |
+
xyz = [
|
746 |
+
np.array(l[s].strip()).astype(float) for s in coords_in_pdb
|
747 |
+
]
|
748 |
+
coords.append(xyz)
|
749 |
+
if also_bfactors:
|
750 |
+
bfactor = np.array(l[bfactor_in_pdb].strip()).astype(float)
|
751 |
+
bfactors.append(bfactor)
|
752 |
+
resi_prev = resi
|
753 |
+
coords = torch.Tensor(np.array(coords)).view(1, -1, len(atoms), 3)
|
754 |
+
if also_bfactors:
|
755 |
+
bfactors = torch.Tensor(np.array(bfactors)).view(1, -1, len(atoms))
|
756 |
+
elif method == "biopdb":
|
757 |
+
structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb[:-3], pdb)
|
758 |
+
for model in structure:
|
759 |
+
for chain in model:
|
760 |
+
for res in chain:
|
761 |
+
for atom in atoms:
|
762 |
+
try:
|
763 |
+
coords.append(np.asarray(res[atom].get_coord()))
|
764 |
+
if also_bfactors:
|
765 |
+
bfactors.append(np.asarray(res[atom].get_bfactor()))
|
766 |
+
except:
|
767 |
+
continue
|
768 |
+
else:
|
769 |
+
raise NotImplementedError(f"Invalid method for reading coords: {method}")
|
770 |
+
if also_bfactors:
|
771 |
+
if normalize_bfactors: # Normalize over Calphas
|
772 |
+
mean_b = bfactors[..., 1].mean()
|
773 |
+
std_b = bfactors[..., 1].var().sqrt()
|
774 |
+
bfactors = (bfactors - mean_b) / (std_b + 1e-6)
|
775 |
+
return coords, bfactors
|
776 |
+
return coords
|
777 |
+
|
778 |
+
|
779 |
+
def feats_to_pdb_str(
|
780 |
+
atom_positions,
|
781 |
+
aatype=None,
|
782 |
+
atom_mask=None,
|
783 |
+
residue_index=None,
|
784 |
+
chain_index=None,
|
785 |
+
b_factors=None,
|
786 |
+
atom_lines_only=True,
|
787 |
+
conect=False,
|
788 |
+
**kwargs,
|
789 |
+
):
|
790 |
+
# Expects unbatched, cropped inputs. needs at least one of atom_mask, aatype
|
791 |
+
# Uses all-GLY aatype if aatype not given: does not infer from atom_mask
|
792 |
+
assert aatype is not None or atom_mask is not None
|
793 |
+
if atom_mask is None:
|
794 |
+
aatype = aatype.cpu()
|
795 |
+
atom_mask = atom37_mask_from_aatype(aatype, torch.ones_like(aatype))
|
796 |
+
if aatype is None:
|
797 |
+
seq_mask = atom_mask[:, residue_constants.atom_order["CA"]].cpu()
|
798 |
+
aatype = seq_mask * residue_constants.restype_order["G"]
|
799 |
+
if residue_index is None:
|
800 |
+
residue_index = torch.arange(aatype.shape[-1])
|
801 |
+
if chain_index is None:
|
802 |
+
chain_index = torch.ones_like(aatype)
|
803 |
+
if b_factors is None:
|
804 |
+
b_factors = torch.ones_like(atom_mask)
|
805 |
+
|
806 |
+
cast = lambda x: np.array(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
|
807 |
+
prot = protein.Protein(
|
808 |
+
atom_positions=cast(atom_positions),
|
809 |
+
atom_mask=cast(atom_mask),
|
810 |
+
aatype=cast(aatype),
|
811 |
+
residue_index=cast(residue_index),
|
812 |
+
chain_index=cast(chain_index),
|
813 |
+
b_factors=cast(b_factors),
|
814 |
+
)
|
815 |
+
pdb_str = protein.to_pdb(prot, conect=conect)
|
816 |
+
if conect:
|
817 |
+
pdb_str, conect_str = pdb_str
|
818 |
+
if atom_lines_only:
|
819 |
+
pdb_lines = pdb_str.split("\n")
|
820 |
+
atom_lines = [
|
821 |
+
l for l in pdb_lines if len(l.split()) > 1 and l.split()[0] == "ATOM"
|
822 |
+
]
|
823 |
+
pdb_str = "\n".join(atom_lines) + "\n"
|
824 |
+
if conect:
|
825 |
+
pdb_str = pdb_str + conect_str
|
826 |
+
return pdb_str
|
827 |
+
|
828 |
+
|
829 |
+
def bb_coords_to_pdb_str(coords, atoms=["N", "CA", "C", "O"]):
|
830 |
+
def _bb_pdb_line(atom, atomnum, resnum, coords, elem, res="GLY"):
|
831 |
+
atm = "ATOM".ljust(6)
|
832 |
+
atomnum = str(atomnum).rjust(5)
|
833 |
+
atomname = atom.center(4)
|
834 |
+
resname = res.ljust(3)
|
835 |
+
chain = "A".rjust(1)
|
836 |
+
resnum = str(resnum).rjust(4)
|
837 |
+
x = str("%8.3f" % (float(coords[0]))).rjust(8)
|
838 |
+
y = str("%8.3f" % (float(coords[1]))).rjust(8)
|
839 |
+
z = str("%8.3f" % (float(coords[2]))).rjust(8)
|
840 |
+
occ = str("%6.2f" % (float(1))).rjust(6)
|
841 |
+
temp = str("%6.2f" % (float(20))).ljust(6)
|
842 |
+
elname = elem.rjust(12)
|
843 |
+
return "%s%s %s %s %s%s %s%s%s%s%s%s\n" % (
|
844 |
+
atm,
|
845 |
+
atomnum,
|
846 |
+
atomname,
|
847 |
+
resname,
|
848 |
+
chain,
|
849 |
+
resnum,
|
850 |
+
x,
|
851 |
+
y,
|
852 |
+
z,
|
853 |
+
occ,
|
854 |
+
temp,
|
855 |
+
elname,
|
856 |
+
)
|
857 |
+
|
858 |
+
n = coords.shape[0]
|
859 |
+
na = len(atoms)
|
860 |
+
pdb_str = ""
|
861 |
+
for j in range(0, n, na):
|
862 |
+
for idx, atom in enumerate(atoms):
|
863 |
+
pdb_str += _bb_pdb_line(
|
864 |
+
atom,
|
865 |
+
j + idx + 1,
|
866 |
+
(j + na) // na,
|
867 |
+
coords[j + idx],
|
868 |
+
atom[0],
|
869 |
+
)
|
870 |
+
return pdb_str
|
871 |
+
|
872 |
+
|
873 |
+
def write_coords_to_pdb(
|
874 |
+
coords_in,
|
875 |
+
filename,
|
876 |
+
batched=True,
|
877 |
+
write_to_frames=False,
|
878 |
+
conect=False,
|
879 |
+
**all_atom_feats,
|
880 |
+
):
|
881 |
+
def _write_pdb_string(pdb_str, filename, append=False):
|
882 |
+
write_mode = "a" if append else "w"
|
883 |
+
with open(filename, write_mode) as f:
|
884 |
+
if write_to_frames:
|
885 |
+
f.write("MODEL\n")
|
886 |
+
f.write(pdb_str)
|
887 |
+
if write_to_frames:
|
888 |
+
f.write("ENDMDL\n")
|
889 |
+
|
890 |
+
if not (batched or write_to_frames):
|
891 |
+
coords_in = [coords_in]
|
892 |
+
filename = [filename]
|
893 |
+
all_atom_feats = {k: [v] for k, v in all_atom_feats.items()}
|
894 |
+
|
895 |
+
n_atoms_in = coords_in[0].shape[-2]
|
896 |
+
is_bb_or_ca_pdb = n_atoms_in <= 4
|
897 |
+
for i, c in enumerate(coords_in):
|
898 |
+
n_res = c.shape[0]
|
899 |
+
if isinstance(filename, list):
|
900 |
+
fname = filename[i]
|
901 |
+
elif write_to_frames or len(coords_in) == 1:
|
902 |
+
fname = filename
|
903 |
+
else:
|
904 |
+
fname = f"{filename[:-4]}_{i}.pdb"
|
905 |
+
|
906 |
+
if is_bb_or_ca_pdb:
|
907 |
+
c_flat = rearrange(c, "n a c -> (n a) c")
|
908 |
+
if n_atoms_in == 1:
|
909 |
+
atoms = ["CA"]
|
910 |
+
if n_atoms_in == 3:
|
911 |
+
atoms = ["N", "CA", "C"]
|
912 |
+
if n_atoms_in == 4:
|
913 |
+
atoms = ["N", "CA", "C", "O"]
|
914 |
+
pdb_str = bb_coords_to_pdb_str(c_flat, atoms)
|
915 |
+
else:
|
916 |
+
feats_i = {k: v[i][:n_res] for k, v in all_atom_feats.items()}
|
917 |
+
pdb_str = feats_to_pdb_str(c, conect=conect, **feats_i)
|
918 |
+
_write_pdb_string(pdb_str, fname, append=write_to_frames and i > 0)
|
919 |
+
|
920 |
+
|
921 |
+
###################### LOSSES ###################################
|
922 |
+
|
923 |
+
|
924 |
+
def masked_cross_entropy(logprobs, target, loss_mask):
|
925 |
+
# target is onehot
|
926 |
+
cel = -(target * logprobs)
|
927 |
+
cel = cel * loss_mask[..., None]
|
928 |
+
cel = cel.sum((-1, -2)) / loss_mask.sum(-1).clamp(min=1e-6)
|
929 |
+
return cel
|
930 |
+
|
931 |
+
|
932 |
+
def masked_mse(x, y, mask, weight=None):
|
933 |
+
data_dims = tuple(range(1, len(x.shape)))
|
934 |
+
mse = (x - y).pow(2) * mask
|
935 |
+
if weight is not None:
|
936 |
+
mse = mse * expand(weight, mse)
|
937 |
+
mse = mse.sum(data_dims) / mask.sum(data_dims).clamp(min=1e-6)
|
938 |
+
return mse
|
939 |
+
|
940 |
+
|
941 |
+
###################### ALIGN ###################################
|
942 |
+
|
943 |
+
|
944 |
+
def quick_tmalign(
|
945 |
+
p, p_sele, q_sele, tmscore_type="avg", differentiable_rmsd=False, rmsd_type="ca"
|
946 |
+
):
|
947 |
+
# sota 210712
|
948 |
+
write_coords_to_pdb(p_sele[:, 1:2], "temp_p.pdb", atoms=["CA"], batched=False)
|
949 |
+
write_coords_to_pdb(q_sele[:, 1:2], "temp_q.pdb", atoms=["CA"], batched=False)
|
950 |
+
cmd = f"{PATH_TO_TMALIGN} temp_p.pdb temp_q.pdb -m temp_matrix.txt"
|
951 |
+
outputs = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
|
952 |
+
|
953 |
+
# Get RMSD and TM scores
|
954 |
+
tmout = outputs.stdout.split("\n")
|
955 |
+
rmsd = float(tmout[16].split()[4][:-1])
|
956 |
+
tmscore1 = float(tmout[17].split()[1])
|
957 |
+
tmscore2 = float(tmout[18].split()[1])
|
958 |
+
if tmscore_type == "avg":
|
959 |
+
tmscore = (tmscore1 + tmscore2) / 2
|
960 |
+
elif tmscore_type == "1" or tmscore_type == "query":
|
961 |
+
tmscore = tmscore1
|
962 |
+
elif tmscore_type == "2":
|
963 |
+
tmscore = tmscore2
|
964 |
+
elif tmscore_type == "both":
|
965 |
+
tmscore = (tmscore1, tmscore2)
|
966 |
+
|
967 |
+
# Get R, t and transform p coords
|
968 |
+
m = open("temp_matrix.txt", "r").readlines()[2:5]
|
969 |
+
m = [l.strip()[1:].strip() for l in m]
|
970 |
+
m = torch.Tensor([[float(i) for i in l.split()] for l in m]).to(p_sele.device)
|
971 |
+
R = m[:, 1:].t()
|
972 |
+
t = m[:, 0]
|
973 |
+
aligned_psele = p_sele @ R + t
|
974 |
+
aligned = p @ R + t
|
975 |
+
|
976 |
+
# Option 2 for rms - MSE of aligned against target coords using TMalign seq alignment. Differentiable
|
977 |
+
if differentiable_rmsd:
|
978 |
+
pi, qi = 0, 0
|
979 |
+
p_idxs, q_idxs = [], []
|
980 |
+
for i, c in enumerate(tmout[23]):
|
981 |
+
if c in [":", "."]:
|
982 |
+
p_idxs.append(pi)
|
983 |
+
q_idxs.append(qi)
|
984 |
+
if tmout[22][i] != "-":
|
985 |
+
pi += 1
|
986 |
+
if tmout[24][i] != "-":
|
987 |
+
qi += 1
|
988 |
+
tmalign_seq_p = p_sele[p_idxs]
|
989 |
+
tmalign_seq_q = q_sele[q_idxs]
|
990 |
+
if rmsd_type == "ca":
|
991 |
+
tmalign_seq_p = tmalign_seq_p[:, 1]
|
992 |
+
tmalign_seq_q = tmalign_seq_q[:, 1]
|
993 |
+
elif rmsd_type == "bb":
|
994 |
+
pass
|
995 |
+
rmsd = (tmalign_seq_p - tmalign_seq_q).pow(2).sum(-1).sqrt().mean()
|
996 |
+
|
997 |
+
# Delete temp files: p.pdb, q.pdb, matrix.txt, tmalign.out
|
998 |
+
subprocess.run(shlex.split("rm temp_p.pdb"))
|
999 |
+
subprocess.run(shlex.split("rm temp_q.pdb"))
|
1000 |
+
subprocess.run(shlex.split("rm temp_matrix.txt"))
|
1001 |
+
|
1002 |
+
return {"aligned": aligned, "rmsd": rmsd, "tm_score": tmscore, "R": R, "t": t}
|
1003 |
+
|
1004 |
+
|
1005 |
+
###################### OTHER ###################################
|
1006 |
+
|
1007 |
+
|
1008 |
+
def expand(x, tgt=None, dim=1):
|
1009 |
+
if tgt is None:
|
1010 |
+
for _ in range(dim):
|
1011 |
+
x = x[..., None]
|
1012 |
+
else:
|
1013 |
+
while len(x.shape) < len(tgt.shape):
|
1014 |
+
x = x[..., None]
|
1015 |
+
return x
|
1016 |
+
|
1017 |
+
|
1018 |
+
def hookfn(name, verbose=False):
|
1019 |
+
def f(grad):
|
1020 |
+
if check_nan_inf(grad) > 0:
|
1021 |
+
print(name, "grad nan/infs", grad.shape, check_nan_inf(grad), grad)
|
1022 |
+
if verbose:
|
1023 |
+
print(name, "grad shape", grad.shape, "norm", grad.norm())
|
1024 |
+
|
1025 |
+
return f
|
1026 |
+
|
1027 |
+
|
1028 |
+
def trigger_nan_check(name, x):
|
1029 |
+
if check_nan_inf(x) > 0:
|
1030 |
+
print(name, check_nan_inf(x))
|
1031 |
+
raise Exception
|
1032 |
+
|
1033 |
+
|
1034 |
+
def check_nan_inf(x):
|
1035 |
+
return torch.isinf(x).sum() + torch.isnan(x).sum()
|
1036 |
+
|
1037 |
+
|
1038 |
+
def directory_find(atom, root="."):
|
1039 |
+
for path, dirs, files in os.walk(root):
|
1040 |
+
if atom in dirs:
|
1041 |
+
return os.path.join(path, atom)
|
1042 |
+
|
1043 |
+
|
1044 |
+
def dict2namespace(config):
|
1045 |
+
namespace = argparse.Namespace()
|
1046 |
+
for key, value in config.items():
|
1047 |
+
if isinstance(value, dict):
|
1048 |
+
new_value = dict2namespace(value)
|
1049 |
+
else:
|
1050 |
+
new_value = value
|
1051 |
+
setattr(namespace, key, new_value)
|
1052 |
+
return namespace
|
1053 |
+
|
1054 |
+
|
1055 |
+
def load_config(path, return_dict=False):
|
1056 |
+
with open(path, "r") as f:
|
1057 |
+
config_dict = yaml.safe_load(f)
|
1058 |
+
config = dict2namespace(config_dict)
|
1059 |
+
if return_dict:
|
1060 |
+
return config, config_dict
|
1061 |
+
else:
|
1062 |
+
return config
|
diffusion.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Noise and diffusion utils.
|
7 |
+
"""
|
8 |
+
from scipy.stats import norm
|
9 |
+
import torch
|
10 |
+
from torchtyping import TensorType
|
11 |
+
|
12 |
+
from core import utils
|
13 |
+
|
14 |
+
|
15 |
+
def noise_schedule(
|
16 |
+
time: TensorType[float],
|
17 |
+
function: str = "uniform",
|
18 |
+
sigma_data: float = 10.0,
|
19 |
+
psigma_mean: float = -1.2,
|
20 |
+
psigma_std: float = 1.2,
|
21 |
+
s_min: float = 0.001,
|
22 |
+
s_max: float = 60,
|
23 |
+
rho: float = 7.0,
|
24 |
+
time_power: float = 4.0,
|
25 |
+
constant_val: float = 0.0,
|
26 |
+
):
|
27 |
+
def sampling_noise(time):
|
28 |
+
# high noise = 1; low noise = 0. opposite of Karras et al. schedule
|
29 |
+
term1 = s_max ** (1 / rho)
|
30 |
+
term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho))
|
31 |
+
noise_level = sigma_data * ((term1 + term2) ** rho)
|
32 |
+
return noise_level
|
33 |
+
|
34 |
+
if function == "lognormal":
|
35 |
+
normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time)
|
36 |
+
noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample)
|
37 |
+
elif function == "uniform":
|
38 |
+
noise_level = sampling_noise(time)
|
39 |
+
elif function == "mpnn":
|
40 |
+
time = time**time_power
|
41 |
+
noise_level = sampling_noise(time)
|
42 |
+
elif function == "constant":
|
43 |
+
noise_level = torch.ones_like(time) * constant_val
|
44 |
+
return noise_level
|
45 |
+
|
46 |
+
|
47 |
+
def noise_coords(
|
48 |
+
coords: TensorType["b n a x", float],
|
49 |
+
noise_level: TensorType["b", float],
|
50 |
+
dummy_fill_masked_atoms: bool = False,
|
51 |
+
atom_mask: TensorType["b n a"] = None,
|
52 |
+
):
|
53 |
+
# Does not apply atom mask after adding noise
|
54 |
+
if dummy_fill_masked_atoms:
|
55 |
+
assert atom_mask is not None
|
56 |
+
dummy_fill_mask = 1 - atom_mask
|
57 |
+
dummy_fill_value = coords[..., 1:2, :] # CA
|
58 |
+
# dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB
|
59 |
+
coords = (
|
60 |
+
coords * atom_mask[..., None]
|
61 |
+
+ dummy_fill_value * dummy_fill_mask[..., None]
|
62 |
+
)
|
63 |
+
|
64 |
+
noise = torch.randn_like(coords) * utils.expand(noise_level, coords)
|
65 |
+
noisy_coords = coords + noise
|
66 |
+
return noisy_coords
|
draw_samples.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Entry point for unconditional or simple conditional sampling.
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
from datetime import datetime
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
import shlex
|
13 |
+
import subprocess
|
14 |
+
import sys
|
15 |
+
import time
|
16 |
+
|
17 |
+
from einops import repeat
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from core import data
|
21 |
+
from core import residue_constants
|
22 |
+
from core import utils
|
23 |
+
import diffusion
|
24 |
+
import models
|
25 |
+
import sampling
|
26 |
+
|
27 |
+
|
28 |
+
def draw_and_save_samples(
|
29 |
+
model,
|
30 |
+
samples_per_len=8,
|
31 |
+
lengths=range(50, 512),
|
32 |
+
save_dir="./",
|
33 |
+
mode="backbone",
|
34 |
+
**sampling_kwargs,
|
35 |
+
):
|
36 |
+
device = model.device
|
37 |
+
if mode == "backbone":
|
38 |
+
total_sampling_time = 0
|
39 |
+
for l in lengths:
|
40 |
+
prot_lens = torch.ones(samples_per_len).long() * l
|
41 |
+
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
|
42 |
+
aux = sampling.draw_backbone_samples(
|
43 |
+
model,
|
44 |
+
seq_mask=seq_mask,
|
45 |
+
pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp",
|
46 |
+
return_aux=True,
|
47 |
+
return_sampling_runtime=True,
|
48 |
+
**sampling_kwargs,
|
49 |
+
)
|
50 |
+
total_sampling_time += aux["runtime"]
|
51 |
+
print("Samples drawn for length", l)
|
52 |
+
return total_sampling_time
|
53 |
+
elif mode == "allatom":
|
54 |
+
total_sampling_time = 0
|
55 |
+
for l in lengths:
|
56 |
+
prot_lens = torch.ones(samples_per_len).long() * l
|
57 |
+
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
|
58 |
+
aux = sampling.draw_allatom_samples(
|
59 |
+
model,
|
60 |
+
seq_mask=seq_mask,
|
61 |
+
pdb_save_path=f"{save_dir}/len{format(l, '03d')}",
|
62 |
+
return_aux=True,
|
63 |
+
**sampling_kwargs,
|
64 |
+
)
|
65 |
+
total_sampling_time += aux["runtime"]
|
66 |
+
print("Samples drawn for length", l)
|
67 |
+
return total_sampling_time
|
68 |
+
|
69 |
+
|
70 |
+
def parse_idx_string(idx_str):
|
71 |
+
spans = idx_str.split(",")
|
72 |
+
idxs = []
|
73 |
+
for s in spans:
|
74 |
+
if "-" in s:
|
75 |
+
start, stop = s.split("-")
|
76 |
+
idxs.extend(list(range(int(start), int(stop))))
|
77 |
+
else:
|
78 |
+
idxs.append(int(s))
|
79 |
+
return idxs
|
80 |
+
|
81 |
+
|
82 |
+
class Manager(object):
|
83 |
+
def __init__(self):
|
84 |
+
self.parser = argparse.ArgumentParser(
|
85 |
+
formatter_class=argparse.RawTextHelpFormatter
|
86 |
+
)
|
87 |
+
|
88 |
+
self.parser.add_argument(
|
89 |
+
"--model_checkpoint",
|
90 |
+
type=str,
|
91 |
+
default="checkpoints",
|
92 |
+
help="Path to denoiser model weights and config",
|
93 |
+
)
|
94 |
+
self.parser.add_argument(
|
95 |
+
"--mpnnpath",
|
96 |
+
type=str,
|
97 |
+
default="checkpoints/minimpnn_state_dict.pth",
|
98 |
+
help="Path to minimpnn model weights",
|
99 |
+
)
|
100 |
+
self.parser.add_argument(
|
101 |
+
"--modeldir",
|
102 |
+
type=str,
|
103 |
+
help="Model base directory, ex 'training_logs/other/lemon-shape-51'",
|
104 |
+
)
|
105 |
+
self.parser.add_argument("--modelepoch", type=int, help="Model epoch, ex 1000")
|
106 |
+
self.parser.add_argument(
|
107 |
+
"--type", type=str, default="allatom", help="Type of model"
|
108 |
+
)
|
109 |
+
self.parser.add_argument(
|
110 |
+
"--param", type=str, default=None, help="Which sampling param to vary"
|
111 |
+
)
|
112 |
+
self.parser.add_argument(
|
113 |
+
"--paramval", type=str, default=None, help="Which param val to use"
|
114 |
+
)
|
115 |
+
self.parser.add_argument(
|
116 |
+
"--parampath",
|
117 |
+
type=str,
|
118 |
+
default=None,
|
119 |
+
help="Path to json file with params, either use param/paramval or parampath, not both",
|
120 |
+
)
|
121 |
+
self.parser.add_argument(
|
122 |
+
"--perlen", type=int, default=2, help="How many samples per sequence length"
|
123 |
+
)
|
124 |
+
self.parser.add_argument(
|
125 |
+
"--minlen", type=int, required=False, help="Minimum sequence length"
|
126 |
+
)
|
127 |
+
self.parser.add_argument(
|
128 |
+
"--maxlen",
|
129 |
+
type=int,
|
130 |
+
required=False,
|
131 |
+
help="Maximum sequence length, not inclusive",
|
132 |
+
)
|
133 |
+
self.parser.add_argument(
|
134 |
+
"--steplen",
|
135 |
+
type=int,
|
136 |
+
required=False,
|
137 |
+
help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
|
138 |
+
)
|
139 |
+
self.parser.add_argument(
|
140 |
+
"--num_lens",
|
141 |
+
type=int,
|
142 |
+
required=False,
|
143 |
+
help="If steplen not provided, how many random lengths to sample at",
|
144 |
+
)
|
145 |
+
self.parser.add_argument(
|
146 |
+
"--targetdir", type=str, default=".", help="Directory to save results"
|
147 |
+
)
|
148 |
+
self.parser.add_argument(
|
149 |
+
"--input_pdb", type=str, required=False, help="PDB file to condition on"
|
150 |
+
)
|
151 |
+
self.parser.add_argument(
|
152 |
+
"--resample_idxs",
|
153 |
+
type=str,
|
154 |
+
required=False,
|
155 |
+
help="Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7",
|
156 |
+
)
|
157 |
+
|
158 |
+
def add_argument(self, *args, **kwargs):
|
159 |
+
self.parser.add_argument(*args, **kwargs)
|
160 |
+
|
161 |
+
def parse_args(self):
|
162 |
+
self.args = self.parser.parse_args()
|
163 |
+
|
164 |
+
return self.args
|
165 |
+
|
166 |
+
|
167 |
+
def main():
|
168 |
+
# Set up params, arguments, sampling config
|
169 |
+
####################
|
170 |
+
manager = Manager()
|
171 |
+
manager.parse_args()
|
172 |
+
args = manager.args
|
173 |
+
print(args)
|
174 |
+
is_test_run = False
|
175 |
+
seed = 0
|
176 |
+
samples_per_len = args.perlen
|
177 |
+
min_len = args.minlen
|
178 |
+
max_len = args.maxlen
|
179 |
+
len_step_size = args.steplen
|
180 |
+
device = "cuda:0"
|
181 |
+
|
182 |
+
# setting default sampling config
|
183 |
+
if args.type == "backbone":
|
184 |
+
sampling_config = sampling.default_backbone_sampling_config()
|
185 |
+
elif args.type == "allatom":
|
186 |
+
sampling_config = sampling.default_allatom_sampling_config()
|
187 |
+
|
188 |
+
sampling_kwargs = vars(sampling_config)
|
189 |
+
|
190 |
+
# Parse conditioning inputs
|
191 |
+
input_pdb_len = None
|
192 |
+
if args.input_pdb:
|
193 |
+
input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True)
|
194 |
+
input_pdb_len = input_feats["aatype"].shape[0]
|
195 |
+
if args.resample_idxs:
|
196 |
+
print(
|
197 |
+
f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths."
|
198 |
+
)
|
199 |
+
resample_idxs = parse_idx_string(args.resample_idxs)
|
200 |
+
else:
|
201 |
+
resample_idxs = list(range(input_pdb_len))
|
202 |
+
cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs]
|
203 |
+
to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to(
|
204 |
+
device
|
205 |
+
)
|
206 |
+
|
207 |
+
# For unconditional model, center coords on whole structure
|
208 |
+
centered_coords = data.apply_random_se3(
|
209 |
+
input_feats["atom_positions"],
|
210 |
+
atom_mask=input_feats["atom_mask"],
|
211 |
+
translation_scale=0.0,
|
212 |
+
)
|
213 |
+
cond_kwargs = {}
|
214 |
+
cond_kwargs["gt_coords"] = to_batch_size(centered_coords)
|
215 |
+
cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"])
|
216 |
+
cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0
|
217 |
+
cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"])
|
218 |
+
cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"])
|
219 |
+
cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1
|
220 |
+
sampling_kwargs.update(cond_kwargs)
|
221 |
+
|
222 |
+
# Determine lengths to sample at
|
223 |
+
if min_len is not None and max_len is not None:
|
224 |
+
if len_step_size is not None:
|
225 |
+
sampling_lengths = range(min_len, max_len, len_step_size)
|
226 |
+
else:
|
227 |
+
sampling_lengths = list(
|
228 |
+
torch.randint(min_len, max_len, size=(args.num_lens,))
|
229 |
+
)
|
230 |
+
elif input_pdb_len is not None:
|
231 |
+
sampling_lengths = [input_pdb_len]
|
232 |
+
else:
|
233 |
+
raise Exception("Need to provide a set of protein lengths or an input pdb.")
|
234 |
+
|
235 |
+
total_num_samples = len(list(sampling_lengths)) * samples_per_len
|
236 |
+
|
237 |
+
model_directory = args.modeldir
|
238 |
+
epoch = args.modelepoch
|
239 |
+
base_dir = args.targetdir
|
240 |
+
|
241 |
+
date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
|
242 |
+
if is_test_run:
|
243 |
+
date_string = f"test-{date_string}"
|
244 |
+
|
245 |
+
# Update sampling config with arguments
|
246 |
+
if args.param:
|
247 |
+
var_param = args.param
|
248 |
+
var_value = args.paramval
|
249 |
+
sampling_kwargs[var_param] = (
|
250 |
+
None
|
251 |
+
if var_value == "None"
|
252 |
+
else int(var_value)
|
253 |
+
if var_param == "n_steps"
|
254 |
+
else float(var_value)
|
255 |
+
)
|
256 |
+
elif args.parampath:
|
257 |
+
with open(args.parampath) as f:
|
258 |
+
var_params = json.loads(f.read())
|
259 |
+
sampling_kwargs.update(var_params)
|
260 |
+
|
261 |
+
# this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule
|
262 |
+
sampling_kwargs_readme = list(sampling_kwargs.items())
|
263 |
+
|
264 |
+
print("Base directory:", base_dir)
|
265 |
+
save_dir = f"{base_dir}/samples"
|
266 |
+
save_init_dir = f"{base_dir}/samples_inits"
|
267 |
+
|
268 |
+
print("Samples saved to:", save_dir)
|
269 |
+
####################
|
270 |
+
|
271 |
+
torch.manual_seed(seed)
|
272 |
+
if not os.path.exists(save_dir):
|
273 |
+
subprocess.run(shlex.split(f"mkdir -p {save_dir}"))
|
274 |
+
|
275 |
+
if not os.path.exists(save_init_dir):
|
276 |
+
subprocess.run(shlex.split(f"mkdir -p {save_init_dir}"))
|
277 |
+
|
278 |
+
# Load model
|
279 |
+
if args.type == "backbone":
|
280 |
+
if args.model_checkpoint:
|
281 |
+
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
282 |
+
cfg_path = f"{args.model_checkpoint}/backbone.yml"
|
283 |
+
else:
|
284 |
+
checkpoint = (
|
285 |
+
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
286 |
+
)
|
287 |
+
cfg_path = f"{model_directory}/configs/backbone.yml"
|
288 |
+
cfg = utils.load_config(cfg_path)
|
289 |
+
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
290 |
+
model = models.Protpardelle(cfg, device=device)
|
291 |
+
model.load_state_dict(weights)
|
292 |
+
model.to(device)
|
293 |
+
model.eval()
|
294 |
+
model.device = device
|
295 |
+
elif args.type == "allatom":
|
296 |
+
if args.model_checkpoint:
|
297 |
+
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
298 |
+
cfg_path = f"{args.model_checkpoint}/allatom.yml"
|
299 |
+
else:
|
300 |
+
checkpoint = (
|
301 |
+
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
302 |
+
)
|
303 |
+
cfg_path = f"{model_directory}/configs/allatom.yml"
|
304 |
+
config = utils.load_config(cfg_path)
|
305 |
+
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
306 |
+
model = models.Protpardelle(config, device=device)
|
307 |
+
model.load_state_dict(weights)
|
308 |
+
model.load_minimpnn(args.mpnnpath)
|
309 |
+
model.to(device)
|
310 |
+
model.eval()
|
311 |
+
model.device = device
|
312 |
+
|
313 |
+
# Sampling
|
314 |
+
with open(base_dir + "/readme.txt", "w") as f:
|
315 |
+
f.write(f"Sampling run for {date_string}\n")
|
316 |
+
f.write(f"Random seed {seed}\n")
|
317 |
+
f.write(f"Model checkpoint: {checkpoint}\n")
|
318 |
+
f.write(
|
319 |
+
f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n"
|
320 |
+
)
|
321 |
+
f.write("Sampling params:\n")
|
322 |
+
for k, v in sampling_kwargs_readme:
|
323 |
+
f.write(f"{k}\t{v}\n")
|
324 |
+
|
325 |
+
print(f"Model loaded from {checkpoint}")
|
326 |
+
print(f"Beginning sampling for {date_string}...")
|
327 |
+
|
328 |
+
# Draw samples
|
329 |
+
start_time = time.time()
|
330 |
+
sampling_time = draw_and_save_samples(
|
331 |
+
model,
|
332 |
+
samples_per_len=samples_per_len,
|
333 |
+
lengths=sampling_lengths,
|
334 |
+
save_dir=save_dir,
|
335 |
+
mode=args.type,
|
336 |
+
**sampling_kwargs,
|
337 |
+
)
|
338 |
+
time_elapsed = time.time() - start_time
|
339 |
+
|
340 |
+
print(f"Sampling concluded after {time_elapsed} seconds.")
|
341 |
+
print(f"Of this, {sampling_time} seconds were for actual sampling.")
|
342 |
+
print(f"{total_num_samples} total samples were drawn.")
|
343 |
+
|
344 |
+
with open(base_dir + "/readme.txt", "a") as f:
|
345 |
+
f.write(f"Total job time: {time_elapsed} seconds\n")
|
346 |
+
f.write(f"Model run time: {sampling_time} seconds\n")
|
347 |
+
f.write(f"Total samples drawn: {total_num_samples}\n")
|
348 |
+
|
349 |
+
return
|
350 |
+
|
351 |
+
|
352 |
+
if __name__ == "__main__":
|
353 |
+
main()
|
evaluation.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Utils for computing evaluation metrics.
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import warnings
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
from Bio.Align import substitution_matrices
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from transformers import AutoTokenizer, EsmForProteinFolding
|
17 |
+
from torchtyping import TensorType
|
18 |
+
|
19 |
+
from core import residue_constants
|
20 |
+
from core import utils
|
21 |
+
from core import protein_mpnn as mpnn
|
22 |
+
import modules
|
23 |
+
import sampling
|
24 |
+
|
25 |
+
|
26 |
+
def mean(x):
|
27 |
+
if len(x) == 0:
|
28 |
+
return 0
|
29 |
+
return sum(x) / len(x)
|
30 |
+
|
31 |
+
|
32 |
+
def calculate_seq_identity(seq1, seq2, seq_mask=None):
|
33 |
+
identity = (seq1 == seq2.to(seq1)).float()
|
34 |
+
if seq_mask is not None:
|
35 |
+
identity *= seq_mask.to(seq1)
|
36 |
+
return identity.sum(-1) / seq_mask.to(seq1).sum(-1).clamp(min=1)
|
37 |
+
else:
|
38 |
+
return identity.mean(-1)
|
39 |
+
|
40 |
+
|
41 |
+
def design_sequence(coords, model=None, num_seqs=1, disallow_aas=["C"]):
|
42 |
+
# Returns list of strs; seqs like 'MKRLLDS', not aatypes
|
43 |
+
if model is None:
|
44 |
+
model = mpnn.get_mpnn_model()
|
45 |
+
if isinstance(coords, str):
|
46 |
+
temp_pdb = False
|
47 |
+
pdb_fn = coords
|
48 |
+
else:
|
49 |
+
temp_pdb = True
|
50 |
+
pdb_fn = f"tmp{np.random.randint(0, 1e8)}.pdb"
|
51 |
+
gly_idx = residue_constants.restype_order["G"]
|
52 |
+
gly_aatype = (torch.ones(coords.shape[0]) * gly_idx).long()
|
53 |
+
utils.write_coords_to_pdb(coords, pdb_fn, batched=False, aatype=gly_aatype)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
designed_seqs = mpnn.run_proteinmpnn(
|
57 |
+
model=model,
|
58 |
+
pdb_path=pdb_fn,
|
59 |
+
num_seq_per_target=num_seqs,
|
60 |
+
omit_AAs=disallow_aas,
|
61 |
+
)
|
62 |
+
|
63 |
+
if temp_pdb:
|
64 |
+
os.system("rm " + pdb_fn)
|
65 |
+
return designed_seqs
|
66 |
+
|
67 |
+
|
68 |
+
def get_esmfold_model(device=None):
|
69 |
+
if device is None:
|
70 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
+
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
|
72 |
+
model.esm = model.esm.half()
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def inference_esmfold(sequence_list, model, tokenizer):
|
77 |
+
inputs = tokenizer(
|
78 |
+
sequence_list,
|
79 |
+
return_tensors="pt",
|
80 |
+
padding=True,
|
81 |
+
add_special_tokens=False,
|
82 |
+
).to(model.device)
|
83 |
+
outputs = model(**inputs)
|
84 |
+
# positions is shape (l, b, n, a, c)
|
85 |
+
pred_coords = outputs.positions[-1].contiguous()
|
86 |
+
plddts = (outputs.plddt[:, :, 1] * inputs.attention_mask).sum(
|
87 |
+
-1
|
88 |
+
) / inputs.attention_mask.sum(-1).clamp(min=1e-3)
|
89 |
+
return pred_coords, plddts
|
90 |
+
|
91 |
+
|
92 |
+
def predict_structures(sequences, model="esmfold", tokenizer=None, force_unk_to_X=True):
|
93 |
+
# Expects seqs like 'MKRLLDS', not aatypes
|
94 |
+
# model can be a model, or a string describing which pred model to load
|
95 |
+
if isinstance(sequences, str):
|
96 |
+
sequences = [sequences]
|
97 |
+
if model == "esmfold":
|
98 |
+
model = get_esmfold_model()
|
99 |
+
device = model.device
|
100 |
+
if tokenizer is None:
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
|
102 |
+
|
103 |
+
aatype = [utils.seq_to_aatype(seq).to(device) for seq in sequences]
|
104 |
+
|
105 |
+
with torch.no_grad():
|
106 |
+
if isinstance(model, EsmForProteinFolding):
|
107 |
+
pred_coords, plddts = inference_esmfold(sequences, model, tokenizer)
|
108 |
+
|
109 |
+
seq_lens = [len(s) for s in sequences]
|
110 |
+
trimmed_coords = [c[: seq_lens[i]] for i, c in enumerate(pred_coords)]
|
111 |
+
trimmed_coords_atom37 = [
|
112 |
+
utils.atom37_coords_from_atom14(c, aatype[i])
|
113 |
+
for i, c in enumerate(trimmed_coords)
|
114 |
+
]
|
115 |
+
return trimmed_coords_atom37, plddts
|
116 |
+
|
117 |
+
|
118 |
+
def compute_structure_metric(coords1, coords2, metric="ca_rmsd", atom_mask=None):
|
119 |
+
# coords1 tensor[l][a][3]
|
120 |
+
def _tmscore(a, b, mask=None):
|
121 |
+
length = len(b)
|
122 |
+
dists = (a - b).pow(2).sum(-1)
|
123 |
+
d0 = 1.24 * ((length - 15) ** (1 / 3)) - 1.8
|
124 |
+
term = 1 / (1 + ((dists) / (d0**2)))
|
125 |
+
if mask is None:
|
126 |
+
return term.mean()
|
127 |
+
else:
|
128 |
+
term = term * mask
|
129 |
+
return term.sum() / mask.sum().clamp(min=1)
|
130 |
+
|
131 |
+
aligned_coords1_ca, (R, t) = utils.kabsch_align(coords1[:, 1], coords2[:, 1])
|
132 |
+
aligned_coords1 = coords1 - coords1[:, 1:2].mean(0, keepdim=True)
|
133 |
+
aligned_coords1 = aligned_coords1 @ R.t() + t
|
134 |
+
if metric == "ca_rmsd":
|
135 |
+
return (aligned_coords1_ca - coords2[:, 1]).pow(2).sum(-1).sqrt().mean()
|
136 |
+
elif metric == "tm_score":
|
137 |
+
tm = _tmscore(aligned_coords1_ca, coords2[:, 1])
|
138 |
+
# TODO: return 1 - tm score for now so sorts work properly
|
139 |
+
return 1 - tm
|
140 |
+
elif metric == "allatom_tm":
|
141 |
+
# Align on Ca, compute allatom TM
|
142 |
+
assert atom_mask is not None
|
143 |
+
return _tmscore(aligned_coords1, coords2, mask=atom_mask)
|
144 |
+
elif metric == "allatom_lddt":
|
145 |
+
assert atom_mask is not None
|
146 |
+
lddt = modules.lddt(
|
147 |
+
coords1.reshape(-1, 3),
|
148 |
+
coords2.reshape(-1, 3),
|
149 |
+
atom_mask.reshape(-1, 1),
|
150 |
+
per_residue=False,
|
151 |
+
)
|
152 |
+
return lddt
|
153 |
+
else:
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
|
157 |
+
def compute_self_consistency(
|
158 |
+
comparison_structures, # can be sampled or ground truth
|
159 |
+
sampled_sequences=None,
|
160 |
+
mpnn_model=None,
|
161 |
+
struct_pred_model=None,
|
162 |
+
tokenizer=None,
|
163 |
+
num_seqs=1,
|
164 |
+
return_aux=False,
|
165 |
+
metric="ca_rmsd",
|
166 |
+
output_file=None,
|
167 |
+
):
|
168 |
+
# Typically used for eval of backbone sampling or sequence design or joint sampling
|
169 |
+
# (Maybe MPNN) + Fold + TM/RMSD
|
170 |
+
# Expects seqs like 'MKRLLDS', not aatypes
|
171 |
+
per_sample_primary_metrics = []
|
172 |
+
per_sample_secondary_metrics = []
|
173 |
+
per_sample_plddts = []
|
174 |
+
per_sample_coords = []
|
175 |
+
per_sample_seqs = []
|
176 |
+
aux = {}
|
177 |
+
for i, coords in enumerate(comparison_structures):
|
178 |
+
if sampled_sequences is None:
|
179 |
+
seqs_to_predict = design_sequence(
|
180 |
+
coords, model=mpnn_model, num_seqs=num_seqs
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
seqs_to_predict = sampled_sequences[i]
|
184 |
+
pred_coords, plddts = predict_structures(
|
185 |
+
seqs_to_predict, model=struct_pred_model, tokenizer=tokenizer
|
186 |
+
)
|
187 |
+
primary_metric_name = "tm_score" if metric == "tm_score" else "ca_rmsd"
|
188 |
+
secondary_metric_name = "tm_score" if metric == "both" else None
|
189 |
+
primary_metrics = [
|
190 |
+
compute_structure_metric(coords.to(pred), pred, metric=primary_metric_name)
|
191 |
+
for pred in pred_coords
|
192 |
+
]
|
193 |
+
if secondary_metric_name:
|
194 |
+
secondary_metrics = [
|
195 |
+
compute_structure_metric(
|
196 |
+
coords.to(pred), pred, metric=secondary_metric_name
|
197 |
+
)
|
198 |
+
for pred in pred_coords
|
199 |
+
]
|
200 |
+
aux.setdefault(secondary_metric_name, []).extend(secondary_metrics)
|
201 |
+
else:
|
202 |
+
secondary_metrics = primary_metrics
|
203 |
+
|
204 |
+
aux.setdefault("pred", []).extend(pred_coords)
|
205 |
+
seqs_to_predict_arr = seqs_to_predict
|
206 |
+
if isinstance(seqs_to_predict_arr, str):
|
207 |
+
seqs_to_predict_arr = [seqs_to_predict_arr]
|
208 |
+
|
209 |
+
aux.setdefault("seqs", []).extend(seqs_to_predict_arr)
|
210 |
+
aux.setdefault("plddt", []).extend(plddts)
|
211 |
+
aux.setdefault("rmsd", []).extend(primary_metrics)
|
212 |
+
|
213 |
+
# Report best rmsd design only among MPNN reps
|
214 |
+
all_designs = [
|
215 |
+
(m, p, t, c, s)
|
216 |
+
for m, p, t, c, s in zip(
|
217 |
+
primary_metrics,
|
218 |
+
plddts,
|
219 |
+
secondary_metrics,
|
220 |
+
pred_coords,
|
221 |
+
seqs_to_predict_arr,
|
222 |
+
)
|
223 |
+
]
|
224 |
+
best_rmsd_design = min(all_designs, key=lambda x: x[0])
|
225 |
+
per_sample_primary_metrics.append(best_rmsd_design[0].detach().cpu())
|
226 |
+
per_sample_plddts.append(best_rmsd_design[1].detach().cpu())
|
227 |
+
per_sample_secondary_metrics.append(best_rmsd_design[2].detach().cpu())
|
228 |
+
per_sample_coords.append(best_rmsd_design[3])
|
229 |
+
per_sample_seqs.append(best_rmsd_design[4])
|
230 |
+
best_idx = np.argmin(per_sample_primary_metrics)
|
231 |
+
metrics = {
|
232 |
+
"sc_rmsd_best": per_sample_primary_metrics[best_idx],
|
233 |
+
"sc_plddt_best": per_sample_plddts[best_idx],
|
234 |
+
"sc_rmsd_mean": mean(per_sample_primary_metrics),
|
235 |
+
"sc_plddt_mean": mean(per_sample_plddts),
|
236 |
+
}
|
237 |
+
if metric == "both":
|
238 |
+
metrics["sc_tmscore_best"] = per_sample_secondary_metrics[best_idx]
|
239 |
+
metrics["sc_tmscore_mean"] = mean(per_sample_secondary_metrics)
|
240 |
+
|
241 |
+
if output_file:
|
242 |
+
pred_coords = per_sample_coords
|
243 |
+
designed_seqs = per_sample_seqs
|
244 |
+
|
245 |
+
if torch.isnan(pred_coords[best_idx]).sum() == 0:
|
246 |
+
designed_seq = utils.seq_to_aatype(designed_seqs[best_idx])
|
247 |
+
utils.write_coords_to_pdb(
|
248 |
+
pred_coords[best_idx],
|
249 |
+
output_file,
|
250 |
+
batched=False,
|
251 |
+
aatype=designed_seq,
|
252 |
+
)
|
253 |
+
|
254 |
+
if return_aux:
|
255 |
+
return metrics, best_idx, aux
|
256 |
+
else:
|
257 |
+
return metrics, best_idx
|
258 |
+
|
259 |
+
|
260 |
+
def compute_secondary_structure_content(coords_batch):
|
261 |
+
dssp_sample = []
|
262 |
+
for i, c in enumerate(coords_batch):
|
263 |
+
with warnings.catch_warnings():
|
264 |
+
warnings.simplefilter("ignore")
|
265 |
+
dssp_str = utils.get_3state_dssp(coords=c)
|
266 |
+
if dssp_str is None or len(dssp_str) == 0:
|
267 |
+
pass
|
268 |
+
else:
|
269 |
+
dssp_sample.append(dssp_str)
|
270 |
+
dssp_sample = "".join(dssp_sample)
|
271 |
+
metrics = {}
|
272 |
+
metrics["sample_pct_beta"] = mean([c == "E" for c in dssp_sample])
|
273 |
+
metrics["sample_pct_alpha"] = mean([c == "H" for c in dssp_sample])
|
274 |
+
return metrics
|
275 |
+
|
276 |
+
|
277 |
+
def compute_bond_length_metric(
|
278 |
+
cropped_coords_list, cropped_aatypes_list, atom_mask=None
|
279 |
+
):
|
280 |
+
bond_length_dict = utils.batched_fullatom_bond_lengths_from_coords(
|
281 |
+
cropped_coords_list, cropped_aatypes_list, atom_mask=atom_mask
|
282 |
+
)
|
283 |
+
all_errors = {}
|
284 |
+
for aa1, d in bond_length_dict.items():
|
285 |
+
aa3 = residue_constants.restype_1to3[aa1]
|
286 |
+
per_bond_errors = []
|
287 |
+
for bond, lengths in d.items():
|
288 |
+
a1, a2 = bond.split("-")
|
289 |
+
ideal_val = None
|
290 |
+
for bond in residue_constants.standard_residue_bonds[aa3]:
|
291 |
+
if (
|
292 |
+
bond.atom1_name == a1
|
293 |
+
and bond.atom2_name == a2
|
294 |
+
or bond.atom1_name == a2
|
295 |
+
and bond.atom2_name == a1
|
296 |
+
):
|
297 |
+
ideal_val = bond.length
|
298 |
+
break
|
299 |
+
error = (np.array(lengths) - ideal_val) ** 2
|
300 |
+
per_bond_errors.append(error.mean() ** 0.5)
|
301 |
+
if len(per_bond_errors) > 0: # often no Cys
|
302 |
+
per_res_errors = np.mean(per_bond_errors)
|
303 |
+
all_errors[aa1] = per_res_errors
|
304 |
+
return np.mean(list(all_errors.values()))
|
305 |
+
|
306 |
+
|
307 |
+
def evaluate_backbone_generation(
|
308 |
+
model,
|
309 |
+
n_samples=1,
|
310 |
+
mpnn_model=None,
|
311 |
+
struct_pred_model=None,
|
312 |
+
tokenizer=None,
|
313 |
+
sample_length_range=(50, 512),
|
314 |
+
):
|
315 |
+
sampling_config = sampling.default_backbone_sampling_config()
|
316 |
+
trimmed_coords, seq_mask = sampling.draw_backbone_samples(
|
317 |
+
model,
|
318 |
+
n_samples=n_samples,
|
319 |
+
sample_length_range=sample_length_range,
|
320 |
+
**vars(sampling_config),
|
321 |
+
)
|
322 |
+
sc_metrics, best_idx, aux = compute_self_consistency(
|
323 |
+
trimmed_coords,
|
324 |
+
mpnn_model=mpnn_model,
|
325 |
+
struct_pred_model=struct_pred_model,
|
326 |
+
tokenizer=tokenizer,
|
327 |
+
return_aux=True,
|
328 |
+
)
|
329 |
+
dssp_metrics = compute_secondary_structure_content(trimmed_coords)
|
330 |
+
all_metrics = {**sc_metrics, **dssp_metrics}
|
331 |
+
all_metrics = {f"bb_{k}": v for k, v in all_metrics.items()}
|
332 |
+
return all_metrics, (trimmed_coords, seq_mask, best_idx, aux["pred"], aux["seqs"])
|
333 |
+
|
334 |
+
|
335 |
+
def evaluate_allatom_generation(
|
336 |
+
model,
|
337 |
+
n_samples,
|
338 |
+
two_stage_sampling=True,
|
339 |
+
struct_pred_model=None,
|
340 |
+
tokenizer=None,
|
341 |
+
sample_length_range=(50, 512),
|
342 |
+
):
|
343 |
+
# Convert allatom model to codesign model by loading miniMPNN
|
344 |
+
model.task = "codesign"
|
345 |
+
model.load_minimpnn()
|
346 |
+
model.eval()
|
347 |
+
|
348 |
+
sampling_config = sampling.default_allatom_sampling_config()
|
349 |
+
ret = sampling.draw_allatom_samples(
|
350 |
+
model,
|
351 |
+
n_samples=n_samples,
|
352 |
+
two_stage_sampling=two_stage_sampling,
|
353 |
+
**vars(sampling_config),
|
354 |
+
)
|
355 |
+
(
|
356 |
+
cropped_samp_coords,
|
357 |
+
cropped_samp_aatypes,
|
358 |
+
samp_atom_mask,
|
359 |
+
stage1_coords,
|
360 |
+
seq_mask,
|
361 |
+
) = ret
|
362 |
+
|
363 |
+
# Compute self consistency
|
364 |
+
if struct_pred_model is None:
|
365 |
+
struct_pred_model = EsmForProteinFolding.from_pretrained(
|
366 |
+
"facebook/esmfold_v1"
|
367 |
+
).to(device)
|
368 |
+
struct_pred_model.esm = struct_pred_model.esm.half()
|
369 |
+
if tokenizer is None:
|
370 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
|
371 |
+
designed_seqs = [utils.aatype_to_seq(a) for a in cropped_samp_aatypes]
|
372 |
+
sc_metrics, best_idx, sc_aux = compute_self_consistency(
|
373 |
+
comparison_structures=cropped_samp_coords,
|
374 |
+
sampled_sequences=designed_seqs,
|
375 |
+
struct_pred_model=struct_pred_model,
|
376 |
+
tokenizer=tokenizer,
|
377 |
+
return_aux=True,
|
378 |
+
)
|
379 |
+
aa_metrics_out = {f"aa_{k}": v for k, v in sc_metrics.items()}
|
380 |
+
|
381 |
+
# Compute secondary structure content
|
382 |
+
cropped_bb_coords = [c[..., [0, 1, 2, 4], :] for c in cropped_samp_coords]
|
383 |
+
dssp_metrics = compute_secondary_structure_content(cropped_bb_coords)
|
384 |
+
aa_metrics_out = {**aa_metrics_out, **dssp_metrics}
|
385 |
+
|
386 |
+
# Compute bond length RMSE
|
387 |
+
if two_stage_sampling: # compute on original sample
|
388 |
+
bond_rmse_coords = stage1_coords
|
389 |
+
else:
|
390 |
+
bond_rmse_coords = cropped_samp_coords
|
391 |
+
bond_rmse = compute_bond_length_metric(
|
392 |
+
bond_rmse_coords, cropped_samp_aatypes, samp_atom_mask
|
393 |
+
)
|
394 |
+
aa_metrics_out["aa_bond_rmse"] = bond_rmse
|
395 |
+
|
396 |
+
# Convert codesign model back to allatom model and return metrics
|
397 |
+
model.task = "allatom"
|
398 |
+
model.remove_minimpnn()
|
399 |
+
aa_aux_out = (
|
400 |
+
cropped_samp_coords,
|
401 |
+
cropped_samp_aatypes,
|
402 |
+
samp_atom_mask,
|
403 |
+
sc_aux["pred"],
|
404 |
+
best_idx,
|
405 |
+
)
|
406 |
+
return aa_metrics_out, aa_aux_out
|
models.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Top-level model definitions.
|
7 |
+
Typically these are initialized with config rather than arguments.
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
from functools import partial
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Optional
|
13 |
+
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torchtyping import TensorType
|
20 |
+
|
21 |
+
from core import protein_mpnn
|
22 |
+
from core import residue_constants
|
23 |
+
from core import utils
|
24 |
+
import diffusion
|
25 |
+
import evaluation
|
26 |
+
import modules
|
27 |
+
|
28 |
+
|
29 |
+
class MiniMPNN(nn.Module):
|
30 |
+
"""Wrapper for ProteinMPNN network to predict sequence from structure."""
|
31 |
+
|
32 |
+
def __init__(self, config: argparse.Namespace):
|
33 |
+
super().__init__()
|
34 |
+
self.config = config
|
35 |
+
self.model_config = cfg = config.model.mpnn_model
|
36 |
+
self.n_tokens = config.data.n_aatype_tokens
|
37 |
+
self.seq_emb_dim = cfg.n_channel
|
38 |
+
time_cond_dim = cfg.n_channel * cfg.noise_cond_mult
|
39 |
+
|
40 |
+
self.noise_block = modules.NoiseConditioningBlock(cfg.n_channel, time_cond_dim)
|
41 |
+
self.token_embedding = nn.Linear(self.n_tokens, self.seq_emb_dim)
|
42 |
+
self.mpnn_net = modules.NoiseConditionalProteinMPNN(
|
43 |
+
n_channel=cfg.n_channel,
|
44 |
+
n_layers=cfg.n_layers,
|
45 |
+
n_neighbors=cfg.n_neighbors,
|
46 |
+
time_cond_dim=time_cond_dim,
|
47 |
+
vocab_size=config.data.n_aatype_tokens,
|
48 |
+
input_S_is_embeddings=True,
|
49 |
+
)
|
50 |
+
self.proj_out = nn.Linear(cfg.n_channel, self.n_tokens)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
denoised_coords: TensorType["b n a x", float],
|
55 |
+
coords_noise_level: TensorType["b", float],
|
56 |
+
seq_mask: TensorType["b n", float],
|
57 |
+
residue_index: TensorType["b n", int],
|
58 |
+
seq_self_cond: Optional[TensorType["b n t", float]] = None, # logprobs
|
59 |
+
return_embeddings: bool = False,
|
60 |
+
):
|
61 |
+
coords_noise_level_scaled = 0.25 * torch.log(coords_noise_level)
|
62 |
+
noise_cond = self.noise_block(coords_noise_level_scaled)
|
63 |
+
|
64 |
+
b, n, _, _ = denoised_coords.shape
|
65 |
+
if seq_self_cond is None or not self.model_config.use_self_conditioning:
|
66 |
+
seq_emb_in = torch.zeros(b, n, self.seq_emb_dim).to(denoised_coords)
|
67 |
+
else:
|
68 |
+
seq_emb_in = self.token_embedding(seq_self_cond.exp())
|
69 |
+
|
70 |
+
node_embs, encoder_embs = self.mpnn_net(
|
71 |
+
denoised_coords, seq_emb_in, seq_mask, residue_index, noise_cond
|
72 |
+
)
|
73 |
+
|
74 |
+
logits = self.proj_out(node_embs)
|
75 |
+
pred_logprobs = F.log_softmax(logits, -1)
|
76 |
+
|
77 |
+
if return_embeddings:
|
78 |
+
return pred_logprobs, node_embs, encoder_embs
|
79 |
+
return pred_logprobs
|
80 |
+
|
81 |
+
|
82 |
+
class CoordinateDenoiser(nn.Module):
|
83 |
+
"""Wrapper for U-ViT module to denoise structure coordinates."""
|
84 |
+
|
85 |
+
def __init__(self, config: argparse.Namespace):
|
86 |
+
super().__init__()
|
87 |
+
self.config = config
|
88 |
+
|
89 |
+
# Configuration
|
90 |
+
self.sigma_data = config.data.sigma_data
|
91 |
+
m_cfg = config.model.struct_model
|
92 |
+
nc = m_cfg.n_channel
|
93 |
+
bb_atoms = ["N", "CA", "C", "O"]
|
94 |
+
n_atoms = config.model.struct_model.n_atoms
|
95 |
+
self.use_conv = len(m_cfg.uvit.n_filt_per_layer) > 0
|
96 |
+
if self.use_conv and n_atoms == 37:
|
97 |
+
n_atoms += 1 # make it an even number
|
98 |
+
self.n_atoms = n_atoms
|
99 |
+
self.bb_idxs = [residue_constants.atom_order[a] for a in bb_atoms]
|
100 |
+
n_xyz = 9 if config.model.crop_conditional else 6
|
101 |
+
nc_in = n_xyz * n_atoms # xyz + selfcond xyz + maybe cropcond xyz
|
102 |
+
|
103 |
+
# Neural networks
|
104 |
+
n_noise_channel = nc * m_cfg.noise_cond_mult
|
105 |
+
self.net = modules.TimeCondUViT(
|
106 |
+
seq_len=config.data.fixed_size,
|
107 |
+
patch_size=m_cfg.uvit.patch_size,
|
108 |
+
dim=nc,
|
109 |
+
depth=m_cfg.uvit.n_layers,
|
110 |
+
n_filt_per_layer=m_cfg.uvit.n_filt_per_layer,
|
111 |
+
heads=m_cfg.uvit.n_heads,
|
112 |
+
dim_head=m_cfg.uvit.dim_head,
|
113 |
+
conv_skip_connection=m_cfg.uvit.conv_skip_connection,
|
114 |
+
n_atoms=n_atoms,
|
115 |
+
channels_per_atom=n_xyz,
|
116 |
+
time_cond_dim=n_noise_channel,
|
117 |
+
position_embedding_type=m_cfg.uvit.position_embedding_type,
|
118 |
+
)
|
119 |
+
self.noise_block = modules.NoiseConditioningBlock(nc, n_noise_channel)
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self,
|
123 |
+
noisy_coords: TensorType["b n a x", float],
|
124 |
+
noise_level: TensorType["b", float],
|
125 |
+
seq_mask: TensorType["b n", float],
|
126 |
+
residue_index: Optional[TensorType["b n", int]] = None,
|
127 |
+
struct_self_cond: Optional[TensorType["b n a x", float]] = None,
|
128 |
+
struct_crop_cond: Optional[TensorType["b n a x", float]] = None,
|
129 |
+
):
|
130 |
+
# Prep inputs and time conditioning
|
131 |
+
actual_var_data = self.sigma_data**2
|
132 |
+
var_noisy_coords = noise_level**2 + actual_var_data
|
133 |
+
emb = noisy_coords / utils.expand(var_noisy_coords.sqrt(), noisy_coords)
|
134 |
+
struct_noise_scaled = 0.25 * torch.log(noise_level)
|
135 |
+
noise_cond = self.noise_block(struct_noise_scaled)
|
136 |
+
|
137 |
+
# Prepare self- and crop-conditioning and concatenate along channels
|
138 |
+
if struct_self_cond is None:
|
139 |
+
struct_self_cond = torch.zeros_like(noisy_coords)
|
140 |
+
if self.config.model.crop_conditional:
|
141 |
+
if struct_crop_cond is None:
|
142 |
+
struct_crop_cond = torch.zeros_like(noisy_coords)
|
143 |
+
else:
|
144 |
+
struct_crop_cond = struct_crop_cond / self.sigma_data
|
145 |
+
emb = torch.cat([emb, struct_self_cond, struct_crop_cond], -1)
|
146 |
+
else:
|
147 |
+
emb = torch.cat([emb, struct_self_cond], -1)
|
148 |
+
|
149 |
+
# Run neural network
|
150 |
+
emb = self.net(emb, noise_cond, seq_mask=seq_mask, residue_index=residue_index)
|
151 |
+
|
152 |
+
# Preconditioning from Karras et al.
|
153 |
+
out_scale = noise_level * actual_var_data**0.5 / torch.sqrt(var_noisy_coords)
|
154 |
+
skip_scale = actual_var_data / var_noisy_coords
|
155 |
+
emb = emb * utils.expand(out_scale, emb)
|
156 |
+
skip_info = noisy_coords * utils.expand(skip_scale, noisy_coords)
|
157 |
+
denoised_coords_x0 = emb + skip_info
|
158 |
+
|
159 |
+
# Don't use atom mask; denoise all atoms
|
160 |
+
denoised_coords_x0 *= utils.expand(seq_mask, denoised_coords_x0)
|
161 |
+
return denoised_coords_x0
|
162 |
+
|
163 |
+
|
164 |
+
class Protpardelle(nn.Module):
|
165 |
+
"""All-atom protein diffusion-based generative model.
|
166 |
+
|
167 |
+
This class wraps a structure denoising network and a sequence prediction network
|
168 |
+
to do structure/sequence co-design (for all-atom generation), or backbone generation.
|
169 |
+
|
170 |
+
It can be trained for one of four main tasks. To produce the all-atom (co-design)
|
171 |
+
Protpardelle model, we will typically pretrain an 'allatom' model, then use this
|
172 |
+
to train a 'seqdes' model. A 'seqdes' model can be trained with either a backbone
|
173 |
+
or allatom denoiser. The two can be combined to yield all-atom (co-design) Protpardelle
|
174 |
+
without further training.
|
175 |
+
'backbone': train only a backbone coords denoiser.
|
176 |
+
'seqdes': train only a mini-MPNN, using a pretrained coords denoiser.
|
177 |
+
'allatom': train only an allatom coords denoiser (cannot do all-atom generation
|
178 |
+
by itself).
|
179 |
+
'codesign': train both an allatom denoiser and mini-MPNN at once.
|
180 |
+
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, config: argparse.Namespace, device: str = "cpu"):
|
184 |
+
super().__init__()
|
185 |
+
self.config = config
|
186 |
+
self.device = device
|
187 |
+
self.task = config.model.task
|
188 |
+
self.n_tokens = config.data.n_aatype_tokens
|
189 |
+
|
190 |
+
self.use_mpnn_model = self.task in ["seqdes", "codesign"]
|
191 |
+
|
192 |
+
# Modules
|
193 |
+
self.all_modules = {}
|
194 |
+
self.bb_idxs = [0, 1, 2, 4]
|
195 |
+
self.n_atoms = 37
|
196 |
+
self.struct_model = CoordinateDenoiser(config)
|
197 |
+
self.all_modules["struct_model"] = self.struct_model
|
198 |
+
self.bb_idxs = self.struct_model.bb_idxs
|
199 |
+
self.n_atoms = self.struct_model.n_atoms
|
200 |
+
|
201 |
+
if self.use_mpnn_model:
|
202 |
+
self.mpnn_model = MiniMPNN(config)
|
203 |
+
self.all_modules["mpnn_model"] = self.mpnn_model
|
204 |
+
|
205 |
+
# Load any pretrained modules
|
206 |
+
for module_name in self.config.model.pretrained_modules:
|
207 |
+
self.load_pretrained_module(module_name)
|
208 |
+
|
209 |
+
# Diffusion-related
|
210 |
+
self.sigma_data = self.struct_model.sigma_data
|
211 |
+
self.training_noise_schedule = partial(
|
212 |
+
diffusion.noise_schedule,
|
213 |
+
sigma_data=self.sigma_data,
|
214 |
+
**vars(config.diffusion.training),
|
215 |
+
)
|
216 |
+
self.sampling_noise_schedule_default = self.make_sampling_noise_schedule()
|
217 |
+
|
218 |
+
def load_pretrained_module(self, module_name: str, ckpt_path: Optional[str] = None):
|
219 |
+
"""Load pretrained weights for a given module name."""
|
220 |
+
assert module_name in ["struct_model", "mpnn_model"], module_name
|
221 |
+
|
222 |
+
# Load pretrained checkpoint
|
223 |
+
if ckpt_path is None:
|
224 |
+
ckpt_path = getattr(self.config.model, f"{module_name}_checkpoint")
|
225 |
+
ckpt_path = os.path.join(self.config.train.home_dir, ckpt_path)
|
226 |
+
ckpt_dict = torch.load(ckpt_path, map_location=self.device)
|
227 |
+
model_state_dict = ckpt_dict["model_state_dict"]
|
228 |
+
|
229 |
+
# Get only submodule state_dict
|
230 |
+
submodule_state_dict = {
|
231 |
+
sk[len(module_name) + 1 :]: sv
|
232 |
+
for sk, sv in model_state_dict.items()
|
233 |
+
if sk.startswith(module_name)
|
234 |
+
}
|
235 |
+
|
236 |
+
# Load into module
|
237 |
+
module = dict(self.named_modules())[module_name]
|
238 |
+
module.load_state_dict(submodule_state_dict)
|
239 |
+
|
240 |
+
# Freeze unneeded modules
|
241 |
+
if module_name == "struct_model":
|
242 |
+
self.struct_model = module
|
243 |
+
if self.task == "seqdes":
|
244 |
+
for p in module.parameters():
|
245 |
+
p.requires_grad = False
|
246 |
+
if module_name == "mpnn_model":
|
247 |
+
self.mpnn_model = module
|
248 |
+
if self.task not in ["codesign", "seqdes"]:
|
249 |
+
for p in module.parameters():
|
250 |
+
p.requires_grad = False
|
251 |
+
|
252 |
+
return module
|
253 |
+
|
254 |
+
def load_minimpnn(self, mpnn_ckpt_path: Optional[str] = None):
|
255 |
+
"""Convert an allatom model to a codesign model."""
|
256 |
+
if mpnn_ckpt_path is None:
|
257 |
+
mpnn_ckpt_path = "checkpoints/minimpnn_state_dict.pth"
|
258 |
+
self.mpnn_model = MiniMPNN(self.config).to(self.device)
|
259 |
+
self.load_pretrained_module("mpnn_model", ckpt_path=mpnn_ckpt_path)
|
260 |
+
self.use_mpnn_model = True
|
261 |
+
return
|
262 |
+
|
263 |
+
def remove_minimpnn(self):
|
264 |
+
"""Revert a codesign model to an allatom model to a codesign model."""
|
265 |
+
self.use_mpnn_model = False
|
266 |
+
self.mpnn_model = None
|
267 |
+
self.all_modules["mpnn_model"] = None
|
268 |
+
|
269 |
+
def make_sampling_noise_schedule(self, **noise_kwargs):
|
270 |
+
"""Make the default sampling noise schedule function."""
|
271 |
+
noise_schedule_kwargs = vars(self.config.diffusion.sampling)
|
272 |
+
if len(noise_kwargs) > 0:
|
273 |
+
noise_schedule_kwargs.update(noise_kwargs)
|
274 |
+
return partial(diffusion.noise_schedule, **noise_schedule_kwargs)
|
275 |
+
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
*,
|
279 |
+
noisy_coords: TensorType["b n a x", float],
|
280 |
+
noise_level: TensorType["b", float],
|
281 |
+
seq_mask: TensorType["b n", float],
|
282 |
+
residue_index: TensorType["b n", int],
|
283 |
+
struct_self_cond: Optional[TensorType["b n a x", float]] = None,
|
284 |
+
struct_crop_cond: Optional[TensorType["b n a x", float]] = None,
|
285 |
+
seq_self_cond: Optional[TensorType["b n t", float]] = None, # logprobs
|
286 |
+
run_struct_model: bool = True,
|
287 |
+
run_mpnn_model: bool = True,
|
288 |
+
):
|
289 |
+
"""Main forward function for denoising/co-design.
|
290 |
+
|
291 |
+
Arguments:
|
292 |
+
noisy_coords: noisy array of xyz coordinates.
|
293 |
+
noise_level: std of noise for each example in the batch.
|
294 |
+
seq_mask: mask indicating which indexes contain data.
|
295 |
+
residue_index: residue ordering. This is used by proteinMPNN, but currently
|
296 |
+
only used by the diffusion model when the 'absolute_residx' or
|
297 |
+
'relative' position_embedding_type is specified.
|
298 |
+
struct_self_cond: denoised coordinates from the previous step, scaled
|
299 |
+
down by sigma data.
|
300 |
+
struct_crop_cond: unnoised coordinates. unscaled (scaled down by sigma
|
301 |
+
data inside the denoiser)
|
302 |
+
seq_self_cond: mpnn-predicted sequence logprobs from the previous step.
|
303 |
+
run_struct_model: flag to optionally not run structure denoiser.
|
304 |
+
run_mpnn_model: flag to optionally not run mini-mpnn.
|
305 |
+
"""
|
306 |
+
|
307 |
+
# Coordinate denoiser
|
308 |
+
denoised_x0 = noisy_coords
|
309 |
+
if run_struct_model:
|
310 |
+
denoised_x0 = self.struct_model(
|
311 |
+
noisy_coords,
|
312 |
+
noise_level,
|
313 |
+
seq_mask,
|
314 |
+
residue_index=residue_index,
|
315 |
+
struct_self_cond=struct_self_cond,
|
316 |
+
struct_crop_cond=struct_crop_cond,
|
317 |
+
)
|
318 |
+
|
319 |
+
# Mini-MPNN
|
320 |
+
aatype_logprobs = None
|
321 |
+
if self.use_mpnn_model and run_mpnn_model:
|
322 |
+
aatype_logprobs = self.mpnn_model(
|
323 |
+
denoised_x0.detach(),
|
324 |
+
noise_level,
|
325 |
+
seq_mask,
|
326 |
+
residue_index,
|
327 |
+
seq_self_cond=seq_self_cond,
|
328 |
+
return_embeddings=False,
|
329 |
+
)
|
330 |
+
aatype_logprobs = aatype_logprobs * seq_mask[..., None]
|
331 |
+
|
332 |
+
# Process outputs
|
333 |
+
if aatype_logprobs is None:
|
334 |
+
aatype_logprobs = repeat(seq_mask, "b n -> b n t", t=self.n_tokens)
|
335 |
+
aatype_logprobs = torch.ones_like(aatype_logprobs)
|
336 |
+
aatype_logprobs = F.log_softmax(aatype_logprobs, -1)
|
337 |
+
struct_self_cond_out = denoised_x0.detach() / self.sigma_data
|
338 |
+
seq_self_cond_out = aatype_logprobs.detach()
|
339 |
+
|
340 |
+
return denoised_x0, aatype_logprobs, struct_self_cond_out, seq_self_cond_out
|
341 |
+
|
342 |
+
def make_seq_mask_for_sampling(
|
343 |
+
self,
|
344 |
+
prot_lens: Optional[TensorType["b", int]] = None,
|
345 |
+
n_samples: int = 1,
|
346 |
+
min_len: int = 50,
|
347 |
+
max_len: Optional[int] = None,
|
348 |
+
):
|
349 |
+
"""Makes a sequence mask of varying protein lengths (only input required
|
350 |
+
to begin sampling).
|
351 |
+
"""
|
352 |
+
if max_len is None:
|
353 |
+
max_len = self.config.data.fixed_size
|
354 |
+
if prot_lens is None:
|
355 |
+
possible_lens = np.arange(min_len, max_len)
|
356 |
+
prot_lens = torch.Tensor(np.random.choice(possible_lens, n_samples))
|
357 |
+
else:
|
358 |
+
n_samples = len(prot_lens)
|
359 |
+
max_len = max(prot_lens)
|
360 |
+
mask = repeat(torch.arange(max_len), "n -> b n", b=n_samples)
|
361 |
+
mask = (mask < prot_lens[:, None]).float().to(self.device)
|
362 |
+
return mask
|
363 |
+
|
364 |
+
def sample(
|
365 |
+
self,
|
366 |
+
*,
|
367 |
+
seq_mask: TensorType["b n", float] = None,
|
368 |
+
n_samples: int = 1,
|
369 |
+
min_len: int = 50,
|
370 |
+
max_len: int = 512,
|
371 |
+
residue_index: TensorType["b n", int] = None,
|
372 |
+
gt_coords: TensorType["b n a x", float] = None,
|
373 |
+
gt_coords_traj: List[TensorType["b n a x", float]] = None,
|
374 |
+
gt_cond_atom_mask: TensorType["b n a", float] = None,
|
375 |
+
gt_aatype: TensorType["b n", int] = None,
|
376 |
+
gt_cond_seq_mask: TensorType["b n", float] = None,
|
377 |
+
apply_cond_proportion: float = 1.0,
|
378 |
+
n_steps: int = 200,
|
379 |
+
step_scale: float = 1.2,
|
380 |
+
s_churn: float = 50.0,
|
381 |
+
noise_scale: float = 1.0,
|
382 |
+
s_t_min: float = 0.01,
|
383 |
+
s_t_max: float = 50.0,
|
384 |
+
temperature: float = 1.0,
|
385 |
+
top_p: float = 1.0,
|
386 |
+
disallow_aas: List[int] = [4, 20], # cys, unk
|
387 |
+
sidechain_mode: bool = False,
|
388 |
+
skip_mpnn_proportion: float = 0.7,
|
389 |
+
anneal_seq_resampling_rate: Optional[str] = None, # linear, cosine
|
390 |
+
use_fullmpnn: bool = False,
|
391 |
+
use_fullmpnn_for_final: bool = True,
|
392 |
+
use_reconstruction_guidance: bool = False,
|
393 |
+
use_classifier_free_guidance: bool = False, # defaults to replacement guidance if these are all false
|
394 |
+
guidance_scale: float = 1.0,
|
395 |
+
noise_schedule: Optional[Callable] = None,
|
396 |
+
tqdm_pbar: Optional[Callable] = None,
|
397 |
+
return_last: bool = True,
|
398 |
+
return_aux: bool = False,
|
399 |
+
):
|
400 |
+
"""Sampling function for backbone or all-atom diffusion. All arguments are optional.
|
401 |
+
|
402 |
+
Arguments:
|
403 |
+
seq_mask: mask defining the number and lengths of proteins to be sampled.
|
404 |
+
n_samples: number of samples to draw (if seq_mask not provided).
|
405 |
+
min_len: minimum length of proteins to be sampled (if seq_mask not provided).
|
406 |
+
max_len: maximum length of proteins to be sampled (if seq_mask not provided).
|
407 |
+
residue_index: residue index of proteins to be sampled.
|
408 |
+
gt_coords: conditioning information for coords.
|
409 |
+
gt_coords_traj: conditioning information for coords specified for each timestep
|
410 |
+
(if gt_coords is not provided).
|
411 |
+
gt_cond_atom_mask: mask identifying atoms to apply gt_coords.
|
412 |
+
gt_aatype: conditioning information for sequence.
|
413 |
+
gt_cond_seq_mask: sequence positions to apply gt_aatype.
|
414 |
+
apply_cond_proportion: the proportion of timesteps to apply the conditioning.
|
415 |
+
e.g. if 0.5, then the first 50% of steps use conditioning, and the last 50%
|
416 |
+
are unconditional.
|
417 |
+
n_steps: number of denoising steps (ODE discretizations).
|
418 |
+
step_scale: scale to apply to the score.
|
419 |
+
s_churn: gamma = s_churn / n_steps describes the additional noise to add
|
420 |
+
relatively at each denoising step. Use 0.0 for deterministic sampling or
|
421 |
+
0.2 * n_steps as a rough default for stochastic sampling.
|
422 |
+
noise_scale: scale to apply to gamma.
|
423 |
+
s_t_min: don't apply s_churn below this noise level.
|
424 |
+
s_t_max: don't apply s_churn above this noise level.
|
425 |
+
temperature: scale to apply to aatype logits.
|
426 |
+
top_p: don't tokens which fall outside this proportion of the total probability.
|
427 |
+
disallow_aas: don't sample these token indices.
|
428 |
+
sidechain_mode: whether to do all-atom sampling (False for backbone-only).
|
429 |
+
skip_mpnn_proportion: proportion of timesteps from the start to skip running
|
430 |
+
mini-MPNN.
|
431 |
+
anneal_seq_resampling_rate: whether and how to decay the probability of
|
432 |
+
running mini-MPNN. None, 'linear', or 'cosine'
|
433 |
+
use_fullmpnn: use "full" ProteinMPNN at each step.
|
434 |
+
use_fullmpnn_for_final: use "full" ProteinMPNN at the final step.
|
435 |
+
use_reconstruction_guidance: use reconstruction guidance on the conditioning.
|
436 |
+
use_classifier_free_guidance: use classifier-free guidance on the conditioning.
|
437 |
+
guidance_scale: weight for reconstruction/classifier-free guidance.
|
438 |
+
noise_schedule: specify the noise level timesteps for sampling.
|
439 |
+
tqdm_pbar: progress bar in interactive contexts.
|
440 |
+
return_last: return only the sampled structure and sequence.
|
441 |
+
return_aux: return a dict of everything associated with the sampling run.
|
442 |
+
"""
|
443 |
+
|
444 |
+
def ode_step(sigma_in, sigma_next, xt_in, x0_pred, gamma, guidance_in=None):
|
445 |
+
if gamma > 0:
|
446 |
+
t_hat = sigma_in + gamma * sigma_in
|
447 |
+
sigma_delta = torch.sqrt(t_hat**2 - sigma_in**2)
|
448 |
+
noisier_x = xt_in + utils.expand(
|
449 |
+
sigma_delta, xt_in
|
450 |
+
) * noise_scale * torch.randn_like(xt_in).to(xt_in)
|
451 |
+
xt_in = noisier_x * utils.expand(seq_mask, noisier_x)
|
452 |
+
sigma_in = t_hat
|
453 |
+
|
454 |
+
mask = (sigma_in > 0).float()
|
455 |
+
score = (xt_in - x0_pred) / utils.expand(sigma_in.clamp(min=1e-6), xt_in)
|
456 |
+
score = score * utils.expand(mask, score)
|
457 |
+
if use_reconstruction_guidance:
|
458 |
+
guidance, guidance_mask = guidance_in
|
459 |
+
guidance = guidance * guidance_mask[..., None]
|
460 |
+
guidance_std = guidance[guidance_mask.bool()].var().sqrt()
|
461 |
+
score_std = score[guidance_mask.bool()].var().sqrt()
|
462 |
+
score = score + guidance * guidance_scale
|
463 |
+
if use_classifier_free_guidance:
|
464 |
+
# guidance_in is the unconditional x0 (x0_pred is the conditional x0)
|
465 |
+
# guidance_scale = 1 + w from Ho paper
|
466 |
+
# ==0: use only unconditional score; <1: interpolate the scores;
|
467 |
+
# ==1: use only conditional score; >1: skew towards conditional score
|
468 |
+
uncond_x0 = guidance_in
|
469 |
+
uncond_score = (xt_in - uncond_x0) / utils.expand(
|
470 |
+
sigma_in.clamp(min=1e-6), xt_in
|
471 |
+
)
|
472 |
+
uncond_score = uncond_score * utils.expand(mask, uncond_score)
|
473 |
+
score = guidance_scale * score + (1 - guidance_scale) * uncond_score
|
474 |
+
step = score * step_scale * utils.expand(sigma_next - sigma_in, score)
|
475 |
+
new_xt = xt_in + step
|
476 |
+
return new_xt
|
477 |
+
|
478 |
+
def sample_aatype(logprobs):
|
479 |
+
# Top-p truncation
|
480 |
+
probs = F.softmax(logprobs.clone(), dim=-1)
|
481 |
+
sorted_prob, sorted_idxs = torch.sort(probs, descending=True)
|
482 |
+
cumsum_prob = torch.cumsum(sorted_prob, dim=-1)
|
483 |
+
sorted_indices_to_remove = cumsum_prob > top_p
|
484 |
+
sorted_indices_to_remove[..., 0] = 0
|
485 |
+
sorted_prob[sorted_indices_to_remove] = 0
|
486 |
+
orig_probs = torch.scatter(
|
487 |
+
torch.zeros_like(sorted_prob),
|
488 |
+
dim=-1,
|
489 |
+
index=sorted_idxs,
|
490 |
+
src=sorted_prob,
|
491 |
+
)
|
492 |
+
|
493 |
+
# Apply temperature and disallowed AAs and sample
|
494 |
+
assert temperature >= 0.0
|
495 |
+
scaled_logits = orig_probs.clamp(min=1e-9).log() / (temperature + 1e-4)
|
496 |
+
if disallow_aas:
|
497 |
+
unwanted_mask = torch.zeros(scaled_logits.shape[-1]).to(scaled_logits)
|
498 |
+
unwanted_mask[disallow_aas] = 1
|
499 |
+
scaled_logits -= unwanted_mask * 1e10
|
500 |
+
orig_probs = F.softmax(scaled_logits, dim=-1)
|
501 |
+
categorical = torch.distributions.Categorical(probs=orig_probs)
|
502 |
+
samp_aatype = categorical.sample()
|
503 |
+
return samp_aatype
|
504 |
+
|
505 |
+
def design_with_fullmpnn(batched_coords, seq_mask):
|
506 |
+
seq_lens = seq_mask.sum(-1).long()
|
507 |
+
designed_seqs = [
|
508 |
+
evaluation.design_sequence(c[: seq_lens[i]], model=fullmpnn_model)[0]
|
509 |
+
for i, c in enumerate(batched_coords)
|
510 |
+
]
|
511 |
+
designed_aatypes, _ = utils.batched_seq_to_aatype_and_mask(
|
512 |
+
designed_seqs, max_len=seq_mask.shape[-1]
|
513 |
+
)
|
514 |
+
return designed_aatypes
|
515 |
+
|
516 |
+
# Initialize masks/features
|
517 |
+
if seq_mask is None: # Sample random lengths
|
518 |
+
assert gt_aatype is None # Don't condition on aatype without seq_mask
|
519 |
+
seq_mask = self.make_seq_mask_for_sampling(
|
520 |
+
n_samples=n_samples,
|
521 |
+
min_len=min_len,
|
522 |
+
max_len=max_len,
|
523 |
+
)
|
524 |
+
if residue_index is None:
|
525 |
+
residue_index = torch.arange(seq_mask.shape[-1])
|
526 |
+
residue_index = repeat(residue_index, "n -> b n", b=seq_mask.shape[0])
|
527 |
+
residue_index = residue_index.to(seq_mask) * seq_mask
|
528 |
+
if use_fullmpnn or use_fullmpnn_for_final:
|
529 |
+
fullmpnn_model = protein_mpnn.get_mpnn_model(
|
530 |
+
path_to_model_weights=self.config.train.home_dir
|
531 |
+
+ "/ProteinMPNN/vanilla_model_weights",
|
532 |
+
device=self.device,
|
533 |
+
)
|
534 |
+
|
535 |
+
# Initialize noise schedule/parameters
|
536 |
+
to_batch_size = lambda x: x * torch.ones(seq_mask.shape[0]).to(self.device)
|
537 |
+
s_t_min = s_t_min * self.sigma_data
|
538 |
+
s_t_max = s_t_max * self.sigma_data
|
539 |
+
if noise_schedule is None:
|
540 |
+
noise_schedule = self.sampling_noise_schedule_default
|
541 |
+
sigma = noise_schedule(1)
|
542 |
+
timesteps = torch.linspace(1, 0, n_steps + 1)
|
543 |
+
|
544 |
+
# Set up conditioning/guidance information
|
545 |
+
crop_cond_coords = None
|
546 |
+
if gt_coords is None:
|
547 |
+
coords_shape = seq_mask.shape + (self.n_atoms, 3)
|
548 |
+
xt = torch.randn(*coords_shape).to(self.device) * sigma
|
549 |
+
xt *= utils.expand(seq_mask, xt)
|
550 |
+
else:
|
551 |
+
assert gt_coords_traj is None
|
552 |
+
noise_levels = [to_batch_size(noise_schedule(t)) for t in timesteps]
|
553 |
+
gt_coords_traj = [
|
554 |
+
diffusion.noise_coords(gt_coords, nl) for nl in noise_levels
|
555 |
+
]
|
556 |
+
xt = gt_coords_traj[0]
|
557 |
+
if gt_cond_atom_mask is not None:
|
558 |
+
crop_cond_coords = gt_coords * gt_cond_atom_mask[..., None]
|
559 |
+
gt_atom_mask = None
|
560 |
+
if gt_aatype is not None:
|
561 |
+
gt_atom_mask = utils.atom37_mask_from_aatype(gt_aatype, seq_mask)
|
562 |
+
fake_logits = repeat(seq_mask, "b n -> b n t", t=self.n_tokens)
|
563 |
+
s_hat = (sample_aatype(fake_logits) * seq_mask).long()
|
564 |
+
|
565 |
+
# Initialize superposition for all-atom sampling
|
566 |
+
if sidechain_mode:
|
567 |
+
b, n = seq_mask.shape[:2]
|
568 |
+
|
569 |
+
# Latest predicted x0 for sidechain superpositions
|
570 |
+
atom73_state_0 = torch.zeros(b, n, 73, 3).to(xt)
|
571 |
+
|
572 |
+
# Current state xt for sidechain superpositions (denoised to different levels)
|
573 |
+
atom73_state_t = torch.randn(b, n, 73, 3).to(xt) * sigma
|
574 |
+
|
575 |
+
# Noise level of xt
|
576 |
+
sigma73_last = torch.ones(b, n, 73).to(xt) * sigma
|
577 |
+
|
578 |
+
# Seqhat and mask used to choose sidechains for euler step (b, n)
|
579 |
+
s_hat = (seq_mask * 7).long()
|
580 |
+
mask37 = utils.atom37_mask_from_aatype(s_hat, seq_mask).bool()
|
581 |
+
mask73 = utils.atom73_mask_from_aatype(s_hat, seq_mask).bool()
|
582 |
+
begin_mpnn_step = int(n_steps * skip_mpnn_proportion)
|
583 |
+
|
584 |
+
# Prepare to run sampling trajectory
|
585 |
+
sigma = to_batch_size(sigma)
|
586 |
+
x0 = None
|
587 |
+
x0_prev = None
|
588 |
+
x_self_cond = None
|
589 |
+
s_logprobs = None
|
590 |
+
s_self_cond = None
|
591 |
+
if tqdm_pbar is None:
|
592 |
+
tqdm_pbar = lambda x: x
|
593 |
+
torch.set_grad_enabled(False)
|
594 |
+
|
595 |
+
# *t_traj is the denoising trajectory; *0_traj is the evolution of predicted clean data
|
596 |
+
# s0 are aatype probs of shape (b n t); s_hat are discrete aatype of shape (b n)
|
597 |
+
xt_traj, x0_traj, st_traj, s0_traj = [], [], [], []
|
598 |
+
|
599 |
+
# Sampling trajectory
|
600 |
+
for i, t in tqdm_pbar(enumerate(iter(timesteps[1:]))):
|
601 |
+
# Set up noise levels
|
602 |
+
sigma_next = noise_schedule(t)
|
603 |
+
if i == n_steps - 1:
|
604 |
+
sigma_next *= 0
|
605 |
+
gamma = (
|
606 |
+
s_churn / n_steps
|
607 |
+
if (sigma_next >= s_t_min and sigma_next <= s_t_max)
|
608 |
+
else 0.0
|
609 |
+
)
|
610 |
+
sigma_next = to_batch_size(sigma_next)
|
611 |
+
|
612 |
+
if sidechain_mode:
|
613 |
+
# Fill in noise for masked positions since xt is initialized to zeros at each step
|
614 |
+
dummy_fill_noise = torch.randn_like(xt) * utils.expand(sigma, xt)
|
615 |
+
zero_atom_mask = utils.atom37_mask_from_aatype(s_hat, seq_mask)
|
616 |
+
dummy_fill_mask = 1 - zero_atom_mask[..., None]
|
617 |
+
xt = xt * zero_atom_mask[..., None] + dummy_fill_noise * dummy_fill_mask
|
618 |
+
else: # backbone only
|
619 |
+
bb_seq = (seq_mask * residue_constants.restype_order["G"]).long()
|
620 |
+
bb_atom_mask = utils.atom37_mask_from_aatype(bb_seq, seq_mask)
|
621 |
+
xt *= bb_atom_mask[..., None]
|
622 |
+
|
623 |
+
# Enable grad for reconstruction guidance
|
624 |
+
if use_reconstruction_guidance:
|
625 |
+
torch.set_grad_enabled(True)
|
626 |
+
xt.requires_grad = True
|
627 |
+
|
628 |
+
# Run denoising network
|
629 |
+
run_mpnn = not sidechain_mode or i > begin_mpnn_step
|
630 |
+
x0, s_logprobs, x_self_cond, s_self_cond = self.forward(
|
631 |
+
noisy_coords=xt,
|
632 |
+
noise_level=sigma,
|
633 |
+
seq_mask=seq_mask,
|
634 |
+
residue_index=residue_index,
|
635 |
+
struct_self_cond=x_self_cond,
|
636 |
+
struct_crop_cond=crop_cond_coords,
|
637 |
+
seq_self_cond=s_self_cond,
|
638 |
+
run_mpnn_model=run_mpnn,
|
639 |
+
)
|
640 |
+
|
641 |
+
# Compute additional stuff for guidance
|
642 |
+
if use_reconstruction_guidance:
|
643 |
+
loss = (x0 - gt_coords).pow(2).sum(-1)
|
644 |
+
loss = loss * gt_cond_atom_mask
|
645 |
+
loss = loss.sum() / gt_cond_atom_mask.sum().clamp(min=1)
|
646 |
+
xt.retain_grad()
|
647 |
+
loss.backward()
|
648 |
+
guidance = xt.grad.clone()
|
649 |
+
xt.grad *= 0
|
650 |
+
torch.set_grad_enabled(False)
|
651 |
+
if use_classifier_free_guidance:
|
652 |
+
assert not use_reconstruction_guidance
|
653 |
+
uncond_x0, _, _, _ = self.forward(
|
654 |
+
noisy_coords=xt,
|
655 |
+
noise_level=sigma,
|
656 |
+
seq_mask=seq_mask,
|
657 |
+
residue_index=residue_index,
|
658 |
+
struct_self_cond=x_self_cond,
|
659 |
+
seq_self_cond=s_self_cond,
|
660 |
+
run_mpnn_model=run_mpnn,
|
661 |
+
)
|
662 |
+
|
663 |
+
# Structure denoising step
|
664 |
+
if not sidechain_mode: # backbone
|
665 |
+
if sigma[0] > 0:
|
666 |
+
xt = ode_step(sigma, sigma_next, xt, x0, gamma)
|
667 |
+
else:
|
668 |
+
xt = x0
|
669 |
+
else: # allatom
|
670 |
+
# Write x0 into atom73_state_0 for atoms corresponding to old seqhat
|
671 |
+
atom73_state_0[mask73] = x0[mask37]
|
672 |
+
|
673 |
+
# Determine sequence resampling probability
|
674 |
+
if anneal_seq_resampling_rate is not None:
|
675 |
+
step_time = 1 - (i - begin_mpnn_step) / max(
|
676 |
+
1, n_steps - begin_mpnn_step
|
677 |
+
)
|
678 |
+
if anneal_seq_resampling_rate == "linear":
|
679 |
+
resampling_rate = step_time
|
680 |
+
elif anneal_seq_resampling_rate == "cosine":
|
681 |
+
k = 2
|
682 |
+
resampling_rate = (
|
683 |
+
1 + np.cos(2 * np.pi * (step_time - 0.5))
|
684 |
+
) / k
|
685 |
+
resample_this_step = np.random.uniform() < resampling_rate
|
686 |
+
|
687 |
+
# Resample sequence or design with full ProteinMPNN
|
688 |
+
if i == n_steps - 1 and use_fullmpnn_for_final:
|
689 |
+
s_hat = design_with_fullmpnn(x0, seq_mask).to(x0.device)
|
690 |
+
elif anneal_seq_resampling_rate is None or resample_this_step:
|
691 |
+
if run_mpnn and use_fullmpnn:
|
692 |
+
s_hat = design_with_fullmpnn(x0, seq_mask).to(x0.device)
|
693 |
+
else:
|
694 |
+
s_hat = sample_aatype(s_logprobs)
|
695 |
+
|
696 |
+
# Overwrite s_hat with any conditioning information
|
697 |
+
if (i + 1) / n_steps <= apply_cond_proportion:
|
698 |
+
if gt_cond_seq_mask is not None and gt_aatype is not None:
|
699 |
+
s_hat = (
|
700 |
+
1 - gt_cond_seq_mask
|
701 |
+
) * s_hat + gt_cond_seq_mask * gt_aatype
|
702 |
+
s_hat = s_hat.long()
|
703 |
+
|
704 |
+
# Set masks for collapsing superposition using new sequence
|
705 |
+
mask37 = utils.atom37_mask_from_aatype(s_hat, seq_mask).bool()
|
706 |
+
mask73 = utils.atom73_mask_from_aatype(s_hat, seq_mask).bool()
|
707 |
+
|
708 |
+
# Determine prev noise levels for atoms corresponding to new sequence
|
709 |
+
step_sigma_prev = (
|
710 |
+
torch.ones(*xt.shape[:-1]).to(xt) * sigma[..., None, None]
|
711 |
+
)
|
712 |
+
step_sigma_prev[mask37] = sigma73_last[mask73] # b, n, 37
|
713 |
+
step_sigma_next = sigma_next[..., None, None] # b, 1, 1
|
714 |
+
|
715 |
+
# Denoising step on atoms corresponding to new sequence
|
716 |
+
b, n = mask37.shape[:2]
|
717 |
+
step_xt = torch.zeros(b, n, 37, 3).to(xt)
|
718 |
+
step_x0 = torch.zeros(b, n, 37, 3).to(xt)
|
719 |
+
step_xt[mask37] = atom73_state_t[mask73]
|
720 |
+
step_x0[mask37] = atom73_state_0[mask73]
|
721 |
+
|
722 |
+
guidance_in = None
|
723 |
+
if (i + 1) / n_steps <= apply_cond_proportion:
|
724 |
+
if use_reconstruction_guidance:
|
725 |
+
guidance_in = (guidance, mask37.float())
|
726 |
+
elif use_classifier_free_guidance:
|
727 |
+
guidance_in = uncond_x0
|
728 |
+
|
729 |
+
step_xt = ode_step(
|
730 |
+
step_sigma_prev,
|
731 |
+
step_sigma_next,
|
732 |
+
step_xt,
|
733 |
+
step_x0,
|
734 |
+
gamma,
|
735 |
+
guidance_in=guidance_in,
|
736 |
+
)
|
737 |
+
xt = step_xt
|
738 |
+
|
739 |
+
# Write new xt into atom73_state_t for atoms corresponding to new seqhat and update sigma_last
|
740 |
+
atom73_state_t[mask73] = step_xt[mask37]
|
741 |
+
sigma73_last[mask73] = step_sigma_next[0].item()
|
742 |
+
|
743 |
+
# Replacement guidance if conditioning information provided
|
744 |
+
if (i + 1) / n_steps <= apply_cond_proportion:
|
745 |
+
if gt_coords_traj is not None:
|
746 |
+
if gt_cond_atom_mask is None:
|
747 |
+
xt = gt_coords_traj[i + 1]
|
748 |
+
else:
|
749 |
+
xt = (1 - gt_cond_atom_mask)[
|
750 |
+
..., None
|
751 |
+
] * xt + gt_cond_atom_mask[..., None] * gt_coords_traj[i + 1]
|
752 |
+
|
753 |
+
sigma = sigma_next
|
754 |
+
|
755 |
+
# Logging
|
756 |
+
xt_scale = self.sigma_data / utils.expand(
|
757 |
+
torch.sqrt(sigma_next**2 + self.sigma_data**2), xt
|
758 |
+
)
|
759 |
+
scaled_xt = xt * xt_scale
|
760 |
+
xt_traj.append(scaled_xt.cpu())
|
761 |
+
x0_traj.append(x0.cpu())
|
762 |
+
st_traj.append(s_hat.cpu())
|
763 |
+
s0_traj.append(s_logprobs.cpu())
|
764 |
+
|
765 |
+
if return_last:
|
766 |
+
return xt, s_hat, seq_mask
|
767 |
+
elif return_aux:
|
768 |
+
return {
|
769 |
+
"x": xt,
|
770 |
+
"s": s_hat,
|
771 |
+
"seq_mask": seq_mask,
|
772 |
+
"xt_traj": xt_traj,
|
773 |
+
"x0_traj": x0_traj,
|
774 |
+
"st_traj": st_traj,
|
775 |
+
"s0_traj": s0_traj,
|
776 |
+
}
|
777 |
+
else:
|
778 |
+
return xt_traj, x0_traj, st_traj, s0_traj, seq_mask
|
modules.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Neural network modules. Many of these are adapted from open source modules.
|
7 |
+
"""
|
8 |
+
from typing import List, Sequence, Optional
|
9 |
+
|
10 |
+
from einops import rearrange, reduce, repeat
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
import numpy as np
|
13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from transformers import AutoTokenizer, EsmModel
|
18 |
+
|
19 |
+
from core import protein_mpnn
|
20 |
+
from core import residue_constants
|
21 |
+
from core import utils
|
22 |
+
|
23 |
+
|
24 |
+
########################################
|
25 |
+
# Adapted from https://github.com/ermongroup/ddim
|
26 |
+
|
27 |
+
|
28 |
+
def downsample(x):
|
29 |
+
return nn.functional.avg_pool2d(x, 2, 2, ceil_mode=True)
|
30 |
+
|
31 |
+
|
32 |
+
def upsample_coords(x, shape):
|
33 |
+
new_l, new_w = shape
|
34 |
+
return nn.functional.interpolate(x, size=(new_l, new_w), mode="nearest")
|
35 |
+
|
36 |
+
|
37 |
+
########################################
|
38 |
+
# Adapted from https://github.com/aqlaboratory/openfold
|
39 |
+
|
40 |
+
|
41 |
+
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
42 |
+
zero_index = -1 * len(inds)
|
43 |
+
first_inds = list(range(len(tensor.shape[:zero_index])))
|
44 |
+
return tensor.contiguous().permute(first_inds + [zero_index + i for i in inds])
|
45 |
+
|
46 |
+
|
47 |
+
def lddt(
|
48 |
+
all_atom_pred_pos: torch.Tensor,
|
49 |
+
all_atom_positions: torch.Tensor,
|
50 |
+
all_atom_mask: torch.Tensor,
|
51 |
+
cutoff: float = 15.0,
|
52 |
+
eps: float = 1e-10,
|
53 |
+
per_residue: bool = True,
|
54 |
+
) -> torch.Tensor:
|
55 |
+
n = all_atom_mask.shape[-2]
|
56 |
+
dmat_true = torch.sqrt(
|
57 |
+
eps
|
58 |
+
+ torch.sum(
|
59 |
+
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
|
60 |
+
** 2,
|
61 |
+
dim=-1,
|
62 |
+
)
|
63 |
+
)
|
64 |
+
|
65 |
+
dmat_pred = torch.sqrt(
|
66 |
+
eps
|
67 |
+
+ torch.sum(
|
68 |
+
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
|
69 |
+
dim=-1,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
dists_to_score = (
|
73 |
+
(dmat_true < cutoff)
|
74 |
+
* all_atom_mask
|
75 |
+
* permute_final_dims(all_atom_mask, (1, 0))
|
76 |
+
* (1.0 - torch.eye(n, device=all_atom_mask.device))
|
77 |
+
)
|
78 |
+
|
79 |
+
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
80 |
+
|
81 |
+
score = (
|
82 |
+
(dist_l1 < 0.5).type(dist_l1.dtype)
|
83 |
+
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
84 |
+
+ (dist_l1 < 2.0).type(dist_l1.dtype)
|
85 |
+
+ (dist_l1 < 4.0).type(dist_l1.dtype)
|
86 |
+
)
|
87 |
+
score = score * 0.25
|
88 |
+
|
89 |
+
dims = (-1,) if per_residue else (-2, -1)
|
90 |
+
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
91 |
+
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
92 |
+
|
93 |
+
return score
|
94 |
+
|
95 |
+
|
96 |
+
class RelativePositionalEncoding(nn.Module):
|
97 |
+
def __init__(self, attn_dim=8, max_rel_idx=32):
|
98 |
+
super().__init__()
|
99 |
+
self.max_rel_idx = max_rel_idx
|
100 |
+
self.n_rel_pos = 2 * self.max_rel_idx + 1
|
101 |
+
self.linear = nn.Linear(self.n_rel_pos, attn_dim)
|
102 |
+
|
103 |
+
def forward(self, residue_index):
|
104 |
+
d_ij = residue_index[..., None] - residue_index[..., None, :]
|
105 |
+
v_bins = torch.arange(self.n_rel_pos).to(d_ij.device) - self.max_rel_idx
|
106 |
+
idxs = (d_ij[..., None] - v_bins[None, None]).abs().argmin(-1)
|
107 |
+
p_ij = nn.functional.one_hot(idxs, num_classes=self.n_rel_pos)
|
108 |
+
embeddings = self.linear(p_ij.float())
|
109 |
+
return embeddings
|
110 |
+
|
111 |
+
|
112 |
+
########################################
|
113 |
+
# Adapted from https://github.com/NVlabs/edm
|
114 |
+
|
115 |
+
|
116 |
+
class Noise_Embedding(nn.Module):
|
117 |
+
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
118 |
+
super().__init__()
|
119 |
+
self.num_channels = num_channels
|
120 |
+
self.max_positions = max_positions
|
121 |
+
self.endpoint = endpoint
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
freqs = torch.arange(
|
125 |
+
start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device
|
126 |
+
)
|
127 |
+
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
128 |
+
freqs = (1 / self.max_positions) ** freqs
|
129 |
+
x = x.outer(freqs.to(x.dtype))
|
130 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
########################################
|
135 |
+
# Adapted from github.com/lucidrains
|
136 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch
|
137 |
+
# https://github.com/lucidrains/recurrent-interface-network-pytorch
|
138 |
+
|
139 |
+
|
140 |
+
def exists(x):
|
141 |
+
return x is not None
|
142 |
+
|
143 |
+
|
144 |
+
def default(val, d):
|
145 |
+
if exists(val):
|
146 |
+
return val
|
147 |
+
return d() if callable(d) else d
|
148 |
+
|
149 |
+
|
150 |
+
def posemb_sincos_1d(patches, temperature=10000, residue_index=None):
|
151 |
+
_, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
152 |
+
|
153 |
+
n = torch.arange(n, device=device) if residue_index is None else residue_index
|
154 |
+
assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb"
|
155 |
+
omega = torch.arange(dim // 2, device=device) / (dim // 2 - 1)
|
156 |
+
omega = 1.0 / (temperature**omega)
|
157 |
+
|
158 |
+
n = n[..., None] * omega
|
159 |
+
pe = torch.cat((n.sin(), n.cos()), dim=-1)
|
160 |
+
return pe.type(dtype)
|
161 |
+
|
162 |
+
|
163 |
+
class LayerNorm(nn.Module):
|
164 |
+
def __init__(self, dim):
|
165 |
+
super().__init__()
|
166 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
167 |
+
self.register_buffer("beta", torch.zeros(dim))
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
171 |
+
|
172 |
+
|
173 |
+
class NoiseConditioningBlock(nn.Module):
|
174 |
+
def __init__(self, n_in_channel, n_out_channel):
|
175 |
+
super().__init__()
|
176 |
+
self.block = nn.Sequential(
|
177 |
+
Noise_Embedding(n_in_channel),
|
178 |
+
nn.Linear(n_in_channel, n_out_channel),
|
179 |
+
nn.SiLU(),
|
180 |
+
nn.Linear(n_out_channel, n_out_channel),
|
181 |
+
Rearrange("b d -> b 1 d"),
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, noise_level):
|
185 |
+
return self.block(noise_level)
|
186 |
+
|
187 |
+
|
188 |
+
class TimeCondResnetBlock(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self, nic, noc, cond_nc, conv_layer=nn.Conv2d, dropout=0.1, n_norm_in_groups=4
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
self.block1 = nn.Sequential(
|
194 |
+
nn.GroupNorm(num_groups=nic // n_norm_in_groups, num_channels=nic),
|
195 |
+
nn.SiLU(),
|
196 |
+
conv_layer(nic, noc, 3, 1, 1),
|
197 |
+
)
|
198 |
+
self.cond_proj = nn.Linear(cond_nc, noc * 2)
|
199 |
+
self.mid_norm = nn.GroupNorm(num_groups=noc // 4, num_channels=noc)
|
200 |
+
self.dropout = dropout if dropout is None else nn.Dropout(dropout)
|
201 |
+
self.block2 = nn.Sequential(
|
202 |
+
nn.GroupNorm(num_groups=noc // 4, num_channels=noc),
|
203 |
+
nn.SiLU(),
|
204 |
+
conv_layer(noc, noc, 3, 1, 1),
|
205 |
+
)
|
206 |
+
self.mismatch = False
|
207 |
+
if nic != noc:
|
208 |
+
self.mismatch = True
|
209 |
+
self.conv_match = conv_layer(nic, noc, 1, 1, 0)
|
210 |
+
|
211 |
+
def forward(self, x, time=None):
|
212 |
+
h = self.block1(x)
|
213 |
+
|
214 |
+
if time is not None:
|
215 |
+
h = self.mid_norm(h)
|
216 |
+
scale, shift = self.cond_proj(time).chunk(2, dim=-1)
|
217 |
+
h = (h * (utils.expand(scale, h) + 1)) + utils.expand(shift, h)
|
218 |
+
|
219 |
+
if self.dropout is not None:
|
220 |
+
h = self.dropout(h)
|
221 |
+
|
222 |
+
h = self.block2(h)
|
223 |
+
|
224 |
+
if self.mismatch:
|
225 |
+
x = self.conv_match(x)
|
226 |
+
|
227 |
+
return x + h
|
228 |
+
|
229 |
+
|
230 |
+
class TimeCondAttention(nn.Module):
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
dim,
|
234 |
+
dim_context=None,
|
235 |
+
heads=4,
|
236 |
+
dim_head=32,
|
237 |
+
norm=False,
|
238 |
+
norm_context=False,
|
239 |
+
time_cond_dim=None,
|
240 |
+
attn_bias_dim=None,
|
241 |
+
rotary_embedding_module=None,
|
242 |
+
):
|
243 |
+
super().__init__()
|
244 |
+
hidden_dim = dim_head * heads
|
245 |
+
dim_context = default(dim_context, dim)
|
246 |
+
|
247 |
+
self.time_cond = None
|
248 |
+
|
249 |
+
if exists(time_cond_dim):
|
250 |
+
self.time_cond = nn.Sequential(nn.SiLU(), nn.Linear(time_cond_dim, dim * 2))
|
251 |
+
|
252 |
+
nn.init.zeros_(self.time_cond[-1].weight)
|
253 |
+
nn.init.zeros_(self.time_cond[-1].bias)
|
254 |
+
|
255 |
+
self.scale = dim_head**-0.5
|
256 |
+
self.heads = heads
|
257 |
+
|
258 |
+
self.norm = LayerNorm(dim) if norm else nn.Identity()
|
259 |
+
self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity()
|
260 |
+
|
261 |
+
self.attn_bias_proj = None
|
262 |
+
if attn_bias_dim is not None:
|
263 |
+
self.attn_bias_proj = nn.Sequential(
|
264 |
+
Rearrange("b a i j -> b i j a"),
|
265 |
+
nn.Linear(attn_bias_dim, heads),
|
266 |
+
Rearrange("b i j a -> b a i j"),
|
267 |
+
)
|
268 |
+
|
269 |
+
self.to_q = nn.Linear(dim, hidden_dim, bias=False)
|
270 |
+
self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias=False)
|
271 |
+
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
|
272 |
+
nn.init.zeros_(self.to_out.weight)
|
273 |
+
|
274 |
+
self.use_rope = False
|
275 |
+
if rotary_embedding_module is not None:
|
276 |
+
self.use_rope = True
|
277 |
+
self.rope = rotary_embedding_module
|
278 |
+
|
279 |
+
def forward(self, x, context=None, time=None, attn_bias=None, seq_mask=None):
|
280 |
+
# attn_bias is b, c, i, j
|
281 |
+
h = self.heads
|
282 |
+
has_context = exists(context)
|
283 |
+
|
284 |
+
context = default(context, x)
|
285 |
+
|
286 |
+
if x.shape[-1] != self.norm.gamma.shape[-1]:
|
287 |
+
print(context.shape, x.shape, self.norm.gamma.shape)
|
288 |
+
|
289 |
+
x = self.norm(x)
|
290 |
+
|
291 |
+
if exists(time):
|
292 |
+
scale, shift = self.time_cond(time).chunk(2, dim=-1)
|
293 |
+
x = (x * (scale + 1)) + shift
|
294 |
+
|
295 |
+
if has_context:
|
296 |
+
context = self.norm_context(context)
|
297 |
+
|
298 |
+
if seq_mask is not None:
|
299 |
+
x = x * seq_mask[..., None]
|
300 |
+
|
301 |
+
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
302 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
303 |
+
|
304 |
+
q = q * self.scale
|
305 |
+
|
306 |
+
if self.use_rope:
|
307 |
+
q = self.rope.rotate_queries_or_keys(q)
|
308 |
+
k = self.rope.rotate_queries_or_keys(k)
|
309 |
+
|
310 |
+
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
|
311 |
+
if attn_bias is not None:
|
312 |
+
if self.attn_bias_proj is not None:
|
313 |
+
attn_bias = self.attn_bias_proj(attn_bias)
|
314 |
+
sim += attn_bias
|
315 |
+
if seq_mask is not None:
|
316 |
+
attn_mask = torch.einsum("b i, b j -> b i j", seq_mask, seq_mask)[:, None]
|
317 |
+
sim -= (1 - attn_mask) * 1e6
|
318 |
+
attn = sim.softmax(dim=-1)
|
319 |
+
|
320 |
+
out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
|
321 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
322 |
+
out = self.to_out(out)
|
323 |
+
if seq_mask is not None:
|
324 |
+
out = out * seq_mask[..., None]
|
325 |
+
return out
|
326 |
+
|
327 |
+
|
328 |
+
class TimeCondFeedForward(nn.Module):
|
329 |
+
def __init__(self, dim, mult=4, dim_out=None, time_cond_dim=None, dropout=0.1):
|
330 |
+
super().__init__()
|
331 |
+
if dim_out is None:
|
332 |
+
dim_out = dim
|
333 |
+
self.norm = LayerNorm(dim)
|
334 |
+
|
335 |
+
self.time_cond = None
|
336 |
+
self.dropout = None
|
337 |
+
inner_dim = int(dim * mult)
|
338 |
+
|
339 |
+
if exists(time_cond_dim):
|
340 |
+
self.time_cond = nn.Sequential(
|
341 |
+
nn.SiLU(),
|
342 |
+
nn.Linear(time_cond_dim, inner_dim * 2),
|
343 |
+
)
|
344 |
+
|
345 |
+
nn.init.zeros_(self.time_cond[-1].weight)
|
346 |
+
nn.init.zeros_(self.time_cond[-1].bias)
|
347 |
+
|
348 |
+
self.linear_in = nn.Linear(dim, inner_dim)
|
349 |
+
self.nonlinearity = nn.SiLU()
|
350 |
+
if dropout is not None:
|
351 |
+
self.dropout = nn.Dropout(dropout)
|
352 |
+
self.linear_out = nn.Linear(inner_dim, dim_out)
|
353 |
+
nn.init.zeros_(self.linear_out.weight)
|
354 |
+
nn.init.zeros_(self.linear_out.bias)
|
355 |
+
|
356 |
+
def forward(self, x, time=None):
|
357 |
+
x = self.norm(x)
|
358 |
+
x = self.linear_in(x)
|
359 |
+
x = self.nonlinearity(x)
|
360 |
+
|
361 |
+
if exists(time):
|
362 |
+
scale, shift = self.time_cond(time).chunk(2, dim=-1)
|
363 |
+
x = (x * (scale + 1)) + shift
|
364 |
+
|
365 |
+
if exists(self.dropout):
|
366 |
+
x = self.dropout(x)
|
367 |
+
|
368 |
+
return self.linear_out(x)
|
369 |
+
|
370 |
+
|
371 |
+
class TimeCondTransformer(nn.Module):
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
dim,
|
375 |
+
depth,
|
376 |
+
heads,
|
377 |
+
dim_head,
|
378 |
+
time_cond_dim,
|
379 |
+
attn_bias_dim=None,
|
380 |
+
mlp_inner_dim_mult=4,
|
381 |
+
position_embedding_type: str = "rotary",
|
382 |
+
):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
self.rope = None
|
386 |
+
self.pos_emb_type = position_embedding_type
|
387 |
+
if position_embedding_type == "rotary":
|
388 |
+
self.rope = RotaryEmbedding(dim=32)
|
389 |
+
elif position_embedding_type == "relative":
|
390 |
+
self.relpos = nn.Sequential(
|
391 |
+
RelativePositionalEncoding(attn_dim=heads),
|
392 |
+
Rearrange("b i j d -> b d i j"),
|
393 |
+
)
|
394 |
+
|
395 |
+
self.layers = nn.ModuleList([])
|
396 |
+
for _ in range(depth):
|
397 |
+
self.layers.append(
|
398 |
+
nn.ModuleList(
|
399 |
+
[
|
400 |
+
TimeCondAttention(
|
401 |
+
dim,
|
402 |
+
heads=heads,
|
403 |
+
dim_head=dim_head,
|
404 |
+
norm=True,
|
405 |
+
time_cond_dim=time_cond_dim,
|
406 |
+
attn_bias_dim=attn_bias_dim,
|
407 |
+
rotary_embedding_module=self.rope,
|
408 |
+
),
|
409 |
+
TimeCondFeedForward(
|
410 |
+
dim, mlp_inner_dim_mult, time_cond_dim=time_cond_dim
|
411 |
+
),
|
412 |
+
]
|
413 |
+
)
|
414 |
+
)
|
415 |
+
|
416 |
+
def forward(
|
417 |
+
self,
|
418 |
+
x,
|
419 |
+
time=None,
|
420 |
+
attn_bias=None,
|
421 |
+
context=None,
|
422 |
+
seq_mask=None,
|
423 |
+
residue_index=None,
|
424 |
+
):
|
425 |
+
if self.pos_emb_type == "absolute":
|
426 |
+
pos_emb = posemb_sincos_1d(x)
|
427 |
+
x = x + pos_emb
|
428 |
+
elif self.pos_emb_type == "absolute_residx":
|
429 |
+
assert residue_index is not None
|
430 |
+
pos_emb = posemb_sincos_1d(x, residue_index=residue_index)
|
431 |
+
x = x + pos_emb
|
432 |
+
elif self.pos_emb_type == "relative":
|
433 |
+
assert residue_index is not None
|
434 |
+
pos_emb = self.relpos(residue_index)
|
435 |
+
attn_bias = pos_emb if attn_bias is None else attn_bias + pos_emb
|
436 |
+
if seq_mask is not None:
|
437 |
+
x = x * seq_mask[..., None]
|
438 |
+
|
439 |
+
for i, (attn, ff) in enumerate(self.layers):
|
440 |
+
x = x + attn(
|
441 |
+
x, context=context, time=time, attn_bias=attn_bias, seq_mask=seq_mask
|
442 |
+
)
|
443 |
+
x = x + ff(x, time=time)
|
444 |
+
if seq_mask is not None:
|
445 |
+
x = x * seq_mask[..., None]
|
446 |
+
|
447 |
+
return x
|
448 |
+
|
449 |
+
|
450 |
+
class TimeCondUViT(nn.Module):
|
451 |
+
def __init__(
|
452 |
+
self,
|
453 |
+
*,
|
454 |
+
seq_len: int,
|
455 |
+
dim: int,
|
456 |
+
patch_size: int = 1,
|
457 |
+
depth: int = 6,
|
458 |
+
heads: int = 8,
|
459 |
+
dim_head: int = 32,
|
460 |
+
n_filt_per_layer: List[int] = [],
|
461 |
+
n_blocks_per_layer: int = 2,
|
462 |
+
n_atoms: int = 37,
|
463 |
+
channels_per_atom: int = 6,
|
464 |
+
attn_bias_dim: int = None,
|
465 |
+
time_cond_dim: int = None,
|
466 |
+
conv_skip_connection: bool = False,
|
467 |
+
position_embedding_type: str = "rotary",
|
468 |
+
):
|
469 |
+
super().__init__()
|
470 |
+
|
471 |
+
# Initialize configuration params
|
472 |
+
if time_cond_dim is None:
|
473 |
+
time_cond_dim = dim * 4
|
474 |
+
self.position_embedding_type = position_embedding_type
|
475 |
+
channels = channels_per_atom
|
476 |
+
self.n_conv_layers = n_conv_layers = len(n_filt_per_layer)
|
477 |
+
if n_conv_layers > 0:
|
478 |
+
post_conv_filt = n_filt_per_layer[-1]
|
479 |
+
self.conv_skip_connection = conv_skip_connection and n_conv_layers == 1
|
480 |
+
transformer_seq_len = seq_len // (2**n_conv_layers)
|
481 |
+
assert transformer_seq_len % patch_size == 0
|
482 |
+
num_patches = transformer_seq_len // patch_size
|
483 |
+
dim_a = post_conv_atom_dim = max(1, n_atoms // (2 ** (n_conv_layers - 1)))
|
484 |
+
if n_conv_layers == 0:
|
485 |
+
patch_dim = patch_size * n_atoms * channels_per_atom
|
486 |
+
patch_dim_out = patch_size * n_atoms * 3
|
487 |
+
dim_a = n_atoms
|
488 |
+
elif conv_skip_connection and n_conv_layers == 1:
|
489 |
+
patch_dim = patch_size * (channels + post_conv_filt) * post_conv_atom_dim
|
490 |
+
patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim
|
491 |
+
elif n_conv_layers > 0:
|
492 |
+
patch_dim = patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim
|
493 |
+
|
494 |
+
# Make downsampling conv
|
495 |
+
# Downsamples n-1 times where n is n_conv_layers
|
496 |
+
down_conv = []
|
497 |
+
block_in = channels
|
498 |
+
for i, nf in enumerate(n_filt_per_layer):
|
499 |
+
block_out = nf
|
500 |
+
layer = []
|
501 |
+
for j in range(n_blocks_per_layer):
|
502 |
+
n_groups = 2 if i == 0 and j == 0 else 4
|
503 |
+
layer.append(
|
504 |
+
TimeCondResnetBlock(
|
505 |
+
block_in, block_out, time_cond_dim, n_norm_in_groups=n_groups
|
506 |
+
)
|
507 |
+
)
|
508 |
+
block_in = block_out
|
509 |
+
down_conv.append(nn.ModuleList(layer))
|
510 |
+
self.down_conv = nn.ModuleList(down_conv)
|
511 |
+
|
512 |
+
# Make transformer
|
513 |
+
self.to_patch_embedding = nn.Sequential(
|
514 |
+
Rearrange("b c (n p) a -> b n (p c a)", p=patch_size),
|
515 |
+
nn.Linear(patch_dim, dim),
|
516 |
+
LayerNorm(dim),
|
517 |
+
)
|
518 |
+
self.transformer = TimeCondTransformer(
|
519 |
+
dim,
|
520 |
+
depth,
|
521 |
+
heads,
|
522 |
+
dim_head,
|
523 |
+
time_cond_dim,
|
524 |
+
attn_bias_dim=attn_bias_dim,
|
525 |
+
position_embedding_type=position_embedding_type,
|
526 |
+
)
|
527 |
+
self.from_patch = nn.Sequential(
|
528 |
+
LayerNorm(dim),
|
529 |
+
nn.Linear(dim, patch_dim_out),
|
530 |
+
Rearrange("b n (p c a) -> b c (n p) a", p=patch_size, a=dim_a),
|
531 |
+
)
|
532 |
+
nn.init.zeros_(self.from_patch[-2].weight)
|
533 |
+
nn.init.zeros_(self.from_patch[-2].bias)
|
534 |
+
|
535 |
+
# Make upsampling conv
|
536 |
+
up_conv = []
|
537 |
+
for i, nf in enumerate(reversed(n_filt_per_layer)):
|
538 |
+
skip_in = nf
|
539 |
+
block_out = nf
|
540 |
+
layer = []
|
541 |
+
for j in range(n_blocks_per_layer):
|
542 |
+
layer.append(
|
543 |
+
TimeCondResnetBlock(block_in + skip_in, block_out, time_cond_dim)
|
544 |
+
)
|
545 |
+
block_in = block_out
|
546 |
+
up_conv.append(nn.ModuleList(layer))
|
547 |
+
self.up_conv = nn.ModuleList(up_conv)
|
548 |
+
|
549 |
+
# Conv out
|
550 |
+
if n_conv_layers > 0:
|
551 |
+
self.conv_out = nn.Sequential(
|
552 |
+
nn.GroupNorm(num_groups=block_out // 4, num_channels=block_out),
|
553 |
+
nn.SiLU(),
|
554 |
+
nn.Conv2d(block_out, channels // 2, 3, 1, 1),
|
555 |
+
)
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self, coords, time_cond, pair_bias=None, seq_mask=None, residue_index=None
|
559 |
+
):
|
560 |
+
if self.n_conv_layers > 0: # pad up to even dims
|
561 |
+
coords = F.pad(coords, (0, 0, 0, 0, 0, 1, 0, 0))
|
562 |
+
|
563 |
+
x = rearr_coords = rearrange(coords, "b n a c -> b c n a")
|
564 |
+
hiddens = []
|
565 |
+
for i, layer in enumerate(self.down_conv):
|
566 |
+
for block in layer:
|
567 |
+
x = block(x, time=time_cond)
|
568 |
+
hiddens.append(x)
|
569 |
+
if i != self.n_conv_layers - 1:
|
570 |
+
x = downsample(x)
|
571 |
+
|
572 |
+
if self.conv_skip_connection:
|
573 |
+
x = torch.cat([x, rearr_coords], 1)
|
574 |
+
|
575 |
+
x = self.to_patch_embedding(x)
|
576 |
+
# if self.position_embedding_type == 'absolute':
|
577 |
+
# pos_emb = posemb_sincos_1d(x)
|
578 |
+
# x = x + pos_emb
|
579 |
+
if seq_mask is not None and x.shape[1] == seq_mask.shape[1]:
|
580 |
+
x *= seq_mask[..., None]
|
581 |
+
x = self.transformer(
|
582 |
+
x,
|
583 |
+
time=time_cond,
|
584 |
+
attn_bias=pair_bias,
|
585 |
+
seq_mask=seq_mask,
|
586 |
+
residue_index=residue_index,
|
587 |
+
)
|
588 |
+
x = self.from_patch(x)
|
589 |
+
|
590 |
+
for i, layer in enumerate(self.up_conv):
|
591 |
+
for block in layer:
|
592 |
+
x = torch.cat([x, hiddens.pop()], 1)
|
593 |
+
x = block(x, time=time_cond)
|
594 |
+
if i != self.n_conv_layers - 1:
|
595 |
+
x = upsample_coords(x, hiddens[-1].shape[2:])
|
596 |
+
|
597 |
+
if self.n_conv_layers > 0:
|
598 |
+
x = self.conv_out(x)
|
599 |
+
x = x[..., :-1, :] # drop even-dims padding
|
600 |
+
|
601 |
+
x = rearrange(x, "b c n a -> b n a c")
|
602 |
+
return x
|
603 |
+
|
604 |
+
|
605 |
+
########################################
|
606 |
+
|
607 |
+
|
608 |
+
class LinearWarmupCosineDecay(torch.optim.lr_scheduler._LRScheduler):
|
609 |
+
def __init__(
|
610 |
+
self,
|
611 |
+
optimizer,
|
612 |
+
max_lr,
|
613 |
+
warmup_steps=1000,
|
614 |
+
decay_steps=int(1e6),
|
615 |
+
min_lr=1e-6,
|
616 |
+
**kwargs,
|
617 |
+
):
|
618 |
+
self.max_lr = max_lr
|
619 |
+
self.min_lr = min_lr
|
620 |
+
self.warmup_steps = warmup_steps
|
621 |
+
self.decay_steps = decay_steps
|
622 |
+
self.total_steps = warmup_steps + decay_steps
|
623 |
+
super(LinearWarmupCosineDecay, self).__init__(optimizer, **kwargs)
|
624 |
+
|
625 |
+
def get_lr(self):
|
626 |
+
# TODO double check for off-by-one errors
|
627 |
+
if self.last_epoch < self.warmup_steps:
|
628 |
+
curr_lr = self.last_epoch / self.warmup_steps * self.max_lr
|
629 |
+
return [curr_lr for group in self.optimizer.param_groups]
|
630 |
+
elif self.last_epoch < self.total_steps:
|
631 |
+
time = (self.last_epoch - self.warmup_steps) / self.decay_steps * np.pi
|
632 |
+
curr_lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (
|
633 |
+
1 + np.cos(time)
|
634 |
+
)
|
635 |
+
return [curr_lr for group in self.optimizer.param_groups]
|
636 |
+
else:
|
637 |
+
return [self.min_lr for group in self.optimizer.param_groups]
|
638 |
+
|
639 |
+
|
640 |
+
class NoiseConditionalProteinMPNN(nn.Module):
|
641 |
+
def __init__(
|
642 |
+
self,
|
643 |
+
n_channel=128,
|
644 |
+
n_layers=3,
|
645 |
+
n_neighbors=32,
|
646 |
+
time_cond_dim=None,
|
647 |
+
vocab_size=21,
|
648 |
+
input_S_is_embeddings=False,
|
649 |
+
):
|
650 |
+
super().__init__()
|
651 |
+
self.n_channel = n_channel
|
652 |
+
self.n_layers = n_layers
|
653 |
+
self.n_neighbors = n_neighbors
|
654 |
+
self.time_cond_dim = time_cond_dim
|
655 |
+
self.vocab_size = vocab_size
|
656 |
+
self.bb_idxs_if_atom37 = [
|
657 |
+
residue_constants.atom_order[a] for a in ["N", "CA", "C", "O"]
|
658 |
+
]
|
659 |
+
|
660 |
+
self.mpnn = protein_mpnn.ProteinMPNN(
|
661 |
+
num_letters=vocab_size,
|
662 |
+
node_features=n_channel,
|
663 |
+
edge_features=n_channel,
|
664 |
+
hidden_dim=n_channel,
|
665 |
+
num_encoder_layers=n_layers,
|
666 |
+
num_decoder_layers=n_layers,
|
667 |
+
vocab=vocab_size,
|
668 |
+
k_neighbors=n_neighbors,
|
669 |
+
augment_eps=0.0,
|
670 |
+
dropout=0.1,
|
671 |
+
ca_only=False,
|
672 |
+
time_cond_dim=time_cond_dim,
|
673 |
+
input_S_is_embeddings=input_S_is_embeddings,
|
674 |
+
)
|
675 |
+
|
676 |
+
def forward(
|
677 |
+
self, denoised_coords, noisy_aatype, seq_mask, residue_index, time_cond
|
678 |
+
):
|
679 |
+
if denoised_coords.shape[-2] == 37:
|
680 |
+
denoised_coords = denoised_coords[:, :, self.bb_idxs_if_atom37]
|
681 |
+
|
682 |
+
node_embs, encoder_embs = self.mpnn(
|
683 |
+
X=denoised_coords,
|
684 |
+
S=noisy_aatype,
|
685 |
+
mask=seq_mask,
|
686 |
+
chain_M=seq_mask,
|
687 |
+
residue_idx=residue_index,
|
688 |
+
chain_encoding_all=seq_mask,
|
689 |
+
randn=None,
|
690 |
+
use_input_decoding_order=False,
|
691 |
+
decoding_order=None,
|
692 |
+
causal_mask=False,
|
693 |
+
time_cond=time_cond,
|
694 |
+
return_node_embs=True,
|
695 |
+
)
|
696 |
+
return node_embs, encoder_embs
|
output_helpers.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
package.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
dssp
|
protpardelle_pymol.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymol import cmd
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import threading
|
7 |
+
|
8 |
+
try:
|
9 |
+
from gradio_client import Client
|
10 |
+
except ImportError:
|
11 |
+
print("gradio_client not installed, trying install:")
|
12 |
+
import pip
|
13 |
+
pip.main(['install', 'gradio_client'])
|
14 |
+
from gradio_client import Client
|
15 |
+
|
16 |
+
|
17 |
+
if os.environ.get("GRADIO_LOCAL") != None:
|
18 |
+
public_link = "http://127.0.0.1:7862"
|
19 |
+
else:
|
20 |
+
public_link = "spacesplaceholder"
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def thread_protpardelle(input_pdb,
|
26 |
+
resample_idxs,
|
27 |
+
modeltype,
|
28 |
+
mode,
|
29 |
+
minlen=50,
|
30 |
+
maxlen= 60,
|
31 |
+
steplen = 2,
|
32 |
+
per_len = 2):
|
33 |
+
client = Client(public_link)
|
34 |
+
|
35 |
+
job = client.submit(
|
36 |
+
input_pdb, # str in 'PDB Content' Textbox component
|
37 |
+
modeltype, # str in 'Choose a Mode' Radio component
|
38 |
+
f'"{resample_idxs}"', # str in 'Resampled Idxs' Textbox component
|
39 |
+
mode, # str (Option from: ['backbone', 'allatom'])
|
40 |
+
minlen, # int | float (numeric value between 2 and 200) minlen
|
41 |
+
maxlen, # int | float (numeric value between 3 and 200) in 'maxlen' Slider component
|
42 |
+
steplen, # int | float (numeric value between 1 and 50) in 'steplen' Slider component
|
43 |
+
per_len, # int | float (numeric value between 1 and 200) in 'perlen' Slider component
|
44 |
+
api_name="/protpardelle"
|
45 |
+
)
|
46 |
+
#start time
|
47 |
+
start = time.time()
|
48 |
+
|
49 |
+
while (job.done() == False):
|
50 |
+
status = job.status()
|
51 |
+
elapsed = time.time()-start
|
52 |
+
# format as hh:mm:ss
|
53 |
+
elapsed = time.strftime("%H:%M:%S", time.gmtime(elapsed))
|
54 |
+
|
55 |
+
print(f"\r protpardelle running since {elapsed}", end="")
|
56 |
+
time.sleep(1)
|
57 |
+
results = job.result()
|
58 |
+
|
59 |
+
# load each result into pymol
|
60 |
+
results = json.loads(results)
|
61 |
+
|
62 |
+
for (name,pdb_content) in results:
|
63 |
+
print(name)
|
64 |
+
cmd.read_pdbstr(pdb_content, os.path.basename(name))
|
65 |
+
|
66 |
+
|
67 |
+
def query_protpardelle(
|
68 |
+
name_of_input: str,
|
69 |
+
selection_resample_idxs: str="",
|
70 |
+
per_len: int = 2,
|
71 |
+
mode: str="allatom",
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
AUTHOR
|
75 |
+
Simon Duerr
|
76 |
+
https://twitter.com/simonduerr
|
77 |
+
DESCRIPTION
|
78 |
+
Run Protpardelle
|
79 |
+
USAGE
|
80 |
+
protpardelle name_of_input, selection_resampled_idx, modeltype, mode, per_len
|
81 |
+
PARAMETERS
|
82 |
+
name_of_input = string: name of input object
|
83 |
+
selection_resampled_idx = string: selection of resampled protein residues
|
84 |
+
per_len = int: per_len (default: 2)
|
85 |
+
mode = string: mode (default: 'allatom')
|
86 |
+
"""
|
87 |
+
if name_of_input != "":
|
88 |
+
input_pdb = cmd.get_pdbstr(name_of_input)
|
89 |
+
|
90 |
+
all_aa = cmd.index(name_of_input+" and name CA")
|
91 |
+
idx = cmd.index(selection_resample_idxs+" and name CA")
|
92 |
+
|
93 |
+
#map to zero indexed values
|
94 |
+
aa_mapping = {aa[1]:i for i,aa in enumerate(all_aa)}
|
95 |
+
|
96 |
+
idx = ",".join([str(aa_mapping[aa[1]]) for aa in idx])
|
97 |
+
|
98 |
+
print("resampling", idx , "(zero indexed) from", name_of_input)
|
99 |
+
|
100 |
+
t = threading.Thread(target=thread_protpardelle,
|
101 |
+
args=(input_pdb, idx, "conditional",mode ),
|
102 |
+
kwargs={'per_len':per_len},
|
103 |
+
daemon=True)
|
104 |
+
t.start()
|
105 |
+
|
106 |
+
def query_protpardelle_uncond(
|
107 |
+
|
108 |
+
minlen: int = 50,
|
109 |
+
maxlen: int = 60,
|
110 |
+
steplen: int = 2,
|
111 |
+
per_len: int = 2,
|
112 |
+
mode: str="allatom",
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
AUTHOR
|
116 |
+
Simon Duerr
|
117 |
+
https://twitter.com/simonduerr
|
118 |
+
DESCRIPTION
|
119 |
+
Run Protpardelle
|
120 |
+
USAGE
|
121 |
+
protpardelle_uncond minlen, maxlen, steplen, per_len,mode
|
122 |
+
PARAMETERS
|
123 |
+
minlen = int: minlen
|
124 |
+
maxlen = int: maxlen
|
125 |
+
steplen = int: steplen
|
126 |
+
per_len = int: per_len
|
127 |
+
mode = string: mode (default: 'allatom')
|
128 |
+
"""
|
129 |
+
|
130 |
+
modeltype = "unconditional"
|
131 |
+
idx = None
|
132 |
+
input_pdb = None
|
133 |
+
|
134 |
+
t = threading.Thread(target=thread_protpardelle,
|
135 |
+
args=(input_pdb, idx, modeltype, mode),
|
136 |
+
kwargs={'minlen':minlen, 'maxlen':maxlen, 'steplen':steplen,'per_len':per_len},
|
137 |
+
daemon=True)
|
138 |
+
t.start()
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
def setprotpardellelink(link:str):
|
143 |
+
global public_link
|
144 |
+
try:
|
145 |
+
client = Client(link)
|
146 |
+
except:
|
147 |
+
print("could not connect to:", public_link)
|
148 |
+
|
149 |
+
public_link = link
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
cmd.extend("protpardelle_setlink", setprotpardellelink)
|
155 |
+
|
156 |
+
cmd.extend("protpardelle", query_protpardelle)
|
157 |
+
|
158 |
+
cmd.extend("protpardelle_uncond", query_protpardelle_uncond)
|
159 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.1+cu116
|
2 |
+
transformers==4.29.1
|
3 |
+
einops
|
4 |
+
tqdm
|
5 |
+
wandb
|
6 |
+
rotary-embedding-torch
|
7 |
+
biopython
|
8 |
+
scipy
|
9 |
+
dm-tree
|
10 |
+
matplotlib
|
11 |
+
seaborn
|
12 |
+
black
|
13 |
+
ipython
|
14 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
sampling.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/ProteinDesignLab/protpardelle
|
3 |
+
License: MIT
|
4 |
+
Author: Alex Chu
|
5 |
+
|
6 |
+
Configs and convenience functions for wrapping the model sample() function.
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
import time
|
10 |
+
from typing import Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torchtyping import TensorType
|
14 |
+
|
15 |
+
from core import residue_constants
|
16 |
+
from core import utils
|
17 |
+
import diffusion
|
18 |
+
|
19 |
+
|
20 |
+
def default_backbone_sampling_config():
|
21 |
+
config = argparse.Namespace(
|
22 |
+
n_steps=500,
|
23 |
+
s_churn=200,
|
24 |
+
step_scale=1.2,
|
25 |
+
sidechain_mode=False,
|
26 |
+
noise_schedule=lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001),
|
27 |
+
)
|
28 |
+
return config
|
29 |
+
|
30 |
+
|
31 |
+
def default_allatom_sampling_config():
|
32 |
+
noise_schedule = lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001)
|
33 |
+
stage2 = argparse.Namespace(
|
34 |
+
apply_cond_proportion=1.0,
|
35 |
+
n_steps=200,
|
36 |
+
s_churn=100,
|
37 |
+
step_scale=1.2,
|
38 |
+
sidechain_mode=True,
|
39 |
+
skip_mpnn_proportion=1.0,
|
40 |
+
noise_schedule=noise_schedule,
|
41 |
+
)
|
42 |
+
config = argparse.Namespace(
|
43 |
+
n_steps=500,
|
44 |
+
s_churn=200,
|
45 |
+
step_scale=1.2,
|
46 |
+
sidechain_mode=True,
|
47 |
+
skip_mpnn_proportion=0.6,
|
48 |
+
use_fullmpnn=False,
|
49 |
+
use_fullmpnn_for_final=True,
|
50 |
+
anneal_seq_resampling_rate="linear",
|
51 |
+
noise_schedule=noise_schedule,
|
52 |
+
stage_2=stage2,
|
53 |
+
)
|
54 |
+
return config
|
55 |
+
|
56 |
+
|
57 |
+
def draw_backbone_samples(
|
58 |
+
model: torch.nn.Module,
|
59 |
+
seq_mask: TensorType["b n", float] = None,
|
60 |
+
n_samples: int = None,
|
61 |
+
sample_length_range: Tuple[int] = (50, 512),
|
62 |
+
pdb_save_path: Optional[str] = None,
|
63 |
+
return_aux: bool = False,
|
64 |
+
return_sampling_runtime: bool = False,
|
65 |
+
**sampling_kwargs,
|
66 |
+
):
|
67 |
+
device = model.device
|
68 |
+
if seq_mask is None:
|
69 |
+
assert n_samples is not None
|
70 |
+
seq_mask = model.make_seq_mask_for_sampling(
|
71 |
+
n_samples=n_samples,
|
72 |
+
min_len=sample_length_range[0],
|
73 |
+
max_len=sample_length_range[1],
|
74 |
+
)
|
75 |
+
|
76 |
+
start = time.time()
|
77 |
+
aux = model.sample(
|
78 |
+
seq_mask=seq_mask, return_last=False, return_aux=True, **sampling_kwargs
|
79 |
+
)
|
80 |
+
aux["runtime"] = time.time() - start
|
81 |
+
seq_lens = seq_mask.sum(-1).long()
|
82 |
+
cropped_samp_coords = [
|
83 |
+
s[: seq_lens[i], model.bb_idxs] for i, s in enumerate(aux["xt_traj"][-1])
|
84 |
+
]
|
85 |
+
|
86 |
+
if pdb_save_path is not None:
|
87 |
+
gly_aatype = (seq_mask * residue_constants.restype_order["G"]).long()
|
88 |
+
trimmed_aatype = [a[: seq_lens[i]] for i, a in enumerate(gly_aatype)]
|
89 |
+
atom_mask = utils.atom37_mask_from_aatype(gly_aatype, seq_mask).cpu()
|
90 |
+
for i in range(len(cropped_samp_coords)):
|
91 |
+
utils.write_coords_to_pdb(
|
92 |
+
cropped_samp_coords[i],
|
93 |
+
f"{pdb_save_path}{i}.pdb",
|
94 |
+
batched=False,
|
95 |
+
aatype=trimmed_aatype[i],
|
96 |
+
atom_mask=atom_mask[i],
|
97 |
+
)
|
98 |
+
|
99 |
+
if return_aux:
|
100 |
+
return aux
|
101 |
+
else:
|
102 |
+
if return_sampling_runtime:
|
103 |
+
return cropped_samp_coords, seq_mask, aux["runtime"]
|
104 |
+
else:
|
105 |
+
return cropped_samp_coords, seq_mask
|
106 |
+
|
107 |
+
|
108 |
+
def draw_allatom_samples(
|
109 |
+
model: torch.nn.Module,
|
110 |
+
seq_mask: TensorType["b n", float] = None,
|
111 |
+
n_samples: int = None,
|
112 |
+
sample_length_range: Tuple[int] = (50, 512),
|
113 |
+
two_stage_sampling: bool = True,
|
114 |
+
pdb_save_path: Optional[str] = None,
|
115 |
+
return_aux: bool = False,
|
116 |
+
return_sampling_runtime: bool = False,
|
117 |
+
**sampling_kwargs,
|
118 |
+
):
|
119 |
+
"""Implement the default 2-stage all-atom sampling routine."""
|
120 |
+
|
121 |
+
def save_allatom_samples(aux, path):
|
122 |
+
seq_lens = aux["seq_mask"].sum(-1).long()
|
123 |
+
cropped_samp_coords = [
|
124 |
+
c[: seq_lens[i]] for i, c in enumerate(aux["xt_traj"][-1])
|
125 |
+
]
|
126 |
+
cropped_samp_aatypes = [
|
127 |
+
s[: seq_lens[i]] for i, s in enumerate(aux["st_traj"][-1])
|
128 |
+
]
|
129 |
+
samp_atom_mask = utils.atom37_mask_from_aatype(
|
130 |
+
aux["st_traj"][-1].to(device), seq_mask
|
131 |
+
)
|
132 |
+
samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)]
|
133 |
+
for i, c in enumerate(cropped_samp_coords):
|
134 |
+
utils.write_coords_to_pdb(
|
135 |
+
c,
|
136 |
+
f"{path}{i}.pdb",
|
137 |
+
batched=False,
|
138 |
+
aatype=cropped_samp_aatypes[i],
|
139 |
+
atom_mask=samp_atom_mask[i],
|
140 |
+
conect=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
device = model.device
|
144 |
+
if seq_mask is None:
|
145 |
+
assert n_samples is not None
|
146 |
+
seq_mask = model.make_seq_mask_for_sampling(
|
147 |
+
n_samples=n_samples,
|
148 |
+
min_len=sample_length_range[0],
|
149 |
+
max_len=sample_length_range[1],
|
150 |
+
)
|
151 |
+
sampling_runtime = 0.0
|
152 |
+
|
153 |
+
# Stage 1 sampling
|
154 |
+
start = time.time()
|
155 |
+
if "stage_2" in sampling_kwargs:
|
156 |
+
stage_2_kwargs = vars(sampling_kwargs.pop("stage_2"))
|
157 |
+
aux = model.sample(
|
158 |
+
seq_mask=seq_mask,
|
159 |
+
return_last=False,
|
160 |
+
return_aux=True,
|
161 |
+
**sampling_kwargs,
|
162 |
+
)
|
163 |
+
sampling_runtime = time.time() - start
|
164 |
+
if pdb_save_path is not None and two_stage_sampling:
|
165 |
+
save_allatom_samples(aux, pdb_save_path + "_init")
|
166 |
+
|
167 |
+
# Stage 2 sampling (sidechain refinement only)
|
168 |
+
if two_stage_sampling:
|
169 |
+
samp_seq = aux["st_traj"][-1]
|
170 |
+
samp_coords = aux["xt_traj"][-1]
|
171 |
+
cond_atom_mask = utils.atom37_mask_from_aatype((seq_mask * 7).long(), seq_mask)
|
172 |
+
aux = {f"stage1_{k}": v for k, v in aux.items()}
|
173 |
+
start = time.time()
|
174 |
+
stage2_aux = model.sample(
|
175 |
+
gt_cond_atom_mask=cond_atom_mask.to(device), # condition on backbone
|
176 |
+
gt_cond_seq_mask=seq_mask.to(device),
|
177 |
+
gt_coords=samp_coords.to(device),
|
178 |
+
gt_aatype=samp_seq.to(device),
|
179 |
+
seq_mask=seq_mask,
|
180 |
+
return_last=False,
|
181 |
+
return_aux=True,
|
182 |
+
**stage_2_kwargs,
|
183 |
+
)
|
184 |
+
sampling_runtime += time.time() - start
|
185 |
+
aux = {**aux, **stage2_aux}
|
186 |
+
if pdb_save_path is not None:
|
187 |
+
save_allatom_samples(aux, pdb_save_path + "_samp")
|
188 |
+
aux["runtime"] = sampling_runtime
|
189 |
+
|
190 |
+
# Process outputs, crop to correct length
|
191 |
+
if return_aux:
|
192 |
+
return aux
|
193 |
+
else:
|
194 |
+
xt_traj = aux["xt_traj"]
|
195 |
+
st_traj = aux["st_traj"]
|
196 |
+
seq_mask = aux["seq_mask"]
|
197 |
+
seq_lens = seq_mask.sum(-1).long()
|
198 |
+
cropped_samp_coords = [c[: seq_lens[i]] for i, c in enumerate(xt_traj[-1])]
|
199 |
+
cropped_samp_aatypes = [s[: seq_lens[i]] for i, s in enumerate(st_traj[-1])]
|
200 |
+
samp_atom_mask = utils.atom37_mask_from_aatype(st_traj[-1].to(device), seq_mask)
|
201 |
+
samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)]
|
202 |
+
orig_xt_traj = aux["stage1_xt_traj"]
|
203 |
+
stage1_coords = [c[: seq_lens[i]] for i, c in enumerate(orig_xt_traj[-1])]
|
204 |
+
ret = (
|
205 |
+
cropped_samp_coords,
|
206 |
+
cropped_samp_aatypes,
|
207 |
+
samp_atom_mask,
|
208 |
+
stage1_coords,
|
209 |
+
seq_mask,
|
210 |
+
)
|
211 |
+
if return_sampling_runtime:
|
212 |
+
ret = ret + (sampling_runtime,)
|
213 |
+
return ret
|