Spaces:
Runtime error
Runtime error
# app.py | |
import json | |
import numpy as np | |
import torch | |
from flask import Flask, request, render_template_string | |
from transformers import CLIPProcessor, CLIPModel | |
import faiss | |
app = Flask(__name__) | |
# Global variables for the model, processor, FAISS index, and image metadata. | |
model = None | |
processor = None | |
index = None | |
image_embeddings = None | |
image_metadata = None | |
def load_model(): | |
global model, processor | |
print("Loading CLIP model and processor...") | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
def load_data(): | |
global image_embeddings, image_metadata, index | |
print("Loading image embeddings and metadata...") | |
# Load precomputed embeddings and metadata from a JSON file. | |
with open("data/embeddings.json", "r") as f: | |
data = json.load(f) | |
# Each item in data should have an "embedding" key and a "url" (and optionally an "id") | |
image_embeddings = np.array([d["embedding"] for d in data]).astype('float32') | |
image_metadata = data | |
# Build a FAISS index using L2 distance. The dimension 'd' must match the embedding size. | |
d = image_embeddings.shape[1] | |
index = faiss.IndexFlatL2(d) | |
index.add(image_embeddings) | |
print(f"FAISS index built with {index.ntotal} embeddings.") | |
def search(): | |
results_html = "" | |
query = "" | |
if request.method == "POST": | |
query = request.form.get("query", "") | |
if query: | |
# Encode the text query using CLIP's text encoder. | |
inputs = processor(text=[query], return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_features = model.get_text_features(**inputs) | |
text_features = text_features.cpu().numpy().astype("float32") | |
# Query the FAISS index for the top k similar images. | |
k = 10 # number of results to return | |
distances, indices = index.search(text_features, k) | |
# Build HTML image elements for each result. | |
results = [] | |
for idx in indices[0]: | |
meta = image_metadata[idx] | |
results.append( | |
f'<div style="margin:10px;"><img src="{meta["url"]}" alt="Image {meta.get("id", "")}" style="max-width:200px;"><br>ID: {meta.get("id", "N/A")}</div>' | |
) | |
results_html = "".join(results) | |
# Simple HTML form with results displayed below. | |
html = f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="UTF-8"> | |
<title>Image Search with CLIP & FAISS</title> | |
</head> | |
<body> | |
<h1>Image Search</h1> | |
<form method="post"> | |
<input type="text" name="query" placeholder="Enter search text" value="{query}" required> | |
<input type="submit" value="Search"> | |
</form> | |
<div style="display:flex; flex-wrap: wrap; margin-top:20px;"> | |
{results_html} | |
</div> | |
</body> | |
</html> | |
""" | |
return render_template_string(html) | |
if __name__ == "__main__": | |
load_model() | |
load_data() | |
# Run the Flask app on the port expected by Hugging Face Spaces. | |
app.run(host="0.0.0.0", port=8080) | |