Spaces:
Sleeping
Sleeping
""" | |
FastAPI ์ ํ๋ฆฌ์ผ์ด์ ๋ฉ์ธ ๋ชจ๋ | |
""" | |
import os | |
import sys | |
import logging | |
import tempfile | |
from fastapi import FastAPI, Request, HTTPException, Query, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from typing import List, Dict, Any, Optional, Union | |
import json | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์ | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' | |
os.environ['HF_HOME'] = '/tmp/huggingface_cache' | |
# ๋๋ ํ ๋ฆฌ ์์ฑ | |
os.makedirs('/tmp/huggingface_cache', exist_ok=True) | |
os.makedirs('/tmp/uploads', exist_ok=True) | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# ํ์ํ ๋ชจ๋ ์ํฌํธ | |
from models.clip_model import KoreanCLIPModel | |
from utils.similarity import calculate_similarity, find_similar_items | |
from api.routes.matching_routers import LostItemPost, ImageMatchingRequest, MatchingResult, MatchingResponse | |
# ๋ชจ๋ธ ์ด๊ธฐํ (์ฑ๊ธํค์ผ๋ก ๋ก๋) | |
clip_model = None | |
def get_clip_model(): | |
""" | |
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ์ธ์คํด์ค๋ฅผ ๋ฐํ (์ฑ๊ธํค ํจํด) | |
""" | |
global clip_model | |
if clip_model is None: | |
try: | |
clip_model = KoreanCLIPModel() | |
return clip_model | |
except Exception as e: | |
logger.error(f"CLIP ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {str(e)}") | |
# ์คํจ ์ None ๋ฐํ (ํ ์คํธ ๊ธฐ๋ฐ ๋งค์นญ๋ง ๊ฐ๋ฅ) | |
return None | |
return clip_model | |
# FastAPI ์ ํ๋ฆฌ์ผ์ด์ ์์ฑ | |
app = FastAPI( | |
title="์ต๋๋ฌผ ์ ์ฌ๋ ๊ฒ์ API", | |
description="ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ API", | |
version="1.0.0" | |
) | |
# CORS ๋ฏธ๋ค์จ์ด ์ค์ | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ์ ์ญ ์์ธ ์ฒ๋ฆฌ | |
async def global_exception_handler(request: Request, exc: Exception): | |
""" | |
์ ์ญ ์์ธ ์ฒ๋ฆฌ๊ธฐ | |
""" | |
logger.error(f"์์ฒญ ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์: {str(exc)}") | |
return JSONResponse( | |
status_code=500, | |
content={"success": False, "message": f"์๋ฒ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(exc)}"} | |
) | |
# API ์๋ํฌ์ธํธ ์ ์ | |
async def find_similar_items_api( | |
request: Union[LostItemPost, ImageMatchingRequest], | |
threshold: float = Query(0.7, description="์ ์ฌ๋ ์๊ณ๊ฐ (0.0 ~ 1.0)"), | |
limit: int = Query(10, description="๋ฐํํ ์ต๋ ํญ๋ชฉ ์") | |
): | |
""" | |
์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ ์ฌํ ์ต๋๋ฌผ์ ์ฐพ๋ API ์๋ํฌ์ธํธ | |
Args: | |
request: ์ฌ์ฉ์์ ๋ถ์ค๋ฌผ ๊ฒ์๊ธ ๋๋ ์ด๋ฏธ์ง ๋งค์นญ ์์ฒญ | |
threshold: ์ ์ฌ๋ ์๊ณ๊ฐ | |
limit: ๋ฐํํ ์ต๋ ํญ๋ชฉ ์ | |
Returns: | |
MatchingResponse: ๋งค์นญ ๊ฒฐ๊ณผ๊ฐ ํฌํจ๋ ์๋ต | |
""" | |
try: | |
logger.info(f"์ ์ฌ ์ต๋๋ฌผ ๊ฒ์ ์์ฒญ: threshold={threshold}, limit={limit}") | |
# ์์ฒญ ๋ฐ์ดํฐ ๋ณํ | |
user_post = {} | |
if isinstance(request, LostItemPost): | |
user_post = request.dict() | |
else: | |
user_post = request.dict() | |
# Base64 ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉด ์ด๋ฏธ์ง ์ฒ๋ฆฌ ๋ก์ง ์ถ๊ฐ | |
if user_post.get("image_base64"): | |
try: | |
# Base64 ์ด๋ฏธ์ง ๋์ฝ๋ฉ | |
base64_str = user_post["image_base64"] | |
# Base64 ๋ฌธ์์ด์์ ํค๋ ์ ๊ฑฐ (์์ ๊ฒฝ์ฐ) | |
if "," in base64_str: | |
base64_str = base64_str.split(",")[1] | |
image_data = base64.b64decode(base64_str) | |
image = Image.open(BytesIO(image_data)) | |
# ์ด๋ฏธ์ง ์ฌ์ฉ (CLIP ๋ชจ๋ธ์ ์ ๋ฌ) | |
user_post["image"] = image | |
logger.info("Base64 ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ") | |
except Exception as e: | |
logger.error(f"Base64 ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์คํจ: {str(e)}") | |
# ์ฌ๊ธฐ์ DB ๋์ ์์ฒญ์์ ์ ๋ฌ๋ ์ต๋๋ฌผ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
lost_items = [] | |
# ์์ฒญ์ ์ต๋๋ฌผ ๋ฐ์ดํฐ๊ฐ ์์ผ๋ฉด ์ฌ์ฉ | |
if hasattr(request, 'lost_items') and request.lost_items: | |
lost_items = request.lost_items | |
if not lost_items: | |
return MatchingResponse( | |
success=False, | |
message="์ต๋๋ฌผ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค. ์์ฒญ์ ์ต๋๋ฌผ ๋ฐ์ดํฐ๋ฅผ ํฌํจํด์ฃผ์ธ์.", | |
result=None | |
) | |
# CLIP ๋ชจ๋ธ ๋ก๋ | |
clip_model_instance = get_clip_model() | |
# ์ ์ฌํ ํญ๋ชฉ ์ฐพ๊ธฐ | |
similar_items = find_similar_items(user_post, lost_items, threshold, clip_model_instance) | |
# ๊ฒฐ๊ณผ ์ ํ | |
similar_items = similar_items[:limit] | |
# ์๋ต ๊ตฌ์ฑ | |
result = MatchingResult( | |
total_matches=len(similar_items), | |
similarity_threshold=threshold, | |
matches=[ | |
{ | |
"item": item["item"], | |
"similarity": round(item["similarity"], 4), | |
"details": { | |
"text_similarity": round(item["details"]["text_similarity"], 4), | |
"image_similarity": round(item["details"]["image_similarity"], 4) if item["details"]["image_similarity"] else None, | |
"category_similarity": round(item["details"]["details"]["category"], 4), | |
"item_name_similarity": round(item["details"]["details"]["item_name"], 4), | |
"color_similarity": round(item["details"]["details"]["color"], 4), | |
"content_similarity": round(item["details"]["details"]["content"], 4) | |
} | |
} | |
for item in similar_items | |
] | |
) | |
return MatchingResponse( | |
success=True, | |
message=f"{len(similar_items)}๊ฐ์ ์ ์ฌํ ์ต๋๋ฌผ์ ์ฐพ์์ต๋๋ค.", | |
result=result | |
) | |
except Exception as e: | |
logger.error(f"API ํธ์ถ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"์์ฒญ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}") | |
async def test_endpoint(): | |
""" | |
API ํ ์คํธ์ฉ ์๋ํฌ์ธํธ | |
Returns: | |
dict: ํ ์คํธ ์๋ต | |
""" | |
return {"message": "API๊ฐ ์ ์์ ์ผ๋ก ์๋ ์ค์ ๋๋ค."} | |
async def status(): | |
""" | |
API ์ํ ์๋ํฌ์ธํธ | |
Returns: | |
dict: API ์ํ ์ ๋ณด | |
""" | |
# CLIP ๋ชจ๋ธ ๋ก๋ ์๋ | |
model = get_clip_model() | |
return { | |
"status": "ok", | |
"models_loaded": model is not None, | |
"version": "1.0.0" | |
} | |
# ๋ฃจํธ ์๋ํฌ์ธํธ | |
async def root(): | |
""" | |
๋ฃจํธ ์๋ํฌ์ธํธ - API ์ ๋ณด ์ ๊ณต | |
""" | |
return { | |
"app_name": "์ต๋๋ฌผ ์ ์ฌ๋ ๊ฒ์ API", | |
"version": "1.0.0", | |
"description": "ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.", | |
"api_endpoint": "/api/matching/find-similar", | |
"test_endpoint": "/api/matching/test", | |
"status_endpoint": "/api/status" | |
} |