Commit
·
ba6b686
1
Parent(s):
1b37761
lets try to change the pipeline
Browse files- modeling_stacked.py +16 -0
modeling_stacked.py
CHANGED
|
@@ -16,6 +16,17 @@ def get_info(label_map):
|
|
| 16 |
return num_token_labels_dict
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
| 20 |
|
| 21 |
config_class = ImpressoConfig
|
|
@@ -63,6 +74,11 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
| 63 |
def get_floret_model(self):
|
| 64 |
return self.model_floret
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# def forward(
|
| 67 |
# self,
|
| 68 |
# input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 16 |
return num_token_labels_dict
|
| 17 |
|
| 18 |
|
| 19 |
+
# class MyCustomModel:
|
| 20 |
+
# def __init__(self):
|
| 21 |
+
# # Custom initialization
|
| 22 |
+
# pass
|
| 23 |
+
#
|
| 24 |
+
# @classmethod
|
| 25 |
+
# def from_pretrained(cls, *args, **kwargs):
|
| 26 |
+
# print("Ignoring weights and using custom initialization.")
|
| 27 |
+
# return cls()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
| 31 |
|
| 32 |
config_class = ImpressoConfig
|
|
|
|
| 74 |
def get_floret_model(self):
|
| 75 |
return self.model_floret
|
| 76 |
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_pretrained(cls, *args, **kwargs):
|
| 79 |
+
print("Ignoring weights and using custom initialization.")
|
| 80 |
+
return cls()
|
| 81 |
+
|
| 82 |
# def forward(
|
| 83 |
# self,
|
| 84 |
# input_ids: Optional[torch.Tensor] = None,
|