Patrick Haller
commited on
Commit
·
d4d56eb
1
Parent(s):
8ac80ca
Making configurable
Browse files
configuration_hf_alibaba_nlp_gte.py
CHANGED
@@ -116,6 +116,8 @@ class GteConfig(PretrainedConfig):
|
|
116 |
use_memory_efficient_attention=False,
|
117 |
logn_attention_scale=False,
|
118 |
logn_attention_clip1=False,
|
|
|
|
|
119 |
**kwargs,
|
120 |
):
|
121 |
super().__init__(**kwargs)
|
@@ -142,4 +144,7 @@ class GteConfig(PretrainedConfig):
|
|
142 |
self.unpad_inputs = unpad_inputs
|
143 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
144 |
self.logn_attention_scale = logn_attention_scale
|
145 |
-
self.logn_attention_clip1 = logn_attention_clip1
|
|
|
|
|
|
|
|
116 |
use_memory_efficient_attention=False,
|
117 |
logn_attention_scale=False,
|
118 |
logn_attention_clip1=False,
|
119 |
+
add_pooling_layer=True,
|
120 |
+
num_labels=0,
|
121 |
**kwargs,
|
122 |
):
|
123 |
super().__init__(**kwargs)
|
|
|
144 |
self.unpad_inputs = unpad_inputs
|
145 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
146 |
self.logn_attention_scale = logn_attention_scale
|
147 |
+
self.logn_attention_clip1 = logn_attention_clip1
|
148 |
+
|
149 |
+
self.add_pooling_layer = add_pooling_layer
|
150 |
+
self.num_labels = num_labels
|
modeling_hf_alibaba_nlp_gte.py
CHANGED
@@ -970,8 +970,9 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
970 |
def __init__(self, config: GteConfig):
|
971 |
super().__init__(config)
|
972 |
self.config = config
|
973 |
-
self.num_labels =
|
974 |
-
|
|
|
975 |
|
976 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
977 |
self.loss_function = nn.MSELoss()
|
@@ -1010,7 +1011,10 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
1010 |
output_attentions=output_attentions,
|
1011 |
output_hidden_states=output_hidden_states,
|
1012 |
)
|
1013 |
-
|
|
|
|
|
|
|
1014 |
|
1015 |
logits = self.score(hidden_states)
|
1016 |
|
|
|
970 |
def __init__(self, config: GteConfig):
|
971 |
super().__init__(config)
|
972 |
self.config = config
|
973 |
+
self.num_labels = config.num_labels
|
974 |
+
assert config.num_labels > 0, "num_labels should be greater than 0 for sequence classification"
|
975 |
+
self.model = GteModel(config, add_pooling_layer=config.add_pooling_layer)
|
976 |
|
977 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
978 |
self.loss_function = nn.MSELoss()
|
|
|
1011 |
output_attentions=output_attentions,
|
1012 |
output_hidden_states=output_hidden_states,
|
1013 |
)
|
1014 |
+
if self.config.add_pooling_layer:
|
1015 |
+
hidden_states = transformer_outputs.pooler_output
|
1016 |
+
else:
|
1017 |
+
hidden_states = transformer_outputs.last_hidden_state[:, 0]
|
1018 |
|
1019 |
logits = self.score(hidden_states)
|
1020 |
|