Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +15 -9
modeling_fastesm.py
CHANGED
|
@@ -4,9 +4,10 @@ from torch.nn import functional as F
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from einops import rearrange
|
|
|
|
| 7 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
| 8 |
from transformers.modeling_outputs import (
|
| 9 |
-
|
| 10 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 11 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 12 |
SequenceClassifierOutput,
|
|
@@ -23,6 +24,15 @@ from transformers.models.esm.modeling_esm import (
|
|
| 23 |
from tqdm.auto import tqdm
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class FastEsmConfig(PretrainedConfig):
|
| 27 |
model_type = "fast_esm"
|
| 28 |
def __init__(
|
|
@@ -656,9 +666,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
|
|
| 656 |
Model outputs including hidden states and optionally attention weights
|
| 657 |
"""
|
| 658 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 659 |
-
output_hidden_states =
|
| 660 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 661 |
-
)
|
| 662 |
|
| 663 |
if input_ids is not None and inputs_embeds is not None:
|
| 664 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
@@ -739,9 +747,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
| 739 |
Model outputs including hidden states and optionally attention weights
|
| 740 |
"""
|
| 741 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 742 |
-
output_hidden_states =
|
| 743 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 744 |
-
)
|
| 745 |
|
| 746 |
if input_ids is not None and inputs_embeds is not None:
|
| 747 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
@@ -798,7 +804,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
| 798 |
output_attentions: Optional[bool] = None,
|
| 799 |
output_hidden_states: Optional[bool] = None,
|
| 800 |
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 801 |
-
) -> Union[Tuple,
|
| 802 |
outputs = self.esm(
|
| 803 |
input_ids,
|
| 804 |
attention_mask=attention_mask,
|
|
@@ -815,7 +821,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
| 815 |
labels = labels.to(prediction_scores.device)
|
| 816 |
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 817 |
|
| 818 |
-
return
|
| 819 |
loss=loss,
|
| 820 |
logits=prediction_scores,
|
| 821 |
hidden_states=outputs.hidden_states,
|
|
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from einops import rearrange
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
|
| 9 |
from transformers.modeling_outputs import (
|
| 10 |
+
ModelOutput,
|
| 11 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 12 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 13 |
SequenceClassifierOutput,
|
|
|
|
| 24 |
from tqdm.auto import tqdm
|
| 25 |
|
| 26 |
|
| 27 |
+
@dataclass
|
| 28 |
+
class EsmMaskedLMOutput(ModelOutput):
|
| 29 |
+
loss: Optional[torch.FloatTensor] = None
|
| 30 |
+
logits: Optional[torch.FloatTensor] = None
|
| 31 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 32 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 33 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class FastEsmConfig(PretrainedConfig):
|
| 37 |
model_type = "fast_esm"
|
| 38 |
def __init__(
|
|
|
|
| 666 |
Model outputs including hidden states and optionally attention weights
|
| 667 |
"""
|
| 668 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 669 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
|
|
| 670 |
|
| 671 |
if input_ids is not None and inputs_embeds is not None:
|
| 672 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
| 747 |
Model outputs including hidden states and optionally attention weights
|
| 748 |
"""
|
| 749 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 750 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
|
|
| 751 |
|
| 752 |
if input_ids is not None and inputs_embeds is not None:
|
| 753 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
| 804 |
output_attentions: Optional[bool] = None,
|
| 805 |
output_hidden_states: Optional[bool] = None,
|
| 806 |
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 807 |
+
) -> Union[Tuple, EsmMaskedLMOutput]:
|
| 808 |
outputs = self.esm(
|
| 809 |
input_ids,
|
| 810 |
attention_mask=attention_mask,
|
|
|
|
| 821 |
labels = labels.to(prediction_scores.device)
|
| 822 |
loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 823 |
|
| 824 |
+
return EsmMaskedLMOutput(
|
| 825 |
loss=loss,
|
| 826 |
logits=prediction_scores,
|
| 827 |
hidden_states=outputs.hidden_states,
|