Spaces:
Running
Running
File size: 2,082 Bytes
feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b adca0d8 feb2a2b 0952218 feb2a2b 0952218 feb2a2b adca0d8 feb2a2b adca0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import Dict
import numpy as np
import pyctcdecode
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
class PreTrainedPipeline:
def __init__(self, model_path: str, language_model_fp: str):
self.language_model_fp = language_model_fp
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
self.model.to(self.device)
processor = Wav2Vec2Processor.from_pretrained(model_path)
self.sampling_rate = processor.feature_extractor.sampling_rate
vocab = processor.tokenizer.get_vocab()
sorted_vocab_dict = [
(char, ix) for char, ix in sorted(vocab.items(), key=lambda item: item[1])
]
self.decoder = pyctcdecode.build_ctcdecoder(
labels=[x[0] for x in sorted_vocab_dict],
kenlm_model_path=self.language_model_fp,
)
self.processor_with_lm = Wav2Vec2ProcessorWithLM(
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
decoder=self.decoder,
)
def __call__(self, inputs: np.array) -> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
input_values = self.processor_with_lm(
inputs, return_tensors="pt", sampling_rate=self.sampling_rate
)["input_values"]
input_values = input_values.to(self.device)
with torch.no_grad():
# input_values should be a 2D tensor by now. 1st dim represents audio channels.
model_outs = self.model(input_values)
logits = model_outs.logits.cpu().detach().numpy()
text_predicted = self.processor_with_lm.batch_decode(logits)["text"]
return {"text": text_predicted}
|