Tanishq_jewelry_recomm_system / #jewelry_recommender_full.py
Maaz Uddin
Add application file
e5eabef
raw
history blame
16 kB
# config.py
import os
import torch
import torchvision.transforms as transforms
class Config:
"""Configuration class for the Jewelry Recommender System."""
# Model settings
VECTOR_DIMENSION = 1280
INDEX_PATH = "rootdir/trained_models/jewelry_index.idx"
METADATA_PATH = "rootdir/trained_models/jewelry_metadata.pkl"
# Hardware settings
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image processing settings
IMAGE_SIZE = (640, 640)
NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
NORMALIZATION_STD = [0.229, 0.224, 0.225]
# Recommendation settings
DEFAULT_NUM_RECOMMENDATIONS = 5
MAX_RECOMMENDATIONS = 20
@classmethod
def get_image_transform(cls):
"""Returns the image transformation pipeline."""
from PIL import ImageOps
return transforms.Compose([
transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),
transforms.Resize(cls.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(
mean=cls.NORMALIZATION_MEAN,
std=cls.NORMALIZATION_STD
)
])
# model_loader.py
import os
import pickle
import faiss
import torch
import torchvision.models as models
import warnings
class ModelLoader:
"""Handles loading of the feature extraction model and FAISS index."""
@staticmethod
def load_feature_extraction_model():
"""Loads and configures the EfficientNet model for feature extraction."""
print("Loading feature extraction model...")
model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT')
model.eval()
# Remove the classification head
model = torch.nn.Sequential(*list(model.children())[:-1])
model = model.to(Config.DEVICE)
return model
@staticmethod
def load_index_and_metadata(index_path=None, metadata_path=None):
"""Loads the FAISS index and metadata from files.
Args:
index_path (str): Path to the FAISS index file
metadata_path (str): Path to the metadata pickle file
Returns:
tuple: (index, metadata, success_flag)
"""
warnings.filterwarnings("ignore")
index_path = index_path or Config.INDEX_PATH
metadata_path = metadata_path or Config.METADATA_PATH
try:
if os.path.exists(index_path) and os.path.exists(metadata_path):
index = faiss.read_index(index_path)
with open(metadata_path, "rb") as f:
metadata = pickle.load(f)
print(f"Index and metadata loaded successfully.")
return index, metadata, True
else:
print(f"Index file or metadata file not found.")
return None, {}, False
except Exception as e:
print(f"Error loading index or metadata: {e}")
return None, {}, False
# image_processor.py
import io
import torch
import numpy as np
from PIL import Image
class ImageProcessor:
"""Handles processing and feature extraction from images."""
def __init__(self, model):
"""Initialize with a pre-trained model.
Args:
model: The pre-trained model for feature extraction
"""
self.model = model
self.transform = Config.get_image_transform()
def normalize_image_input(self, image):
"""Normalize different image input types to a PIL Image.
Args:
image: Can be a PIL.Image, file path, byte stream, or numpy array
Returns:
PIL.Image: The normalized image
"""
try:
if isinstance(image, str):
# If image is a file path
return Image.open(image).convert('RGB')
elif isinstance(image, bytes) or isinstance(image, io.BytesIO):
# If image is a byte stream
if isinstance(image, bytes):
image = io.BytesIO(image)
return Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
# If image is a numpy array (as from gradio)
return Image.fromarray(image.astype('uint8')).convert('RGB')
elif isinstance(image, Image.Image):
# If image is already a PIL Image
return image.convert('RGB')
else:
raise ValueError(f"Unsupported image type: {type(image)}")
except Exception as e:
print(f"Error normalizing image: {str(e)}")
return None
def extract_embedding(self, image):
"""Extract feature embedding from an image.
Args:
image: The image to extract features from (various formats accepted)
Returns:
numpy.ndarray: The feature embedding or None if extraction failed
"""
try:
img = self.normalize_image_input(image)
if img is None:
return None
img_tensor = self.transform(img).unsqueeze(0).to(Config.DEVICE)
with torch.no_grad():
embedding = self.model(img_tensor).squeeze().cpu().numpy()
return embedding
except Exception as e:
print(f"Error extracting embedding: {str(e)}")
return None
# recommender.py - Already provided in the artifact above
# jewelry_recommender.py
import warnings
class JewelryRecommenderService:
"""Main service class for the Jewelry Recommender System."""
def __init__(self,
index_path=None,
metadata_path=None):
"""Initialize the jewelry recommender service.
Args:
index_path (str, optional): Path to FAISS index
metadata_path (str, optional): Path to metadata pickle file
"""
warnings.filterwarnings("ignore")
# Load the model
self.model = ModelLoader.load_feature_extraction_model()
# Load index and metadata
self.index, self.metadata, success = ModelLoader.load_index_and_metadata(
index_path, metadata_path
)
# Initialize pipeline components
self.image_processor = ImageProcessor(self.model)
self.recommender = RecommenderEngine(self.index, self.metadata)
def get_recommendations(self, image, num_recommendations=None, skip_exact_match=True):
"""Get recommendations for a query image.
Args:
image: Query image (various formats)
num_recommendations (int, optional): Number of recommendations
skip_exact_match (bool): Whether to skip the first/exact match
Returns:
list: Recommendation results
"""
num_recommendations = num_recommendations or Config.DEFAULT_NUM_RECOMMENDATIONS
# Extract embedding from the image
embedding = self.image_processor.extract_embedding(image)
# Get similar items based on the embedding
recommendations = self.recommender.find_similar_items(
embedding, num_recommendations, skip_exact_match
)
return recommendations
# formatter.py
class ResultFormatter:
"""Formats recommendation results for display."""
@staticmethod
def format_html(recommendations):
"""Format recommendations as HTML for the Gradio interface.
Args:
recommendations (list): List of recommendation dictionaries
Returns:
str: HTML formatted results
"""
if not recommendations:
return "No recommendations found."
result_html = "<h3>Recommended Jewelry Items:</h3>"
for i, rec in enumerate(recommendations, 1):
metadata = rec["metadata"]
result_html += f"<div style='margin-bottom:15px; padding:10px; border:1px solid #ddd; border-radius:5px;'>"
result_html += f"<h4>#{i}: {metadata.get('name', 'Unknown')}</h4>"
result_html += f"<p><b>Category:</b> {metadata.get('category', 'Unknown')}</p>"
result_html += f"<p><b>Description:</b> {metadata.get('description', 'No description available')}</p>"
result_html += f"<p><b>Price:</b> ${metadata.get('price', 'N/A')}</p>"
result_html += f"<p><b>Similarity Score:</b> {rec['similarity_score']:.4f}</p>"
if 'image_url' in metadata:
result_html += f"<p><img src='{metadata['image_url']}' style='max-width:200px; max-height:200px;'></p>"
result_html += "</div>"
return result_html
@staticmethod
def format_json(recommendations):
"""Format recommendations as JSON.
Args:
recommendations (list): List of recommendation dictionaries
Returns:
list: Clean JSON-serializable results
"""
if not recommendations:
return []
results = []
for rec in recommendations:
results.append({
"item": rec["metadata"].get("name", "Unknown"),
"category": rec["metadata"].get("category", "Unknown"),
"description": rec["metadata"].get("description", "No description"),
"price": rec["metadata"].get("price", "N/A"),
"similarity_score": round(rec["similarity_score"], 4),
"image_url": rec["metadata"].get("image_url", None)
})
return results
# input_handlers.py
import io
import base64
from PIL import Image
class InputHandlers:
"""Handles different types of image inputs for recommendation."""
@staticmethod
def process_image(image, num_recommendations=5, skip_exact_match=True):
"""Process direct image input.
Args:
image: The image (PIL, numpy array, etc.)
num_recommendations (int): Number of recommendations
skip_exact_match (bool): Whether to skip the first/exact match
Returns:
str: HTML formatted results
"""
recommender = JewelryRecommenderService()
recommendations = recommender.get_recommendations(
image, num_recommendations, skip_exact_match
)
return ResultFormatter.format_html(recommendations)
@staticmethod
def process_url(url, num_recommendations=5, skip_exact_match=True):
"""Process image from URL.
Args:
url (str): URL to the image
num_recommendations (int): Number of recommendations
skip_exact_match (bool): Whether to skip the first/exact match
Returns:
str: HTML formatted results
"""
try:
import requests
response = requests.get(url)
image = Image.open(io.BytesIO(response.content))
return InputHandlers.process_image(image, num_recommendations, skip_exact_match)
except Exception as e:
return f"Error processing URL: {str(e)}"
# Base64 input handler is commented out
"""
@staticmethod
def process_base64(base64_string, num_recommendations=5, skip_exact_match=True):
# Process base64-encoded image.
#
# Args:
# base64_string (str): Base64 encoded image
# num_recommendations (int): Number of recommendations
# skip_exact_match (bool): Whether to skip the first/exact match
#
# Returns:
# str: HTML formatted results
try:
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',', 1)[1]
image_bytes = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_bytes))
return InputHandlers.process_image(image, num_recommendations, skip_exact_match)
except Exception as e:
return f"Error processing base64 image: {str(e)}"
"""
# gradio_app.py
import gradio as gr
def create_gradio_interface():
"""Create and configure the Gradio web interface.
Returns:
gradio.Blocks: The configured Gradio interface
"""
with gr.Blocks(title="Jewelry Recommender") as demo:
gr.Markdown("# Jewelry Recommendation System")
gr.Markdown("Upload an image of jewelry to get similar recommendations.")
with gr.Tab("Upload Image"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Jewelry Image")
num_recs_slider = gr.Slider(
minimum=1,
maximum=Config.MAX_RECOMMENDATIONS,
value=Config.DEFAULT_NUM_RECOMMENDATIONS,
step=1,
label="Number of Recommendations"
)
skip_exact = gr.Checkbox(value=True, label="Skip Exact Match")
submit_btn = gr.Button("Get Recommendations")
output_html = gr.HTML(label="Recommendations")
submit_btn.click(
InputHandlers.process_image,
inputs=[image_input, num_recs_slider, skip_exact],
outputs=output_html
)
with gr.Tab("Image URL"):
with gr.Row():
url_input = gr.Textbox(label="Enter Image URL")
url_num_recs = gr.Slider(
minimum=1,
maximum=Config.MAX_RECOMMENDATIONS,
value=Config.DEFAULT_NUM_RECOMMENDATIONS,
step=1,
label="Number of Recommendations"
)
url_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match")
url_btn = gr.Button("Get Recommendations from URL")
url_output = gr.HTML(label="Recommendations")
url_btn.click(
InputHandlers.process_url,
inputs=[url_input, url_num_recs, url_skip_exact],
outputs=url_output
)
# Base64 tab is commented out
"""
with gr.Tab("Base64 Image"):
with gr.Row():
base64_input = gr.Textbox(label="Enter Base64 Image String")
base64_num_recs = gr.Slider(
minimum=1,
maximum=Config.MAX_RECOMMENDATIONS,
value=Config.DEFAULT_NUM_RECOMMENDATIONS,
step=1,
label="Number of Recommendations"
)
base64_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match")
base64_btn = gr.Button("Get Recommendations from Base64")
base64_output = gr.HTML(label="Recommendations")
base64_btn.click(
InputHandlers.process_base64,
inputs=[base64_input, base64_num_recs, base64_skip_exact],
outputs=base64_output
)
"""
gr.Markdown("## How to Use")
gr.Markdown("""
1. Upload an image of jewelry or provide an image URL
2. Adjust the number of recommendations you want to see
3. Check "Skip Exact Match" to exclude the identical or closest match from results
4. Click the 'Get Recommendations' button
5. View similar jewelry items based on visual similarity
""")
return demo
# main.py
def main():
"""Main entry point to run the Jewelry Recommender application."""
print("Starting Jewelry Recommender System...")
demo = create_gradio_interface()
demo.launch()
if __name__ == "__main__":
main()