Spaces:
Sleeping
Sleeping
# app/app.py | |
# 1. sqlite3 패치 코드를 가장 먼저 실행 | |
__import__('pysqlite3') | |
import sys | |
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
# 2. 로딩 확인 및 디버깅용 print 문 (하나로 통일) | |
print("<<<<< app/app.py IS BEING LOADED (sqlite3 patched with pysqlite3) >>>>>") | |
# 3. 기본 인코딩 설정 (보통 파일 첫 줄에 두지만, 위 print문들 다음에 와도 괜찮습니다) | |
# -*- coding: utf-8 -*- | |
# 4. 필요한 모든 모듈들 import | |
import os | |
import time | |
import torch | |
import torch.nn.functional as F | |
import concurrent.futures | |
from pathlib import Path | |
from dotenv import load_dotenv, find_dotenv | |
import wikipediaapi | |
from konlpy.tag import Okt | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import Runnable, RunnablePassthrough | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
from langchain_core.documents import Document | |
from langchain.memory import ConversationBufferMemory | |
import traceback # 상세 오류 로깅용 | |
from fastapi import FastAPI, HTTPException | |
from .schemas import GenerateRequest, GenerateResponse, ResetMemoryResponse # <--- 이 줄이 올바르게 있어야 합니다. | |
# schemas 모듈을 현재 디렉토리(app 패키지) 기준으로 상대 경로 import 합니다. | |
# from .schemas import GenerateRequest, GenerateResponse, ResetMemoryResponse # <--- 이 import 문은 한 번만 있어야 합니다. | |
# 5. FastAPI 애플리케이션 인스턴스 생성 | |
# print("--- All imports in app.py successful, attempting FastAPI init ---") # 필요 시 디버깅 | |
app = FastAPI( | |
title="미드저니 프롬프트 생성기 API", | |
description="사용자 입력, 대화 기록, 검색된 컨텍스트를 기반으로 미드저니 프롬프트를 생성합니다.", | |
version="1.0.0" | |
) | |
# print(f"--- FastAPI instance 'app' IS DEFINED in app.py, type: {type(app)} ---") # 필요 시 디버깅 | |
# --- Configuration (FastAPI 객체 생성 후, @app.on_event("startup") 이전 또는 내부로 이동 가능) --- | |
# 이 print문들은 애플리케이션 로직의 일부이므로, FastAPI 객체 생성 이후에 위치하는 것이 자연스럽습니다. | |
# 또는 @app.on_event("startup") 함수 내부로 옮겨서 애플리케이션 시작 시점에 실행되도록 할 수도 있습니다. | |
# 지금은 순서상 큰 문제는 없어 보이지만, 명확성을 위해 FastAPI 객체 생성 이후에 두겠습니다. | |
print(f"--- FastAPI instance 'app' IS DEFINED in app.py, type: {type(app)} ---") | |
# --- Configuration --- | |
print("🚀 API 스크립트 시작: 설정 로딩 중...") | |
env_path = find_dotenv() | |
if env_path: | |
load_dotenv(env_path) | |
print(f"✅ .env 로드됨: {env_path}") | |
else: | |
print("⚠️ .env 파일이 없습니다. 기본값이나 환경변수를 사용합니다.") | |
# Docker 환경 내 애플리케이션 루트는 /app 입니다. | |
BASE_DIR = Path(os.getenv("PROJECT_ROOT", "/app")) | |
# Chroma DB는 Dockerfile에서 /app/chroma_db_data로 복사될 예정입니다. | |
CHROMA_DB_DIR = BASE_DIR / "chroma_db_data" | |
print(f"📂 Chroma DB 경로 (API): {CHROMA_DB_DIR}") | |
if not CHROMA_DB_DIR.exists(): | |
# 애플리케이션 시작 시 치명적인 오류이므로, 여기서 멈추기보다는 경고 후 계속 진행 | |
# 실제 운영 시에는 DB가 없으면 시작하지 않도록 처리할 수 있습니다. | |
print(f"🚨🚨🚨 중요 경고: Chroma DB 디렉토리가 존재하지 않습니다: {CHROMA_DB_DIR}. API가 정상 작동하지 않을 수 있습니다. Dockerfile에 chroma_db_data 폴더 복사 구문이 있고, 해당 폴더에 DB 파일이 있는지 확인하세요.") | |
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "intfloat/e5-large-v2") | |
TRANSLATION_MODEL_NAME = os.getenv("TRANSLATION_MODEL", "Helsinki-NLP/opus-mt-ko-en") | |
LLM_MODEL_NAME = os.getenv("LLM_MODEL", "sdgsjlfnjkl/kanana-2.1b-full-v12") # 사용자의 파인튜닝 모델 | |
SBERT_MODEL_NAME = os.getenv("SBERT_MODEL", "snunlp/KR-SBERT-V40K-klueNLI-augSTS") | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.80)) | |
RETRIEVER_K = int(os.getenv("RETRIEVER_K", 5)) | |
print(f"⚙️ 사용 디바이스: {DEVICE}") | |
print(f"📚 임베딩 모델 (Chroma용): {EMBEDDING_MODEL_NAME}") | |
print(f"✈️ 번역 모델: {TRANSLATION_MODEL_NAME}") | |
print(f"🧠 LLM 모델: {LLM_MODEL_NAME}") | |
print(f"🇰🇷 키워드 분석 모델 (SBERT): {SBERT_MODEL_NAME}") | |
print(f"🎯 유사도 임계값 (Chroma): {SIMILARITY_THRESHOLD}") | |
print(f"🔍 초기 검색 문서 수(k, Chroma): {RETRIEVER_K}") | |
# --- Global Variables for Models and Utilities (Load once at startup) --- | |
okt = None | |
wiki = None | |
STOPWORDS = {"하다", "되다", "있다", "없다"} | |
embedding_model = None | |
trans_tokenizer = None | |
trans_model = None | |
llm_tokenizer = None | |
llm_model_instance = None # llm_model 변수명 충돌 피하기 위해 변경 | |
sbert_model_instance = None # sbert_model 변수명 충돌 피하기 위해 변경 | |
db = None | |
retriever = None | |
memory = ConversationBufferMemory(memory_key="history", return_messages=True) | |
already_searched_wiki = set() # 전역 위키 검색 기록 | |
inferencer = None | |
llm_chain = None | |
# --- Model Loading and Setup Function --- | |
def load_models_and_setup(): | |
global okt, wiki, embedding_model, trans_tokenizer, trans_model, \ | |
llm_tokenizer, llm_model_instance, sbert_model_instance, db, retriever, \ | |
inferencer, llm_chain | |
print("\n⏳ 위키피디아 및 키워드 분석기 설정 중...") | |
try: | |
okt = Okt() | |
wiki = wikipediaapi.Wikipedia(user_agent='midjourney_prompt_generator_api/1.0', language='ko') | |
print("✅ Okt, Wikipedia API 설정 완료.") | |
except Exception as e: | |
print(f"🚨 Okt 또는 Wikipedia 설정 실패: {e}") | |
okt = None | |
wiki = None # Ensure it's None if setup fails | |
print("\n⏳ 모델 로딩 시작...") | |
start_load_time = time.time() | |
try: | |
embedding_model = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL_NAME, | |
model_kwargs={'device': DEVICE} | |
# cache_folder="/app/cache/huggingface_embeddings" # <--- 이 줄은 일단 삭제 또는 주석 처리 | |
) | |
print(f"✅ 임베딩 모델 로드 완료 ({EMBEDDING_MODEL_NAME})") | |
except Exception as e: | |
print(f"🚨🚨🚨 임베딩 모델 로딩 실패 (치명적): {e}"); raise | |
try: | |
trans_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_NAME) | |
trans_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL_NAME).to(DEVICE) | |
trans_model.eval() | |
print(f"✅ 번역 모델 로드 완료 ({TRANSLATION_MODEL_NAME})") | |
except Exception as e: | |
print(f"🚨🚨🚨 번역 모델 로딩 실패 (치명적): {e}"); raise | |
try: | |
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
# device_map="auto" 사용 시 accelerate 라이브러리 필요 | |
llm_model_instance = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME, device_map="auto", torch_dtype=torch.float16) | |
if llm_tokenizer.pad_token is None: | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
print(" ⚠️ LLM 토크나이저의 pad_token이 없어 eos_token으로 설정합니다.") | |
llm_model_instance.eval() | |
print(f"✅ LLM 로드 완료 ({LLM_MODEL_NAME})") | |
if not callable(llm_tokenizer): | |
print(f"🚨 에러: LLM 토크나이저 로딩 실패 또는 callable 객체가 아닙니다. 타입: {type(llm_tokenizer)}"); raise TypeError("LLM tokenizer not callable") | |
except Exception as e: | |
print(f"🚨🚨🚨 LLM 로딩 실패 (치명적): {e}"); raise | |
try: | |
sbert_model_instance = SentenceTransformer(SBERT_MODEL_NAME, device=DEVICE) | |
print(f"✅ SBERT 모델 로드 완료 ({SBERT_MODEL_NAME})") | |
except Exception as e: | |
print(f"🚨 SBERT 모델 로딩 실패 (일부 기능 제한될 수 있음): {e}") | |
sbert_model_instance = None # Ensure it's None if setup fails | |
print(f"⏱️ 전체 모델 로딩 시간: {time.time() - start_load_time:.2f}초") | |
print("\n⏳ Chroma DB 연결 중...") | |
if not CHROMA_DB_DIR.exists(): | |
print(f"🚨🚨🚨 Chroma DB 디렉토리 ({CHROMA_DB_DIR})를 찾을 수 없어 DB 연결을 건너<0xEB><0x81><0xB0>니다. API가 정상 작동하지 않습니다.") | |
db = None | |
retriever = None | |
else: | |
try: | |
db = Chroma(embedding_function=embedding_model, persist_directory=str(CHROMA_DB_DIR), collection_name="midjourney-prompts") | |
retriever = db.as_retriever(search_kwargs={"k": RETRIEVER_K}) | |
print(f"✅ Chroma DB 연결 완료 (Collection: {db._collection.name if db and hasattr(db, '_collection') else 'N/A'}, Retriever k={RETRIEVER_K})") | |
if db: | |
try: | |
sample_docs = db.get(limit=1) | |
if not sample_docs or not sample_docs.get('ids'): print("⚠️ 경고: Chroma DB 컬렉션이 비어있거나 접근할 수 없습니다.") | |
else: print(f" ℹ️ DB 샘플 ID 확인: {sample_docs.get('ids')}") | |
except Exception as db_check_e: print(f"⚠️ Chroma DB 샘플 확인 중 오류: {db_check_e}") | |
except Exception as e: | |
print(f"🚨🚨🚨 Chroma DB 연결 실패 (치명적일 수 있음): {e}") | |
db = None # Ensure it's None | |
retriever = None # Ensure it's None | |
# Langchain Components (Models must be loaded first) | |
if llm_model_instance and callable(llm_tokenizer): | |
print("\n⛓️ Langchain Chain 구성 중...") | |
inferencer = LoRAInferencer(llm_model_instance, llm_tokenizer) # 클래스명 변경 없음 | |
llm_chain = ( | |
{ | |
"input": RunnablePassthrough(), | |
"history": lambda x: format_memory_string(memory.chat_memory.messages), | |
"already_searched_wiki_ref": lambda x: already_searched_wiki # 전역 set 전달 | |
} | |
| RunnablePassthrough.assign( | |
retrieved_wiki_context=lambda x: retrieve_wikipedia_context(x["input"], x["already_searched_wiki_ref"]) | |
) | |
| RunnablePassthrough.assign( | |
modified_korean_request=lambda x: x["input"] + " 미드저니 프롬프트 작성해줘", | |
translated_request=lambda x: translate_ko_to_en(x["input"]), | |
) | |
| RunnablePassthrough.assign( | |
retrieved_chroma_context=lambda x: retrieve_english_context(x["translated_request"]) | |
) | |
| (lambda x: print(f""" | |
DEBUG (API): Prompt Inputs Ready: | |
- Original Korean Input: {x.get('input', '')[:50]}... | |
- Modified Korean Request: {x.get('modified_korean_request', '')[:50]}... | |
- Translated Request (for Chroma): {x.get('translated_request', '')[:50]}... | |
- Retrieved Chroma Context: {x.get('retrieved_chroma_context', '')[:100]}... | |
- Retrieved Wiki Context: {x.get('retrieved_wiki_context', '')[:100]}... | |
- History (Formatted): {x.get('history', '')[-200:]}... | |
""") or x) # 디버깅 후 x 반환 확인 | |
| prompt # 전역 프롬프트 템플릿 사용 | |
| (lambda p: p.to_string()) | |
| inferencer | |
) | |
print("✅ Langchain Chain 구성 완료.") | |
else: | |
print("🚨🚨🚨 LLM 모델 또는 토크나이저가 제대로 로드되지 않아 Langchain Chain을 구성할 수 없습니다.") | |
llm_chain = None | |
# --- Helper Functions (from user script, adapted slightly) --- | |
def translate_ko_to_en(text: str) -> str: | |
if not text or not trans_tokenizer or not trans_model: return "" | |
try: | |
with torch.no_grad(): | |
inputs = trans_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) | |
outputs = trans_model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True) | |
translated_text = trans_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: print(f"🚨 번역 중 오류 발생 (입력: {text[:50]}...): {e}"); return f"[Translation Error: {e}]" | |
def retrieve_english_context(translated_english_text: str) -> str: | |
if not translated_english_text or "[Translation Error:" in translated_english_text or not retriever or not embedding_model: | |
print(" ⚠️ 번역 실패, 텍스트 부재, 리트리버/임베딩 모델 미로드로 Chroma 검색을 건너<0xEB><0x81><0xB0>니다.") | |
return "" | |
start_retrieval = time.time() | |
try: | |
initial_docs: list[Document] = retriever.invoke(translated_english_text) # retriever는 전역 사용 | |
if not initial_docs: print(" ⚠️ Chroma DB에서 초기 문서를 찾지 못했습니다."); return "" | |
query_embedding_list = embedding_model.embed_query(translated_english_text) # embedding_model은 전역 사용 | |
query_embedding = torch.tensor(query_embedding_list, dtype=torch.float).to(DEVICE).unsqueeze(0) | |
doc_contents = [doc.page_content for doc in initial_docs if doc.page_content] | |
if not doc_contents: print(" ⚠️ 검색된 문서에 유효한 내용이 없습니다."); return "" | |
doc_embeddings_list = embedding_model.embed_documents(doc_contents) # embedding_model은 전역 사용 | |
doc_embeddings = torch.tensor(doc_embeddings_list, dtype=torch.float).to(DEVICE) | |
similarities = F.cosine_similarity(query_embedding, doc_embeddings, dim=1) | |
similarities_list = similarities.cpu().tolist() | |
doc_similarity_pairs = list(zip(initial_docs, similarities_list)) | |
filtered_docs = [(doc, sim) for doc, sim in doc_similarity_pairs if sim >= SIMILARITY_THRESHOLD] | |
if not filtered_docs: print(f" ❌ 유사도 {SIMILARITY_THRESHOLD} 이상인 Chroma 문서를 찾지 못했습니다."); return "" | |
filtered_docs.sort(key=lambda item: item[1], reverse=True) | |
best_doc, best_sim = filtered_docs[0] | |
print(f" ✅ 가장 유사한 Chroma 문서 선택 (유사도: {best_sim:.4f})") | |
return best_doc.page_content.strip() | |
except Exception as e: | |
print(f"🚨 Chroma 컨텍스트 검색/유사도 계산 중 오류 발생: {e}") | |
traceback.print_exc() | |
return "[Context Retrieval/Similarity Error]" | |
def extract_keywords(text: str) -> list[str]: | |
if not okt: print(" ⚠️ Okt 형태소 분석기가 로드되지 않아 키워드 추출을 건너<0xEB><0x81><0xB0>니다."); return [] | |
try: | |
nouns = okt.nouns(text) | |
verbs_adjectives = [w for w, pos in okt.pos(text, stem=True) if pos in ['Adjective', 'Verb']] # 어간 추출 추가 | |
keywords = [w for w in nouns + verbs_adjectives if w not in STOPWORDS and len(w) > 1] | |
return list(set(keywords)) | |
except Exception as e: print(f"🚨 키워드 추출 중 오류: {e}"); return [] | |
def sort_by_semantic_importance(text: str, keywords: list[str]) -> list[str]: | |
if not sbert_model_instance or not keywords: return keywords # sbert_model_instance 사용 | |
try: | |
text_emb = sbert_model_instance.encode(text, convert_to_tensor=True) | |
keyword_embs = sbert_model_instance.encode(keywords, convert_to_tensor=True) | |
scores = util.cos_sim(text_emb, keyword_embs)[0] | |
sorted_keywords = [kw for kw, _ in sorted(zip(keywords, scores.cpu().tolist()), key=lambda x: -x[1])] | |
return sorted_keywords | |
except Exception as e: print(f"🚨 키워드 중요도 정렬 중 오류: {e}"); return keywords | |
def get_wiki_content(word: str, max_length: int = 200) -> str: | |
if not wiki: print(" ⚠️ Wikipedia API가 설정되지 않아 검색을 건너<0xEB><0x81><0xB0>니다."); return "" | |
try: | |
page = wiki.page(word) | |
if page.exists(): | |
summary = page.summary[:max_length].replace("\n", " ") | |
return summary | |
return "" | |
except Exception as e: print(f"🚨 위키피디아 검색 중 오류 ('{word}'): {e}"); return "" | |
def retrieve_wikipedia_context(korean_input: str, current_already_searched_wiki: set) -> str: # 인자로 already_searched_wiki 받음 | |
if not korean_input or not okt or not wiki or not sbert_model_instance: | |
print(" ⚠️ Wikipedia 컨텍스트 생성에 필요한 요소 부족 (okt, wiki, sbert 중 하나 이상)") | |
return "" | |
start_wiki_retrieval = time.time() | |
keywords = extract_keywords(korean_input) | |
if not keywords: print(" ⚠️ 키워드를 추출하지 못했습니다."); return "" | |
sorted_keywords = sort_by_semantic_importance(korean_input, keywords) | |
keywords_to_search = [kw for kw in sorted_keywords if kw not in current_already_searched_wiki][:3] # 전달받은 set 사용 | |
if not keywords_to_search: print(" ℹ️ 이미 검색했거나 검색할 새로운 위키 키워드가 없습니다."); return "" | |
wiki_context_parts = [] | |
try: | |
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: | |
future_to_keyword = {executor.submit(get_wiki_content, kw): kw for kw in keywords_to_search} | |
for future in concurrent.futures.as_completed(future_to_keyword): | |
keyword = future_to_keyword[future] | |
try: | |
result = future.result() | |
if result: | |
wiki_context_parts.append(f"- {keyword}: {result}") | |
current_already_searched_wiki.add(keyword) # 전달받은 set 업데이트 (전역 set이 업데이트됨) | |
except Exception as exc: print(f' 🚨 위키 검색 중 예외 발생 ({keyword}): {exc}') | |
wiki_context = "\n".join(wiki_context_parts) | |
total_wiki_time = time.time() - start_wiki_retrieval | |
if wiki_context: print(f" ✅ Wikipedia 컨텍스트 생성 완료 ({total_wiki_time:.2f}초)"); return wiki_context.strip() | |
else: print(f" ❌ 관련된 Wikipedia 정보를 찾지 못했습니다. ({total_wiki_time:.2f}초)"); return "" | |
except Exception as e: | |
print(f"🚨 Wikipedia 컨텍스트 생성 중 오류 발생: {e}") | |
traceback.print_exc() | |
return "[Wikipedia Context Retrieval Error]" | |
# --- Langchain Components ( 정의는 load_models_and_setup 안으로 이동 ) --- | |
class LoRAInferencer(Runnable): # 스크립트에 있는 클래스명 사용 | |
def __init__(self, model, tokenizer): | |
self.model = model | |
if not callable(tokenizer): raise TypeError(f"LoRAInferencer 초기화 실패: 전달된 tokenizer가 callable이 아닙니다.") | |
self.tokenizer = tokenizer | |
def invoke(self, input: str, config=None): # input 타입 명시 (str) | |
if not callable(self.tokenizer): raise TypeError(f"LoRAInferencer invoke 실패: self.tokenizer가 callable이 아닙니다.") | |
prompt_text = input # input은 이미 prompt.to_string()의 결과인 문자열 | |
try: | |
inputs = self.tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True, max_length=1536).to(self.model.device) | |
if 'token_type_ids' in inputs: del inputs['token_type_ids'] # Gemma 같은 모델은 token_type_ids 불필요 | |
except Exception as e: print(f"🚨 토크나이징 중 에러 발생: {e}"); raise | |
try: | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2, | |
no_repeat_ngram_size=3, max_new_tokens=300, pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, num_return_sequences=1 | |
) | |
# 입력 프롬프트를 제외한 생성된 부분만 디코딩 | |
output_token_ids = outputs[0][inputs['input_ids'].shape[1]:] | |
decoded_output = self.tokenizer.decode(output_token_ids, skip_special_tokens=True).strip() | |
# 후처리 (스크립트 내용 반영) | |
if decoded_output.startswith("[Midjourney Prompt (English)]"): | |
decoded_output = decoded_output.replace("[Midjourney Prompt (English)]","").strip() | |
if not decoded_output: | |
print("⚠️ 생성된 결과가 비어있습니다. 입력 프롬프트의 마지막 부분을 확인합니다.") | |
answer_marker = "[Midjourney Prompt (English)]\n" | |
if answer_marker in prompt_text: # prompt_text는 LLM에 들어간 전체 프롬프트 | |
potential_output = prompt_text.split(answer_marker)[-1].strip() | |
if potential_output and potential_output != input: # input은 chain의 최종 입력이 아닌 초기 입력 | |
decoded_output = potential_output | |
print(" -> 입력 프롬프트 마지막 부분을 결과로 사용.") | |
else: | |
decoded_output = "..." | |
print(" -> 기본값 '...' 사용 (potential_output이 비었거나 초기 입력과 동일).") | |
else: | |
decoded_output = "..." | |
print(" -> 기본값 '...' 사용 (answer_marker 없음).") | |
return decoded_output | |
except Exception as e: | |
print(f"🚨 모델 생성 중 에러 발생: {e}") | |
traceback.print_exc() | |
# 오류 시 입력 프롬프트에서 답변 부분 추출 시도 | |
try: return prompt_text.split("[Midjourney Prompt (English)]\n")[-1].strip() | |
except: raise # 이것도 실패하면 원래 예외 발생 | |
prompt = PromptTemplate.from_template('''You are a professional prompt engineer creating evolving prompts for Midjourney. | |
Continuously build upon the previous conversation history. | |
Generate a concise and effective Midjourney prompt IN ENGLISH ONLY based on the entire conversation flow. | |
[Reference English Examples from Database (Chroma)] | |
{retrieved_chroma_context} | |
[Reference Context from Wikipedia (Korean)] | |
{retrieved_wiki_context} | |
[Conversation Context So Far] | |
{history} | |
[Now Continue the Prompt for Midjourney based on the above conversation and the following request:] | |
User Request (Korean): {modified_korean_request} | |
[Midjourney Prompt (English)] | |
''') | |
def format_memory_string(messages): | |
# print(f"--- DEBUG: format_memory_string called with {len(messages)} messages ---") # API 로그에 너무 많을 수 있음 | |
if not messages: return "No prior conversation." | |
lines = [] | |
for msg in messages: | |
role = "User" if msg.type == "human" else "Assistant" | |
lines.append(f"{role}: {msg.content}") | |
formatted_history = "\n".join(lines) | |
# print(f" 🔄 Formatted History (for Prompt):\n{formatted_history[-300:]}...") # API 로그에 너무 많을 수 있음 | |
return formatted_history | |
# --- Application Startup Event --- | |
async def startup_event(): | |
print("🌟 FastAPI 애플리케이션 시작... 모델 및 설정 로딩...") | |
try: | |
load_models_and_setup() | |
print("✅ 모든 모델 및 설정 로딩 완료.") | |
except Exception as e: | |
print(f"🚨🚨🚨 애플리케이션 시작 중 치명적 오류 발생: {e}") | |
traceback.print_exc() | |
# 여기서 애플리케이션을 종료하거나, 오류 상태로 실행되도록 할 수 있습니다. | |
# Hugging Face Space에서는 오류가 나도 계속 실행될 수 있도록 두는 것이 일반적입니다 (로그 확인 가능). | |
# raise HTTPException(status_code=500, detail=f"서버 초기화 실패: {e}") # 이렇게 하면 서버가 시작 안 될 수 있음 | |
# --- API Endpoints --- | |
async def read_root(): | |
return {"message": "Midjourney 프롬프트 생성기 API가 실행 중입니다. /docs 에서 API 문서를 확인하세요."} | |
async def health_check(): | |
# 모델 로드 상태 등을 더 상세히 체크할 수 있습니다. | |
if llm_chain and retriever and okt and wiki and sbert_model_instance and trans_model and embedding_model: | |
return {"status": "healthy", "message": "모든 주요 구성 요소가 로드되었습니다."} | |
else: | |
missing = [] | |
if not llm_chain: missing.append("LLM Chain") | |
if not retriever: missing.append("Chroma Retriever") | |
if not okt: missing.append("Okt") | |
if not wiki: missing.append("Wikipedia") | |
if not sbert_model_instance: missing.append("SBERT model") | |
if not trans_model: missing.append("Translation model") | |
if not embedding_model: missing.append("Embedding model") | |
return {"status": "degraded", "message": f"일부 구성 요소 로드 실패: {', '.join(missing)}"} | |
async def generate_api_prompt(request: GenerateRequest): | |
global memory, already_searched_wiki # 전역 변수 사용 명시 | |
if not llm_chain: | |
print("🚨 /generate 호출: LLM 체인이 로드되지 않았습니다.") | |
raise HTTPException(status_code=503, detail="서버가 아직 준비되지 않았거나 초기화에 실패했습니다. LLM 체인을 사용할 수 없습니다.") | |
user_input = request.user_input | |
if not user_input: | |
raise HTTPException(status_code=400, detail="user_input 필드는 비워둘 수 없습니다.") | |
print(f"💬 API /generate 호출 (입력: '{user_input[:50]}...')") | |
start_time = time.time() | |
try: | |
# --- DEBUG PRINT (API) --- | |
print(f"--- DEBUG (API): Memory before llm_chain.invoke: {len(memory.chat_memory.messages)} messages ---") | |
# if memory.chat_memory.messages: | |
# for i, msg in enumerate(memory.chat_memory.messages): | |
# print(f" DEBUG (API) MSG {i}: type={msg.type}, content='{str(msg.content)[:50]}...'") | |
# LangChain 실행 | |
result = llm_chain.invoke(user_input) # llm_chain은 전역 변수 사용 | |
end_time = time.time() | |
processing_time = end_time - start_time | |
print(f"⏱️ API 생성 및 검색 시간: {processing_time:.2f}초") | |
history_updated = False | |
if isinstance(result, str) and not result.startswith("[오류 발생]"): | |
try: | |
memory.save_context({"input": user_input}, {"output": result}) # memory는 전역 변수 사용 | |
history_updated = True | |
print(" ✅ API 대화 기록에 저장되었습니다.") | |
# --- DEBUG PRINT (API) --- | |
# print(f"--- DEBUG (API): Memory after save_context: {len(memory.chat_memory.messages)} messages ---") | |
# if memory.chat_memory.messages: | |
# for i, msg in enumerate(memory.chat_memory.messages): | |
# print(f" DEBUG (API) SAVED MSG {i}: type={msg.type}, content='{str(msg.content)[:50]}...'") | |
except Exception as mem_err: | |
print(f" 🚨 API 대화 기록 저장 중 오류 발생: {mem_err}") | |
else: | |
print(" ⚠️ API 생성 결과가 오류 문자열이거나 유효하지 않아 기록되지 않았습니다.") | |
return GenerateResponse( | |
generated_prompt=result, | |
processing_time_seconds=round(processing_time, 2), | |
# history_updated=history_updated # 필요하다면 응답에 포함 | |
) | |
except Exception as e: | |
print(f"🚨 API /generate 실행 중 심각한 오류 발생: {e}") | |
traceback.print_exc() | |
end_time = time.time() | |
processing_time = end_time - start_time | |
raise HTTPException( | |
status_code=500, | |
detail=f"요청 처리 중 서버 오류 발생: {e}" | |
) | |
# @app.post("/reset_memory", response_model=ResetMemoryResponse, summary="대화 기록 초기화", description="서버의 대화 기록과 위키 검색 기록을 초기화합니다.") | |
# async def reset_api_memory(): | |
# global memory, already_searched_wiki # 전역 변수 사용 명시 | |
# memory_cleared_count = len(memory.chat_memory.messages) | |
# wiki_cleared_count = len(already_searched_wiki) | |
memory.clear() | |
already_searched_wiki.clear() | |
print(f"🔄 API /reset_memory 호출: 대화 기록 ({memory_cleared_count}개) 및 위키 검색 기록 ({wiki_cleared_count}개) 초기화 완료.") | |
return ResetMemoryResponse( | |
message="대화 기록 및 위키 검색 기록이 성공적으로 초기화되었습니다.", | |
cleared_memory=True, | |
cleared_wiki_history=True | |
) | |
print(f"----- FastAPI instance 'app' IS DEFINED in app.app.py, type: {type(app)}, dir: {dir()}") | |
# 로컬 테스트용 uvicorn 실행 명령어 (실제 Space에서는 Dockerfile의 CMD 사용) | |
# if __name__ == "__main__": | |
# import uvicorn | |
# # load_models_and_setup() # uvicorn이 시작 시 @app.on_event("startup") 호출 | |
# uvicorn.run(app, host="0.0.0.0", port=8000) |