import chromadb import logging import open_clip import torch from PIL import Image import numpy as np from transformers import pipeline import requests import io from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import os # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('db_creation.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def load_models(): """Load CLIP and segmentation models""" try: logger.info("Loading models...") # CLIP 모델 model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') # 세그멘테이션 모델 segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") model.to(device) return model, preprocess_val, segmenter, device except Exception as e: logger.error(f"Error loading models: {e}") raise def process_segmentation(image, segmenter): """Apply segmentation to image""" try: segments = segmenter(image) if not segments: return None # 가장 큰 세그먼트 선택 largest_segment = max(segments, key=lambda s: np.sum(s['mask'])) mask = np.array(largest_segment['mask']) return mask except Exception as e: logger.error(f"Segmentation error: {e}") return None def extract_features(image, mask, model, preprocess_val, device): """Extract CLIP features with segmentation mask""" try: if mask is not None: img_array = np.array(image) mask = np.expand_dims(mask, axis=2) masked_img = img_array * mask masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로 image = Image.fromarray(masked_img.astype(np.uint8)) image_tensor = preprocess_val(image).unsqueeze(0).to(device) with torch.no_grad(): features = model.encode_image(image_tensor) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().flatten() except Exception as e: logger.error(f"Feature extraction error: {e}") return None def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device): """Download and process single image""" try: response = requests.get(url, timeout=10) if response.status_code != 200: logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}") return None image = Image.open(io.BytesIO(response.content)).convert('RGB') # Apply segmentation mask = process_segmentation(image, segmenter) if mask is None: logger.warning(f"No valid mask found for image {metadata_id}") return None # Extract features features = extract_features(image, mask, model, preprocess_val, device) if features is None: logger.warning(f"Failed to extract features for image {metadata_id}") return None return features except Exception as e: logger.error(f"Error processing image {metadata_id}: {e}") return None def create_segmented_db(source_path, target_path, batch_size=100): """Create new segmented database from existing one""" try: logger.info("Loading models...") model, preprocess_val, segmenter, device = load_models() # Source DB 연결 source_client = chromadb.PersistentClient(path=source_path) source_collection = source_client.get_collection(name="clothes") # Target DB 생성 os.makedirs(target_path, exist_ok=True) target_client = chromadb.PersistentClient(path=target_path) try: target_client.delete_collection("clothes_segmented") except: pass target_collection = target_client.create_collection( name="clothes_segmented", metadata={"description": "Clothes collection with segmentation-based features"} ) # 전체 아이템 수 확인 all_items = source_collection.get(include=['metadatas']) total_items = len(all_items['metadatas']) logger.info(f"Found {total_items} items in source database") # 배치 처리를 위한 준비 successful_updates = 0 failed_updates = 0 # ThreadPoolExecutor 설정 max_workers = min(10, os.cpu_count() or 4) with ThreadPoolExecutor(max_workers=max_workers) as executor: # 전체 데이터를 배치로 나누어 처리 for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"): batch_end = min(batch_start + batch_size, total_items) batch_items = all_items['metadatas'][batch_start:batch_end] # 배치 내의 모든 이미지에 대한 future 생성 futures = [] for metadata in batch_items: if 'image_url' in metadata: future = executor.submit( download_and_process_image, metadata['image_url'], metadata.get('id', 'unknown'), model, preprocess_val, segmenter, device ) futures.append((metadata, future)) # 배치 결과 처리 batch_embeddings = [] batch_metadatas = [] batch_ids = [] for metadata, future in futures: try: features = future.result() if features is not None: batch_embeddings.append(features.tolist()) batch_metadatas.append(metadata) batch_ids.append(metadata.get('id', str(hash(metadata['image_url'])))) successful_updates += 1 else: failed_updates += 1 except Exception as e: logger.error(f"Error processing batch item: {e}") failed_updates += 1 continue # 배치 데이터 저장 if batch_embeddings: try: target_collection.add( embeddings=batch_embeddings, metadatas=batch_metadatas, ids=batch_ids ) logger.info(f"Added batch of {len(batch_embeddings)} items") except Exception as e: logger.error(f"Error adding batch to collection: {e}") failed_updates += len(batch_embeddings) successful_updates -= len(batch_embeddings) # 최종 결과 출력 logger.info(f"Database creation completed.") logger.info(f"Successfully processed: {successful_updates}") logger.info(f"Failed: {failed_updates}") logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%") return True except Exception as e: logger.error(f"Database creation error: {e}") return False if __name__ == "__main__": # 설정값 SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # 원본 DB 경로 TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # 새로운 DB 경로 BATCH_SIZE = 50 # 한 번에 처리할 아이템 수 # DB 생성 실행 success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE) if success: logger.info("Successfully created segmented database!") else: logger.error("Failed to create segmented database.")