Roman Solomatin commited on
Commit
f7a361f
·
unverified ·
1 Parent(s): 8c3030b

fix dimenstions again

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. listconranker.py +134 -75
config.json CHANGED
@@ -12,8 +12,8 @@
12
  "gradient_checkpointing": false,
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
- "hidden_size": 1792,
16
- "base_hidden_size": 1024,
17
  "id2label": {
18
  "0": "LABEL_0"
19
  },
 
12
  "gradient_checkpointing": false,
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 1024,
16
+ "list_con_hidden_size": 1792,
17
  "id2label": {
18
  "0": "LABEL_0"
19
  },
listconranker.py CHANGED
@@ -1,20 +1,20 @@
1
  # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
  #
3
- # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
- # and associated documentation files (the "Software"), to deal in the Software without
5
- # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
- # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
  # Software is furnished to do so, subject to the following conditions:
8
  #
9
- # The above copyright notice and this permission notice shall be included in all copies or
10
  # substantial portions of the Software.
11
  #
12
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
- # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
- # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
- # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
  # OTHER DEALINGS IN THE SOFTWARE.
19
 
20
  import math
@@ -23,47 +23,45 @@ from torch import nn
23
  from torch.nn import functional as F
24
  import numpy as np
25
  from transformers import (
26
- AutoTokenizer,
27
- is_torch_npu_available,
28
- AutoModel,
29
- PreTrainedModel,
30
  PretrainedConfig,
31
  AutoConfig,
32
  BertModel,
33
- BertConfig
34
  )
35
  from transformers.modeling_outputs import SequenceClassifierOutput
36
  from typing import Union, List, Optional
37
 
38
 
39
- class ListConRankerConfig(PretrainedConfig):
40
  """Configuration class for ListConRanker model."""
41
-
42
  model_type = "ListConRanker"
43
-
44
  def __init__(
45
  self,
46
  list_transformer_layers: int = 2,
47
- hidden_size: int = 1792,
48
- base_hidden_size: int = 1024,
49
  num_labels: int = 1,
50
- **kwargs
51
  ):
52
  super().__init__(**kwargs)
53
  self.list_transformer_layers = list_transformer_layers
54
- self.hidden_size = hidden_size
55
- self.base_hidden_size = base_hidden_size
56
  self.num_labels = num_labels
57
 
58
  self.bert_config = BertConfig(**kwargs)
59
- self.bert_config.hidden_size = self.base_hidden_size
60
  self.bert_config.output_hidden_states = True
61
 
 
62
  class QueryEmbedding(nn.Module):
63
  def __init__(self, config) -> None:
64
  super().__init__()
65
- self.query_embedding = nn.Embedding(2, config.hidden_size)
66
- self.layerNorm = nn.LayerNorm(config.hidden_size)
67
 
68
  def forward(self, x, tags):
69
  query_embeddings = self.query_embedding(tags)
@@ -71,40 +69,70 @@ class QueryEmbedding(nn.Module):
71
  x = self.layerNorm(x)
72
  return x
73
 
 
74
  class ListTransformer(nn.Module):
75
  def __init__(self, num_layer, config) -> None:
76
  super().__init__()
77
  self.config = config
78
- self.list_transformer_layer = nn.TransformerEncoderLayer(1792, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False)
79
- self.list_transformer = nn.TransformerEncoder(self.list_transformer_layer, num_layer)
 
 
 
 
 
 
 
 
80
  self.relu = nn.ReLU()
81
  self.query_embedding = QueryEmbedding(config)
82
 
83
- self.linear_score3 = nn.Linear(config.hidden_size * 2, config.hidden_size)
84
- self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
85
- self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
 
 
 
 
86
 
87
- def forward(self, pair_features: torch.Tensor):
88
- pair_nums = pair_features.size(0)
89
- pair_nums = [x + 1 for x in pair_nums]
90
  batch_pair_features = pair_features.split(pair_nums)
91
 
92
  pair_feature_query_passage_concat_list = []
93
  for i in range(len(batch_pair_features)):
94
- pair_feature_query = batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
 
 
95
  pair_feature_passage = batch_pair_features[i][1:]
96
- pair_feature_query_passage_concat_list.append(torch.cat([pair_feature_query, pair_feature_passage], dim=1))
97
- pair_feature_query_passage_concat = torch.cat(pair_feature_query_passage_concat_list, dim=0)
 
 
 
 
98
 
99
- batch_pair_features = nn.utils.rnn.pad_sequence(batch_pair_features, batch_first=True)
 
 
100
 
101
- query_embedding_tags = torch.zeros(batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device)
 
 
 
 
 
102
  query_embedding_tags[:, 0] = 1
103
- batch_pair_features = self.query_embedding(batch_pair_features, query_embedding_tags)
 
 
104
 
