Image-based search engine
Introduction
In the realm of visual data, efficient and accurate image retrieval is key. This blog post offers a step-by-step guide to building an image-based search engine using open-source tools. By the end, you'll have the skills to create a robust, customizable search engine for precise image retrieval. This blog will empower you to explore visual data in new and exciting ways. Get ready to dive into image retrieval, unlocking the secrets of efficient visual data search with open-source tools!
Embedding the data
Embedding is crucial for building an image-based search engine as it transforms images into a format that can be easily compared and processed. This step enables efficient image retrieval by representing images as unique vectors in a high-dimensional space.
Let's first start by downloading our libraries. as for loadimg it is a python library that I developed to read images and convert them with ease, you can skip this library if you want. if you are interested about it and you want to contribute to its advancement you can checkout my github repository
pip install -qU datasets accelerate loadimg faiss-cpu
then we can move on to loading our dataset
from datasets import load_dataset
data = load_dataset("not-lain/pokemon",split="train")
data
>>> Dataset({
features: ['image', 'text'],
num_rows: 898
})
After that let's load our model, I'm using CLIP here, but you can use any other similar model.
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification # or you can use CLIPProcessor, CLIPModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
Now for the most important part, we will embed our dataset and add our embedding to a new column called embeddings
.
It is recommended that you use a GPU here or any other accelerator since this is a very slow process.
def embed(batch):
pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values']
pixel_values = pixel_values.to(device)
img_emb = model.get_image_features(pixel_values)
batch["embeddings"] = img_emb
return batch
embedded_dataset = dataset.map(embed, batched=True, batch_size=16)
It is recommended that you store your embedded data in a database it being locally, HF, pinecone, chromadb, or any other alternative to avoid embedding the dataset again.
💡TIP
Although it is unrelated to our work here, you can also use the same model to create the embeddings for data of type text by first processing the input via tokens = processor(text = "some text here", padding=True, return_tensors="pt").to(device)
and then you can pass them to your model to create the embeddings by text_emb = model.get_text_features(**tokens)
.
embedded_dataset.push_to_hub("not-lain/embedded-pokemon")
Retrieve images
Once your dataset is embedded and stored in a database we can now move on to defining the retrieval logic. We will need to load the similar model that was used in the previous section here first.
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification # or you can use CLIPProcessor, CLIPModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
We will also need to load our embedded dataset
from datasets import load_dataset
dataset = load_dataset("not-lain/embedded-pokemon", split="train")
You need to add a Faiss index to the embeddings column to set it up as a similarity search index.
Faiss is a library for efficient similarity search and clustering of dense vectors, and it is particularly useful for large-scale image retrieval tasks. By adding a Faiss index to the embeddings column, you enable fast and accurate nearest neighbor searches, making it efficient to retrieve similar images based on their embeddings. This step significantly enhances the performance of your image-based search engine, especially when dealing with a large number of images.
dataset = dataset.add_faiss_index("embeddings")
Now to retrieve the most similar images, you will need to first create the embedding of the new image, then retrieve the most similar entries from the dataset.
import numpy as np
def search(query: str, k: int = 4 ):
"""a function that embeds a new image and returns the most probable results"""
pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] # embed new image
pixel_values = pixel_values.to(device)
img_emb = model.get_image_features(pixel_values)[0] # because it's a single element
img_emb = img_emb.cpu().detach().numpy() # convert to numpy because the datasets library does not support torch vectors
scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results
"embeddings", img_emb, # compare our new embedded image with the dataset embeddings
k=k # get only top k results
)
return retrieved_examples
Let's test our algorithm, to do this you can start by loading an image
from loadimg import load_img
image = load_img("https://img.pokemondb.net/artwork/large/charmander.jpg")
image
after that you can retrieve the most similar entries.
the entries are sorted in a decreasing order with the first entry being the most similar to our input image.
retrieved_examples = search(image)
let's visualize our results
import matplotlib.pyplot as plt
f, axarr = plt.subplots(2,2)
for index in range(4):
i,j = index//2, index%2
axarr[i,j].set_title(retrieved_examples["text"][index])
axarr[i,j].imshow(retrieved_examples["image"][index])
axarr[i,j].axis('off')
plt.show()
Demo
now let's strap everything together in a single application to see our work in motion, you might consider the following application as a good reference https://huggingface.co/spaces/not-lain/image-retriever
Acknowledgement
I would like to acknowledge the importance of Pinecone's docs in helping me develop the script I used in this blogpost 🌲
I would like to thank everyone in the huggingface discord server who supported my work especially lunarflu, christopher, and tomaarsen ❤️
If you loved this blog post and consider upvoting it as this will help me showcase my work 🤗
Finally, if you want to request another blog post you can contact me on twitter, email or linkedin ✉️