Spaces:
Sleeping
Sleeping
| """ | |
| ํ๊ตญ์ด CLIP ๋ชจ๋ธ ๊ตฌํ | |
| ์ด ๋ชจ๋์ HuggingFace์ ํ๊ตญ์ด CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ ์คํธ์ ์ด๋ฏธ์ง์ ์๋ฒ ๋ฉ์ ์์ฑ | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| # ์บ์ ๋๋ ํ ๋ฆฌ ์ค์ | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' | |
| os.environ['HF_HOME'] = '/tmp/huggingface_cache' | |
| # ๋๋ ํ ๋ฆฌ ์์ฑ | |
| os.makedirs('/tmp/huggingface_cache', exist_ok=True) | |
| os.makedirs('/tmp/uploads', exist_ok=True) | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ๋ชจ๋ธ ์ค์ - ํ๊ฒฝ ๋ณ์์์ ๊ฐ์ ธ์ค๊ฑฐ๋ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko') | |
| DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu" | |
| class KoreanCLIPModel: | |
| """ | |
| ํ๊ตญ์ด CLIP ๋ชจ๋ธ ํด๋์ค | |
| ํ ์คํธ์ ์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉํ๊ณ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ ๊ธฐ๋ฅ ์ ๊ณต | |
| """ | |
| def __init__(self, model_name=CLIP_MODEL_NAME, device=DEVICE): | |
| """ | |
| CLIP ๋ชจ๋ธ ์ด๊ธฐํ | |
| Args: | |
| model_name (str): ์ฌ์ฉํ CLIP ๋ชจ๋ธ ์ด๋ฆ ๋๋ ๊ฒฝ๋ก | |
| device (str): ์ฌ์ฉํ ์ฅ์น ('cuda' ๋๋ 'cpu') | |
| """ | |
| self.device = device | |
| self.model_name = model_name | |
| logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค (device: {device})...") | |
| try: | |
| # ์บ์ ๋๋ ํ ๋ฆฌ ์ค์ | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" | |
| os.makedirs("/tmp/transformers_cache", exist_ok=True) | |
| self.model = CLIPModel.from_pretrained(model_name).to(device) | |
| self.processor = CLIPProcessor.from_pretrained(model_name) | |
| logger.info("CLIP ๋ชจ๋ธ ๋ก๋ ์๋ฃ") | |
| except Exception as e: | |
| logger.error(f"CLIP ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}") | |
| raise | |
| def encode_text(self, text): | |
| """ | |
| ํ ์คํธ๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ | |
| Args: | |
| text (str or list): ์ธ์ฝ๋ฉํ ํ ์คํธ ๋๋ ํ ์คํธ ๋ฆฌ์คํธ | |
| Returns: | |
| numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ | |
| """ | |
| if isinstance(text, str): | |
| text = [text] | |
| try: | |
| with torch.no_grad(): | |
| # ํ ์คํธ ์ธ์ฝ๋ฉ | |
| inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device) | |
| text_features = self.model.get_text_features(**inputs) | |
| # ํ ์คํธ ํน์ฑ ์ ๊ทํ | |
| text_embeddings = text_features / text_features.norm(dim=1, keepdim=True) | |
| return text_embeddings.cpu().numpy() | |
| except Exception as e: | |
| logger.error(f"ํ ์คํธ ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return np.zeros((len(text), self.model.text_embed_dim)) | |
| def encode_image(self, image_source): | |
| """ | |
| ์ด๋ฏธ์ง๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํ | |
| Args: | |
| image_source: ์ธ์ฝ๋ฉํ ์ด๋ฏธ์ง (PIL Image, URL ๋๋ ์ด๋ฏธ์ง ๊ฒฝ๋ก) | |
| Returns: | |
| numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ | |
| """ | |
| try: | |
| # ์ด๋ฏธ์ง ๋ก๋ (URL, ํ์ผ ๊ฒฝ๋ก, PIL ์ด๋ฏธ์ง ๊ฐ์ฒด ๋๋ Base64) | |
| if isinstance(image_source, str): | |
| if image_source.startswith('http'): | |
| # URL์์ ์ด๋ฏธ์ง ๋ก๋ | |
| response = requests.get(image_source) | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| else: | |
| # ๋ก์ปฌ ํ์ผ์์ ์ด๋ฏธ์ง ๋ก๋ | |
| image = Image.open(image_source).convert('RGB') | |
| else: | |
| # ์ด๋ฏธ PIL ์ด๋ฏธ์ง ๊ฐ์ฒด์ธ ๊ฒฝ์ฐ | |
| image = image_source.convert('RGB') | |
| with torch.no_grad(): | |
| # ์ด๋ฏธ์ง ์ธ์ฝ๋ฉ | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| image_features = self.model.get_image_features(**inputs) | |
| # ์ด๋ฏธ์ง ํน์ฑ ์ ๊ทํ | |
| image_embeddings = image_features / image_features.norm(dim=1, keepdim=True) | |
| return image_embeddings.cpu().numpy() | |
| except Exception as e: | |
| logger.error(f"์ด๋ฏธ์ง ์ธ์ฝ๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return np.zeros((1, self.model.vision_embed_dim)) | |
| def calculate_similarity(self, text_embedding, image_embedding=None): | |
| """ | |
| ํ ์คํธ์ ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ๊ฐ์ ์ ์ฌ๋ ๊ณ์ฐ | |
| Args: | |
| text_embedding (numpy.ndarray): ํ ์คํธ ์๋ฒ ๋ฉ | |
| image_embedding (numpy.ndarray, optional): ์ด๋ฏธ์ง ์๋ฒ ๋ฉ (์์ผ๋ฉด ํ ์คํธ๋ง ๋น๊ต) | |
| Returns: | |
| float: ์ ์ฌ๋ ์ ์ (0~1 ์ฌ์ด) | |
| """ | |
| if image_embedding is None: | |
| # ํ ์คํธ-ํ ์คํธ ์ ์ฌ๋ ๊ณ์ฐ (์ฝ์ฌ์ธ ์ ์ฌ๋) | |
| similarity = np.dot(text_embedding, text_embedding.T)[0, 0] | |
| else: | |
| # ํ ์คํธ-์ด๋ฏธ์ง ์ ์ฌ๋ ๊ณ์ฐ (์ฝ์ฌ์ธ ์ ์ฌ๋) | |
| similarity = np.dot(text_embedding, image_embedding.T)[0, 0] | |
| # ์ ์ฌ๋๋ฅผ 0~1 ๋ฒ์๋ก ์ ๊ทํ | |
| similarity = (similarity + 1) / 2 | |
| return float(similarity) | |
| def encode_batch_texts(self, texts): | |
| """ | |
| ์ฌ๋ฌ ํ ์คํธ๋ฅผ ํ ๋ฒ์ ์๋ฒ ๋ฉ | |
| Args: | |
| texts (list): ํ ์คํธ ๋ชฉ๋ก | |
| Returns: | |
| numpy.ndarray: ์๋ฒ ๋ฉ ๋ฒกํฐ ๋ฐฐ์ด | |
| """ | |
| # ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ํ ์ฝ๋ | |
| # ์ค์ ๊ตฌํ์์๋ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ์ ํ ๋ฐฐ์น ํฌ๊ธฐ ์กฐ์ ํ์ | |
| return self.encode_text(texts) | |
| # ๋ชจ๋ ํ ์คํธ์ฉ ์ฝ๋ | |
| if __name__ == "__main__": | |
| # ๋ชจ๋ธ ์ด๊ธฐํ | |
| clip_model = KoreanCLIPModel() | |
| # ์ํ ํ ์คํธ ์ธ์ฝ๋ฉ | |
| sample_text = "๊ฒ์์ ์ง๊ฐ์ ์์ด๋ฒ๋ ธ์ต๋๋ค. ํ๊ธ๊ณผ ์นด๋๊ฐ ๋ค์ด์์ด์." | |
| text_embedding = clip_model.encode_text(sample_text) | |
| print(f"ํ ์คํธ ์๋ฒ ๋ฉ shape: {text_embedding.shape}") | |
| # ์ ์ฌ๋ ๊ณ์ฐ (ํ ์คํธ๋ง) | |
| sample_text2 = "๊ฒ์์ ์ง๊ฐ์ ์ฐพ์์ต๋๋ค. ์์ ํ๊ธ๊ณผ ์นด๋๊ฐ ์์ต๋๋ค." | |
| text_embedding2 = clip_model.encode_text(sample_text2) | |
| similarity = clip_model.calculate_similarity(text_embedding, text_embedding2) | |
| print(f"ํ ์คํธ ๊ฐ ์ ์ฌ๋: {similarity:.4f}") |