RAG_voice / app_rag.py
jeongsoo's picture
Add application file
a76f77b
"""
์•ฑ์— ๋‚ด์žฅ๋œ ๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„
"""
from typing import List, Dict, Any, Optional
import os
from config import OPENAI_API_KEY, LLM_MODEL, USE_OPENAI, TOP_K_RETRIEVAL
# ์•ˆ์ „ํ•œ ์ž„ํฌํŠธ
try:
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
LANGCHAIN_IMPORTS_AVAILABLE = True
except ImportError:
print("[APP_RAG] langchain ๊ด€๋ จ ํŒจํ‚ค์ง€๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
LANGCHAIN_IMPORTS_AVAILABLE = False
class SimpleRAGChain:
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (์•ฑ์— ๋‚ด์žฅ)
"""
def __init__(self, vector_store):
"""๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
print("[APP_RAG] ๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
self.vector_store = vector_store
if not LANGCHAIN_IMPORTS_AVAILABLE:
print("[APP_RAG] langchain ํŒจํ‚ค์ง€๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด RAG ์ฒด์ธ์„ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
raise ImportError("RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
# API ํ‚ค ํ™•์ธ
if not OPENAI_API_KEY and USE_OPENAI:
print("[APP_RAG] ๊ฒฝ๊ณ : OpenAI API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
raise ValueError("OpenAI API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
try:
# LLM ์ดˆ๊ธฐํ™”
if USE_OPENAI:
self.llm = ChatOpenAI(
model_name=LLM_MODEL,
temperature=0.2,
api_key=OPENAI_API_KEY,
)
print(f"[APP_RAG] OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
else:
try:
# Ollama ์‚ฌ์šฉ ์‹œ๋„
from langchain_community.chat_models import ChatOllama
from config import OLLAMA_HOST
self.llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_HOST,
)
print(f"[APP_RAG] Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
except ImportError:
# Ollama ๊ฐ€์ ธ์˜ค๊ธฐ ์‹คํŒจ ์‹œ OpenAI ์‚ฌ์šฉ
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.2,
api_key=OPENAI_API_KEY,
)
print("[APP_RAG] Ollama๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์–ด OpenAI๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.")
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
template = """
๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์งˆ๋ฌธ: {question}
์ฐธ๊ณ  ์ •๋ณด:
{context}
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”.
"""
self.prompt = PromptTemplate.from_template(template)
# ์ฒด์ธ ๊ตฌ์„ฑ
self.chain = (
{"context": self._retrieve, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
print("[APP_RAG] RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
except Exception as e:
print(f"[APP_RAG] RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
raise
def _retrieve(self, query):
"""๋ฌธ์„œ ๊ฒ€์ƒ‰"""
try:
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํŽ˜์ด์ง€: {page})"
context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")
return "\n".join(context_parts)
except Exception as e:
print(f"[APP_RAG] ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
import traceback
traceback.print_exc()
return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
def run(self, query):
"""์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
try:
return self.chain.invoke(query)
except Exception as e:
print(f"[APP_RAG] ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
import traceback
traceback.print_exc()
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"