ListConRanker

Model

  • We propose a Listwise-encoded Contrastive text reRanker (ListConRanker), includes a ListTransformer module for listwise encoding. The ListTransformer can facilitate global contrastive information learning between passage features, including the clustering of similar passages, the clustering between dissimilar passages, and the distinction between similar and dissimilar passages. Besides, we propose ListAttention to help ListTransformer maintain the features of the query while learning global comparative information.
  • The training loss function is Circle Loss[1]. Compared with cross-entropy loss and ranking loss, it can solve the problems of low data efficiency and unsmooth gradient change.

Data

The training data consists of approximately 2.6 million queries, each corresponding to multiple passages. The data comes from the training sets of several datasets, including cMedQA1.0, cMedQA2.0, MMarcoReranking, T2Reranking, huatuo, MARC, XL-sum, CSL and so on.

Training

We trained the model in two stages. In the first stage, we freeze the parameters of embedding model and only train the ListTransformer for 4 epochs with a batch size of 1024. In the second stage, we do not freeze any parameter and train for another 2 epochs with a batch size of 256.

Inference

Due to the limited memory of GPUs, we input about 20 passages at a time for each query during training. However, during actual use, there may be situations where far more than 20 passages are input at the same time (e.g, MMarcoReranking).

To reduce the discrepancy between training and inference, we propose iterative inference. The iterative inference feeds the passages into the ListConRanker multiple times, and each time it only decides the ranking of the passage at the end of the list.

Performance

Model cMedQA1.0 cMedQA2.0 MMarcoReranking T2Reranking Avg.
LdIR-Qwen2-reranker-1.5B 86.50 87.11 39.35 68.84 70.45
zpoint-large-embedding-zh 91.11 90.07 38.87 69.29 72.34
xiaobu-embedding-v2 90.96 90.41 39.91 69.03 72.58
Conan-embedding-v1 91.39 89.72 41.58 68.36 72.76
ListConRanker 90.55 89.38 43.88 69.17 73.25
- w/o Iterative Inference 90.20 89.98 37.52 69.17 71.72

How to use

from transfoermers import AutoModelForSequenceClassification, AutoTokenizer

reranker = AutoModelForSequenceClassification.from_pretrained('ByteDance/ListConRanker', trust_remote_code=True)

