import gradio as gr from speechbrain.inference.VAD import VAD import torch import torchaudio import tempfile import os # Initialize the VAD model vad = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty") def perform_vad(audio_input): """ This function takes an audio tuple from Gradio, saves it to a temporary file, runs VAD, and then cleans up. """ if audio_input is None: return "Please upload an audio file.", None sample_rate, waveform_data = audio_input # --- START OF FIX --- # Convert numpy array to a torch tensor waveform_tensor = torch.from_numpy(waveform_data) # Ensure the tensor is 2D: [channels, num_samples] if waveform_tensor.ndim == 1: # If it's a mono file, unsqueeze to add the channel dimension waveform_tensor = waveform_tensor.unsqueeze(0) # We should now have a 2D tensor, which is what torchaudio.save expects. # We no longer add the extra batch dimension. # --- END OF FIX --- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: temp_audio_path = tmpfile.name torchaudio.save(temp_audio_path, waveform_tensor, sample_rate, format="wav") try: # For VAD processing, we can use the tensor directly speech_segments = vad.get_speech_segments(waveform_tensor) if speech_segments.shape[0] == 0: return "No speech detected in the audio.", None output_text = "Detected Speech Segments (startTime, endTime in seconds):\n" output_json = [] for segment in speech_segments: start_time = round(segment[0].item(), 3) end_time = round(segment[1].item(), 3) output_text += f"- [{start_time}, {end_time}]\n" output_json.append({"startTime": start_time, "endTime": end_time}) return output_text, output_json except Exception as e: return f"An error occurred: {str(e)}", None finally: if os.path.exists(temp_audio_path): os.remove(temp_audio_path) # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# SpeechBrain VAD Demo") gr.Markdown("Upload an audio file to detect speech segments...") audio_input = gr.Audio(type="numpy", label="Upload Your Audio") process_button = gr.Button("Detect Speech") with gr.Row(): text_output = gr.Textbox(label="Detected Timestamps") json_output = gr.JSON(label="JSON Output for Backend") process_button.click( fn=perform_vad, inputs=audio_input, outputs=[text_output, json_output] ) demo.launch()