e5-small-math / usage_example.py
ThanhLe0125's picture
MRR-optimized E5-Math (+0.046 MRR vs base) - 25/06/2025
049bdf4 verified
raw
history blame
2.55 kB
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
# Load fine-tuned model
model = SentenceTransformer('ThanhLe0125/e5-small-math')
print("🧪 Testing MRR-optimized fine-tuned model:")
print("="*50)
# Example: Vietnamese math question
query = "query: Định nghĩa hàm số đồng biến"
chunks = [
"passage: Hàm số đồng biến trên khoảng (a;b) là hàm số mà với mọi x1 < x2 thì f(x1) < f(x2)",
"passage: Ví dụ: Tìm khoảng đồng biến của hàm số y = x^2 - 2x + 1",
"passage: Phương trình bậc hai ax^2 + bx + c = 0 có delta = b^2 - 4ac",
"passage: Tính đạo hàm của hàm số đa thức",
"passage: Giới hạn của dãy số"
]
# Encode and rank
query_emb = model.encode([query])
chunk_embs = model.encode(chunks)
similarities = cosine_similarity(query_emb, chunk_embs)[0]
ranked_indices = similarities.argsort()[::-1]
# Display results
print("🎯 MRR-Optimized Rankings:")
chunk_types = ["CORRECT", "RELATED", "IRRELEVANT", "IRRELEVANT", "IRRELEVANT"]
for rank, idx in enumerate(ranked_indices, 1):
print(f"Rank {rank}: {chunk_types[idx]:>10} (Score: {similarities[idx]:.4f})")
print(f" {chunks[idx][:70]}...")
print()
# Calculate metrics for this query
correct_rank = None
for rank, idx in enumerate(ranked_indices, 1):
if idx == 0: # First chunk is correct
correct_rank = rank
break
if correct_rank:
mrr = 1.0 / correct_rank
recall_at_k = {}
for k in [1, 2, 3, 4, 5]:
recall_at_k[k] = 1 if correct_rank <= k else 0
print(f"📊 Query Metrics:")
print(f" MRR: {mrr:.4f} (correct chunk at rank #{correct_rank})")
print(f" Recall@1: {recall_at_k[1]} | Recall@2: {recall_at_k[2]} | Recall@3: {recall_at_k[3]}")
print(f" Recall@4: {recall_at_k[4]} | Recall@5: {recall_at_k[5]}")
if correct_rank == 1:
print(" 🌟 PERFECT! Correct chunk at rank #1!")
elif correct_rank <= 2:
print(" 🎯 EXCELLENT! Correct chunk in top 2!")
elif correct_rank <= 3:
print(" 👍 GOOD! Correct chunk in top 3!")
else:
print(" 📈 Could be better - but still found the answer!")
print("\n" + "="*50)
print("💡 Fine-tuning Benefits:")
print(" ✅ Pushes correct chunks to rank #1")
print(" ✅ Reduces inference cost (need fewer chunks)")
print(" ✅ Improves user experience (instant answers)")
print(" ✅ Specialized for Vietnamese mathematics")