ListConRanker
Model
- We propose a Listwise-encoded Contrastive text reRanker (ListConRanker), includes a ListTransformer module for listwise encoding. The ListTransformer can facilitate global contrastive information learning between passage features, including the clustering of similar passages, the clustering between dissimilar passages, and the distinction between similar and dissimilar passages. Besides, we propose ListAttention to help ListTransformer maintain the features of the query while learning global comparative information.
- The training loss function is Circle Loss[1]. Compared with cross-entropy loss and ranking loss, it can solve the problems of low data efficiency and unsmooth gradient change.
Data
The training data consists of approximately 2.6 million queries, each corresponding to multiple passages. The data comes from the training sets of several datasets, including cMedQA1.0, cMedQA2.0, MMarcoReranking, T2Reranking, huatuo, MARC, XL-sum, CSL and so on.
Training
We trained the model in two stages. In the first stage, we freeze the parameters of embedding model and only train the ListTransformer for 4 epochs with a batch size of 1024. In the second stage, we do not freeze any parameter and train for another 2 epochs with a batch size of 256.
Inference
Due to the limited memory of GPUs, we input about 20 passages at a time for each query during training. However, during actual use, there may be situations where far more than 20 passages are input at the same time (e.g, MMarcoReranking).
To reduce the discrepancy between training and inference, we propose iterative inference. The iterative inference feeds the passages into the ListConRanker multiple times, and each time it only decides the ranking of the passage at the end of the list.
Performance
Model | cMedQA1.0 | cMedQA2.0 | MMarcoReranking | T2Reranking | Avg. |
---|---|---|---|---|---|
LdIR-Qwen2-reranker-1.5B | 86.50 | 87.11 | 39.35 | 68.84 | 70.45 |
zpoint-large-embedding-zh | 91.11 | 90.07 | 38.87 | 69.29 | 72.34 |
xiaobu-embedding-v2 | 90.96 | 90.41 | 39.91 | 69.03 | 72.58 |
Conan-embedding-v1 | 91.39 | 89.72 | 41.58 | 68.36 | 72.76 |
ListConRanker | 90.55 | 89.38 | 43.88 | 69.17 | 73.25 |
- w/o Iterative Inference | 90.20 | 89.98 | 37.52 | 69.17 | 71.72 |
How to use
from transfoermers import AutoModelForSequenceClassification, AutoTokenizer
reranker = AutoModelForSequenceClassification.from_pretrained('ByteDance/ListConRanker', trust_remote_code=True)
# [query, passages_1, passage_2, ..., passage_n]
batch = [
[
'็ฎ่ๆฏๅฏๆง็้ฃ็ฉๅ', # query
'่ฅๅ
ปๅปๅธไป็ป็ฎ่ๆฏๅฑไบๅๆง็้ฃ็ฉ,ไธญๅป่ฎคไธบ็ฎ่ๅฏๆฒป็ผ็ผใ็็ผใ้ซ่กๅใ่ณ้ธฃ็ฉๆ็ญ็พ็
ใไฝ่่
่ฆๅฐๅใ', # passage_1
'็ฎ่่ฟ็ง้ฃๅๆฏๅจไธญๅฝๅฐๅๆๅธธ่ง็ไผ ็ป้ฃๅ,ๅฎ็็้ฟๆฑ้ไนๆฏ้ๅธธ็ๆ ้ฟใ', # passage_2
'ๅๆฌข็ฎ่็ไบบไผ่งๅพ็ฎ่ๆฏๆ็พๅณ็้ฃ็ฉ,ไธๅๆฌข็ฎ่็ไบบๅ่งๅพ็ฎ่ๆฏ้ปๆๆ็,ๅฐคๅ
ถๅพๅคๅคๅฝๆๅ้ฝไธ็่งฃๆไปฌๅ็ฎ่็ไน ๆฏ' # passage_3
],
[
'ๆๆ้ดๆดๅ็ผบ็ๆไน', # query
'ๅฝขๅฎน็ๆฏๆๆๆ็็ถๆ,ๆดๆๆๅช,้ดๆฒๆททๆฒ,ๆๆๅๆถ,ไฝๅคๆฐๆถๆปๆฏๆ็ผบ้ทใ', # passage_1
'ไบบๆๆฒๆฌข็ฆปๅ,ๆๆ้ดๆดๅ็ผบ่ฟๅฅ่ฏๆๆๆฏไบบๆๆฒๆฌข็ฆปๅ็ๅ่ฟ,ๆๆ้ดๆดๅ็ผบ็่ฝฌๆขใ', # passage_2
'ๆข็ถๆฏ่ฏๆญ,ๅๅช้ไผๆ็ๆญฃๅซไนๅข? ๅคงๆฆๅฏไปฅ่ฏด:ไบบ็ๆๅคชๅคๅๅท,่ฆ้พ,ไปๅฎนๅฆ่ก้ขๅฏนๅฐฑๅฅฝใ', # passage_3
'ไธ้ถไธๅ
ญๅนด่่ฝผ่ดฌๅฎๅฏๅท,ๆถๅนดๅๅไธๅฒ็ไปๆฟๆฒปไธๅพไธๅพๅฟ,ๆถๅผไธญ็งไฝณ่,้ๅธธๆณๅฟต่ชๅทฑ็ๅผๅผๅญ็ฑๅ
ๅฟ้ขๆๅฟง้,ๆ
็ปชไฝๆฒ,ๆๆ่ๅๅไบ่ฟ้ฆ่ฏใ' # passage_4
]
]
# for conventional inference, please manage the batch size by yourself
scores = reranker.multi_passage(batch)
print(scores)
# [0.5126814246177673, 0.33125635981559753, 0.3642643094062805, 0.6367220282554626, 0.7166246175765991, 0.4281482696533203, 0.3530198335647583]
# for iterative inferfence, only a batch size of 1 is supported
# the scores do not carry similarity meaning and are only used for ranking
scores = reranker.multi_passage_in_iterative_inference(batch[0])
print(scores)
# [0.5126813650131226, 0.3312564790248871, 0.3642643094062805]
tokenizer = AutoTokenizer.from_pretrained('ByteDance/ListConRanker')
inputs = tokenizer(
[
[
"็ฎ่ๆฏๅฏๆง็้ฃ็ฉๅ",
"่ฅๅ
ปๅปๅธไป็ป็ฎ่ๆฏๅฑไบๅๆง็้ฃ็ฉ,ไธญๅป่ฎคไธบ็ฎ่ๅฏๆฒป็ผ็ผใ็็ผใ้ซ่กๅใ่ณ้ธฃ็ฉๆ็ญ็พ็
ใไฝ่่
่ฆๅฐๅใ",
],
[
"็ฎ่ๆฏๅฏๆง็้ฃ็ฉๅ",
"็ฎ่่ฟ็ง้ฃๅๆฏๅจไธญๅฝๅฐๅๆๅธธ่ง็ไผ ็ป้ฃๅ,ๅฎ็็้ฟๆฑ้ไนๆฏ้ๅธธ็ๆ ้ฟใ",
],
[
"ๆๆ้ดๆดๅ็ผบ็ๆไน",
"ๅฝขๅฎน็ๆฏๆๆๆ็็ถๆ,ๆดๆๆๅช,้ดๆฒๆททๆฒ,ๆๆๅๆถ,ไฝๅคๆฐๆถๆปๆฏๆ็ผบ้ทใ",
],
],
return_tensors="pt",
padding=True,
truncation=False
)
# tensor([[0.5070], [0.3334], [0.6294]], device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)
or using the sentence_transformers
library (We do not recommend using sentence_transformers
. Because its truncation strategy may not match the model design, which may lead to performance degradation.):
from sentence_transformers import CrossEncoder
model = CrossEncoder('ByteDance/ListConRanker', trust_remote_code=True)
inputs = [
[
"็ฎ่ๆฏๅฏๆง็้ฃ็ฉๅ",
"่ฅๅ
ปๅปๅธไป็ป็ฎ่ๆฏๅฑไบๅๆง็้ฃ็ฉ,ไธญๅป่ฎคไธบ็ฎ่ๅฏๆฒป็ผ็ผใ็็ผใ้ซ่กๅใ่ณ้ธฃ็ฉๆ็ญ็พ็
ใไฝ่่
่ฆๅฐๅใ",
],
[
"็ฎ่ๆฏๅฏๆง็้ฃ็ฉๅ",
"็ฎ่่ฟ็ง้ฃๅๆฏๅจไธญๅฝๅฐๅๆๅธธ่ง็ไผ ็ป้ฃๅ,ๅฎ็็้ฟๆฑ้ไนๆฏ้ๅธธ็ๆ ้ฟใ",
],
[
"ๆๆ้ดๆดๅ็ผบ็ๆไน",
"ๅฝขๅฎน็ๆฏๆๆๆ็็ถๆ,ๆดๆๆๅช,้ดๆฒๆททๆฒ,ๆๆๅๆถ,ไฝๅคๆฐๆถๆปๆฏๆ็ผบ้ทใ",
],
]
scores = model.predict(inputs)
print(scores)
To reproduce the results with iterative inference, please run:
python3 eval_listconranker_iterative_inference.py
To reproduce the results without iterative inference, please run:
python3 eval_listconranker.py
Reference
- https://arxiv.org/abs/2002.10857
- https://github.com/FlagOpen/FlagEmbedding
- https://arxiv.org/abs/2408.15710
License
This work is licensed under a MIT License and the weight of models is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
- Downloads last month
- 29
Evaluation results
- map on MTEB CMedQAv1test set self-reported90.554
- mrr_1 on MTEB CMedQAv1test set self-reported87.800
- mrr_10 on MTEB CMedQAv1test set self-reported92.451
- mrr_5 on MTEB CMedQAv1test set self-reported92.325
- map on MTEB CMedQAv2test set self-reported89.381
- mrr_1 on MTEB CMedQAv2test set self-reported85.900
- mrr_10 on MTEB CMedQAv2test set self-reported91.288
- mrr_5 on MTEB CMedQAv2test set self-reported91.090
- map on MTEB MMarcoRerankingself-reported43.881
- mrr_1 on MTEB MMarcoRerankingself-reported32.000