File size: 3,753 Bytes
24cd111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI
import time
import csv
from data.to_poem_list import to_poem_list
import os
import gradio as gr
from huggingface_hub import hf_hub_download,login


hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
#====Settings====
model_path = "slxhere/modern_ancientpoem_encoder"
poem_csv_path = hf_hub_download(
    repo_id="slxhere/tang_poems",
    repo_type="dataset",
    filename="tang_poem.csv"
)
api_key = os.environ.get("DEEPSEEK_API_KEY")
base_url = "https://api.deepseek.com"
top_k = 5
embedding_cache_path = hf_hub_download(
    repo_id="slxhere/poetic-mirror-cache-tang-embedding",
    repo_type="dataset",
    filename="cached_tang_embedding.npy"
)


print("Loading model and data...")
model = SentenceTransformer(model_path)
client = OpenAI(api_key=api_key, base_url=base_url)
poem_sentences = to_poem_list(poem_csv_path)

#========

if os.path.exists(embedding_cache_path):
    poem_embeddings = np.load(embedding_cache_path)
else:
    print("Cached embeddings not found! Encoding... This might take some time...")
    poem_embeddings = model.encode(
        poem_sentences, batch_size=64, show_progress_bar=True, normalize_embeddings=True
    )
    np.save(embedding_cache_path, poem_embeddings)
    print(f"Embedding saved to {embedding_cache_path}")


def rerank_with_llm(modern, candidates):
    prompt = f"""
我说了一句话:“{modern}”,你觉得下面哪一句古诗最能表达这句话的情绪与意境?

"""
    for i, c in enumerate(candidates):
        prompt += f"{i+1}. {c}\n"
    prompt += "\n请直接回复最匹配的一句编号(如 2),不要解释。"

    try:
        resp = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "system", "content": "你是古诗匹配专家。"},
                {"role": "user", "content": prompt}
            ]
        )
        reply = resp.choices[0].message.content.strip()
        for line in reply.splitlines():
            if line.strip().isdigit():
                idx = int(line.strip()) - 1
                if 0 <= idx < len(candidates):
                    return idx
    except Exception as e:
        print("LLM error: ", e)
    return 0  


def retrieve_and_rerank(modern_sentence):
    start_time = time.time()
    emb = model.encode([modern_sentence], normalize_embeddings=True)
    sims = cosine_similarity(emb, poem_embeddings)[0]

    top_k_idx = sims.argsort()[-top_k:][::-1]
    top_k_sims = sims[top_k_idx]
    top_k_poems = [poem_sentences[i] for i in top_k_idx]

    rerank_idx = rerank_with_llm(modern_sentence, top_k_poems)

    scores = np.exp(top_k_sims - np.max(top_k_sims))
    probs = scores / scores.sum()

    results = [{
        "poem": top_k_poems[i],
        "score": round(float(probs[i]), 4),
        "(LLM selected)": i == rerank_idx
    } for i in range(top_k)]

    print(f"Reaction time: {time.time() - start_time:.2f}s")
    return results


def poetry_matcher(input_text):
    results = retrieve_and_rerank(input_text)
    return "\n".join(
        [f"{'✅' if r['(LLM selected)'] else '  '} [{r['score']}] {r['poem']}" for r in results]
    )

iface = gr.Interface(
    fn=poetry_matcher,
    inputs=gr.Textbox(lines=2, placeholder="Enter your sentence..."),
    outputs="text",
    title="🔭 Poetic Mirror 🖌",
    description="穿越千年诗意,为你精准匹配最契合的古诗名句——输入你的句子,邂逅古人共鸣。\nTravel through a thousand years of poetry—enter your sentence, and we'll find the most matching Tang dynasty verse for you."
)

iface.launch()