์†์„œํ˜„
feat: initial file
ed93606
raw
history blame
7.7 kB
"""
FastAPI ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๋ฉ”์ธ ๋ชจ๋“ˆ
"""
import os
import sys
import logging
import tempfile
from fastapi import FastAPI, Request, HTTPException, Query, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import List, Dict, Any, Optional, Union
import json
import base64
from io import BytesIO
from PIL import Image
# ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs('/tmp/huggingface_cache', exist_ok=True)
os.makedirs('/tmp/uploads', exist_ok=True)
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ํ•„์š”ํ•œ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from models.clip_model import KoreanCLIPModel
from utils.similarity import calculate_similarity, find_similar_items
from api.routes.matching_routers import LostItemPost, ImageMatchingRequest, MatchingResult, MatchingResponse
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (์‹ฑ๊ธ€ํ†ค์œผ๋กœ ๋กœ๋“œ)
clip_model = None
def get_clip_model():
"""
ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜ (์‹ฑ๊ธ€ํ†ค ํŒจํ„ด)
"""
global clip_model
if clip_model is None:
try:
clip_model = KoreanCLIPModel()
return clip_model
except Exception as e:
logger.error(f"CLIP ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
# ์‹คํŒจ ์‹œ None ๋ฐ˜ํ™˜ (ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ ๋งค์นญ๋งŒ ๊ฐ€๋Šฅ)
return None
return clip_model
# FastAPI ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ƒ์„ฑ
app = FastAPI(
title="์Šต๋“๋ฌผ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰ API",
description="ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉ์ž ๊ฒŒ์‹œ๊ธ€๊ณผ ์Šต๋“๋ฌผ ๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” API",
version="1.0.0"
)
# CORS ๋ฏธ๋“ค์›จ์–ด ์„ค์ •
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ์ „์—ญ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""
์ „์—ญ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ๊ธฐ
"""
logger.error(f"์š”์ฒญ ์ฒ˜๋ฆฌ ์ค‘ ์˜ˆ์™ธ ๋ฐœ์ƒ: {str(exc)}")
return JSONResponse(
status_code=500,
content={"success": False, "message": f"์„œ๋ฒ„ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(exc)}"}
)
# API ์—”๋“œํฌ์ธํŠธ ์ •์˜
@app.post("/api/matching/find-similar", response_model=MatchingResponse)
async def find_similar_items_api(
request: Union[LostItemPost, ImageMatchingRequest],
threshold: float = Query(0.7, description="์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’ (0.0 ~ 1.0)"),
limit: int = Query(10, description="๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ํ•ญ๋ชฉ ์ˆ˜")
):
"""
์‚ฌ์šฉ์ž ๊ฒŒ์‹œ๊ธ€๊ณผ ์œ ์‚ฌํ•œ ์Šต๋“๋ฌผ์„ ์ฐพ๋Š” API ์—”๋“œํฌ์ธํŠธ
Args:
request: ์‚ฌ์šฉ์ž์˜ ๋ถ„์‹ค๋ฌผ ๊ฒŒ์‹œ๊ธ€ ๋˜๋Š” ์ด๋ฏธ์ง€ ๋งค์นญ ์š”์ฒญ
threshold: ์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’
limit: ๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ํ•ญ๋ชฉ ์ˆ˜
Returns:
MatchingResponse: ๋งค์นญ ๊ฒฐ๊ณผ๊ฐ€ ํฌํ•จ๋œ ์‘๋‹ต
"""
try:
logger.info(f"์œ ์‚ฌ ์Šต๋“๋ฌผ ๊ฒ€์ƒ‰ ์š”์ฒญ: threshold={threshold}, limit={limit}")
# ์š”์ฒญ ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
user_post = {}
if isinstance(request, LostItemPost):
user_post = request.dict()
else:
user_post = request.dict()
# Base64 ์ด๋ฏธ์ง€๊ฐ€ ์žˆ์œผ๋ฉด ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๋กœ์ง ์ถ”๊ฐ€
if user_post.get("image_base64"):
try:
# Base64 ์ด๋ฏธ์ง€ ๋””์ฝ”๋”ฉ
base64_str = user_post["image_base64"]
# Base64 ๋ฌธ์ž์—ด์—์„œ ํ—ค๋” ์ œ๊ฑฐ (์žˆ์„ ๊ฒฝ์šฐ)
if "," in base64_str:
base64_str = base64_str.split(",")[1]
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
# ์ด๋ฏธ์ง€ ์‚ฌ์šฉ (CLIP ๋ชจ๋ธ์— ์ „๋‹ฌ)
user_post["image"] = image
logger.info("Base64 ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์™„๋ฃŒ")
except Exception as e:
logger.error(f"Base64 ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์‹คํŒจ: {str(e)}")
# ์—ฌ๊ธฐ์„œ DB ๋Œ€์‹  ์š”์ฒญ์—์„œ ์ „๋‹ฌ๋œ ์Šต๋“๋ฌผ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
lost_items = []
# ์š”์ฒญ์— ์Šต๋“๋ฌผ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉ
if hasattr(request, 'lost_items') and request.lost_items:
lost_items = request.lost_items
if not lost_items:
return MatchingResponse(
success=False,
message="์Šต๋“๋ฌผ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ์š”์ฒญ์— ์Šต๋“๋ฌผ ๋ฐ์ดํ„ฐ๋ฅผ ํฌํ•จํ•ด์ฃผ์„ธ์š”.",
result=None
)
# CLIP ๋ชจ๋ธ ๋กœ๋“œ
clip_model_instance = get_clip_model()
# ์œ ์‚ฌํ•œ ํ•ญ๋ชฉ ์ฐพ๊ธฐ
similar_items = find_similar_items(user_post, lost_items, threshold, clip_model_instance)
# ๊ฒฐ๊ณผ ์ œํ•œ
similar_items = similar_items[:limit]
# ์‘๋‹ต ๊ตฌ์„ฑ
result = MatchingResult(
total_matches=len(similar_items),
similarity_threshold=threshold,
matches=[
{
"item": item["item"],
"similarity": round(item["similarity"], 4),
"details": {
"text_similarity": round(item["details"]["text_similarity"], 4),
"image_similarity": round(item["details"]["image_similarity"], 4) if item["details"]["image_similarity"] else None,
"category_similarity": round(item["details"]["details"]["category"], 4),
"item_name_similarity": round(item["details"]["details"]["item_name"], 4),
"color_similarity": round(item["details"]["details"]["color"], 4),
"content_similarity": round(item["details"]["details"]["content"], 4)
}
}
for item in similar_items
]
)
return MatchingResponse(
success=True,
message=f"{len(similar_items)}๊ฐœ์˜ ์œ ์‚ฌํ•œ ์Šต๋“๋ฌผ์„ ์ฐพ์•˜์Šต๋‹ˆ๋‹ค.",
result=result
)
except Exception as e:
logger.error(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
raise HTTPException(status_code=500, detail=f"์š”์ฒญ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
@app.get("/api/matching/test")
async def test_endpoint():
"""
API ํ…Œ์ŠคํŠธ์šฉ ์—”๋“œํฌ์ธํŠธ
Returns:
dict: ํ…Œ์ŠคํŠธ ์‘๋‹ต
"""
return {"message": "API๊ฐ€ ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ ์ค‘์ž…๋‹ˆ๋‹ค."}
@app.get("/api/status")
async def status():
"""
API ์ƒํƒœ ์—”๋“œํฌ์ธํŠธ
Returns:
dict: API ์ƒํƒœ ์ •๋ณด
"""
# CLIP ๋ชจ๋ธ ๋กœ๋“œ ์‹œ๋„
model = get_clip_model()
return {
"status": "ok",
"models_loaded": model is not None,
"version": "1.0.0"
}
# ๋ฃจํŠธ ์—”๋“œํฌ์ธํŠธ
@app.get("/")
async def root():
"""
๋ฃจํŠธ ์—”๋“œํฌ์ธํŠธ - API ์ •๋ณด ์ œ๊ณต
"""
return {
"app_name": "์Šต๋“๋ฌผ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰ API",
"version": "1.0.0",
"description": "ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉ์ž ๊ฒŒ์‹œ๊ธ€๊ณผ ์Šต๋“๋ฌผ ๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.",
"api_endpoint": "/api/matching/find-similar",
"test_endpoint": "/api/matching/test",
"status_endpoint": "/api/status"
}