|
|
|
|
|
""" |
|
|
Ringg Parrot STT V1 π¦ - Hugging Face Space (Frontend) |
|
|
Real-time streaming transcription using Gradio's audio streaming. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
HAS_LIBROSA = True |
|
|
except ImportError: |
|
|
HAS_LIBROSA = False |
|
|
print("β οΈ librosa not installed. Install with: pip install librosa") |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
API_ENDPOINT = os.environ.get("STT_API_ENDPOINT", "http://localhost:7864") |
|
|
TARGET_SAMPLE_RATE = 16000 |
|
|
|
|
|
|
|
|
MIN_AUDIO_LENGTH = 0.4 |
|
|
|
|
|
|
|
|
class RinggSTTClient: |
|
|
"""Client for Ringg Parrot STT API""" |
|
|
|
|
|
def __init__(self, api_endpoint: str): |
|
|
self.api_endpoint = api_endpoint.rstrip("/") |
|
|
self.session = requests.Session() |
|
|
self.session.headers.update({"User-Agent": "RinggSTT-HF-Space/1.0"}) |
|
|
|
|
|
def check_health(self) -> dict: |
|
|
try: |
|
|
response = self.session.get(f"{self.api_endpoint}/health", timeout=5) |
|
|
if response.status_code == 200: |
|
|
return {"status": "healthy", "message": "β
API is online"} |
|
|
return {"status": "error", "message": f"β API returned status {response.status_code}"} |
|
|
except Exception as e: |
|
|
return {"status": "error", "message": f"β Error: {str(e)}"} |
|
|
|
|
|
def transcribe_audio_data(self, audio_data: np.ndarray, sample_rate: int, language: str = "hi") -> str: |
|
|
"""Transcribe audio data (numpy array) via multipart upload API""" |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
temp_path = f.name |
|
|
sf.write(temp_path, audio_data, sample_rate) |
|
|
|
|
|
try: |
|
|
with open(temp_path, "rb") as f: |
|
|
files = {"file": ("audio.wav", f, "audio/wav")} |
|
|
data = {"language": language, "punctuate": "false"} |
|
|
response = self.session.post( |
|
|
f"{self.api_endpoint}/v1/audio/transcriptions", |
|
|
files=files, |
|
|
data=data, |
|
|
timeout=30, |
|
|
) |
|
|
|
|
|
|
|
|
print( |
|
|
f"[transcribe_audio_data] status={response.status_code} " |
|
|
f"body={response.text[:500]}" |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
if "transcription_channel_0" in result: |
|
|
return result.get("transcription_channel_0", "") |
|
|
return result.get("transcription", "") |
|
|
else: |
|
|
return "" |
|
|
finally: |
|
|
os.unlink(temp_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Transcription error: {e}") |
|
|
return "" |
|
|
|
|
|
def transcribe_file(self, audio_file_path: str, language: str = "hi") -> str: |
|
|
"""Transcribe audio file via multipart upload API""" |
|
|
try: |
|
|
with open(audio_file_path, "rb") as f: |
|
|
files = {"file": (Path(audio_file_path).name, f)} |
|
|
data = {"language": language, "punctuate": "false"} |
|
|
response = self.session.post( |
|
|
f"{self.api_endpoint}/v1/audio/transcriptions", |
|
|
files=files, |
|
|
data=data, |
|
|
timeout=120, |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
if "transcription_channel_0" in result: |
|
|
transcripts = [] |
|
|
if result.get("transcription_channel_0"): |
|
|
transcripts.append(result["transcription_channel_0"]) |
|
|
if result.get("transcription_channel_1"): |
|
|
transcripts.append(f"\n[Channel 2]: {result['transcription_channel_1']}") |
|
|
return "".join(transcripts) if transcripts else "No speech detected" |
|
|
return result.get("transcription", "No transcription received") |
|
|
else: |
|
|
return f"β API Error: {response.status_code}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
print(f"π Connecting to STT API: {API_ENDPOINT}") |
|
|
stt_client = RinggSTTClient(API_ENDPOINT) |
|
|
health_status = stt_client.check_health() |
|
|
print(f"API Health: {health_status}") |
|
|
|
|
|
|
|
|
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
|
|
"""Resample audio to target sample rate""" |
|
|
if orig_sr == target_sr: |
|
|
return audio |
|
|
|
|
|
if HAS_LIBROSA: |
|
|
return librosa.resample(audio.astype(np.float64), orig_sr=orig_sr, target_sr=target_sr) |
|
|
else: |
|
|
|
|
|
duration = len(audio) / orig_sr |
|
|
new_length = int(duration * target_sr) |
|
|
indices = np.linspace(0, len(audio) - 1, new_length) |
|
|
return np.interp(indices, np.arange(len(audio)), audio.astype(np.float64)) |
|
|
|
|
|
|
|
|
def transcribe_stream(audio, language, audio_buffer, last_transcription, samples_processed): |
|
|
""" |
|
|
Process streaming audio from microphone. |
|
|
|
|
|
Simplified approach: |
|
|
- Accumulate ALL audio chunks |
|
|
- When we have enough new audio, transcribe the ENTIRE recording |
|
|
- Display the complete transcription (backend handles everything) |
|
|
""" |
|
|
|
|
|
if audio_buffer is None: |
|
|
audio_buffer = [] |
|
|
if last_transcription is None: |
|
|
last_transcription = "" |
|
|
if samples_processed is None: |
|
|
samples_processed = 0 |
|
|
|
|
|
|
|
|
if audio is None or isinstance(audio, int): |
|
|
display = last_transcription if last_transcription else "π€ Click microphone to start..." |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
|
|
|
if not isinstance(audio, tuple) or len(audio) != 2: |
|
|
display = last_transcription if last_transcription else "π€ Listening..." |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
sample_rate, audio_data = audio |
|
|
|
|
|
if not isinstance(audio_data, np.ndarray) or len(audio_data) == 0: |
|
|
display = last_transcription if last_transcription else "π€ Listening..." |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
|
|
|
audio_buffer.append(audio_data.copy()) |
|
|
|
|
|
|
|
|
total_samples = sum(len(arr) for arr in audio_buffer) |
|
|
total_duration = total_samples / sample_rate |
|
|
|
|
|
|
|
|
new_samples = total_samples - samples_processed |
|
|
new_duration = new_samples / sample_rate |
|
|
|
|
|
|
|
|
if new_duration < MIN_AUDIO_LENGTH: |
|
|
display = last_transcription if last_transcription else f"π€ Recording... ({total_duration:.1f}s)" |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
try: |
|
|
|
|
|
full_audio = np.concatenate(audio_buffer) |
|
|
|
|
|
|
|
|
if sample_rate != TARGET_SAMPLE_RATE: |
|
|
full_audio = resample_audio(full_audio, sample_rate, TARGET_SAMPLE_RATE) |
|
|
|
|
|
|
|
|
max_val = np.max(np.abs(full_audio)) |
|
|
if max_val > 0: |
|
|
full_audio = full_audio / max_val * 0.95 |
|
|
|
|
|
|
|
|
lang_code = "hi" if language == "Hindi" else "en" |
|
|
|
|
|
|
|
|
transcription = stt_client.transcribe_audio_data( |
|
|
full_audio.astype(np.float32), |
|
|
TARGET_SAMPLE_RATE, |
|
|
lang_code |
|
|
) |
|
|
|
|
|
|
|
|
if transcription.strip(): |
|
|
last_transcription = transcription |
|
|
|
|
|
|
|
|
samples_processed = total_samples |
|
|
|
|
|
display = last_transcription if last_transcription else f"π€ Recording... ({total_duration:.1f}s)" |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Processing error: {e}") |
|
|
display = last_transcription if last_transcription else "π€ Listening..." |
|
|
return display, audio_buffer, last_transcription, samples_processed |
|
|
|
|
|
|
|
|
def clear_transcription(): |
|
|
"""Clear all transcription state""" |
|
|
return "π€ Click microphone to start...", None, "", 0 |
|
|
|
|
|
|
|
|
def transcribe_file(audio_file, language): |
|
|
"""Transcribe uploaded audio file""" |
|
|
if audio_file is None: |
|
|
return "β οΈ Please upload an audio file to transcribe." |
|
|
|
|
|
lang_code = "hi" if language == "Hindi" else "en" |
|
|
transcription = stt_client.transcribe_file(audio_file, lang_code) |
|
|
text = (transcription or "").strip() |
|
|
|
|
|
if not text or text.startswith("β") or text.startswith("β±"): |
|
|
return text or "β οΈ No speech detectedβtry a clearer recording." |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create Gradio interface""" |
|
|
|
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Base( |
|
|
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"] |
|
|
), |
|
|
css=".gradio-container {max-width: none !important;}", |
|
|
) as demo: |
|
|
gr.HTML(""" |
|
|
<div style="display: flex; align-items: center; gap: 10px;"> |
|
|
<img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;" |
|
|
src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo"> |
|
|
<h1 style="margin: 0;">Ringg Parrot STT V1.0 π¦</h1> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
## π€ Real-time Transcription |
|
|
Click the microphone to start recording. Transcription updates as you speak. |
|
|
|
|
|
*The entire recording is transcribed each time, so text may refine as more context is added.* |
|
|
""") |
|
|
|
|
|
|
|
|
audio_buffer = gr.State(None) |
|
|
last_transcription = gr.State("") |
|
|
samples_processed = gr.State(0) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
stream_language = gr.Dropdown( |
|
|
choices=["Hindi", "English"], |
|
|
value="Hindi", |
|
|
label="Language", |
|
|
) |
|
|
audio_input = gr.Audio( |
|
|
sources=["microphone"], |
|
|
type="numpy", |
|
|
streaming=True, |
|
|
label="π€ Click to start recording", |
|
|
) |
|
|
clear_btn = gr.Button("ποΈ Clear & Reset", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
text_output = gr.Textbox( |
|
|
label="Transcription", |
|
|
value="π€ Click microphone to start...", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
|
|
|
audio_input.stream( |
|
|
fn=transcribe_stream, |
|
|
inputs=[audio_input, stream_language, audio_buffer, last_transcription, samples_processed], |
|
|
outputs=[text_output, audio_buffer, last_transcription, samples_processed], |
|
|
) |
|
|
|
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_transcription, |
|
|
inputs=[], |
|
|
outputs=[text_output, audio_buffer, last_transcription, samples_processed], |
|
|
) |
|
|
|
|
|
gr.Markdown("<br>") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
## π Upload an audio file for transcription |
|
|
Supports WAV, MP3, FLAC, M4A, and more. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
file_language = gr.Dropdown( |
|
|
choices=["Hindi", "English"], |
|
|
value="Hindi", |
|
|
label="Language", |
|
|
) |
|
|
file_input = gr.Audio( |
|
|
type="filepath", |
|
|
sources=["upload"], |
|
|
label="Upload Audio", |
|
|
) |
|
|
transcribe_btn = gr.Button("Transcribe File", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
file_output = gr.Textbox( |
|
|
label="Transcription", |
|
|
lines=8, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=transcribe_file, |
|
|
inputs=[file_input, file_language], |
|
|
outputs=file_output, |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
<br> |
|
|
|
|
|
## π― Performance Benchmarks |
|
|
**Ringg Parrot STT V1** Ranks **1st** Among Top Models. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
gr.DataFrame( |
|
|
value=[ |
|
|
["Parrot STT (Ringg AI)", "15.00%", "15.92%"], |
|
|
["IndicWav2Vec ", "19.35%", "20.91%"], |
|
|
["VakyanSh Wav2Vec2", "22.73%", "24.78%"], |
|
|
], |
|
|
headers=["Model", "Median WER β", "Mean WER β"], |
|
|
datatype=["str", "str", "str"], |
|
|
row_count=3, |
|
|
col_count=(3, "fixed"), |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
## π Acknowledgements |
|
|
- Built with [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) models |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("π Launching Ringg Parrot STT V1 Gradio Interface...") |
|
|
print(f"Backend API: {API_ENDPOINT}") |
|
|
demo = create_interface() |
|
|
demo.queue(default_concurrency_limit=2, max_size=20) |
|
|
demo.launch( |
|
|
share=False, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
debug=True, |
|
|
show_api=False, |
|
|
) |
|
|
|