audio_chat_indic / main.py
pvanand's picture
Update main.py
1eae296 verified
from fastapi import FastAPI, HTTPException, Query, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
import io
import wave
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, List
from enum import Enum
from pydub import AudioSegment
import io
import lameenc
import os
from openai import OpenAI
from pydantic import BaseModel
import edge_tts
from fast_langdetect import detect
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
app = FastAPI(
title="Multilingual Text-to-Speech API",
description="This API provides text-to-speech conversion for multiple Indian languages using ONNX models.",
version="1.0.0",
)
# Define supported languages and their corresponding model files
SUPPORTED_LANGUAGES: Dict[str, Dict[str, str]] = {
"hin": {"name": "Hindi", "file": "mms-tts-hin.onnx"},
"ben": {"name": "Bengali", "file": "mms-tts-ben.onnx"},
"mar": {"name": "Marathi", "file": "mms-tts-mar.onnx"},
"tel": {"name": "Telugu", "file": "mms-tts-tel.onnx"},
"tam": {"name": "Tamil", "file": "mms-tts-tam.onnx"},
"guj": {"name": "Gujarati", "file": "mms-tts-guj.onnx"},
"urd-script_arabic": {"name": "Urdu", "file": "mms-tts-urd-script_arabic.onnx"},
"kan": {"name": "Kannada", "file": "mms-tts-kan.onnx"},
"mal": {"name": "Malayalam", "file": "mms-tts-mal.onnx"},
"pan": {"name": "Punjabi", "file": "mms-tts-pan.onnx"},
}
# Create an Enum for language codes
class LanguageCode(str, Enum):
hindi = "hin"
bengali = "ben"
marathi = "mar"
telugu = "tel"
tamil = "tam"
gujarati = "guj"
urdu = "urd-script_arabic"
kannada = "kan"
malayalam = "mal"
punjabi = "pan"
# Initialize dictionaries to store sessions and tokenizers
sessions: Dict[str, ort.InferenceSession] = {}
tokenizers: Dict[str, AutoTokenizer] = {}
# Load models and tokenizers for all supported languages
for lang, info in SUPPORTED_LANGUAGES.items():
sessions[lang] = ort.InferenceSession(info["file"], providers=['CPUExecutionProvider'])
tokenizers[lang] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{lang}")
CHUNK_SIZE = 4000 # Number of samples per chunk
def text_to_speech(text: str, lang: str):
text = text.replace('\\n', ' ').strip()
inputs = tokenizers[lang](text, return_tensors="np")
input_ids = inputs.input_ids.astype(np.int64)
onnx_output = sessions[lang].run(None, {"input_ids": input_ids})
waveform = onnx_output[0][0]
return waveform
def numpy_to_mp3(waveform, sample_rate=16000):
# Convert to int16
audio_data = (waveform * 32767).astype(np.int16)
# Create an AudioSegment
audio_segment = AudioSegment(
audio_data.tobytes(),
frame_rate=sample_rate,
sample_width=2,
channels=1
)
# Export as MP3
buffer = io.BytesIO()
audio_segment.export(buffer, format="mp3")
return buffer.getvalue()
def create_wav_header(sample_rate, bits_per_sample, channels):
byte_io = io.BytesIO()
with wave.open(byte_io, 'wb') as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bits_per_sample // 8)
wav_file.setframerate(sample_rate)
wav_file.writeframes(b'') # Write empty frames to create header
return byte_io.getvalue()
async def edge_tts_generate(text: str, voice: str = "en-GB-SoniaNeural"):
text = text.replace('\\n', ' ').strip()
communicate = edge_tts.Communicate(text, voice)
async for chunk in communicate.stream():
if chunk["type"] == "audio":
yield chunk["data"]
@app.get("/tts", summary="Convert text to speech", response_description="Audio in MP3 format")
async def tts_endpoint(
text: str = Query(..., description="The text to convert to speech"),
lang: LanguageCode = Query(..., description="The language code for text-to-speech conversion"),
voice: str = Query(default="en-GB-SoniaNeural", description="Voice to use for speech (only for English)")
):
"""
Convert the given text to speech in the specified language.
- **text**: The input text to be converted to speech
- **lang**: The language code for the input text and desired speech output
- **voice**: The voice to use for speech (only applicable for English)
Available language codes:
- hin: Hindi
- ben: Bengali
- mar: Marathi
- tel: Telugu
- tam: Tamil
- guj: Gujarati
- urd-script_arabic: Urdu
- kan: Kannada
- mal: Malayalam
- pan: Punjabi
- eng: English
Returns a streaming response with the audio data in MP3 format.
"""
try:
if lang == "eng":
return StreamingResponse(edge_tts_generate(text, voice), media_type="audio/mpeg")
else:
waveform = text_to_speech(text, lang)
mp3_data = numpy_to_mp3(waveform)
return StreamingResponse(io.BytesIO(mp3_data), media_type="audio/mpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
##########################################################
iso_code_mapping = {
"hi": "hin",
"bn": "ben",
"mr": "mar",
"te": "tel",
"ta": "tam",
"gu": "guj",
"ur": "urd-script_arabic",
"kn": "kan",
"ml": "mal",
"pa": "pan",
"en": "eng"
}
def detect_language(text):
try:
lang_code_2letter = detect(text, low_memory=False)["lang"]
lang_code_3letter = iso_code_mapping.get(lang_code_2letter, "Unknown")
return lang_code_3letter
except Exception as e:
return f"Error: {str(e)}"
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.get("/auto-tts", summary="Auto-detect language and convert text to speech", response_description="Audio in MP3 format")
async def auto_tts_endpoint(
text: str = Query(..., description="The text to convert to speech"),
voice: str = Query(default="en-GB-SoniaNeural", description="Voice to use for speech (only for English)")
):
try:
text = text.replace('\\n', ' ').strip()
logger.info(f"Received text: {text[:100]}...") # Log first 100 chars of input
detected_lang = detect_language(text)
logger.info(f"Detected language: {detected_lang}")
if detected_lang == "eng" or detected_lang == "Unknown":
logger.info("Using edge_tts_generate")
return StreamingResponse(edge_tts_generate(text, voice), media_type="audio/mpeg")
else:
logger.info(f"Using text_to_speech for language: {detected_lang}")
waveform = text_to_speech(text, detected_lang)
logger.info("Converting waveform to MP3")
mp3_data = numpy_to_mp3(waveform)
return StreamingResponse(io.BytesIO(mp3_data), media_type="audio/mpeg")
except Exception as e:
logger.error(f"Error in auto_tts_endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
#########################################################
# Initialize OpenAI API client with your API key
client = OpenAI(api_key=OPENAI_API_KEY)
class TranscriptionResponse(BaseModel):
transcription: str
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
try:
# Check if the file is an audio file
if not file.content_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="File must be an audio file")
# Read the file content
content = await file.read()
# Create a temporary file
with open(file.filename, "wb") as temp_file:
temp_file.write(content)
# Open the temporary file and transcribe
with open(file.filename, "rb") as audio_file:
transcript = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
# Remove the temporary file
os.remove(file.filename)
# Return the transcript
return JSONResponse(content={"transcript": transcript.text})
except Exception as e:
# Handle any errors
return HTTPException(status_code=500, detail=str(e))
@app.get("/languages", summary="Get supported languages", response_model=List[Dict[str, str]])
async def get_languages():
"""
Retrieve a list of supported languages with their codes and names.
Returns a list of dictionaries, each containing:
- **code**: The language code
- **name**: The full name of the language
"""
return [{"code": code, "name": info["name"]} for code, info in SUPPORTED_LANGUAGES.items()]
###### TTS STREAM
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import lameenc
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
# Initialize ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=4) # Adjust the number of workers as needed
CHUNK_SIZE = 128 # Adjust as needed
async def text_to_speech_async(text: str, lang: str):
loop = asyncio.get_running_loop()
# Run the ONNX inference in a separate thread
inputs = await loop.run_in_executor(executor, lambda: tokenizers[lang](text, return_tensors="np"))
input_ids = inputs.input_ids.astype(np.int64)
onnx_output = await loop.run_in_executor(executor, lambda: sessions[lang].run(None, {"input_ids": input_ids}))
waveform = onnx_output[0][0]
# Initialize the MP3 encoder
encoder = lameenc.Encoder()
encoder.set_bit_rate(128)
encoder.set_in_sample_rate(16000) # Adjust if your model uses a different sample rate
encoder.set_channels(1)
encoder.set_quality(2)
for i in range(0, len(waveform), CHUNK_SIZE):
chunk = waveform[i:i+CHUNK_SIZE]
# Convert to int16 and encode to MP3
audio_data = (chunk * 32767).astype(np.int16)
mp3_chunk = await loop.run_in_executor(executor, encoder.encode, audio_data.tobytes())
if mp3_chunk:
yield mp3_chunk
# Flush the encoder
mp3_chunk = await loop.run_in_executor(executor, encoder.flush)
if mp3_chunk:
yield mp3_chunk
def numpy_to_mp3_chunk(waveform, sample_rate=16000):
audio_data = (waveform * 32767).astype(np.int16)
encoder = lameenc.Encoder()
encoder.set_bit_rate(128)
encoder.set_in_sample_rate(sample_rate)
encoder.set_channels(1)
encoder.set_quality(2)
mp3_data = encoder.encode(audio_data.tobytes())
return mp3_data
@app.get("/tts-stream")
async def tts_endpoint(
text: str = Query(..., description="The text to convert to speech"),
lang: LanguageCode = Query(..., description="The language code for text-to-speech conversion")
):
try:
async def generate():
async for mp3_chunk in text_to_speech_async(text, lang):
yield mp3_chunk
return StreamingResponse(generate(), media_type="audio/mpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
from fastapi.middleware.cors import CORSMiddleware
# CORS middleware setup
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://www.elevaticsai.com",
"https://www.elevatics.cloud",
"https://www.elevatics.online",
"https://www.elevatics.ai",
"https://elevaticsai.com",
"https://elevatics.cloud",
"https://elevatics.online",
"https://elevatics.ai",
"https://pvanand-specialized-agents.hf.space"
],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
if __name__ == "__main__":
host = "0.0.0.0"
port = 8000
print(f"Starting server. Access the API documentation at http://localhost:{port}/docs")
uvicorn.run(app, host=host, port=port)