|
""" |
|
์ฑ์ ๋ด์ฅ๋ ๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ |
|
""" |
|
from typing import List, Dict, Any, Optional |
|
import os |
|
from config import OPENAI_API_KEY, LLM_MODEL, USE_OPENAI, TOP_K_RETRIEVAL |
|
|
|
|
|
try: |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
LANGCHAIN_IMPORTS_AVAILABLE = True |
|
except ImportError: |
|
print("[APP_RAG] langchain ๊ด๋ จ ํจํค์ง๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.") |
|
LANGCHAIN_IMPORTS_AVAILABLE = False |
|
|
|
class SimpleRAGChain: |
|
""" |
|
๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ (์ฑ์ ๋ด์ฅ) |
|
""" |
|
def __init__(self, vector_store): |
|
"""๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ""" |
|
print("[APP_RAG] ๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์ค...") |
|
self.vector_store = vector_store |
|
|
|
if not LANGCHAIN_IMPORTS_AVAILABLE: |
|
print("[APP_RAG] langchain ํจํค์ง๋ฅผ ์ฐพ์ ์ ์์ด RAG ์ฒด์ธ์ ์ด๊ธฐํํ ์ ์์ต๋๋ค.") |
|
raise ImportError("RAG ์ฒด์ธ ์ด๊ธฐํ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค.") |
|
|
|
|
|
if not OPENAI_API_KEY and USE_OPENAI: |
|
print("[APP_RAG] ๊ฒฝ๊ณ : OpenAI API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") |
|
raise ValueError("OpenAI API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") |
|
|
|
try: |
|
|
|
if USE_OPENAI: |
|
self.llm = ChatOpenAI( |
|
model_name=LLM_MODEL, |
|
temperature=0.2, |
|
api_key=OPENAI_API_KEY, |
|
) |
|
print(f"[APP_RAG] OpenAI ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}") |
|
else: |
|
try: |
|
|
|
from langchain_community.chat_models import ChatOllama |
|
from config import OLLAMA_HOST |
|
|
|
self.llm = ChatOllama( |
|
model=LLM_MODEL, |
|
temperature=0.2, |
|
base_url=OLLAMA_HOST, |
|
) |
|
print(f"[APP_RAG] Ollama ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}") |
|
except ImportError: |
|
|
|
self.llm = ChatOpenAI( |
|
model_name="gpt-3.5-turbo", |
|
temperature=0.2, |
|
api_key=OPENAI_API_KEY, |
|
) |
|
print("[APP_RAG] Ollama๋ฅผ ์ฌ์ฉํ ์ ์์ด OpenAI๋ก ๋์ฒดํฉ๋๋ค.") |
|
|
|
|
|
template = """ |
|
๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. |
|
|
|
์ง๋ฌธ: {question} |
|
|
|
์ฐธ๊ณ ์ ๋ณด: |
|
{context} |
|
|
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. |
|
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช
ํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์. |
|
""" |
|
|
|
self.prompt = PromptTemplate.from_template(template) |
|
|
|
|
|
self.chain = ( |
|
{"context": self._retrieve, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
print("[APP_RAG] RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") |
|
except Exception as e: |
|
print(f"[APP_RAG] RAG ์ฒด์ธ ์ด๊ธฐํ ์คํจ: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
def _retrieve(self, query): |
|
"""๋ฌธ์ ๊ฒ์""" |
|
try: |
|
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(docs, 1): |
|
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") |
|
page = doc.metadata.get("page", "") |
|
source_info = f"{source}" |
|
if page: |
|
source_info += f" (ํ์ด์ง: {page})" |
|
|
|
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") |
|
|
|
return "\n".join(context_parts) |
|
except Exception as e: |
|
print(f"[APP_RAG] ๊ฒ์ ์ค ์ค๋ฅ: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return "๋ฌธ์ ๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค." |
|
|
|
def run(self, query): |
|
"""์ฟผ๋ฆฌ ์ฒ๋ฆฌ""" |
|
try: |
|
return self.chain.invoke(query) |
|
except Exception as e: |
|
print(f"[APP_RAG] ์คํ ์ค ์ค๋ฅ: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |