Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import time | |
| import matplotlib.pyplot as plt | |
| from typing import Tuple, List | |
| from statistics import mean, median, stdev | |
| from lib import ( | |
| normalize_text, | |
| chunk_text, | |
| count_tokens, | |
| load_module_from_file, | |
| download_model_files, | |
| list_voice_files, | |
| download_voice_files, | |
| ensure_dir, | |
| concatenate_audio_chunks | |
| ) | |
| class TTSModel: | |
| """GPU-accelerated TTS model manager""" | |
| def __init__(self): | |
| self.model = None | |
| self.voices_dir = "voices" | |
| self.model_repo = "hexgrad/Kokoro-82M" | |
| ensure_dir(self.voices_dir) | |
| # Load required modules | |
| py_modules = ["istftnet", "plbert", "models", "kokoro"] | |
| module_files = download_model_files(self.model_repo, [f"{m}.py" for m in py_modules]) | |
| for module_name, file_path in zip(py_modules, module_files): | |
| load_module_from_file(module_name, file_path) | |
| # Import required functions from kokoro module | |
| kokoro = __import__("kokoro") | |
| self.generate = kokoro.generate | |
| self.build_model = __import__("models").build_model | |
| def initialize(self) -> bool: | |
| """Initialize model and download voices""" | |
| try: | |
| print("Initializing model...") | |
| # Download model files | |
| model_files = download_model_files( | |
| self.model_repo, | |
| ["kokoro-v0_19.pth", "config.json"] | |
| ) | |
| model_path = model_files[0] # kokoro-v0_19.pth | |
| # Build model directly on GPU | |
| with torch.cuda.device(0): | |
| torch.cuda.set_device(0) | |
| self.model = self.build_model(model_path, 'cuda') | |
| self._model_on_gpu = True | |
| print("Model initialization complete") | |
| return True | |
| except Exception as e: | |
| print(f"Error initializing model: {str(e)}") | |
| return False | |
| def ensure_voice_downloaded(self, voice_name: str) -> bool: | |
| """Ensure specific voice is downloaded""" | |
| try: | |
| voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt") | |
| if not os.path.exists(voice_path): | |
| print(f"Downloading voice {voice_name}.pt...") | |
| download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir) | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading voice {voice_name}: {str(e)}") | |
| return False | |
| def list_voices(self) -> List[str]: | |
| """List available voices""" | |
| return [ | |
| "af_bella", "af_nicole", "af_sarah", "af_sky", "af", | |
| "am_adam", "am_michael", "bf_emma", "bf_isabella", | |
| "bm_george", "bm_lewis" | |
| ] | |
| def _ensure_model_on_gpu(self) -> None: | |
| """Ensure model is on GPU and stays there""" | |
| if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu: | |
| print("Moving model to GPU...") | |
| with torch.cuda.device(0): | |
| torch.cuda.set_device(0) | |
| if hasattr(self.model, 'to'): | |
| self.model.to('cuda') | |
| else: | |
| for name in self.model: | |
| if isinstance(self.model[name], torch.Tensor): | |
| self.model[name] = self.model[name].cuda() | |
| self._model_on_gpu = True | |
| def _generate_audio(self, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> np.ndarray: | |
| """GPU-accelerated audio generation""" | |
| try: | |
| with torch.cuda.device(0): | |
| torch.cuda.set_device(0) | |
| # Move everything to GPU in a single context | |
| if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu: | |
| print("Moving model to GPU...") | |
| if hasattr(self.model, 'to'): | |
| self.model.to('cuda') | |
| else: | |
| for name in self.model: | |
| if isinstance(self.model[name], torch.Tensor): | |
| self.model[name] = self.model[name].cuda() | |
| self._model_on_gpu = True | |
| # Move voicepack to GPU | |
| voicepack = voicepack.cuda() | |
| # Run generation with everything on GPU | |
| audio, _ = self.generate( | |
| self.model, | |
| text, | |
| voicepack, | |
| lang=lang, | |
| speed=speed | |
| ) | |
| return audio | |
| except Exception as e: | |
| print(f"Error in audio generation: {str(e)}") | |
| raise e | |
| def generate_speech(self, text: str, voice_name: str, speed: float = 1.0, progress_callback=None) -> Tuple[np.ndarray, float]: | |
| """Generate speech from text. Returns (audio_array, duration) | |
| Args: | |
| text: Input text to convert to speech | |
| voice_name: Name of voice to use | |
| speed: Speech speed multiplier | |
| progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf) | |
| """ | |
| try: | |
| if not text or not voice_name: | |
| raise ValueError("Text and voice name are required") | |
| start_time = time.time() | |
| # Count tokens and normalize text | |
| total_tokens = count_tokens(text) | |
| text = normalize_text(text) | |
| if not text: | |
| raise ValueError("Text is empty after normalization") | |
| # Load voice and process within GPU context | |
| with torch.cuda.device(0): | |
| torch.cuda.set_device(0) | |
| voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt") | |
| # Ensure voice is downloaded and load directly to GPU | |
| if not self.ensure_voice_downloaded(voice_name): | |
| raise ValueError(f"Failed to download voice: {voice_name}") | |
| voicepack = torch.load(voice_path, map_location='cuda', weights_only=True) | |
| # Break text into chunks for better memory management | |
| chunks = chunk_text(text) | |
| print(f"Processing {len(chunks)} chunks...") | |
| # Ensure model is initialized and on GPU | |
| if self.model is None: | |
| print("Model not initialized, reinitializing...") | |
| if not self.initialize(): | |
| raise ValueError("Failed to initialize model") | |
| # Move model to GPU if needed | |
| if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu: | |
| print("Moving model to GPU...") | |
| if hasattr(self.model, 'to'): | |
| self.model.to('cuda') | |
| else: | |
| for name in self.model: | |
| if isinstance(self.model[name], torch.Tensor): | |
| self.model[name] = self.model[name].cuda() | |
| self._model_on_gpu = True | |
| # Process all chunks within same GPU context | |
| audio_chunks = [] | |
| chunk_times = [] | |
| chunk_sizes = [] # Store chunk lengths | |
| total_processed_tokens = 0 | |
| total_processed_time = 0 | |
| for i, chunk in enumerate(chunks): | |
| chunk_start = time.time() | |
| chunk_audio = self._generate_audio( | |
| text=chunk, | |
| voicepack=voicepack, | |
| lang=voice_name[0], | |
| speed=speed | |
| ) | |
| chunk_time = time.time() - chunk_start | |
| # Update metrics | |
| chunk_tokens = count_tokens(chunk) | |
| total_processed_tokens += chunk_tokens | |
| total_processed_time += chunk_time | |
| current_tokens_per_sec = total_processed_tokens / total_processed_time | |
| # Calculate processing speed metrics | |
| chunk_duration = len(chunk_audio) / 24000 # audio duration in seconds | |
| rtf = chunk_time / chunk_duration | |
| times_faster = 1 / rtf | |
| chunk_times.append(chunk_time) | |
| chunk_sizes.append(len(chunk)) | |
| print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s") | |
| print(f"Current tokens/sec: {current_tokens_per_sec:.2f}") | |
| print(f"Real-time factor: {rtf:.2f}x") | |
| print(f"{times_faster:.1f}x faster than real-time") | |
| audio_chunks.append(chunk_audio) | |
| # Call progress callback if provided | |
| if progress_callback: | |
| progress_callback(i + 1, len(chunks), current_tokens_per_sec, rtf) | |
| # Concatenate audio chunks | |
| audio = concatenate_audio_chunks(audio_chunks) | |
| def setup_plot(fig, ax, title): | |
| """Configure plot styling""" | |
| # Improve grid | |
| ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff") | |
| # Set title and labels with better fonts and more padding | |
| ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff") | |
| ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff") | |
| ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff") | |
| # Improve tick labels | |
| ax.tick_params(labelsize=12, colors="#ffffff") | |
| # Style spines | |
| for spine in ax.spines.values(): | |
| spine.set_color("#ffffff") | |
| spine.set_alpha(0.3) | |
| spine.set_linewidth(0.5) | |
| # Set background colors | |
| ax.set_facecolor("#1a1a2e") | |
| fig.patch.set_facecolor("#1a1a2e") | |
| return fig, ax | |
| # Set dark style | |
| plt.style.use("dark_background") | |
| # Create figure with subplots | |
| fig = plt.figure(figsize=(18, 16)) | |
| fig.patch.set_facecolor("#1a1a2e") | |
| # Create subplot grid | |
| gs = plt.GridSpec(2, 1, left=0.15, right=0.85, top=0.9, bottom=0.15, hspace=0.4) | |
| # Processing times plot | |
| ax1 = plt.subplot(gs[0]) | |
| chunks_x = list(range(1, len(chunks) + 1)) | |
| bars = ax1.bar(chunks_x, chunk_times, color='#ff2a6d', alpha=0.8) | |
| # Add statistics lines | |
| mean_time = mean(chunk_times) | |
| median_time = median(chunk_times) | |
| std_time = stdev(chunk_times) if len(chunk_times) > 1 else 0 | |
| ax1.axhline(y=mean_time, color='#05d9e8', linestyle='--', | |
| label=f'Mean: {mean_time:.2f}s') | |
| ax1.axhline(y=median_time, color='#d1f7ff', linestyle=':', | |
| label=f'Median: {median_time:.2f}s') | |
| # Add ±1 std dev range | |
| if len(chunk_times) > 1: | |
| ax1.axhspan(mean_time - std_time, mean_time + std_time, | |
| color='#8c1eff', alpha=0.2, label='±1 Std Dev') | |
| # Add value labels on top of bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax1.text(bar.get_x() + bar.get_width() / 2.0, | |
| height, | |
| f'{height:.2f}s', | |
| ha='center', | |
| va='bottom', | |
| color='white', | |
| fontsize=10) | |
| ax1.set_xlabel('Chunk Number') | |
| ax1.set_ylabel('Processing Time (seconds)') | |
| setup_plot(fig, ax1, 'Chunk Processing Times') | |
| ax1.legend(facecolor="#1a1a2e", edgecolor="#ffffff") | |
| # Chunk sizes plot | |
| ax2 = plt.subplot(gs[1]) | |
| ax2.plot(chunks_x, chunk_sizes, color='#ff9e00', marker='o', linewidth=2) | |
| ax2.set_xlabel('Chunk Number') | |
| ax2.set_ylabel('Chunk Size (chars)') | |
| setup_plot(fig, ax2, 'Chunk Sizes') | |
| # Save plot | |
| plt.savefig('chunk_times.png') | |
| plt.close() | |
| # Calculate metrics | |
| total_time = time.time() - start_time | |
| tokens_per_second = total_tokens / total_time | |
| print(f"\nProcessing Metrics:") | |
| print(f"Total tokens: {total_tokens}") | |
| print(f"Total time: {total_time:.2f}s") | |
| print(f"Tokens per second: {tokens_per_second:.2f}") | |
| print(f"Mean chunk time: {mean_time:.2f}s") | |
| print(f"Median chunk time: {median_time:.2f}s") | |
| if len(chunk_times) > 1: | |
| print(f"Std dev: {std_time:.2f}s") | |
| print(f"\nChunk time plot saved as 'chunk_times.png'") | |
| return audio, len(audio) / 24000 # Return audio array and duration | |
| except Exception as e: | |
| print(f"Error generating speech: {str(e)}") | |
| raise | |