RAG_voice / reranker.py
jeongsoo's picture
Add application file
4a98f26
"""
์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ์˜ต์…˜์ด ์ถ”๊ฐ€๋œ ๋ฆฌ๋žญ์ปค ๋ชจ๋“ˆ
"""
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import CrossEncoder
from langchain.schema import Document
from config import RERANKER_MODEL
class Reranker:
def __init__(self, model_name: str = RERANKER_MODEL):
"""
Cross-Encoder ๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™”
Args:
model_name: ์‚ฌ์šฉํ•  Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ„
"""
print(f"๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {model_name}")
# ์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ํ—ˆ์šฉ ์˜ต์…˜ ์ถ”๊ฐ€
self.model = CrossEncoder(
model_name,
trust_remote_code=True # ์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ํ—ˆ์šฉ (ํ•„์ˆ˜)
)
print(f"๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ: {model_name}")
def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]:
"""
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์žฌ์ •๋ ฌ
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
documents: ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ
top_k: ๋ฐ˜ํ™˜ํ•  ์ƒ์œ„ ๊ฒฐ๊ณผ ์ˆ˜
Returns:
์žฌ์ •๋ ฌ๋œ ์ƒ์œ„ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ
"""
if not documents:
return []
# Cross-Encoder ์ž…๋ ฅ ์Œ ์ƒ์„ฑ
document_texts = [doc.page_content for doc in documents]
query_doc_pairs = [(query, doc) for doc in document_texts]
# ์ ์ˆ˜ ๊ณ„์‚ฐ
print(f"๋ฆฌ๋žญํ‚น ์ˆ˜ํ–‰ ์ค‘: {len(documents)}๊ฐœ ๋ฌธ์„œ")
scores = self.model.predict(query_doc_pairs)
# ์ ์ˆ˜์— ๋”ฐ๋ผ ๋ฌธ์„œ ์žฌ์ •๋ ฌ
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
print(f"๋ฆฌ๋žญํ‚น ์™„๋ฃŒ: ์ƒ์œ„ {top_k}๊ฐœ ๋ฌธ์„œ ์„ ํƒ")
# ์ƒ์œ„ k๊ฐœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return [doc for doc, score in doc_score_pairs[:top_k]]