fdaudens's picture
Update app.py
df910d2 verified
import os, json, re
from typing import List, Tuple
import numpy as np
import gradio as gr
import faiss
from sentence_transformers import SentenceTransformer
HF_TOKEN = os.getenv("HF_TOKEN")
# ---------- Paths (expects files committed under ./assets) ----------
APP_DIR = os.path.dirname(__file__)
ASSETS_DIR = os.path.join(APP_DIR, "assets")
CACHE_DIR = "/mnt/data/eg_space_cache" # runtime cache
os.makedirs(CACHE_DIR, exist_ok=True)
CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json")
EMB_FP32 = os.path.join(ASSETS_DIR, "doc_embs_fp32.npy")
EMB_FP16 = os.path.join(ASSETS_DIR, "doc_embs_fp16.npy")
FAISS_MAIN = os.path.join(ASSETS_DIR, "faiss_ip_768.index")
# ---------- Matryoshka dims ----------
MATRYOSHKA_DIMS = [768, 512, 256, 128]
DEFAULT_DIMS = 768
# ---------- Load corpus ----------
with open(CORPUS_JSON, "r", encoding="utf-8") as f:
corpus = json.load(f) # list of {"title","text"} in EXACT same order as embeddings
# ---------- Load embeddings ----------
if os.path.exists(EMB_FP32):
doc_embs = np.load(EMB_FP32).astype(np.float32, copy=False)
elif os.path.exists(EMB_FP16):
doc_embs = np.load(EMB_FP16).astype(np.float32) # cast back for FAISS
else:
raise FileNotFoundError("Expected assets/doc_embs_fp32.npy or assets/doc_embs_fp16.npy")
if doc_embs.ndim != 2 or doc_embs.shape[0] != len(corpus):
raise ValueError("Embeddings shape mismatch vs corpus length.")
EMB_DIM = doc_embs.shape[1] # should be 768
# ---------- Model (for queries + sentence-level ops) ----------
model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN) # CPU is fine for queries
# ---------- FAISS indexes ----------
if os.path.exists(FAISS_MAIN):
base_index_768 = faiss.read_index(FAISS_MAIN)
else:
base_index_768 = faiss.IndexFlatIP(EMB_DIM)
base_index_768.add(doc_embs.astype(np.float32, copy=False))
# Build per-dimension flat IP indexes from the loaded embeddings
class MultiDimFaiss:
def __init__(self, doc_embs_full: np.ndarray):
self.full = doc_embs_full
self.indexes = {}
for d in MATRYOSHKA_DIMS:
if d == 768 and FAISS_MAIN and os.path.exists(FAISS_MAIN):
self.indexes[d] = base_index_768
else:
view = self.full[:, :d].astype(np.float32, copy=False)
idx = faiss.IndexFlatIP(d)
idx.add(view)
self.indexes[d] = idx
def search(self, q_vec: np.ndarray, top_k: int, dims: int) -> Tuple[np.ndarray, np.ndarray]:
q = q_vec[:dims].astype(np.float32, copy=False)[None, :]
idx = self.indexes[dims]
return idx.search(q, top_k)
faiss_md = MultiDimFaiss(doc_embs)
# ---------- Core ops ----------
def _format_snippet(text: str, max_len: int = 380) -> str:
return text[:max_len] + ("…" if len(text) > max_len else "")
def do_search(query: str, top_k: int = 5, dims: int = DEFAULT_DIMS) -> List[List[str]]:
if not query or not query.strip():
return []
q_emb = model.encode_query(
query.strip(),
normalize_embeddings=True,
convert_to_numpy=True
)
scores, idxs = faiss_md.search(q_emb, top_k=top_k, dims=dims)
rows = []
for s, i in zip(scores[0].tolist(), idxs[0].tolist()):
if i == -1:
continue
title = corpus[i]["title"]
snippet = _format_snippet(corpus[i]["text"])
rows.append([f"{s:.4f}", title, snippet])
return rows
def do_similarity(text_a: str, text_b: str, dims: int = DEFAULT_DIMS) -> float:
if not text_a or not text_b:
return 0.0
a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0][:dims]
b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0][:dims]
return float(np.dot(a, b))
# ---------- Gradio UI ----------
with gr.Blocks(title="EmbeddingGemma × Wikipedia (EN corpus)") as demo:
gr.Markdown(
"""
# Demo: EmbeddingGemma × Wikipedia (EN corpus)
This Space showcases [Google DeepMind’s EmbeddingGemma models](https://huggingface.co/collections/google/embeddinggemma-68b9ae3a72a82f0562a80dc4), on a pre-indexed **random 10k sample** of [English Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia).
You can try:
- **Semantic search** (English queries)
- **Cross-lingual search** (queries in other languages → English articles)
- **Sentence similarity** (compare two texts)
🔗 Learn more in the [EmbeddingGemma blog post](https://huggingface.co/blog/embeddinggemma).
"""
)
with gr.Tabs():
# 1) Semantic Search (EN-only corpus)
with gr.TabItem("Semantic Search (EN corpus)"):
with gr.Row():
q = gr.Textbox(label="Query", value="Who discovered penicillin?")
topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
dims = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
run = gr.Button("Search")
out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True)
run.click(lambda query, k, d: do_search(query, int(k), int(d)), [q, topk, dims], out)
# 2) Cross-Lingual (queries in FR/ES/etc → EN corpus)
with gr.TabItem("Cross-Lingual (EN corpus)"):
gr.Markdown("Type your query in **French/Spanish/Arabic**. Results come from the **English-only** corpus.")
with gr.Row():
qx = gr.Textbox(label="Query", value="¿Quién descubrió la penicilina?")
topkx = gr.Slider(1, 20, value=5, step=1, label="Top-K")
dimsx = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
runx = gr.Button("Search")
outx = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True)
runx.click(lambda query, k, d: do_search(query, int(k), int(d)), [qx, topkx, dimsx], outx)
# 3) Similarity
with gr.TabItem("Similarity"):
with gr.Row():
a = gr.Textbox(lines=5, label="Text A", value="Alexander Fleming observed a mold that killed bacteria in 1928.")
b = gr.Textbox(lines=5, label="Text B", value="La penicilina fue descubierta por Alexander Fleming en 1928.")
dims2 = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
sim_btn = gr.Button("Compute Similarity")
sim_out = gr.Number(label="Cosine similarity (-1..1)")
sim_btn.click(lambda x, y, d: do_similarity(x, y, int(d)), [a, b, dims2], sim_out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)