Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -619,9 +619,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 619 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 620 |
"""
|
| 621 |
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 622 |
-
sequences = sorted(sequences, key=len, reverse=True)
|
| 623 |
-
dataset = ProteinDataset(sequences)
|
| 624 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 625 |
device = self.device
|
| 626 |
|
| 627 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
@@ -636,6 +633,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 636 |
else:
|
| 637 |
raise ValueError(f"Invalid pooling type: {pooling_type}")
|
| 638 |
|
|
|
|
| 639 |
if sql:
|
| 640 |
import sqlite3
|
| 641 |
conn = sqlite3.connect(sql_db_path)
|
|
@@ -646,6 +644,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 646 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 647 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 648 |
if len(to_embed) > 0:
|
|
|
|
|
|
|
|
|
|
| 649 |
with torch.no_grad():
|
| 650 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 651 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
|
@@ -668,6 +669,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 668 |
return None
|
| 669 |
|
| 670 |
embeddings_dict = {}
|
|
|
|
|
|
|
|
|
|
| 671 |
with torch.no_grad():
|
| 672 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 673 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
|
|
|
| 619 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 620 |
"""
|
| 621 |
sequences = list(set([seq[:max_len] for seq in sequences]))
|
|
|
|
|
|
|
|
|
|
| 622 |
device = self.device
|
| 623 |
|
| 624 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 633 |
else:
|
| 634 |
raise ValueError(f"Invalid pooling type: {pooling_type}")
|
| 635 |
|
| 636 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 637 |
if sql:
|
| 638 |
import sqlite3
|
| 639 |
conn = sqlite3.connect(sql_db_path)
|
|
|
|
| 644 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 645 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 646 |
if len(to_embed) > 0:
|
| 647 |
+
to_embed = sorted(to_embed, key=len, reverse=True)
|
| 648 |
+
dataset = ProteinDataset(to_embed)
|
| 649 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 650 |
with torch.no_grad():
|
| 651 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 652 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
|
|
|
| 669 |
return None
|
| 670 |
|
| 671 |
embeddings_dict = {}
|
| 672 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
| 673 |
+
dataset = ProteinDataset(sequences)
|
| 674 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 675 |
with torch.no_grad():
|
| 676 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 677 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|