|
import os, json, re |
|
from typing import List, Tuple |
|
|
|
import numpy as np |
|
import gradio as gr |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
APP_DIR = os.path.dirname(__file__) |
|
ASSETS_DIR = os.path.join(APP_DIR, "assets") |
|
CACHE_DIR = "/mnt/data/eg_space_cache" |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json") |
|
EMB_FP32 = os.path.join(ASSETS_DIR, "doc_embs_fp32.npy") |
|
EMB_FP16 = os.path.join(ASSETS_DIR, "doc_embs_fp16.npy") |
|
FAISS_MAIN = os.path.join(ASSETS_DIR, "faiss_ip_768.index") |
|
|
|
|
|
MATRYOSHKA_DIMS = [768, 512, 256, 128] |
|
DEFAULT_DIMS = 768 |
|
|
|
|
|
with open(CORPUS_JSON, "r", encoding="utf-8") as f: |
|
corpus = json.load(f) |
|
|
|
|
|
if os.path.exists(EMB_FP32): |
|
doc_embs = np.load(EMB_FP32).astype(np.float32, copy=False) |
|
elif os.path.exists(EMB_FP16): |
|
doc_embs = np.load(EMB_FP16).astype(np.float32) |
|
else: |
|
raise FileNotFoundError("Expected assets/doc_embs_fp32.npy or assets/doc_embs_fp16.npy") |
|
|
|
if doc_embs.ndim != 2 or doc_embs.shape[0] != len(corpus): |
|
raise ValueError("Embeddings shape mismatch vs corpus length.") |
|
|
|
EMB_DIM = doc_embs.shape[1] |
|
|
|
|
|
model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN) |
|
|
|
|
|
if os.path.exists(FAISS_MAIN): |
|
base_index_768 = faiss.read_index(FAISS_MAIN) |
|
else: |
|
base_index_768 = faiss.IndexFlatIP(EMB_DIM) |
|
base_index_768.add(doc_embs.astype(np.float32, copy=False)) |
|
|
|
|
|
class MultiDimFaiss: |
|
def __init__(self, doc_embs_full: np.ndarray): |
|
self.full = doc_embs_full |
|
self.indexes = {} |
|
for d in MATRYOSHKA_DIMS: |
|
if d == 768 and FAISS_MAIN and os.path.exists(FAISS_MAIN): |
|
self.indexes[d] = base_index_768 |
|
else: |
|
view = self.full[:, :d].astype(np.float32, copy=False) |
|
idx = faiss.IndexFlatIP(d) |
|
idx.add(view) |
|
self.indexes[d] = idx |
|
|
|
def search(self, q_vec: np.ndarray, top_k: int, dims: int) -> Tuple[np.ndarray, np.ndarray]: |
|
q = q_vec[:dims].astype(np.float32, copy=False)[None, :] |
|
idx = self.indexes[dims] |
|
return idx.search(q, top_k) |
|
|
|
faiss_md = MultiDimFaiss(doc_embs) |
|
|
|
|
|
def _format_snippet(text: str, max_len: int = 380) -> str: |
|
return text[:max_len] + ("…" if len(text) > max_len else "") |
|
|
|
def do_search(query: str, top_k: int = 5, dims: int = DEFAULT_DIMS) -> List[List[str]]: |
|
if not query or not query.strip(): |
|
return [] |
|
q_emb = model.encode_query( |
|
query.strip(), |
|
normalize_embeddings=True, |
|
convert_to_numpy=True |
|
) |
|
scores, idxs = faiss_md.search(q_emb, top_k=top_k, dims=dims) |
|
rows = [] |
|
for s, i in zip(scores[0].tolist(), idxs[0].tolist()): |
|
if i == -1: |
|
continue |
|
title = corpus[i]["title"] |
|
snippet = _format_snippet(corpus[i]["text"]) |
|
rows.append([f"{s:.4f}", title, snippet]) |
|
return rows |
|
|
|
def do_similarity(text_a: str, text_b: str, dims: int = DEFAULT_DIMS) -> float: |
|
if not text_a or not text_b: |
|
return 0.0 |
|
a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0][:dims] |
|
b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0][:dims] |
|
return float(np.dot(a, b)) |
|
|
|
|
|
with gr.Blocks(title="EmbeddingGemma × Wikipedia (EN corpus)") as demo: |
|
gr.Markdown( |
|
""" |
|
# Demo: EmbeddingGemma × Wikipedia (EN corpus) |
|
|
|
This Space showcases [Google DeepMind’s EmbeddingGemma models](https://huggingface.co/collections/google/embeddinggemma-68b9ae3a72a82f0562a80dc4), on a pre-indexed **random 10k sample** of [English Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia). |
|
You can try: |
|
|
|
- **Semantic search** (English queries) |
|
- **Cross-lingual search** (queries in other languages → English articles) |
|
- **Sentence similarity** (compare two texts) |
|
|
|
🔗 Learn more in the [EmbeddingGemma blog post](https://huggingface.co/blog/embeddinggemma). |
|
""" |
|
) |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("Semantic Search (EN corpus)"): |
|
with gr.Row(): |
|
q = gr.Textbox(label="Query", value="Who discovered penicillin?") |
|
topk = gr.Slider(1, 20, value=5, step=1, label="Top-K") |
|
dims = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") |
|
run = gr.Button("Search") |
|
out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True) |
|
run.click(lambda query, k, d: do_search(query, int(k), int(d)), [q, topk, dims], out) |
|
|
|
|
|
with gr.TabItem("Cross-Lingual (EN corpus)"): |
|
gr.Markdown("Type your query in **French/Spanish/Arabic**. Results come from the **English-only** corpus.") |
|
with gr.Row(): |
|
qx = gr.Textbox(label="Query", value="¿Quién descubrió la penicilina?") |
|
topkx = gr.Slider(1, 20, value=5, step=1, label="Top-K") |
|
dimsx = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") |
|
runx = gr.Button("Search") |
|
outx = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True) |
|
runx.click(lambda query, k, d: do_search(query, int(k), int(d)), [qx, topkx, dimsx], outx) |
|
|
|
|
|
with gr.TabItem("Similarity"): |
|
with gr.Row(): |
|
a = gr.Textbox(lines=5, label="Text A", value="Alexander Fleming observed a mold that killed bacteria in 1928.") |
|
b = gr.Textbox(lines=5, label="Text B", value="La penicilina fue descubierta por Alexander Fleming en 1928.") |
|
dims2 = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims") |
|
sim_btn = gr.Button("Compute Similarity") |
|
sim_out = gr.Number(label="Cosine similarity (-1..1)") |
|
sim_btn.click(lambda x, y, d: do_similarity(x, y, int(d)), [a, b, dims2], sim_out) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |