File size: 7,016 Bytes
f81cfe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import numpy as np
from typing import List, Union, Dict, Any
from config import *

def get_encoded_input_single(text, tokenizer, token_style, sequence_len = 256):
    """Process a single text sequence - matches your conversion code logic"""
    words = text.split()
    word_pos = 0

    x = [TOKEN_IDX[token_style]['START_SEQ']]
    y_mask = [0]

    while len(x) < sequence_len and word_pos < len(words):
        tokens = tokenizer.tokenize(words[word_pos])
        if len(tokens) + len(x) >= sequence_len:
            break
        else:
            for i in range(len(tokens) - 1):
                x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
                y_mask.append(0)
            x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
            y_mask.append(1)
            word_pos += 1
    
    x.append(TOKEN_IDX[token_style]['END_SEQ'])
    y_mask.append(0)
    
    # Pad to sequence_len
    if len(x) < sequence_len:
        x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
        y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
    
    attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
    
    return {
        'input_values': x,
        'attention_mask': attn_mask,
        'y_mask': y_mask
    }

def get_encoded_input_batch(texts, tokenizer, token_style, sequence_len = 256):
    """Process a batch of text sequences - matches your conversion code logic"""
    batch_data = []
    
    for text in texts:
        encoded = get_encoded_input_single(text, tokenizer, token_style, sequence_len)
        batch_data.append(encoded)
    
    # Stack all sequences into batch tensors
    batch_input_values = torch.tensor([item['input_values'] for item in batch_data])
    batch_attention_mask = torch.tensor([item['attention_mask'] for item in batch_data])
    batch_y_mask = torch.tensor([item['y_mask'] for item in batch_data])
    
    encoded_input = {
        'input_values': batch_input_values,
        'attention_mask': batch_attention_mask,
        'y_mask': batch_y_mask
    }
    
    return encoded_input

def run_onnx_inference(input_values, attention_mask, session):
    """Run ONNX inference with the unified model"""
    # Get input/output names
    input_values_name = session.get_inputs()[0].name
    attention_mask_name = session.get_inputs()[1].name
    output_name = session.get_outputs()[0].name
    
    # Prepare inputs for ONNX (convert to numpy)
    inputs = {
        input_values_name: input_values.cpu().numpy(),
        attention_mask_name: attention_mask.cpu().numpy()
    }
    
    # Run inference
    output = session.run([output_name], inputs)
    predictions = torch.tensor(output[0])  # Shape: [batch_size, seq_len, num_classes]
    predictions = torch.argmax(predictions, dim=2)  # Shape: [batch_size, seq_len]
    
    return predictions

def get_transcription_batch(texts, session, tokenizer, device, token_style):
    """Process multiple texts and return punctuated results"""
    
    # Prepare batch data
    encoded_batch = get_encoded_input_batch(texts, tokenizer, token_style)
    
    # Move to device
    input_values = encoded_batch['input_values'].to(device)
    attention_mask = encoded_batch['attention_mask'].to(device)
    y_masks = encoded_batch['y_mask']
    
    # Run batch inference
    predictions = run_onnx_inference(input_values, attention_mask, session)
    
    # Post-process results for each text
    results = []
    for text_idx, text in enumerate(texts):
        words_original_case = text.split()
        y_mask = y_masks[text_idx]
        y_predict = predictions[text_idx]
        
        result = ""
        decode_idx = 0
        
        for i in range(y_mask.shape[0]):
            if y_mask[i] == 1 and decode_idx < len(words_original_case):
                result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
                decode_idx += 1
        
        results.append(result.strip())
    
    return results

def get_transcription(text_or_texts, session, tokenizer, device, token_style):
    """
    Main function that handles both single text and batch processing
    Uses the unified ONNX model for both cases
    
    Args:
        text_or_texts: Single text string or list of text strings
    
    Returns:
        Single punctuated string or list of punctuated strings
    """
    if isinstance(text_or_texts, str):
        return get_transcription_batch([text_or_texts], session, tokenizer, device, token_style)
    elif isinstance(text_or_texts, list):
        return get_transcription_batch(text_or_texts, session, tokenizer, device, token_style)
    else:
        raise ValueError("Input must be either a string or a list of strings")


if __name__ == '__main__':
    import time
    
    test_text = 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট'
    
    print("Testing single text processing:")
    print("=" * 50)
    
    # Test single text processing
    for i in range(3):
        start_time = time.time()
        result = get_transcription(test_text)
        end_time = time.time()
        print(f"Run {i+1}: {end_time - start_time:.4f}s")
    
    print(f"\nSingle result: {result[:100]}...")
    
    print("\nTesting batch text processing:")
    print("=" * 50)
    
    # Test batch processing
    batch_texts = [
        'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
        'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
    ]
    
    start_time = time.time()
    batch_results = get_transcription(batch_texts)
    end_time = time.time()
    
    print(f"Batch processing time: {end_time - start_time:.4f}s")
    print(f"Processed {len(batch_texts)} texts")
    print(f"Average time per text: {(end_time - start_time) / len(batch_texts):.4f}s")
    
    for i, result in enumerate(batch_results):
        print(f"Text {i+1}: {result[:50]}...")