ssfinder-matching / models /clip_model.py
asefasdfcv's picture
Update models/clip_model.py
f965b35 verified
raw
history blame
10.5 kB
"""
ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ ๊ตฌํ˜„
์ด ๋ชจ๋“ˆ์€ HuggingFace์˜ ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€์˜ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑ
"""
import os
import sys
import logging
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import time
# ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs('/tmp/huggingface_cache', exist_ok=True)
os.makedirs('/tmp/uploads', exist_ok=True)
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ๋ชจ๋ธ ์„ค์ • - ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜ค๊ฑฐ๋‚˜ ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ
CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko')
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu"
# ์š”์ฒญ ํƒ€์ž„์•„์›ƒ ์„ค์ •
REQUEST_TIMEOUT = int(os.getenv('REQUEST_TIMEOUT', '10')) # 10์ดˆ ํƒ€์ž„์•„์›ƒ
def preload_clip_model():
"""CLIP ๋ชจ๋ธ์„ ์‚ฌ์ „์— ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์บ์‹œ"""
try:
start_time = time.time()
logger.info(f"CLIP ๋ชจ๋ธ ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ ์‹œ์ž‘: {CLIP_MODEL_NAME}")
# ๋ชจ๋ธ๊ณผ ํ”„๋กœ์„ธ์„œ ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ
CLIPModel.from_pretrained(
CLIP_MODEL_NAME,
cache_dir='/tmp/huggingface_cache',
low_cpu_mem_usage=True, # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ์ตœ์ ํ™”
torch_dtype=torch.float32 # float32 ํƒ€์ž…์œผ๋กœ ํ†ต์ผ
)
CLIPProcessor.from_pretrained(
CLIP_MODEL_NAME,
cache_dir='/tmp/huggingface_cache'
)
logger.info(f"โœ… CLIP ๋ชจ๋ธ ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ ์™„๋ฃŒ (์†Œ์š”์‹œ๊ฐ„: {time.time() - start_time:.2f}์ดˆ)")
except Exception as e:
logger.error(f"โŒ CLIP ๋ชจ๋ธ ์‚ฌ์ „ ๋‹ค์šด๋กœ๋“œ ์‹คํŒจ: {str(e)}")
class KoreanCLIPModel:
"""
ํ•œ๊ตญ์–ด CLIP ๋ชจ๋ธ ํด๋ž˜์Šค
ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€๋ฅผ ์ž„๋ฒ ๋”ฉํ•˜๊ณ  ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๊ธฐ๋Šฅ ์ œ๊ณต
"""
def __init__(self, model_name=CLIP_MODEL_NAME, device=DEVICE):
"""CLIP ๋ชจ๋ธ ์ดˆ๊ธฐํ™” - ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”"""
self.device = device
self.model_name = model_name
self.embedding_dim = None # ์ถ”๊ฐ€: ์ž„๋ฒ ๋”ฉ ์ฐจ์› ์ €์žฅ
logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘ (device: {device})...")
try:
# ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.makedirs("/tmp/transformers_cache", exist_ok=True)
# ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์˜ต์…˜ ์ถ”๊ฐ€ - float32 ํƒ€์ž…์œผ๋กœ ํ†ต์ผ
self.model = CLIPModel.from_pretrained(
model_name,
cache_dir='/tmp/huggingface_cache',
low_cpu_mem_usage=True,
torch_dtype=torch.float32 # float16์—์„œ float32๋กœ ๋ณ€๊ฒฝ
).to(device)
# ์ž„๋ฒ ๋”ฉ ์ฐจ์› ์ €์žฅ
self.text_embedding_dim = self.model.text_model.config.hidden_size
self.image_embedding_dim = self.model.vision_model.config.hidden_size
logger.info(f"ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ฐจ์›: {self.text_embedding_dim}, ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ ์ฐจ์›: {self.image_embedding_dim}")
self.processor = CLIPProcessor.from_pretrained(
model_name,
cache_dir='/tmp/huggingface_cache'
)
logger.info("CLIP ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except Exception as e:
logger.error(f"CLIP ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
import traceback
logger.error(traceback.format_exc())
raise
def encode_text(self, text):
"""
ํ…์ŠคํŠธ๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
Args:
text (str or list): ์ธ์ฝ”๋”ฉํ•  ํ…์ŠคํŠธ ๋˜๋Š” ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
"""
if isinstance(text, str):
text = [text]
try:
with torch.no_grad():
# ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
text_features = self.model.get_text_features(**inputs)
# ํ…์ŠคํŠธ ํŠน์„ฑ ์ •๊ทœํ™”
text_embeddings = text_features / text_features.norm(dim=1, keepdim=True)
return text_embeddings.cpu().numpy()
except Exception as e:
logger.error(f"ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
# ์ฐจ์›์ด ์ผ์น˜ํ•˜๋Š” 0 ๋ฒกํ„ฐ ๋ฐ˜ํ™˜
return np.zeros((len(text), self.text_embedding_dim))
def encode_image(self, image_source):
"""
์ด๋ฏธ์ง€๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
Args:
image_source: ์ธ์ฝ”๋”ฉํ•  ์ด๋ฏธ์ง€ (PIL Image, URL ๋˜๋Š” ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ)
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ
"""
try:
# ์ด๋ฏธ์ง€ ๋กœ๋“œ (URL, ํŒŒ์ผ ๊ฒฝ๋กœ, PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด ๋˜๋Š” Base64)
if isinstance(image_source, str):
if image_source.startswith('http'):
# URL์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ - ํƒ€์ž„์•„์›ƒ ์ถ”๊ฐ€
try:
response = requests.get(image_source, timeout=REQUEST_TIMEOUT)
if response.status_code == 200:
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
logger.warning(f"์ด๋ฏธ์ง€ URL์—์„œ ์‘๋‹ต ์˜ค๋ฅ˜: {response.status_code}")
# ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (๊ฒ€์€์ƒ‰ ์ด๋ฏธ์ง€)
image = Image.new('RGB', (224, 224), color='black')
except requests.exceptions.RequestException as e:
logger.error(f"์ด๋ฏธ์ง€ URL ์ ‘๊ทผ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
# ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (๊ฒ€์€์ƒ‰ ์ด๋ฏธ์ง€)
image = Image.new('RGB', (224, 224), color='black')
else:
# ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
try:
if os.path.exists(image_source):
image = Image.open(image_source).convert('RGB')
else:
logger.warning(f"์ด๋ฏธ์ง€ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Œ: {image_source}")
# ํŒŒ์ผ์ด ์—†๋Š” ๊ฒฝ์šฐ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
image = Image.new('RGB', (224, 224), color='black')
except Exception as e:
logger.error(f"๋กœ์ปฌ ์ด๋ฏธ์ง€ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
# ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
image = Image.new('RGB', (224, 224), color='black')
else:
# ์ด๋ฏธ PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ
image = image_source.convert('RGB')
with torch.no_grad():
# ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
image_features = self.model.get_image_features(**inputs)
# ์ด๋ฏธ์ง€ ํŠน์„ฑ ์ •๊ทœํ™”
image_embeddings = image_features / image_features.norm(dim=1, keepdim=True)
return image_embeddings.cpu().numpy()
except Exception as e:
logger.error(f"์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
# ์ฐจ์›์ด ์ผ์น˜ํ•˜๋Š” 0 ๋ฒกํ„ฐ ๋ฐ˜ํ™˜
return np.zeros((1, self.image_embedding_dim))
def calculate_similarity(self, embedding1, embedding2):
"""
๋‘ ์ž„๋ฒ ๋”ฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
Args:
embedding1 (numpy.ndarray): ์ฒซ ๋ฒˆ์งธ ์ž„๋ฒ ๋”ฉ
embedding2 (numpy.ndarray): ๋‘ ๋ฒˆ์งธ ์ž„๋ฒ ๋”ฉ
Returns:
float: ์œ ์‚ฌ๋„ ์ ์ˆ˜ (0~1 ์‚ฌ์ด)
"""
try:
# ์ฐจ์› ํ™•์ธ ๋ฐ ๋กœ๊น…
logger.debug(f"์ž„๋ฒ ๋”ฉ1 shape: {embedding1.shape}, ์ž„๋ฒ ๋”ฉ2 shape: {embedding2.shape}")
# ์ฐจ์›์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ - ์ฐจ์›์ด ๋งž์ง€ ์•Š์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ๋ฐ˜ํ™˜
if embedding1.shape[1] != embedding2.shape[1]:
logger.warning(f"์ž„๋ฒ ๋”ฉ ์ฐจ์› ๋ถˆ์ผ์น˜: {embedding1.shape} vs {embedding2.shape}")
return 0.5
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarity = np.dot(embedding1, embedding2.T)[0, 0]
# ์œ ์‚ฌ๋„๋ฅผ 0~1 ๋ฒ”์œ„๋กœ ์ •๊ทœํ™”
similarity = (similarity + 1) / 2
return float(similarity)
except Exception as e:
logger.error(f"์œ ์‚ฌ๋„ ๊ณ„์‚ฐ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return 0.5 # ์˜ค๋ฅ˜ ์‹œ ์ค‘๊ฐ„๊ฐ’ ๋ฐ˜ํ™˜
def encode_batch_texts(self, texts):
"""
์—ฌ๋Ÿฌ ํ…์ŠคํŠธ๋ฅผ ํ•œ ๋ฒˆ์— ์ž„๋ฒ ๋”ฉ
Args:
texts (list): ํ…์ŠคํŠธ ๋ชฉ๋ก
Returns:
numpy.ndarray: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ๋ฐฐ์—ด
"""
# ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์ฝ”๋“œ
# ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ๋ฉ”๋ชจ๋ฆฌ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ ์ ˆํ•œ ๋ฐฐ์น˜ ํฌ๊ธฐ ์กฐ์ • ํ•„์š”
return self.encode_text(texts)
# ๋ชจ๋“ˆ ํ…Œ์ŠคํŠธ์šฉ ์ฝ”๋“œ
if __name__ == "__main__":
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
clip_model = KoreanCLIPModel()
# ์ƒ˜ํ”Œ ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
sample_text = "๊ฒ€์€์ƒ‰ ์ง€๊ฐ‘์„ ์žƒ์–ด๋ฒ„๋ ธ์Šต๋‹ˆ๋‹ค. ํ˜„๊ธˆ๊ณผ ์นด๋“œ๊ฐ€ ๋“ค์–ด์žˆ์–ด์š”."
text_embedding = clip_model.encode_text(sample_text)
print(f"ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ shape: {text_embedding.shape}")
# ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (ํ…์ŠคํŠธ๋งŒ)
sample_text2 = "๊ฒ€์€์ƒ‰ ์ง€๊ฐ‘์„ ์ฐพ์•˜์Šต๋‹ˆ๋‹ค. ์•ˆ์— ํ˜„๊ธˆ๊ณผ ์นด๋“œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค."
text_embedding2 = clip_model.encode_text(sample_text2)
similarity = clip_model.calculate_similarity(text_embedding, text_embedding2)
print(f"ํ…์ŠคํŠธ ๊ฐ„ ์œ ์‚ฌ๋„: {similarity:.4f}")