emanuelaboros commited on
Commit
ce6d631
·
1 Parent(s): 770b039

testin the trick

Browse files
Files changed (2) hide show
  1. lang_detect.py +2 -2
  2. modeling_stacked.py +14 -16
lang_detect.py CHANGED
@@ -16,8 +16,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
16
  def _forward(self, text):
17
  print(f"Do we arrive here? {text}")
18
  print(f"Let's check the model: {self.model.get_floret_model()}")
19
- predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
20
-
21
  return text
22
 
23
  def postprocess(self, text, **kwargs):
 
16
  def _forward(self, text):
17
  print(f"Do we arrive here? {text}")
18
  print(f"Let's check the model: {self.model.get_floret_model()}")
19
+ # predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
20
+ self.model(text)
21
  return text
22
 
23
  def postprocess(self, text, **kwargs):
modeling_stacked.py CHANGED
@@ -42,22 +42,20 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
42
  def forward(self, input_ids, attention_mask=None, **kwargs):
43
  # Convert input_ids to strings using tokenizer
44
  print(f"Check if it arrives here: {input_ids}")
45
- if input_ids is not None:
46
- tokenizer = kwargs.get("tokenizer")
47
- texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
48
- else:
49
- texts = kwargs.get("text", None)
50
-
51
- if texts:
52
- # Floret expects strings, not tensors
53
- predictions = [self.model_floret(text) for text in texts]
54
- # Convert predictions to tensors for Hugging Face compatibility
55
- return torch.tensor(predictions)
56
- else:
57
- # If no text is found, return dummy output
58
- return torch.zeros(
59
- (1, 2)
60
- ) # Dummy tensor with shape (batch_size, num_classes)
61
 
62
  def state_dict(self, *args, **kwargs):
63
  # Return an empty state dictionary
 
42
  def forward(self, input_ids, attention_mask=None, **kwargs):
43
  # Convert input_ids to strings using tokenizer
44
  print(f"Check if it arrives here: {input_ids}")
45
+ # if input_ids is not None:
46
+ # tokenizer = kwargs.get("tokenizer")
47
+ # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
48
+ # else:
49
+ # texts = kwargs.get("text", None)
50
+ #
51
+ # if texts:
52
+ # # Floret expects strings, not tensors
53
+ # predictions = [self.model_floret(text) for text in texts]
54
+ # # Convert predictions to tensors for Hugging Face compatibility
55
+ # return torch.tensor(predictions)
56
+ # else:
57
+ # If no text is found, return dummy output
58
+ return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes)
 
 
59
 
60
  def state_dict(self, *args, **kwargs):
61
  # Return an empty state dictionary