File size: 9,264 Bytes
2dba380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import chromadb
import logging

# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize session state
if 'image' not in st.session_state:
    st.session_state.image = None
if 'detected_items' not in st.session_state:
    st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
    st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
    st.session_state.upload_state = 'initial'

# Load models μ•ˆλ…•
@st.cache_resource
def load_models():
    try:
        # CLIP λͺ¨λΈ
        model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
        
        # μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ λͺ¨λΈ
        segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        
        return model, preprocess_val, segmenter, device
    except Exception as e:
        logger.error(f"Error loading models: {e}")
        raise

# λͺ¨λΈ λ‘œλ“œ
clip_model, preprocess_val, segmenter, device = load_models()

# ChromaDB μ„€μ •
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
collection = client.get_collection(name="clothes")

def process_segmentation(image):
    """Segmentation processing μ•ˆλ…•ν•˜μ„Έμš”"""
    try:
        segments = segmenter(image)
        valid_items = []
        
        for s in segments:
            mask_array = np.array(s['mask'])
            confidence = np.mean(mask_array)
            
            valid_items.append({
                'score': confidence,
                'label': s['label'],
                'mask': mask_array
            })
        
        return valid_items
    except Exception as e:
        logger.error(f"Segmentation error: {e}")
        return []

def extract_features(image, mask=None):
    """Extract CLIP features"""
    try:
        if mask is not None:
            img_array = np.array(image)
            mask = np.expand_dims(mask, axis=2)
            masked_img = img_array * mask
            masked_img[mask[:,:,0] == 0] = 255
            image = Image.fromarray(masked_img.astype(np.uint8))
        
        image_tensor = preprocess_val(image).unsqueeze(0).to(device)
        with torch.no_grad():
            features = clip_model.encode_image(image_tensor)
            features /= features.norm(dim=-1, keepdim=True)
        return features.cpu().numpy().flatten()
    except Exception as e:
        logger.error(f"Feature extraction error: {e}")
        raise

def search_similar_items(features, top_k=10):
    """Search similar items with distance scores"""
    try:
        results = collection.query(
            query_embeddings=[features.tolist()],
            n_results=top_k,
            include=['metadatas', 'distances']  # distances 포함
        )
        
        similar_items = []
        for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
            # 거리λ₯Ό μœ μ‚¬λ„ 점수둜 λ³€ν™˜ (0~1 λ²”μœ„)
            similarity_score = 1 / (1 + distance)
            metadata['similarity_score'] = similarity_score  # 메타데이터에 점수 μΆ”κ°€
            similar_items.append(metadata)
            
        return similar_items
    except Exception as e:
        logger.error(f"Search error: {e}")
        return []

def show_similar_items(similar_items):
    """Display similar items in a structured format with similarity scores"""
    st.subheader("Similar Items:")
    for item in similar_items:
        col1, col2 = st.columns([1, 2])
        with col1:
            st.image(item['image_url'])
        with col2:
            # μœ μ‚¬λ„ 점수λ₯Ό νΌμ„ΌνŠΈλ‘œ ν‘œμ‹œ
            similarity_percent = item['similarity_score'] * 100
            st.write(f"Similarity: {similarity_percent:.1f}%")
            st.write(f"Brand: {item.get('brand', 'Unknown')}")
            st.write(f"Name: {item.get('name', 'Unknown')}")
            st.write(f"Price: {item.get('price', 'Unknown'):,}원")
            if 'discount' in item:
                st.write(f"Discount: {item['discount']}%")
                if 'original_price' in item:
                    st.write(f"Original Price: {item['original_price']:,}원")

# Initialize session state
if 'image' not in st.session_state:
    st.session_state.image = None
if 'detected_items' not in st.session_state:
    st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
    st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
    st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
    st.session_state.search_clicked = False 

def reset_state():
    """Reset all session state variables"""
    for key in list(st.session_state.keys()):
        del st.session_state[key]

# Callback functions
def handle_file_upload():
    if st.session_state.uploaded_file is not None:
        image = Image.open(st.session_state.uploaded_file).convert('RGB')
        st.session_state.image = image
        st.session_state.upload_state = 'image_uploaded'
        st.rerun()

def handle_detection():
    if st.session_state.image is not None:
        detected_items = process_segmentation(st.session_state.image)
        st.session_state.detected_items = detected_items
        st.session_state.upload_state = 'items_detected'
        st.rerun()
        
def handle_search():
    st.session_state.search_clicked = True

def main():
    st.title("ν¬μ–΄λΈ”λž™ fashion demo!!!")

    # 파일 μ—…λ‘œλ” (upload_stateκ°€ initial일 λ•Œλ§Œ ν‘œμ‹œ)
    if st.session_state.upload_state == 'initial':
        uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], 
                                       key='uploaded_file', on_change=handle_file_upload)

    # 이미지가 μ—…λ‘œλ“œλœ μƒνƒœ df
    if st.session_state.image is not None:
        st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
        
        if st.session_state.detected_items is None:
            if st.button("Detect Items", key='detect_button', on_click=handle_detection):
                pass
        
        # κ²€μΆœλœ μ•„μ΄ν…œ ν‘œμ‹œd
        if st.session_state.detected_items:
            # κ°μ§€λœ μ•„μ΄ν…œλ“€d을 2μ—΄λ‘œ ν‘œμ‹œ
            cols = st.columns(2)
            for idx, item in enumerate(st.session_state.detected_items):
                with cols[idx % 2]:
                    mask = item['mask']
                    masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
                    st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
                    st.write(f"Item {idx + 1}: {item['label']}")
                    st.write(f"Confidence: {item['score']*100:.1f}%")
            
            # μ•„μ΄ν…œ 선택
            selected_idx = st.selectbox(
                "Select item to search:",
                range(len(st.session_state.detected_items)),
                format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
                key='item_selector'
            )
            st.session_state.selected_item_index = selected_idx
            
            # μœ μ‚¬ μ•„μ΄ν…œ 검색
            col1, col2 = st.columns([1, 2])
            with col1:
                search_button = st.button("Search Similar Items", 
                                        key='search_button', 
                                        on_click=handle_search,
                                        type="primary")  # κ°•μ‘°λœ λ²„νŠΌ
            with col2:
                num_results = st.slider("Number of results:", 
                                      min_value=1, 
                                      max_value=20, 
                                      value=5,
                                      key='num_results')
            
            if st.session_state.search_clicked:
                with st.spinner("Searching similar items..."):
                    try:
                        selected_mask = st.session_state.detected_items[selected_idx]['mask']
                        features = extract_features(st.session_state.image, selected_mask)
                        similar_items = search_similar_items(features, top_k=num_results)
                        
                        if similar_items:
                            show_similar_items(similar_items)
                        else:
                            st.warning("No similar items found.")
                    except Exception as e:
                        st.error(f"Error during search: {str(e)}")

    # μƒˆ 검색 λ²„νŠΌ
    if st.button("Start New Search  ", key='new_search'):
        reset_state()
        st.rerun()

if __name__ == "__main__":
    main()