itda-segment / app_10281200.py
leedoming's picture
Upload 14 files
2dba380 verified
raw
history blame
9.26 kB
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import chromadb
import logging
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
# Load models μ•ˆλ…•
@st.cache_resource
def load_models():
try:
# CLIP λͺ¨λΈ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ λͺ¨λΈ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
# λͺ¨λΈ λ‘œλ“œ
clip_model, preprocess_val, segmenter, device = load_models()
# ChromaDB μ„€μ •
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
collection = client.get_collection(name="clothes")
def process_segmentation(image):
"""Segmentation processing μ•ˆλ…•ν•˜μ„Έμš”"""
try:
segments = segmenter(image)
valid_items = []
for s in segments:
mask_array = np.array(s['mask'])
confidence = np.mean(mask_array)
valid_items.append({
'score': confidence,
'label': s['label'],
'mask': mask_array
})
return valid_items
except Exception as e:
logger.error(f"Segmentation error: {e}")
return []
def extract_features(image, mask=None):
"""Extract CLIP features"""
try:
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
features = clip_model.encode_image(image_tensor)
features /= features.norm(dim=-1, keepdim=True)
return features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
raise
def search_similar_items(features, top_k=10):
"""Search similar items with distance scores"""
try:
results = collection.query(
query_embeddings=[features.tolist()],
n_results=top_k,
include=['metadatas', 'distances'] # distances 포함
)
similar_items = []
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
# 거리λ₯Ό μœ μ‚¬λ„ 점수둜 λ³€ν™˜ (0~1 λ²”μœ„)
similarity_score = 1 / (1 + distance)
metadata['similarity_score'] = similarity_score # 메타데이터에 점수 μΆ”κ°€
similar_items.append(metadata)
return similar_items
except Exception as e:
logger.error(f"Search error: {e}")
return []
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
st.subheader("Similar Items:")
for item in similar_items:
col1, col2 = st.columns([1, 2])
with col1:
st.image(item['image_url'])
with col2:
# μœ μ‚¬λ„ 점수λ₯Ό νΌμ„ΌνŠΈλ‘œ ν‘œμ‹œ
similarity_percent = item['similarity_score'] * 100
st.write(f"Similarity: {similarity_percent:.1f}%")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
st.write(f"Name: {item.get('name', 'Unknown')}")
st.write(f"Price: {item.get('price', 'Unknown'):,}원")
if 'discount' in item:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original Price: {item['original_price']:,}원")
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
def reset_state():
"""Reset all session state variables"""
for key in list(st.session_state.keys()):
del st.session_state[key]
# Callback functions
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def main():
st.title("ν¬μ–΄λΈ”λž™ fashion demo!!!")
# 파일 μ—…λ‘œλ” (upload_stateκ°€ initial일 λ•Œλ§Œ ν‘œμ‹œ)
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# 이미지가 μ—…λ‘œλ“œλœ μƒνƒœ df
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# κ²€μΆœλœ μ•„μ΄ν…œ ν‘œμ‹œd
if st.session_state.detected_items:
# κ°μ§€λœ μ•„μ΄ν…œλ“€d을 2μ—΄λ‘œ ν‘œμ‹œ
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
mask = item['mask']
masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
st.write(f"Item {idx + 1}: {item['label']}")
st.write(f"Confidence: {item['score']*100:.1f}%")
# μ•„μ΄ν…œ 선택
selected_idx = st.selectbox(
"Select item to search:",
range(len(st.session_state.detected_items)),
format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
key='item_selector'
)
st.session_state.selected_item_index = selected_idx
# μœ μ‚¬ μ•„μ΄ν…œ 검색
col1, col2 = st.columns([1, 2])
with col1:
search_button = st.button("Search Similar Items",
key='search_button',
on_click=handle_search,
type="primary") # κ°•μ‘°λœ λ²„νŠΌ
with col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
if st.session_state.search_clicked:
with st.spinner("Searching similar items..."):
try:
selected_mask = st.session_state.detected_items[selected_idx]['mask']
features = extract_features(st.session_state.image, selected_mask)
similar_items = search_similar_items(features, top_k=num_results)
if similar_items:
show_similar_items(similar_items)
else:
st.warning("No similar items found.")
except Exception as e:
st.error(f"Error during search: {str(e)}")
# μƒˆ 검색 λ²„νŠΌ
if st.button("Start New Search ", key='new_search'):
reset_state()
st.rerun()
if __name__ == "__main__":
main()