Spaces:
Sleeping
Sleeping
""" | |
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ๊ตฌํ | |
์ด ๋ชจ๋์ HuggingFace์ ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ ์คํธ์ ์ด๋ฏธ์ง์ ์๋ฒ ๋ฉ์ ์์ฑ | |
""" | |
import os | |
import sys | |
import logging | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import numpy as np | |
import time | |
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์ | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' | |
os.environ['HF_HOME'] = '/tmp/huggingface_cache' | |
# ๋๋ ํ ๋ฆฌ ์์ฑ | |
os.makedirs('/tmp/huggingface_cache', exist_ok=True) | |
os.makedirs('/tmp/uploads', exist_ok=True) | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# ๋ชจ๋ธ ์ค์ - ํ๊ฒฝ ๋ณ์์์ ๊ฐ์ ธ์ค๊ฑฐ๋ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko') | |
DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu" | |
# ์์ฒญ ํ์์์ ์ค์ | |
REQUEST_TIMEOUT = int(os.getenv('REQUEST_TIMEOUT', '10')) # 10์ด ํ์์์ | |
def preload_clip_model(): | |
"""CLIP ๋ชจ๋ธ์ ์ฌ์ ์ ๋ค์ด๋ก๋ํ๊ณ ์บ์""" | |
try: | |
start_time = time.time() | |
logger.info(f"CLIP ๋ชจ๋ธ ์ฌ์ ๋ค์ด๋ก๋ ์์: {CLIP_MODEL_NAME}") | |
# ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์ ์ฌ์ ๋ค์ด๋ก๋ | |
CLIPModel.from_pretrained( | |
CLIP_MODEL_NAME, | |
cache_dir='/tmp/huggingface_cache', | |
low_cpu_mem_usage=True, # ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ ์ต์ ํ | |
torch_dtype=torch.float32 # float32 ํ์ ์ผ๋ก ํต์ผ | |
) | |
CLIPProcessor.from_pretrained( | |
CLIP_MODEL_NAME, | |
cache_dir='/tmp/huggingface_cache' | |
) | |
logger.info(f"โ CLIP ๋ชจ๋ธ ์ฌ์ ๋ค์ด๋ก๋ ์๋ฃ (์์์๊ฐ: {time.time() - start_time:.2f}์ด)") | |
except Exception as e: | |
logger.error(f"โ CLIP ๋ชจ๋ธ ์ฌ์ ๋ค์ด๋ก๋ ์คํจ: {str(e)}") | |
class KoreanCLIPModel: | |
""" | |
ํ๊ตญ์ด CLIP ๋ชจ๋ธ ํด๋์ค | |
ํ ์คํธ์ ์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉํ๊ณ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ ๊ธฐ๋ฅ ์ ๊ณต | |
""" | |
def __init__(self, model_name=CLIP_MODEL_NAME, device=DEVICE): | |
"""CLIP ๋ชจ๋ธ ์ด๊ธฐํ - ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ""" | |
self.device = device | |
self.model_name = model_name | |
self.embedding_dim = None # ์ถ๊ฐ: ์๋ฒ ๋ฉ ์ฐจ์ ์ ์ฅ | |
logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค (device: {device})...") | |
try: | |
# ์บ์ ๋๋ ํ ๋ฆฌ ์ค์ | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" | |
os.makedirs("/tmp/transformers_cache", exist_ok=True) | |
# ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ ์ต์ ์ถ๊ฐ - float32 ํ์ ์ผ๋ก ํต์ผ | |
self.model = CLIPModel.from_pretrained( | |
model_name, | |
cache_dir='/tmp/huggingface_cache', | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 # float16์์ float32๋ก ๋ณ๊ฒฝ | |
).to(device) | |
# ์๋ฒ ๋ฉ ์ฐจ์ ์ ์ฅ | |
self.text_embedding_dim = self.model.text_model.config.hidden_size | |
self.image_embedding_dim = self.model.vision_model.config.hidden_size | |
logger.info(f"ํ ์คํธ ์๋ฒ ๋ฉ ์ฐจ์: {self.text_embedding_dim}, ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์ฐจ์: {self.image_embedding_dim}") | |
self.processor = CLIPProcessor.from_pretrained( | |
model_name, | |
cache_dir='/tmp/huggingface_cache' | |
) | |
logger.info("CLIP ๋ชจ๋ธ ๋ก๋ ์๋ฃ") | |
except Exception as e: | |
logger.error(f"CLIP ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
raise | |
def encode_text(self, text): | |
""" | |
ํ ์คํธ๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ | |
Args: | |
text (str or list): ์ธ์ฝ๋ฉํ ํ ์คํธ ๋๋ ํ ์คํธ ๋ฆฌ์คํธ | |
Returns: | |
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ | |
""" | |
if isinstance(text, str): | |
text = [text] | |
try: | |
with torch.no_grad(): | |
# ํ ์คํธ ์ธ์ฝ๋ฉ | |
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device) | |
text_features = self.model.get_text_features(**inputs) | |
# ํ ์คํธ ํน์ฑ ์ ๊ทํ | |
text_embeddings = text_features / text_features.norm(dim=1, keepdim=True) | |
return text_embeddings.cpu().numpy() | |
except Exception as e: | |
logger.error(f"ํ ์คํธ ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
# ์ฐจ์์ด ์ผ์นํ๋ 0 ๋ฒกํฐ ๋ฐํ | |
return np.zeros((len(text), self.text_embedding_dim)) | |
def encode_image(self, image_source): | |
""" | |
์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ | |
Args: | |
image_source: ์ธ์ฝ๋ฉํ ์ด๋ฏธ์ง (PIL Image, URL ๋๋ ์ด๋ฏธ์ง ๊ฒฝ๋ก) | |
Returns: | |
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ | |
""" | |
try: | |
# ์ด๋ฏธ์ง ๋ก๋ (URL, ํ์ผ ๊ฒฝ๋ก, PIL ์ด๋ฏธ์ง ๊ฐ์ฒด ๋๋ Base64) | |
if isinstance(image_source, str): | |
if image_source.startswith('http'): | |
# URL์์ ์ด๋ฏธ์ง ๋ก๋ - ํ์์์ ์ถ๊ฐ | |
try: | |
response = requests.get(image_source, timeout=REQUEST_TIMEOUT) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
else: | |
logger.warning(f"์ด๋ฏธ์ง URL์์ ์๋ต ์ค๋ฅ: {response.status_code}") | |
# ์ค๋ฅ ์ ๋๋ฏธ ์ด๋ฏธ์ง ์์ฑ (๊ฒ์์ ์ด๋ฏธ์ง) | |
image = Image.new('RGB', (224, 224), color='black') | |
except requests.exceptions.RequestException as e: | |
logger.error(f"์ด๋ฏธ์ง URL ์ ๊ทผ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
# ์ค๋ฅ ์ ๋๋ฏธ ์ด๋ฏธ์ง ์์ฑ (๊ฒ์์ ์ด๋ฏธ์ง) | |
image = Image.new('RGB', (224, 224), color='black') | |
else: | |
# ๋ก์ปฌ ํ์ผ์์ ์ด๋ฏธ์ง ๋ก๋ | |
try: | |
if os.path.exists(image_source): | |
image = Image.open(image_source).convert('RGB') | |
else: | |
logger.warning(f"์ด๋ฏธ์ง ํ์ผ์ด ์กด์ฌํ์ง ์์: {image_source}") | |
# ํ์ผ์ด ์๋ ๊ฒฝ์ฐ ๋๋ฏธ ์ด๋ฏธ์ง ์์ฑ | |
image = Image.new('RGB', (224, 224), color='black') | |
except Exception as e: | |
logger.error(f"๋ก์ปฌ ์ด๋ฏธ์ง ๋ก๋ ์ค ์ค๋ฅ: {str(e)}") | |
# ์ค๋ฅ ์ ๋๋ฏธ ์ด๋ฏธ์ง ์์ฑ | |
image = Image.new('RGB', (224, 224), color='black') | |
else: | |
# ์ด๋ฏธ PIL ์ด๋ฏธ์ง ๊ฐ์ฒด์ธ ๊ฒฝ์ฐ | |
image = image_source.convert('RGB') | |
with torch.no_grad(): | |
# ์ด๋ฏธ์ง ์ธ์ฝ๋ฉ | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
image_features = self.model.get_image_features(**inputs) | |
# ์ด๋ฏธ์ง ํน์ฑ ์ ๊ทํ | |
image_embeddings = image_features / image_features.norm(dim=1, keepdim=True) | |
return image_embeddings.cpu().numpy() | |
except Exception as e: | |
logger.error(f"์ด๋ฏธ์ง ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
# ์ฐจ์์ด ์ผ์นํ๋ 0 ๋ฒกํฐ ๋ฐํ | |
return np.zeros((1, self.image_embedding_dim)) | |
def calculate_similarity(self, embedding1, embedding2): | |
""" | |
๋ ์๋ฒ ๋ฉ ๊ฐ์ ์ ์ฌ๋ ๊ณ์ฐ | |
Args: | |
embedding1 (numpy.ndarray): ์ฒซ ๋ฒ์งธ ์๋ฒ ๋ฉ | |
embedding2 (numpy.ndarray): ๋ ๋ฒ์งธ ์๋ฒ ๋ฉ | |
Returns: | |
float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด) | |
""" | |
try: | |
# ์ฐจ์ ํ์ธ ๋ฐ ๋ก๊น | |
logger.debug(f"์๋ฒ ๋ฉ1 shape: {embedding1.shape}, ์๋ฒ ๋ฉ2 shape: {embedding2.shape}") | |
# ์ฐจ์์ด ๋ค๋ฅธ ๊ฒฝ์ฐ ์์ธ ์ฒ๋ฆฌ - ์ฐจ์์ด ๋ง์ง ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ๋ฐํ | |
if embedding1.shape[1] != embedding2.shape[1]: | |
logger.warning(f"์๋ฒ ๋ฉ ์ฐจ์ ๋ถ์ผ์น: {embedding1.shape} vs {embedding2.shape}") | |
return 0.5 | |
# ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ | |
similarity = np.dot(embedding1, embedding2.T)[0, 0] | |
# ์ ์ฌ๋๋ฅผ 0~1 ๋ฒ์๋ก ์ ๊ทํ | |
similarity = (similarity + 1) / 2 | |
return float(similarity) | |
except Exception as e: | |
logger.error(f"์ ์ฌ๋ ๊ณ์ฐ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
return 0.5 # ์ค๋ฅ ์ ์ค๊ฐ๊ฐ ๋ฐํ | |
def encode_batch_texts(self, texts): | |
""" | |
์ฌ๋ฌ ํ ์คํธ๋ฅผ ํ ๋ฒ์ ์๋ฒ ๋ฉ | |
Args: | |
texts (list): ํ ์คํธ ๋ชฉ๋ก | |
Returns: | |
numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ ๋ฐฐ์ด | |
""" | |
# ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ํ ์ฝ๋ | |
# ์ค์ ๊ตฌํ์์๋ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ์ ํ ๋ฐฐ์น ํฌ๊ธฐ ์กฐ์ ํ์ | |
return self.encode_text(texts) | |
# ๋ชจ๋ ํ ์คํธ์ฉ ์ฝ๋ | |
if __name__ == "__main__": | |
# ๋ชจ๋ธ ์ด๊ธฐํ | |
clip_model = KoreanCLIPModel() | |
# ์ํ ํ ์คํธ ์ธ์ฝ๋ฉ | |
sample_text = "๊ฒ์์ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค. ํ๊ธ๊ณผ ์นด๋๊ฐ ๋ค์ด์์ด์." | |
text_embedding = clip_model.encode_text(sample_text) | |
print(f"ํ ์คํธ ์๋ฒ ๋ฉ shape: {text_embedding.shape}") | |
# ์ ์ฌ๋ ๊ณ์ฐ (ํ ์คํธ๋ง) | |
sample_text2 = "๊ฒ์์ ์ง๊ฐ์ ์ฐพ์์ต๋๋ค. ์์ ํ๊ธ๊ณผ ์นด๋๊ฐ ์์ต๋๋ค." | |
text_embedding2 = clip_model.encode_text(sample_text2) | |
similarity = clip_model.calculate_similarity(text_embedding, text_embedding2) | |
print(f"ํ ์คํธ ๊ฐ ์ ์ฌ๋: {similarity:.4f}") |