# [query, passages_1, passage_2, ..., passage_n]
batch = [
    [
        '็šฎ่›‹ๆ˜ฏๅฏ’ๆ€ง็š„้ฃŸ็‰ฉๅ—', # query
        '่ฅๅ…ปๅŒปๅธˆไป‹็ป็šฎ่›‹ๆ˜ฏๅฑžไบŽๅ‡‰ๆ€ง็š„้ฃŸ็‰ฉ,ไธญๅŒป่ฎคไธบ็šฎ่›‹ๅฏๆฒป็œผ็–ผใ€็‰™็–ผใ€้ซ˜่ก€ๅŽ‹ใ€่€ณ้ธฃ็œฉๆ™•็ญ‰็–พ็—…ใ€‚ไฝ“่™š่€…่ฆๅฐ‘ๅƒใ€‚',  # passage_1
        '็šฎ่›‹่ฟ™็ง้ฃŸๅ“ๆ˜ฏๅœจไธญๅ›ฝๅœฐๅŸŸๆ‰ๅธธ่ง็š„ไผ ็ปŸ้ฃŸๅ“,ๅฎƒ็š„็”Ÿ้•ฟๆฑ—้’ไนŸๆ˜ฏ้žๅธธ็š„ๆ‚ ้•ฟใ€‚',  # passage_2
        'ๅ–œๆฌข็šฎ่›‹็š„ไบบไผš่ง‰ๅพ—็šฎ่›‹ๆ˜ฏๆœ€็พŽๅ‘ณ็š„้ฃŸ็‰ฉ,ไธๅ–œๆฌข็šฎ่›‹็š„ไบบๅˆ™่ง‰ๅพ—็šฎ่›‹ๆ˜ฏ้ป‘ๆš—ๆ–™็†,ๅฐคๅ…ถๅพˆๅคšๅค–ๅ›ฝๆœ‹ๅ‹้ƒฝไธ็†่งฃๆˆ‘ไปฌๅƒ็šฎ่›‹็š„ไน ๆƒฏ' # passage_3
    ],
    [
        'ๆœˆๆœ‰้˜ดๆ™ดๅœ†็ผบ็š„ๆ„ไน‰', # query
        'ๅฝขๅฎน็š„ๆ˜ฏๆœˆๆ‰€ๆœ‰็š„็Šถๆ€,ๆ™ดๆœ—ๆ˜Žๅชš,้˜ดๆฒ‰ๆททๆฒŒ,ๆœ‰ๆœˆๅœ†ๆ—ถ,ไฝ†ๅคšๆ•ฐๆ—ถๆ€ปๆ˜ฏๆœ‰็ผบ้™ทใ€‚',  # passage_1
        'ไบบๆœ‰ๆ‚ฒๆฌข็ฆปๅˆ,ๆœˆๆœ‰้˜ดๆ™ดๅœ†็ผบ่ฟ™ๅฅ่ฏๆ„ๆ€ๆ˜ฏไบบๆœ‰ๆ‚ฒๆฌข็ฆปๅˆ็š„ๅ˜่ฟ,ๆœˆๆœ‰้˜ดๆ™ดๅœ†็ผบ็š„่ฝฌๆขใ€‚',  # passage_2
        'ๆ—ข็„ถๆ˜ฏ่ฏ—ๆญŒ,ๅˆๅ“ช้‡Œไผšๆœ‰็œŸๆญฃๅซไน‰ๅ‘ข? ๅคงๆฆ‚ๅฏไปฅ่ฏด:ไบบ็”Ÿๆœ‰ๅคชๅคšๅŽๅท,่‹ฆ้šพ,ไปŽๅฎนๅฆ่ก้ขๅฏนๅฐฑๅฅฝใ€‚', # passage_3
        'ไธ€้›ถไธƒๅ…ญๅนด่‹่ฝผ่ดฌๅฎ˜ๅฏ†ๅทž,ๆ—ถๅนดๅ››ๅไธ€ๅฒ็š„ไป–ๆ”ฟๆฒปไธŠๅพˆไธๅพ—ๅฟ—,ๆ—ถๅ€ผไธญ็ง‹ไฝณ่Š‚,้žๅธธๆƒณๅฟต่‡ชๅทฑ็š„ๅผŸๅผŸๅญ็”ฑๅ†…ๅฟƒ้ข‡ๆ„Ÿๅฟง้ƒ,ๆƒ…็ปชไฝŽๆฒ‰,ๆœ‰ๆ„Ÿ่€Œๅ‘ๅ†™ไบ†่ฟ™้ฆ–่ฏใ€‚' # passage_4
    ]
]

# for conventional inference, please manage the batch size by yourself
scores = reranker.multi_passage(batch)
print(scores)
# [0.5126814246177673, 0.33125635981559753, 0.3642643094062805, 0.6367220282554626, 0.7166246175765991, 0.4281482696533203, 0.3530198335647583]

# for iterative inferfence, only a batch size of 1 is supported
# the scores do not carry similarity meaning and are only used for ranking
scores = reranker.multi_passage_in_iterative_inference(batch[0])
print(scores)
# [0.5126813650131226, 0.3312564790248871, 0.3642643094062805]

