성현 김
Set cache ENV vars in Dockerfile, removed cache_folder arg from HuggingFaceEmbeddings
ba5a5ec
# app/app.py
# 1. sqlite3 패치 코드를 가장 먼저 실행
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
# 2. 로딩 확인 및 디버깅용 print 문 (하나로 통일)
print("<<<<< app/app.py IS BEING LOADED (sqlite3 patched with pysqlite3) >>>>>")
# 3. 기본 인코딩 설정 (보통 파일 첫 줄에 두지만, 위 print문들 다음에 와도 괜찮습니다)
# -*- coding: utf-8 -*-
# 4. 필요한 모든 모듈들 import
import os
import time
import torch
import torch.nn.functional as F
import concurrent.futures
from pathlib import Path
from dotenv import load_dotenv, find_dotenv
import wikipediaapi
from konlpy.tag import Okt
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain.memory import ConversationBufferMemory
import traceback # 상세 오류 로깅용
from fastapi import FastAPI, HTTPException
from .schemas import GenerateRequest, GenerateResponse, ResetMemoryResponse # <--- 이 줄이 올바르게 있어야 합니다.
# schemas 모듈을 현재 디렉토리(app 패키지) 기준으로 상대 경로 import 합니다.
# from .schemas import GenerateRequest, GenerateResponse, ResetMemoryResponse # <--- 이 import 문은 한 번만 있어야 합니다.
# 5. FastAPI 애플리케이션 인스턴스 생성
# print("--- All imports in app.py successful, attempting FastAPI init ---") # 필요 시 디버깅
app = FastAPI(
title="미드저니 프롬프트 생성기 API",
description="사용자 입력, 대화 기록, 검색된 컨텍스트를 기반으로 미드저니 프롬프트를 생성합니다.",
version="1.0.0"
)
# print(f"--- FastAPI instance 'app' IS DEFINED in app.py, type: {type(app)} ---") # 필요 시 디버깅
# --- Configuration (FastAPI 객체 생성 후, @app.on_event("startup") 이전 또는 내부로 이동 가능) ---
# 이 print문들은 애플리케이션 로직의 일부이므로, FastAPI 객체 생성 이후에 위치하는 것이 자연스럽습니다.
# 또는 @app.on_event("startup") 함수 내부로 옮겨서 애플리케이션 시작 시점에 실행되도록 할 수도 있습니다.
# 지금은 순서상 큰 문제는 없어 보이지만, 명확성을 위해 FastAPI 객체 생성 이후에 두겠습니다.
print(f"--- FastAPI instance 'app' IS DEFINED in app.py, type: {type(app)} ---")
# --- Configuration ---
print("🚀 API 스크립트 시작: 설정 로딩 중...")
env_path = find_dotenv()
if env_path:
load_dotenv(env_path)
print(f"✅ .env 로드됨: {env_path}")
else:
print("⚠️ .env 파일이 없습니다. 기본값이나 환경변수를 사용합니다.")
# Docker 환경 내 애플리케이션 루트는 /app 입니다.
BASE_DIR = Path(os.getenv("PROJECT_ROOT", "/app"))
# Chroma DB는 Dockerfile에서 /app/chroma_db_data로 복사될 예정입니다.
CHROMA_DB_DIR = BASE_DIR / "chroma_db_data"
print(f"📂 Chroma DB 경로 (API): {CHROMA_DB_DIR}")
if not CHROMA_DB_DIR.exists():
# 애플리케이션 시작 시 치명적인 오류이므로, 여기서 멈추기보다는 경고 후 계속 진행
# 실제 운영 시에는 DB가 없으면 시작하지 않도록 처리할 수 있습니다.
print(f"🚨🚨🚨 중요 경고: Chroma DB 디렉토리가 존재하지 않습니다: {CHROMA_DB_DIR}. API가 정상 작동하지 않을 수 있습니다. Dockerfile에 chroma_db_data 폴더 복사 구문이 있고, 해당 폴더에 DB 파일이 있는지 확인하세요.")
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "intfloat/e5-large-v2")
TRANSLATION_MODEL_NAME = os.getenv("TRANSLATION_MODEL", "Helsinki-NLP/opus-mt-ko-en")
LLM_MODEL_NAME = os.getenv("LLM_MODEL", "sdgsjlfnjkl/kanana-2.1b-full-v12") # 사용자의 파인튜닝 모델
SBERT_MODEL_NAME = os.getenv("SBERT_MODEL", "snunlp/KR-SBERT-V40K-klueNLI-augSTS")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.80))
RETRIEVER_K = int(os.getenv("RETRIEVER_K", 5))
print(f"⚙️ 사용 디바이스: {DEVICE}")
print(f"📚 임베딩 모델 (Chroma용): {EMBEDDING_MODEL_NAME}")
print(f"✈️ 번역 모델: {TRANSLATION_MODEL_NAME}")
print(f"🧠 LLM 모델: {LLM_MODEL_NAME}")
print(f"🇰🇷 키워드 분석 모델 (SBERT): {SBERT_MODEL_NAME}")
print(f"🎯 유사도 임계값 (Chroma): {SIMILARITY_THRESHOLD}")
print(f"🔍 초기 검색 문서 수(k, Chroma): {RETRIEVER_K}")
# --- Global Variables for Models and Utilities (Load once at startup) ---
okt = None
wiki = None
STOPWORDS = {"하다", "되다", "있다", "없다"}
embedding_model = None
trans_tokenizer = None
trans_model = None
llm_tokenizer = None
llm_model_instance = None # llm_model 변수명 충돌 피하기 위해 변경
sbert_model_instance = None # sbert_model 변수명 충돌 피하기 위해 변경
db = None
retriever = None
memory = ConversationBufferMemory(memory_key="history", return_messages=True)
already_searched_wiki = set() # 전역 위키 검색 기록
inferencer = None
llm_chain = None
# --- Model Loading and Setup Function ---
def load_models_and_setup():
global okt, wiki, embedding_model, trans_tokenizer, trans_model, \
llm_tokenizer, llm_model_instance, sbert_model_instance, db, retriever, \
inferencer, llm_chain
print("\n⏳ 위키피디아 및 키워드 분석기 설정 중...")
try:
okt = Okt()
wiki = wikipediaapi.Wikipedia(user_agent='midjourney_prompt_generator_api/1.0', language='ko')
print("✅ Okt, Wikipedia API 설정 완료.")
except Exception as e:
print(f"🚨 Okt 또는 Wikipedia 설정 실패: {e}")
okt = None
wiki = None # Ensure it's None if setup fails
print("\n⏳ 모델 로딩 시작...")
start_load_time = time.time()
try:
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={'device': DEVICE}
# cache_folder="/app/cache/huggingface_embeddings" # <--- 이 줄은 일단 삭제 또는 주석 처리
)
print(f"✅ 임베딩 모델 로드 완료 ({EMBEDDING_MODEL_NAME})")
except Exception as e:
print(f"🚨🚨🚨 임베딩 모델 로딩 실패 (치명적): {e}"); raise
try:
trans_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_NAME)
trans_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL_NAME).to(DEVICE)
trans_model.eval()
print(f"✅ 번역 모델 로드 완료 ({TRANSLATION_MODEL_NAME})")
except Exception as e:
print(f"🚨🚨🚨 번역 모델 로딩 실패 (치명적): {e}"); raise
try:
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
# device_map="auto" 사용 시 accelerate 라이브러리 필요
llm_model_instance = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
if llm_tokenizer.pad_token is None:
llm_tokenizer.pad_token = llm_tokenizer.eos_token
print(" ⚠️ LLM 토크나이저의 pad_token이 없어 eos_token으로 설정합니다.")
llm_model_instance.eval()
print(f"✅ LLM 로드 완료 ({LLM_MODEL_NAME})")
if not callable(llm_tokenizer):
print(f"🚨 에러: LLM 토크나이저 로딩 실패 또는 callable 객체가 아닙니다. 타입: {type(llm_tokenizer)}"); raise TypeError("LLM tokenizer not callable")
except Exception as e:
print(f"🚨🚨🚨 LLM 로딩 실패 (치명적): {e}"); raise
try:
sbert_model_instance = SentenceTransformer(SBERT_MODEL_NAME, device=DEVICE)
print(f"✅ SBERT 모델 로드 완료 ({SBERT_MODEL_NAME})")
except Exception as e:
print(f"🚨 SBERT 모델 로딩 실패 (일부 기능 제한될 수 있음): {e}")
sbert_model_instance = None # Ensure it's None if setup fails
print(f"⏱️ 전체 모델 로딩 시간: {time.time() - start_load_time:.2f}초")
print("\n⏳ Chroma DB 연결 중...")
if not CHROMA_DB_DIR.exists():
print(f"🚨🚨🚨 Chroma DB 디렉토리 ({CHROMA_DB_DIR})를 찾을 수 없어 DB 연결을 건너<0xEB><0x81><0xB0>니다. API가 정상 작동하지 않습니다.")
db = None
retriever = None
else:
try:
db = Chroma(embedding_function=embedding_model, persist_directory=str(CHROMA_DB_DIR), collection_name="midjourney-prompts")
retriever = db.as_retriever(search_kwargs={"k": RETRIEVER_K})
print(f"✅ Chroma DB 연결 완료 (Collection: {db._collection.name if db and hasattr(db, '_collection') else 'N/A'}, Retriever k={RETRIEVER_K})")
if db:
try:
sample_docs = db.get(limit=1)
if not sample_docs or not sample_docs.get('ids'): print("⚠️ 경고: Chroma DB 컬렉션이 비어있거나 접근할 수 없습니다.")
else: print(f" ℹ️ DB 샘플 ID 확인: {sample_docs.get('ids')}")
except Exception as db_check_e: print(f"⚠️ Chroma DB 샘플 확인 중 오류: {db_check_e}")
except Exception as e:
print(f"🚨🚨🚨 Chroma DB 연결 실패 (치명적일 수 있음): {e}")
db = None # Ensure it's None
retriever = None # Ensure it's None
# Langchain Components (Models must be loaded first)
if llm_model_instance and callable(llm_tokenizer):
print("\n⛓️ Langchain Chain 구성 중...")
inferencer = LoRAInferencer(llm_model_instance, llm_tokenizer) # 클래스명 변경 없음
llm_chain = (
{
"input": RunnablePassthrough(),
"history": lambda x: format_memory_string(memory.chat_memory.messages),
"already_searched_wiki_ref": lambda x: already_searched_wiki # 전역 set 전달
}
| RunnablePassthrough.assign(
retrieved_wiki_context=lambda x: retrieve_wikipedia_context(x["input"], x["already_searched_wiki_ref"])
)
| RunnablePassthrough.assign(
modified_korean_request=lambda x: x["input"] + " 미드저니 프롬프트 작성해줘",
translated_request=lambda x: translate_ko_to_en(x["input"]),
)
| RunnablePassthrough.assign(
retrieved_chroma_context=lambda x: retrieve_english_context(x["translated_request"])
)
| (lambda x: print(f"""
DEBUG (API): Prompt Inputs Ready:
- Original Korean Input: {x.get('input', '')[:50]}...
- Modified Korean Request: {x.get('modified_korean_request', '')[:50]}...
- Translated Request (for Chroma): {x.get('translated_request', '')[:50]}...
- Retrieved Chroma Context: {x.get('retrieved_chroma_context', '')[:100]}...
- Retrieved Wiki Context: {x.get('retrieved_wiki_context', '')[:100]}...
- History (Formatted): {x.get('history', '')[-200:]}...
""") or x) # 디버깅 후 x 반환 확인
| prompt # 전역 프롬프트 템플릿 사용
| (lambda p: p.to_string())
| inferencer
)
print("✅ Langchain Chain 구성 완료.")
else:
print("🚨🚨🚨 LLM 모델 또는 토크나이저가 제대로 로드되지 않아 Langchain Chain을 구성할 수 없습니다.")
llm_chain = None
# --- Helper Functions (from user script, adapted slightly) ---
def translate_ko_to_en(text: str) -> str:
if not text or not trans_tokenizer or not trans_model: return ""
try:
with torch.no_grad():
inputs = trans_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
outputs = trans_model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
translated_text = trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
except Exception as e: print(f"🚨 번역 중 오류 발생 (입력: {text[:50]}...): {e}"); return f"[Translation Error: {e}]"
def retrieve_english_context(translated_english_text: str) -> str:
if not translated_english_text or "[Translation Error:" in translated_english_text or not retriever or not embedding_model:
print(" ⚠️ 번역 실패, 텍스트 부재, 리트리버/임베딩 모델 미로드로 Chroma 검색을 건너<0xEB><0x81><0xB0>니다.")
return ""
start_retrieval = time.time()
try:
initial_docs: list[Document] = retriever.invoke(translated_english_text) # retriever는 전역 사용
if not initial_docs: print(" ⚠️ Chroma DB에서 초기 문서를 찾지 못했습니다."); return ""
query_embedding_list = embedding_model.embed_query(translated_english_text) # embedding_model은 전역 사용
query_embedding = torch.tensor(query_embedding_list, dtype=torch.float).to(DEVICE).unsqueeze(0)
doc_contents = [doc.page_content for doc in initial_docs if doc.page_content]
if not doc_contents: print(" ⚠️ 검색된 문서에 유효한 내용이 없습니다."); return ""
doc_embeddings_list = embedding_model.embed_documents(doc_contents) # embedding_model은 전역 사용
doc_embeddings = torch.tensor(doc_embeddings_list, dtype=torch.float).to(DEVICE)
similarities = F.cosine_similarity(query_embedding, doc_embeddings, dim=1)
similarities_list = similarities.cpu().tolist()
doc_similarity_pairs = list(zip(initial_docs, similarities_list))
filtered_docs = [(doc, sim) for doc, sim in doc_similarity_pairs if sim >= SIMILARITY_THRESHOLD]
if not filtered_docs: print(f" ❌ 유사도 {SIMILARITY_THRESHOLD} 이상인 Chroma 문서를 찾지 못했습니다."); return ""
filtered_docs.sort(key=lambda item: item[1], reverse=True)
best_doc, best_sim = filtered_docs[0]
print(f" ✅ 가장 유사한 Chroma 문서 선택 (유사도: {best_sim:.4f})")
return best_doc.page_content.strip()
except Exception as e:
print(f"🚨 Chroma 컨텍스트 검색/유사도 계산 중 오류 발생: {e}")
traceback.print_exc()
return "[Context Retrieval/Similarity Error]"
def extract_keywords(text: str) -> list[str]:
if not okt: print(" ⚠️ Okt 형태소 분석기가 로드되지 않아 키워드 추출을 건너<0xEB><0x81><0xB0>니다."); return []
try:
nouns = okt.nouns(text)
verbs_adjectives = [w for w, pos in okt.pos(text, stem=True) if pos in ['Adjective', 'Verb']] # 어간 추출 추가
keywords = [w for w in nouns + verbs_adjectives if w not in STOPWORDS and len(w) > 1]
return list(set(keywords))
except Exception as e: print(f"🚨 키워드 추출 중 오류: {e}"); return []
def sort_by_semantic_importance(text: str, keywords: list[str]) -> list[str]:
if not sbert_model_instance or not keywords: return keywords # sbert_model_instance 사용
try:
text_emb = sbert_model_instance.encode(text, convert_to_tensor=True)
keyword_embs = sbert_model_instance.encode(keywords, convert_to_tensor=True)
scores = util.cos_sim(text_emb, keyword_embs)[0]
sorted_keywords = [kw for kw, _ in sorted(zip(keywords, scores.cpu().tolist()), key=lambda x: -x[1])]
return sorted_keywords
except Exception as e: print(f"🚨 키워드 중요도 정렬 중 오류: {e}"); return keywords
def get_wiki_content(word: str, max_length: int = 200) -> str:
if not wiki: print(" ⚠️ Wikipedia API가 설정되지 않아 검색을 건너<0xEB><0x81><0xB0>니다."); return ""
try:
page = wiki.page(word)
if page.exists():
summary = page.summary[:max_length].replace("\n", " ")
return summary
return ""
except Exception as e: print(f"🚨 위키피디아 검색 중 오류 ('{word}'): {e}"); return ""
def retrieve_wikipedia_context(korean_input: str, current_already_searched_wiki: set) -> str: # 인자로 already_searched_wiki 받음
if not korean_input or not okt or not wiki or not sbert_model_instance:
print(" ⚠️ Wikipedia 컨텍스트 생성에 필요한 요소 부족 (okt, wiki, sbert 중 하나 이상)")
return ""
start_wiki_retrieval = time.time()
keywords = extract_keywords(korean_input)
if not keywords: print(" ⚠️ 키워드를 추출하지 못했습니다."); return ""
sorted_keywords = sort_by_semantic_importance(korean_input, keywords)
keywords_to_search = [kw for kw in sorted_keywords if kw not in current_already_searched_wiki][:3] # 전달받은 set 사용
if not keywords_to_search: print(" ℹ️ 이미 검색했거나 검색할 새로운 위키 키워드가 없습니다."); return ""
wiki_context_parts = []
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
future_to_keyword = {executor.submit(get_wiki_content, kw): kw for kw in keywords_to_search}
for future in concurrent.futures.as_completed(future_to_keyword):
keyword = future_to_keyword[future]
try:
result = future.result()
if result:
wiki_context_parts.append(f"- {keyword}: {result}")
current_already_searched_wiki.add(keyword) # 전달받은 set 업데이트 (전역 set이 업데이트됨)
except Exception as exc: print(f' 🚨 위키 검색 중 예외 발생 ({keyword}): {exc}')
wiki_context = "\n".join(wiki_context_parts)
total_wiki_time = time.time() - start_wiki_retrieval
if wiki_context: print(f" ✅ Wikipedia 컨텍스트 생성 완료 ({total_wiki_time:.2f}초)"); return wiki_context.strip()
else: print(f" ❌ 관련된 Wikipedia 정보를 찾지 못했습니다. ({total_wiki_time:.2f}초)"); return ""
except Exception as e:
print(f"🚨 Wikipedia 컨텍스트 생성 중 오류 발생: {e}")
traceback.print_exc()
return "[Wikipedia Context Retrieval Error]"
# --- Langchain Components ( 정의는 load_models_and_setup 안으로 이동 ) ---
class LoRAInferencer(Runnable): # 스크립트에 있는 클래스명 사용
def __init__(self, model, tokenizer):
self.model = model
if not callable(tokenizer): raise TypeError(f"LoRAInferencer 초기화 실패: 전달된 tokenizer가 callable이 아닙니다.")
self.tokenizer = tokenizer
def invoke(self, input: str, config=None): # input 타입 명시 (str)
if not callable(self.tokenizer): raise TypeError(f"LoRAInferencer invoke 실패: self.tokenizer가 callable이 아닙니다.")
prompt_text = input # input은 이미 prompt.to_string()의 결과인 문자열
try:
inputs = self.tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True, max_length=1536).to(self.model.device)
if 'token_type_ids' in inputs: del inputs['token_type_ids'] # Gemma 같은 모델은 token_type_ids 불필요
except Exception as e: print(f"🚨 토크나이징 중 에러 발생: {e}"); raise
try:
with torch.no_grad():
outputs = self.model.generate(
**inputs, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2,
no_repeat_ngram_size=3, max_new_tokens=300, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id, num_return_sequences=1
)
# 입력 프롬프트를 제외한 생성된 부분만 디코딩
output_token_ids = outputs[0][inputs['input_ids'].shape[1]:]
decoded_output = self.tokenizer.decode(output_token_ids, skip_special_tokens=True).strip()
# 후처리 (스크립트 내용 반영)
if decoded_output.startswith("[Midjourney Prompt (English)]"):
decoded_output = decoded_output.replace("[Midjourney Prompt (English)]","").strip()
if not decoded_output:
print("⚠️ 생성된 결과가 비어있습니다. 입력 프롬프트의 마지막 부분을 확인합니다.")
answer_marker = "[Midjourney Prompt (English)]\n"
if answer_marker in prompt_text: # prompt_text는 LLM에 들어간 전체 프롬프트
potential_output = prompt_text.split(answer_marker)[-1].strip()
if potential_output and potential_output != input: # input은 chain의 최종 입력이 아닌 초기 입력
decoded_output = potential_output
print(" -> 입력 프롬프트 마지막 부분을 결과로 사용.")
else:
decoded_output = "..."
print(" -> 기본값 '...' 사용 (potential_output이 비었거나 초기 입력과 동일).")
else:
decoded_output = "..."
print(" -> 기본값 '...' 사용 (answer_marker 없음).")
return decoded_output
except Exception as e:
print(f"🚨 모델 생성 중 에러 발생: {e}")
traceback.print_exc()
# 오류 시 입력 프롬프트에서 답변 부분 추출 시도
try: return prompt_text.split("[Midjourney Prompt (English)]\n")[-1].strip()
except: raise # 이것도 실패하면 원래 예외 발생
prompt = PromptTemplate.from_template('''You are a professional prompt engineer creating evolving prompts for Midjourney.
Continuously build upon the previous conversation history.
Generate a concise and effective Midjourney prompt IN ENGLISH ONLY based on the entire conversation flow.
[Reference English Examples from Database (Chroma)]
{retrieved_chroma_context}
[Reference Context from Wikipedia (Korean)]
{retrieved_wiki_context}
[Conversation Context So Far]
{history}
[Now Continue the Prompt for Midjourney based on the above conversation and the following request:]
User Request (Korean): {modified_korean_request}
[Midjourney Prompt (English)]
''')
def format_memory_string(messages):
# print(f"--- DEBUG: format_memory_string called with {len(messages)} messages ---") # API 로그에 너무 많을 수 있음
if not messages: return "No prior conversation."
lines = []
for msg in messages:
role = "User" if msg.type == "human" else "Assistant"
lines.append(f"{role}: {msg.content}")
formatted_history = "\n".join(lines)
# print(f" 🔄 Formatted History (for Prompt):\n{formatted_history[-300:]}...") # API 로그에 너무 많을 수 있음
return formatted_history
# --- Application Startup Event ---
@app.on_event("startup")
async def startup_event():
print("🌟 FastAPI 애플리케이션 시작... 모델 및 설정 로딩...")
try:
load_models_and_setup()
print("✅ 모든 모델 및 설정 로딩 완료.")
except Exception as e:
print(f"🚨🚨🚨 애플리케이션 시작 중 치명적 오류 발생: {e}")
traceback.print_exc()
# 여기서 애플리케이션을 종료하거나, 오류 상태로 실행되도록 할 수 있습니다.
# Hugging Face Space에서는 오류가 나도 계속 실행될 수 있도록 두는 것이 일반적입니다 (로그 확인 가능).
# raise HTTPException(status_code=500, detail=f"서버 초기화 실패: {e}") # 이렇게 하면 서버가 시작 안 될 수 있음
# --- API Endpoints ---
@app.get("/", summary="API 루트", description="API가 실행 중인지 확인합니다.")
async def read_root():
return {"message": "Midjourney 프롬프트 생성기 API가 실행 중입니다. /docs 에서 API 문서를 확인하세요."}
@app.get("/health", summary="헬스 체크", description="API 서버의 상태를 확인합니다.")
async def health_check():
# 모델 로드 상태 등을 더 상세히 체크할 수 있습니다.
if llm_chain and retriever and okt and wiki and sbert_model_instance and trans_model and embedding_model:
return {"status": "healthy", "message": "모든 주요 구성 요소가 로드되었습니다."}
else:
missing = []
if not llm_chain: missing.append("LLM Chain")
if not retriever: missing.append("Chroma Retriever")
if not okt: missing.append("Okt")
if not wiki: missing.append("Wikipedia")
if not sbert_model_instance: missing.append("SBERT model")
if not trans_model: missing.append("Translation model")
if not embedding_model: missing.append("Embedding model")
return {"status": "degraded", "message": f"일부 구성 요소 로드 실패: {', '.join(missing)}"}
@app.post("/generate", response_model=GenerateResponse, summary="미드저니 프롬프트 생성", description="사용자 입력을 기반으로 미드저니 프롬프트를 생성하고 대화 기록을 업데이트합니다.")
async def generate_api_prompt(request: GenerateRequest):
global memory, already_searched_wiki # 전역 변수 사용 명시
if not llm_chain:
print("🚨 /generate 호출: LLM 체인이 로드되지 않았습니다.")
raise HTTPException(status_code=503, detail="서버가 아직 준비되지 않았거나 초기화에 실패했습니다. LLM 체인을 사용할 수 없습니다.")
user_input = request.user_input
if not user_input:
raise HTTPException(status_code=400, detail="user_input 필드는 비워둘 수 없습니다.")
print(f"💬 API /generate 호출 (입력: '{user_input[:50]}...')")
start_time = time.time()
try:
# --- DEBUG PRINT (API) ---
print(f"--- DEBUG (API): Memory before llm_chain.invoke: {len(memory.chat_memory.messages)} messages ---")
# if memory.chat_memory.messages:
# for i, msg in enumerate(memory.chat_memory.messages):
# print(f" DEBUG (API) MSG {i}: type={msg.type}, content='{str(msg.content)[:50]}...'")
# LangChain 실행
result = llm_chain.invoke(user_input) # llm_chain은 전역 변수 사용
end_time = time.time()
processing_time = end_time - start_time
print(f"⏱️ API 생성 및 검색 시간: {processing_time:.2f}초")
history_updated = False
if isinstance(result, str) and not result.startswith("[오류 발생]"):
try:
memory.save_context({"input": user_input}, {"output": result}) # memory는 전역 변수 사용
history_updated = True
print(" ✅ API 대화 기록에 저장되었습니다.")
# --- DEBUG PRINT (API) ---
# print(f"--- DEBUG (API): Memory after save_context: {len(memory.chat_memory.messages)} messages ---")
# if memory.chat_memory.messages:
# for i, msg in enumerate(memory.chat_memory.messages):
# print(f" DEBUG (API) SAVED MSG {i}: type={msg.type}, content='{str(msg.content)[:50]}...'")
except Exception as mem_err:
print(f" 🚨 API 대화 기록 저장 중 오류 발생: {mem_err}")
else:
print(" ⚠️ API 생성 결과가 오류 문자열이거나 유효하지 않아 기록되지 않았습니다.")
return GenerateResponse(
generated_prompt=result,
processing_time_seconds=round(processing_time, 2),
# history_updated=history_updated # 필요하다면 응답에 포함
)
except Exception as e:
print(f"🚨 API /generate 실행 중 심각한 오류 발생: {e}")
traceback.print_exc()
end_time = time.time()
processing_time = end_time - start_time
raise HTTPException(
status_code=500,
detail=f"요청 처리 중 서버 오류 발생: {e}"
)
# @app.post("/reset_memory", response_model=ResetMemoryResponse, summary="대화 기록 초기화", description="서버의 대화 기록과 위키 검색 기록을 초기화합니다.")
# async def reset_api_memory():
# global memory, already_searched_wiki # 전역 변수 사용 명시
# memory_cleared_count = len(memory.chat_memory.messages)
# wiki_cleared_count = len(already_searched_wiki)
memory.clear()
already_searched_wiki.clear()
print(f"🔄 API /reset_memory 호출: 대화 기록 ({memory_cleared_count}개) 및 위키 검색 기록 ({wiki_cleared_count}개) 초기화 완료.")
return ResetMemoryResponse(
message="대화 기록 및 위키 검색 기록이 성공적으로 초기화되었습니다.",
cleared_memory=True,
cleared_wiki_history=True
)
print(f"----- FastAPI instance 'app' IS DEFINED in app.app.py, type: {type(app)}, dir: {dir()}")
# 로컬 테스트용 uvicorn 실행 명령어 (실제 Space에서는 Dockerfile의 CMD 사용)
# if __name__ == "__main__":
# import uvicorn
# # load_models_and_setup() # uvicorn이 시작 시 @app.on_event("startup") 호출
# uvicorn.run(app, host="0.0.0.0", port=8000)