File size: 4,015 Bytes
7ff080c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import asyncio
import zlib
from functools import partial
from io import BytesIO

from hfendpoints.openai import Context, run
from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
    TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
from librosa import load as load_audio, get_duration
from loguru import logger
from nemo.collections.asr.models import ASRModel

from hfendpoints import EndpointConfig, Handler, __version__


def compression_ratio(text: str) -> float:
    """
    :param text:
    :return:
    """
    text_bytes = text.encode("utf-8")
    return len(text_bytes) / len(zlib.compress(text_bytes))


def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment:
    return SegmentBuilder() \
        .id(idx) \
        .start(segment['start']) \
        .end(segment['end']) \
        .text(segment['segment']) \
        .tokens(tokenizer.text_to_ids(segment['segment'])) \
        .temperature(request.temperature) \
        .compression_ratio(compression_ratio(segment['segment'])) \
        .build()


class NemoAsrHandler(Handler):
    __slots__ = ("_model",)

    def __init__(self, config: EndpointConfig):
        logger.info(config.repository)
        self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()

    async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
        with logger.contextualize(request_id=ctx.request_id):
            with memoryview(request) as audio:
                (waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True)
                logger.debug(
                    f"Successfully decoded {len(waveform)} bytes PCM audio chunk"
                )

                # Do we need to compute the timestamps?
                needs_timestamps = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
                transcribe_f = partial(self._model.transcribe, timestamps=needs_timestamps, verbose=False)

                outputs = await asyncio.get_running_loop().run_in_executor(
                    None,
                    transcribe_f,
                    (waveform,)
                )

                output = outputs[0]
                text = output.text

                match request.response_kind:
                    case TranscriptionResponseKind.VERBOSE_JSON:
                        segment_timestamps = output.timestamp['segment']
                        segments = [
                            get_segment(idx, stamp, self._model.tokenizer, request)
                            for (idx, stamp) in enumerate(segment_timestamps)
                        ]

                        logger.info(f"Segment: {segment_timestamps[0]}")

                        return TranscriptionResponse.verbose(
                            VerboseTranscription(
                                text=text,
                                duration=get_duration(y=waveform, sr=sampling),
                                language=request.language,
                                segments=segments,
                                # word=None
                            )
                        )
                    case TranscriptionResponseKind.JSON:
                        return TranscriptionResponse.json(text)

                    case TranscriptionResponseKind.TEXT:
                        return TranscriptionResponse.text(text)

                # Theoretically, we can't end up there as Rust validates the enum value beforehand
                raise RuntimeError(f"unknown response_kind: {request.response_kind}")


def entrypoint():
    config = EndpointConfig.from_env()
    handler = NemoAsrHandler(config)
    endpoint = AutomaticSpeechRecognitionEndpoint(handler)

    logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}")
    run(endpoint, config.interface, config.port)


if __name__ == '__main__':
    entrypoint()