Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import io | |
| import json | |
| import time | |
| import base64 | |
| import asyncio | |
| from typing import AsyncGenerator, List | |
| import numpy as np | |
| from fastapi import FastAPI, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from kokoro_onnx import Kokoro # pip install kokoro-onnx | |
| APP_NAME = "kokoro-onnx-fastapi-sse" | |
| SAMPLE_RATE = 24000 # Kokoro output | |
| CHANNELS = 1 | |
| CHUNK_SAMPLES = 2400 # 100 ms | |
| MODEL_PATH = os.getenv("KOKORO_MODEL", "models/kokoro-v1.0.int8.onnx") | |
| VOICES_PATH = os.getenv("KOKORO_VOICES", "models/voices-v1.0.bin") | |
| app = FastAPI(title=APP_NAME) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| kokoro = None | |
| _model_lock = asyncio.Lock() | |
| def split_text(text: str, max_len: int = 220) -> List[str]: | |
| # Split on sentences, then fold long pieces | |
| parts = re.split(r"(?<=[\.\!\?।])\s+", text.strip()) | |
| chunks: List[str] = [] | |
| for p in parts: | |
| p = p.strip() | |
| while len(p) > max_len: | |
| cut = p.rfind(" ", 0, max_len) | |
| if cut == -1: | |
| cut = max_len | |
| chunks.append(p[:cut].strip()) | |
| p = p[cut:].strip() | |
| if p: | |
| chunks.append(p) | |
| return [c for c in chunks if c] | |
| def floats_to_s16le(samples: np.ndarray) -> np.ndarray: | |
| x = np.clip(samples, -1.0, 1.0) | |
| return (x * 32767.0).astype(np.int16) | |
| async def sse_generator(text: str, voice: str, speed: float, lang: str) -> AsyncGenerator[bytes, None]: | |
| # Yields SSE messages with base64 PCM16 frames (100ms per chunk). | |
| seq = 0 | |
| total_samples = 0 | |
| try: | |
| # Warmup ping | |
| yield b": keep-alive\n\n" | |
| for sentence in split_text(text): | |
| async with _model_lock: | |
| samples, sr = kokoro.create(sentence, voice=voice, speed=speed, lang=lang) | |
| assert sr == SAMPLE_RATE, f"Expected {SAMPLE_RATE}, got {sr}" | |
| pcm16 = floats_to_s16le(np.asarray(samples)) | |
| for i in range(0, len(pcm16), CHUNK_SAMPLES): | |
| frame = pcm16[i:i+CHUNK_SAMPLES] | |
| total_samples += len(frame) | |
| payload = { | |
| "seq": seq, | |
| "sr": SAMPLE_RATE, | |
| "ch": CHANNELS, | |
| "format": "s16le", | |
| "pcm16": base64.b64encode(frame.tobytes()).decode("ascii"), | |
| } | |
| msg = f"data: {json.dumps(payload, separators=(',',':'))}\n\n" | |
| yield msg.encode("utf-8") | |
| seq += 1 | |
| await asyncio.sleep(0) # give control back to loop | |
| # Done event | |
| done = {"total_chunks": seq, "total_samples": total_samples} | |
| yield f"event: done\ndata: {json.dumps(done)}\n\n".encode("utf-8") | |
| except asyncio.CancelledError: | |
| # Client disconnected | |
| return | |
| async def _load_model(): | |
| global kokoro | |
| kokoro = Kokoro(MODEL_PATH, VOICES_PATH) | |
| async def healthz(): | |
| return {"status": "ok", "model": os.path.basename(MODEL_PATH)} | |
| async def list_voices(): | |
| try: | |
| import numpy as _np | |
| with _np.load(VOICES_PATH) as z: | |
| names = sorted(list(z.files)) | |
| return {"voices": names} | |
| except Exception: | |
| fallback = [ | |
| "af", "af_bella", "af_nicole", "af_sarah", "af_sky", | |
| "am_adam", "am_michael", | |
| "bf_emma", "bf_isabella", | |
| "bm_george", "bm_lewis", | |
| ] | |
| return {"voices": fallback, "note": "fallback list; voices file not parsed"} | |
| async def tts_sse( | |
| text: str = Query(..., description="Text to synthesize"), | |
| voice: str = Query("af_sarah"), | |
| speed: float = Query(1.0, ge=0.5, le=1.5), | |
| lang: str = Query("en-us"), | |
| ): | |
| headers = { | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", # for nginx | |
| } | |
| return StreamingResponse( | |
| sse_generator(text, voice, speed, lang), | |
| media_type="text/event-stream", | |
| headers=headers, | |
| ) | |