Spaces:
Sleeping
Sleeping
import gradio as gr | |
from speechbrain.inference.VAD import VAD | |
import torch | |
import torchaudio | |
import tempfile | |
import os | |
# Initialize the VAD model | |
vad = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty") | |
def perform_vad(audio_input): | |
""" | |
This function takes an audio tuple from Gradio, saves it to a temporary file, | |
runs VAD, and then cleans up. | |
""" | |
if audio_input is None: | |
return "Please upload an audio file.", None | |
sample_rate, waveform_data = audio_input | |
# --- START OF FIX --- | |
# Convert numpy array to a torch tensor | |
waveform_tensor = torch.from_numpy(waveform_data) | |
# Ensure the tensor is 2D: [channels, num_samples] | |
if waveform_tensor.ndim == 1: | |
# If it's a mono file, unsqueeze to add the channel dimension | |
waveform_tensor = waveform_tensor.unsqueeze(0) | |
# We should now have a 2D tensor, which is what torchaudio.save expects. | |
# We no longer add the extra batch dimension. | |
# --- END OF FIX --- | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
temp_audio_path = tmpfile.name | |
torchaudio.save(temp_audio_path, waveform_tensor, sample_rate, format="wav") | |
try: | |
# For VAD processing, we can use the tensor directly | |
speech_segments = vad.get_speech_segments(waveform_tensor) | |
if speech_segments.shape[0] == 0: | |
return "No speech detected in the audio.", None | |
output_text = "Detected Speech Segments (startTime, endTime in seconds):\n" | |
output_json = [] | |
for segment in speech_segments: | |
start_time = round(segment[0].item(), 3) | |
end_time = round(segment[1].item(), 3) | |
output_text += f"- [{start_time}, {end_time}]\n" | |
output_json.append({"startTime": start_time, "endTime": end_time}) | |
return output_text, output_json | |
except Exception as e: | |
return f"An error occurred: {str(e)}", None | |
finally: | |
if os.path.exists(temp_audio_path): | |
os.remove(temp_audio_path) | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# SpeechBrain VAD Demo") | |
gr.Markdown("Upload an audio file to detect speech segments...") | |
audio_input = gr.Audio(type="numpy", label="Upload Your Audio") | |
process_button = gr.Button("Detect Speech") | |
with gr.Row(): | |
text_output = gr.Textbox(label="Detected Timestamps") | |
json_output = gr.JSON(label="JSON Output for Backend") | |
process_button.click( | |
fn=perform_vad, | |
inputs=audio_input, | |
outputs=[text_output, json_output] | |
) | |
demo.launch() |