File size: 4,032 Bytes
57a3905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Example usage of Paraformer model for legal document retrieval.

This is a simplified implementation. For full functionality and customization,
visit: https://github.com/nguyenthanhasia/paraformer

License: Research purposes - free to use. Commercial purposes - at your own risk.
"""

from transformers import AutoModel
import torch

def main():
    print("Paraformer Model - Example Usage")
    print("=" * 50)
    
    # Load the model
    print("Loading model from Hugging Face Hub...")
    model = AutoModel.from_pretrained('nguyenthanhasia/paraformer', trust_remote_code=True)
    print("✓ Model loaded successfully")
    
    # Example 1: Single query-article pair
    print("\n1. Single Query-Article Example:")
    print("-" * 30)
    
    query = "What are the legal requirements for contract formation?"
    article = [
        "A contract is a legally binding agreement between two or more parties.",
        "For a contract to be valid, it must have offer, acceptance, and consideration.",
        "The parties must have legal capacity to enter into the contract."
    ]
    
    print(f"Query: {query}")
    print(f"Article: {len(article)} sentences")
    
    # Get relevance score
    relevance_score = model.get_relevance_score(query, article)
    print(f"Relevance Score: {relevance_score:.4f}")
    
    # Get binary prediction
    prediction = model.predict_relevance(query, article)
    print(f"Binary Output: {prediction} (0=lower similarity, 1=higher similarity)")
    
    # Example 2: Batch processing
    print("\n2. Batch Processing Example:")
    print("-" * 30)
    
    queries = [
        "What constitutes a valid contract?",
        "How can employment be terminated?",
        "What are the requirements for copyright protection?"
    ]
    
    articles = [
        ["A contract requires offer, acceptance, and consideration.", "All parties must have legal capacity."],
        ["Employment can be terminated by mutual agreement.", "Notice period must be respected."],
        ["Copyright protects original works of authorship.", "The work must be fixed in a tangible medium."]
    ]
    
    # Forward pass for batch processing
    outputs = model.forward(
        query_texts=queries,
        article_texts=articles,
        return_dict=True
    )
    
    # Get probabilities and predictions
    probabilities = torch.softmax(outputs.logits, dim=-1)
    predictions = torch.argmax(outputs.logits, dim=-1)
    
    for i, (query, article) in enumerate(zip(queries, articles)):
        score = probabilities[i, 1].item()
        pred = predictions[i].item()
        print(f"\nQuery {i+1}: {query}")
        print(f"  Similarity Score: {score:.4f}")
        print(f"  Binary Output: {pred}")
    
    # Example 3: Accessing attention weights
    print("\n3. Attention Weights Example:")
    print("-" * 30)
    
    query = "What is required for a valid contract?"
    article = [
        "A contract is an agreement between parties.",
        "It must have offer and acceptance.",
        "Consideration is also required.",
        "The weather is nice today."  # Irrelevant sentence
    ]
    
    outputs = model.forward(
        query_texts=[query],
        article_texts=[article],
        return_dict=True
    )
    
    if outputs.attentions is not None:
        attention_weights = outputs.attentions[0, 0]  # First batch, first query
        print(f"Query: {query}")
        print("Attention weights per sentence:")
        for i, (sentence, weight) in enumerate(zip(article, attention_weights)):
            print(f"  Sentence {i+1}: {weight:.4f} - {sentence}")
    
    print("\n" + "=" * 50)
    print("Important Notes:")
    print("- Scores represent similarity in learned feature space, not absolute relevance")
    print("- This is a simplified implementation for easy integration")
    print("- For full functionality: https://github.com/nguyenthanhasia/paraformer")
    print("- Research use: free | Commercial use: at your own risk")

if __name__ == "__main__":
    main()