pvanand commited on
Commit
e35c088
·
verified ·
1 Parent(s): eaee333

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -14
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  import onnxruntime as ort
4
  import numpy as np
@@ -7,7 +7,14 @@ import io
7
  import wave
8
  import uvicorn
9
  from fastapi.middleware.cors import CORSMiddleware
10
- app = FastAPI()
 
 
 
 
 
 
 
11
 
12
  # Set up CORS
13
  app.add_middleware(
@@ -18,21 +25,51 @@ app.add_middleware(
18
  allow_headers=["*"], # Allows all headers
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Load the ONNX model
23
- session = ort.InferenceSession("mms-tts-hin.onnx", providers=['CPUExecutionProvider'])
 
24
 
25
- # Load the tokenizer
26
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hin")
 
 
27
 
28
  CHUNK_SIZE = 4000 # Number of samples per chunk
29
 
30
- def text_to_speech_generator(text):
31
- inputs = tokenizer(text, return_tensors="np")
32
  input_ids = inputs.input_ids.astype(np.int64)
33
- onnx_output = session.run(None, {"input_ids": input_ids})
34
  waveform = onnx_output[0][0]
35
-
36
  for i in range(0, len(waveform), CHUNK_SIZE):
37
  yield waveform[i:i+CHUNK_SIZE]
38
 
@@ -45,8 +82,32 @@ def create_wav_header(sample_rate, bits_per_sample, channels):
45
  wav_file.writeframes(b'') # Write empty frames to create header
46
  return byte_io.getvalue()
47
 
48
- @app.get("/tts")
49
- async def tts_endpoint(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
  sample_rate = 16000
52
  bits_per_sample = 16
@@ -55,15 +116,25 @@ async def tts_endpoint(text: str):
55
  async def audio_stream_generator():
56
  # First, yield the WAV header
57
  yield create_wav_header(sample_rate, bits_per_sample, channels)
58
-
59
  # Then stream the audio data
60
- for chunk in text_to_speech_generator(text):
61
  yield (chunk * 32767).astype(np.int16).tobytes()
62
 
63
  return StreamingResponse(audio_stream_generator(), media_type="audio/wav")
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=str(e))
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  if __name__ == "__main__":
68
  host = "0.0.0.0"
69
  port = 8000
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
  from fastapi.responses import StreamingResponse
3
  import onnxruntime as ort
4
  import numpy as np
 
7
  import wave
8
  import uvicorn
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from typing import Dict, List
11
+ from enum import Enum
12
+
13
+ app = FastAPI(
14
+ title="Multilingual Text-to-Speech API",
15
+ description="This API provides text-to-speech conversion for multiple Indian languages using ONNX models.",
16
+ version="1.0.0",
17
+ )
18
 
19
  # Set up CORS
20
  app.add_middleware(
 
25
  allow_headers=["*"], # Allows all headers
26
  )
27
 
28
+ # Define supported languages and their corresponding model files
29
+ SUPPORTED_LANGUAGES: Dict[str, Dict[str, str]] = {
30
+ "hin": {"name": "Hindi", "file": "mms-tts-hin.onnx"},
31
+ "ben": {"name": "Bengali", "file": "mms-tts-ben.onnx"},
32
+ "mar": {"name": "Marathi", "file": "mms-tts-mar.onnx"},
33
+ "tel": {"name": "Telugu", "file": "mms-tts-tel.onnx"},
34
+ "tam": {"name": "Tamil", "file": "mms-tts-tam.onnx"},
35
+ "guj": {"name": "Gujarati", "file": "mms-tts-guj.onnx"},
36
+ "urd": {"name": "Urdu", "file": "mms-tts-urd-script_arabic.onnx"},
37
+ "kan": {"name": "Kannada", "file": "mms-tts-kan.onnx"},
38
+ "mal": {"name": "Malayalam", "file": "mms-tts-mal.onnx"},
39
+ "pan": {"name": "Punjabi", "file": "mms-tts-pan.onnx"},
40
+ "nep": {"name": "Nepali", "file": "mms-tts-nep.onnx"}
41
+ }
42
+
43
+ # Create an Enum for language codes
44
+ class LanguageCode(str, Enum):
45
+ hindi = "hin"
46
+ bengali = "ben"
47
+ marathi = "mar"
48
+ telugu = "tel"
49
+ tamil = "tam"
50
+ gujarati = "guj"
51
+ urdu = "urd"
52
+ kannada = "kan"
53
+ malayalam = "mal"
54
+ punjabi = "pan"
55
+ nepali = "nep"
56
 
57
+ # Initialize dictionaries to store sessions and tokenizers
58
+ sessions: Dict[str, ort.InferenceSession] = {}
59
+ tokenizers: Dict[str, AutoTokenizer] = {}
60
 
61
+ # Load models and tokenizers for all supported languages
62
+ for lang, info in SUPPORTED_LANGUAGES.items():
63
+ sessions[lang] = ort.InferenceSession(info["file"], providers=['CPUExecutionProvider'])
64
+ tokenizers[lang] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{lang}")
65
 
66
  CHUNK_SIZE = 4000 # Number of samples per chunk
67
 
68
+ def text_to_speech_generator(text: str, lang: str):
69
+ inputs = tokenizers[lang](text, return_tensors="np")
70
  input_ids = inputs.input_ids.astype(np.int64)
71
+ onnx_output = sessions[lang].run(None, {"input_ids": input_ids})
72
  waveform = onnx_output[0][0]
 
73
  for i in range(0, len(waveform), CHUNK_SIZE):
74
  yield waveform[i:i+CHUNK_SIZE]
75
 
 
82
  wav_file.writeframes(b'') # Write empty frames to create header
83
  return byte_io.getvalue()
84
 
85
+ @app.get("/tts", summary="Convert text to speech", response_description="Audio stream in WAV format")
86
+ async def tts_endpoint(
87
+ text: str = Query(..., description="The text to convert to speech"),
88
+ lang: LanguageCode = Query(..., description="The language code for text-to-speech conversion")
89
+ ):
90
+ """
91
+ Convert the given text to speech in the specified language.
92
+
93
+ - **text**: The input text to be converted to speech
94
+ - **lang**: The language code for the input text and desired speech output
95
+
96
+ Available language codes:
97
+ - hin: Hindi
98
+ - ben: Bengali
99
+ - mar: Marathi
100
+ - tel: Telugu
101
+ - tam: Tamil
102
+ - guj: Gujarati
103
+ - urd: Urdu
104
+ - kan: Kannada
105
+ - mal: Malayalam
106
+ - pan: Punjabi
107
+ - nep: Nepali
108
+
109
+ Returns a streaming response with the audio data in WAV format.
110
+ """
111
  try:
112
  sample_rate = 16000
113
  bits_per_sample = 16
 
116
  async def audio_stream_generator():
117
  # First, yield the WAV header
118
  yield create_wav_header(sample_rate, bits_per_sample, channels)
 
119
  # Then stream the audio data
120
+ for chunk in text_to_speech_generator(text, lang):
121
  yield (chunk * 32767).astype(np.int16).tobytes()
122
 
123
  return StreamingResponse(audio_stream_generator(), media_type="audio/wav")
124
  except Exception as e:
125
  raise HTTPException(status_code=500, detail=str(e))
126
 
127
+ @app.get("/languages", summary="Get supported languages", response_model=List[Dict[str, str]])
128
+ async def get_languages():
129
+ """
130
+ Retrieve a list of supported languages with their codes and names.
131
+
132
+ Returns a list of dictionaries, each containing:
133
+ - **code**: The language code
134
+ - **name**: The full name of the language
135
+ """
136
+ return [{"code": code, "name": info["name"]} for code, info in SUPPORTED_LANGUAGES.items()]
137
+
138
  if __name__ == "__main__":
139
  host = "0.0.0.0"
140
  port = 8000