Remove handler.py
Browse files- handler.py +0 -103
handler.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
import torch
|
3 |
-
from pyctcdecode import build_ctcdecoder
|
4 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Tokenizer, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
|
5 |
-
from transformers import pipeline
|
6 |
-
|
7 |
-
|
8 |
-
class EndpointHandler:
|
9 |
-
def __init__(self, path=""):
|
10 |
-
|
11 |
-
print("init")
|
12 |
-
|
13 |
-
self.pipeline = pipeline("automatic-speech-recognition", model=path)
|
14 |
-
|
15 |
-
# Preload all the elements you are going to need at inference.
|
16 |
-
self.model = Wav2Vec2ForCTC.from_pretrained(path)
|
17 |
-
self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
-
|
19 |
-
print(self.model)
|
20 |
-
|
21 |
-
print("Wav2Vec2Tokenizer")
|
22 |
-
|
23 |
-
# self.processor = Wav2Vec2Processor.from_pretrained(os.path.join(path, "pytorch_model.bin"))
|
24 |
-
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(path)
|
25 |
-
|
26 |
-
print(self.tokenizer)
|
27 |
-
|
28 |
-
print("Wav2Vec2FeatureExtractor")
|
29 |
-
|
30 |
-
self.feature_extractor = Wav2Vec2FeatureExtractor(
|
31 |
-
feature_size=1,
|
32 |
-
sampling_rate=16000,
|
33 |
-
padding_value=0.0,
|
34 |
-
do_normalize=True,
|
35 |
-
return_attention_mask=True
|
36 |
-
)
|
37 |
-
|
38 |
-
print(self.feature_extractor)
|
39 |
-
|
40 |
-
vocab_dict = self.tokenizer.get_vocab()
|
41 |
-
sorted_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
|
42 |
-
sorted_dict_keys = list(sorted_dict.keys())
|
43 |
-
self.vocab = sorted_dict
|
44 |
-
|
45 |
-
print("Vocabulary", self.vocab)
|
46 |
-
|
47 |
-
language_model_decoder = build_ctcdecoder(
|
48 |
-
labels=sorted_dict_keys,
|
49 |
-
alpha=0.5,
|
50 |
-
beta=1.5
|
51 |
-
)
|
52 |
-
|
53 |
-
# beam size?
|
54 |
-
self.processor = Wav2Vec2ProcessorWithLM(
|
55 |
-
tokenizer=self.tokenizer,
|
56 |
-
feature_extractor=self.feature_extractor,
|
57 |
-
decoder=language_model_decoder
|
58 |
-
)
|
59 |
-
|
60 |
-
def __call__(self, inputs: Dict[str, Any]) -> List[Dict[str, Any]]:
|
61 |
-
"""
|
62 |
-
data args:
|
63 |
-
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
64 |
-
kwargs
|
65 |
-
Return:
|
66 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
67 |
-
"""
|
68 |
-
|
69 |
-
print("inputs")
|
70 |
-
print(inputs)
|
71 |
-
|
72 |
-
|
73 |
-
if "audio_file_path" in inputs.keys():
|
74 |
-
audio_file_path = inputs.pop("audio_file_path")
|
75 |
-
prediction = self.pipeline(audio_file_path)
|
76 |
-
|
77 |
-
return prediction
|
78 |
-
|
79 |
-
if "audio" in inputs.keys():
|
80 |
-
audio_input = inputs.pop("audio")
|
81 |
-
sample_rate = inputs.pop("sampling_rate", 16000)
|
82 |
-
inputs = self.processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True)
|
83 |
-
|
84 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
85 |
-
inputs = {key: value.to(device) for key, value in inputs.items()}
|
86 |
-
|
87 |
-
print(inputs)
|
88 |
-
|
89 |
-
|
90 |
-
print(r)
|
91 |
-
|
92 |
-
# Perform inference
|
93 |
-
with torch.no_grad():
|
94 |
-
logits = self.model(inputs["input_values"][0]).logits
|
95 |
-
|
96 |
-
# predicted_ids = torch.argmax(logits, dim=-1)
|
97 |
-
transcription = self.processor.batch_decode(logits.cpu().numpy()).text
|
98 |
-
|
99 |
-
return {"prediction": transcription[0]}
|
100 |
-
|
101 |
-
def postprocess(self, inference_output):
|
102 |
-
pass
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|