Patrick Haller commited on
Commit
efaa79d
·
1 Parent(s): fac4c2e
Files changed (1) hide show
  1. modeling_hf_alibaba_nlp_gte.py +7 -1
modeling_hf_alibaba_nlp_gte.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  import math
18
  from dataclasses import dataclass
19
- from typing import List, Optional, Tuple, Union
20
 
21
  import torch
22
  import torch.utils.checkpoint
@@ -994,6 +994,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
994
  use_cache: Optional[bool] = None,
995
  output_attentions: Optional[bool] = None,
996
  output_hidden_states: Optional[bool] = None,
 
997
  ) -> SequenceClassifierOutputWithPast:
998
  r"""
999
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1009,6 +1010,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
1009
  inputs_embeds=inputs_embeds,
1010
  output_attentions=output_attentions,
1011
  output_hidden_states=output_hidden_states,
 
1012
  )
1013
  if self.config.add_pooling_layer:
1014
  hidden_states = transformer_outputs.pooler_output
@@ -1021,6 +1023,10 @@ class GteForSequenceClassification(GtePreTrainedModel):
1021
  if labels is not None:
1022
  loss = self.loss_function(labels, logits, self.config)
1023
 
 
 
 
 
1024
  return SequenceClassifierOutputWithPast(
1025
  loss=loss,
1026
  logits=logits,
 
16
 
17
  import math
18
  from dataclasses import dataclass
19
+ from typing import Any, List, Optional, Tuple, Union
20
 
21
  import torch
22
  import torch.utils.checkpoint
 
994
  use_cache: Optional[bool] = None,
995
  output_attentions: Optional[bool] = None,
996
  output_hidden_states: Optional[bool] = None,
997
+ **kwargs: Any,
998
  ) -> SequenceClassifierOutputWithPast:
999
  r"""
1000
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1010
  inputs_embeds=inputs_embeds,
1011
  output_attentions=output_attentions,
1012
  output_hidden_states=output_hidden_states,
1013
+ **kwargs
1014
  )
1015
  if self.config.add_pooling_layer:
1016
  hidden_states = transformer_outputs.pooler_output
 
1023
  if labels is not None:
1024
  loss = self.loss_function(labels, logits, self.config)
1025
 
1026
+ # if not return_dict:
1027
+ # output = (logits,) + transformer_outputs[1:]
1028
+ # return ((loss,) + output) if loss is not None else output
1029
+
1030
  return SequenceClassifierOutputWithPast(
1031
  loss=loss,
1032
  logits=logits,