RAG_voice / rag_chain.py
jeongsoo's picture
Add application file
4a98f26
"""
LangChain์„ ํ™œ์šฉํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„
"""
from typing import List, Dict, Any
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.chat_models import ChatOllama
from langchain_openai import ChatOpenAI
from config import (
OLLAMA_HOST, LLM_MODEL, USE_OPENAI,
OPENAI_API_KEY, TOP_K_RETRIEVAL, TOP_K_RERANK
)
from vector_store import VectorStore
from reranker import Reranker
class RAGChain:
def __init__(self, vector_store: VectorStore, use_reranker: bool = True):
"""
RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” (ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ LLM ์„ ํƒ)
Args:
vector_store: ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ธ์Šคํ„ด์Šค
use_reranker: ๋ฆฌ๋žญ์ปค ์‚ฌ์šฉ ์—ฌ๋ถ€
"""
try:
print("RAGChain ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
self.vector_store = vector_store
self.use_reranker = use_reranker
print(f"๋ฆฌ๋žญ์ปค ์‚ฌ์šฉ ์—ฌ๋ถ€: {use_reranker}")
if use_reranker:
try:
self.reranker = Reranker()
print("๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
except Exception as e:
print(f"๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
self.reranker = None
self.use_reranker = False
else:
self.reranker = None
# ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ LLM ๋ชจ๋ธ ์„ค์ •
if USE_OPENAI or IS_HUGGINGFACE:
print(f"OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
print(f"API ํ‚ค ์กด์žฌ ์—ฌ๋ถ€: {'์žˆ์Œ' if OPENAI_API_KEY else '์—†์Œ'}")
try:
self.llm = ChatOpenAI(
model_name=LLM_MODEL,
temperature=0.2,
api_key=OPENAI_API_KEY,
)
print("OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
except Exception as e:
print(f"OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
raise
else:
try:
print(f"Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
self.llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_HOST,
)
print("Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
except Exception as e:
print(f"Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
raise
# RAG ์ฒด์ธ ๊ตฌ์„ฑ ๋ฐ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
print("RAG ์ฒด์ธ ์„ค์ • ์‹œ์ž‘...")
self.setup_chain()
print("RAG ์ฒด์ธ ์„ค์ • ์™„๋ฃŒ")
except Exception as e:
print(f"RAGChain ์ดˆ๊ธฐํ™” ์ค‘ ์ƒ์„ธ ์˜ค๋ฅ˜: {str(e)}")
import traceback
traceback.print_exc()
raise
def setup_chain(self) -> None:
"""
RAG ์ฒด์ธ ๋ฐ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
"""
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ •์˜
template = """
๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์งˆ๋ฌธ: {question}
์ฐธ๊ณ  ์ •๋ณด:
{context}
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”.
"""
self.prompt = PromptTemplate.from_template(template)
# RAG ์ฒด์ธ ์ •์˜
self.chain = (
{"context": self._retrieve, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def _retrieve(self, query: str) -> str:
"""
์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํฌํ•จํ•œ ์ปจํ…์ŠคํŠธ ๋ฌธ์ž์—ด
"""
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)
# ๋ฆฌ๋žญ์ปค ์ ์šฉ (์„ ํƒ์ )
if self.use_reranker and docs:
docs = self.reranker.rerank(query, docs, top_k=TOP_K_RERANK)
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํŽ˜์ด์ง€: {page})"
context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")
return "\n".join(context_parts)
def run(self, query: str) -> str:
"""
์‚ฌ์šฉ์ž ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ RAG ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๋ชจ๋ธ ์‘๋‹ต ๋ฌธ์ž์—ด
"""
return self.chain.invoke(query)