Patrick Haller
commited on
Commit
·
efaa79d
1
Parent(s):
fac4c2e
Fix
Browse files
modeling_hf_alibaba_nlp_gte.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
|
17 |
import math
|
18 |
from dataclasses import dataclass
|
19 |
-
from typing import List, Optional, Tuple, Union
|
20 |
|
21 |
import torch
|
22 |
import torch.utils.checkpoint
|
@@ -994,6 +994,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
994 |
use_cache: Optional[bool] = None,
|
995 |
output_attentions: Optional[bool] = None,
|
996 |
output_hidden_states: Optional[bool] = None,
|
|
|
997 |
) -> SequenceClassifierOutputWithPast:
|
998 |
r"""
|
999 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -1009,6 +1010,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
1009 |
inputs_embeds=inputs_embeds,
|
1010 |
output_attentions=output_attentions,
|
1011 |
output_hidden_states=output_hidden_states,
|
|
|
1012 |
)
|
1013 |
if self.config.add_pooling_layer:
|
1014 |
hidden_states = transformer_outputs.pooler_output
|
@@ -1021,6 +1023,10 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
1021 |
if labels is not None:
|
1022 |
loss = self.loss_function(labels, logits, self.config)
|
1023 |
|
|
|
|
|
|
|
|
|
1024 |
return SequenceClassifierOutputWithPast(
|
1025 |
loss=loss,
|
1026 |
logits=logits,
|
|
|
16 |
|
17 |
import math
|
18 |
from dataclasses import dataclass
|
19 |
+
from typing import Any, List, Optional, Tuple, Union
|
20 |
|
21 |
import torch
|
22 |
import torch.utils.checkpoint
|
|
|
994 |
use_cache: Optional[bool] = None,
|
995 |
output_attentions: Optional[bool] = None,
|
996 |
output_hidden_states: Optional[bool] = None,
|
997 |
+
**kwargs: Any,
|
998 |
) -> SequenceClassifierOutputWithPast:
|
999 |
r"""
|
1000 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
1010 |
inputs_embeds=inputs_embeds,
|
1011 |
output_attentions=output_attentions,
|
1012 |
output_hidden_states=output_hidden_states,
|
1013 |
+
**kwargs
|
1014 |
)
|
1015 |
if self.config.add_pooling_layer:
|
1016 |
hidden_states = transformer_outputs.pooler_output
|
|
|
1023 |
if labels is not None:
|
1024 |
loss = self.loss_function(labels, logits, self.config)
|
1025 |
|
1026 |
+
# if not return_dict:
|
1027 |
+
# output = (logits,) + transformer_outputs[1:]
|
1028 |
+
# return ((loss,) + output) if loss is not None else output
|
1029 |
+
|
1030 |
return SequenceClassifierOutputWithPast(
|
1031 |
loss=loss,
|
1032 |
logits=logits,
|