ianma2024 commited on
Commit
b3e02f6
·
verified ·
1 Parent(s): b05bf4f

rm ood files

Browse files
Files changed (1) hide show
  1. listconranker.py +0 -546
listconranker.py DELETED
@@ -1,546 +0,0 @@
1
- # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
- #
3
- # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
- # and associated documentation files (the "Software"), to deal in the Software without
5
- # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
- # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
- # Software is furnished to do so, subject to the following conditions:
8
- #
9
- # The above copyright notice and this permission notice shall be included in all copies or
10
- # substantial portions of the Software.
11
- #
12
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
- # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
- # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
- # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
- # OTHER DEALINGS IN THE SOFTWARE.
19
- from __future__ import annotations
20
-
21
- import torch
22
- from torch import nn
23
- from torch.nn import functional as F
24
- from transformers import (
25
- PreTrainedModel,
26
- BertModel,
27
- BertConfig,
28
- AutoTokenizer,
29
- )
30
- import os
31
- from transformers.modeling_outputs import SequenceClassifierOutput
32
- from typing import Union, List, Optional
33
- from collections import defaultdict
34
- import numpy as np
35
- import math
36
-
37
-
38
- class ListConRankerConfig(BertConfig):
39
- """Configuration class for ListConRanker model."""
40
-
41
- model_type = "ListConRanker"
42
-
43
- def __init__(
44
- self,
45
- list_transformer_layers: int = 2,
46
- list_con_hidden_size: int = 1792,
47
- num_labels: int = 1,
48
- cls_token_id: int = 101,
49
- sep_token_id: int = 102,
50
- **kwargs,
51
- ):
52
- super().__init__(**kwargs)
53
- self.list_transformer_layers = list_transformer_layers
54
- self.list_con_hidden_size = list_con_hidden_size
55
- self.num_labels = num_labels
56
- self.cls_token_id = cls_token_id
57
- self.sep_token_id = sep_token_id
58
-
59
- self.bert_config = BertConfig(**kwargs)
60
- self.bert_config.output_hidden_states = True
61
-
62
-
63
- class QueryEmbedding(nn.Module):
64
- def __init__(self, config) -> None:
65
- super().__init__()
66
- self.query_embedding = nn.Embedding(2, config.list_con_hidden_size)
67
- self.layerNorm = nn.LayerNorm(config.list_con_hidden_size)
68
-
69
- def forward(self, x, tags):
70
- query_embeddings = self.query_embedding(tags)
71
- x += query_embeddings
72
- x = self.layerNorm(x)
73
- return x
74
-
75
-
76
- class ListTransformer(nn.Module):
77
- def __init__(self, num_layer, config) -> None:
78
- super().__init__()
79
- self.config = config
80
- self.list_transformer_layer = nn.TransformerEncoderLayer(
81
- config.list_con_hidden_size,
82
- self.config.num_attention_heads,
83
- batch_first=True,
84
- activation=F.gelu,
85
- norm_first=False,
86
- )
87
- self.list_transformer = nn.TransformerEncoder(
88
- self.list_transformer_layer, num_layer
89
- )
90
- self.relu = nn.ReLU()
91
- self.query_embedding = QueryEmbedding(config)
92
-
93
- self.linear_score3 = nn.Linear(
94
- config.list_con_hidden_size * 2, config.list_con_hidden_size
95
- )
96
- self.linear_score2 = nn.Linear(
97
- config.list_con_hidden_size * 2, config.list_con_hidden_size
98
- )
99
- self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1)
100
-
101
- def forward(
102
- self, pair_features: torch.Tensor, pair_nums: List[int]
103
- ) -> torch.Tensor:
104
- batch_pair_features = pair_features.split(pair_nums)
105
-
106
- pair_feature_query_passage_concat_list = []
107
- for i in range(len(batch_pair_features)):
108
- pair_feature_query = (
109
- batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
110
- )
111
- pair_feature_passage = batch_pair_features[i][1:]
112
- pair_feature_query_passage_concat_list.append(
113
- torch.cat([pair_feature_query, pair_feature_passage], dim=1)
114
- )
115
- pair_feature_query_passage_concat = torch.cat(
116
- pair_feature_query_passage_concat_list, dim=0
117
- )
118
-
119
- batch_pair_features = nn.utils.rnn.pad_sequence(
120
- batch_pair_features, batch_first=True
121
- )
122
-
123
- query_embedding_tags = torch.zeros(
124
- batch_pair_features.size(0),
125
- batch_pair_features.size(1),
126
- dtype=torch.long,
127
- device=self.device,
128
- )
129
- query_embedding_tags[:, 0] = 1
130
- batch_pair_features = self.query_embedding(
131
- batch_pair_features, query_embedding_tags
132
- )
133
-
134
- mask = self.generate_attention_mask(pair_nums)
135
- query_mask = self.generate_attention_mask_custom(pair_nums)
136
- pair_list_features = self.list_transformer(
137
- batch_pair_features, src_key_padding_mask=mask, mask=query_mask
138
- )
139
-
140
- output_pair_list_features = []
141
- output_query_list_features = []
142
- pair_features_after_transformer_list = []
143
- for idx, pair_num in enumerate(pair_nums):
144
- output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
145
- output_query_list_features.append(pair_list_features[idx, 0, :])
146
- pair_features_after_transformer_list.append(
147
- pair_list_features[idx, :pair_num, :]
148
- )
149
-
150
- pair_features_after_transformer_cat_query_list = []
151
- for idx, pair_num in enumerate(pair_nums):
152
- query_ft = (
153
- output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
154
- )
155
- pair_features_after_transformer_cat_query = torch.cat(
156
- [query_ft, output_pair_list_features[idx]], dim=1
157
- )
158
- pair_features_after_transformer_cat_query_list.append(
159
- pair_features_after_transformer_cat_query
160
- )
161
- pair_features_after_transformer_cat_query = torch.cat(
162
- pair_features_after_transformer_cat_query_list, dim=0
163
- )
164
-
165
- pair_feature_query_passage_concat = self.relu(
166
- self.linear_score2(pair_feature_query_passage_concat)
167
- )
168
- pair_features_after_transformer_cat_query = self.relu(
169
- self.linear_score3(pair_features_after_transformer_cat_query)
170
- )
171
- final_ft = torch.cat(
172
- [
173
- pair_feature_query_passage_concat,
174
- pair_features_after_transformer_cat_query,
175
- ],
176
- dim=1,
177
- )
178
- logits = self.linear_score1(final_ft).squeeze()
179
- return logits, torch.cat(pair_features_after_transformer_list, dim=0)
180
-
181
- def generate_attention_mask(self, pair_num):
182
- max_len = max(pair_num)
183
- batch_size = len(pair_num)
184
- mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
185
- for i, length in enumerate(pair_num):
186
- mask[i, length:] = True
187
- return mask
188
-
189
- def generate_attention_mask_custom(self, pair_num):
190
- max_len = max(pair_num)
191
- mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
192
- mask[0, 1:] = True
193
- return mask
194
-
195
-
196
- class ListConRankerModel(PreTrainedModel):
197
- """
198
- ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
199
- """
200
-
201
- config_class = ListConRankerConfig
202
- base_model_prefix = "listconranker"
203
-
204
- def __init__(self, config: ListConRankerConfig):
205
- super().__init__(config)
206
- self.config = config
207
- self.num_labels = config.num_labels
208
- self.hf_model = BertModel(config.bert_config)
209
-
210
- self.sigmoid = nn.Sigmoid()
211
-
212
- self.linear_in_embedding = nn.Linear(
213
- config.hidden_size, config.list_con_hidden_size
214
- )
215
- self.list_transformer = ListTransformer(
216
- config.list_transformer_layers,
217
- config,
218
- )
219
-
220
- def forward(
221
- self,
222
- input_ids: torch.Tensor,
223
- attention_mask: Optional[torch.Tensor] = None,
224
- token_type_ids: Optional[torch.Tensor] = None,
225
- position_ids: Optional[torch.Tensor] = None,
226
- head_mask: Optional[torch.Tensor] = None,
227
- inputs_embeds: Optional[torch.Tensor] = None,
228
- labels: Optional[torch.Tensor] = None,
229
- output_attentions: Optional[bool] = None,
230
- output_hidden_states: Optional[bool] = None,
231
- return_dict: Optional[bool] = None,
232
- **kwargs,
233
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
234
- if self.training:
235
- raise NotImplementedError("Training not supported; use eval mode.")
236
- device = input_ids.device
237
- self.list_transformer.device = device
238
- # Reorganize by unique queries and their passages
239
- (
240
- reorganized_input_ids,
241
- reorganized_attention_mask,
242
- reorganized_token_type_ids,
243
- pair_nums,
244
- group_indices,
245
- ) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids)
246
-
247
- out = self.hf_model(
248
- input_ids=reorganized_input_ids,
249
- attention_mask=reorganized_attention_mask,
250
- token_type_ids=reorganized_token_type_ids,
251
- return_dict=True,
252
- )
253
- feats = out.last_hidden_state
254
- pooled = self.average_pooling(feats, reorganized_attention_mask)
255
- embedded = self.linear_in_embedding(pooled)
256
- logits, _ = self.list_transformer(embedded, pair_nums)
257
- probs = self.sigmoid(logits)
258
-
259
- # Restore original order
260
- sorted_probs = self._restore_original_order(probs, group_indices)
261
- sorted_logits = self._restore_original_order(logits, group_indices)
262
- if not return_dict:
263
- return (sorted_probs, sorted_logits)
264
-
265
- return SequenceClassifierOutput(
266
- loss=None,
267
- logits=sorted_logits,
268
- hidden_states=out.hidden_states,
269
- attentions=out.attentions,
270
- )
271
-
272
- def _reorganize_inputs(
273
- self,
274
- input_ids: torch.Tensor,
275
- attention_mask: torch.Tensor,
276
- token_type_ids: Optional[torch.Tensor],
277
- ) -> tuple[
278
- torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]]
279
- ]:
280
- """
281
- Group inputs by unique queries: for each query, produce [query] + its passages,
282
- then flatten, pad, and return pair sizes and original indices mapping.
283
- """
284
- batch_size = input_ids.size(0)
285
- # Structure: query_key -> {
286
- # 'query': (seq, mask, tt),
287
- # 'passages': [(seq, mask, tt), ...],
288
- # 'indices': [original_index, ...]
289
- # }
290
- grouped = {}
291
-
292
- for idx in range(batch_size):
293
- seq = input_ids[idx]
294
- mask = attention_mask[idx]
295
- token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq)
296
-
297
- sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0]
298
- if sep_idxs.numel() == 0:
299
- raise ValueError(f"No SEP in sequence {idx}")
300
- first_sep = sep_idxs[0].item()
301
- second_sep = sep_idxs[1].item()
302
-
303
- # Extract query and passage
304
- q_seq = seq[: first_sep + 1]
305
- q_mask = mask[: first_sep + 1]
306
- q_tt = torch.zeros_like(q_seq)
307
-
308
- p_seq = seq[first_sep : second_sep + 1]
309
- p_mask = mask[first_sep : second_sep + 1]
310
- p_seq = p_seq.clone()
311
- p_seq[0] = self.config.cls_token_id
312
- p_tt = torch.zeros_like(p_seq)
313
-
314
- # Build key excluding CLS/SEP
315
- key = tuple(
316
- q_seq[
317
- (q_seq != self.config.cls_token_id)
318
- & (q_seq != self.config.sep_token_id)
319
- ].tolist()
320
- )
321
-
322
- # truncation
323
- q_seq = q_seq[: self.config.max_position_embeddings]
324
- q_seq[-1] = self.config.sep_token_id
325
- p_seq = p_seq[: self.config.max_position_embeddings]
326
- p_seq[-1] = self.config.sep_token_id
327
- q_mask = q_mask[: self.config.max_position_embeddings]
328
- p_mask = p_mask[: self.config.max_position_embeddings]
329
- q_tt = q_tt[: self.config.max_position_embeddings]
330
- p_tt = p_tt[: self.config.max_position_embeddings]
331
-
332
- if key not in grouped:
333
- grouped[key] = {
334
- "query": (q_seq, q_mask, q_tt),
335
- "passages": [],
336
- "indices": [],
337
- }
338
- grouped[key]["passages"].append((p_seq, p_mask, p_tt))
339
- grouped[key]["indices"].append(idx)
340
-
341
- # Flatten according to group insertion order
342
- seqs, masks, tts, pair_nums, group_indices = [], [], [], [], []
343
- for key, data in grouped.items():
344
- q_seq, q_mask, q_tt = data["query"]
345
- passages = data["passages"]
346
- indices = data["indices"]
347
- # record sizes and original positions
348
- pair_nums.append(len(passages) + 1) # +1 for the query
349
- group_indices.append(indices)
350
-
351
- # append query then its passages
352
- seqs.append(q_seq)
353
- masks.append(q_mask)
354
- tts.append(q_tt)
355
- for p_seq, p_mask, p_tt in passages:
356
- seqs.append(p_seq)
357
- masks.append(p_mask)
358
- tts.append(p_tt)
359
-
360
- # Pad to uniform length
361
- max_len = max(s.size(0) for s in seqs)
362
- padded_seqs, padded_masks, padded_tts = [], [], []
363
- for s, m, t in zip(seqs, masks, tts):
364
- ps = torch.zeros(max_len, dtype=s.dtype, device=s.device)
365
- pm = torch.zeros(max_len, dtype=m.dtype, device=m.device)
366
- pt = torch.zeros(max_len, dtype=t.dtype, device=t.device)
367
- ps[: s.size(0)] = s
368
- pm[: m.size(0)] = m
369
- pt[: t.size(0)] = t
370
- padded_seqs.append(ps)
371
- padded_masks.append(pm)
372
- padded_tts.append(pt)
373
-
374
- rid = torch.stack(padded_seqs)
375
- ram = torch.stack(padded_masks)
376
- rtt = torch.stack(padded_tts) if token_type_ids is not None else None
377
-
378
- return rid, ram, rtt, pair_nums, group_indices
379
-
380
- def _restore_original_order(
381
- self,
382
- logits: torch.Tensor,
383
- group_indices: List[List[int]],
384
- ) -> torch.Tensor:
385
- """
386
- Map flattened logits back so each original index gets its passage score.
387
- """
388
- out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device)
389
- i = 0
390
- for indices in group_indices:
391
- for idx in indices:
392
- out[idx] = logits[i]
393
- i += 1
394
- return out.reshape(-1, 1)
395
-
396
- def average_pooling(self, hidden_state, attention_mask):
397
- extended_attention_mask = (
398
- attention_mask.unsqueeze(-1)
399
- .expand(hidden_state.size())
400
- .to(dtype=hidden_state.dtype)
401
- )
402
- masked_hidden_state = hidden_state * extended_attention_mask
403
- sum_embeddings = torch.sum(masked_hidden_state, dim=1)
404
- sum_mask = extended_attention_mask.sum(dim=1)
405
- return sum_embeddings / sum_mask
406
-
407
- @classmethod
408
- def from_pretrained(
409
- cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
410
- ):
411
- model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
412
- model.hf_model = BertModel.from_pretrained(
413
- model_name_or_path, config=model.config.bert_config, **kwargs
414
- )
415
-
416
- linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
417
- transformer_path = os.path.join(model_name_or_path, "list_transformer.pt")
418
-
419
- try:
420
- model.linear_in_embedding.load_state_dict(torch.load(linear_path))
421
- model.list_transformer.load_state_dict(torch.load(transformer_path))
422
- except FileNotFoundError as e:
423
- raise e
424
-
425
- return model
426
-
427
- def multi_passage(
428
- self,
429
- sentences: List[List[str]],
430
- batch_size: int = 32,
431
- tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
432
- "ByteDance/ListConRanker"
433
- ),
434
- ):
435
- """
436
- Process multiple passages for each query.
437
- :param sentences: List of lists, where each inner list contains sentences for a query.
438
- :return: Tensor of logits for each passage.
439
- """
440
- pairs = []
441
- for batch in sentences:
442
- if len(batch) < 2:
443
- raise ValueError("Each query must have at least one passage.")
444
- query = batch[0]
445
- passages = batch[1:]
446
- for passage in passages:
447
- pairs.append((query, passage))
448
-
449
- total_batches = (len(pairs) + batch_size - 1) // batch_size
450
- total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device)
451
- for batch in range(total_batches):
452
- batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size]
453
- inputs = tokenizer(
454
- batch_pairs,
455
- padding=True,
456
- truncation=False,
457
- return_tensors="pt",
458
- )
459
-
460
- for k, v in inputs.items():
461
- inputs[k] = v.to(self.device)
462
-
463
- logits = self(**inputs)[0]
464
- total_logits[batch * batch_size : (batch + 1) * batch_size] = (
465
- logits.squeeze(1)
466
- )
467
- return total_logits
468
-
469
- def multi_passage_in_iterative_inference(
470
- self,
471
- sentences: List[str],
472
- stop_num: int = 20,
473
- decrement_rate: float = 0.2,
474
- min_filter_num: int = 10,
475
- tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
476
- "ByteDance/ListConRanker"
477
- ),
478
- ):
479
- """
480
- Process multiple passages for one query in iterative inference.
481
- :param sentences: List contains sentences for a query.
482
- :return: Tensor of logits for each passage.
483
- """
484
- if stop_num < 1:
485
- raise ValueError("stop_num must be greater than 0")
486
- if decrement_rate <= 0 or decrement_rate >= 1:
487
- raise ValueError("decrement_rate must be in (0, 1)")
488
- if min_filter_num < 1:
489
- raise ValueError("min_filter_num must be greater than 0")
490
-
491
- query = sentences[0]
492
- passage = sentences[1:]
493
-
494
- filter_times = 0
495
- passage2score = defaultdict(list)
496
- while len(passage) > stop_num:
497
- batch = [[query] + passage]
498
- pred_scores = self.multi_passage(
499
- batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
500
- ).tolist()
501
- pred_scores_argsort = np.argsort(
502
- pred_scores
503
- ).tolist() # Sort in increasing order
504
-
505
- passage_len = len(passage)
506
- to_filter_num = math.ceil(passage_len * decrement_rate)
507
- if to_filter_num < min_filter_num:
508
- to_filter_num = min_filter_num
509
-
510
- have_filter_num = 0
511
- while have_filter_num < to_filter_num:
512
- idx = pred_scores_argsort[have_filter_num]
513
- passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
514
- have_filter_num += 1
515
- while (
516
- pred_scores[pred_scores_argsort[have_filter_num - 1]]
517
- == pred_scores[pred_scores_argsort[have_filter_num]]
518
- ):
519
- idx = pred_scores_argsort[have_filter_num]
520
- passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
521
- have_filter_num += 1
522
- next_passage = []
523
- next_passage_idx = have_filter_num
524
- while next_passage_idx < len(passage):
525
- idx = pred_scores_argsort[next_passage_idx]
526
- next_passage.append(passage[idx])
527
- next_passage_idx += 1
528
- passage = next_passage
529
- filter_times += 1
530
-
531
- batch = [[query] + passage]
532
- pred_scores = self.multi_passage(
533
- batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
534
- ).tolist()
535
-
536
- cnt = 0
537
- while cnt < len(passage):
538
- passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
539
- cnt += 1
540
-
541
- passage = sentences[1:]
542
- final_score = []
543
- for i in range(len(passage)):
544
- p = passage[i]
545
- final_score.append(passage2score[p][0])
546
- return final_score