|
""" |
|
์๊ฒฉ ์ฝ๋ ์คํ ์ต์
์ด ์ถ๊ฐ๋ ๋ฆฌ๋ญ์ปค ๋ชจ๋ |
|
""" |
|
from typing import List, Dict, Tuple |
|
import numpy as np |
|
from sentence_transformers import CrossEncoder |
|
from langchain.schema import Document |
|
from config import RERANKER_MODEL |
|
|
|
class Reranker: |
|
def __init__(self, model_name: str = RERANKER_MODEL): |
|
""" |
|
Cross-Encoder ๋ฆฌ๋ญ์ปค ์ด๊ธฐํ |
|
|
|
Args: |
|
model_name: ์ฌ์ฉํ Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ |
|
""" |
|
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์ค: {model_name}") |
|
|
|
|
|
self.model = CrossEncoder( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
|
|
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {model_name}") |
|
|
|
def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]: |
|
""" |
|
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ |
|
|
|
Args: |
|
query: ๊ฒ์ ์ฟผ๋ฆฌ |
|
documents: ๋ฒกํฐ ๊ฒ์ ๊ฒฐ๊ณผ ๋ฌธ์ ๋ฆฌ์คํธ |
|
top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ |
|
|
|
Returns: |
|
์ฌ์ ๋ ฌ๋ ์์ ๋ฌธ์ ๋ฆฌ์คํธ |
|
""" |
|
if not documents: |
|
return [] |
|
|
|
|
|
document_texts = [doc.page_content for doc in documents] |
|
query_doc_pairs = [(query, doc) for doc in document_texts] |
|
|
|
|
|
print(f"๋ฆฌ๋ญํน ์ํ ์ค: {len(documents)}๊ฐ ๋ฌธ์") |
|
scores = self.model.predict(query_doc_pairs) |
|
|
|
|
|
doc_score_pairs = list(zip(documents, scores)) |
|
doc_score_pairs.sort(key=lambda x: x[1], reverse=True) |
|
|
|
print(f"๋ฆฌ๋ญํน ์๋ฃ: ์์ {top_k}๊ฐ ๋ฌธ์ ์ ํ") |
|
|
|
|
|
return [doc for doc, score in doc_score_pairs[:top_k]] |