Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Body | |
from fastapi.middleware.cors import CORSMiddleware | |
import tempfile | |
from models.models import LostItemAnalyzer | |
from typing import Dict, Any | |
import base64 | |
from PIL import Image | |
import io | |
# FastAPI μ ν리μΌμ΄μ μμ± | |
app = FastAPI(title="λΆμ€λ¬Ό μ΄λ―Έμ§ λΆμ API") | |
# CORS μ€μ | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # νλ‘λμ μμλ νΉμ λλ©μΈμΌλ‘ μ ννμΈμ | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# λΆμκΈ° μ΄κΈ°ν | |
analyzer = LostItemAnalyzer() | |
# API λ£¨νΈ κ²½λ‘ νΈλ€λ¬ | |
async def root(): | |
return {"message": "λΆμ€λ¬Ό μ΄λ―Έμ§ λΆμ APIκ° μ€ν μ€μ λλ€."} | |
# νμΌ μ λ‘λλ₯Ό ν΅ν μ΄λ―Έμ§ λΆμ | |
async def analyze_image_upload(file: UploadFile = File(...)): | |
# νμΌ νμ₯μ κ²μ¦ | |
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] | |
file_ext = os.path.splitext(file.filename)[1].lower() | |
if file_ext not in valid_extensions: | |
raise HTTPException( | |
status_code=400, | |
detail=f"μ§μλμ§ μλ νμΌ νμμ λλ€. μ§μλλ νμ: {', '.join(valid_extensions)}" | |
) | |
try: | |
# μμ νμΌλ‘ μ μ₯ | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp: | |
temp_path = temp.name | |
content = await file.read() | |
temp.write(content) | |
# μ΄λ―Έμ§ λΆμ | |
result = analyzer.analyze_lost_item(temp_path) | |
# μμ νμΌ μμ | |
os.unlink(temp_path) | |
if result["success"]: | |
# νκ΅μ΄ λ²μ κ²°κ³Όλ§ λ°ν | |
ko_result = { | |
"status": "success", | |
"data": { | |
"title": result["data"]["translated"]["title"], | |
"category": result["data"]["translated"]["category"], | |
"color": result["data"]["translated"]["color"], | |
"material": result["data"]["translated"]["material"], | |
"brand": result["data"]["translated"]["brand"], | |
"description": result["data"]["translated"]["description"], | |
"distinctive_features": result["data"]["translated"]["distinctive_features"] | |
} | |
} | |
return ko_result | |
else: | |
raise HTTPException(status_code=500, detail=result["error"]) | |
except Exception as e: | |
# μμΈ λ°μ μ μμ νμΌ μμ μλ | |
try: | |
if 'temp_path' in locals() and os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except: | |
pass | |
raise HTTPException(status_code=500, detail=f"μ΄λ―Έμ§ λΆμ μ€ μ€λ₯ λ°μ: {str(e)}") | |
# Base64 μΈμ½λ©λ μ΄λ―Έμ§ λΆμ (Javaμμ μ¬μ©ν μλν¬μΈνΈ) | |
# λ©λͺ¨λ¦¬μμ μ§μ μ²λ¦¬ | |
async def analyze_image_base64(payload: Dict[str, Any] = Body(...)): | |
try: | |
if "image" not in payload: | |
raise HTTPException(status_code=400, detail="μμ²μ 'image' νλκ° νμν©λλ€") | |
base64_str = payload["image"] | |
# Base64 λ¬Έμμ΄μμ ν€λ μ κ±° (μμ κ²½μ°) | |
if "," in base64_str: | |
base64_str = base64_str.split(",")[1] | |
# Base64 λμ½λ© | |
image_data = base64.b64decode(base64_str) | |
image = Image.open(io.BytesIO(image_data)) | |
# λ©λͺ¨λ¦¬ λ΄ μ΄λ―Έμ§λ₯Ό λ°μ΄νΈ μ€νΈλ¦ΌμΌλ‘ μ μ₯ | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
# μμ νμΌ κ²½λ‘ λμ μμ± | |
temp_path = f"/tmp/uploads/temp_{os.getpid()}_{id(image)}.jpg" | |
# μ΄λ―Έμ§ μ μ₯ | |
with open(temp_path, 'wb') as f: | |
f.write(img_byte_arr) | |
# μ΄λ―Έμ§ λΆμ | |
result = analyzer.analyze_lost_item(temp_path) | |
# μμ νμΌ μμ | |
os.unlink(temp_path) | |
if result["success"]: | |
# νκ΅μ΄ λ²μ κ²°κ³Όλ§ λ°ν | |
ko_result = { | |
"status": "success", | |
"data": { | |
"title": result["data"]["translated"]["title"], | |
"category": result["data"]["translated"]["category"], | |
"color": result["data"]["translated"]["color"], | |
"material": result["data"]["translated"]["material"], | |
"brand": result["data"]["translated"]["brand"], | |
"description": result["data"]["translated"]["description"], | |
"distinctive_features": result["data"]["translated"]["distinctive_features"] | |
} | |
} | |
return ko_result | |
else: | |
raise HTTPException(status_code=500, detail=result["error"]) | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"μ΄λ―Έμ§ λΆμ μ€ μ€λ₯ λ°μ: {str(e)}") | |
# API μν, νκ²½λ³μ νμΈ | |
async def status(): | |
return { | |
"status": "ok", | |
"papago_api": "active" if analyzer.translator.use_papago else "inactive", | |
"models_loaded": True | |
} | |
# λ©μΈ μ€ν μ½λ (λ‘컬 ν μ€νΈμ©) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=5001) |