abdullahalmunem's picture
model added
f81cfe2
import torch
import numpy as np
from typing import List, Union, Dict, Any
from config import *
def get_encoded_input_single(text, tokenizer, token_style, sequence_len = 256):
"""Process a single text sequence - matches your conversion code logic"""
words = text.split()
word_pos = 0
x = [TOKEN_IDX[token_style]['START_SEQ']]
y_mask = [0]
while len(x) < sequence_len and word_pos < len(words):
tokens = tokenizer.tokenize(words[word_pos])
if len(tokens) + len(x) >= sequence_len:
break
else:
for i in range(len(tokens) - 1):
x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
y_mask.append(0)
x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
y_mask.append(1)
word_pos += 1
x.append(TOKEN_IDX[token_style]['END_SEQ'])
y_mask.append(0)
# Pad to sequence_len
if len(x) < sequence_len:
x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
return {
'input_values': x,
'attention_mask': attn_mask,
'y_mask': y_mask
}
def get_encoded_input_batch(texts, tokenizer, token_style, sequence_len = 256):
"""Process a batch of text sequences - matches your conversion code logic"""
batch_data = []
for text in texts:
encoded = get_encoded_input_single(text, tokenizer, token_style, sequence_len)
batch_data.append(encoded)
# Stack all sequences into batch tensors
batch_input_values = torch.tensor([item['input_values'] for item in batch_data])
batch_attention_mask = torch.tensor([item['attention_mask'] for item in batch_data])
batch_y_mask = torch.tensor([item['y_mask'] for item in batch_data])
encoded_input = {
'input_values': batch_input_values,
'attention_mask': batch_attention_mask,
'y_mask': batch_y_mask
}
return encoded_input
def run_onnx_inference(input_values, attention_mask, session):
"""Run ONNX inference with the unified model"""
# Get input/output names
input_values_name = session.get_inputs()[0].name
attention_mask_name = session.get_inputs()[1].name
output_name = session.get_outputs()[0].name
# Prepare inputs for ONNX (convert to numpy)
inputs = {
input_values_name: input_values.cpu().numpy(),
attention_mask_name: attention_mask.cpu().numpy()
}
# Run inference
output = session.run([output_name], inputs)
predictions = torch.tensor(output[0]) # Shape: [batch_size, seq_len, num_classes]
predictions = torch.argmax(predictions, dim=2) # Shape: [batch_size, seq_len]
return predictions
def get_transcription_batch(texts, session, tokenizer, device, token_style):
"""Process multiple texts and return punctuated results"""
# Prepare batch data
encoded_batch = get_encoded_input_batch(texts, tokenizer, token_style)
# Move to device
input_values = encoded_batch['input_values'].to(device)
attention_mask = encoded_batch['attention_mask'].to(device)
y_masks = encoded_batch['y_mask']
# Run batch inference
predictions = run_onnx_inference(input_values, attention_mask, session)
# Post-process results for each text
results = []
for text_idx, text in enumerate(texts):
words_original_case = text.split()
y_mask = y_masks[text_idx]
y_predict = predictions[text_idx]
result = ""
decode_idx = 0
for i in range(y_mask.shape[0]):
if y_mask[i] == 1 and decode_idx < len(words_original_case):
result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
decode_idx += 1
results.append(result.strip())
return results
def get_transcription(text_or_texts, session, tokenizer, device, token_style):
"""
Main function that handles both single text and batch processing
Uses the unified ONNX model for both cases
Args:
text_or_texts: Single text string or list of text strings
Returns:
Single punctuated string or list of punctuated strings
"""
if isinstance(text_or_texts, str):
return get_transcription_batch([text_or_texts], session, tokenizer, device, token_style)
elif isinstance(text_or_texts, list):
return get_transcription_batch(text_or_texts, session, tokenizer, device, token_style)
else:
raise ValueError("Input must be either a string or a list of strings")
if __name__ == '__main__':
import time
test_text = 'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট'
print("Testing single text processing:")
print("=" * 50)
# Test single text processing
for i in range(3):
start_time = time.time()
result = get_transcription(test_text)
end_time = time.time()
print(f"Run {i+1}: {end_time - start_time:.4f}s")
print(f"\nSingle result: {result[:100]}...")
print("\nTesting batch text processing:")
print("=" * 50)
# Test batch processing
batch_texts = [
'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
'ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান চতুর্দশ পাকিস্তানি বোলার হিসেবে অভিষেকেই তুলে নিলেন ছয় উইকেট',
]
start_time = time.time()
batch_results = get_transcription(batch_texts)
end_time = time.time()
print(f"Batch processing time: {end_time - start_time:.4f}s")
print(f"Processed {len(batch_texts)} texts")
print(f"Average time per text: {(end_time - start_time) / len(batch_texts):.4f}s")
for i, result in enumerate(batch_results):
print(f"Text {i+1}: {result[:50]}...")