|
""" |
|
LangChain์ ํ์ฉํ RAG ์ฒด์ธ ๊ตฌํ |
|
""" |
|
from typing import List, Dict, Any |
|
from langchain.schema import Document |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_community.chat_models import ChatOllama |
|
from langchain_openai import ChatOpenAI |
|
|
|
from config import ( |
|
OLLAMA_HOST, LLM_MODEL, USE_OPENAI, |
|
OPENAI_API_KEY, TOP_K_RETRIEVAL, TOP_K_RERANK |
|
) |
|
from vector_store import VectorStore |
|
from reranker import Reranker |
|
|
|
|
|
class RAGChain: |
|
def __init__(self, vector_store: VectorStore, use_reranker: bool = True): |
|
""" |
|
RAG ์ฒด์ธ ์ด๊ธฐํ (ํ๊ฒฝ์ ๋ฐ๋ฅธ LLM ์ ํ) |
|
|
|
Args: |
|
vector_store: ๋ฒกํฐ ์คํ ์ด ์ธ์คํด์ค |
|
use_reranker: ๋ฆฌ๋ญ์ปค ์ฌ์ฉ ์ฌ๋ถ |
|
""" |
|
try: |
|
print("RAGChain ์ด๊ธฐํ ์์...") |
|
self.vector_store = vector_store |
|
self.use_reranker = use_reranker |
|
print(f"๋ฆฌ๋ญ์ปค ์ฌ์ฉ ์ฌ๋ถ: {use_reranker}") |
|
|
|
if use_reranker: |
|
try: |
|
self.reranker = Reranker() |
|
print("๋ฆฌ๋ญ์ปค ์ด๊ธฐํ ์ฑ๊ณต") |
|
except Exception as e: |
|
print(f"๋ฆฌ๋ญ์ปค ์ด๊ธฐํ ์คํจ: {str(e)}") |
|
self.reranker = None |
|
self.use_reranker = False |
|
else: |
|
self.reranker = None |
|
|
|
|
|
if USE_OPENAI or IS_HUGGINGFACE: |
|
print(f"OpenAI ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}") |
|
print(f"API ํค ์กด์ฌ ์ฌ๋ถ: {'์์' if OPENAI_API_KEY else '์์'}") |
|
try: |
|
self.llm = ChatOpenAI( |
|
model_name=LLM_MODEL, |
|
temperature=0.2, |
|
api_key=OPENAI_API_KEY, |
|
) |
|
print("OpenAI ๋ชจ๋ธ ์ด๊ธฐํ ์ฑ๊ณต") |
|
except Exception as e: |
|
print(f"OpenAI ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {str(e)}") |
|
raise |
|
else: |
|
try: |
|
print(f"Ollama ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}") |
|
self.llm = ChatOllama( |
|
model=LLM_MODEL, |
|
temperature=0.2, |
|
base_url=OLLAMA_HOST, |
|
) |
|
print("Ollama ๋ชจ๋ธ ์ด๊ธฐํ ์ฑ๊ณต") |
|
except Exception as e: |
|
print(f"Ollama ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {str(e)}") |
|
raise |
|
|
|
|
|
print("RAG ์ฒด์ธ ์ค์ ์์...") |
|
self.setup_chain() |
|
print("RAG ์ฒด์ธ ์ค์ ์๋ฃ") |
|
except Exception as e: |
|
print(f"RAGChain ์ด๊ธฐํ ์ค ์์ธ ์ค๋ฅ: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
def setup_chain(self) -> None: |
|
""" |
|
RAG ์ฒด์ธ ๋ฐ ํ๋กฌํํธ ์ค์ |
|
""" |
|
|
|
template = """ |
|
๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. |
|
|
|
์ง๋ฌธ: {question} |
|
|
|
์ฐธ๊ณ ์ ๋ณด: |
|
{context} |
|
|
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. |
|
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช
ํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์. |
|
""" |
|
|
|
self.prompt = PromptTemplate.from_template(template) |
|
|
|
|
|
self.chain = ( |
|
{"context": self._retrieve, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
|
|
def _retrieve(self, query: str) -> str: |
|
""" |
|
์ฟผ๋ฆฌ์ ๋ํ ๊ด๋ จ ๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ
์คํธ ๊ตฌ์ฑ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌํจํ ์ปจํ
์คํธ ๋ฌธ์์ด |
|
""" |
|
|
|
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) |
|
|
|
|
|
if self.use_reranker and docs: |
|
docs = self.reranker.rerank(query, docs, top_k=TOP_K_RERANK) |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(docs, 1): |
|
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") |
|
page = doc.metadata.get("page", "") |
|
source_info = f"{source}" |
|
if page: |
|
source_info += f" (ํ์ด์ง: {page})" |
|
|
|
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") |
|
|
|
return "\n".join(context_parts) |
|
|
|
def run(self, query: str) -> str: |
|
""" |
|
์ฌ์ฉ์ ์ฟผ๋ฆฌ์ ๋ํ RAG ํ์ดํ๋ผ์ธ ์คํ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๋ชจ๋ธ ์๋ต ๋ฌธ์์ด |
|
""" |
|
return self.chain.invoke(query) |