STT / app.py
harsh2ai
updated the logic
f4a0156
raw
history blame
14.5 kB
#!/usr/bin/env python3
"""
Ringg Parrot STT V1 🦜 - Hugging Face Space (Frontend)
Real-time streaming transcription using Gradio's audio streaming.
"""
import os
import tempfile
from pathlib import Path
import gradio as gr
import requests
import numpy as np
import soundfile as sf
from dotenv import load_dotenv
try:
import librosa
HAS_LIBROSA = True
except ImportError:
HAS_LIBROSA = False
print("⚠️ librosa not installed. Install with: pip install librosa")
load_dotenv()
# Backend API endpoint
API_ENDPOINT = os.environ.get("STT_API_ENDPOINT", "http://localhost:7864")
TARGET_SAMPLE_RATE = 16000
# How often to transcribe (in seconds of audio)
MIN_AUDIO_LENGTH = 0.4 # Transcribe when we have at least 400ms of new audio
class RinggSTTClient:
"""Client for Ringg Parrot STT API"""
def __init__(self, api_endpoint: str):
self.api_endpoint = api_endpoint.rstrip("/")
self.session = requests.Session()
self.session.headers.update({"User-Agent": "RinggSTT-HF-Space/1.0"})
def check_health(self) -> dict:
try:
response = self.session.get(f"{self.api_endpoint}/health", timeout=5)
if response.status_code == 200:
return {"status": "healthy", "message": "βœ… API is online"}
return {"status": "error", "message": f"❌ API returned status {response.status_code}"}
except Exception as e:
return {"status": "error", "message": f"❌ Error: {str(e)}"}
def transcribe_audio_data(self, audio_data: np.ndarray, sample_rate: int, language: str = "hi") -> str:
"""Transcribe audio data (numpy array) via multipart upload API"""
try:
# Save to temporary WAV file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = f.name
sf.write(temp_path, audio_data, sample_rate)
try:
with open(temp_path, "rb") as f:
files = {"file": ("audio.wav", f, "audio/wav")}
data = {"language": language, "punctuate": "false"}
response = self.session.post(
f"{self.api_endpoint}/v1/audio/transcriptions",
files=files,
data=data,
timeout=30,
)
# Debug: log the response for troubleshooting
print(
f"[transcribe_audio_data] status={response.status_code} "
f"body={response.text[:500]}"
)
if response.status_code == 200:
result = response.json()
if "transcription_channel_0" in result:
return result.get("transcription_channel_0", "")
return result.get("transcription", "")
else:
return ""
finally:
os.unlink(temp_path)
except Exception as e:
print(f"Transcription error: {e}")
return ""
def transcribe_file(self, audio_file_path: str, language: str = "hi") -> str:
"""Transcribe audio file via multipart upload API"""
try:
with open(audio_file_path, "rb") as f:
files = {"file": (Path(audio_file_path).name, f)}
data = {"language": language, "punctuate": "false"}
response = self.session.post(
f"{self.api_endpoint}/v1/audio/transcriptions",
files=files,
data=data,
timeout=120,
)
if response.status_code == 200:
result = response.json()
if "transcription_channel_0" in result:
transcripts = []
if result.get("transcription_channel_0"):
transcripts.append(result["transcription_channel_0"])
if result.get("transcription_channel_1"):
transcripts.append(f"\n[Channel 2]: {result['transcription_channel_1']}")
return "".join(transcripts) if transcripts else "No speech detected"
return result.get("transcription", "No transcription received")
else:
return f"❌ API Error: {response.status_code}"
except Exception as e:
return f"❌ Error: {str(e)}"
# Initialize API client
print(f"πŸ”— Connecting to STT API: {API_ENDPOINT}")
stt_client = RinggSTTClient(API_ENDPOINT)
health_status = stt_client.check_health()
print(f"API Health: {health_status}")
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
"""Resample audio to target sample rate"""
if orig_sr == target_sr:
return audio
if HAS_LIBROSA:
return librosa.resample(audio.astype(np.float64), orig_sr=orig_sr, target_sr=target_sr)
else:
# Simple linear interpolation fallback
duration = len(audio) / orig_sr
new_length = int(duration * target_sr)
indices = np.linspace(0, len(audio) - 1, new_length)
return np.interp(indices, np.arange(len(audio)), audio.astype(np.float64))
def transcribe_stream(audio, language, audio_buffer, last_transcription, samples_processed):
"""
Process streaming audio from microphone.
Simplified approach:
- Accumulate ALL audio chunks
- When we have enough new audio, transcribe the ENTIRE recording
- Display the complete transcription (backend handles everything)
"""
# Initialize states
if audio_buffer is None:
audio_buffer = []
if last_transcription is None:
last_transcription = ""
if samples_processed is None:
samples_processed = 0
# Handle invalid audio input
if audio is None or isinstance(audio, int):
display = last_transcription if last_transcription else "🎀 Click microphone to start..."
return display, audio_buffer, last_transcription, samples_processed
# Gradio streaming returns (sample_rate, audio_data)
if not isinstance(audio, tuple) or len(audio) != 2:
display = last_transcription if last_transcription else "🎀 Listening..."
return display, audio_buffer, last_transcription, samples_processed
sample_rate, audio_data = audio
if not isinstance(audio_data, np.ndarray) or len(audio_data) == 0:
display = last_transcription if last_transcription else "🎀 Listening..."
return display, audio_buffer, last_transcription, samples_processed
# Convert stereo to mono if needed
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Append this chunk to buffer
audio_buffer.append(audio_data.copy())
# Calculate total samples we have now
total_samples = sum(len(arr) for arr in audio_buffer)
total_duration = total_samples / sample_rate
# Calculate new audio since last transcription
new_samples = total_samples - samples_processed
new_duration = new_samples / sample_rate
# Only transcribe if we have enough NEW audio (to avoid too frequent API calls)
if new_duration < MIN_AUDIO_LENGTH:
display = last_transcription if last_transcription else f"🎀 Recording... ({total_duration:.1f}s)"
return display, audio_buffer, last_transcription, samples_processed
try:
# Concatenate ALL buffered audio
full_audio = np.concatenate(audio_buffer)
# Resample to 16kHz if needed
if sample_rate != TARGET_SAMPLE_RATE:
full_audio = resample_audio(full_audio, sample_rate, TARGET_SAMPLE_RATE)
# Normalize audio
max_val = np.max(np.abs(full_audio))
if max_val > 0:
full_audio = full_audio / max_val * 0.95
# Get language code
lang_code = "hi" if language == "Hindi" else "en"
# Transcribe the ENTIRE audio
transcription = stt_client.transcribe_audio_data(
full_audio.astype(np.float32),
TARGET_SAMPLE_RATE,
lang_code
)
# Update state
if transcription.strip():
last_transcription = transcription
# Mark all current samples as processed
samples_processed = total_samples
display = last_transcription if last_transcription else f"🎀 Recording... ({total_duration:.1f}s)"
return display, audio_buffer, last_transcription, samples_processed
except Exception as e:
print(f"Processing error: {e}")
display = last_transcription if last_transcription else "🎀 Listening..."
return display, audio_buffer, last_transcription, samples_processed
def clear_transcription():
"""Clear all transcription state"""
return "🎀 Click microphone to start...", None, "", 0
def transcribe_file(audio_file, language):
"""Transcribe uploaded audio file"""
if audio_file is None:
return "⚠️ Please upload an audio file to transcribe."
lang_code = "hi" if language == "Hindi" else "en"
transcription = stt_client.transcribe_file(audio_file, lang_code)
text = (transcription or "").strip()
if not text or text.startswith("❌") or text.startswith("⏱"):
return text or "⚠️ No speech detectedβ€”try a clearer recording."
return text
def create_interface():
"""Create Gradio interface"""
with gr.Blocks(
theme=gr.themes.Base(
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
),
css=".gradio-container {max-width: none !important;}",
) as demo:
gr.HTML("""
<div style="display: flex; align-items: center; gap: 10px;">
<img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;"
src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo">
<h1 style="margin: 0;">Ringg Parrot STT V1.0 🦜</h1>
</div>
""")
# Real-time streaming section
gr.Markdown("""
## 🎀 Real-time Transcription
Click the microphone to start recording. Transcription updates as you speak.
*The entire recording is transcribed each time, so text may refine as more context is added.*
""")
# States for streaming
audio_buffer = gr.State(None)
last_transcription = gr.State("")
samples_processed = gr.State(0)
with gr.Row():
with gr.Column(scale=1):
stream_language = gr.Dropdown(
choices=["Hindi", "English"],
value="Hindi",
label="Language",
)
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="🎀 Click to start recording",
)
clear_btn = gr.Button("πŸ—‘οΈ Clear & Reset", variant="secondary")
with gr.Column(scale=2):
text_output = gr.Textbox(
label="Transcription",
value="🎀 Click microphone to start...",
lines=10,
interactive=False,
)
# Wire up streaming
audio_input.stream(
fn=transcribe_stream,
inputs=[audio_input, stream_language, audio_buffer, last_transcription, samples_processed],
outputs=[text_output, audio_buffer, last_transcription, samples_processed],
)
# Clear button
clear_btn.click(
fn=clear_transcription,
inputs=[],
outputs=[text_output, audio_buffer, last_transcription, samples_processed],
)
gr.Markdown("<br>")
# File upload section
gr.Markdown("""
## πŸ“ Upload an audio file for transcription
Supports WAV, MP3, FLAC, M4A, and more.
""")
with gr.Row():
with gr.Column(scale=1):
file_language = gr.Dropdown(
choices=["Hindi", "English"],
value="Hindi",
label="Language",
)
file_input = gr.Audio(
type="filepath",
sources=["upload"],
label="Upload Audio",
)
transcribe_btn = gr.Button("Transcribe File", variant="primary", size="lg")
with gr.Column(scale=2):
file_output = gr.Textbox(
label="Transcription",
lines=8,
interactive=False,
)
transcribe_btn.click(
fn=transcribe_file,
inputs=[file_input, file_language],
outputs=file_output,
)
gr.Markdown("""
<br>
## 🎯 Performance Benchmarks
**Ringg Parrot STT V1** Ranks **1st** Among Top Models.
""")
with gr.Row():
gr.DataFrame(
value=[
["Parrot STT (Ringg AI)", "15.00%", "15.92%"],
["IndicWav2Vec ", "19.35%", "20.91%"],
["VakyanSh Wav2Vec2", "22.73%", "24.78%"],
],
headers=["Model", "Median WER ↓", "Mean WER ↓"],
datatype=["str", "str", "str"],
row_count=3,
col_count=(3, "fixed"),
interactive=False,
)
gr.Markdown("""
## πŸ™ Acknowledgements
- Built with [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) models
""")
return demo
if __name__ == "__main__":
print("🌐 Launching Ringg Parrot STT V1 Gradio Interface...")
print(f"Backend API: {API_ENDPOINT}")
demo = create_interface()
demo.queue(default_concurrency_limit=2, max_size=20)
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
debug=True,
show_api=False,
)