105
  mask = self.generate_attention_mask(pair_nums)
106
  query_mask = self.generate_attention_mask_custom(pair_nums)
107
- pair_list_features = self.list_transformer(batch_pair_features, src_key_padding_mask=mask, mask=query_mask)
 
 
108
 
109
  output_pair_list_features = []
110
  output_query_list_features = []
@@ -112,20 +140,39 @@ class ListTransformer(nn.Module):
112
  for idx, pair_num in enumerate(pair_nums):
113
  output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
114
  output_query_list_features.append(pair_list_features[idx, 0, :])
115
- pair_features_after_transformer_list.append(pair_list_features[idx, :pair_num, :])
 
 
116
 
117
  pair_features_after_transformer_cat_query_list = []
118
  for idx, pair_num in enumerate(pair_nums):
119
- query_ft = output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
120
- pair_features_after_transformer_cat_query = torch.cat([query_ft, output_pair_list_features[idx]], dim=1)
121
- pair_features_after_transformer_cat_query_list.append(pair_features_after_transformer_cat_query)
122
- pair_features_after_transformer_cat_query = torch.cat(pair_features_after_transformer_cat_query_list, dim=0)
123
-
124
- pair_feature_query_passage_concat = self.relu(self.linear_score2(pair_feature_query_passage_concat))
125
- pair_features_after_transformer_cat_query = self.relu(self.linear_score3(pair_features_after_transformer_cat_query))
126
- final_ft = torch.cat([pair_feature_query_passage_concat, pair_features_after_transformer_cat_query], dim=1)
127
- logits = self.linear_score1(final_ft).squeeze()
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return logits, torch.cat(pair_features_after_transformer_list, dim=0)
130
 
131
  def generate_attention_mask(self, pair_num):
@@ -147,6 +194,7 @@ class ListConRankerModel(PreTrainedModel):
147
  """
148
  ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
149
  """
 
150
  config_class = ListConRankerConfig
151
  base_model_prefix = "listconranker"
152
 
@@ -155,14 +203,17 @@ class ListConRankerModel(PreTrainedModel):
155
  self.config = config
156
  self.num_labels = config.num_labels
157
  self.hf_model = BertModel(config.bert_config)
158
-
159
  self.sigmoid = nn.Sigmoid()
160
 
161
- self.linear_in_embedding = nn.Linear(config.base_hidden_size, config.hidden_size)
 
 
162
  self.list_transformer = ListTransformer(
163
- config.list_transformer_layers,
164
- config,
165
  )
 
166
 
167
  def forward(
168
  self,
@@ -176,55 +227,63 @@ class ListConRankerModel(PreTrainedModel):
176
  output_attentions: Optional[bool] = None,
177
  output_hidden_states: Optional[bool] = None,
178
  return_dict: Optional[bool] = None,
179
- **kwargs
180
- ) -> Union[SequenceClassifierOutput, tuple]:
181
  # Get device
182
  device = input_ids.device if input_ids is not None else inputs_embeds.device
183
  self.list_transformer.device = device
184
-
185
  # Forward through base model
186
  if self.training:
187
  pass
188
  else:
189
  ranker_out = self.hf_model(
190
- input_ids=input_ids,
191
- attention_mask=attention_mask,
192
- token_type_ids=token_type_ids,
193
- position_ids=position_ids,
194
- head_mask=head_mask,
195
- inputs_embeds=inputs_embeds,
196
- output_attentions=output_attentions,
197
- return_dict=True)
 
198
  last_hidden_state = ranker_out.last_hidden_state
199
 
200
  pair_features = self.average_pooling(last_hidden_state, attention_mask)
201
  pair_features = self.linear_in_embedding(pair_features)
202
 
203
- logits, pair_features_after_list_transformer = self.list_transformer(pair_features)
 
 
204
  logits = self.sigmoid(logits)
205
 
206
  return logits
207
 
208
  def average_pooling(self, hidden_state, attention_mask):
209
- extended_attention_mask = attention_mask.unsqueeze(-1).expand(hidden_state.size()).to(dtype=hidden_state.dtype)
 
 
 
 
210
  masked_hidden_state = hidden_state * extended_attention_mask
211
  sum_embeddings = torch.sum(masked_hidden_state, dim=1)
212
  sum_mask = extended_attention_mask.sum(dim=1)
213
  return sum_embeddings / sum_mask
214
 
215
  @classmethod
216
- def from_pretrained(cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs):
217
- model = super().from_pretrained(
218
- model_name_or_path,config=config, **kwargs)
219
-
 
220
  # Load custom weights
221
  linear_path = f"{model_name_or_path}/linear_in_embedding.pt"
222
  transformer_path = f"{model_name_or_path}/list_transformer.pt"
223
-
224
  try:
