|
|
|
import os |
|
import torch |
|
import torchvision.transforms as transforms |
|
|
|
class Config: |
|
"""Configuration class for the Jewelry Recommender System.""" |
|
|
|
|
|
VECTOR_DIMENSION = 1280 |
|
INDEX_PATH = "rootdir/trained_models/jewelry_index.idx" |
|
METADATA_PATH = "rootdir/trained_models/jewelry_metadata.pkl" |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
IMAGE_SIZE = (640, 640) |
|
NORMALIZATION_MEAN = [0.485, 0.456, 0.406] |
|
NORMALIZATION_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
DEFAULT_NUM_RECOMMENDATIONS = 5 |
|
MAX_RECOMMENDATIONS = 20 |
|
|
|
@classmethod |
|
def get_image_transform(cls): |
|
"""Returns the image transformation pipeline.""" |
|
from PIL import ImageOps |
|
return transforms.Compose([ |
|
transforms.Lambda(lambda img: ImageOps.exif_transpose(img)), |
|
transforms.Resize(cls.IMAGE_SIZE), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=cls.NORMALIZATION_MEAN, |
|
std=cls.NORMALIZATION_STD |
|
) |
|
]) |
|
|
|
|
|
|
|
import os |
|
import pickle |
|
import faiss |
|
import torch |
|
import torchvision.models as models |
|
import warnings |
|
|
|
class ModelLoader: |
|
"""Handles loading of the feature extraction model and FAISS index.""" |
|
|
|
@staticmethod |
|
def load_feature_extraction_model(): |
|
"""Loads and configures the EfficientNet model for feature extraction.""" |
|
print("Loading feature extraction model...") |
|
model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT') |
|
model.eval() |
|
|
|
model = torch.nn.Sequential(*list(model.children())[:-1]) |
|
model = model.to(Config.DEVICE) |
|
return model |
|
|
|
@staticmethod |
|
def load_index_and_metadata(index_path=None, metadata_path=None): |
|
"""Loads the FAISS index and metadata from files. |
|
|
|
Args: |
|
index_path (str): Path to the FAISS index file |
|
metadata_path (str): Path to the metadata pickle file |
|
|
|
Returns: |
|
tuple: (index, metadata, success_flag) |
|
""" |
|
warnings.filterwarnings("ignore") |
|
|
|
index_path = index_path or Config.INDEX_PATH |
|
metadata_path = metadata_path or Config.METADATA_PATH |
|
|
|
try: |
|
if os.path.exists(index_path) and os.path.exists(metadata_path): |
|
index = faiss.read_index(index_path) |
|
with open(metadata_path, "rb") as f: |
|
metadata = pickle.load(f) |
|
print(f"Index and metadata loaded successfully.") |
|
return index, metadata, True |
|
else: |
|
print(f"Index file or metadata file not found.") |
|
return None, {}, False |
|
except Exception as e: |
|
print(f"Error loading index or metadata: {e}") |
|
return None, {}, False |
|
|
|
|
|
|
|
import io |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
class ImageProcessor: |
|
"""Handles processing and feature extraction from images.""" |
|
|
|
def __init__(self, model): |
|
"""Initialize with a pre-trained model. |
|
|
|
Args: |
|
model: The pre-trained model for feature extraction |
|
""" |
|
self.model = model |
|
self.transform = Config.get_image_transform() |
|
|
|
def normalize_image_input(self, image): |
|
"""Normalize different image input types to a PIL Image. |
|
|
|
Args: |
|
image: Can be a PIL.Image, file path, byte stream, or numpy array |
|
|
|
Returns: |
|
PIL.Image: The normalized image |
|
""" |
|
try: |
|
if isinstance(image, str): |
|
|
|
return Image.open(image).convert('RGB') |
|
elif isinstance(image, bytes) or isinstance(image, io.BytesIO): |
|
|
|
if isinstance(image, bytes): |
|
image = io.BytesIO(image) |
|
return Image.open(image).convert('RGB') |
|
elif isinstance(image, np.ndarray): |
|
|
|
return Image.fromarray(image.astype('uint8')).convert('RGB') |
|
elif isinstance(image, Image.Image): |
|
|
|
return image.convert('RGB') |
|
else: |
|
raise ValueError(f"Unsupported image type: {type(image)}") |
|
except Exception as e: |
|
print(f"Error normalizing image: {str(e)}") |
|
return None |
|
|
|
def extract_embedding(self, image): |
|
"""Extract feature embedding from an image. |
|
|
|
Args: |
|
image: The image to extract features from (various formats accepted) |
|
|
|
Returns: |
|
numpy.ndarray: The feature embedding or None if extraction failed |
|
""" |
|
try: |
|
img = self.normalize_image_input(image) |
|
if img is None: |
|
return None |
|
|
|
img_tensor = self.transform(img).unsqueeze(0).to(Config.DEVICE) |
|
with torch.no_grad(): |
|
embedding = self.model(img_tensor).squeeze().cpu().numpy() |
|
return embedding |
|
except Exception as e: |
|
print(f"Error extracting embedding: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
class JewelryRecommenderService: |
|
"""Main service class for the Jewelry Recommender System.""" |
|
|
|
def __init__(self, |
|
index_path=None, |
|
metadata_path=None): |
|
"""Initialize the jewelry recommender service. |
|
|
|
Args: |
|
index_path (str, optional): Path to FAISS index |
|
metadata_path (str, optional): Path to metadata pickle file |
|
""" |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
self.model = ModelLoader.load_feature_extraction_model() |
|
|
|
|
|
self.index, self.metadata, success = ModelLoader.load_index_and_metadata( |
|
index_path, metadata_path |
|
) |
|
|
|
|
|
self.image_processor = ImageProcessor(self.model) |
|
self.recommender = RecommenderEngine(self.index, self.metadata) |
|
|
|
def get_recommendations(self, image, num_recommendations=None, skip_exact_match=True): |
|
"""Get recommendations for a query image. |
|
|
|
Args: |
|
image: Query image (various formats) |
|
num_recommendations (int, optional): Number of recommendations |
|
skip_exact_match (bool): Whether to skip the first/exact match |
|
|
|
Returns: |
|
list: Recommendation results |
|
""" |
|
num_recommendations = num_recommendations or Config.DEFAULT_NUM_RECOMMENDATIONS |
|
|
|
|
|
embedding = self.image_processor.extract_embedding(image) |
|
|
|
|
|
recommendations = self.recommender.find_similar_items( |
|
embedding, num_recommendations, skip_exact_match |
|
) |
|
|
|
return recommendations |
|
|
|
|
|
|
|
class ResultFormatter: |
|
"""Formats recommendation results for display.""" |
|
|
|
@staticmethod |
|
def format_html(recommendations): |
|
"""Format recommendations as HTML for the Gradio interface. |
|
|
|
Args: |
|
recommendations (list): List of recommendation dictionaries |
|
|
|
Returns: |
|
str: HTML formatted results |
|
""" |
|
if not recommendations: |
|
return "No recommendations found." |
|
|
|
result_html = "<h3>Recommended Jewelry Items:</h3>" |
|
for i, rec in enumerate(recommendations, 1): |
|
metadata = rec["metadata"] |
|
result_html += f"<div style='margin-bottom:15px; padding:10px; border:1px solid #ddd; border-radius:5px;'>" |
|
result_html += f"<h4>#{i}: {metadata.get('name', 'Unknown')}</h4>" |
|
result_html += f"<p><b>Category:</b> {metadata.get('category', 'Unknown')}</p>" |
|
result_html += f"<p><b>Description:</b> {metadata.get('description', 'No description available')}</p>" |
|
result_html += f"<p><b>Price:</b> ${metadata.get('price', 'N/A')}</p>" |
|
result_html += f"<p><b>Similarity Score:</b> {rec['similarity_score']:.4f}</p>" |
|
if 'image_url' in metadata: |
|
result_html += f"<p><img src='{metadata['image_url']}' style='max-width:200px; max-height:200px;'></p>" |
|
result_html += "</div>" |
|
|
|
return result_html |
|
|
|
@staticmethod |
|
def format_json(recommendations): |
|
"""Format recommendations as JSON. |
|
|
|
Args: |
|
recommendations (list): List of recommendation dictionaries |
|
|
|
Returns: |
|
list: Clean JSON-serializable results |
|
""" |
|
if not recommendations: |
|
return [] |
|
|
|
results = [] |
|
for rec in recommendations: |
|
results.append({ |
|
"item": rec["metadata"].get("name", "Unknown"), |
|
"category": rec["metadata"].get("category", "Unknown"), |
|
"description": rec["metadata"].get("description", "No description"), |
|
"price": rec["metadata"].get("price", "N/A"), |
|
"similarity_score": round(rec["similarity_score"], 4), |
|
"image_url": rec["metadata"].get("image_url", None) |
|
}) |
|
|
|
return results |
|
|
|
|
|
|
|
import io |
|
import base64 |
|
from PIL import Image |
|
|
|
class InputHandlers: |
|
"""Handles different types of image inputs for recommendation.""" |
|
|
|
@staticmethod |
|
def process_image(image, num_recommendations=5, skip_exact_match=True): |
|
"""Process direct image input. |
|
|
|
Args: |
|
image: The image (PIL, numpy array, etc.) |
|
num_recommendations (int): Number of recommendations |
|
skip_exact_match (bool): Whether to skip the first/exact match |
|
|
|
Returns: |
|
str: HTML formatted results |
|
""" |
|
recommender = JewelryRecommenderService() |
|
recommendations = recommender.get_recommendations( |
|
image, num_recommendations, skip_exact_match |
|
) |
|
return ResultFormatter.format_html(recommendations) |
|
|
|
@staticmethod |
|
def process_url(url, num_recommendations=5, skip_exact_match=True): |
|
"""Process image from URL. |
|
|
|
Args: |
|
url (str): URL to the image |
|
num_recommendations (int): Number of recommendations |
|
skip_exact_match (bool): Whether to skip the first/exact match |
|
|
|
Returns: |
|
str: HTML formatted results |
|
""" |
|
try: |
|
import requests |
|
response = requests.get(url) |
|
image = Image.open(io.BytesIO(response.content)) |
|
return InputHandlers.process_image(image, num_recommendations, skip_exact_match) |
|
except Exception as e: |
|
return f"Error processing URL: {str(e)}" |
|
|
|
|
|
""" |
|
@staticmethod |
|
def process_base64(base64_string, num_recommendations=5, skip_exact_match=True): |
|
# Process base64-encoded image. |
|
# |
|
# Args: |
|
# base64_string (str): Base64 encoded image |
|
# num_recommendations (int): Number of recommendations |
|
# skip_exact_match (bool): Whether to skip the first/exact match |
|
# |
|
# Returns: |
|
# str: HTML formatted results |
|
|
|
try: |
|
# Remove data URL prefix if present |
|
if ',' in base64_string: |
|
base64_string = base64_string.split(',', 1)[1] |
|
|
|
image_bytes = base64.b64decode(base64_string) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
return InputHandlers.process_image(image, num_recommendations, skip_exact_match) |
|
except Exception as e: |
|
return f"Error processing base64 image: {str(e)}" |
|
""" |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
def create_gradio_interface(): |
|
"""Create and configure the Gradio web interface. |
|
|
|
Returns: |
|
gradio.Blocks: The configured Gradio interface |
|
""" |
|
with gr.Blocks(title="Jewelry Recommender") as demo: |
|
gr.Markdown("# Jewelry Recommendation System") |
|
gr.Markdown("Upload an image of jewelry to get similar recommendations.") |
|
|
|
with gr.Tab("Upload Image"): |
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload Jewelry Image") |
|
num_recs_slider = gr.Slider( |
|
minimum=1, |
|
maximum=Config.MAX_RECOMMENDATIONS, |
|
value=Config.DEFAULT_NUM_RECOMMENDATIONS, |
|
step=1, |
|
label="Number of Recommendations" |
|
) |
|
skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") |
|
submit_btn = gr.Button("Get Recommendations") |
|
output_html = gr.HTML(label="Recommendations") |
|
submit_btn.click( |
|
InputHandlers.process_image, |
|
inputs=[image_input, num_recs_slider, skip_exact], |
|
outputs=output_html |
|
) |
|
|
|
with gr.Tab("Image URL"): |
|
with gr.Row(): |
|
url_input = gr.Textbox(label="Enter Image URL") |
|
url_num_recs = gr.Slider( |
|
minimum=1, |
|
maximum=Config.MAX_RECOMMENDATIONS, |
|
value=Config.DEFAULT_NUM_RECOMMENDATIONS, |
|
step=1, |
|
label="Number of Recommendations" |
|
) |
|
url_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") |
|
url_btn = gr.Button("Get Recommendations from URL") |
|
url_output = gr.HTML(label="Recommendations") |
|
url_btn.click( |
|
InputHandlers.process_url, |
|
inputs=[url_input, url_num_recs, url_skip_exact], |
|
outputs=url_output |
|
) |
|
|
|
|
|
""" |
|
with gr.Tab("Base64 Image"): |
|
with gr.Row(): |
|
base64_input = gr.Textbox(label="Enter Base64 Image String") |
|
base64_num_recs = gr.Slider( |
|
minimum=1, |
|
maximum=Config.MAX_RECOMMENDATIONS, |
|
value=Config.DEFAULT_NUM_RECOMMENDATIONS, |
|
step=1, |
|
label="Number of Recommendations" |
|
) |
|
base64_skip_exact = gr.Checkbox(value=True, label="Skip Exact Match") |
|
base64_btn = gr.Button("Get Recommendations from Base64") |
|
base64_output = gr.HTML(label="Recommendations") |
|
base64_btn.click( |
|
InputHandlers.process_base64, |
|
inputs=[base64_input, base64_num_recs, base64_skip_exact], |
|
outputs=base64_output |
|
) |
|
""" |
|
|
|
gr.Markdown("## How to Use") |
|
gr.Markdown(""" |
|
1. Upload an image of jewelry or provide an image URL |
|
2. Adjust the number of recommendations you want to see |
|
3. Check "Skip Exact Match" to exclude the identical or closest match from results |
|
4. Click the 'Get Recommendations' button |
|
5. View similar jewelry items based on visual similarity |
|
""") |
|
|
|
return demo |
|
|
|
|
|
|
|
def main(): |
|
"""Main entry point to run the Jewelry Recommender application.""" |
|
print("Starting Jewelry Recommender System...") |
|
demo = create_gradio_interface() |
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|