paraformer / example_usage.py
nguyenthanhasia's picture
Upload example_usage.py with huggingface_hub
57a3905 verified
"""
Example usage of Paraformer model for legal document retrieval.
This is a simplified implementation. For full functionality and customization,
visit: https://github.com/nguyenthanhasia/paraformer
License: Research purposes - free to use. Commercial purposes - at your own risk.
"""
from transformers import AutoModel
import torch
def main():
print("Paraformer Model - Example Usage")
print("=" * 50)
# Load the model
print("Loading model from Hugging Face Hub...")
model = AutoModel.from_pretrained('nguyenthanhasia/paraformer', trust_remote_code=True)
print("✓ Model loaded successfully")
# Example 1: Single query-article pair
print("\n1. Single Query-Article Example:")
print("-" * 30)
query = "What are the legal requirements for contract formation?"
article = [
"A contract is a legally binding agreement between two or more parties.",
"For a contract to be valid, it must have offer, acceptance, and consideration.",
"The parties must have legal capacity to enter into the contract."
]
print(f"Query: {query}")
print(f"Article: {len(article)} sentences")
# Get relevance score
relevance_score = model.get_relevance_score(query, article)
print(f"Relevance Score: {relevance_score:.4f}")
# Get binary prediction
prediction = model.predict_relevance(query, article)
print(f"Binary Output: {prediction} (0=lower similarity, 1=higher similarity)")
# Example 2: Batch processing
print("\n2. Batch Processing Example:")
print("-" * 30)
queries = [
"What constitutes a valid contract?",
"How can employment be terminated?",
"What are the requirements for copyright protection?"
]
articles = [
["A contract requires offer, acceptance, and consideration.", "All parties must have legal capacity."],
["Employment can be terminated by mutual agreement.", "Notice period must be respected."],
["Copyright protects original works of authorship.", "The work must be fixed in a tangible medium."]
]
# Forward pass for batch processing
outputs = model.forward(
query_texts=queries,
article_texts=articles,
return_dict=True
)
# Get probabilities and predictions
probabilities = torch.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(outputs.logits, dim=-1)
for i, (query, article) in enumerate(zip(queries, articles)):
score = probabilities[i, 1].item()
pred = predictions[i].item()
print(f"\nQuery {i+1}: {query}")
print(f" Similarity Score: {score:.4f}")
print(f" Binary Output: {pred}")
# Example 3: Accessing attention weights
print("\n3. Attention Weights Example:")
print("-" * 30)
query = "What is required for a valid contract?"
article = [
"A contract is an agreement between parties.",
"It must have offer and acceptance.",
"Consideration is also required.",
"The weather is nice today." # Irrelevant sentence
]
outputs = model.forward(
query_texts=[query],
article_texts=[article],
return_dict=True
)
if outputs.attentions is not None:
attention_weights = outputs.attentions[0, 0] # First batch, first query
print(f"Query: {query}")
print("Attention weights per sentence:")
for i, (sentence, weight) in enumerate(zip(article, attention_weights)):
print(f" Sentence {i+1}: {weight:.4f} - {sentence}")
print("\n" + "=" * 50)
print("Important Notes:")
print("- Scores represent similarity in learned feature space, not absolute relevance")
print("- This is a simplified implementation for easy integration")
print("- For full functionality: https://github.com/nguyenthanhasia/paraformer")
print("- Research use: free | Commercial use: at your own risk")
if __name__ == "__main__":
main()