225
  model.linear_in_embedding.load_state_dict(torch.load(linear_path))
226
  model.list_transformer.load_state_dict(torch.load(transformer_path))
227
  except FileNotFoundError:
228
  print(f"Warning: Could not load custom weights from {model_name_or_path}")
229
-
230
  return model
 
1
  # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
  #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the "Software"), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
  # Software is furnished to do so, subject to the following conditions:
8
  #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
  # substantial portions of the Software.
11
  #
12
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
  # OTHER DEALINGS IN THE SOFTWARE.
19
 
20
  import math
 
23
  from torch.nn import functional as F
24
  import numpy as np
25
  from transformers import (
26
+ AutoTokenizer,
27
+ is_torch_npu_available,
28
+ AutoModel,
29
+ PreTrainedModel,
30
  PretrainedConfig,
31
  AutoConfig,
32
  BertModel,
33
+ BertConfig,
34
  )
35
  from transformers.modeling_outputs import SequenceClassifierOutput
36
  from typing import Union, List, Optional
37
 
38
 
39
+ class ListConRankerConfig(BertConfig):
40
  """Configuration class for ListConRanker model."""
41
+
42
  model_type = "ListConRanker"
43
+
44
  def __init__(
45
  self,
46
  list_transformer_layers: int = 2,
47
+ list_con_hidden_size: int = 1792,
 
48
  num_labels: int = 1,
49
+ **kwargs,
50
  ):
51
  super().__init__(**kwargs)
52
  self.list_transformer_layers = list_transformer_layers
53
+ self.list_con_hidden_size = list_con_hidden_size
 
54
  self.num_labels = num_labels
55
 
56
  self.bert_config = BertConfig(**kwargs)
 
57
  self.bert_config.output_hidden_states = True
58
 
59
+
60
  class QueryEmbedding(nn.Module):
61
  def __init__(self, config) -> None:
62
  super().__init__()
63
+ self.query_embedding = nn.Embedding(2, config.list_con_hidden_size)
64
+ self.layerNorm = nn.LayerNorm(config.list_con_hidden_size)
65
 
66
  def forward(self, x, tags):
67
  query_embeddings = self.query_embedding(tags)
 
69
  x = self.layerNorm(x)
70
  return x
71
 
72
+
73
  class ListTransformer(nn.Module):
74
  def __init__(self, num_layer, config) -> None:
75
  super().__init__()
76
  self.config = config
77
+ self.list_transformer_layer = nn.TransformerEncoderLayer(
78
+ 1792,
79
+ self.config.num_attention_heads,
80
+ batch_first=True,
81
+ activation=F.gelu,
82
+ norm_first=False,
83
+ )
84
+ self.list_transformer = nn.TransformerEncoder(
85
+ self.list_transformer_layer, num_layer
86
+ )
87
  self.relu = nn.ReLU()
88
  self.query_embedding = QueryEmbedding(config)
89
 
90
+ self.linear_score3 = nn.Linear(
91
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
92
+ )
93
+ self.linear_score2 = nn.Linear(
94
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
95
+ )
96
+ self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1)
97
 
98
+ def forward(
99
+ self, pair_features: torch.Tensor, pair_nums: List[int]
100
+ ) -> torch.Tensor:
101
  batch_pair_features = pair_features.split(pair_nums)
102
 
103
  pair_feature_query_passage_concat_list = []
104
  for i in range(len(batch_pair_features)):
105
+ pair_feature_query = (
106
+ batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
107
+ )
108
  pair_feature_passage = batch_pair_features[i][1:]
109
+ pair_feature_query_passage_concat_list.append(
110
+ torch.cat([pair_feature_query, pair_feature_passage], dim=1)
111
+ )
112
+ pair_feature_query_passage_concat = torch.cat(
113
+ pair_feature_query_passage_concat_list, dim=0
114
+ )
115
 
116
+ batch_pair_features = nn.utils.rnn.pad_sequence(
117
+ batch_pair_features, batch_first=True
118
+ )
119
 
120
+ query_embedding_tags = torch.zeros(
121
+ batch_pair_features.size(0),
122
+ batch_pair_features.size(1),
123
+ dtype=torch.long,
124
+ device=self.device,
125
+ )
126
  query_embedding_tags[:, 0] = 1
127
+ batch_pair_features = self.query_embedding(
128
+ batch_pair_features, query_embedding_tags
129
+ )
130
 
131
  mask = self.generate_attention_mask(pair_nums)
132
  query_mask = self.generate_attention_mask_custom(pair_nums)
133
+ pair_list_features = self.list_transformer(
134
+ batch_pair_features, src_key_padding_mask=mask, mask=query_mask
135
+ )
136
 
137
  output_pair_list_features = []
138
  output_query_list_features = []
 
