Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from openai import OpenAI | |
import time | |
import csv | |
from data.to_poem_list import to_poem_list | |
import os | |
import gradio as gr | |
from huggingface_hub import hf_hub_download,login | |
hf_token = os.environ.get("HF_TOKEN") | |
login(token=hf_token) | |
#====Settings==== | |
model_path = "slxhere/modern_ancientpoem_encoder" | |
poem_csv_path = hf_hub_download( | |
repo_id="slxhere/tang_poems", | |
repo_type="dataset", | |
filename="tang_poem.csv" | |
) | |
api_key = os.environ.get("DEEPSEEK_API_KEY") | |
base_url = "https://api.deepseek.com" | |
top_k = 5 | |
embedding_cache_path = hf_hub_download( | |
repo_id="slxhere/poetic-mirror-cache-tang-embedding", | |
repo_type="dataset", | |
filename="cached_tang_embedding.npy" | |
) | |
print("Loading model and data...") | |
model = SentenceTransformer(model_path) | |
client = OpenAI(api_key=api_key, base_url=base_url) | |
poem_sentences = to_poem_list(poem_csv_path) | |
#======== | |
if os.path.exists(embedding_cache_path): | |
poem_embeddings = np.load(embedding_cache_path) | |
else: | |
print("Cached embeddings not found! Encoding... This might take some time...") | |
poem_embeddings = model.encode( | |
poem_sentences, batch_size=64, show_progress_bar=True, normalize_embeddings=True | |
) | |
np.save(embedding_cache_path, poem_embeddings) | |
print(f"Embedding saved to {embedding_cache_path}") | |
def rerank_with_llm(modern, candidates): | |
prompt = f""" | |
我说了一句话:“{modern}”,你觉得下面哪一句古诗最能表达这句话的情绪与意境? | |
""" | |
for i, c in enumerate(candidates): | |
prompt += f"{i+1}. {c}\n" | |
prompt += "\n请直接回复最匹配的一句编号(如 2),不要解释。" | |
try: | |
resp = client.chat.completions.create( | |
model="deepseek-chat", | |
messages=[ | |
{"role": "system", "content": "你是古诗匹配专家。"}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
reply = resp.choices[0].message.content.strip() | |
for line in reply.splitlines(): | |
if line.strip().isdigit(): | |
idx = int(line.strip()) - 1 | |
if 0 <= idx < len(candidates): | |
return idx | |
except Exception as e: | |
print("LLM error: ", e) | |
return 0 | |
def retrieve_and_rerank(modern_sentence): | |
start_time = time.time() | |
emb = model.encode([modern_sentence], normalize_embeddings=True) | |
sims = cosine_similarity(emb, poem_embeddings)[0] | |
top_k_idx = sims.argsort()[-top_k:][::-1] | |
top_k_sims = sims[top_k_idx] | |
top_k_poems = [poem_sentences[i] for i in top_k_idx] | |
rerank_idx = rerank_with_llm(modern_sentence, top_k_poems) | |
scores = np.exp(top_k_sims - np.max(top_k_sims)) | |
probs = scores / scores.sum() | |
results = [{ | |
"poem": top_k_poems[i], | |
"score": round(float(probs[i]), 4), | |
"(LLM selected)": i == rerank_idx | |
} for i in range(top_k)] | |
print(f"Reaction time: {time.time() - start_time:.2f}s") | |
return results | |
def poetry_matcher(input_text): | |
results = retrieve_and_rerank(input_text) | |
return "\n".join( | |
[f"{'✅' if r['(LLM selected)'] else ' '} [{r['score']}] {r['poem']}" for r in results] | |
) | |
iface = gr.Interface( | |
fn=poetry_matcher, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your sentence..."), | |
outputs="text", | |
title="🔭 Poetic Mirror 🖌", | |
description="穿越千年诗意,为你精准匹配最契合的古诗名句——输入你的句子,邂逅古人共鸣。\nTravel through a thousand years of poetry—enter your sentence, and we'll find the most matching Tang dynasty verse for you." | |
) | |
iface.launch() | |