File size: 367 Bytes
bb1f9ea
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
def rerank_results(cross_encoder_model, query, candidate_passages, top_k=10):
    if not candidate_passages: 
        return []
    pairs = [(query, passage) for passage in candidate_passages]
    scores = cross_encoder_model.predict(pairs)
    reranked = sorted(zip(candidate_passages, scores), key=lambda x: x[1], reverse=True)

    return reranked[:top_k]