tokenizer = AutoTokenizer.from_pretrained('ByteDance/ListConRanker')
inputs = tokenizer(
    [
        [
            "็šฎ่›‹ๆ˜ฏๅฏ’ๆ€ง็š„้ฃŸ็‰ฉๅ—",
            "่ฅๅ…ปๅŒปๅธˆไป‹็ป็šฎ่›‹ๆ˜ฏๅฑžไบŽๅ‡‰ๆ€ง็š„้ฃŸ็‰ฉ,ไธญๅŒป่ฎคไธบ็šฎ่›‹ๅฏๆฒป็œผ็–ผใ€็‰™็–ผใ€้ซ˜่ก€ๅŽ‹ใ€่€ณ้ธฃ็œฉๆ™•็ญ‰็–พ็—…ใ€‚ไฝ“่™š่€…่ฆๅฐ‘ๅƒใ€‚",
        ],
        [
            "็šฎ่›‹ๆ˜ฏๅฏ’ๆ€ง็š„้ฃŸ็‰ฉๅ—",
            "็šฎ่›‹่ฟ™็ง้ฃŸๅ“ๆ˜ฏๅœจไธญๅ›ฝๅœฐๅŸŸๆ‰ๅธธ่ง็š„ไผ ็ปŸ้ฃŸๅ“,ๅฎƒ็š„็”Ÿ้•ฟๆฑ—้’ไนŸๆ˜ฏ้žๅธธ็š„ๆ‚ ้•ฟใ€‚",
        ],
        [
            "ๆœˆๆœ‰้˜ดๆ™ดๅœ†็ผบ็š„ๆ„ไน‰",
            "ๅฝขๅฎน็š„ๆ˜ฏๆœˆๆ‰€ๆœ‰็š„็Šถๆ€,ๆ™ดๆœ—ๆ˜Žๅชš,้˜ดๆฒ‰ๆททๆฒŒ,ๆœ‰ๆœˆๅœ†ๆ—ถ,ไฝ†ๅคšๆ•ฐๆ—ถๆ€ปๆ˜ฏๆœ‰็ผบ้™ทใ€‚",
        ],
    ],
    return_tensors="pt",
    padding=True,
    truncation=False
)
# tensor([[0.5070], [0.3334], [0.6294]], device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)

or using the sentence_transformers library (We do not recommend using sentence_transformers. Because its truncation strategy may not match the model design, which may lead to performance degradation.):

from sentence_transformers import CrossEncoder

model = CrossEncoder('ByteDance/ListConRanker', trust_remote_code=True)

inputs = [
  [
    "็šฎ่›‹ๆ˜ฏๅฏ’ๆ€ง็š„้ฃŸ็‰ฉๅ—",
    "่ฅๅ…ปๅŒปๅธˆไป‹็ป็šฎ่›‹ๆ˜ฏๅฑžไบŽๅ‡‰ๆ€ง็š„้ฃŸ็‰ฉ,ไธญๅŒป่ฎคไธบ็šฎ่›‹ๅฏๆฒป็œผ็–ผใ€็‰™็–ผใ€้ซ˜่ก€ๅŽ‹ใ€่€ณ้ธฃ็œฉๆ™•็ญ‰็–พ็—…ใ€‚ไฝ“่™š่€…่ฆๅฐ‘ๅƒใ€‚",
  ],
  [
    "็šฎ่›‹ๆ˜ฏๅฏ’ๆ€ง็š„้ฃŸ็‰ฉๅ—",
    "็šฎ่›‹่ฟ™็ง้ฃŸๅ“ๆ˜ฏๅœจไธญๅ›ฝๅœฐๅŸŸๆ‰ๅธธ่ง็š„ไผ ็ปŸ้ฃŸๅ“,ๅฎƒ็š„็”Ÿ้•ฟๆฑ—้’ไนŸๆ˜ฏ้žๅธธ็š„ๆ‚ ้•ฟใ€‚",
  ],
  [
    "ๆœˆๆœ‰้˜ดๆ™ดๅœ†็ผบ็š„ๆ„ไน‰",
    "ๅฝขๅฎน็š„ๆ˜ฏๆœˆๆ‰€ๆœ‰็š„็Šถๆ€,ๆ™ดๆœ—ๆ˜Žๅชš,้˜ดๆฒ‰ๆททๆฒŒ,ๆœ‰ๆœˆๅœ†ๆ—ถ,ไฝ†ๅคšๆ•ฐๆ—ถๆ€ปๆ˜ฏๆœ‰็ผบ้™ทใ€‚",
  ],
]
scores = model.predict(inputs)
print(scores)

To reproduce the results with iterative inference, please run:

python3 eval_listconranker_iterative_inference.py

To reproduce the results without iterative inference, please run:

python3 eval_listconranker.py

Reference

  1. https://arxiv.org/abs/2002.10857
  2. https://github.com/FlagOpen/FlagEmbedding
  3. https://arxiv.org/abs/2408.15710

License

This work is licensed under a MIT License and the weight of models is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Downloads last month
29
Safetensors
Model size
326M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Evaluation results