mohAbdullah's picture
Update app.py
586aa17 verified
raw
history blame
2.66 kB
import gradio as gr
from speechbrain.inference.VAD import VAD
import torch
import torchaudio
import tempfile
import os
# Initialize the VAD model
vad = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty")
def perform_vad(audio_input):
"""
This function takes an audio tuple from Gradio, saves it to a temporary file,
runs VAD, and then cleans up.
"""
if audio_input is None:
return "Please upload an audio file.", None
sample_rate, waveform_data = audio_input
# --- START OF FIX ---
# Convert numpy array to a torch tensor
waveform_tensor = torch.from_numpy(waveform_data)
# Ensure the tensor is 2D: [channels, num_samples]
if waveform_tensor.ndim == 1:
# If it's a mono file, unsqueeze to add the channel dimension
waveform_tensor = waveform_tensor.unsqueeze(0)
# We should now have a 2D tensor, which is what torchaudio.save expects.
# We no longer add the extra batch dimension.
# --- END OF FIX ---
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
temp_audio_path = tmpfile.name
torchaudio.save(temp_audio_path, waveform_tensor, sample_rate, format="wav")
try:
# For VAD processing, we can use the tensor directly
speech_segments = vad.get_speech_segments(waveform_tensor)
if speech_segments.shape[0] == 0:
return "No speech detected in the audio.", None
output_text = "Detected Speech Segments (startTime, endTime in seconds):\n"
output_json = []
for segment in speech_segments:
start_time = round(segment[0].item(), 3)
end_time = round(segment[1].item(), 3)
output_text += f"- [{start_time}, {end_time}]\n"
output_json.append({"startTime": start_time, "endTime": end_time})
return output_text, output_json
except Exception as e:
return f"An error occurred: {str(e)}", None
finally:
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# SpeechBrain VAD Demo")
gr.Markdown("Upload an audio file to detect speech segments...")
audio_input = gr.Audio(type="numpy", label="Upload Your Audio")
process_button = gr.Button("Detect Speech")
with gr.Row():
text_output = gr.Textbox(label="Detected Timestamps")
json_output = gr.JSON(label="JSON Output for Backend")
process_button.click(
fn=perform_vad,
inputs=audio_input,
outputs=[text_output, json_output]
)
demo.launch()