leedoming commited on
Commit
badc60b
โ€ข
1 Parent(s): 0786d34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +433 -443
app.py CHANGED
@@ -1,444 +1,434 @@
1
- import streamlit as st
2
- import open_clip
3
- import torch
4
- from PIL import Image
5
- import numpy as np
6
- from transformers import pipeline
7
- import chromadb
8
- import logging
9
- import io
10
- import requests
11
- from concurrent.futures import ThreadPoolExecutor
12
-
13
- # ๋กœ๊น… ์„ค์ •
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- # Initialize session state
18
- if 'image' not in st.session_state:
19
- st.session_state.image = None
20
- if 'detected_items' not in st.session_state:
21
- st.session_state.detected_items = None
22
- if 'selected_item_index' not in st.session_state:
23
- st.session_state.selected_item_index = None
24
- if 'upload_state' not in st.session_state:
25
- st.session_state.upload_state = 'initial'
26
- if 'search_clicked' not in st.session_state:
27
- st.session_state.search_clicked = False
28
-
29
- # Load models
30
- @st.cache_resource
31
- def load_models():
32
- try:
33
- # CLIP ๋ชจ๋ธ
34
- model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
35
-
36
- # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
37
- segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
38
-
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- model.to(device)
41
-
42
- return model, preprocess_val, segmenter, device
43
- except Exception as e:
44
- logger.error(f"Error loading models: {e}")
45
- raise
46
-
47
- # ๋ชจ๋ธ ๋กœ๋“œ
48
- clip_model, preprocess_val, segmenter, device = load_models()
49
-
50
- # ChromaDB ์„ค์ •
51
- client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
52
- collection = client.get_collection(name="clothes")
53
-
54
- def process_segmentation(image):
55
- """Segmentation processing"""
56
- try:
57
- # pipeline ์ถœ๋ ฅ ๊ฒฐ๊ณผ ์ง์ ‘ ์ฒ˜๋ฆฌ
58
- output = segmenter(image)
59
-
60
- if not output:
61
- logger.warning("No segments found in image")
62
- return None
63
-
64
- # ๊ฐ ์„ธ๊ทธ๋จผํŠธ์˜ ๋งˆ์Šคํฌ ํฌ๊ธฐ ๊ณ„์‚ฐ
65
- segment_sizes = [np.sum(seg['mask']) for seg in output]
66
-
67
- if not segment_sizes:
68
- return None
69
-
70
- # ๊ฐ€์žฅ ํฐ ์„ธ๊ทธ๋จผํŠธ ์„ ํƒ
71
- largest_idx = np.argmax(segment_sizes)
72
- mask = output[largest_idx]['mask']
73
-
74
- # ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
75
- if not isinstance(mask, np.ndarray):
76
- mask = np.array(mask)
77
-
78
- # ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
79
- if len(mask.shape) > 2:
80
- mask = mask[:, :, 0]
81
-
82
- # bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
83
- mask = mask.astype(float)
84
-
85
- logger.info(f"Successfully created mask with shape {mask.shape}")
86
- return mask
87
-
88
- except Exception as e:
89
- logger.error(f"Segmentation error: {str(e)}")
90
- import traceback
91
- logger.error(traceback.format_exc())
92
- return None
93
-
94
- def download_and_process_image(image_url, metadata_id):
95
- """Download image from URL and apply segmentation"""
96
- try:
97
- response = requests.get(image_url, timeout=10) # timeout ์ถ”๊ฐ€
98
- if response.status_code != 200:
99
- logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
100
- return None
101
-
102
- image = Image.open(io.BytesIO(response.content)).convert('RGB')
103
- logger.info(f"Successfully downloaded image {metadata_id}")
104
-
105
- mask = process_segmentation(image)
106
- if mask is not None:
107
- features = extract_features(image, mask)
108
- logger.info(f"Successfully extracted features for image {metadata_id}")
109
- return features
110
-
111
- logger.warning(f"No valid mask found for image {metadata_id}")
112
- return None
113
-
114
- except Exception as e:
115
- logger.error(f"Error processing image {metadata_id}: {str(e)}")
116
- import traceback
117
- logger.error(traceback.format_exc())
118
- return None
119
-
120
- def update_db_with_segmentation():
121
- """DB์˜ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด segmentation์„ ์ ์šฉํ•˜๊ณ  feature๋ฅผ ์—…๋ฐ์ดํŠธ"""
122
- try:
123
- logger.info("Starting database update with segmentation")
124
-
125
- # ์ƒˆ๋กœ์šด collection ์ƒ์„ฑ
126
- try:
127
- client.delete_collection("clothes_segmented")
128
- logger.info("Deleted existing segmented collection")
129
- except:
130
- logger.info("No existing segmented collection to delete")
131
-
132
- new_collection = client.create_collection(
133
- name="clothes_segmented",
134
- metadata={"description": "Clothes collection with segmentation-based features"}
135
- )
136
- logger.info("Created new segmented collection")
137
-
138
- # ๊ธฐ์กด collection์—์„œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
139
- try:
140
- all_items = collection.get(include=['metadatas'])
141
- total_items = len(all_items['metadatas'])
142
- logger.info(f"Found {total_items} items in database")
143
- except Exception as e:
144
- logger.error(f"Error getting items from collection: {str(e)}")
145
- # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋นˆ ๋ฆฌ์ŠคํŠธ๋กœ ์ดˆ๊ธฐํ™”
146
- all_items = {'metadatas': []}
147
- total_items = 0
148
-
149
- # ์ง„ํ–‰ ์ƒํ™ฉ ํ‘œ์‹œ๋ฅผ ์œ„ํ•œ progress bar
150
- progress_bar = st.progress(0)
151
- status_text = st.empty()
152
-
153
- successful_updates = 0
154
- failed_updates = 0
155
-
156
- with ThreadPoolExecutor(max_workers=4) as executor:
157
- futures = []
158
- # ์ด๋ฏธ์ง€ URL์ด ์žˆ๋Š” ํ•ญ๋ชฉ๋งŒ ์ฒ˜๋ฆฌ
159
- valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
160
-
161
- for metadata in valid_items:
162
- future = executor.submit(
163
- download_and_process_image,
164
- metadata['image_url'],
165
- metadata.get('id', 'unknown')
166
- )
167
- futures.append((metadata, future))
168
-
169
- # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ ๋ฐ ์ƒˆ DB์— ์ €์žฅ
170
- for idx, (metadata, future) in enumerate(futures):
171
- try:
172
- new_features = future.result()
173
- if new_features is not None:
174
- item_id = metadata.get('id', str(hash(metadata['image_url'])))
175
- try:
176
- new_collection.add(
177
- embeddings=[new_features.tolist()],
178
- metadatas=[metadata],
179
- ids=[item_id]
180
- )
181
- successful_updates += 1
182
- logger.info(f"Successfully added item {item_id}")
183
- except Exception as e:
184
- logger.error(f"Error adding item to new collection: {str(e)}")
185
- failed_updates += 1
186
- else:
187
- failed_updates += 1
188
-
189
- # ์ง„ํ–‰ ์ƒํ™ฉ ์—…๋ฐ์ดํŠธ
190
- progress = (idx + 1) / len(futures)
191
- progress_bar.progress(progress)
192
- status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
193
-
194
- except Exception as e:
195
- logger.error(f"Error processing item: {str(e)}")
196
- failed_updates += 1
197
- continue
198
-
199
- # ์ตœ์ข… ๊ฒฐ๊ณผ ํ‘œ์‹œ
200
- status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
201
- logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
202
-
203
- # ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋œ ํ•ญ๋ชฉ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
204
- if successful_updates > 0:
205
- return True
206
- else:
207
- logger.error("No items were successfully processed")
208
- return False
209
-
210
- except Exception as e:
211
- logger.error(f"Database update error: {str(e)}")
212
- import traceback
213
- logger.error(traceback.format_exc())
214
- return False
215
-
216
- def extract_features(image, mask=None):
217
- """Extract CLIP features with segmentation mask"""
218
- try:
219
- if mask is not None:
220
- img_array = np.array(image)
221
- mask = np.expand_dims(mask, axis=2)
222
- masked_img = img_array * mask
223
- masked_img[mask[:,:,0] == 0] = 255 # ๋ฐฐ๊ฒฝ์„ ํฐ์ƒ‰์œผ๋กœ
224
- image = Image.fromarray(masked_img.astype(np.uint8))
225
-
226
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
227
- with torch.no_grad():
228
- features = clip_model.encode_image(image_tensor)
229
- features /= features.norm(dim=-1, keepdim=True)
230
- return features.cpu().numpy().flatten()
231
- except Exception as e:
232
- logger.error(f"Feature extraction error: {e}")
233
- raise
234
-
235
- def search_similar_items(features, top_k=10):
236
- """Search similar items using segmentation-based features"""
237
- try:
238
- # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์ด ์ ์šฉ๋œ collection์ด ์žˆ๋Š”์ง€ ํ™•์ธ
239
- try:
240
- search_collection = client.get_collection("clothes_segmented")
241
- logger.info("Using segmented collection for search")
242
- except:
243
- # ์—†์œผ๋ฉด ๊ธฐ์กด collection ์‚ฌ์šฉ
244
- search_collection = collection
245
- logger.info("Using original collection for search")
246
-
247
- results = search_collection.query(
248
- query_embeddings=[features.tolist()],
249
- n_results=top_k,
250
- include=['metadatas', 'distances']
251
- )
252
-
253
- if not results or not results['metadatas'] or not results['distances']:
254
- logger.warning("No results returned from ChromaDB")
255
- return []
256
-
257
- similar_items = []
258
- for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
259
- try:
260
- similarity_score = 1 / (1 + float(distance))
261
- item_data = metadata.copy()
262
- item_data['similarity_score'] = similarity_score
263
- similar_items.append(item_data)
264
- except Exception as e:
265
- logger.error(f"Error processing search result: {str(e)}")
266
- continue
267
-
268
- similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
269
- return similar_items
270
- except Exception as e:
271
- logger.error(f"Search error: {str(e)}")
272
- return []
273
-
274
- def show_similar_items(similar_items):
275
- """Display similar items in a structured format with similarity scores"""
276
- if not similar_items:
277
- st.warning("No similar items found.")
278
- return
279
-
280
- st.subheader("Similar Items:")
281
-
282
- # ๊ฒฐ๊ณผ๋ฅผ 2์—ด๋กœ ํ‘œ์‹œ
283
- items_per_row = 2
284
- for i in range(0, len(similar_items), items_per_row):
285
- cols = st.columns(items_per_row)
286
- for j, col in enumerate(cols):
287
- if i + j < len(similar_items):
288
- item = similar_items[i + j]
289
- with col:
290
- try:
291
- if 'image_url' in item:
292
- st.image(item['image_url'], use_column_width=True)
293
-
294
- # ์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ ํผ์„ผํŠธ๋กœ ํ‘œ์‹œ
295
- similarity_percent = item['similarity_score'] * 100
296
- st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
297
-
298
- st.write(f"Brand: {item.get('brand', 'Unknown')}")
299
- name = item.get('name', 'Unknown')
300
- if len(name) > 50: # ๊ธด ์ด๋ฆ„์€ ์ค„์ž„
301
- name = name[:47] + "..."
302
- st.write(f"Name: {name}")
303
-
304
- # ๊ฐ€๊ฒฉ ์ •๋ณด ํ‘œ์‹œ
305
- price = item.get('price', 0)
306
- if isinstance(price, (int, float)):
307
- st.write(f"Price: {price:,}์›")
308
- else:
309
- st.write(f"Price: {price}")
310
-
311
- # ํ• ์ธ ์ •๋ณด๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
312
- if 'discount' in item and item['discount']:
313
- st.write(f"Discount: {item['discount']}%")
314
- if 'original_price' in item:
315
- st.write(f"Original: {item['original_price']:,}์›")
316
-
317
- st.divider() # ๊ตฌ๋ถ„์„  ์ถ”๊ฐ€
318
-
319
- except Exception as e:
320
- logger.error(f"Error displaying item: {e}")
321
- st.error("Error displaying this item")
322
-
323
- def process_search(image, mask, num_results):
324
- """์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
325
- try:
326
- with st.spinner("Extracting features..."):
327
- features = extract_features(image, mask)
328
-
329
- with st.spinner("Finding similar items..."):
330
- similar_items = search_similar_items(features, top_k=num_results)
331
-
332
- return similar_items
333
- except Exception as e:
334
- logger.error(f"Search processing error: {e}")
335
- return None
336
-
337
- # Callback functions
338
- def handle_file_upload():
339
- if st.session_state.uploaded_file is not None:
340
- image = Image.open(st.session_state.uploaded_file).convert('RGB')
341
- st.session_state.image = image
342
- st.session_state.upload_state = 'image_uploaded'
343
- st.rerun()
344
-
345
- def handle_detection():
346
- if st.session_state.image is not None:
347
- detected_items = process_segmentation(st.session_state.image)
348
- st.session_state.detected_items = detected_items
349
- st.session_state.upload_state = 'items_detected'
350
- st.rerun()
351
-
352
- def handle_search():
353
- st.session_state.search_clicked = True
354
-
355
- def admin_interface():
356
- st.title("Admin Interface - DB Update")
357
- if st.button("Update DB with Segmentation"):
358
- with st.spinner("Updating database with segmentation... This may take a while..."):
359
- success = update_db_with_segmentation()
360
- if success:
361
- st.success("Database successfully updated with segmentation-based features!")
362
- else:
363
- st.error("Failed to update database. Please check the logs.")
364
-
365
- def main():
366
- st.title("Fashion Search App")
367
-
368
- # Admin controls in sidebar
369
- st.sidebar.title("Admin Controls")
370
- if st.sidebar.checkbox("Show Admin Interface"):
371
- admin_interface()
372
- st.divider()
373
-
374
- # ํŒŒ์ผ ์—…๋กœ๋” (upload_state๊ฐ€ initial์ผ ๋•Œ๋งŒ ํ‘œ์‹œ)
375
- if st.session_state.upload_state == 'initial':
376
- uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
377
- key='uploaded_file', on_change=handle_file_upload)
378
-
379
- # ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
380
- if st.session_state.image is not None:
381
- st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
382
-
383
- if st.session_state.detected_items is None:
384
- if st.button("Detect Items", key='detect_button', on_click=handle_detection):
385
- pass
386
-
387
- # ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ
388
- if st.session_state.detected_items:
389
- # ๊ฐ์ง€๋œ ์•„์ดํ…œ๋“ค์„ 2์—ด๋กœ ํ‘œ์‹œ
390
- cols = st.columns(2)
391
- for idx, item in enumerate(st.session_state.detected_items):
392
- with cols[idx % 2]:
393
- mask = item['mask']
394
- masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
395
- st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
396
- st.write(f"Item {idx + 1}: {item['label']}")
397
- st.write(f"Confidence: {item['score']*100:.1f}%")
398
-
399
- # ์•„์ดํ…œ ์„ ํƒ
400
- selected_idx = st.selectbox(
401
- "Select item to search:",
402
- range(len(st.session_state.detected_items)),
403
- format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
404
- key='item_selector'
405
- )
406
-
407
- # ๊ฒ€์ƒ‰ ์ปจํŠธ๋กค
408
- search_col1, search_col2 = st.columns([1, 2])
409
- with search_col1:
410
- search_clicked = st.button("Search Similar Items",
411
- key='search_button',
412
- type="primary")
413
- with search_col2:
414
- num_results = st.slider("Number of results:",
415
- min_value=1,
416
- max_value=20,
417
- value=5,
418
- key='num_results')
419
-
420
- # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
421
- if search_clicked or st.session_state.get('search_clicked', False):
422
- st.session_state.search_clicked = True
423
- selected_mask = st.session_state.detected_items[selected_idx]['mask']
424
-
425
- # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
426
- if 'search_results' not in st.session_state:
427
- similar_items = process_search(st.session_state.image, selected_mask, num_results)
428
- st.session_state.search_results = similar_items
429
-
430
- # ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ
431
- if st.session_state.search_results:
432
- show_similar_items(st.session_state.search_results)
433
- else:
434
- st.warning("No similar items found.")
435
-
436
- # ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
437
- if st.button("Start New Search", key='new_search'):
438
- # ๋ชจ๋“  ์ƒํƒœ ์ดˆ๊ธฐํ™”
439
- for key in list(st.session_state.keys()):
440
- del st.session_state[key]
441
- st.rerun()
442
-
443
- if __name__ == "__main__":
444
  main()
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ import chromadb
8
+ import logging
9
+ import io
10
+ import requests
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ # ๋กœ๊น… ์„ค์ •
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize session state
18
+ if 'image' not in st.session_state:
19
+ st.session_state.image = None
20
+ if 'detected_items' not in st.session_state:
21
+ st.session_state.detected_items = None
22
+ if 'selected_item_index' not in st.session_state:
23
+ st.session_state.selected_item_index = None
24
+ if 'upload_state' not in st.session_state:
25
+ st.session_state.upload_state = 'initial'
26
+ if 'search_clicked' not in st.session_state:
27
+ st.session_state.search_clicked = False
28
+
29
+ # Load models
30
+ @st.cache_resource
31
+ def load_models():
32
+ try:
33
+ # CLIP ๋ชจ๋ธ
34
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
35
+
36
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
37
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ return model, preprocess_val, segmenter, device
43
+ except Exception as e:
44
+ logger.error(f"Error loading models: {e}")
45
+ raise
46
+
47
+ # ๋ชจ๋ธ ๋กœ๋“œ
48
+ clip_model, preprocess_val, segmenter, device = load_models()
49
+
50
+ # ChromaDB ์„ค์ •
51
+ client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
52
+ collection = client.get_collection(name="clothes")
53
+
54
+ def process_segmentation(image):
55
+ """Segmentation processing"""
56
+ try:
57
+ # pipeline ์ถœ๋ ฅ ๊ฒฐ๊ณผ ์ง์ ‘ ์ฒ˜๋ฆฌ
58
+ output = segmenter(image)
59
+
60
+ if not output or len(output) == 0:
61
+ logger.warning("No segments found in image")
62
+ return []
63
+
64
+ processed_items = []
65
+ for segment in output:
66
+ mask = segment['mask']
67
+ # ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
68
+ if not isinstance(mask, np.ndarray):
69
+ mask = np.array(mask)
70
+
71
+ # ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
72
+ if len(mask.shape) > 2:
73
+ mask = mask[:, :, 0]
74
+
75
+ # bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
76
+ mask = mask.astype(float)
77
+
78
+ processed_items.append({
79
+ 'mask': mask,
80
+ 'label': segment.get('label', 'Unknown'),
81
+ 'score': segment.get('score', 0.0)
82
+ })
83
+
84
+ logger.info(f"Successfully processed {len(processed_items)} segments")
85
+ return processed_items
86
+
87
+ except Exception as e:
88
+ logger.error(f"Segmentation error: {str(e)}")
89
+ import traceback
90
+ logger.error(traceback.format_exc())
91
+ return []
92
+
93
+ def download_and_process_image(image_url, metadata_id):
94
+ """Download image from URL and apply segmentation"""
95
+ try:
96
+ response = requests.get(image_url, timeout=10)
97
+ if response.status_code != 200:
98
+ logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
99
+ return None
100
+
101
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
102
+ logger.info(f"Successfully downloaded image {metadata_id}")
103
+
104
+ processed_items = process_segmentation(image)
105
+ if processed_items and len(processed_items) > 0:
106
+ # ๊ฐ€์žฅ ํฐ ์„ธ๊ทธ๋จผํŠธ์˜ ๋งˆ์Šคํฌ ์‚ฌ์šฉ
107
+ largest_mask = max(processed_items, key=lambda x: np.sum(x['mask']))['mask']
108
+ features = extract_features(image, largest_mask)
109
+ logger.info(f"Successfully extracted features for image {metadata_id}")
110
+ return features
111
+
112
+ logger.warning(f"No valid mask found for image {metadata_id}")
113
+ return None
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error processing image {metadata_id}: {str(e)}")
117
+ import traceback
118
+ logger.error(traceback.format_exc())
119
+ return None
120
+
121
+ def update_db_with_segmentation():
122
+ """DB์˜ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด segmentation์„ ์ ์šฉํ•˜๊ณ  feature๋ฅผ ์—…๋ฐ์ดํŠธ"""
123
+ try:
124
+ logger.info("Starting database update with segmentation")
125
+
126
+ # ์ƒˆ๋กœ์šด collection ์ƒ์„ฑ
127
+ try:
128
+ client.delete_collection("clothes_segmented")
129
+ logger.info("Deleted existing segmented collection")
130
+ except:
131
+ logger.info("No existing segmented collection to delete")
132
+
133
+ new_collection = client.create_collection(
134
+ name="clothes_segmented",
135
+ metadata={"description": "Clothes collection with segmentation-based features"}
136
+ )
137
+ logger.info("Created new segmented collection")
138
+
139
+ # ๊ธฐ์กด collection์—์„œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
140
+ try:
141
+ all_items = collection.get(include=['metadatas'])
142
+ total_items = len(all_items['metadatas'])
143
+ logger.info(f"Found {total_items} items in database")
144
+ except Exception as e:
145
+ logger.error(f"Error getting items from collection: {str(e)}")
146
+ # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋นˆ ๋ฆฌ์ŠคํŠธ๋กœ ์ดˆ๊ธฐํ™”
147
+ all_items = {'metadatas': []}
148
+ total_items = 0
149
+
150
+ # ์ง„ํ–‰ ์ƒํ™ฉ ํ‘œ์‹œ๋ฅผ ์œ„ํ•œ progress bar
151
+ progress_bar = st.progress(0)
152
+ status_text = st.empty()
153
+
154
+ successful_updates = 0
155
+ failed_updates = 0
156
+
157
+ with ThreadPoolExecutor(max_workers=4) as executor:
158
+ futures = []
159
+ # ์ด๋ฏธ์ง€ URL์ด ์žˆ๋Š” ํ•ญ๋ชฉ๋งŒ ์ฒ˜๋ฆฌ
160
+ valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
161
+
162
+ for metadata in valid_items:
163
+ future = executor.submit(
164
+ download_and_process_image,
165
+ metadata['image_url'],
166
+ metadata.get('id', 'unknown')
167
+ )
168
+ futures.append((metadata, future))
169
+
170
+ # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ ๋ฐ ์ƒˆ DB์— ์ €์žฅ
171
+ for idx, (metadata, future) in enumerate(futures):
172
+ try:
173
+ new_features = future.result()
174
+ if new_features is not None:
175
+ item_id = metadata.get('id', str(hash(metadata['image_url'])))
176
+ try:
177
+ new_collection.add(
178
+ embeddings=[new_features.tolist()],
179
+ metadatas=[metadata],
180
+ ids=[item_id]
181
+ )
182
+ successful_updates += 1
183
+ logger.info(f"Successfully added item {item_id}")
184
+ except Exception as e:
185
+ logger.error(f"Error adding item to new collection: {str(e)}")
186
+ failed_updates += 1
187
+ else:
188
+ failed_updates += 1
189
+
190
+ # ์ง„ํ–‰ ์ƒํ™ฉ ์—…๋ฐ์ดํŠธ
191
+ progress = (idx + 1) / len(futures)
192
+ progress_bar.progress(progress)
193
+ status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error processing item: {str(e)}")
197
+ failed_updates += 1
198
+ continue
199
+
200
+ # ์ตœ์ข… ๊ฒฐ๊ณผ ํ‘œ์‹œ
201
+ status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
202
+ logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
203
+
204
+ # ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋œ ํ•ญ๋ชฉ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
205
+ if successful_updates > 0:
206
+ return True
207
+ else:
208
+ logger.error("No items were successfully processed")
209
+ return False
210
+
211
+ except Exception as e:
212
+ logger.error(f"Database update error: {str(e)}")
213
+ import traceback
214
+ logger.error(traceback.format_exc())
215
+ return False
216
+
217
+ def extract_features(image, mask=None):
218
+ """Extract CLIP features with segmentation mask"""
219
+ try:
220
+ if mask is not None:
221
+ img_array = np.array(image)
222
+ mask = np.expand_dims(mask, axis=2)
223
+ masked_img = img_array * mask
224
+ masked_img[mask[:,:,0] == 0] = 255 # ๋ฐฐ๊ฒฝ์„ ํฐ์ƒ‰์œผ๋กœ
225
+ image = Image.fromarray(masked_img.astype(np.uint8))
226
+
227
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
228
+ with torch.no_grad():
229
+ features = clip_model.encode_image(image_tensor)
230
+ features /= features.norm(dim=-1, keepdim=True)
231
+ return features.cpu().numpy().flatten()
232
+ except Exception as e:
233
+ logger.error(f"Feature extraction error: {e}")
234
+ raise
235
+
236
+ def search_similar_items(features, top_k=10):
237
+ """Search similar items using segmentation-based features"""
238
+ try:
239
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์ด ์ ์šฉ๋œ collection์ด ์žˆ๋Š”์ง€ ํ™•์ธ
240
+ try:
241
+ search_collection = client.get_collection("clothes_segmented")
242
+ logger.info("Using segmented collection for search")
243
+ except:
244
+ # ์—†์œผ๋ฉด ๊ธฐ์กด collection ์‚ฌ์šฉ
245
+ search_collection = collection
246
+ logger.info("Using original collection for search")
247
+
248
+ results = search_collection.query(
249
+ query_embeddings=[features.tolist()],
250
+ n_results=top_k,
251
+ include=['metadatas', 'distances']
252
+ )
253
+
254
+ if not results or not results['metadatas'] or not results['distances']:
255
+ logger.warning("No results returned from ChromaDB")
256
+ return []
257
+
258
+ similar_items = []
259
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
260
+ try:
261
+ similarity_score = 1 / (1 + float(distance))
262
+ item_data = metadata.copy()
263
+ item_data['similarity_score'] = similarity_score
264
+ similar_items.append(item_data)
265
+ except Exception as e:
266
+ logger.error(f"Error processing search result: {str(e)}")
267
+ continue
268
+
269
+ similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
270
+ return similar_items
271
+ except Exception as e:
272
+ logger.error(f"Search error: {str(e)}")
273
+ return []
274
+
275
+ def show_similar_items(similar_items):
276
+ """Display similar items in a structured format with similarity scores"""
277
+ if not similar_items:
278
+ st.warning("No similar items found.")
279
+ return
280
+
281
+ st.subheader("Similar Items:")
282
+
283
+ # ๊ฒฐ๊ณผ๋ฅผ 2์—ด๋กœ ํ‘œ์‹œ
284
+ items_per_row = 2
285
+ for i in range(0, len(similar_items), items_per_row):
286
+ cols = st.columns(items_per_row)
287
+ for j, col in enumerate(cols):
288
+ if i + j < len(similar_items):
289
+ item = similar_items[i + j]
290
+ with col:
291
+ try:
292
+ if 'image_url' in item:
293
+ st.image(item['image_url'], use_column_width=True)
294
+
295
+ # ์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ ํผ์„ผํŠธ๋กœ ํ‘œ์‹œ
296
+ similarity_percent = item['similarity_score'] * 100
297
+ st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
298
+
299
+ st.write(f"Brand: {item.get('brand', 'Unknown')}")
300
+ name = item.get('name', 'Unknown')
301
+ if len(name) > 50: # ๊ธด ์ด๋ฆ„์€ ์ค„์ž„
302
+ name = name[:47] + "..."
303
+ st.write(f"Name: {name}")
304
+
305
+ # ๊ฐ€๊ฒฉ ์ •๋ณด ํ‘œ์‹œ
306
+ price = item.get('price', 0)
307
+ if isinstance(price, (int, float)):
308
+ st.write(f"Price: {price:,}์›")
309
+ else:
310
+ st.write(f"Price: {price}")
311
+
312
+ # ํ• ์ธ ์ •๋ณด๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
313
+ if 'discount' in item and item['discount']:
314
+ st.write(f"Discount: {item['discount']}%")
315
+ if 'original_price' in item:
316
+ st.write(f"Original: {item['original_price']:,}์›")
317
+
318
+ st.divider() # ๊ตฌ๋ถ„์„  ์ถ”๊ฐ€
319
+
320
+ except Exception as e:
321
+ logger.error(f"Error displaying item: {e}")
322
+ st.error("Error displaying this item")
323
+
324
+ def process_search(image, mask, num_results):
325
+ """์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
326
+ try:
327
+ with st.spinner("Extracting features..."):
328
+ features = extract_features(image, mask)
329
+
330
+ with st.spinner("Finding similar items..."):
331
+ similar_items = search_similar_items(features, top_k=num_results)
332
+
333
+ return similar_items
334
+ except Exception as e:
335
+ logger.error(f"Search processing error: {e}")
336
+ return None
337
+
338
+ def handle_file_upload():
339
+ if st.session_state.uploaded_file is not None:
340
+ image = Image.open(st.session_state.uploaded_file).convert('RGB')
341
+ st.session_state.image = image
342
+ st.session_state.upload_state = 'image_uploaded'
343
+ st.rerun()
344
+
345
+ def handle_detection():
346
+ if st.session_state.image is not None:
347
+ detected_items = process_segmentation(st.session_state.image)
348
+ st.session_state.detected_items = detected_items
349
+ st.session_state.upload_state = 'items_detected'
350
+ st.rerun()
351
+
352
+ def handle_search():
353
+ st.session_state.search_clicked = True
354
+
355
+ def main():
356
+ st.title("Fashion Search App")
357
+
358
+ # Admin controls in sidebar
359
+ st.sidebar.title("Admin Controls")
360
+ if st.sidebar.checkbox("Show Admin Interface"):
361
+ admin_interface()
362
+ st.divider()
363
+
364
+ # ํŒŒ์ผ ์—…๋กœ๋”
365
+ if st.session_state.upload_state == 'initial':
366
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
367
+ key='uploaded_file', on_change=handle_file_upload)
368
+
369
+ # ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
370
+ if st.session_state.image is not None:
371
+ st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
372
+
373
+ if st.session_state.detected_items is None:
374
+ if st.button("Detect Items", key='detect_button', on_click=handle_detection):
375
+ pass
376
+
377
+ # ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ
378
+ if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
379
+ # ๊ฐ์ง€๋œ ์•„์ดํ…œ๋“ค์„ 2์—ด๋กœ ํ‘œ์‹œ
380
+ cols = st.columns(2)
381
+ for idx, item in enumerate(st.session_state.detected_items):
382
+ with cols[idx % 2]:
383
+ mask = item['mask']
384
+ masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
385
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
386
+ st.write(f"Item {idx + 1}: {item['label']}")
387
+ st.write(f"Confidence: {item['score']*100:.1f}%")
388
+
389
+ # ์•„์ดํ…œ ์„ ํƒ
390
+ selected_idx = st.selectbox(
391
+ "Select item to search:",
392
+ range(len(st.session_state.detected_items)),
393
+ format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
394
+ key='item_selector'
395
+ )
396
+
397
+ # ๊ฒ€์ƒ‰ ์ปจํŠธ๋กค
398
+ search_col1, search_col2 = st.columns([1, 2])
399
+ with search_col1:
400
+ search_clicked = st.button("Search Similar Items",
401
+ key='search_button',
402
+ type="primary")
403
+ with search_col2:
404
+ num_results = st.slider("Number of results:",
405
+ min_value=1,
406
+ max_value=20,
407
+ value=5,
408
+ key='num_results')
409
+
410
+ # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
411
+ if search_clicked or st.session_state.get('search_clicked', False):
412
+ st.session_state.search_clicked = True
413
+ selected_mask = st.session_state.detected_items[selected_idx]['mask']
414
+
415
+ # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
416
+ if 'search_results' not in st.session_state:
417
+ similar_items = process_search(st.session_state.image, selected_mask, num_results)
418
+ st.session_state.search_results = similar_items
419
+
420
+ # ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ
421
+ if st.session_state.search_results:
422
+ show_similar_items(st.session_state.search_results)
423
+ else:
424
+ st.warning("No similar items found.")
425
+
426
+ # ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
427
+ if st.button("Start New Search", key='new_search'):
428
+ # ๋ชจ๋“  ์ƒํƒœ ์ดˆ๊ธฐํ™”
429
+ for key in list(st.session_state.keys()):
430
+ del st.session_state[key]
431
+ st.rerun()
432
+
433
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
434
  main()