ianma2024 commited on
Commit
b05bf4f
·
verified ·
1 Parent(s): d0deb70

support mteb evaluation and update readme

Browse files
README.md CHANGED
@@ -106,9 +106,9 @@ To reduce the discrepancy between training and inference, we propose iterative i
106
 
107
  ## How to use
108
  ```python
109
- from transfoermers import AutoModelForSequenceClassification
110
 
111
- reranker = AutoModelForSequenceClassification('ByteDance/ListConRanker', trust_remote_code=True)
112
 
113
  # [query, passages_1, passage_2, ..., passage_n]
114
  batch = [
@@ -130,31 +130,37 @@ batch = [
130
  # for conventional inference, please manage the batch size by yourself
131
  scores = reranker.multi_passage(batch)
132
  print(scores)
133
- # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
134
 
 
 
 
 
 
 
 
135
  inputs = tokenizer(
136
  [
137
  [
138
- "query 1",
139
- "passage_11",
140
  ],
141
  [
142
- "query 1",
143
- "passage_12",
144
  ],
145
  [
146
- "query_2",
147
- "passage_21",
148
  ],
149
  ],
150
  return_tensors="pt",
151
  padding=True,
 
152
  )
153
- probs, logits = model(**inputs)
154
- print(probs)
155
- # tensor([[0.4359], [0.3840]], grad_fn=<ViewBackward0>)
156
  ```
157
- or using the `sentence_transformers` library:
158
  ```python
159
  from sentence_transformers import CrossEncoder
160
 
@@ -162,21 +168,20 @@ model = CrossEncoder('ByteDance/ListConRanker', trust_remote_code=True)
162
 
163
  inputs = [
164
  [
165
- "query 1",
166
- "passage_11",
167
  ],
168
  [
169
- "query 1",
170
- "passage_12",
171
  ],
172
  [
173
- "query_2",
174
- "passage_21",
175
  ],
176
  ]
177
  scores = model.predict(inputs)
178
  print(scores)
179
- # [0.43585014, 0.32305932, 0.38395187]
180
  ```
181
 
182
  To reproduce the results with iterative inference, please run:
 
106
 
107
  ## How to use
108
  ```python
109
+ from transfoermers import AutoModelForSequenceClassification, AutoTokenizer
110
 
111
+ reranker = AutoModelForSequenceClassification.from_pretrained('ByteDance/ListConRanker', trust_remote_code=True)
112
 
113
  # [query, passages_1, passage_2, ..., passage_n]
114
  batch = [
 
130
  # for conventional inference, please manage the batch size by yourself
131
  scores = reranker.multi_passage(batch)
132
  print(scores)
133
+ # [0.5126814246177673, 0.33125635981559753, 0.3642643094062805, 0.6367220282554626, 0.7166246175765991, 0.4281482696533203, 0.3530198335647583]
134
 
135
+ # for iterative inferfence, only a batch size of 1 is supported
136
+ # the scores do not carry similarity meaning and are only used for ranking
137
+ scores = reranker.multi_passage_in_iterative_inference(batch[0])
138
+ print(scores)
139
+ # [0.5126813650131226, 0.3312564790248871, 0.3642643094062805]
140
+
141
+ tokenizer = AutoTokenizer.from_pretrained('ByteDance/ListConRanker')
142
  inputs = tokenizer(
143
  [
144
  [
145
+ "皮蛋是寒性的食物吗",
146
+ "营养医师介绍皮蛋是属于凉性的食物,中医认为皮蛋可治眼疼、牙疼、高血压、耳鸣眩晕等疾病。体虚者要少吃。",
147
  ],
148
  [
149
+ "皮蛋是寒性的食物吗",
150
+ "皮蛋这种食品是在中国地域才常见的传统食品,它的生长汗青也是非常的悠长。",
151
  ],
152
  [
153
+ "月有阴晴圆缺的意义",
154
+ "形容的是月所有的状态,晴朗明媚,阴沉混沌,有月圆时,但多数时总是有缺陷。",
155
  ],
156
  ],
157
  return_tensors="pt",
158
  padding=True,
159
+ truncation=False
160
  )
161
+ # tensor([[0.5070], [0.3334], [0.6294]], device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)
 
 
162
  ```