140
  for idx, pair_num in enumerate(pair_nums):
141
  output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
142
  output_query_list_features.append(pair_list_features[idx, 0, :])
143
+ pair_features_after_transformer_list.append(
144
+ pair_list_features[idx, :pair_num, :]
145
+ )
146
 
147
  pair_features_after_transformer_cat_query_list = []
148
  for idx, pair_num in enumerate(pair_nums):
149
+ query_ft = (
150
+ output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
151
+ )
152
+ pair_features_after_transformer_cat_query = torch.cat(
153
+ [query_ft, output_pair_list_features[idx]], dim=1
154
+ )
155
+ pair_features_after_transformer_cat_query_list.append(
156
+ pair_features_after_transformer_cat_query
157
+ )
158
+ pair_features_after_transformer_cat_query = torch.cat(
159
+ pair_features_after_transformer_cat_query_list, dim=0
160
+ )
161
 
162
+ pair_feature_query_passage_concat = self.relu(
163
+ self.linear_score2(pair_feature_query_passage_concat)
164
+ )
165
+ pair_features_after_transformer_cat_query = self.relu(
166
+ self.linear_score3(pair_features_after_transformer_cat_query)
167
+ )
168
+ final_ft = torch.cat(
169
+ [
170
+ pair_feature_query_passage_concat,
171
+ pair_features_after_transformer_cat_query,
172
+ ],
173
+ dim=1,
174
+ )
175
+ logits = self.linear_score1(final_ft).squeeze()
176
  return logits, torch.cat(pair_features_after_transformer_list, dim=0)
177
 
178
  def generate_attention_mask(self, pair_num):
 
194
  """
195
  ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
196
  """
197
+
198
  config_class = ListConRankerConfig
199
  base_model_prefix = "listconranker"
200
 
 
203
  self.config = config
204
  self.num_labels = config.num_labels
205
  self.hf_model = BertModel(config.bert_config)
206
+
207
  self.sigmoid = nn.Sigmoid()
208
 
209
+ self.linear_in_embedding = nn.Linear(
210
+ config.hidden_size, config.list_con_hidden_size
211
+ )
212
  self.list_transformer = ListTransformer(
213
+ config.list_transformer_layers,
214
+ config,
215
  )
216
+ self.sep_token_id = 102 # [SEP] token ID
217
 
218
  def forward(
219
  self,
 
227
  output_attentions: Optional[bool] = None,
228
  output_hidden_states: Optional[bool] = None,
229
  return_dict: Optional[bool] = None,
230
+ **kwargs,
231
+ ) -> Union[SequenceClassifierOutput, tuple]:
232
  # Get device
233
  device = input_ids.device if input_ids is not None else inputs_embeds.device
234
  self.list_transformer.device = device
235
+
236
  # Forward through base model
237
  if self.training:
238
  pass
239
  else:
240
  ranker_out = self.hf_model(
241
+ input_ids=input_ids,
242
+ attention_mask=attention_mask,
243
+ token_type_ids=token_type_ids,
244
+ position_ids=position_ids,
245
+ head_mask=head_mask,
246
+ inputs_embeds=inputs_embeds,
247
+ output_attentions=output_attentions,
248
+ return_dict=True,
249
+ )
250
  last_hidden_state = ranker_out.last_hidden_state
251
 
252
  pair_features = self.average_pooling(last_hidden_state, attention_mask)
253
  pair_features = self.linear_in_embedding(pair_features)
254
 
255
+ logits, pair_features_after_list_transformer = self.list_transformer(
256
+ pair_features
257
+ )
258
  logits = self.sigmoid(logits)
259
 
260
  return logits
261
 
262
  def average_pooling(self, hidden_state, attention_mask):
263
+ extended_attention_mask = (
264
+ attention_mask.unsqueeze(-1)
265
+ .expand(hidden_state.size())
266
+ .to(dtype=hidden_state.dtype)
267
+ )
268
  masked_hidden_state = hidden_state * extended_attention_mask
269
  sum_embeddings = torch.sum(masked_hidden_state, dim=1)
270
  sum_mask = extended_attention_mask.sum(dim=1)
271
  return sum_embeddings / sum_mask
272
 
273
  @classmethod
274
+ def from_pretrained(
275
+ cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
276
+ ):
277
+ model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
278
+
279
  # Load custom weights
280
  linear_path = f"{model_name_or_path}/linear_in_embedding.pt"
281
  transformer_path = f"{model_name_or_path}/list_transformer.pt"
282
+
283
  try:
284
  model.linear_in_embedding.load_state_dict(torch.load(linear_path))
285
  model.list_transformer.load_state_dict(torch.load(transformer_path))
286
  except FileNotFoundError:
287
  print(f"Warning: Could not load custom weights from {model_name_or_path}")
288
+
289
  return model