Patrick Haller commited on
Commit
d4d56eb
·
1 Parent(s): 8ac80ca

Making configurable

Browse files
configuration_hf_alibaba_nlp_gte.py CHANGED
@@ -116,6 +116,8 @@ class GteConfig(PretrainedConfig):
116
  use_memory_efficient_attention=False,
117
  logn_attention_scale=False,
118
  logn_attention_clip1=False,
 
 
119
  **kwargs,
120
  ):
121
  super().__init__(**kwargs)
@@ -142,4 +144,7 @@ class GteConfig(PretrainedConfig):
142
  self.unpad_inputs = unpad_inputs
143
  self.use_memory_efficient_attention = use_memory_efficient_attention
144
  self.logn_attention_scale = logn_attention_scale
145
- self.logn_attention_clip1 = logn_attention_clip1
 
 
 
 
116
  use_memory_efficient_attention=False,
117
  logn_attention_scale=False,
118
  logn_attention_clip1=False,
119
+ add_pooling_layer=True,
120
+ num_labels=0,
121
  **kwargs,
122
  ):
123
  super().__init__(**kwargs)
 
144
  self.unpad_inputs = unpad_inputs
145
  self.use_memory_efficient_attention = use_memory_efficient_attention
146
  self.logn_attention_scale = logn_attention_scale
147
+ self.logn_attention_clip1 = logn_attention_clip1
148
+
149
+ self.add_pooling_layer = add_pooling_layer
150
+ self.num_labels = num_labels
modeling_hf_alibaba_nlp_gte.py CHANGED
@@ -970,8 +970,9 @@ class GteForSequenceClassification(GtePreTrainedModel):
970
  def __init__(self, config: GteConfig):
971
  super().__init__(config)
972
  self.config = config
973
- self.num_labels = 1
974
- self.model = GteModel(config, add_pooling_layer=True)
 
975
 
976
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
977
  self.loss_function = nn.MSELoss()
@@ -1010,7 +1011,10 @@ class GteForSequenceClassification(GtePreTrainedModel):
1010
  output_attentions=output_attentions,
1011
  output_hidden_states=output_hidden_states,
1012
  )
1013
- hidden_states = transformer_outputs.pooler_output
 
 
 
1014
 
1015
  logits = self.score(hidden_states)
1016
 
 
970
  def __init__(self, config: GteConfig):
971
  super().__init__(config)
972
  self.config = config
973
+ self.num_labels = config.num_labels
974
+ assert config.num_labels > 0, "num_labels should be greater than 0 for sequence classification"
975
+ self.model = GteModel(config, add_pooling_layer=config.add_pooling_layer)
976
 
977
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
978
  self.loss_function = nn.MSELoss()
 
1011
  output_attentions=output_attentions,
1012
  output_hidden_states=output_hidden_states,
1013
  )
1014
+ if self.config.add_pooling_layer:
1015
+ hidden_states = transformer_outputs.pooler_output
1016
+ else:
1017
+ hidden_states = transformer_outputs.last_hidden_state[:, 0]
1018
 
1019
  logits = self.score(hidden_states)
1020