|
|
""" |
|
|
Example usage of Paraformer model for legal document retrieval. |
|
|
|
|
|
This is a simplified implementation. For full functionality and customization, |
|
|
visit: https://github.com/nguyenthanhasia/paraformer |
|
|
|
|
|
License: Research purposes - free to use. Commercial purposes - at your own risk. |
|
|
""" |
|
|
|
|
|
from transformers import AutoModel |
|
|
import torch |
|
|
|
|
|
def main(): |
|
|
print("Paraformer Model - Example Usage") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("Loading model from Hugging Face Hub...") |
|
|
model = AutoModel.from_pretrained('nguyenthanhasia/paraformer', trust_remote_code=True) |
|
|
print("✓ Model loaded successfully") |
|
|
|
|
|
|
|
|
print("\n1. Single Query-Article Example:") |
|
|
print("-" * 30) |
|
|
|
|
|
query = "What are the legal requirements for contract formation?" |
|
|
article = [ |
|
|
"A contract is a legally binding agreement between two or more parties.", |
|
|
"For a contract to be valid, it must have offer, acceptance, and consideration.", |
|
|
"The parties must have legal capacity to enter into the contract." |
|
|
] |
|
|
|
|
|
print(f"Query: {query}") |
|
|
print(f"Article: {len(article)} sentences") |
|
|
|
|
|
|
|
|
relevance_score = model.get_relevance_score(query, article) |
|
|
print(f"Relevance Score: {relevance_score:.4f}") |
|
|
|
|
|
|
|
|
prediction = model.predict_relevance(query, article) |
|
|
print(f"Binary Output: {prediction} (0=lower similarity, 1=higher similarity)") |
|
|
|
|
|
|
|
|
print("\n2. Batch Processing Example:") |
|
|
print("-" * 30) |
|
|
|
|
|
queries = [ |
|
|
"What constitutes a valid contract?", |
|
|
"How can employment be terminated?", |
|
|
"What are the requirements for copyright protection?" |
|
|
] |
|
|
|
|
|
articles = [ |
|
|
["A contract requires offer, acceptance, and consideration.", "All parties must have legal capacity."], |
|
|
["Employment can be terminated by mutual agreement.", "Notice period must be respected."], |
|
|
["Copyright protects original works of authorship.", "The work must be fixed in a tangible medium."] |
|
|
] |
|
|
|
|
|
|
|
|
outputs = model.forward( |
|
|
query_texts=queries, |
|
|
article_texts=articles, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=-1) |
|
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
|
|
|
|
for i, (query, article) in enumerate(zip(queries, articles)): |
|
|
score = probabilities[i, 1].item() |
|
|
pred = predictions[i].item() |
|
|
print(f"\nQuery {i+1}: {query}") |
|
|
print(f" Similarity Score: {score:.4f}") |
|
|
print(f" Binary Output: {pred}") |
|
|
|
|
|
|
|
|
print("\n3. Attention Weights Example:") |
|
|
print("-" * 30) |
|
|
|
|
|
query = "What is required for a valid contract?" |
|
|
article = [ |
|
|
"A contract is an agreement between parties.", |
|
|
"It must have offer and acceptance.", |
|
|
"Consideration is also required.", |
|
|
"The weather is nice today." |
|
|
] |
|
|
|
|
|
outputs = model.forward( |
|
|
query_texts=[query], |
|
|
article_texts=[article], |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
if outputs.attentions is not None: |
|
|
attention_weights = outputs.attentions[0, 0] |
|
|
print(f"Query: {query}") |
|
|
print("Attention weights per sentence:") |
|
|
for i, (sentence, weight) in enumerate(zip(article, attention_weights)): |
|
|
print(f" Sentence {i+1}: {weight:.4f} - {sentence}") |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("Important Notes:") |
|
|
print("- Scores represent similarity in learned feature space, not absolute relevance") |
|
|
print("- This is a simplified implementation for easy integration") |
|
|
print("- For full functionality: https://github.com/nguyenthanhasia/paraformer") |
|
|
print("- Research use: free | Commercial use: at your own risk") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|