ssfinder-matching / models /clip_model.py
์†์„œํ˜„
feat: initial file
ed93606
raw
history blame
6.89 kB
"""
ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ ๊ตฌํ˜„
์ด ๋ชจ๋“ˆ์€ HuggingFace์˜ ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€์˜ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑ
"""
import os
import sys
import logging
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
import numpy as np
# ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs('/tmp/huggingface_cache', exist_ok=True)
os.makedirs('/tmp/uploads', exist_ok=True)
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ๋ชจ๋ธ ์„ค์ • - ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜ค๊ฑฐ๋‚˜ ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ
CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko')
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu"
class KoreanCLIPModel:
"""
ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ ํด๋ž˜์Šค
ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€๋ฅผ ์ž„๋ฒ ๋”ฉํ•˜๊ณ  ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๊ธฐ๋Šฅ ์ œ๊ณต
"""
def __init__(self, model_name=CLIP_MODEL_NAME, device=DEVICE):
"""
CLIP ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
Args:
model_name (str): ์‚ฌ์šฉํ•  CLIP ๋ชจ๋ธ ์ด๋ฆ„ ๋˜๋Š” ๊ฒฝ๋กœ
device (str): ์‚ฌ์šฉํ•  ์žฅ์น˜ ('cuda' ๋˜๋Š” 'cpu')
"""
self.device = device
self.model_name = model_name
logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘ (device: {device})...")
try:
# ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.makedirs("/tmp/transformers_cache", exist_ok=True)
self.model = CLIPModel.from_pretrained(model_name).to(device)
self.processor = CLIPProcessor.from_pretrained(model_name)
logger.info("CLIP ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except Exception as e:
logger.error(f"CLIP ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
raise
def encode_text(self, text):
"""
ํ…์ŠคํŠธ๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
Args:
text (str or list): ์ธ์ฝ”๋”ฉํ•  ํ…์ŠคํŠธ ๋˜๋Š” ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
"""
if isinstance(text, str):
text = [text]
try:
with torch.no_grad():
# ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
text_features = self.model.get_text_features(**inputs)
# ํ…์ŠคํŠธ ํŠน์„ฑ ์ •๊ทœํ™”
text_embeddings = text_features / text_features.norm(dim=1, keepdim=True)
return text_embeddings.cpu().numpy()
except Exception as e:
logger.error(f"ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return np.zeros((len(text), self.model.text_embed_dim))
def encode_image(self, image_source):
"""
์ด๋ฏธ์ง€๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
Args:
image_source: ์ธ์ฝ”๋”ฉํ•  ์ด๋ฏธ์ง€ (PIL Image, URL ๋˜๋Š” ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ)
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
"""
try:
# ์ด๋ฏธ์ง€ ๋กœ๋“œ (URL, ํŒŒ์ผ ๊ฒฝ๋กœ, PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด ๋˜๋Š” Base64)
if isinstance(image_source, str):
if image_source.startswith('http'):
# URL์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
response = requests.get(image_source)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
# ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
image = Image.open(image_source).convert('RGB')
else:
# ์ด๋ฏธ PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ
image = image_source.convert('RGB')
with torch.no_grad():
# ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
image_features = self.model.get_image_features(**inputs)
# ์ด๋ฏธ์ง€ ํŠน์„ฑ ์ •๊ทœํ™”
image_embeddings = image_features / image_features.norm(dim=1, keepdim=True)
return image_embeddings.cpu().numpy()
except Exception as e:
logger.error(f"์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return np.zeros((1, self.model.vision_embed_dim))
def calculate_similarity(self, text_embedding, image_embedding=None):
"""
ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
Args:
text_embedding (numpy.ndarray): ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ
image_embedding (numpy.ndarray, optional): ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ (์—†์œผ๋ฉด ํ…์ŠคํŠธ๋งŒ ๋น„๊ต)
Returns:
float: ์œ ์‚ฌ๋„ ์ ์ˆ˜ (0~1 ์‚ฌ์ด)
"""
if image_embedding is None:
# ํ…์ŠคํŠธ-ํ…์ŠคํŠธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„)
similarity = np.dot(text_embedding, text_embedding.T)[0, 0]
else:
# ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„)
similarity = np.dot(text_embedding, image_embedding.T)[0, 0]
# ์œ ์‚ฌ๋„๋ฅผ 0~1 ๋ฒ”์œ„๋กœ ์ •๊ทœํ™”
similarity = (similarity + 1) / 2
return float(similarity)
def encode_batch_texts(self, texts):
"""
์—ฌ๋Ÿฌ ํ…์ŠคํŠธ๋ฅผ ํ•œ ๋ฒˆ์— ์ž„๋ฒ ๋”ฉ
Args:
texts (list): ํ…์ŠคํŠธ ๋ชฉ๋ก
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ๋ฐฐ์—ด
"""
# ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์ฝ”๋“œ
# ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ๋ฉ”๋ชจ๋ฆฌ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ ์ ˆํ•œ ๋ฐฐ์น˜ ํฌ๊ธฐ ์กฐ์ • ํ•„์š”
return self.encode_text(texts)
# ๋ชจ๋“ˆ ํ…Œ์ŠคํŠธ์šฉ ์ฝ”๋“œ
if __name__ == "__main__":
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
clip_model = KoreanCLIPModel()
# ์ƒ˜ํ”Œ ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
sample_text = "๊ฒ€์€์ƒ‰ ์ง€๊ฐ‘์„ ์žƒ์–ด๋ฒ„๋ ธ์Šต๋‹ˆ๋‹ค. ํ˜„๊ธˆ๊ณผ ์นด๋“œ๊ฐ€ ๋“ค์–ด์žˆ์–ด์š”."
text_embedding = clip_model.encode_text(sample_text)
print(f"ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ shape: {text_embedding.shape}")
# ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (ํ…์ŠคํŠธ๋งŒ)
sample_text2 = "๊ฒ€์€์ƒ‰ ์ง€๊ฐ‘์„ ์ฐพ์•˜์Šต๋‹ˆ๋‹ค. ์•ˆ์— ํ˜„๊ธˆ๊ณผ ์นด๋“œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค."
text_embedding2 = clip_model.encode_text(sample_text2)
similarity = clip_model.calculate_similarity(text_embedding, text_embedding2)
print(f"ํ…์ŠคํŠธ ๊ฐ„ ์œ ์‚ฌ๋„: {similarity:.4f}")