|
import asyncio |
|
import zlib |
|
from functools import partial |
|
from io import BytesIO |
|
|
|
import torch |
|
from hfendpoints.openai import Context, run |
|
from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \ |
|
TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription |
|
from librosa import load as load_audio, get_duration |
|
from loguru import logger |
|
from nemo.collections.asr.models import ASRModel |
|
|
|
from hfendpoints import EndpointConfig, Handler, __version__ |
|
|
|
|
|
def compression_ratio(text: str) -> float: |
|
""" |
|
:param text: |
|
:return: |
|
""" |
|
text_bytes = text.encode("utf-8") |
|
return len(text_bytes) / len(zlib.compress(text_bytes)) |
|
|
|
|
|
def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment: |
|
return SegmentBuilder() \ |
|
.id(idx) \ |
|
.start(segment['start']) \ |
|
.end(segment['end']) \ |
|
.text(segment['segment']) \ |
|
.tokens(tokenizer.text_to_ids(segment['segment'])) \ |
|
.temperature(request.temperature) \ |
|
.compression_ratio(compression_ratio(segment['segment'])) \ |
|
.build() |
|
|
|
|
|
class NemoAsrHandler(Handler): |
|
__slots__ = ("_model",) |
|
|
|
def __init__(self, config: EndpointConfig): |
|
logger.info(config.repository) |
|
self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval() |
|
self._model = self._model.to(torch.bfloat16) |
|
|
|
async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse: |
|
with logger.contextualize(request_id=ctx.request_id): |
|
with memoryview(request) as audio: |
|
(waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True) |
|
logger.debug( |
|
f"Successfully decoded {len(waveform)} bytes PCM audio chunk" |
|
) |
|
|
|
|
|
needs_timestamps = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON |
|
transcribe_f = partial(self._model.transcribe, timestamps=needs_timestamps, verbose=False) |
|
|
|
outputs = await asyncio.get_running_loop().run_in_executor( |
|
None, |
|
transcribe_f, |
|
(waveform,) |
|
) |
|
|
|
output = outputs[0] |
|
text = output.text |
|
|
|
match request.response_kind: |
|
case TranscriptionResponseKind.VERBOSE_JSON: |
|
segment_timestamps = output.timestamp['segment'] |
|
segments = [ |
|
get_segment(idx, stamp, self._model.tokenizer, request) |
|
for (idx, stamp) in enumerate(segment_timestamps) |
|
] |
|
|
|
logger.info(f"Segment: {segment_timestamps[0]}") |
|
|
|
return TranscriptionResponse.verbose( |
|
VerboseTranscription( |
|
text=text, |
|
duration=get_duration(y=waveform, sr=sampling), |
|
language=request.language, |
|
segments=segments, |
|
|
|
) |
|
) |
|
case TranscriptionResponseKind.JSON: |
|
return TranscriptionResponse.json(text) |
|
|
|
case TranscriptionResponseKind.TEXT: |
|
return TranscriptionResponse.text(text) |
|
|
|
|
|
raise RuntimeError(f"unknown response_kind: {request.response_kind}") |
|
|
|
|
|
def entrypoint(): |
|
config = EndpointConfig.from_env() |
|
handler = NemoAsrHandler(config) |
|
endpoint = AutomaticSpeechRecognitionEndpoint(handler) |
|
|
|
logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}") |
|
run(endpoint, config.interface, config.port) |
|
|
|
|
|
if __name__ == '__main__': |
|
entrypoint() |
|
|