rawistt / app.py
walker11's picture
Upload 2 files
4197dc6 verified
import os
import tempfile
import speech_recognition as sr
import gradio as gr
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import uvicorn
from pathlib import Path
from pydub import AudioSegment
# Create FastAPI app
app = FastAPI(title="Speech to Text Model")
# Configure CORS to allow requests from frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # This can be more restrictive in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize speech recognition
recognizer = sr.Recognizer()
# FastAPI endpoint for direct API access
@app.post("/generate-story")
async def generate_story_api(file: UploadFile = File(...)):
try:
# Save uploaded audio to a temp file with original extension
file_extension = os.path.splitext(file.filename)[1] if file.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
# Process the audio using our function
transcript = transcribe_audio(tmp_path)
# Clean up temp file
os.remove(tmp_path)
# Return JSON response
return JSONResponse({
"transcript": transcript
})
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Convert any audio format to WAV
def convert_to_wav(audio_path):
try:
# Get the file extension
file_extension = os.path.splitext(audio_path)[1].lower()
# If already WAV, don't convert
if file_extension == ".wav":
return audio_path
# Create a new temporary WAV file
wav_path = os.path.splitext(audio_path)[0] + "_converted.wav"
# Convert based on file extension
if file_extension in [".mp3", ".m4a", ".ogg", ".flac", ".aac"]:
audio = AudioSegment.from_file(audio_path)
audio.export(wav_path, format="wav")
return wav_path
else:
# For unknown formats, try a generic approach
audio = AudioSegment.from_file(audio_path)
audio.export(wav_path, format="wav")
return wav_path
except Exception as e:
raise Exception(f"Error converting audio format: {str(e)}")
# Function for processing audio (used by both FastAPI and Gradio)
def transcribe_audio(audio_path):
try:
# Convert audio to WAV format first
wav_path = convert_to_wav(audio_path)
# Use speech_recognition to transcribe
with sr.AudioFile(wav_path) as source:
audio_data = recognizer.record(source)
# Try to use Google's speech recognition for Arabic
text = recognizer.recognize_google(audio_data, language="ar-AR")
# Clean up converted file if it's different from the original
if wav_path != audio_path and os.path.exists(wav_path):
os.remove(wav_path)
return text
except sr.UnknownValueError:
return "لم يتم التعرف على الكلام"
except sr.RequestError as e:
return f"حدث خطأ في خدمة التعرف على الصوت: {e}"
except Exception as e:
return f"حدث خطأ: {str(e)}"
# Gradio interface wrapper for the model
def gradio_process(audio_file):
try:
# Handle the audio file whether it's a string path or an object
audio_path = audio_file if isinstance(audio_file, str) else audio_file.name
# Process the audio
transcript = transcribe_audio(audio_path)
return transcript
except Exception as e:
return f"حدث خطأ: {str(e)}"
# Define Gradio interface
with gr.Blocks(title="Speech to Text Model") as demo:
gr.Markdown("# Speech to Text")
gr.Markdown("قم بتسجيل أو تحميل ملف صوتي باللغة العربية وسيقوم النظام بتحويله إلى نص.")
with gr.Row():
audio_input = gr.Audio(label="تسجيل أو تحميل صوت", type="filepath")
with gr.Row():
submit_btn = gr.Button("تحويل إلى نص")
with gr.Row():
transcript_output = gr.Textbox(label="النص المستخرج من التسجيل الصوتي")
submit_btn.click(
fn=gradio_process,
inputs=audio_input,
outputs=transcript_output,
)
# Mount static files for frontend if they exist
frontend_path = Path("../front")
if frontend_path.exists():
app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="frontend")
# Launch with uvicorn when run directly
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)