# config.py import os import torch import torchvision.transforms as transforms class Config: """Configuration class for the Jewelry Recommender System.""" # Model settings VECTOR_DIMENSION = 1280 INDEX_PATH = "rootdir/trained_models/jewelry_index.idx" METADATA_PATH = "rootdir/trained_models/jewelry_metadata.pkl" # Hardware settings DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Image processing settings IMAGE_SIZE = (640, 640) NORMALIZATION_MEAN = [0.485, 0.456, 0.406] NORMALIZATION_STD = [0.229, 0.224, 0.225] # Recommendation settings DEFAULT_NUM_RECOMMENDATIONS = 5 MAX_RECOMMENDATIONS = 20 @classmethod def get_image_transform(cls): """Returns the image transformation pipeline.""" from PIL import ImageOps return transforms.Compose([ transforms.Lambda(lambda img: ImageOps.exif_transpose(img)), transforms.Resize(cls.IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize( mean=cls.NORMALIZATION_MEAN, std=cls.NORMALIZATION_STD ) ]) # model_loader.py import os import pickle import faiss import torch import torchvision.models as models import warnings class ModelLoader: """Handles loading of the feature extraction model and FAISS index.""" @staticmethod def load_feature_extraction_model(): """Loads and configures the EfficientNet model for feature extraction.""" print("Loading feature extraction model...") model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT') model.eval() # Remove the classification head model = torch.nn.Sequential(*list(model.children())[:-1]) model = model.to(Config.DEVICE) return model @staticmethod def load_index_and_metadata(index_path=None, metadata_path=None): """Loads the FAISS index and metadata from files. Args: index_path (str): Path to the FAISS index file metadata_path (str): Path to the metadata pickle file Returns: tuple: (index, metadata, success_flag) """ warnings.filterwarnings("ignore") index_path = index_path or Config.INDEX_PATH metadata_path = metadata_path or Config.METADATA_PATH try: if os.path.exists(index_path) and os.path.exists(metadata_path): index = faiss.read_index(index_path) with open(metadata_path, "rb") as f: metadata = pickle.load(f) print(f"Index and metadata loaded successfully.") return index, metadata, True else: print(f"Index file or metadata file not found.") return None, {}, False except Exception as e: print(f"Error loading index or metadata: {e}") return None, {}, False # image_processor.py import io import torch import numpy as np from PIL import Image class ImageProcessor: """Handles processing and feature extraction from images.""" def __init__(self, model): """Initialize with a pre-trained model. Args: model: The pre-trained model for feature extraction """ self.model = model self.transform = Config.get_image_transform() def normalize_image_input(self, image): """Normalize different image input types to a PIL Image. Args: image: Can be a PIL.Image, file path, byte stream, or numpy array Returns: PIL.Image: The normalized image """ try: if isinstance(image, str): # If image is a file path return Image.open(image).convert('RGB') elif isinstance(image, bytes) or isinstance(image, io.BytesIO): # If image is a byte stream if isinstance(image, bytes): image = io.BytesIO(image) return Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): # If image is a numpy array (as from gradio) return Image.fromarray(image.astype('uint8')).convert('RGB') elif isinstance(image, Image.Image): # If image is already a PIL Image return image.convert('RGB') else: raise ValueError(f"Unsupported image type: {type(image)}") except Exception as e: print(f"Error normalizing image: {str(e)}") return None def extract_embedding(self, image): """Extract feature embedding from an image. Args: image: The image to extract features from (various formats accepted) Returns: numpy.ndarray: The feature embedding or None if extraction failed """ try: img = self.normalize_image_input(image) if img is None: return None img_tensor = self.transform(img).unsqueeze(0).to(Config.DEVICE) with torch.no_grad(): embedding = self.model(img_tensor).squeeze().cpu().numpy() return embedding except Exception as e: print(f"Error extracting embedding: {str(e)}") return None # recommender.py - Already provided in the artifact above # jewelry_recommender.py import warnings class JewelryRecommenderService: """Main service class for the Jewelry Recommender System.""" def __init__(self, index_path=None, metadata_path=None): """Initialize the jewelry recommender service. Args: index_path (str, optional): Path to FAISS index metadata_path (str, optional): Path to metadata pickle file """ warnings.filterwarnings("ignore") # Load the model self.model = ModelLoader.load_feature_extraction_model() # Load index and metadata self.index, self.metadata, success = ModelLoader.load_index_and_metadata( index_path, metadata_path ) # Initialize pipeline components self.image_processor = ImageProcessor(self.model) self.recommender = RecommenderEngine(self.index, self.metadata) def get_recommendations(self, image, num_recommendations=None, skip_exact_match=True): """Get recommendations for a query image. Args: image: Query image (various formats) num_recommendations (int, optional): Number of recommendations skip_exact_match (bool): Whether to skip the first/exact match Returns: list: Recommendation results """ num_recommendations = num_recommendations or Config.DEFAULT_NUM_RECOMMENDATIONS # Extract embedding from the image embedding = self.image_processor.extract_embedding(image) # Get similar items based on the embedding recommendations = self.recommender.find_similar_items( embedding, num_recommendations, skip_exact_match ) return recommendations # formatter.py class ResultFormatter: """Formats recommendation results for display.""" @staticmethod def format_html(recommendations): """Format recommendations as HTML for the Gradio interface. Args: recommendations (list): List of recommendation dictionaries Returns: str: HTML formatted results """ if not recommendations: return "No recommendations found." result_html = "
Category: {metadata.get('category', 'Unknown')}
" result_html += f"Description: {metadata.get('description', 'No description available')}
" result_html += f"Price: ${metadata.get('price', 'N/A')}
" result_html += f"Similarity Score: {rec['similarity_score']:.4f}
" if 'image_url' in metadata: result_html += f"