Commit
·
da75bdd
1
Parent(s):
8c19df9
lets try to change the pipeline
Browse files- modeling_stacked.py +2 -3
modeling_stacked.py
CHANGED
|
@@ -36,8 +36,8 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
| 36 |
self.config = config
|
| 37 |
|
| 38 |
# Load floret model
|
|
|
|
| 39 |
self.model_floret = floret.load_model(self.config.filename)
|
| 40 |
-
print(f"Model loaded: {self.model_floret}")
|
| 41 |
|
| 42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 43 |
# Convert input_ids to strings using tokenizer
|
|
@@ -81,8 +81,7 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
| 81 |
|
| 82 |
@property
|
| 83 |
def device(self):
|
| 84 |
-
return
|
| 85 |
-
# torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 86 |
|
| 87 |
@classmethod
|
| 88 |
def from_pretrained(cls, *args, **kwargs):
|
|
|
|
| 36 |
self.config = config
|
| 37 |
|
| 38 |
# Load floret model
|
| 39 |
+
self.dummy_param = nn.Parameter(torch.zeros(1))
|
| 40 |
self.model_floret = floret.load_model(self.config.filename)
|
|
|
|
| 41 |
|
| 42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 43 |
# Convert input_ids to strings using tokenizer
|
|
|
|
| 81 |
|
| 82 |
@property
|
| 83 |
def device(self):
|
| 84 |
+
return next(self.parameters()).device
|
|
|
|
| 85 |
|
| 86 |
@classmethod
|
| 87 |
def from_pretrained(cls, *args, **kwargs):
|