import torch import numpy as np from typing import List, Union, Dict, Any from config import * def get_encoded_input_single(text, tokenizer, token_style, sequence_len = 256): """Process a single text sequence - matches your conversion code logic""" words = text.split() word_pos = 0 x = [TOKEN_IDX[token_style]['START_SEQ']] y_mask = [0] while len(x) < sequence_len and word_pos < len(words): tokens = tokenizer.tokenize(words[word_pos]) if len(tokens) + len(x) >= sequence_len: break else: for i in range(len(tokens) - 1): x.append(tokenizer.convert_tokens_to_ids(tokens[i])) y_mask.append(0) x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) y_mask.append(1) word_pos += 1 x.append(TOKEN_IDX[token_style]['END_SEQ']) y_mask.append(0) # Pad to sequence_len if len(x) < sequence_len: x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] return { 'input_values': x, 'attention_mask': attn_mask, 'y_mask': y_mask } def get_encoded_input_batch(texts, tokenizer, token_style, sequence_len = 256): """Process a batch of text sequences - matches your conversion code logic""" batch_data = [] for text in texts: encoded = get_encoded_input_single(text, tokenizer, token_style, sequence_len) batch_data.append(encoded) # Stack all sequences into batch tensors batch_input_values = torch.tensor([item['input_values'] for item in batch_data]) batch_attention_mask = torch.tensor([item['attention_mask'] for item in batch_data]) batch_y_mask = torch.tensor([item['y_mask'] for item in batch_data]) encoded_input = { 'input_values': batch_input_values, 'attention_mask': batch_attention_mask, 'y_mask': batch_y_mask } return encoded_input def run_onnx_inference(input_values, attention_mask, session): """Run ONNX inference with the unified model""" # Get input/output names input_values_name = session.get_inputs()[0].name attention_mask_name = session.get_inputs()[1].name output_name = session.get_outputs()[0].name # Prepare inputs for ONNX (convert to numpy) inputs = { input_values_name: input_values.cpu().numpy(), attention_mask_name: attention_mask.cpu().numpy() } # Run inference output = session.run([output_name], inputs) predictions = torch.tensor(output[0]) # Shape: [batch_size, seq_len, num_classes] predictions = torch.argmax(predictions, dim=2) # Shape: [batch_size, seq_len] return predictions def get_transcription_batch(texts, session, tokenizer, device, token_style): """Process multiple texts and return punctuated results""" # Prepare batch data encoded_batch = get_encoded_input_batch(texts, tokenizer, token_style) # Move to device input_values = encoded_batch['input_values'].to(device) attention_mask = encoded_batch['attention_mask'].to(device) y_masks = encoded_batch['y_mask'] # Run batch inference predictions = run_onnx_inference(input_values, attention_mask, session) # Post-process results for each text results = [] for text_idx, text in enumerate(texts): words_original_case = text.split() y_mask = y_masks[text_idx] y_predict = predictions[text_idx] result = "" decode_idx = 0 for i in range(y_mask.shape[0]): if y_mask[i] == 1 and decode_idx < len(words_original_case): result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' ' decode_idx += 1 results.append(result.strip()) return results def get_transcription(text_or_texts, session, tokenizer, device, token_style): """ Main function that handles both single text and batch processing Uses the unified ONNX model for both cases Args: text_or_texts: Single text string or list of text strings Returns: Single punctuated string or list of punctuated strings """ if isinstance(text_or_texts, str): return get_transcription_batch([text_or_texts], session, tokenizer, device, token_style) elif isinstance(text_or_texts, list): return get_transcription_batch(text_or_texts, session, tokenizer, device, token_style) else: raise ValueError("Input must be either a string or a list of strings") if __name__ == '__main__': import time test_text = 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট' print("Testing single text processing:") print("=" * 50) # Test single text processing for i in range(3): start_time = time.time() result = get_transcription(test_text) end_time = time.time() print(f"Run {i+1}: {end_time - start_time:.4f}s") print(f"\nSingle result: {result[:100]}...") print("\nTesting batch text processing:") print("=" * 50) # Test batch processing batch_texts = [ 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট', 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট', ] start_time = time.time() batch_results = get_transcription(batch_texts) end_time = time.time() print(f"Batch processing time: {end_time - start_time:.4f}s") print(f"Processed {len(batch_texts)} texts") print(f"Average time per text: {(end_time - start_time) / len(batch_texts):.4f}s") for i, result in enumerate(batch_results): print(f"Text {i+1}: {result[:50]}...")