163
+ or using the `sentence_transformers` library (We do not recommend using `sentence_transformers`. Because its truncation strategy may not match the model design, which may lead to performance degradation.):
164
  ```python
165
  from sentence_transformers import CrossEncoder
166
 
 
168
 
169
  inputs = [
170
  [
171
+ "皮蛋是寒性的食物吗",
172
+ "营养医师介绍皮蛋是属于凉性的食物,中医认为皮蛋可治眼疼、牙疼、高血压、耳鸣眩晕等疾病。体虚者要少吃。",
173
  ],
174
  [
175
+ "皮蛋是寒性的食物吗",
176
+ "皮蛋这种食品是在中国地域才常见的传统食品,它的生长汗青也是非常的悠长。",
177
  ],
178
  [
179
+ "月有阴晴圆缺的意义",
180
+ "形容的是月所有的状态,晴朗明媚,阴沉混沌,有月圆时,但多数时总是有缺陷。",
181
  ],
182
  ]
183
  scores = model.predict(inputs)
184
  print(scores)
 
185
  ```
186
 
187
  To reproduce the results with iterative inference, please run:
configuration_listconranker.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the "Software"), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+ from __future__ import annotations
20
+ from transformers import BertConfig
21
+
22
+ class ListConRankerConfig(BertConfig):
23
+ """Configuration class for ListConRanker model."""
24
+
25
+ model_type = "ListConRanker"
26
+
27
+ def __init__(
28
+ self,
29
+ list_transformer_layers: int = 2,
30
+ list_con_hidden_size: int = 1792,
31
+ num_labels: int = 1,
32
+ cls_token_id: int = 101,
33
+ sep_token_id: int = 102,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.list_transformer_layers = list_transformer_layers
38
+ self.list_con_hidden_size = list_con_hidden_size
39
+ self.num_labels = num_labels
40
+ self.cls_token_id = cls_token_id
41
+ self.sep_token_id = sep_token_id
42
+
43
+ self.bert_config = BertConfig(**kwargs)
44
+ self.bert_config.output_hidden_states = True
modeling_listconranker.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the "Software"), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+ from __future__ import annotations
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+ from transformers import (
24
+ PreTrainedModel,
25
+ BertModel,
26
+ AutoTokenizer,
27
+ )
28
+ import os
29
+ from transformers.modeling_outputs import SequenceClassifierOutput
30
+ from typing import Union, List, Optional
31
+ from collections import defaultdict
32
+ import numpy as np
33
+ import math
34
+ from huggingface_hub import hf_hub_download
35
+ from .configuration_listconranker import ListConRankerConfig
36
+
37
+
38
+ class QueryEmbedding(nn.Module):
39
+ def __init__(self, config) -> None:
40
+ super().__init__()
41
+ self.query_embedding = nn.Embedding(2, config.list_con_hidden_size)
42
+ self.layerNorm = nn.LayerNorm(config.list_con_hidden_size)
43
+
44
+ def forward(self, x, tags):
45
+ query_embeddings = self.query_embedding(tags)
46
+ x += query_embeddings
47
+ x = self.layerNorm(x)
48
+ return x
49
+
50
+
51
+ class ListTransformer(nn.Module):
52
+ def __init__(self, num_layer, config) -> None:
53
+ super().__init__()
54
+ self.config = config
55
+ self.list_transformer_layer = nn.TransformerEncoderLayer(
56
+ config.list_con_hidden_size,
57
+ self.config.num_attention_heads,
58
+ batch_first=True,
59
+ activation=F.gelu,
60
+ norm_first=False,
61
+ )
62
+ self.list_transformer = nn.TransformerEncoder(
63
+ self.list_transformer_layer, num_layer
64
+ )
65
+ self.relu = nn.ReLU()
66
+ self.query_embedding = QueryEmbedding(config)
67
+
68
+ self.linear_score3 = nn.Linear(
69
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
70
+ )
71
+ self.linear_score2 = nn.Linear(
72
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
73
+ )
74
+ self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1)
75
+
76
+ def forward(
77
+ self, pair_features: torch.Tensor, pair_nums: List[int]
78
+ ) -> torch.Tensor:
79
+ batch_pair_features = pair_features.split(pair_nums)
80
+
81
+ pair_feature_query_passage_concat_list = []
82
+ for i in range(len(batch_pair_features)):
83
+ pair_feature_query = (
84
+ batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
85
+ )
86
+ pair_feature_passage = batch_pair_features[i][1:]
87
+ pair_feature_query_passage_concat_list.append(
88
+ torch.cat([pair_feature_query, pair_feature_passage], dim=1)
89
+ )
90
+ pair_feature_query_passage_concat = torch.cat(
91
+ pair_feature_query_passage_concat_list, dim=0
92
+ )
93
+
94
+ batch_pair_features = nn.utils.rnn.pad_sequence(
95
+ batch_pair_features, batch_first=True
96
+ )
97
+
98
+ query_embedding_tags = torch.zeros(
99
+ batch_pair_features.size(0),
100
+ batch_pair_features.size(1),
101
+ dtype=torch.long,
102
+ device=self.device,
103
+ )
104
+ query_embedding_tags[:, 0] = 1
105
+ batch_pair_features = self.query_embedding(
106
+ batch_pair_features, query_embedding_tags
107
+ )
108
+
109
+ mask = self.generate_attention_mask(pair_nums)
110
+ query_mask = self.generate_attention_mask_custom(pair_nums)
111
+ pair_list_features = self.list_transformer(
112
+ batch_pair_features, src_key_padding_mask=mask, mask=query_mask
113
+ )
114
+
115
+ output_pair_list_features = []
116
+ output_query_list_features = []
117
+ pair_features_after_transformer_list = []
118
+ for idx, pair_num in enumerate(pair_nums):
119
+ output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
120
+ output_query_list_features.append(pair_list_features[idx, 0, :])
121
+ pair_features_after_transformer_list.append(
122
+ pair_list_features[idx, :pair_num, :]
123
+ )
124
+
125
+ pair_features_after_transformer_cat_query_list = []
126
+ for idx, pair_num in enumerate(pair_nums):
127
+ query_ft = (
128
+ output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
129
+ )
130
+ pair_features_after_transformer_cat_query = torch.cat(
131
+ [query_ft, output_pair_list_features[idx]], dim=1
132
+ )
133
+ pair_features_after_transformer_cat_query_list.append(
134
+ pair_features_after_transformer_cat_query
135
+ )
136
+ pair_features_after_transformer_cat_query = torch.cat(
137
+ pair_features_after_transformer_cat_query_list, dim=0
138
+ )
139
+
140
+ pair_feature_query_passage_concat = self.relu(
141
+ self.linear_score2(pair_feature_query_passage_concat)
142
+ )
143
+ pair_features_after_transformer_cat_query = self.relu(
144
+ self.linear_score3(pair_features_after_transformer_cat_query)
145
+ )
146
+ final_ft = torch.cat(
147
+ [
148
+ pair_feature_query_passage_concat,
149
+ pair_features_after_transformer_cat_query,
150
+ ],
151
+ dim=1,
152
+ )
153
+ logits = self.linear_score1(final_ft).squeeze()
154
+ return logits, torch.cat(pair_features_after_transformer_list, dim=0)
155
+
156
+ def generate_attention_mask(self, pair_num):
157
+ max_len = max(pair_num)
158
+ batch_size = len(pair_num)
159
+ mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
160
+ for i, length in enumerate(pair_num):
161
+ mask[i, length:] = True
162
+ return mask
163
+
164
+ def generate_attention_mask_custom(self, pair_num):
165
+ max_len = max(pair_num)
166
+ mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
167
+ mask[0, 1:] = True
168
+ return mask
169
+
170
+
171
+ class ListConRankerModel(PreTrainedModel):
172
+ """
173
+ ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
174
+ """
175
+
176
+ config_class = ListConRankerConfig
177
+ base_model_prefix = "listconranker"
178
+
179
+ def __init__(self, config: ListConRankerConfig):
180
+ super().__init__(config)
181
+ self.config = config
182
+ self.num_labels = config.num_labels
183
+ self.hf_model = BertModel(config.bert_config)
184
+
185
+ self.sigmoid = nn.Sigmoid()
186
+
187
+ self.linear_in_embedding = nn.Linear(
188
+ config.hidden_size, config.list_con_hidden_size
189
+ )
190
+ self.list_transformer = ListTransformer(
191
+ config.list_transformer_layers,
192
+ config,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ input_ids: torch.Tensor,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ token_type_ids: Optional[torch.Tensor] = None,
200
+ position_ids: Optional[torch.Tensor] = None,
201
+ head_mask: Optional[torch.Tensor] = None,
202
+ inputs_embeds: Optional[torch.Tensor] = None,
203
+ labels: Optional[torch.Tensor] = None,
204
+ output_attentions: Optional[bool] = None,
205
+ output_hidden_states: Optional[bool] = None,
206
+ return_dict: Optional[bool] = None,
207
+ **kwargs,
208
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
209
+ if self.training:
210
+ raise NotImplementedError("Training not supported; use eval mode.")
211
+ device = input_ids.device
212
+ self.list_transformer.device = device
213
+ # Reorganize by unique queries and their passages
214
+ (
215
+ reorganized_input_ids,
216
+ reorganized_attention_mask,
217
+ reorganized_token_type_ids,
218
+ pair_nums,
219
+ group_indices,
220
+ ) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids)
221
+
222
+ out = self.hf_model(
223
+ input_ids=reorganized_input_ids,
224
+ attention_mask=reorganized_attention_mask,
225
+ token_type_ids=reorganized_token_type_ids,
226
+ return_dict=True,
227
+ )
228
+ feats = out.last_hidden_state
229
+ pooled = self.average_pooling(feats, reorganized_attention_mask)
230
+ embedded = self.linear_in_embedding(pooled)
231
+ logits, _ = self.list_transformer(embedded, pair_nums)
232
+ probs = self.sigmoid(logits)
233
+
234
+ # Restore original order
235
+ sorted_probs = self._restore_original_order(probs, group_indices)
236
+ sorted_logits = self._restore_original_order(logits, group_indices)
237
+ if not return_dict:
238
+ return (sorted_probs, sorted_logits)
239
+
240
+ return SequenceClassifierOutput(
241
+ loss=None,
242
+ logits=sorted_logits,
243
+ hidden_states=out.hidden_states,
244
+ attentions=out.attentions,
245
+ )
246
+
247
+ def _reorganize_inputs(
248
+ self,
249
+ input_ids: torch.Tensor,
250
+ attention_mask: torch.Tensor,
251
+ token_type_ids: Optional[torch.Tensor],
252
+ ) -> tuple[
253
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]]
254
+ ]:
255
+ """
256
+ Group inputs by unique queries: for each query, produce [query] + its passages,
257
+ then flatten, pad, and return pair sizes and original indices mapping.
258
+ """
259
+ batch_size = input_ids.size(0)
260
+ # Structure: query_key -> {
261
+ # 'query': (seq, mask, tt),
262
+ # 'passages': [(seq, mask, tt), ...],
263
+ # 'indices': [original_index, ...]
264
+ # }
265
+ grouped = {}
266
+
267
+ for idx in range(batch_size):
268
+ seq = input_ids[idx]
269
+ mask = attention_mask[idx]
270
+ token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq)
271
+
272
+ sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0]
273
+ if sep_idxs.numel() == 0:
274
+ raise ValueError(f"No SEP in sequence {idx}")
275
+ first_sep = sep_idxs[0].item()
276
+ second_sep = sep_idxs[1].item()
277
+
278
+ # Extract query and passage
279
+ q_seq = seq[: first_sep + 1]
280
+ q_mask = mask[: first_sep + 1]
281
+ q_tt = torch.zeros_like(q_seq)
282
+
283
+ p_seq = seq[first_sep : second_sep + 1]
284
+ p_mask = mask[first_sep : second_sep + 1]
285
+ p_seq = p_seq.clone()
286
+ p_seq[0] = self.config.cls_token_id
287
+ p_tt = torch.zeros_like(p_seq)
288
+
289
+ # Build key excluding CLS/SEP
290
+ key = tuple(
291
+ q_seq[
292
+ (q_seq != self.config.cls_token_id)
293
+ & (q_seq != self.config.sep_token_id)
294
+ ].tolist()
295
+ )
296
+
297
+ # truncation
298
+ q_seq = q_seq[: self.config.max_position_embeddings]
299
+ q_seq[-1] = self.config.sep_token_id
300
+ p_seq = p_seq[: self.config.max_position_embeddings]
301
+ p_seq[-1] = self.config.sep_token_id
302
+ q_mask = q_mask[: self.config.max_position_embeddings]
303
+ p_mask = p_mask[: self.config.max_position_embeddings]
304
+ q_tt = q_tt[: self.config.max_position_embeddings]
305
+ p_tt = p_tt[: self.config.max_position_embeddings]
306
+
307
+ if key not in grouped:
308
+ grouped[key] = {
309
+ "query": (q_seq, q_mask, q_tt),
310
+ "passages": [],
311
+ "indices": [],
312
+ }
313
+ grouped[key]["passages"].append((p_seq, p_mask, p_tt))
314
+ grouped[key]["indices"].append(idx)
315
+
316
+ # Flatten according to group insertion order
317
+ seqs, masks, tts, pair_nums, group_indices = [], [], [], [], []
318
+ for key, data in grouped.items():
319
+ q_seq, q_mask, q_tt = data["query"]
320
+ passages = data["passages"]
321
+ indices = data["indices"]
322
+ # record sizes and original positions
323
+ pair_nums.append(len(passages) + 1) # +1 for the query
324
+ group_indices.append(indices)
325
+
326
+ # append query then its passages
327
+ seqs.append(q_seq)
328
+ masks.append(q_mask)
329
+ tts.append(q_tt)
330
+ for p_seq, p_mask, p_tt in passages:
331
+ seqs.append(p_seq)
332
+ masks.append(p_mask)
333
+ tts.append(p_tt)
334
+
335
+ # Pad to uniform length
336
+ max_len = max(s.size(0) for s in seqs)
337
+ padded_seqs, padded_masks, padded_tts = [], [], []
338
+ for s, m, t in zip(seqs, masks, tts):
339
+ ps = torch.zeros(max_len, dtype=s.dtype, device=s.device)
340
+ pm = torch.zeros(max_len, dtype=m.dtype, device=m.device)
341
+ pt = torch.zeros(max_len, dtype=t.dtype, device=t.device)
342
+ ps[: s.size(0)] = s
343
+ pm[: m.size(0)] = m
344
+ pt[: t.size(0)] = t
345
+ padded_seqs.append(ps)
346
+ padded_masks.append(pm)
347
+ padded_tts.append(pt)
348
+
349
+ rid = torch.stack(padded_seqs)
350
+ ram = torch.stack(padded_masks)
351
+ rtt = torch.stack(padded_tts) if token_type_ids is not None else None
352
+
353
+ return rid, ram, rtt, pair_nums, group_indices
354
+
355
+ def _restore_original_order(
356
+ self,
357
+ logits: torch.Tensor,
358
+ group_indices: List[List[int]],
359
+ ) -> torch.Tensor:
360
+ """
361
+ Map flattened logits back so each original index gets its passage score.
362
+ """
363
+ out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device)
364
+ i = 0
365
+ for indices in group_indices:
366
+ for idx in indices:
367
+ out[idx] = logits[i]
368
+ i += 1
369
+ return out.reshape(-1, 1)
370
+
371
+ def average_pooling(self, hidden_state, attention_mask):
372
+ extended_attention_mask = (
373
+ attention_mask.unsqueeze(-1)
374
+ .expand(hidden_state.size())
375
+ .to(dtype=hidden_state.dtype)
376
+ )
377
+ masked_hidden_state = hidden_state * extended_attention_mask
378
+ sum_embeddings = torch.sum(masked_hidden_state, dim=1)
379
+ sum_mask = extended_attention_mask.sum(dim=1)
380
+ return sum_embeddings / sum_mask
381
+
382
+ @classmethod
383
+ def from_pretrained(
384
+ cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
385
+ ):
386
+ model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
387
+ model.hf_model = BertModel.from_pretrained(
388
+ model_name_or_path, config=model.config.bert_config, **kwargs
389
+ )
390
+ linear_path = hf_hub_download(
391
+ repo_id = model_name_or_path,
392
+ filename = "linear_in_embedding.pt",
393
+ revision = "main",
394
+ cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None
395
+ )
396
+ list_transformer_path = hf_hub_download(
397
+ repo_id = "ByteDance/ListConRanker",
398
+ filename = "list_transformer.pt",
399
+ revision = "main",
400
+ cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None
401
+ )
402
+
403
+ try:
404
+ model.linear_in_embedding.load_state_dict(torch.load(linear_path))
405
+ model.list_transformer.load_state_dict(torch.load(list_transformer_path))
406
+ except FileNotFoundError as e:
407
+ raise e
408
+
409
+ return model
410
+
411
+ def multi_passage(
412
+ self,
413
+ sentences: List[List[str]],
414
+ batch_size: int = 32,
415
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
416
+ "ByteDance/ListConRanker"
417
+ ),
418
+ ):
419
+ """
420
+ Process multiple passages for each query.
421
+ :param sentences: List of lists, where each inner list contains sentences for a query.
422
+ :return: Tensor of logits for each passage.
423
+ """
424
+ pairs = []
425
+ for batch in sentences:
426
+ if len(batch) < 2:
427
+ raise ValueError("Each query must have at least one passage.")
428
+ query = batch[0]
429
+ passages = batch[1:]
430
+ for passage in passages:
431
+ pairs.append((query, passage))
432
+
433
+ total_batches = (len(pairs) + batch_size - 1) // batch_size
434
+ total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device)
435
+ for batch in range(total_batches):
436
+ batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size]
437
+ inputs = tokenizer(
438
+ batch_pairs,
439
+ padding=True,
440
+ truncation=False,
441
+ return_tensors="pt",
442
+ )
443
+
444
+ for k, v in inputs.items():
445
+ inputs[k] = v.to(self.device)
446
+
447
+ logits = self(**inputs)[0]
448
+ total_logits[batch * batch_size : (batch + 1) * batch_size] = (
449
+ logits.squeeze(1)
450
+ )
451
+ return total_logits.tolist()
452
+
453
+ def multi_passage_in_iterative_inference(
454
+ self,
455
+ sentences: List[str],
456
+ stop_num: int = 20,
457
+ decrement_rate: float = 0.2,
458
+ min_filter_num: int = 10,
459
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
460
+ "ByteDance/ListConRanker"
461
+ ),
462
+ ):
463
+ """
464
+ Process multiple passages for one query in iterative inference.
465
+ :param sentences: List contains sentences for a query.
466
+ :return: Tensor of logits for each passage.
467
+ """
468
+ if stop_num < 1:
469
+ raise ValueError("stop_num must be greater than 0")
470
+ if decrement_rate <= 0 or decrement_rate >= 1:
471
+ raise ValueError("decrement_rate must be in (0, 1)")
472
+ if min_filter_num < 1:
473
+ raise ValueError("min_filter_num must be greater than 0")
474
+
475
+ query = sentences[0]
476
+ passage = sentences[1:]
477
+
478
+ filter_times = 0
479
+ passage2score = defaultdict(list)
480
+ while len(passage) > stop_num:
481
+ batch = [[query] + passage]
482
+ pred_scores = self.multi_passage(
483
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
484
+ )
485
+ pred_scores_argsort = np.argsort(
486
+ pred_scores
487
+ ).tolist() # Sort in increasing order
488
+
489
+ passage_len = len(passage)
490
+ to_filter_num = math.ceil(passage_len * decrement_rate)
491
+ if to_filter_num < min_filter_num:
492
+ to_filter_num = min_filter_num
493
+
494
+ have_filter_num = 0
495
+ while have_filter_num < to_filter_num:
496
+ idx = pred_scores_argsort[have_filter_num]
497
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
498
+ have_filter_num += 1
499
+ while (
500
+ pred_scores[pred_scores_argsort[have_filter_num - 1]]
501
+ == pred_scores[pred_scores_argsort[have_filter_num]]
502
+ ):
503
+ idx = pred_scores_argsort[have_filter_num]
504
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
505
+ have_filter_num += 1
506
+ next_passage = []
507
+ next_passage_idx = have_filter_num
508
+ while next_passage_idx < len(passage):
509
+ idx = pred_scores_argsort[next_passage_idx]
510
+ next_passage.append(passage[idx])
511
+ next_passage_idx += 1
512
+ passage = next_passage
513
+ filter_times += 1
514
+
515
+ batch = [[query] + passage]
516
+ pred_scores = self.multi_passage(
517
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
518
+ )
519
+
520
+ cnt = 0
521
+ while cnt < len(passage):
522
+ passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
523
+ cnt += 1
524
+
525
+ passage = sentences[1:]
526
+ final_score = []
527
+ for i in range(len(passage)):
528
+ p = passage[i]
529
+ final_score.append(passage2score[p][0])
530
+ return final_score