|
import gradio as gr |
|
from sentence_transformers import CrossEncoder |
|
|
|
ce = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1") |
|
|
|
def rerank(query, docs): |
|
texts = [str(d) for d in docs] |
|
pairs = [[query, txt] for txt in texts] |
|
scores = ce.predict(pairs) |
|
rows = [[txt, float(score)] for txt, score in zip(texts, scores)] |
|
return rows |
|
|
|
|
|
iface = gr.Interface( |
|
fn=rerank, |
|
inputs=[ |
|
gr.Textbox(label="Query"), |
|
gr.JSON(label="Docs (JSON array of objects)") |
|
], |
|
outputs=gr.Dataframe(type="array", headers=["doc", "score"]), |
|
api_name="rerank" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|