Upload code
Browse files- modeling_prot2text.py +0 -13
modeling_prot2text.py
CHANGED
@@ -123,11 +123,7 @@ class Prot2TextModel(PreTrainedModel):
|
|
123 |
|
124 |
@torch.no_grad()
|
125 |
def generate_protein_description(self,
|
126 |
-
protein_pdbID=None,
|
127 |
protein_sequence=None,
|
128 |
-
edge_index: Optional[torch.LongTensor] = None,
|
129 |
-
x: Optional[torch.FloatTensor] = None,
|
130 |
-
edge_type: Optional[torch.LongTensor] = None,
|
131 |
tokenizer=None,
|
132 |
device='cpu'
|
133 |
):
|
@@ -136,17 +132,8 @@ class Prot2TextModel(PreTrainedModel):
|
|
136 |
raise ValueError(
|
137 |
"The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
|
138 |
)
|
139 |
-
if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None):
|
140 |
-
raise ValueError(
|
141 |
-
"The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type"
|
142 |
-
)
|
143 |
if self.config.esm:
|
144 |
esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
|
145 |
-
|
146 |
-
if protein_pdbID==None and protein_sequence==None:
|
147 |
-
raise ValueError(
|
148 |
-
"you need to provide either a protein AlphaFold Id or an amino-acid sequence"
|
149 |
-
)
|
150 |
|
151 |
|
152 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|
|
|
123 |
|
124 |
@torch.no_grad()
|
125 |
def generate_protein_description(self,
|
|
|
126 |
protein_sequence=None,
|
|
|
|
|
|
|
127 |
tokenizer=None,
|
128 |
device='cpu'
|
129 |
):
|
|
|
132 |
raise ValueError(
|
133 |
"The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
|
134 |
)
|
|
|
|
|
|
|
|
|
135 |
if self.config.esm:
|
136 |
esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|