Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration | |
| import librosa | |
| import numpy as np | |
| import os | |
| import gc | |
| import re | |
| class MultiModelASRInterface: | |
| def __init__(self): | |
| """Initialize the ASR interface with model selection.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| # Available models with descriptions | |
| self.available_models = { | |
| "facebook/wav2vec2-base-960h": { | |
| "name": "Wav2Vec2 Base (960h)", | |
| "description": "Fast, good accuracy, ~1GB memory", | |
| "size": "~1GB", | |
| "type": "wav2vec2" | |
| }, | |
| "facebook/wav2vec2-large-960h": { | |
| "name": "Wav2Vec2 Large (960h)", | |
| "description": "High accuracy, ~3GB memory", | |
| "size": "~3GB", | |
| "type": "wav2vec2" | |
| }, | |
| "facebook/wav2vec2-base-100h": { | |
| "name": "Wav2Vec2 Base (100h)", | |
| "description": "Fast, smaller model, ~300MB memory", | |
| "size": "~300MB", | |
| "type": "wav2vec2" | |
| }, | |
| "openai/whisper-large-v3-turbo": { | |
| "name": "Whisper Large V3 Turbo", | |
| "description": "State-of-the-art accuracy, multilingual, ~5GB memory", | |
| "size": "~5GB", | |
| "type": "whisper" | |
| } | |
| } | |
| # Current loaded model | |
| self.current_model_name = None | |
| self.processor = None | |
| self.model = None | |
| self.model_type = None | |
| print("Multi-model ASR interface initialized!") | |
| print("Available models:") | |
| for model_id, info in self.available_models.items(): | |
| print(f" - {model_id}: {info['description']}") | |
| def load_model(self, model_name): | |
| """ | |
| Load a specific model and clear previous one. | |
| Args: | |
| model_name: Name of the model to load | |
| Returns: | |
| str: Status message | |
| """ | |
| try: | |
| # Clear previous model if different | |
| if self.current_model_name != model_name: | |
| self.clear_model() | |
| print(f"Loading model: {model_name}") | |
| # Load new model based on type | |
| model_type = self.available_models[model_name]["type"] | |
| self.model_type = model_type | |
| if model_type == "wav2vec2": | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(self.device) | |
| elif model_type == "whisper": | |
| self.processor = WhisperProcessor.from_pretrained(model_name) | |
| self.model = WhisperForConditionalGeneration.from_pretrained(model_name).to(self.device) | |
| self.current_model_name = model_name | |
| print(f"β Model loaded successfully: {model_name}") | |
| return f"β Model loaded: {self.available_models[model_name]['name']}" | |
| else: | |
| return f"β Model already loaded: {self.available_models[model_name]['name']}" | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {str(e)}") | |
| return f"β Error loading model: {str(e)}" | |
| def clear_model(self): | |
| """Clear the currently loaded model to free memory.""" | |
| if self.model is not None: | |
| print(f"Clearing model: {self.current_model_name}") | |
| # Move to CPU first to free GPU memory | |
| if self.device == "cuda": | |
| self.model = self.model.cpu() | |
| # Delete model and processor | |
| del self.model | |
| del self.processor | |
| # Clear cache | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| # Force garbage collection | |
| gc.collect() | |
| self.model = None | |
| self.processor = None | |
| self.current_model_name = None | |
| print("Model cleared from memory") | |
| def preprocess_audio(self, audio): | |
| """ | |
| Preprocess audio for ASR models. | |
| Args: | |
| audio: Audio data (numpy array or file path) | |
| Returns: | |
| tuple: (processed_audio, sample_rate) | |
| """ | |
| try: | |
| if isinstance(audio, tuple): | |
| # Direct recording from microphone | |
| sample_rate, audio_data = audio | |
| print(f"Processing recorded audio: sample_rate={sample_rate}, shape={audio_data.shape}") | |
| elif isinstance(audio, str): | |
| # Uploaded file | |
| print(f"Processing uploaded file: {audio}") | |
| audio_data, sample_rate = librosa.load(audio, sr=None) | |
| print(f"Loaded audio: sample_rate={sample_rate}, shape={audio_data.shape}") | |
| else: | |
| raise ValueError("Unsupported audio format") | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| print(f"Converted to mono: shape={audio_data.shape}") | |
| # Resample to 16kHz if needed (models expect 16kHz) | |
| if sample_rate != 16000: | |
| print(f"Resampling from {sample_rate}Hz to 16000Hz") | |
| audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| # Normalize audio | |
| max_val = np.max(np.abs(audio_data)) | |
| if max_val > 0: | |
| audio_data = audio_data / max_val | |
| return audio_data, sample_rate | |
| except Exception as e: | |
| print(f"Error in audio preprocessing: {str(e)}") | |
| raise | |
| def transcribe_audio(self, audio): | |
| """ | |
| Transcribe audio using the currently loaded model. | |
| Args: | |
| audio: Audio file path or tuple from Gradio audio component | |
| Returns: | |
| str: Transcribed text | |
| """ | |
| if audio is None: | |
| return "No audio provided. Please record or upload an audio file." | |
| if self.model is None: | |
| return "No model loaded. Please select a model first." | |
| try: | |
| # Preprocess audio | |
| audio_data, sample_rate = self.preprocess_audio(audio) | |
| print(f"Transcribing with model: {self.current_model_name}") | |
| print(f"Audio shape: {audio_data.shape}, sample_rate: {sample_rate}") | |
| # Process with the model based on type | |
| with torch.no_grad(): | |
| if self.model_type == "wav2vec2": | |
| inputs = self.processor(audio_data, sampling_rate=sample_rate, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| logits = self.model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| elif self.model_type == "whisper": | |
| # Whisper expects 16kHz audio | |
| if sample_rate != 16000: | |
| audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| # Process with Whisper | |
| inputs = self.processor(audio_data, sampling_rate=sample_rate, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate transcription | |
| generated_ids = self.model.generate(**inputs) | |
| transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(f"Transcription: {transcription}") | |
| return transcription.strip() | |
| except Exception as e: | |
| print(f"Error during transcription: {str(e)}") | |
| return f"Error during transcription: {str(e)}" | |
| def normalize_text(self, text): | |
| """ | |
| Normalize text for WER calculation: lowercase, remove punctuation. | |
| Args: | |
| text: Input text string | |
| Returns: | |
| str: Normalized text | |
| """ | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove punctuation except apostrophes in contractions | |
| text = re.sub(r'[^\w\s\']', '', text) | |
| # Remove extra whitespace | |
| text = ' '.join(text.split()) | |
| return text | |
| def calculate_wer_details(self, reference, hypothesis): | |
| """ | |
| Calculate WER using edit distance for accurate alignment. | |
| Args: | |
| reference: Reference text (ground truth) | |
| hypothesis: Hypothesis text (transcription) | |
| Returns: | |
| dict: WER details including rate, insertions, deletions, substitutions | |
| """ | |
| try: | |
| # Normalize both texts | |
| ref_normalized = self.normalize_text(reference) | |
| hyp_normalized = self.normalize_text(hypothesis) | |
| # Split into words | |
| ref_words = ref_normalized.split() | |
| hyp_words = hyp_normalized.split() | |
| print(f"Reference words: {ref_words}") | |
| print(f"Hypothesis words: {hyp_words}") | |
| # Use edit distance for word-level alignment | |
| import editdistance | |
| # Calculate edit distance between word sequences | |
| edit_dist = editdistance.eval(ref_words, hyp_words) | |
| # Use dynamic programming to find the optimal alignment | |
| m, n = len(ref_words), len(hyp_words) | |
| dp = [[0] * (n + 1) for _ in range(m + 1)] | |
| # Initialize DP table | |
| for i in range(m + 1): | |
| dp[i][0] = i # deletions | |
| for j in range(n + 1): | |
| dp[0][j] = j # insertions | |
| # Fill DP table | |
| for i in range(1, m + 1): | |
| for j in range(1, n + 1): | |
| if ref_words[i-1] == hyp_words[j-1]: | |
| dp[i][j] = dp[i-1][j-1] # match | |
| else: | |
| dp[i][j] = min(dp[i-1][j] + 1, # deletion | |
| dp[i][j-1] + 1, # insertion | |
| dp[i-1][j-1] + 1) # substitution | |
| # Backtrack to get operations | |
| i, j = m, n | |
| operations = [] | |
| while i > 0 or j > 0: | |
| if i > 0 and j > 0 and ref_words[i-1] == hyp_words[j-1]: | |
| operations.append(('match', ref_words[i-1], hyp_words[j-1])) | |
| i, j = i-1, j-1 | |
| elif i > 0 and (j == 0 or dp[i][j] == dp[i-1][j] + 1): | |
| operations.append(('delete', ref_words[i-1], None)) | |
| i -= 1 | |
| elif j > 0 and (i == 0 or dp[i][j] == dp[i][j-1] + 1): | |
| operations.append(('insert', None, hyp_words[j-1])) | |
| j -= 1 | |
| else: | |
| operations.append(('substitute', ref_words[i-1], hyp_words[j-1])) | |
| i, j = i-1, j-1 | |
| operations.reverse() | |
| # Count operations | |
| insertions = sum(1 for op, _, _ in operations if op == 'insert') | |
| deletions = sum(1 for op, _, _ in operations if op == 'delete') | |
| substitutions = sum(1 for op, _, _ in operations if op == 'substitute') | |
| correct_matches = sum(1 for op, _, _ in operations if op == 'match') | |
| print("Operations:") | |
| for op, ref, hyp in operations: | |
| print(f" {op}: {ref} -> {hyp}") | |
| # Calculate total errors and WER | |
| total_errors = insertions + deletions + substitutions | |
| total_words = len(ref_words) | |
| if total_words == 0: | |
| wer = 0.0 | |
| else: | |
| wer = total_errors / total_words | |
| print(f"Final counts - Insertions: {insertions}, Deletions: {deletions}, Substitutions: {substitutions}") | |
| return { | |
| 'wer': wer, | |
| 'total_errors': total_errors, | |
| 'total_words': total_words, | |
| 'correct_words': correct_matches, | |
| 'insertions': insertions, | |
| 'deletions': deletions, | |
| 'substitutions': substitutions, | |
| 'ref_normalized': ref_normalized, | |
| 'hyp_normalized': hyp_normalized | |
| } | |
| except Exception as e: | |
| print(f"Error in WER calculation: {str(e)}") | |
| # Return a default result in case of error | |
| return { | |
| 'wer': 1.0, | |
| 'total_errors': len(ref_words) if 'ref_words' in locals() else 0, | |
| 'total_words': len(ref_words) if 'ref_words' in locals() else 0, | |
| 'correct_words': 0, | |
| 'insertions': 0, | |
| 'deletions': 0, | |
| 'substitutions': 0, | |
| 'ref_normalized': ref_normalized if 'ref_normalized' in locals() else '', | |
| 'hyp_normalized': hyp_normalized if 'hyp_normalized' in locals() else '' | |
| } | |
| def get_model_info(self, model_name): | |
| """Get information about a specific model.""" | |
| if model_name in self.available_models: | |
| info = self.available_models[model_name] | |
| return f"**{info['name']}**\n{info['description']}\nMemory: {info['size']}" | |
| return "Model information not available" | |
| # Initialize the ASR interface | |
| asr_interface = MultiModelASRInterface() | |
| def load_selected_model(model_name): | |
| """Load the selected model.""" | |
| return asr_interface.load_model(model_name) | |
| def transcribe(audio): | |
| """Transcribe audio.""" | |
| if audio is None: | |
| return "Please provide audio first." | |
| if asr_interface.model is None: | |
| return "Please load a model first." | |
| print(f"Transcribe called with model: {asr_interface.current_model_name}") | |
| transcription = asr_interface.transcribe_audio(audio) | |
| if transcription and "Error" not in transcription and "No audio provided" not in transcription: | |
| return transcription | |
| else: | |
| return transcription | |
| def calculate_wer(transcription, reference): | |
| """Calculate WER when reference text is provided.""" | |
| if not transcription or transcription.strip() == "": | |
| return "No transcription available for WER calculation." | |
| if not reference or reference.strip() == "": | |
| return "Enter reference text to calculate WER." | |
| try: | |
| wer_details = asr_interface.calculate_wer_details(reference, transcription) | |
| # Format WER results | |
| wer_percent = wer_details['wer'] * 100 | |
| result = f""" | |
| ## π WER Analysis Results | |
| **Word Error Rate:** {wer_percent:.2f}% | |
| ### Word Statistics: | |
| - **Correct Words:** {wer_details['correct_words']} | |
| - **Total Words:** {wer_details['total_words']} | |
| - **Accuracy:** {(wer_details['correct_words'] / wer_details['total_words'] * 100):.2f}% | |
| ### Error Breakdown: | |
| - **Insertions:** {wer_details['insertions']} | |
| - **Deletions:** {wer_details['deletions']} | |
| - **Substitutions:** {wer_details['substitutions']} | |
| - **Total Errors:** {wer_details['total_errors']} | |
| ### Normalized Texts: | |
| **Reference:** `{wer_details['ref_normalized']}` | |
| **Hypothesis:** `{wer_details['hyp_normalized']}` | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"Error calculating WER: {str(e)}" | |
| def clear(): | |
| """Clear all inputs.""" | |
| return None, "", "" | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Multi-Model ASR", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# π€ Multi-Model Speech Recognition") | |
| gr.Markdown("Select a model, then record or upload audio for transcription.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Model Selection") | |
| # Model dropdown | |
| model_dropdown = gr.Dropdown( | |
| choices=list(asr_interface.available_models.keys()), | |
| value="facebook/wav2vec2-base-960h", | |
| label="Select ASR Model", | |
| info="Choose the model based on your needs" | |
| ) | |
| # Model info display | |
| model_info = gr.Markdown(asr_interface.get_model_info("facebook/wav2vec2-base-960h")) | |
| # Load model button | |
| load_btn = gr.Button("π₯ Load Model", variant="primary") | |
| # Current model status | |
| model_status = gr.Markdown("No model loaded. Please select and load a model.") | |
| gr.Markdown("### πΉ Audio Input") | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Record or upload audio", | |
| show_label=True | |
| ) | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("π Transcribe", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Transcription") | |
| text_output = gr.Textbox( | |
| label="Transcribed Text", | |
| placeholder="Your transcribed text will appear here...", | |
| lines=6, | |
| max_lines=10 | |
| ) | |
| gr.Markdown("### π WER Analysis") | |
| reference_input = gr.Textbox( | |
| label="Reference Text (Optional)", | |
| placeholder="Enter the correct/expected text to calculate WER...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| wer_output = gr.Markdown("Enter reference text to see WER analysis") | |
| # Status indicator | |
| status = gr.Markdown("Ready! Select a model and load it to get started.") | |
| # Event handlers | |
| def update_model_info(model_name): | |
| return asr_interface.get_model_info(model_name) | |
| # Connect event handlers | |
| model_dropdown.change( | |
| fn=update_model_info, | |
| inputs=model_dropdown, | |
| outputs=model_info | |
| ) | |
| load_btn.click( | |
| fn=load_selected_model, | |
| inputs=model_dropdown, | |
| outputs=model_status | |
| ) | |
| transcribe_btn.click( | |
| fn=transcribe, | |
| inputs=audio_input, | |
| outputs=text_output | |
| ) | |
| clear_btn.click( | |
| fn=clear, | |
| outputs=[audio_input, text_output, wer_output] | |
| ) | |
| # Auto-transcribe when audio changes | |
| audio_input.change( | |
| fn=transcribe, | |
| inputs=audio_input, | |
| outputs=text_output | |
| ) | |
| # Calculate WER when reference text changes | |
| reference_input.change( | |
| fn=calculate_wer, | |
| inputs=[text_output, reference_input], | |
| outputs=wer_output | |
| ) | |
| # Calculate WER when transcription changes | |
| text_output.change( | |
| fn=calculate_wer, | |
| inputs=[text_output, reference_input], | |
| outputs=wer_output | |
| ) | |
| # Instructions | |
| with gr.Accordion("βΉοΈ Instructions", open=False): | |
| gr.Markdown(""" | |
| ### How to use: | |
| 1. **Select Model**: Choose from available Wav2Vec2 and Whisper models | |
| 2. **Load Model**: Click 'Load Model' to load the selected model | |
| 3. **Record/Upload**: Record audio or upload an audio file | |
| 4. **Transcribe**: Click 'Transcribe' or wait for auto-transcription | |
| 5. **WER Analysis**: Enter reference text to calculate Word Error Rate | |
| 6. **Copy Text**: Use 'Copy Text' to copy the result | |
| ### Model Comparison: | |
| - **Wav2Vec2 Base (100h)**: Fastest, smallest memory (~300MB), good for basic tasks | |
| - **Wav2Vec2 Base (960h)**: Balanced speed/accuracy (~1GB), recommended for most uses | |
| - **Wav2Vec2 Large (960h)**: High accuracy (~3GB), best for difficult audio | |
| - **Whisper Large V3 Turbo**: State-of-the-art accuracy (~5GB), multilingual support | |
| ### Tips: | |
| - Larger models are more accurate but slower | |
| - Only one model is loaded at a time to save memory | |
| - Switch models anytime by selecting and loading a new one | |
| - WER calculation normalizes text (lowercase, no punctuation) | |
| - Lower WER percentage indicates better transcription accuracy | |
| """) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, # Default HF Spaces port | |
| share=False, # Don't create shareable link (HF handles this) | |
| show_error=True, # Show errors for debugging | |
| quiet=False, # Show startup messages | |
| inbrowser=False # Don't open browser (HF handles this) | |
| ) |