|
import torch |
|
import numpy as np |
|
from typing import List, Union, Dict, Any |
|
from config import * |
|
|
|
def get_encoded_input_single(text, tokenizer, token_style, sequence_len = 256): |
|
"""Process a single text sequence - matches your conversion code logic""" |
|
words = text.split() |
|
word_pos = 0 |
|
|
|
x = [TOKEN_IDX[token_style]['START_SEQ']] |
|
y_mask = [0] |
|
|
|
while len(x) < sequence_len and word_pos < len(words): |
|
tokens = tokenizer.tokenize(words[word_pos]) |
|
if len(tokens) + len(x) >= sequence_len: |
|
break |
|
else: |
|
for i in range(len(tokens) - 1): |
|
x.append(tokenizer.convert_tokens_to_ids(tokens[i])) |
|
y_mask.append(0) |
|
x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) |
|
y_mask.append(1) |
|
word_pos += 1 |
|
|
|
x.append(TOKEN_IDX[token_style]['END_SEQ']) |
|
y_mask.append(0) |
|
|
|
|
|
if len(x) < sequence_len: |
|
x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] |
|
y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] |
|
|
|
attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] |
|
|
|
return { |
|
'input_values': x, |
|
'attention_mask': attn_mask, |
|
'y_mask': y_mask |
|
} |
|
|
|
def get_encoded_input_batch(texts, tokenizer, token_style, sequence_len = 256): |
|
"""Process a batch of text sequences - matches your conversion code logic""" |
|
batch_data = [] |
|
|
|
for text in texts: |
|
encoded = get_encoded_input_single(text, tokenizer, token_style, sequence_len) |
|
batch_data.append(encoded) |
|
|
|
|
|
batch_input_values = torch.tensor([item['input_values'] for item in batch_data]) |
|
batch_attention_mask = torch.tensor([item['attention_mask'] for item in batch_data]) |
|
batch_y_mask = torch.tensor([item['y_mask'] for item in batch_data]) |
|
|
|
encoded_input = { |
|
'input_values': batch_input_values, |
|
'attention_mask': batch_attention_mask, |
|
'y_mask': batch_y_mask |
|
} |
|
|
|
return encoded_input |
|
|
|
def run_onnx_inference(input_values, attention_mask, session): |
|
"""Run ONNX inference with the unified model""" |
|
|
|
input_values_name = session.get_inputs()[0].name |
|
attention_mask_name = session.get_inputs()[1].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
|
|
inputs = { |
|
input_values_name: input_values.cpu().numpy(), |
|
attention_mask_name: attention_mask.cpu().numpy() |
|
} |
|
|
|
|
|
output = session.run([output_name], inputs) |
|
predictions = torch.tensor(output[0]) |
|
predictions = torch.argmax(predictions, dim=2) |
|
|
|
return predictions |
|
|
|
def get_transcription_batch(texts, session, tokenizer, device, token_style): |
|
"""Process multiple texts and return punctuated results""" |
|
|
|
|
|
encoded_batch = get_encoded_input_batch(texts, tokenizer, token_style) |
|
|
|
|
|
input_values = encoded_batch['input_values'].to(device) |
|
attention_mask = encoded_batch['attention_mask'].to(device) |
|
y_masks = encoded_batch['y_mask'] |
|
|
|
|
|
predictions = run_onnx_inference(input_values, attention_mask, session) |
|
|
|
|
|
results = [] |
|
for text_idx, text in enumerate(texts): |
|
words_original_case = text.split() |
|
y_mask = y_masks[text_idx] |
|
y_predict = predictions[text_idx] |
|
|
|
result = "" |
|
decode_idx = 0 |
|
|
|
for i in range(y_mask.shape[0]): |
|
if y_mask[i] == 1 and decode_idx < len(words_original_case): |
|
result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' ' |
|
decode_idx += 1 |
|
|
|
results.append(result.strip()) |
|
|
|
return results |
|
|
|
def get_transcription(text_or_texts, session, tokenizer, device, token_style): |
|
""" |
|
Main function that handles both single text and batch processing |
|
Uses the unified ONNX model for both cases |
|
|
|
Args: |
|
text_or_texts: Single text string or list of text strings |
|
|
|
Returns: |
|
Single punctuated string or list of punctuated strings |
|
""" |
|
if isinstance(text_or_texts, str): |
|
return get_transcription_batch([text_or_texts], session, tokenizer, device, token_style) |
|
elif isinstance(text_or_texts, list): |
|
return get_transcription_batch(text_or_texts, session, tokenizer, device, token_style) |
|
else: |
|
raise ValueError("Input must be either a string or a list of strings") |
|
|
|
|
|
if __name__ == '__main__': |
|
import time |
|
|
|
test_text = 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট' |
|
|
|
print("Testing single text processing:") |
|
print("=" * 50) |
|
|
|
|
|
for i in range(3): |
|
start_time = time.time() |
|
result = get_transcription(test_text) |
|
end_time = time.time() |
|
print(f"Run {i+1}: {end_time - start_time:.4f}s") |
|
|
|
print(f"\nSingle result: {result[:100]}...") |
|
|
|
print("\nTesting batch text processing:") |
|
print("=" * 50) |
|
|
|
|
|
batch_texts = [ |
|
'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট', |
|
'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট', |
|
] |
|
|
|
start_time = time.time() |
|
batch_results = get_transcription(batch_texts) |
|
end_time = time.time() |
|
|
|
print(f"Batch processing time: {end_time - start_time:.4f}s") |
|
print(f"Processed {len(batch_texts)} texts") |
|
print(f"Average time per text: {(end_time - start_time) / len(batch_texts):.4f}s") |
|
|
|
for i, result in enumerate(batch_results): |
|
print(f"Text {i+1}: {result[:50]}...") |
|
|
|
|