Spaces:
Running
Running
์์ํ
commited on
Commit
ยท
ed93606
1
Parent(s):
c4d447f
feat: initial file
Browse files- Dockerfile +39 -0
- api/__init__.py +0 -0
- api/routes/__init__.py +6 -0
- api/routes/matching_routers.py +79 -0
- main.py +221 -0
- models/__init__.py +6 -0
- models/clip_model.py +178 -0
- requirements.txt +13 -0
- utils/__init__.py +6 -0
- utils/similarity.py +355 -0
Dockerfile
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
# ์บ์ ๋๋ ํ ๋ฆฌ ํ๊ฒฝ๋ณ์ ์ค์
|
4 |
+
ENV TRANSFORMERS_CACHE=/tmp/huggingface_cache
|
5 |
+
ENV HF_HOME=/tmp/huggingface_cache
|
6 |
+
ENV PYTHONUNBUFFERED=1
|
7 |
+
|
8 |
+
WORKDIR /app
|
9 |
+
|
10 |
+
# ์์คํ
ํจํค์ง ์ค์น
|
11 |
+
RUN apt-get update && apt-get install -y \
|
12 |
+
build-essential \
|
13 |
+
libgl1-mesa-glx \
|
14 |
+
libglib2.0-0 \
|
15 |
+
&& rm -rf /var/lib/apt/lists/*
|
16 |
+
|
17 |
+
# ์บ์ ๋๋ ํ ๋ฆฌ ์์ฑ ๋ฐ ๊ถํ ์ค์
|
18 |
+
RUN mkdir -p $TRANSFORMERS_CACHE && chmod -R 777 $TRANSFORMERS_CACHE
|
19 |
+
|
20 |
+
# ์์ ๋๋ ํ ๋ฆฌ ๊ถํ ์ค์
|
21 |
+
RUN mkdir -p /tmp/uploads && chmod 777 /tmp/uploads
|
22 |
+
ENV TMPDIR=/tmp/uploads
|
23 |
+
|
24 |
+
# ์๊ตฌ์ฌํญ ํ์ผ ๋ณต์ฌ ๋ฐ ์ค์น
|
25 |
+
COPY requirements.txt .
|
26 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
27 |
+
|
28 |
+
# ์ ํ๋ฆฌ์ผ์ด์
ํ์ผ ๋ณต์ฌ
|
29 |
+
COPY . .
|
30 |
+
|
31 |
+
# ํ๊ฒฝ ๋ณ์ ์ค์
|
32 |
+
ENV PYTHONPATH=/app
|
33 |
+
|
34 |
+
# kiwipiepy ์ด๊ธฐํ ํ์ผ ๋ค์ด๋ก๋ - ์ฌ์ ๋ค์ด๋ก๋ ๋ฌธ์ ํด๊ฒฐ
|
35 |
+
RUN python -c "from kiwipiepy import Kiwi; Kiwi()"
|
36 |
+
|
37 |
+
# ์ ํ๋ฆฌ์ผ์ด์
์คํ (ํฌํธ 7861๋ก ๋ณ๊ฒฝ)
|
38 |
+
EXPOSE 7861
|
39 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7861"]
|
api/__init__.py
ADDED
File without changes
|
api/routes/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API ๋ผ์ฐํธ ํจํค์ง
|
3 |
+
"""
|
4 |
+
from .matching_routers import router as matching_router
|
5 |
+
|
6 |
+
__all__ = ['matching_router']
|
api/routes/matching_routers.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
์ต๋๋ฌผ ๋งค์นญ ๊ด๋ จ API ๋ผ์ฐํธ
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
from typing import List, Dict, Any, Optional
|
8 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
9 |
+
from fastapi.responses import JSONResponse
|
10 |
+
from pydantic import BaseModel, Field, validator
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
# ๋ก๊น
์ค์
|
16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# ๋ผ์ฐํฐ ์์ฑ
|
20 |
+
router = APIRouter(
|
21 |
+
prefix="/api/matching",
|
22 |
+
tags=["matching"],
|
23 |
+
responses={404: {"description": "Not found"}},
|
24 |
+
)
|
25 |
+
|
26 |
+
# Pydantic ๋ชจ๋ธ ์ ์
|
27 |
+
class LostItemPost(BaseModel):
|
28 |
+
"""์ฌ์ฉ์๊ฐ ๋ถ์คํ ๋ฌผํ ๊ฒ์๊ธ ๋ชจ๋ธ"""
|
29 |
+
category: str = Field(..., description="๋ถ์ค๋ฌผ ์นดํ
๊ณ ๋ฆฌ (์: ์ง๊ฐ, ๊ฐ๋ฐฉ, ์ ์๊ธฐ๊ธฐ)")
|
30 |
+
item_name: str = Field(..., description="๋ฌผํ๋ช
(์: ๊ฒ์์ ๊ฐ์ฃฝ ์ง๊ฐ)")
|
31 |
+
color: Optional[str] = Field(None, description="๋ฌผํ ์์")
|
32 |
+
content: str = Field(..., description="๊ฒ์๊ธ ๋ด์ฉ")
|
33 |
+
location: Optional[str] = Field(None, description="๋ถ์ค ์ฅ์")
|
34 |
+
image_url: Optional[str] = Field(None, description="์ด๋ฏธ์ง URL (์๋ ๊ฒฝ์ฐ)")
|
35 |
+
lost_items: Optional[List[Dict[str, Any]]] = Field(None, description="๋น๊ตํ ์ต๋๋ฌผ ๋ฐ์ดํฐ (API ํ
์คํธ์ฉ)")
|
36 |
+
|
37 |
+
class Config:
|
38 |
+
schema_extra = {
|
39 |
+
"example": {
|
40 |
+
"category": "์ง๊ฐ",
|
41 |
+
"item_name": "๊ฒ์์ ๊ฐ์ฃฝ ์ง๊ฐ",
|
42 |
+
"color": "๊ฒ์ ์",
|
43 |
+
"content": "์ง๋์ฃผ ํ ์์ผ ๊ฐ๋จ์ญ ๊ทผ์ฒ์์ ๊ฒ์ ์ ๊ฐ์ฃฝ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค. ํ๊ธ๊ณผ ์นด๋๊ฐ ๋ค์ด์์ด์.",
|
44 |
+
"location": "๊ฐ๋จ์ญ",
|
45 |
+
"image_url": None
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
class ImageMatchingRequest(BaseModel):
|
50 |
+
"""์ด๋ฏธ์ง ๊ธฐ๋ฐ ๋งค์นญ ์์ฒญ ๋ชจ๋ธ"""
|
51 |
+
category: Optional[str] = Field(None, description="๋ถ์ค๋ฌผ ์นดํ
๊ณ ๋ฆฌ")
|
52 |
+
item_name: Optional[str] = Field(None, description="๋ฌผํ๋ช
")
|
53 |
+
color: Optional[str] = Field(None, description="์์")
|
54 |
+
content: Optional[str] = Field(None, description="๋ด์ฉ")
|
55 |
+
image_base64: Optional[str] = Field(None, description="Base64 ์ธ์ฝ๋ฉ๋ ์ด๋ฏธ์ง")
|
56 |
+
lost_items: Optional[List[Dict[str, Any]]] = Field(None, description="๋น๊ตํ ์ต๋๋ฌผ ๋ฐ์ดํฐ (API ํ
์คํธ์ฉ)")
|
57 |
+
|
58 |
+
class Config:
|
59 |
+
schema_extra = {
|
60 |
+
"example": {
|
61 |
+
"category": "์ง๊ฐ",
|
62 |
+
"item_name": "๊ฒ์์ ๊ฐ์ฃฝ ์ง๊ฐ",
|
63 |
+
"color": "๊ฒ์ ์",
|
64 |
+
"content": "์ง๋์ฃผ ํ ์์ผ ๊ฐ๋จ์ญ ๊ทผ์ฒ์์ ๊ฒ์ ์ ๊ฐ์ฃฝ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค.",
|
65 |
+
"image_base64": "[base64 encoded image string]"
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
class MatchingResult(BaseModel):
|
70 |
+
"""๋งค์นญ ๊ฒฐ๊ณผ ๋ชจ๋ธ"""
|
71 |
+
total_matches: int = Field(..., description="๋งค์นญ๋ ํญ๋ชฉ ์")
|
72 |
+
similarity_threshold: float = Field(..., description="์ ์ฌ๋ ์๊ณ๊ฐ")
|
73 |
+
matches: List[Dict[str, Any]] = Field(..., description="๋งค์นญ๋ ํญ๋ชฉ ๋ชฉ๋ก")
|
74 |
+
|
75 |
+
class MatchingResponse(BaseModel):
|
76 |
+
"""API ์๋ต ๋ชจ๋ธ"""
|
77 |
+
success: bool = Field(..., description="์์ฒญ ์ฑ๊ณต ์ฌ๋ถ")
|
78 |
+
message: str = Field(..., description="์๋ต ๋ฉ์์ง")
|
79 |
+
result: Optional[MatchingResult] = Field(None, description="๋งค์นญ ๊ฒฐ๊ณผ")
|
main.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FastAPI ์ ํ๋ฆฌ์ผ์ด์
๋ฉ์ธ ๋ชจ๋
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
import tempfile
|
8 |
+
from fastapi import FastAPI, Request, HTTPException, Query, Body
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from typing import List, Dict, Any, Optional, Union
|
12 |
+
import json
|
13 |
+
import base64
|
14 |
+
from io import BytesIO
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์
|
18 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
19 |
+
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
20 |
+
|
21 |
+
# ๋๋ ํ ๋ฆฌ ์์ฑ
|
22 |
+
os.makedirs('/tmp/huggingface_cache', exist_ok=True)
|
23 |
+
os.makedirs('/tmp/uploads', exist_ok=True)
|
24 |
+
|
25 |
+
# ๋ก๊น
์ค์
|
26 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
# ํ์ํ ๋ชจ๋ ์ํฌํธ
|
30 |
+
from models.clip_model import KoreanCLIPModel
|
31 |
+
from utils.similarity import calculate_similarity, find_similar_items
|
32 |
+
from api.routes.matching_routers import LostItemPost, ImageMatchingRequest, MatchingResult, MatchingResponse
|
33 |
+
|
34 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ (์ฑ๊ธํค์ผ๋ก ๋ก๋)
|
35 |
+
clip_model = None
|
36 |
+
|
37 |
+
def get_clip_model():
|
38 |
+
"""
|
39 |
+
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ์ธ์คํด์ค๋ฅผ ๋ฐํ (์ฑ๊ธํค ํจํด)
|
40 |
+
"""
|
41 |
+
global clip_model
|
42 |
+
if clip_model is None:
|
43 |
+
try:
|
44 |
+
clip_model = KoreanCLIPModel()
|
45 |
+
return clip_model
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"CLIP ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {str(e)}")
|
48 |
+
# ์คํจ ์ None ๋ฐํ (ํ
์คํธ ๊ธฐ๋ฐ ๋งค์นญ๋ง ๊ฐ๋ฅ)
|
49 |
+
return None
|
50 |
+
return clip_model
|
51 |
+
|
52 |
+
# FastAPI ์ ํ๋ฆฌ์ผ์ด์
์์ฑ
|
53 |
+
app = FastAPI(
|
54 |
+
title="์ต๋๋ฌผ ์ ์ฌ๋ ๊ฒ์ API",
|
55 |
+
description="ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ API",
|
56 |
+
version="1.0.0"
|
57 |
+
)
|
58 |
+
|
59 |
+
# CORS ๋ฏธ๋ค์จ์ด ์ค์
|
60 |
+
app.add_middleware(
|
61 |
+
CORSMiddleware,
|
62 |
+
allow_origins=["*"],
|
63 |
+
allow_credentials=True,
|
64 |
+
allow_methods=["*"],
|
65 |
+
allow_headers=["*"],
|
66 |
+
)
|
67 |
+
|
68 |
+
# ์ ์ญ ์์ธ ์ฒ๋ฆฌ
|
69 |
+
@app.exception_handler(Exception)
|
70 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
71 |
+
"""
|
72 |
+
์ ์ญ ์์ธ ์ฒ๋ฆฌ๊ธฐ
|
73 |
+
"""
|
74 |
+
logger.error(f"์์ฒญ ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์: {str(exc)}")
|
75 |
+
return JSONResponse(
|
76 |
+
status_code=500,
|
77 |
+
content={"success": False, "message": f"์๋ฒ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(exc)}"}
|
78 |
+
)
|
79 |
+
|
80 |
+
# API ์๋ํฌ์ธํธ ์ ์
|
81 |
+
@app.post("/api/matching/find-similar", response_model=MatchingResponse)
|
82 |
+
async def find_similar_items_api(
|
83 |
+
request: Union[LostItemPost, ImageMatchingRequest],
|
84 |
+
threshold: float = Query(0.7, description="์ ์ฌ๋ ์๊ณ๊ฐ (0.0 ~ 1.0)"),
|
85 |
+
limit: int = Query(10, description="๋ฐํํ ์ต๋ ํญ๋ชฉ ์")
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ ์ฌํ ์ต๋๋ฌผ์ ์ฐพ๋ API ์๋ํฌ์ธํธ
|
89 |
+
|
90 |
+
Args:
|
91 |
+
request: ์ฌ์ฉ์์ ๋ถ์ค๋ฌผ ๊ฒ์๊ธ ๋๋ ์ด๋ฏธ์ง ๋งค์นญ ์์ฒญ
|
92 |
+
threshold: ์ ์ฌ๋ ์๊ณ๊ฐ
|
93 |
+
limit: ๋ฐํํ ์ต๋ ํญ๋ชฉ ์
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
MatchingResponse: ๋งค์นญ ๊ฒฐ๊ณผ๊ฐ ํฌํจ๋ ์๋ต
|
97 |
+
"""
|
98 |
+
try:
|
99 |
+
logger.info(f"์ ์ฌ ์ต๋๋ฌผ ๊ฒ์ ์์ฒญ: threshold={threshold}, limit={limit}")
|
100 |
+
|
101 |
+
# ์์ฒญ ๋ฐ์ดํฐ ๋ณํ
|
102 |
+
user_post = {}
|
103 |
+
if isinstance(request, LostItemPost):
|
104 |
+
user_post = request.dict()
|
105 |
+
else:
|
106 |
+
user_post = request.dict()
|
107 |
+
# Base64 ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉด ์ด๋ฏธ์ง ์ฒ๋ฆฌ ๋ก์ง ์ถ๊ฐ
|
108 |
+
if user_post.get("image_base64"):
|
109 |
+
try:
|
110 |
+
# Base64 ์ด๋ฏธ์ง ๋์ฝ๋ฉ
|
111 |
+
base64_str = user_post["image_base64"]
|
112 |
+
|
113 |
+
# Base64 ๋ฌธ์์ด์์ ํค๋ ์ ๊ฑฐ (์์ ๊ฒฝ์ฐ)
|
114 |
+
if "," in base64_str:
|
115 |
+
base64_str = base64_str.split(",")[1]
|
116 |
+
|
117 |
+
image_data = base64.b64decode(base64_str)
|
118 |
+
image = Image.open(BytesIO(image_data))
|
119 |
+
|
120 |
+
# ์ด๋ฏธ์ง ์ฌ์ฉ (CLIP ๋ชจ๋ธ์ ์ ๋ฌ)
|
121 |
+
user_post["image"] = image
|
122 |
+
|
123 |
+
logger.info("Base64 ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ")
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Base64 ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์คํจ: {str(e)}")
|
126 |
+
|
127 |
+
# ์ฌ๊ธฐ์ DB ๋์ ์์ฒญ์์ ์ ๋ฌ๋ ์ต๋๋ฌผ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํฉ๋๋ค.
|
128 |
+
lost_items = []
|
129 |
+
|
130 |
+
# ์์ฒญ์ ์ต๋๋ฌผ ๋ฐ์ดํฐ๊ฐ ์์ผ๋ฉด ์ฌ์ฉ
|
131 |
+
if hasattr(request, 'lost_items') and request.lost_items:
|
132 |
+
lost_items = request.lost_items
|
133 |
+
|
134 |
+
if not lost_items:
|
135 |
+
return MatchingResponse(
|
136 |
+
success=False,
|
137 |
+
message="์ต๋๋ฌผ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค. ์์ฒญ์ ์ต๋๋ฌผ ๋ฐ์ดํฐ๋ฅผ ํฌํจํด์ฃผ์ธ์.",
|
138 |
+
result=None
|
139 |
+
)
|
140 |
+
|
141 |
+
# CLIP ๋ชจ๋ธ ๋ก๋
|
142 |
+
clip_model_instance = get_clip_model()
|
143 |
+
|
144 |
+
# ์ ์ฌํ ํญ๋ชฉ ์ฐพ๊ธฐ
|
145 |
+
similar_items = find_similar_items(user_post, lost_items, threshold, clip_model_instance)
|
146 |
+
|
147 |
+
# ๊ฒฐ๊ณผ ์ ํ
|
148 |
+
similar_items = similar_items[:limit]
|
149 |
+
|
150 |
+
# ์๋ต ๊ตฌ์ฑ
|
151 |
+
result = MatchingResult(
|
152 |
+
total_matches=len(similar_items),
|
153 |
+
similarity_threshold=threshold,
|
154 |
+
matches=[
|
155 |
+
{
|
156 |
+
"item": item["item"],
|
157 |
+
"similarity": round(item["similarity"], 4),
|
158 |
+
"details": {
|
159 |
+
"text_similarity": round(item["details"]["text_similarity"], 4),
|
160 |
+
"image_similarity": round(item["details"]["image_similarity"], 4) if item["details"]["image_similarity"] else None,
|
161 |
+
"category_similarity": round(item["details"]["details"]["category"], 4),
|
162 |
+
"item_name_similarity": round(item["details"]["details"]["item_name"], 4),
|
163 |
+
"color_similarity": round(item["details"]["details"]["color"], 4),
|
164 |
+
"content_similarity": round(item["details"]["details"]["content"], 4)
|
165 |
+
}
|
166 |
+
}
|
167 |
+
for item in similar_items
|
168 |
+
]
|
169 |
+
)
|
170 |
+
|
171 |
+
return MatchingResponse(
|
172 |
+
success=True,
|
173 |
+
message=f"{len(similar_items)}๊ฐ์ ์ ์ฌํ ์ต๋๋ฌผ์ ์ฐพ์์ต๋๋ค.",
|
174 |
+
result=result
|
175 |
+
)
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
logger.error(f"API ํธ์ถ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
179 |
+
raise HTTPException(status_code=500, detail=f"์์ฒญ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}")
|
180 |
+
|
181 |
+
@app.get("/api/matching/test")
|
182 |
+
async def test_endpoint():
|
183 |
+
"""
|
184 |
+
API ํ
์คํธ์ฉ ์๋ํฌ์ธํธ
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
dict: ํ
์คํธ ์๋ต
|
188 |
+
"""
|
189 |
+
return {"message": "API๊ฐ ์ ์์ ์ผ๋ก ์๋ ์ค์
๋๋ค."}
|
190 |
+
|
191 |
+
@app.get("/api/status")
|
192 |
+
async def status():
|
193 |
+
"""
|
194 |
+
API ์ํ ์๋ํฌ์ธํธ
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
dict: API ์ํ ์ ๋ณด
|
198 |
+
"""
|
199 |
+
# CLIP ๋ชจ๋ธ ๋ก๋ ์๋
|
200 |
+
model = get_clip_model()
|
201 |
+
|
202 |
+
return {
|
203 |
+
"status": "ok",
|
204 |
+
"models_loaded": model is not None,
|
205 |
+
"version": "1.0.0"
|
206 |
+
}
|
207 |
+
|
208 |
+
# ๋ฃจํธ ์๋ํฌ์ธํธ
|
209 |
+
@app.get("/")
|
210 |
+
async def root():
|
211 |
+
"""
|
212 |
+
๋ฃจํธ ์๋ํฌ์ธํธ - API ์ ๋ณด ์ ๊ณต
|
213 |
+
"""
|
214 |
+
return {
|
215 |
+
"app_name": "์ต๋๋ฌผ ์ ์ฌ๋ ๊ฒ์ API",
|
216 |
+
"version": "1.0.0",
|
217 |
+
"description": "ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.",
|
218 |
+
"api_endpoint": "/api/matching/find-similar",
|
219 |
+
"test_endpoint": "/api/matching/test",
|
220 |
+
"status_endpoint": "/api/status"
|
221 |
+
}
|
models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
๋ชจ๋ธ ๊ด๋ จ ๋ชจ๋ ํจํค์ง
|
3 |
+
"""
|
4 |
+
from .clip_model import KoreanCLIPModel
|
5 |
+
|
6 |
+
__all__ = ['KoreanCLIPModel']
|
models/clip_model.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ๊ตฌํ
|
3 |
+
์ด ๋ชจ๋์ HuggingFace์ ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ
์คํธ์ ์ด๋ฏธ์ง์ ์๋ฒ ๋ฉ์ ์์ฑ
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPProcessor, CLIPModel
|
10 |
+
from PIL import Image
|
11 |
+
import requests
|
12 |
+
from io import BytesIO
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์
|
16 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
17 |
+
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
18 |
+
|
19 |
+
# ๋๋ ํ ๋ฆฌ ์์ฑ
|
20 |
+
os.makedirs('/tmp/huggingface_cache', exist_ok=True)
|
21 |
+
os.makedirs('/tmp/uploads', exist_ok=True)
|
22 |
+
|
23 |
+
# ๋ก๊น
์ค์
|
24 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
# ๋ชจ๋ธ ์ค์ - ํ๊ฒฝ ๋ณ์์์ ๊ฐ์ ธ์ค๊ฑฐ๋ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
28 |
+
CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko')
|
29 |
+
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu"
|
30 |
+
|
31 |
+
class KoreanCLIPModel:
|
32 |
+
"""
|
33 |
+
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ํด๋์ค
|
34 |
+
ํ
์คํธ์ ์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉํ๊ณ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ ๊ธฐ๋ฅ ์ ๊ณต
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, model_name=CLIP_MODEL_NAME, device=DEVICE):
|
38 |
+
"""
|
39 |
+
CLIP ๋ชจ๋ธ ์ด๊ธฐํ
|
40 |
+
|
41 |
+
Args:
|
42 |
+
model_name (str): ์ฌ์ฉํ CLIP ๋ชจ๋ธ ์ด๋ฆ ๋๋ ๊ฒฝ๋ก
|
43 |
+
device (str): ์ฌ์ฉํ ์ฅ์น ('cuda' ๋๋ 'cpu')
|
44 |
+
"""
|
45 |
+
self.device = device
|
46 |
+
self.model_name = model_name
|
47 |
+
|
48 |
+
logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค (device: {device})...")
|
49 |
+
|
50 |
+
try:
|
51 |
+
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์
|
52 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
53 |
+
os.makedirs("/tmp/transformers_cache", exist_ok=True)
|
54 |
+
|
55 |
+
self.model = CLIPModel.from_pretrained(model_name).to(device)
|
56 |
+
self.processor = CLIPProcessor.from_pretrained(model_name)
|
57 |
+
logger.info("CLIP ๋ชจ๋ธ ๋ก๋ ์๋ฃ")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"CLIP ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
def encode_text(self, text):
|
63 |
+
"""
|
64 |
+
ํ
์คํธ๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ
|
65 |
+
|
66 |
+
Args:
|
67 |
+
text (str or list): ์ธ์ฝ๋ฉํ ํ
์คํธ ๋๋ ํ
์คํธ ๋ฆฌ์คํธ
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ
|
71 |
+
"""
|
72 |
+
if isinstance(text, str):
|
73 |
+
text = [text]
|
74 |
+
|
75 |
+
try:
|
76 |
+
with torch.no_grad():
|
77 |
+
# ํ
์คํธ ์ธ์ฝ๋ฉ
|
78 |
+
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
79 |
+
text_features = self.model.get_text_features(**inputs)
|
80 |
+
|
81 |
+
# ํ
์คํธ ํน์ฑ ์ ๊ทํ
|
82 |
+
text_embeddings = text_features / text_features.norm(dim=1, keepdim=True)
|
83 |
+
|
84 |
+
return text_embeddings.cpu().numpy()
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"ํ
์คํธ ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
87 |
+
return np.zeros((len(text), self.model.text_embed_dim))
|
88 |
+
|
89 |
+
def encode_image(self, image_source):
|
90 |
+
"""
|
91 |
+
์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ
|
92 |
+
|
93 |
+
Args:
|
94 |
+
image_source: ์ธ์ฝ๋ฉํ ์ด๋ฏธ์ง (PIL Image, URL ๋๋ ์ด๋ฏธ์ง ๊ฒฝ๋ก)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
# ์ด๋ฏธ์ง ๋ก๋ (URL, ํ์ผ ๊ฒฝ๋ก, PIL ์ด๋ฏธ์ง ๊ฐ์ฒด ๋๋ Base64)
|
101 |
+
if isinstance(image_source, str):
|
102 |
+
if image_source.startswith('http'):
|
103 |
+
# URL์์ ์ด๋ฏธ์ง ๋ก๋
|
104 |
+
response = requests.get(image_source)
|
105 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
106 |
+
else:
|
107 |
+
# ๋ก์ปฌ ํ์ผ์์ ์ด๋ฏธ์ง ๋ก๋
|
108 |
+
image = Image.open(image_source).convert('RGB')
|
109 |
+
else:
|
110 |
+
# ์ด๋ฏธ PIL ์ด๋ฏธ์ง ๊ฐ์ฒด์ธ ๊ฒฝ์ฐ
|
111 |
+
image = image_source.convert('RGB')
|
112 |
+
|
113 |
+
with torch.no_grad():
|
114 |
+
# ์ด๋ฏธ์ง ์ธ์ฝ๋ฉ
|
115 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
116 |
+
image_features = self.model.get_image_features(**inputs)
|
117 |
+
|
118 |
+
# ์ด๋ฏธ์ง ํน์ฑ ์ ๊ทํ
|
119 |
+
image_embeddings = image_features / image_features.norm(dim=1, keepdim=True)
|
120 |
+
|
121 |
+
return image_embeddings.cpu().numpy()
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"์ด๋ฏธ์ง ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
124 |
+
return np.zeros((1, self.model.vision_embed_dim))
|
125 |
+
|
126 |
+
def calculate_similarity(self, text_embedding, image_embedding=None):
|
127 |
+
"""
|
128 |
+
ํ
์คํธ์ ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ๊ฐ์ ์ ์ฌ๋ ๊ณ์ฐ
|
129 |
+
|
130 |
+
Args:
|
131 |
+
text_embedding (numpy.ndarray): ํ
์คํธ ์๋ฒ ๋ฉ
|
132 |
+
image_embedding (numpy.ndarray, optional): ์ด๋ฏธ์ง ์๋ฒ ๋ฉ (์์ผ๋ฉด ํ
์คํธ๋ง ๋น๊ต)
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด)
|
136 |
+
"""
|
137 |
+
if image_embedding is None:
|
138 |
+
# ํ
์คํธ-ํ
์คํธ ์ ์ฌ๋ ๊ณ์ฐ (์ฝ์ฌ์ธ ์ ์ฌ๋)
|
139 |
+
similarity = np.dot(text_embedding, text_embedding.T)[0, 0]
|
140 |
+
else:
|
141 |
+
# ํ
์คํธ-์ด๋ฏธ์ง ์ ์ฌ๋ ๊ณ์ฐ (์ฝ์ฌ์ธ ์ ์ฌ๋)
|
142 |
+
similarity = np.dot(text_embedding, image_embedding.T)[0, 0]
|
143 |
+
|
144 |
+
# ์ ์ฌ๋๋ฅผ 0~1 ๋ฒ์๋ก ์ ๊ทํ
|
145 |
+
similarity = (similarity + 1) / 2
|
146 |
+
return float(similarity)
|
147 |
+
|
148 |
+
def encode_batch_texts(self, texts):
|
149 |
+
"""
|
150 |
+
์ฌ๋ฌ ํ
์คํธ๋ฅผ ํ ๋ฒ์ ์๋ฒ ๋ฉ
|
151 |
+
|
152 |
+
Args:
|
153 |
+
texts (list): ํ
์คํธ ๋ชฉ๋ก
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ ๋ฐฐ์ด
|
157 |
+
"""
|
158 |
+
# ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ํ ์ฝ๋
|
159 |
+
# ์ค์ ๊ตฌํ์์๋ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ์ ํ ๋ฐฐ์น ํฌ๊ธฐ ์กฐ์ ํ์
|
160 |
+
return self.encode_text(texts)
|
161 |
+
|
162 |
+
# ๋ชจ๋ ํ
์คํธ์ฉ ์ฝ๋
|
163 |
+
if __name__ == "__main__":
|
164 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
165 |
+
clip_model = KoreanCLIPModel()
|
166 |
+
|
167 |
+
# ์ํ ํ
์คํธ ์ธ์ฝ๋ฉ
|
168 |
+
sample_text = "๊ฒ์์ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค. ํ๊ธ๊ณผ ์นด๋๊ฐ ๋ค์ด์์ด์."
|
169 |
+
text_embedding = clip_model.encode_text(sample_text)
|
170 |
+
|
171 |
+
print(f"ํ
์คํธ ์๋ฒ ๋ฉ shape: {text_embedding.shape}")
|
172 |
+
|
173 |
+
# ์ ์ฌ๋ ๊ณ์ฐ (ํ
์คํธ๋ง)
|
174 |
+
sample_text2 = "๊ฒ์์ ์ง๊ฐ์ ์ฐพ์์ต๋๋ค. ์์ ํ๊ธ๊ณผ ์นด๋๊ฐ ์์ต๋๋ค."
|
175 |
+
text_embedding2 = clip_model.encode_text(sample_text2)
|
176 |
+
|
177 |
+
similarity = clip_model.calculate_similarity(text_embedding, text_embedding2)
|
178 |
+
print(f"ํ
์คํธ ๊ฐ ์ ์ฌ๋: {similarity:.4f}")
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
+
transformers==4.26.0
|
4 |
+
torch
|
5 |
+
Pillow
|
6 |
+
pydantic
|
7 |
+
python-multipart
|
8 |
+
httpx
|
9 |
+
python-dotenv
|
10 |
+
aiofiles
|
11 |
+
kiwipiepy==0.20.4
|
12 |
+
numpy
|
13 |
+
requests
|
utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
์ ํธ๋ฆฌํฐ ๊ด๋ จ ๋ชจ๋ ํจํค์ง
|
3 |
+
"""
|
4 |
+
from .similarity import calculate_similarity, find_similar_items
|
5 |
+
|
6 |
+
__all__ = ['calculate_similarity', 'find_similar_items']
|
utils/similarity.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
์ ์ฌ๋ ๊ณ์ฐ ๋ฐ ๊ด๋ จ ์ ํธ๋ฆฌํฐ ํจ์
|
3 |
+
Kiwi ํํ์ ๋ถ์๊ธฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ๊ตญ์ด ํ
์คํธ ๋ถ์ ๊ฐ์
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import re
|
10 |
+
from collections import Counter
|
11 |
+
from kiwipiepy import Kiwi
|
12 |
+
|
13 |
+
# ๋ก๊น
์ค์
|
14 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Kiwi ํํ์ ๋ถ์๊ธฐ ์ด๊ธฐํ
|
18 |
+
kiwi = Kiwi()
|
19 |
+
|
20 |
+
# ์ค์ ๊ฐ (ํ๊ฒฝ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ)
|
21 |
+
SIMILARITY_THRESHOLD = float(os.getenv('SIMILARITY_THRESHOLD', '0.7'))
|
22 |
+
TEXT_WEIGHT = float(os.getenv('TEXT_WEIGHT', '0.7'))
|
23 |
+
IMAGE_WEIGHT = float(os.getenv('IMAGE_WEIGHT', '0.3'))
|
24 |
+
CATEGORY_WEIGHT = float(os.getenv('CATEGORY_WEIGHT', '0.5'))
|
25 |
+
ITEM_NAME_WEIGHT = float(os.getenv('ITEM_NAME_WEIGHT', '0.3'))
|
26 |
+
COLOR_WEIGHT = float(os.getenv('COLOR_WEIGHT', '0.1'))
|
27 |
+
CONTENT_WEIGHT = float(os.getenv('CONTENT_WEIGHT', '0.1'))
|
28 |
+
|
29 |
+
def preprocess_text(text):
|
30 |
+
"""
|
31 |
+
ํ
์คํธ ์ ์ฒ๋ฆฌ ํจ์
|
32 |
+
|
33 |
+
Args:
|
34 |
+
text (str): ์ ์ฒ๋ฆฌํ ํ
์คํธ
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
str: ์ ์ฒ๋ฆฌ๋ ํ
์คํธ
|
38 |
+
"""
|
39 |
+
if not text:
|
40 |
+
return ""
|
41 |
+
|
42 |
+
# ์๋ฌธ์ ๋ณํ (์์ด์ ๊ฒฝ์ฐ)
|
43 |
+
text = text.lower()
|
44 |
+
|
45 |
+
# ๋ถํ์ํ ๊ณต๋ฐฑ ์ ๊ฑฐ
|
46 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
47 |
+
|
48 |
+
# ํน์ ๋ฌธ์ ์ ๊ฑฐ (๋จ, ํ๊ธ, ์๋ฌธ, ์ซ์๋ ์ ์ง)
|
49 |
+
text = re.sub(r'[^\w\s๊ฐ-ํฃใฑ-ใ
ใ
-ใ
ฃ]', ' ', text)
|
50 |
+
|
51 |
+
return text
|
52 |
+
|
53 |
+
def extract_keywords(text):
|
54 |
+
"""
|
55 |
+
Kiwi ํํ์ ๋ถ์๊ธฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ์์ ์ค์ ํค์๋ ์ถ์ถ
|
56 |
+
|
57 |
+
Args:
|
58 |
+
text (str): ํค์๋๋ฅผ ์ถ์ถํ ํ
์คํธ
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list: ํค์๋ ๋ฆฌ์คํธ (์ฃผ๋ก ๋ช
์ฌ์ ํ์ฉ์ฌ)
|
62 |
+
"""
|
63 |
+
if not text:
|
64 |
+
return []
|
65 |
+
|
66 |
+
# ํ
์คํธ ์ ์ฒ๋ฆฌ
|
67 |
+
processed_text = preprocess_text(text)
|
68 |
+
|
69 |
+
try:
|
70 |
+
# Kiwi ํํ์ ๋ถ์ ์ํ
|
71 |
+
result = kiwi.analyze(processed_text)
|
72 |
+
|
73 |
+
# ์ค์ ํค์๋ ์ถ์ถ (๋ช
์ฌ, ํ์ฉ์ฌ ๋ฑ)
|
74 |
+
keywords = []
|
75 |
+
for token in result[0][0]:
|
76 |
+
# NNG: ์ผ๋ฐ๋ช
์ฌ, NNP: ๊ณ ์ ๋ช
์ฌ, VA: ํ์ฉ์ฌ, VV: ๋์ฌ, SL: ์ธ๊ตญ์ด(์์ด ๋ฑ)
|
77 |
+
if token.tag in ['NNG', 'NNP', 'VA', 'SL']:
|
78 |
+
# ํ ๊ธ์ ๋ช
์ฌ๋ ์ค์๋ ๋ฎ์ ์ ์์ด ํํฐ๋ง (์ ํ์ )
|
79 |
+
if len(token.form) > 1 or token.tag in ['SL']:
|
80 |
+
keywords.append(token.form)
|
81 |
+
|
82 |
+
logger.debug(f"ํค์๋ ์ถ์ถ ๊ฒฐ๊ณผ: {keywords}")
|
83 |
+
return keywords
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
logger.warning(f"ํํ์ ๋ถ์ ์ค๋ฅ: {str(e)}, ๊ธฐ๋ณธ ๋ถ๋ฆฌ ๋ฐฉ์์ผ๋ก ๋์ฒด")
|
87 |
+
# ์ค๋ฅ ๋ฐ์ ์ ๊ธฐ๋ณธ ๋ฐฉ์์ผ๋ก ๋์ฒด
|
88 |
+
words = processed_text.split()
|
89 |
+
return words
|
90 |
+
|
91 |
+
def calculate_text_similarity(text1, text2, weights=None):
|
92 |
+
"""
|
93 |
+
๋ ํ
์คํธ ๊ฐ์ ์ ์ฌ๋ ๊ณ์ฐ (Kiwi ํํ์ ๋ถ์ ํ์ฉ)
|
94 |
+
|
95 |
+
Args:
|
96 |
+
text1 (str): ์ฒซ ๋ฒ์งธ ํ
์คํธ
|
97 |
+
text2 (str): ๋ ๋ฒ์งธ ํ
์คํธ
|
98 |
+
weights (dict, optional): ๊ฐ ๋ถ๋ถ์ ๋ํ ๊ฐ์ค์น
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด)
|
102 |
+
"""
|
103 |
+
if not text1 or not text2:
|
104 |
+
return 0.0
|
105 |
+
|
106 |
+
# ๊ธฐ๋ณธ ๊ฐ์ค์น ์ค์
|
107 |
+
if weights is None:
|
108 |
+
weights = {
|
109 |
+
'common_words': 0.7, # ๊ณตํต ๋จ์ด ๋น์จ์ ๊ฐ์ค์น ์ฆ๊ฐ
|
110 |
+
'length_ratio': 0.15,
|
111 |
+
'word_order': 0.15
|
112 |
+
}
|
113 |
+
|
114 |
+
# ํ
์คํธ์์ ํค์๋ ์ถ์ถ (Kiwi ํํ์ ๋ถ์๊ธฐ ์ฌ์ฉ)
|
115 |
+
keywords1 = extract_keywords(text1)
|
116 |
+
keywords2 = extract_keywords(text2)
|
117 |
+
|
118 |
+
if not keywords1 or not keywords2:
|
119 |
+
return 0.0
|
120 |
+
|
121 |
+
# 1. ๊ณตํต ๋จ์ด ๋น์จ ๊ณ์ฐ
|
122 |
+
common_words = set(keywords1) & set(keywords2)
|
123 |
+
common_ratio = len(common_words) / max(1, min(len(set(keywords1)), len(set(keywords2))))
|
124 |
+
|
125 |
+
# 2. ํ
์คํธ ๊ธธ์ด ์ ์ฌ๋
|
126 |
+
length_ratio = min(len(keywords1), len(keywords2)) / max(1, max(len(keywords1), len(keywords2)))
|
127 |
+
|
128 |
+
# 3. ๋จ์ด ์์ ์ ์ฌ๋ (์ ํ์ )
|
129 |
+
word_order_sim = 0.0
|
130 |
+
if common_words:
|
131 |
+
# ๊ณตํต ๋จ์ด์ ์์น ์ฐจ์ด ๊ธฐ๋ฐ ์ ์ฌ๋
|
132 |
+
positions1 = {word: i for i, word in enumerate(keywords1) if word in common_words}
|
133 |
+
positions2 = {word: i for i, word in enumerate(keywords2) if word in common_words}
|
134 |
+
|
135 |
+
if positions1 and positions2:
|
136 |
+
pos_diff_sum = sum(abs(positions1[word] - positions2[word]) for word in common_words if word in positions1 and word in positions2)
|
137 |
+
max_diff = len(keywords1) + len(keywords2)
|
138 |
+
word_order_sim = 1.0 - (pos_diff_sum / max(1, max_diff))
|
139 |
+
|
140 |
+
# ๊ฐ์ค์น ์ ์ฉํ์ฌ ์ต์ข
์ ์ฌ๋ ๊ณ์ฐ
|
141 |
+
similarity = (
|
142 |
+
weights['common_words'] * common_ratio +
|
143 |
+
weights['length_ratio'] * length_ratio +
|
144 |
+
weights['word_order'] * word_order_sim
|
145 |
+
)
|
146 |
+
|
147 |
+
return min(1.0, max(0.0, similarity))
|
148 |
+
|
149 |
+
def calculate_category_similarity(category1, category2):
|
150 |
+
"""
|
151 |
+
๋ ์นดํ
๊ณ ๋ฆฌ ๊ฐ์ ์ ์ฌ๋ ๊ณ์ฐ (๊ธฐํ ์นดํ
๊ณ ๋ฆฌ ๊ณ ๋ ค)
|
152 |
+
|
153 |
+
Args:
|
154 |
+
category1 (str): ์ฒซ ๋ฒ์งธ ์นดํ
๊ณ ๋ฆฌ
|
155 |
+
category2 (str): ๋ ๋ฒ์งธ ์นดํ
๊ณ ๋ฆฌ
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด)
|
159 |
+
"""
|
160 |
+
if not category1 or not category2:
|
161 |
+
return 0.0
|
162 |
+
|
163 |
+
# ์นดํ
๊ณ ๋ฆฌ ์ ์ฒ๋ฆฌ
|
164 |
+
cat1 = preprocess_text(str(category1))
|
165 |
+
cat2 = preprocess_text(str(category2))
|
166 |
+
|
167 |
+
# ์ ํํ ์ผ์นํ๋ ๊ฒฝ์ฐ
|
168 |
+
if cat1 == cat2:
|
169 |
+
return 1.0
|
170 |
+
|
171 |
+
# Kiwi๋ก ํค์๋ ์ถ์ถ
|
172 |
+
keywords1 = set(extract_keywords(cat1))
|
173 |
+
keywords2 = set(extract_keywords(cat2))
|
174 |
+
|
175 |
+
# '๊ธฐํ' ์นดํ
๊ณ ๋ฆฌ ์ฒ๋ฆฌ
|
176 |
+
if '๊ธฐํ' in cat1 or '๊ธฐํ' in cat2:
|
177 |
+
# ํค์๋ ์ถ์ถ ๋ฐ ๊ต์งํฉ ๋น๊ต
|
178 |
+
if not keywords1 or not keywords2:
|
179 |
+
return 0.3 # ๊ธฐํ ์นดํ
๊ณ ๋ฆฌ๋ ๊ธฐ๋ณธ ์ ์ฌ๋ ๋ถ์ฌ
|
180 |
+
|
181 |
+
# ๊ต์งํฉ ๋จ์ด๊ฐ ์์ผ๋ฉด ๋์ ์ ์ฌ๋
|
182 |
+
common_words = keywords1 & keywords2
|
183 |
+
if common_words:
|
184 |
+
return 0.7
|
185 |
+
|
186 |
+
return 0.3 # ๊ธฐํ ์นดํ
๊ณ ๋ฆฌ์ง๋ง ๊ณตํต ํค์๋ ์์
|
187 |
+
|
188 |
+
# ์ผ๋ฐ ์นดํ
๊ณ ๋ฆฌ ์ ์ฌ๋
|
189 |
+
return calculate_text_similarity(cat1, cat2)
|
190 |
+
|
191 |
+
def calculate_similarity(user_post, lost_item, clip_model=None):
|
192 |
+
"""
|
193 |
+
์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ ํญ๋ชฉ ๊ฐ์ ์ข
ํฉ ์ ์ฌ๋ ๊ณ์ฐ
|
194 |
+
|
195 |
+
Args:
|
196 |
+
user_post (dict): ์ฌ์ฉ์ ๊ฒ์๊ธ ์ ๋ณด
|
197 |
+
lost_item (dict): ์ต๋๋ฌผ ๋ฐ์ดํฐ
|
198 |
+
clip_model (KoreanCLIPModel, optional): CLIP ๋ชจ๋ธ ์ธ์คํด์ค
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด)
|
202 |
+
dict: ์ธ๋ถ ์ ์ฌ๋ ์ ๋ณด
|
203 |
+
"""
|
204 |
+
# ํ
์คํธ ์ ์ฌ๋ ๊ณ์ฐ
|
205 |
+
text_similarities = {}
|
206 |
+
|
207 |
+
# 1. ์นดํ
๊ณ ๋ฆฌ ์ ์ฌ๋
|
208 |
+
category_sim = 0.0
|
209 |
+
if 'category' in user_post and 'category' in lost_item:
|
210 |
+
category_sim = calculate_category_similarity(user_post['category'], lost_item['category'])
|
211 |
+
text_similarities['category'] = category_sim
|
212 |
+
|
213 |
+
# 2. ๋ฌผํ๋ช
์ ์ฌ๋
|
214 |
+
item_name_sim = 0.0
|
215 |
+
user_item_name = user_post.get('item_name', '')
|
216 |
+
lost_item_name = lost_item.get('item_name', '')
|
217 |
+
if user_item_name and lost_item_name:
|
218 |
+
item_name_sim = calculate_text_similarity(user_item_name, lost_item_name)
|
219 |
+
text_similarities['item_name'] = item_name_sim
|
220 |
+
|
221 |
+
# 3. ์์ ์ ์ฌ๋
|
222 |
+
color_sim = 0.0
|
223 |
+
user_color = user_post.get('color', '')
|
224 |
+
lost_color = lost_item.get('color', '')
|
225 |
+
if user_color and lost_color:
|
226 |
+
color_sim = calculate_text_similarity(user_color, lost_color)
|
227 |
+
text_similarities['color'] = color_sim
|
228 |
+
|
229 |
+
# 4. ๋ด์ฉ ์ ์ฌ๋
|
230 |
+
content_sim = 0.0
|
231 |
+
user_content = user_post.get('content', '')
|
232 |
+
lost_content = lost_item.get('content', '')
|
233 |
+
if user_content and lost_content:
|
234 |
+
content_sim = calculate_text_similarity(user_content, lost_content)
|
235 |
+
text_similarities['content'] = content_sim
|
236 |
+
|
237 |
+
# ํ
์คํธ ์ข
ํฉ ์ ์ฌ๋ ๊ณ์ฐ (๊ฐ์ค์น ์ ์ฉ)
|
238 |
+
text_similarity = (
|
239 |
+
CATEGORY_WEIGHT * category_sim +
|
240 |
+
ITEM_NAME_WEIGHT * item_name_sim +
|
241 |
+
COLOR_WEIGHT * color_sim +
|
242 |
+
CONTENT_WEIGHT * content_sim
|
243 |
+
)
|
244 |
+
|
245 |
+
# CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง-ํ
์คํธ ์ ์ฌ๋ ๊ณ์ฐ
|
246 |
+
image_similarity = 0.0
|
247 |
+
has_image = False
|
248 |
+
|
249 |
+
if clip_model is not None:
|
250 |
+
# ์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ต๋๋ฌผ์ ๋ชจ๋ ์ด๋ฏธ์ง๊ฐ ์๋ ๊ฒฝ์ฐ
|
251 |
+
user_image = user_post.get('image', None) or user_post.get('image_url', None)
|
252 |
+
lost_image = lost_item.get('image', None) or lost_item.get('image_url', None)
|
253 |
+
|
254 |
+
if user_image and lost_image:
|
255 |
+
try:
|
256 |
+
# CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์ฌ๋ ๊ณ์ฐ
|
257 |
+
user_text_embedding = clip_model.encode_text(user_post.get('content', ''))
|
258 |
+
user_image_embedding = clip_model.encode_image(user_image)
|
259 |
+
|
260 |
+
item_text_embedding = clip_model.encode_text(lost_item.get('content', ''))
|
261 |
+
item_image_embedding = clip_model.encode_image(lost_image)
|
262 |
+
|
263 |
+
# ํ
์คํธ-์ด๋ฏธ์ง ๊ต์ฐจ ์ ์ฌ๋ ๊ณ์ฐ
|
264 |
+
text_to_image_sim = clip_model.calculate_similarity(user_text_embedding, item_image_embedding)
|
265 |
+
image_to_text_sim = clip_model.calculate_similarity(item_text_embedding, user_image_embedding)
|
266 |
+
image_to_image_sim = clip_model.calculate_similarity(user_image_embedding, item_image_embedding)
|
267 |
+
|
268 |
+
image_similarity = (text_to_image_sim + image_to_text_sim + image_to_image_sim) / 3
|
269 |
+
has_image = True
|
270 |
+
except Exception as e:
|
271 |
+
logger.warning(f"์ด๋ฏธ์ง ์ ์ฌ๋ ๊ณ์ฐ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
272 |
+
|
273 |
+
# ์ต์ข
์ ์ฌ๋ ๊ณ์ฐ (ํ
์คํธ์ ์ด๋ฏธ์ง ๊ฐ์ค์น ์ ์ฉ)
|
274 |
+
if has_image:
|
275 |
+
final_similarity = TEXT_WEIGHT * text_similarity + IMAGE_WEIGHT * image_similarity
|
276 |
+
else:
|
277 |
+
final_similarity = text_similarity
|
278 |
+
|
279 |
+
# ์ธ๋ถ ์ ์ฌ๋ ์ ๋ณด
|
280 |
+
similarity_details = {
|
281 |
+
'text_similarity': text_similarity,
|
282 |
+
'image_similarity': image_similarity if has_image else None,
|
283 |
+
'final_similarity': final_similarity,
|
284 |
+
'details': text_similarities
|
285 |
+
}
|
286 |
+
|
287 |
+
return final_similarity, similarity_details
|
288 |
+
|
289 |
+
def find_similar_items(user_post, lost_items, threshold=SIMILARITY_THRESHOLD, clip_model=None):
|
290 |
+
"""
|
291 |
+
์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ ์ ์ฌํ ์ต๋๋ฌผ ๋ชฉ๋ก ์ฐพ๊ธฐ
|
292 |
+
|
293 |
+
Args:
|
294 |
+
user_post (dict): ์ฌ์ฉ์ ๊ฒ์๊ธ ์ ๋ณด
|
295 |
+
lost_items (list): ์ต๋๋ฌผ ๋ฐ์ดํฐ ๋ชฉ๋ก
|
296 |
+
threshold (float): ์ ์ฌ๋ ์๊ณ๊ฐ (๊ธฐ๋ณธ๊ฐ: config์์ ์ค์ )
|
297 |
+
clip_model (KoreanCLIPModel, optional): CLIP ๋ชจ๋ธ ์ธ์คํด์ค
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
list: ์ ์ฌ๋๊ฐ ์๊ณ๊ฐ ์ด์์ธ ์ต๋๋ฌผ ๋ชฉ๋ก (์ ์ฌ๋ ๋์ ์)
|
301 |
+
"""
|
302 |
+
similar_items = []
|
303 |
+
|
304 |
+
logger.info(f"์ฌ์ฉ์ ๊ฒ์๊ธ๊ณผ {len(lost_items)}๊ฐ ์ต๋๋ฌผ ๋น๊ต ์ค...")
|
305 |
+
|
306 |
+
for item in lost_items:
|
307 |
+
similarity, details = calculate_similarity(user_post, item, clip_model)
|
308 |
+
|
309 |
+
if similarity >= threshold:
|
310 |
+
similar_items.append({
|
311 |
+
'item': item,
|
312 |
+
'similarity': similarity,
|
313 |
+
'details': details
|
314 |
+
})
|
315 |
+
|
316 |
+
# ์ ์ฌ๋ ๋์ ์์ผ๋ก ์ ๋ ฌ
|
317 |
+
similar_items.sort(key=lambda x: x['similarity'], reverse=True)
|
318 |
+
|
319 |
+
logger.info(f"์ ์ฌ๋ {threshold} ์ด์์ธ ์ต๋๋ฌผ {len(similar_items)}๊ฐ ๋ฐ๊ฒฌ")
|
320 |
+
|
321 |
+
return similar_items
|
322 |
+
|
323 |
+
# ๋ชจ๋ ํ
์คํธ์ฉ ์ฝ๋
|
324 |
+
if __name__ == "__main__":
|
325 |
+
# ํ
์คํธ ์ ์ฌ๋ ํ
์คํธ
|
326 |
+
text1 = "๊ฒ์์ ๊ฐ์ฃฝ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค."
|
327 |
+
text2 = "๊ฒ์ ๊ฐ์ฃฝ ์ง๊ฐ์ ์ฐพ์์ต๋๋ค."
|
328 |
+
text3 = "๋
ธํธ๋ถ์ ๋ถ์คํ์ต๋๋ค."
|
329 |
+
|
330 |
+
# ํค์๋ ์ถ์ถ ํ
์คํธ
|
331 |
+
print("[ ํค์๋ ์ถ์ถ ํ
์คํธ ]")
|
332 |
+
print(f"ํ
์คํธ 1: '{text1}'")
|
333 |
+
print(f"์ถ์ถ๋ ํค์๋: {extract_keywords(text1)}")
|
334 |
+
print(f"ํ
์คํธ 2: '{text2}'")
|
335 |
+
print(f"์ถ์ถ๋ ํค์๋: {extract_keywords(text2)}")
|
336 |
+
|
337 |
+
# ์ ์ฌ๋ ํ
์คํธ
|
338 |
+
sim12 = calculate_text_similarity(text1, text2)
|
339 |
+
sim13 = calculate_text_similarity(text1, text3)
|
340 |
+
|
341 |
+
print("\n[ ์ ์ฌ๋ ํ
์คํธ ]")
|
342 |
+
print(f"ํ
์คํธ 1-2 ์ ์ฌ๋: {sim12:.4f}")
|
343 |
+
print(f"ํ
์คํธ 1-3 ์ ์ฌ๋: {sim13:.4f}")
|
344 |
+
|
345 |
+
# ์นดํ
๊ณ ๋ฆฌ ์ ์ฌ๋ ํ
์คํธ
|
346 |
+
cat1 = "์ง๊ฐ"
|
347 |
+
cat2 = "๊ฐ๋ฐฉ/์ง๊ฐ"
|
348 |
+
cat3 = "๊ธฐํ"
|
349 |
+
|
350 |
+
cat_sim12 = calculate_category_similarity(cat1, cat2)
|
351 |
+
cat_sim13 = calculate_category_similarity(cat1, cat3)
|
352 |
+
|
353 |
+
print("\n[ ์นดํ
๊ณ ๋ฆฌ ์ ์ฌ๋ ํ
์คํธ ]")
|
354 |
+
print(f"์นดํ
๊ณ ๋ฆฌ 1-2 ์ ์ฌ๋: {cat_sim12:.4f}")
|
355 |
+
print(f"์นดํ
๊ณ ๋ฆฌ 1-3 ์ ์ฌ๋: {cat_sim13:.4f}")
|