import duckdb import gradio as gr from gradio_client import Client from sentence_transformers import CrossEncoder from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding from huggingface_hub import get_token import pandas as pd static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M") model = SentenceTransformer(modules=[static_embedding]) reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2") embedding_dimensions = model.get_sentence_embedding_dimension() dataset_name = "cyrilzakka/pubmed-medline-embeddings" embedding_column = "embedding" embedding_column_float = f"{embedding_column}_float" table_name = "pubmed_medline" duckdb.sql(query=f""" INSTALL vss; LOAD vss; CREATE TABLE {table_name} AS SELECT *, {embedding_column}::float[{embedding_dimensions}] as {embedding_column_float} FROM 'hf://datasets/{dataset_name}/**/*.parquet'; CREATE INDEX my_hnsw_index ON {table_name} USING HNSW ({embedding_column_float}) WITH (metric = 'cosine'); """) def similarity_search(query: str, k: int = 5): embedding = model.encode(query).tolist() df = duckdb.sql( query=f""" SELECT *, array_cosine_distance({embedding_column_float}, {embedding}::FLOAT[{embedding_dimensions}]) as distance FROM {table_name} ORDER BY distance LIMIT {k}; """ ).to_df() df = df.drop(columns=[embedding_column, embedding_column_float]) return df def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame: documents = documents.copy() documents = documents.drop_duplicates("content") documents["rank"] = reranker.predict([[query, hit] for hit in documents["content"]]) documents = documents.sort_values(by="rank", ascending=False) return documents with gr.Blocks() as demo: gr.Markdown("""# RAG - PubMed Medline (https://pubmed.ncbi.nlm.nih.gov) Executes vector search and re-ranking top of [pubmed-medline-embeddings](https://huggingface.co/datasets/cyrilzakka/pubmed-medline-embeddings). Part of the [Therapeutics Actionability Challenge](https://sail.health/event/sail-2025/program/) Demo.""") query = gr.Textbox(label="Query") k = gr.Slider(1, 50, value=5, label="Number of results") btn = gr.Button("Search") results = gr.Dataframe(headers=["url", "chunk", "distance"], wrap=True) btn.click(fn=similarity_search, inputs=[query, k], outputs=[results]) demo.launch(mcp_server=True)