koekeloer / app.py
ChristopherMarais's picture
Create app.py
bb00855 verified
# app.py
import json
import numpy as np
import torch
from flask import Flask, request, render_template_string
from transformers import CLIPProcessor, CLIPModel
import faiss
app = Flask(__name__)
# Global variables for the model, processor, FAISS index, and image metadata.
model = None
processor = None
index = None
image_embeddings = None
image_metadata = None
def load_model():
global model, processor
print("Loading CLIP model and processor...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def load_data():
global image_embeddings, image_metadata, index
print("Loading image embeddings and metadata...")
# Load precomputed embeddings and metadata from a JSON file.
with open("data/embeddings.json", "r") as f:
data = json.load(f)
# Each item in data should have an "embedding" key and a "url" (and optionally an "id")
image_embeddings = np.array([d["embedding"] for d in data]).astype('float32')
image_metadata = data
# Build a FAISS index using L2 distance. The dimension 'd' must match the embedding size.
d = image_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(image_embeddings)
print(f"FAISS index built with {index.ntotal} embeddings.")
@app.route("/", methods=["GET", "POST"])
def search():
results_html = ""
query = ""
if request.method == "POST":
query = request.form.get("query", "")
if query:
# Encode the text query using CLIP's text encoder.
inputs = processor(text=[query], return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**inputs)
text_features = text_features.cpu().numpy().astype("float32")
# Query the FAISS index for the top k similar images.
k = 10 # number of results to return
distances, indices = index.search(text_features, k)
# Build HTML image elements for each result.
results = []
for idx in indices[0]:
meta = image_metadata[idx]
results.append(
f'<div style="margin:10px;"><img src="{meta["url"]}" alt="Image {meta.get("id", "")}" style="max-width:200px;"><br>ID: {meta.get("id", "N/A")}</div>'
)
results_html = "".join(results)
# Simple HTML form with results displayed below.
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Image Search with CLIP &amp; FAISS</title>
</head>
<body>
<h1>Image Search</h1>
<form method="post">
<input type="text" name="query" placeholder="Enter search text" value="{query}" required>
<input type="submit" value="Search">
</form>
<div style="display:flex; flex-wrap: wrap; margin-top:20px;">
{results_html}
</div>
</body>
</html>
"""
return render_template_string(html)
if __name__ == "__main__":
load_model()
load_data()
# Run the Flask app on the port expected by Hugging Face Spaces.
app.run(host="0.0.0.0", port=8080)