Commit
·
ce6d631
1
Parent(s):
770b039
testin the trick
Browse files- lang_detect.py +2 -2
- modeling_stacked.py +14 -16
lang_detect.py
CHANGED
|
@@ -16,8 +16,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
| 16 |
def _forward(self, text):
|
| 17 |
print(f"Do we arrive here? {text}")
|
| 18 |
print(f"Let's check the model: {self.model.get_floret_model()}")
|
| 19 |
-
predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
|
| 20 |
-
|
| 21 |
return text
|
| 22 |
|
| 23 |
def postprocess(self, text, **kwargs):
|
|
|
|
| 16 |
def _forward(self, text):
|
| 17 |
print(f"Do we arrive here? {text}")
|
| 18 |
print(f"Let's check the model: {self.model.get_floret_model()}")
|
| 19 |
+
# predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
|
| 20 |
+
self.model(text)
|
| 21 |
return text
|
| 22 |
|
| 23 |
def postprocess(self, text, **kwargs):
|
modeling_stacked.py
CHANGED
|
@@ -42,22 +42,20 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
| 42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 43 |
# Convert input_ids to strings using tokenizer
|
| 44 |
print(f"Check if it arrives here: {input_ids}")
|
| 45 |
-
if input_ids is not None:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
else:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if texts:
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
else:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
(1, 2)
|
| 60 |
-
) # Dummy tensor with shape (batch_size, num_classes)
|
| 61 |
|
| 62 |
def state_dict(self, *args, **kwargs):
|
| 63 |
# Return an empty state dictionary
|
|
|
|
| 42 |
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 43 |
# Convert input_ids to strings using tokenizer
|
| 44 |
print(f"Check if it arrives here: {input_ids}")
|
| 45 |
+
# if input_ids is not None:
|
| 46 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 47 |
+
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
| 48 |
+
# else:
|
| 49 |
+
# texts = kwargs.get("text", None)
|
| 50 |
+
#
|
| 51 |
+
# if texts:
|
| 52 |
+
# # Floret expects strings, not tensors
|
| 53 |
+
# predictions = [self.model_floret(text) for text in texts]
|
| 54 |
+
# # Convert predictions to tensors for Hugging Face compatibility
|
| 55 |
+
# return torch.tensor(predictions)
|
| 56 |
+
# else:
|
| 57 |
+
# If no text is found, return dummy output
|
| 58 |
+
return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes)
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def state_dict(self, *args, **kwargs):
|
| 61 |
# Return an empty state dictionary
|