Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
import onnxruntime as ort
|
4 |
import numpy as np
|
@@ -7,7 +7,14 @@ import io
|
|
7 |
import wave
|
8 |
import uvicorn
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Set up CORS
|
13 |
app.add_middleware(
|
@@ -18,21 +25,51 @@ app.add_middleware(
|
|
18 |
allow_headers=["*"], # Allows all headers
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
|
|
|
24 |
|
25 |
-
# Load
|
26 |
-
|
|
|
|
|
27 |
|
28 |
CHUNK_SIZE = 4000 # Number of samples per chunk
|
29 |
|
30 |
-
def text_to_speech_generator(text):
|
31 |
-
inputs =
|
32 |
input_ids = inputs.input_ids.astype(np.int64)
|
33 |
-
onnx_output =
|
34 |
waveform = onnx_output[0][0]
|
35 |
-
|
36 |
for i in range(0, len(waveform), CHUNK_SIZE):
|
37 |
yield waveform[i:i+CHUNK_SIZE]
|
38 |
|
@@ -45,8 +82,32 @@ def create_wav_header(sample_rate, bits_per_sample, channels):
|
|
45 |
wav_file.writeframes(b'') # Write empty frames to create header
|
46 |
return byte_io.getvalue()
|
47 |
|
48 |
-
@app.get("/tts")
|
49 |
-
async def tts_endpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
try:
|
51 |
sample_rate = 16000
|
52 |
bits_per_sample = 16
|
@@ -55,15 +116,25 @@ async def tts_endpoint(text: str):
|
|
55 |
async def audio_stream_generator():
|
56 |
# First, yield the WAV header
|
57 |
yield create_wav_header(sample_rate, bits_per_sample, channels)
|
58 |
-
|
59 |
# Then stream the audio data
|
60 |
-
for chunk in text_to_speech_generator(text):
|
61 |
yield (chunk * 32767).astype(np.int16).tobytes()
|
62 |
|
63 |
return StreamingResponse(audio_stream_generator(), media_type="audio/wav")
|
64 |
except Exception as e:
|
65 |
raise HTTPException(status_code=500, detail=str(e))
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
if __name__ == "__main__":
|
68 |
host = "0.0.0.0"
|
69 |
port = 8000
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Query
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
import onnxruntime as ort
|
4 |
import numpy as np
|
|
|
7 |
import wave
|
8 |
import uvicorn
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from typing import Dict, List
|
11 |
+
from enum import Enum
|
12 |
+
|
13 |
+
app = FastAPI(
|
14 |
+
title="Multilingual Text-to-Speech API",
|
15 |
+
description="This API provides text-to-speech conversion for multiple Indian languages using ONNX models.",
|
16 |
+
version="1.0.0",
|
17 |
+
)
|
18 |
|
19 |
# Set up CORS
|
20 |
app.add_middleware(
|
|
|
25 |
allow_headers=["*"], # Allows all headers
|
26 |
)
|
27 |
|
28 |
+
# Define supported languages and their corresponding model files
|
29 |
+
SUPPORTED_LANGUAGES: Dict[str, Dict[str, str]] = {
|
30 |
+
"hin": {"name": "Hindi", "file": "mms-tts-hin.onnx"},
|
31 |
+
"ben": {"name": "Bengali", "file": "mms-tts-ben.onnx"},
|
32 |
+
"mar": {"name": "Marathi", "file": "mms-tts-mar.onnx"},
|
33 |
+
"tel": {"name": "Telugu", "file": "mms-tts-tel.onnx"},
|
34 |
+
"tam": {"name": "Tamil", "file": "mms-tts-tam.onnx"},
|
35 |
+
"guj": {"name": "Gujarati", "file": "mms-tts-guj.onnx"},
|
36 |
+
"urd": {"name": "Urdu", "file": "mms-tts-urd-script_arabic.onnx"},
|
37 |
+
"kan": {"name": "Kannada", "file": "mms-tts-kan.onnx"},
|
38 |
+
"mal": {"name": "Malayalam", "file": "mms-tts-mal.onnx"},
|
39 |
+
"pan": {"name": "Punjabi", "file": "mms-tts-pan.onnx"},
|
40 |
+
"nep": {"name": "Nepali", "file": "mms-tts-nep.onnx"}
|
41 |
+
}
|
42 |
+
|
43 |
+
# Create an Enum for language codes
|
44 |
+
class LanguageCode(str, Enum):
|
45 |
+
hindi = "hin"
|
46 |
+
bengali = "ben"
|
47 |
+
marathi = "mar"
|
48 |
+
telugu = "tel"
|
49 |
+
tamil = "tam"
|
50 |
+
gujarati = "guj"
|
51 |
+
urdu = "urd"
|
52 |
+
kannada = "kan"
|
53 |
+
malayalam = "mal"
|
54 |
+
punjabi = "pan"
|
55 |
+
nepali = "nep"
|
56 |
|
57 |
+
# Initialize dictionaries to store sessions and tokenizers
|
58 |
+
sessions: Dict[str, ort.InferenceSession] = {}
|
59 |
+
tokenizers: Dict[str, AutoTokenizer] = {}
|
60 |
|
61 |
+
# Load models and tokenizers for all supported languages
|
62 |
+
for lang, info in SUPPORTED_LANGUAGES.items():
|
63 |
+
sessions[lang] = ort.InferenceSession(info["file"], providers=['CPUExecutionProvider'])
|
64 |
+
tokenizers[lang] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{lang}")
|
65 |
|
66 |
CHUNK_SIZE = 4000 # Number of samples per chunk
|
67 |
|
68 |
+
def text_to_speech_generator(text: str, lang: str):
|
69 |
+
inputs = tokenizers[lang](text, return_tensors="np")
|
70 |
input_ids = inputs.input_ids.astype(np.int64)
|
71 |
+
onnx_output = sessions[lang].run(None, {"input_ids": input_ids})
|
72 |
waveform = onnx_output[0][0]
|
|
|
73 |
for i in range(0, len(waveform), CHUNK_SIZE):
|
74 |
yield waveform[i:i+CHUNK_SIZE]
|
75 |
|
|
|
82 |
wav_file.writeframes(b'') # Write empty frames to create header
|
83 |
return byte_io.getvalue()
|
84 |
|
85 |
+
@app.get("/tts", summary="Convert text to speech", response_description="Audio stream in WAV format")
|
86 |
+
async def tts_endpoint(
|
87 |
+
text: str = Query(..., description="The text to convert to speech"),
|
88 |
+
lang: LanguageCode = Query(..., description="The language code for text-to-speech conversion")
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Convert the given text to speech in the specified language.
|
92 |
+
|
93 |
+
- **text**: The input text to be converted to speech
|
94 |
+
- **lang**: The language code for the input text and desired speech output
|
95 |
+
|
96 |
+
Available language codes:
|
97 |
+
- hin: Hindi
|
98 |
+
- ben: Bengali
|
99 |
+
- mar: Marathi
|
100 |
+
- tel: Telugu
|
101 |
+
- tam: Tamil
|
102 |
+
- guj: Gujarati
|
103 |
+
- urd: Urdu
|
104 |
+
- kan: Kannada
|
105 |
+
- mal: Malayalam
|
106 |
+
- pan: Punjabi
|
107 |
+
- nep: Nepali
|
108 |
+
|
109 |
+
Returns a streaming response with the audio data in WAV format.
|
110 |
+
"""
|
111 |
try:
|
112 |
sample_rate = 16000
|
113 |
bits_per_sample = 16
|
|
|
116 |
async def audio_stream_generator():
|
117 |
# First, yield the WAV header
|
118 |
yield create_wav_header(sample_rate, bits_per_sample, channels)
|
|
|
119 |
# Then stream the audio data
|
120 |
+
for chunk in text_to_speech_generator(text, lang):
|
121 |
yield (chunk * 32767).astype(np.int16).tobytes()
|
122 |
|
123 |
return StreamingResponse(audio_stream_generator(), media_type="audio/wav")
|
124 |
except Exception as e:
|
125 |
raise HTTPException(status_code=500, detail=str(e))
|
126 |
|
127 |
+
@app.get("/languages", summary="Get supported languages", response_model=List[Dict[str, str]])
|
128 |
+
async def get_languages():
|
129 |
+
"""
|
130 |
+
Retrieve a list of supported languages with their codes and names.
|
131 |
+
|
132 |
+
Returns a list of dictionaries, each containing:
|
133 |
+
- **code**: The language code
|
134 |
+
- **name**: The full name of the language
|
135 |
+
"""
|
136 |
+
return [{"code": code, "name": info["name"]} for code, info in SUPPORTED_LANGUAGES.items()]
|
137 |
+
|
138 |
if __name__ == "__main__":
|
139 |
host = "0.0.0.0"
|
140 |
port = 8000
|