# config.py import os import torch import torchvision.transforms as transforms class Config: """Configuration class for the Jewelry Recommender System.""" # Model settings VECTOR_DIMENSION = 1280 INDEX_PATH = "rootdir/trained_models/jewelry_index.idx" METADATA_PATH = "rootdir/trained_models/jewelry_metadata.pkl" # Hardware settings DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Image processing settings IMAGE_SIZE = (640, 640) NORMALIZATION_MEAN = [0.485, 0.456, 0.406] NORMALIZATION_STD = [0.229, 0.224, 0.225] # Recommendation settings DEFAULT_NUM_RECOMMENDATIONS = 5 MAX_RECOMMENDATIONS = 20 @classmethod def get_image_transform(cls): """Returns the image transformation pipeline.""" from PIL import ImageOps return transforms.Compose([ transforms.Lambda(lambda img: ImageOps.exif_transpose(img)), transforms.Resize(cls.IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize( mean=cls.NORMALIZATION_MEAN, std=cls.NORMALIZATION_STD ) ]) # model_loader.py import os import pickle import faiss import torch import torchvision.models as models import warnings class ModelLoader: """Handles loading of the feature extraction model and FAISS index.""" @staticmethod def load_feature_extraction_model(): """Loads and configures the EfficientNet model for feature extraction.""" print("Loading feature extraction model...") model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT') model.eval() # Remove the classification head model = torch.nn.Sequential(*list(model.children())[:-1]) model = model.to(Config.DEVICE) return model @staticmethod def load_index_and_metadata(index_path=None, metadata_path=None): """Loads the FAISS index and metadata from files. Args: index_path (str): Path to the FAISS index file metadata_path (str): Path to the metadata pickle file Returns: tuple: (index, metadata, success_flag) """ warnings.filterwarnings("ignore") index_path = index_path or Config.INDEX_PATH metadata_path = metadata_path or Config.METADATA_PATH try: if os.path.exists(index_path) and os.path.exists(metadata_path): index = faiss.read_index(index_path) with open(metadata_path, "rb") as f: metadata = pickle.load(f) print(f"Index and metadata loaded successfully.") return index, metadata, True else: print(f"Index file or metadata file not found.") return None, {}, False except Exception as e: print(f"Error loading index or metadata: {e}") return None, {}, False # image_processor.py import io import torch import numpy as np from PIL import Image class ImageProcessor: """Handles processing and feature extraction from images.""" def __init__(self, model): """Initialize with a pre-trained model. Args: model: The pre-trained model for feature extraction """ self.model = model self.transform = Config.get_image_transform() def normalize_image_input(self, image): """Normalize different image input types to a PIL Image. Args: image: Can be a PIL.Image, file path, byte stream, or numpy array Returns: PIL.Image: The normalized image """ try: if isinstance(image, str): # If image is a file path return Image.open(image).convert('RGB') elif isinstance(image, bytes) or isinstance(image, io.BytesIO): # If image is a byte stream if isinstance(image, bytes): image = io.BytesIO(image) return Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): # If image is a numpy array (as from gradio) return Image.fromarray(image.astype('uint8')).convert('RGB') elif isinstance(image, Image.Image): # If image is already a PIL Image return image.convert('RGB') else: raise ValueError(f"Unsupported image type: {type(image)}") except Exception as e: print(f"Error normalizing image: {str(e)}") return None def extract_embedding(self, image): """Extract feature embedding from an image. Args: image: The image to extract features from (various formats accepted) Returns: numpy.ndarray: The feature embedding or None if extraction failed """ try: img = self.normalize_image_input(image) if img is None: return None img_tensor = self.transform(img).unsqueeze(0).to(Config.DEVICE) with torch.no_grad(): embedding = self.model(img_tensor).squeeze().cpu().numpy() return embedding except Exception as e: print(f"Error extracting embedding: {str(e)}") return None # recommender.py - Already provided in the artifact above # jewelry_recommender.py import warnings class JewelryRecommenderService: """Main service class for the Jewelry Recommender System.""" def __init__(self, index_path=None, metadata_path=None): """Initialize the jewelry recommender service. Args: index_path (str, optional): Path to FAISS index metadata_path (str, optional): Path to metadata pickle file """ warnings.filterwarnings("ignore") # Load the model self.model = ModelLoader.load_feature_extraction_model() # Load index and metadata self.index, self.metadata, success = ModelLoader.load_index_and_metadata( index_path, metadata_path ) # Initialize pipeline components self.image_processor = ImageProcessor(self.model) self.recommender = RecommenderEngine(self.index, self.metadata) def get_recommendations(self, image, num_recommendations=None, skip_exact_match=True): """Get recommendations for a query image. Args: image: Query image (various formats) num_recommendations (int, optional): Number of recommendations skip_exact_match (bool): Whether to skip the first/exact match Returns: list: Recommendation results """ num_recommendations = num_recommendations or Config.DEFAULT_NUM_RECOMMENDATIONS # Extract embedding from the image embedding = self.image_processor.extract_embedding(image) # Get similar items based on the embedding recommendations = self.recommender.find_similar_items( embedding, num_recommendations, skip_exact_match ) return recommendations # formatter.py class ResultFormatter: """Formats recommendation results for display.""" @staticmethod def format_html(recommendations): """Format recommendations as HTML for the Gradio interface. Args: recommendations (list): List of recommendation dictionaries Returns: str: HTML formatted results """ if not recommendations: return "No recommendations found." result_html = "

Recommended Jewelry Items:

