khowar-whisper-asr / inference.py
Aizazayyubi's picture
created inference
e69ec9f verified
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
class Inference:
def __init__(self):
self.processor = AutoProcessor.from_pretrained(".")
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(".")
self.model.eval()
def __call__(self, inputs):
audio_path = inputs.get("inputs")
if not audio_path:
return {"error": "No audio provided."}
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
inputs = self.processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
generated_ids = self.model.generate(inputs["input_features"])
text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {"text": text}