Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import librosa | |
import tensorflow as tf | |
from scipy.fftpack import dct | |
import os | |
# DSCNN model configuration | |
MODEL_PATH = "ds_cnn_l_quantized.tflite" | |
# Keywords based on Speech Commands dataset (12 classes) | |
KEYWORDS = [ | |
"silence", "unknown", "yes", "no", "up", "down", | |
"left", "right", "on", "off", "stop", "go" | |
] | |
print("Loading DSCNN TensorFlow Lite model...") | |
try: | |
# Load the TFLite model | |
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) | |
interpreter.allocate_tensors() | |
# Get input and output details | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
print(f"β DSCNN model loaded successfully!") | |
print(f"Input shape: {input_details[0]['shape']}") | |
print(f"Output shape: {output_details[0]['shape']}") | |
print(f"Input dtype: {input_details[0]['dtype']}") | |
print(f"Output dtype: {output_details[0]['dtype']}") | |
except Exception as e: | |
print(f"β Error loading DSCNN model: {e}") | |
interpreter = None | |
def extract_mfcc_features(audio_path, target_length=490): | |
""" | |
Extract MFCC features exactly as specified in the original DSCNN paper. | |
Based on "Hello Edge: Keyword Spotting on Microcontrollers" | |
Parameters from paper: | |
- 40ms frame length (640 samples at 16kHz) | |
- 20ms stride (320 samples at 16kHz) | |
- 10 MFCC features per frame | |
- 49 frames total for 1 second β 49Γ10 = 490 features | |
""" | |
try: | |
# Load audio and resample to 16kHz (standard for speech commands) | |
audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
# Ensure audio is exactly 1 second (16000 samples) | |
if len(audio) < 16000: | |
# Pad with zeros | |
audio = np.pad(audio, (0, 16000 - len(audio)), 'constant') | |
else: | |
# Truncate to 1 second | |
audio = audio[:16000] | |
# DSCNN paper parameters | |
frame_length = 640 # 40ms at 16kHz | |
hop_length = 320 # 20ms at 16kHz (50% overlap) | |
n_mfcc = 10 # 10 MFCC features as in paper | |
n_fft = 1024 # FFT size | |
n_mels = 40 # Mel filter bank size (before DCT) | |
# Extract mel spectrogram | |
mel_spec = librosa.feature.melspectrogram( | |
y=audio, | |
sr=sr, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
win_length=frame_length, | |
n_mels=n_mels, | |
fmin=20, | |
fmax=4000 | |
) | |
# Convert to log scale | |
log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
# Apply DCT to get MFCC features (only take first 10 coefficients) | |
mfcc_features = dct(log_mel_spec, axis=0, norm='ortho')[:n_mfcc, :] | |
# Should be shape (10, 49) for 1 second of audio | |
print(f"MFCC shape before flattening: {mfcc_features.shape}") | |
# Flatten to 1D array (10 Γ 49 = 490 features) | |
features_flat = mfcc_features.flatten() | |
# Ensure exactly 490 features | |
if len(features_flat) > target_length: | |
features_flat = features_flat[:target_length] | |
elif len(features_flat) < target_length: | |
features_flat = np.pad(features_flat, (0, target_length - len(features_flat)), 'constant') | |
print(f"Features length after processing: {len(features_flat)}") | |
# Normalize features (zero mean, unit variance) | |
features_flat = (features_flat - np.mean(features_flat)) / (np.std(features_flat) + 1e-8) | |
# Quantize to INT8 range for DSCNN model | |
# Scale to approximately match training distribution | |
features_int8 = np.clip(features_flat * 127.0, -128, 127).astype(np.int8) | |
return features_int8.reshape(1, -1) # Shape: (1, 490) | |
except Exception as e: | |
raise Exception(f"Error extracting MFCC features: {str(e)}") | |
def classify_audio(audio_input): | |
""" | |
Classify the input audio using the DSCNN model and return keyword predictions. | |
""" | |
if audio_input is None: | |
return "Please upload an audio file or record audio." | |
if interpreter is None: | |
return "β DSCNN model not loaded. Please refresh the page and try again." | |
try: | |
# Extract MFCC features | |
features = extract_mfcc_features(audio_input) | |
print(f"Input features shape: {features.shape}") | |
print(f"Input features dtype: {features.dtype}") | |
print(f"Input features range: [{features.min()}, {features.max()}]") | |
# Set input tensor | |
interpreter.set_tensor(input_details[0]['index'], features) | |
# Run inference | |
interpreter.invoke() | |
# Get output | |
output_data = interpreter.get_tensor(output_details[0]['index']) | |
print(f"Raw output shape: {output_data.shape}") | |
print(f"Raw output dtype: {output_data.dtype}") | |
print(f"Raw output range: [{output_data.min()}, {output_data.max()}]") | |
# Handle quantized INT8 output | |
if output_data.dtype == np.int8: | |
# Dequantize INT8 to float (assuming symmetric quantization) | |
# Scale factor is typically around 1/128 for INT8 | |
logits = output_data.astype(np.float32) / 128.0 | |
else: | |
logits = output_data.astype(np.float32) | |
# Apply softmax to get probabilities | |
exp_logits = np.exp(logits - np.max(logits)) | |
probabilities = exp_logits / np.sum(exp_logits) | |
# Get predictions with confidence scores | |
predictions = [] | |
for i, prob in enumerate(probabilities[0]): | |
predictions.append({ | |
'label': KEYWORDS[i], | |
'score': float(prob) | |
}) | |
# Sort by confidence score | |
predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) | |
# Format results | |
results = [] | |
for i, pred in enumerate(predictions[:5]): | |
confidence = pred['score'] * 100 | |
label = pred['label'] | |
indicator = "π―" if i == 0 else " " | |
results.append(f"{indicator} {i+1}. {label}: {confidence:.1f}%") | |
return "\n".join(results) | |
except Exception as e: | |
error_msg = str(e) | |
if "mfcc" in error_msg.lower() or "librosa" in error_msg.lower(): | |
return "β Audio processing error. Please ensure your audio file is in a supported format (WAV, MP3, etc.)" | |
elif "model" in error_msg.lower() or "tensor" in error_msg.lower(): | |
return "β Model inference error. Please try recording a clear 1-second audio clip." | |
else: | |
return f"β Error processing audio: {error_msg}\n\nTip: Try recording a clear 1-second word like 'yes' or 'stop'." | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=classify_audio, | |
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio"), | |
outputs=gr.Textbox(label="DSCNN Keyword Predictions", lines=8), | |
title="π€ DSCNN Wake Word Detection Demo", | |
description=""" | |
**Advanced wake word detection using Depthwise Separable CNN (DSCNN)** | |
This demo uses a quantized DSCNN model optimized for edge deployment. Upload audio or record directly to test keyword recognition. | |
**Supported Keywords:** yes, no, up, down, left, right, on, off, stop, go, silence, unknown | |
**Model Details:** | |
- Architecture: Depthwise Separable CNN (DSCNN) | |
- Quantization: INT8 (504KB model size) | |
- Accuracy: 94.5% on Google Speech Commands | |
- Input: MFCC features (1Γ490) | |
- Optimized for: ARM Cortex-M, embedded systems | |
""" | |
) | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.launch(share=True) |