" for i, rec in enumerate(recommendations, 1): metadata = rec["metadata"] result_html += f"
" result_html += f"

#{i}: {metadata.get('name', 'Unknown')}

" result_html += f"

Category: {metadata.get('category', 'Unknown')}

" result_html += f"

Description: {metadata.get('description', 'No description available')}

" result_html += f"

Price: ${metadata.get('price', 'N/A')}

" result_html += f"

Similarity Score: {rec['similarity_score']:.4f}

" if 'image_url' in metadata: result_html += f"

" result_html += "
" return result_html @staticmethod def format_json(recommendations): """Format recommendations as JSON. Args: recommendations (list): List of recommendation dictionaries Returns: list: Clean JSON-serializable results """ if not recommendations: return [] results = [] for rec in recommendations: results.append({ "item": rec["metadata"].get("name", "Unknown"), "category": rec["metadata"].get("category", "Unknown"), "description": rec["metadata"].get("description", "No description"), "price": rec["metadata"].get("price", "N/A"), "similarity_score": round(rec["similarity_score"], 4), "image_url": rec["metadata"].get("image_url", None) }) return results # input_handlers.py import io import base64 from PIL import Image class InputHandlers: """Handles different types of image inputs for recommendation.""" @staticmethod def process_image(image, num_recommendations=5, skip_exact_match=True): """Process direct image input. Args: image: The image (PIL, numpy array, etc.) num_recommendations (int): Number of recommendations skip_exact_match (bool): Whether to skip the first/exact match Returns: str: HTML formatted results """ recommender = JewelryRecommenderService() recommendations = recommender.get_recommendations( image, num_recommendations, skip_exact_match ) return ResultFormatter.format_html(recommendations) @staticmethod def process_url(url, num_recommendations=5, skip_exact_match=True): """Process image from URL. Args: url (str): URL to the image num_recommendations (int): Number of recommendations skip_exact_match (bool): Whether to skip the first/exact match Returns: str: HTML formatted results """ try: import requests response = requests.get(url) image = Image.open(io.BytesIO(response.content)) return InputHandlers.process_image(image, num_recommendations, skip_exact_match) except Exception as e: return f"Error processing URL: {str(e)}" # Base64 input handler is commented out """ @staticmethod def process_base64(base64_string, num_recommendations=5, skip_exact_match=True): # Process base64-encoded image. # # Args: # base64_string (str): Base64 encoded image # num_recommendations (int): Number of recommendations # skip_exact_match (bool): Whether to skip the first/exact match # # Returns: # str: HTML formatted results try: # Remove data URL prefix if present if ',' in base64_string: base64_string = base64_string.split(',', 1)[1] image_bytes = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_bytes)) return InputHandlers.process_image(image, num_recommendations, skip_exact_match) except Exception as e: return f"Error processing base64 image: {str(e)}" """ # gradio_app.py import gradio as gr def create_gradio_interface(): """Create and configure the Gradio web interface. Returns: gradio.Blocks: The configured Gradio interface """ with gr.Blocks(title="Jewelry Recommender") as demo: gr.Markdown("# Jewelry Recommendation System") gr.Markdown("Upload an image of jewelry to get similar recommendations.") with gr.Tab("Upload Image"): with gr.Row(): image_input = gr.Image(type="pil", label="Upload Jewelry Image") num_recs_slider = gr.Slider( minimum=1, maximum=Config.MAX_RECOMMENDATIONS, value=Config.DEFAULT_NUM_RECOMMENDATIONS, step=1, label="Number of Recommendations" ) skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") submit_btn = gr.Button("Get Recommendations") output_html = gr.HTML(label="Recommendations") submit_btn.click( InputHandlers.process_image, inputs=[image_input, num_recs_slider, skip_exact], outputs=output_html ) with gr.Tab("Image URL"): with gr.Row(): url_input = gr.Textbox(label="Enter Image URL") url_num_recs = gr.Slider( minimum=1, maximum=Config.MAX_RECOMMENDATIONS, value=Config.DEFAULT_NUM_RECOMMENDATIONS, step=1, label="Number of Recommendations" ) url_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") url_btn = gr.Button("Get Recommendations from URL") url_output = gr.HTML(label="Recommendations") url_btn.click( InputHandlers.process_url, inputs=[url_input, url_num_recs, url_skip_exact], outputs=url_output ) # Base64 tab is commented out """ with gr.Tab("Base64 Image"): with gr.Row(): base64_input = gr.Textbox(label="Enter Base64 Image String") base64_num_recs = gr.Slider( minimum=1, maximum=Config.MAX_RECOMMENDATIONS, value=Config.DEFAULT_NUM_RECOMMENDATIONS, step=1, label="Number of Recommendations" ) base64_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") base64_btn = gr.Button("Get Recommendations from Base64") base64_output = gr.HTML(label="Recommendations") base64_btn.click( InputHandlers.process_base64, inputs=[base64_input, base64_num_recs, base64_skip_exact], outputs=base64_output ) """ gr.Markdown("## How to Use") gr.Markdown(""" 1. Upload an image of jewelry or provide an image URL 2. Adjust the number of recommendations you want to see 3. Check "Skip Exact Match" to exclude the identical or closest match from results 4. Click the 'Get Recommendations' button 5. View similar jewelry items based on visual similarity """) return demo # main.py def main(): """Main entry point to run the Jewelry Recommender application.""" print("Starting Jewelry Recommender System...") demo = create_gradio_interface() demo.launch() if __name__ == "__main__": main()