Spaces:
Runtime error
Runtime error
File size: 2,104 Bytes
feb2a2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import numpy as np
from typing import Dict
import torch
import pyctcdecode
from transformers import (
Wav2Vec2Processor,
Wav2Vec2ProcessorWithLM,
Wav2Vec2ForCTC,
)
class PreTrainedPipeline():
def __init__(self, model_path: str, language_model_fp: str):
self.language_model_fp = language_model_fp
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
self.model.to(self.device)
processor = Wav2Vec2Processor.from_pretrained(model_path)
self.sampling_rate = processor.feature_extractor.sampling_rate
vocab = processor.tokenizer.get_vocab()
sorted_vocab_dict = [(char, ix) for char, ix in sorted(vocab.items(), key=lambda item: item[1])]
self.decoder = pyctcdecode.build_ctcdecoder(
labels=[x[0] for x in sorted_vocab_dict],
kenlm_model_path=self.language_model_fp,
)
self.processor_with_lm = Wav2Vec2ProcessorWithLM(
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
decoder=self.decoder
)
def __call__(self, inputs: np.array) -> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
input_values = self.processor_with_lm(
inputs, return_tensors="pt",
sampling_rate=self.sampling_rate
)['input_values']
with torch.no_grad():
# input_values should be a 1D numpy array by now
input_values = torch.tensor(input_values, device=self.device)
model_outs = self.model(input_values)
logits = model_outs.logits.cpu().detach().numpy()
text_predicted = self.processor_with_lm.batch_decode(logits)['text']
return {
"text": text_predicted
}
|