asefasdfcv commited on
Commit
f965b35
ยท
verified ยท
1 Parent(s): 83c7141

Update models/clip_model.py

Browse files
Files changed (1) hide show
  1. models/clip_model.py +63 -22
models/clip_model.py CHANGED
@@ -30,6 +30,9 @@ logger = logging.getLogger(__name__)
30
  CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko')
31
  DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu"
32
 
 
 
 
33
  def preload_clip_model():
34
  """CLIP ๋ชจ๋ธ์„ ์‚ฌ์ „์— ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์บ์‹œ"""
35
  try:
@@ -40,7 +43,8 @@ def preload_clip_model():
40
  CLIPModel.from_pretrained(
41
  CLIP_MODEL_NAME,
42
  cache_dir='/tmp/huggingface_cache',
43
- low_cpu_mem_usage=True # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ์ตœ์ ํ™”
 
44
  )
45
 
46
  CLIPProcessor.from_pretrained(
@@ -62,6 +66,7 @@ class KoreanCLIPModel:
62
  """CLIP ๋ชจ๋ธ ์ดˆ๊ธฐํ™” - ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”"""
63
  self.device = device
64
  self.model_name = model_name
 
65
 
66
  logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘ (device: {device})...")
67
 
@@ -70,14 +75,20 @@ class KoreanCLIPModel:
70
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
71
  os.makedirs("/tmp/transformers_cache", exist_ok=True)
72
 
73
- # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์˜ต์…˜ ์ถ”๊ฐ€
74
  self.model = CLIPModel.from_pretrained(
75
  model_name,
76
  cache_dir='/tmp/huggingface_cache',
77
  low_cpu_mem_usage=True,
78
- torch_dtype=torch.float16 # ๋ฐ˜์ •๋ฐ€๋„ ์‚ฌ์šฉ
79
  ).to(device)
80
 
 
 
 
 
 
 
81
  self.processor = CLIPProcessor.from_pretrained(
82
  model_name,
83
  cache_dir='/tmp/huggingface_cache'
@@ -114,7 +125,8 @@ class KoreanCLIPModel:
114
  return text_embeddings.cpu().numpy()
115
  except Exception as e:
116
  logger.error(f"ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
117
- return np.zeros((len(text), self.model.text_embed_dim))
 
118
 
119
  def encode_image(self, image_source):
120
  """
@@ -130,12 +142,32 @@ class KoreanCLIPModel:
130
  # ์ด๋ฏธ์ง€ ๋กœ๋“œ (URL, ํŒŒ์ผ ๊ฒฝ๋กœ, PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด ๋˜๋Š” Base64)
131
  if isinstance(image_source, str):
132
  if image_source.startswith('http'):
133
- # URL์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
134
- response = requests.get(image_source)
135
- image = Image.open(BytesIO(response.content)).convert('RGB')
 
 
 
 
 
 
 
 
 
 
136
  else:
137
  # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
138
- image = Image.open(image_source).convert('RGB')
 
 
 
 
 
 
 
 
 
 
139
  else:
140
  # ์ด๋ฏธ PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ
141
  image = image_source.convert('RGB')
@@ -151,29 +183,38 @@ class KoreanCLIPModel:
151
  return image_embeddings.cpu().numpy()
152
  except Exception as e:
153
  logger.error(f"์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
154
- return np.zeros((1, self.model.vision_embed_dim))
 
155
 
156
- def calculate_similarity(self, text_embedding, image_embedding=None):
157
  """
158
- ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
159
 
160
  Args:
161
- text_embedding (numpy.ndarray): ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ
162
- image_embedding (numpy.ndarray, optional): ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ (์—†์œผ๋ฉด ํ…์ŠคํŠธ๋งŒ ๋น„๊ต)
163
 
164
  Returns:
165
  float: ์œ ์‚ฌ๋„ ์ ์ˆ˜ (0~1 ์‚ฌ์ด)
166
  """
167
- if image_embedding is None:
168
- # ํ…์ŠคํŠธ-ํ…์ŠคํŠธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„)
169
- similarity = np.dot(text_embedding, text_embedding.T)[0, 0]
170
- else:
171
- # ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ (์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„)
172
- similarity = np.dot(text_embedding, image_embedding.T)[0, 0]
173
 
174
- # ์œ ์‚ฌ๋„๋ฅผ 0~1 ๋ฒ”์œ„๋กœ ์ •๊ทœํ™”
175
- similarity = (similarity + 1) / 2
176
- return float(similarity)
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def encode_batch_texts(self, texts):
179
  """
 
30
  CLIP_MODEL_NAME = os.getenv('CLIP_MODEL_NAME', 'Bingsu/clip-vit-large-patch14-ko')
31
  DEVICE = "cuda" if torch.cuda.is_available() and os.getenv('USE_GPU', 'True').lower() == 'true' else "cpu"
32
 
33
+ # ์š”์ฒญ ํƒ€์ž„์•„์›ƒ ์„ค์ •
34
+ REQUEST_TIMEOUT = int(os.getenv('REQUEST_TIMEOUT', '10')) # 10์ดˆ ํƒ€์ž„์•„์›ƒ
35
+
36
  def preload_clip_model():
37
  """CLIP ๋ชจ๋ธ์„ ์‚ฌ์ „์— ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์บ์‹œ"""
38
  try:
 
43
  CLIPModel.from_pretrained(
44
  CLIP_MODEL_NAME,
45
  cache_dir='/tmp/huggingface_cache',
46
+ low_cpu_mem_usage=True, # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ์ตœ์ ํ™”
47
+ torch_dtype=torch.float32 # float32 ํƒ€์ž…์œผ๋กœ ํ†ต์ผ
48
  )
49
 
50
  CLIPProcessor.from_pretrained(
 
66
  """CLIP ๋ชจ๋ธ ์ดˆ๊ธฐํ™” - ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”"""
67
  self.device = device
68
  self.model_name = model_name
69
+ self.embedding_dim = None # ์ถ”๊ฐ€: ์ž„๋ฒ ๋”ฉ ์ฐจ์› ์ €์žฅ
70
 
71
  logger.info(f"CLIP ๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘ (device: {device})...")
72
 
 
75
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
76
  os.makedirs("/tmp/transformers_cache", exist_ok=True)
77
 
78
+ # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์˜ต์…˜ ์ถ”๊ฐ€ - float32 ํƒ€์ž…์œผ๋กœ ํ†ต์ผ
79
  self.model = CLIPModel.from_pretrained(
80
  model_name,
81
  cache_dir='/tmp/huggingface_cache',
82
  low_cpu_mem_usage=True,
83
+ torch_dtype=torch.float32 # float16์—์„œ float32๋กœ ๋ณ€๊ฒฝ
84
  ).to(device)
85
 
86
+ # ์ž„๋ฒ ๋”ฉ ์ฐจ์› ์ €์žฅ
87
+ self.text_embedding_dim = self.model.text_model.config.hidden_size
88
+ self.image_embedding_dim = self.model.vision_model.config.hidden_size
89
+
90
+ logger.info(f"ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ฐจ์›: {self.text_embedding_dim}, ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ ์ฐจ์›: {self.image_embedding_dim}")
91
+
92
  self.processor = CLIPProcessor.from_pretrained(
93
  model_name,
94
  cache_dir='/tmp/huggingface_cache'
 
125
  return text_embeddings.cpu().numpy()
126
  except Exception as e:
127
  logger.error(f"ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
128
+ # ์ฐจ์›์ด ์ผ์น˜ํ•˜๋Š” 0 ๋ฒกํ„ฐ ๋ฐ˜ํ™˜
129
+ return np.zeros((len(text), self.text_embedding_dim))
130
 
131
  def encode_image(self, image_source):
132
  """
 
142
  # ์ด๋ฏธ์ง€ ๋กœ๋“œ (URL, ํŒŒ์ผ ๊ฒฝ๋กœ, PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด ๋˜๋Š” Base64)
143
  if isinstance(image_source, str):
144
  if image_source.startswith('http'):
145
+ # URL์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ - ํƒ€์ž„์•„์›ƒ ์ถ”๊ฐ€
146
+ try:
147
+ response = requests.get(image_source, timeout=REQUEST_TIMEOUT)
148
+ if response.status_code == 200:
149
+ image = Image.open(BytesIO(response.content)).convert('RGB')
150
+ else:
151
+ logger.warning(f"์ด๋ฏธ์ง€ URL์—์„œ ์‘๋‹ต ์˜ค๋ฅ˜: {response.status_code}")
152
+ # ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (๊ฒ€์€์ƒ‰ ์ด๋ฏธ์ง€)
153
+ image = Image.new('RGB', (224, 224), color='black')
154
+ except requests.exceptions.RequestException as e:
155
+ logger.error(f"์ด๋ฏธ์ง€ URL ์ ‘๊ทผ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
156
+ # ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (๊ฒ€์€์ƒ‰ ์ด๋ฏธ์ง€)
157
+ image = Image.new('RGB', (224, 224), color='black')
158
  else:
159
  # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ๋กœ๋“œ
160
+ try:
161
+ if os.path.exists(image_source):
162
+ image = Image.open(image_source).convert('RGB')
163
+ else:
164
+ logger.warning(f"์ด๋ฏธ์ง€ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Œ: {image_source}")
165
+ # ํŒŒ์ผ์ด ์—†๋Š” ๊ฒฝ์šฐ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
166
+ image = Image.new('RGB', (224, 224), color='black')
167
+ except Exception as e:
168
+ logger.error(f"๋กœ์ปฌ ์ด๋ฏธ์ง€ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜: {str(e)}")
169
+ # ์˜ค๋ฅ˜ ์‹œ ๋”๋ฏธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
170
+ image = Image.new('RGB', (224, 224), color='black')
171
  else:
172
  # ์ด๋ฏธ PIL ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ
173
  image = image_source.convert('RGB')
 
183
  return image_embeddings.cpu().numpy()
184
  except Exception as e:
185
  logger.error(f"์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
186
+ # ์ฐจ์›์ด ์ผ์น˜ํ•˜๋Š” 0 ๋ฒกํ„ฐ ๋ฐ˜ํ™˜
187
+ return np.zeros((1, self.image_embedding_dim))
188
 
189
+ def calculate_similarity(self, embedding1, embedding2):
190
  """
191
+ ๋‘ ์ž„๋ฒ ๋”ฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
192
 
193
  Args:
194
+ embedding1 (numpy.ndarray): ์ฒซ ๋ฒˆ์งธ ์ž„๋ฒ ๋”ฉ
195
+ embedding2 (numpy.ndarray): ๋‘ ๋ฒˆ์งธ ์ž„๋ฒ ๋”ฉ
196
 
197
  Returns:
198
  float: ์œ ์‚ฌ๋„ ์ ์ˆ˜ (0~1 ์‚ฌ์ด)
199
  """
200
+ try:
201
+ # ์ฐจ์› ํ™•์ธ ๋ฐ ๋กœ๊น…
202
+ logger.debug(f"์ž„๋ฒ ๋”ฉ1 shape: {embedding1.shape}, ์ž„๋ฒ ๋”ฉ2 shape: {embedding2.shape}")
 
 
 
203
 
204
+ # ์ฐจ์›์ด ๋‹ค๋ฅธ ๊ฒฝ์šฐ ์˜ˆ์™ธ ์ฒ˜๋ฆฌ - ์ฐจ์›์ด ๋งž์ง€ ์•Š์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ๋ฐ˜ํ™˜
205
+ if embedding1.shape[1] != embedding2.shape[1]:
206
+ logger.warning(f"์ž„๋ฒ ๋”ฉ ์ฐจ์› ๋ถˆ์ผ์น˜: {embedding1.shape} vs {embedding2.shape}")
207
+ return 0.5
208
+
209
+ # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
210
+ similarity = np.dot(embedding1, embedding2.T)[0, 0]
211
+
212
+ # ์œ ์‚ฌ๋„๋ฅผ 0~1 ๋ฒ”์œ„๋กœ ์ •๊ทœํ™”
213
+ similarity = (similarity + 1) / 2
214
+ return float(similarity)
215
+ except Exception as e:
216
+ logger.error(f"์œ ์‚ฌ๋„ ๊ณ„์‚ฐ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
217
+ return 0.5 # ์˜ค๋ฅ˜ ์‹œ ์ค‘๊ฐ„๊ฐ’ ๋ฐ˜ํ™˜
218
 
219
  def encode_batch_texts(self, texts):
220
  """