File size: 679 Bytes
06ccfd7
abe6bcd
06ccfd7
cf8dfef
06ccfd7
abe6bcd
cf8dfef
abe6bcd
06ccfd7
abe6bcd
b8cb8ea
06ccfd7
cf8dfef
b8cb8ea
06ccfd7
abe6bcd
 
 
 
b8cb8ea
06ccfd7
 
 
 
b8cb8ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr
from sentence_transformers import CrossEncoder

ce = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")

def rerank(query, docs):
    texts = [str(d) for d in docs]  # просто список строк
    pairs = [[query, txt] for txt in texts]
    scores = ce.predict(pairs)
    rows = [[txt, float(score)] for txt, score in zip(texts, scores)]
    return rows


iface = gr.Interface(
    fn=rerank,
    inputs=[
        gr.Textbox(label="Query"),
        gr.JSON(label="Docs (JSON array of objects)")
    ],
    outputs=gr.Dataframe(type="array", headers=["doc", "score"]),
    api_name="rerank"
)

if __name__ == "__main__":
    iface.launch()