ArunKr commited on
Commit
ea520ad
·
verified ·
1 Parent(s): 07f2f14

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import re
4
+ import io
5
+ import json
6
+ import time
7
+ import base64
8
+ import asyncio
9
+ from typing import AsyncGenerator, List
10
+
11
+ import numpy as np
12
+ from fastapi import FastAPI, Query
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import StreamingResponse, JSONResponse
15
+
16
+ from kokoro_onnx import Kokoro # pip install kokoro-onnx
17
+
18
+ APP_NAME = "kokoro-onnx-fastapi-sse"
19
+ SAMPLE_RATE = 24000 # Kokoro output
20
+ CHANNELS = 1
21
+ CHUNK_SAMPLES = 2400 # 100 ms
22
+
23
+ MODEL_PATH = os.getenv("KOKORO_MODEL", "models/kokoro-v1.0.int8.onnx")
24
+ VOICES_PATH = os.getenv("KOKORO_VOICES", "models/voices-v1.0.bin")
25
+
26
+ app = FastAPI(title=APP_NAME)
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
30
+ )
31
+
32
+ kokoro = None
33
+ _model_lock = asyncio.Lock()
34
+
35
+
36
+ def split_text(text: str, max_len: int = 220) -> List[str]:
37
+ # Split on sentences, then fold long pieces
38
+ parts = re.split(r"(?<=[\.\!\?।])\s+", text.strip())
39
+ chunks: List[str] = []
40
+ for p in parts:
41
+ p = p.strip()
42
+ while len(p) > max_len:
43
+ cut = p.rfind(" ", 0, max_len)
44
+ if cut == -1:
45
+ cut = max_len
46
+ chunks.append(p[:cut].strip())
47
+ p = p[cut:].strip()
48
+ if p:
49
+ chunks.append(p)
50
+ return [c for c in chunks if c]
51
+
52
+
53
+ def floats_to_s16le(samples: np.ndarray) -> np.ndarray:
54
+ x = np.clip(samples, -1.0, 1.0)
55
+ return (x * 32767.0).astype(np.int16)
56
+
57
+
58
+ async def sse_generator(text: str, voice: str, speed: float, lang: str) -> AsyncGenerator[bytes, None]:
59
+ # Yields SSE messages with base64 PCM16 frames (100ms per chunk).
60
+ seq = 0
61
+ total_samples = 0
62
+ try:
63
+ # Warmup ping
64
+ yield b": keep-alive\n\n"
65
+ for sentence in split_text(text):
66
+ async with _model_lock:
67
+ samples, sr = kokoro.create(sentence, voice=voice, speed=speed, lang=lang)
68
+ assert sr == SAMPLE_RATE, f"Expected {SAMPLE_RATE}, got {sr}"
69
+ pcm16 = floats_to_s16le(np.asarray(samples))
70
+
71
+ for i in range(0, len(pcm16), CHUNK_SAMPLES):
72
+ frame = pcm16[i:i+CHUNK_SAMPLES]
73
+ total_samples += len(frame)
74
+ payload = {
75
+ "seq": seq,
76
+ "sr": SAMPLE_RATE,
77
+ "ch": CHANNELS,
78
+ "format": "s16le",
79
+ "pcm16": base64.b64encode(frame.tobytes()).decode("ascii"),
80
+ }
81
+ msg = f"data: {json.dumps(payload, separators=(',',':'))}\n\n"
82
+ yield msg.encode("utf-8")
83
+ seq += 1
84
+ await asyncio.sleep(0) # give control back to loop
85
+ # Done event
86
+ done = {"total_chunks": seq, "total_samples": total_samples}
87
+ yield f"event: done\ndata: {json.dumps(done)}\n\n".encode("utf-8")
88
+ except asyncio.CancelledError:
89
+ # Client disconnected
90
+ return
91
+
92
+
93
+ @app.on_event("startup")
94
+ async def _load_model():
95
+ global kokoro
96
+ kokoro = Kokoro(MODEL_PATH, VOICES_PATH)
97
+
98
+
99
+ @app.get("/healthz")
100
+ async def healthz():
101
+ return {"status": "ok", "model": os.path.basename(MODEL_PATH)}
102
+
103
+
104
+ @app.get("/v1/voices")
105
+ async def list_voices():
106
+ try:
107
+ import numpy as _np
108
+ with _np.load(VOICES_PATH) as z:
109
+ names = sorted(list(z.files))
110
+ return {"voices": names}
111
+ except Exception:
112
+ fallback = [
113
+ "af", "af_bella", "af_nicole", "af_sarah", "af_sky",
114
+ "am_adam", "am_michael",
115
+ "bf_emma", "bf_isabella",
116
+ "bm_george", "bm_lewis",
117
+ ]
118
+ return {"voices": fallback, "note": "fallback list; voices file not parsed"}
119
+
120
+
121
+ @app.get("/v1/tts.sse")
122
+ async def tts_sse(
123
+ text: str = Query(..., description="Text to synthesize"),
124
+ voice: str = Query("af_sarah"),
125
+ speed: float = Query(1.0, ge=0.5, le=1.5),
126
+ lang: str = Query("en-us"),
127
+ ):
128
+ headers = {
129
+ "Cache-Control": "no-cache",
130
+ "Connection": "keep-alive",
131
+ "X-Accel-Buffering": "no", # for nginx
132
+ }
133
+ return StreamingResponse(
134
+ sse_generator(text, voice, speed, lang),
135
+ media_type="text/event-stream",
136
+ headers=headers,
137
+ )