mict-zhaw commited on
Commit
1b0dfa6
·
1 Parent(s): 791642a

Remove handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -103
handler.py DELETED
@@ -1,103 +0,0 @@
1
- from typing import Dict, List, Any
2
- import torch
3
- from pyctcdecode import build_ctcdecoder
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Tokenizer, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
5
- from transformers import pipeline
6
-
7
-
8
- class EndpointHandler:
9
- def __init__(self, path=""):
10
-
11
- print("init")
12
-
13
- self.pipeline = pipeline("automatic-speech-recognition", model=path)
14
-
15
- # Preload all the elements you are going to need at inference.
16
- self.model = Wav2Vec2ForCTC.from_pretrained(path)
17
- self.model.to("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- print(self.model)
20
-
21
- print("Wav2Vec2Tokenizer")
22
-
23
- # self.processor = Wav2Vec2Processor.from_pretrained(os.path.join(path, "pytorch_model.bin"))
24
- self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(path)
25
-
26
- print(self.tokenizer)
27
-
28
- print("Wav2Vec2FeatureExtractor")
29
-
30
- self.feature_extractor = Wav2Vec2FeatureExtractor(
31
- feature_size=1,
32
- sampling_rate=16000,
33
- padding_value=0.0,
34
- do_normalize=True,
35
- return_attention_mask=True
36
- )
37
-
38
- print(self.feature_extractor)
39
-
40
- vocab_dict = self.tokenizer.get_vocab()
41
- sorted_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
42
- sorted_dict_keys = list(sorted_dict.keys())
43
- self.vocab = sorted_dict
44
-
45
- print("Vocabulary", self.vocab)
46
-
47
- language_model_decoder = build_ctcdecoder(
48
- labels=sorted_dict_keys,
49
- alpha=0.5,
50
- beta=1.5
51
- )
52
-
53
- # beam size?
54
- self.processor = Wav2Vec2ProcessorWithLM(
55
- tokenizer=self.tokenizer,
56
- feature_extractor=self.feature_extractor,
57
- decoder=language_model_decoder
58
- )
59
-
60
- def __call__(self, inputs: Dict[str, Any]) -> List[Dict[str, Any]]:
61
- """
62
- data args:
63
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
64
- kwargs
65
- Return:
66
- A :obj:`list` | `dict`: will be serialized and returned
67
- """
68
-
69
- print("inputs")
70
- print(inputs)
71
-
72
-
73
- if "audio_file_path" in inputs.keys():
74
- audio_file_path = inputs.pop("audio_file_path")
75
- prediction = self.pipeline(audio_file_path)
76
-
77
- return prediction
78
-
79
- if "audio" in inputs.keys():
80
- audio_input = inputs.pop("audio")
81
- sample_rate = inputs.pop("sampling_rate", 16000)
82
- inputs = self.processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True)
83
-
84
- device = "cuda" if torch.cuda.is_available() else "cpu"
85
- inputs = {key: value.to(device) for key, value in inputs.items()}
86
-
87
- print(inputs)
88
-
89
-
90
- print(r)
91
-
92
- # Perform inference
93
- with torch.no_grad():
94
- logits = self.model(inputs["input_values"][0]).logits
95
-
96
- # predicted_ids = torch.argmax(logits, dim=-1)
97
- transcription = self.processor.batch_decode(logits.cpu().numpy()).text
98
-
99
- return {"prediction": transcription[0]}
100
-
101
- def postprocess(self, inference_output):
102
- pass
103
-