Spaces:
Running
Running
| """ | |
| Data models and types for the TTSFM package. | |
| This module defines the core data structures used throughout the package, | |
| including request/response models, enums, and error types. | |
| """ | |
| from enum import Enum | |
| from typing import Optional, Dict, Any, Union | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| class Voice(str, Enum): | |
| """Available voice options for TTS generation.""" | |
| ALLOY = "alloy" | |
| ASH = "ash" | |
| BALLAD = "ballad" | |
| CORAL = "coral" | |
| ECHO = "echo" | |
| FABLE = "fable" | |
| NOVA = "nova" | |
| ONYX = "onyx" | |
| SAGE = "sage" | |
| SHIMMER = "shimmer" | |
| VERSE = "verse" | |
| class AudioFormat(str, Enum): | |
| """Supported audio output formats.""" | |
| MP3 = "mp3" | |
| WAV = "wav" | |
| OPUS = "opus" | |
| AAC = "aac" | |
| FLAC = "flac" | |
| PCM = "pcm" | |
| class TTSRequest: | |
| """ | |
| Request model for TTS generation. | |
| Attributes: | |
| input: Text to convert to speech | |
| voice: Voice to use for generation | |
| response_format: Audio format for output | |
| instructions: Optional instructions for voice modulation | |
| model: Model to use (for OpenAI compatibility, usually ignored) | |
| speed: Speech speed (for OpenAI compatibility, usually ignored) | |
| max_length: Maximum allowed text length (default: 4096 characters) | |
| validate_length: Whether to validate text length (default: True) | |
| """ | |
| input: str | |
| voice: Union[Voice, str] = Voice.ALLOY | |
| response_format: Union[AudioFormat, str] = AudioFormat.MP3 | |
| instructions: Optional[str] = None | |
| model: Optional[str] = None | |
| speed: Optional[float] = None | |
| max_length: int = 4096 | |
| validate_length: bool = True | |
| def __post_init__(self): | |
| """Validate and normalize fields after initialization.""" | |
| # Ensure voice is a valid Voice enum | |
| if isinstance(self.voice, str): | |
| try: | |
| self.voice = Voice(self.voice.lower()) | |
| except ValueError: | |
| raise ValueError(f"Invalid voice: {self.voice}. Must be one of {list(Voice)}") | |
| # Ensure response_format is a valid AudioFormat enum | |
| if isinstance(self.response_format, str): | |
| try: | |
| self.response_format = AudioFormat(self.response_format.lower()) | |
| except ValueError: | |
| raise ValueError(f"Invalid format: {self.response_format}. Must be one of {list(AudioFormat)}") | |
| # Validate input text | |
| if not self.input or not self.input.strip(): | |
| raise ValueError("Input text cannot be empty") | |
| # Validate text length if enabled | |
| if self.validate_length: | |
| text_length = len(self.input) | |
| if text_length > self.max_length: | |
| raise ValueError( | |
| f"Input text is too long ({text_length} characters). " | |
| f"Maximum allowed length is {self.max_length} characters. " | |
| f"Consider splitting your text into smaller chunks or disable " | |
| f"length validation with validate_length=False." | |
| ) | |
| # Validate max_length parameter | |
| if self.max_length <= 0: | |
| raise ValueError("max_length must be a positive integer") | |
| # Validate speed if provided | |
| if self.speed is not None and (self.speed < 0.25 or self.speed > 4.0): | |
| raise ValueError("Speed must be between 0.25 and 4.0") | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert request to dictionary for API calls.""" | |
| data = { | |
| "input": self.input, | |
| "voice": self.voice.value if isinstance(self.voice, Voice) else self.voice, | |
| "response_format": self.response_format.value if isinstance(self.response_format, AudioFormat) else self.response_format | |
| } | |
| if self.instructions: | |
| data["instructions"] = self.instructions | |
| if self.model: | |
| data["model"] = self.model | |
| if self.speed is not None: | |
| data["speed"] = self.speed | |
| return data | |
| class TTSResponse: | |
| """ | |
| Response model for TTS generation. | |
| Attributes: | |
| audio_data: Generated audio as bytes | |
| content_type: MIME type of the audio data | |
| format: Audio format used | |
| size: Size of audio data in bytes | |
| duration: Estimated duration in seconds (if available) | |
| metadata: Additional response metadata | |
| """ | |
| audio_data: bytes | |
| content_type: str | |
| format: AudioFormat | |
| size: int | |
| duration: Optional[float] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| def __post_init__(self): | |
| """Calculate derived fields after initialization.""" | |
| if self.size is None: | |
| self.size = len(self.audio_data) | |
| def save_to_file(self, filename: str) -> str: | |
| """ | |
| Save audio data to a file. | |
| Args: | |
| filename: Target filename (extension will be added if missing) | |
| Returns: | |
| str: Final filename used | |
| """ | |
| import os | |
| # Use the actual returned format for the extension, not any requested format | |
| expected_extension = f".{self.format.value}" | |
| # Check if filename already has the correct extension | |
| if filename.endswith(expected_extension): | |
| final_filename = filename | |
| else: | |
| # Remove any existing extension and add the correct one | |
| base_name = filename | |
| # Remove common audio extensions if present | |
| for ext in ['.mp3', '.wav', '.opus', '.aac', '.flac', '.pcm']: | |
| if base_name.endswith(ext): | |
| base_name = base_name[:-len(ext)] | |
| break | |
| final_filename = f"{base_name}{expected_extension}" | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(final_filename) if os.path.dirname(final_filename) else ".", exist_ok=True) | |
| # Write audio data | |
| with open(final_filename, "wb") as f: | |
| f.write(self.audio_data) | |
| return final_filename | |
| class TTSError: | |
| """ | |
| Error information from TTS API. | |
| Attributes: | |
| code: Error code | |
| message: Human-readable error message | |
| type: Error type/category | |
| details: Additional error details | |
| timestamp: When the error occurred | |
| """ | |
| code: str | |
| message: str | |
| type: Optional[str] = None | |
| details: Optional[Dict[str, Any]] = None | |
| timestamp: Optional[datetime] = None | |
| def __post_init__(self): | |
| """Set timestamp if not provided.""" | |
| if self.timestamp is None: | |
| self.timestamp = datetime.now() | |
| class APIError(TTSError): | |
| """API-specific error information.""" | |
| status_code: int = 500 | |
| headers: Optional[Dict[str, str]] = None | |
| class NetworkError(TTSError): | |
| """Network-related error information.""" | |
| timeout: Optional[float] = None | |
| retry_count: int = 0 | |
| class ValidationError(TTSError): | |
| """Validation error information.""" | |
| field: Optional[str] = None | |
| value: Optional[Any] = None | |
| # Content type mappings for audio formats | |
| CONTENT_TYPE_MAP = { | |
| AudioFormat.MP3: "audio/mpeg", | |
| AudioFormat.OPUS: "audio/opus", | |
| AudioFormat.AAC: "audio/aac", | |
| AudioFormat.FLAC: "audio/flac", | |
| AudioFormat.WAV: "audio/wav", | |
| AudioFormat.PCM: "audio/pcm" | |
| } | |
| # Reverse mapping for content type to format | |
| FORMAT_FROM_CONTENT_TYPE = {v: k for k, v in CONTENT_TYPE_MAP.items()} | |
| def get_content_type(format: Union[AudioFormat, str]) -> str: | |
| """Get MIME content type for audio format.""" | |
| if isinstance(format, str): | |
| format = AudioFormat(format.lower()) | |
| return CONTENT_TYPE_MAP.get(format, "audio/mpeg") | |
| def get_format_from_content_type(content_type: str) -> AudioFormat: | |
| """Get audio format from MIME content type.""" | |
| return FORMAT_FROM_CONTENT_TYPE.get(content_type, AudioFormat.MP3) | |
| def get_supported_format(requested_format: AudioFormat) -> AudioFormat: | |
| """ | |
| Map requested format to supported format. | |
| Args: | |
| requested_format: The requested audio format | |
| Returns: | |
| AudioFormat: MP3 or WAV (the supported formats) | |
| """ | |
| if requested_format == AudioFormat.MP3: | |
| return AudioFormat.MP3 | |
| else: | |
| # All other formats (WAV, OPUS, AAC, FLAC, PCM) return WAV | |
| return AudioFormat.WAV | |
| def maps_to_wav(format_value: str) -> bool: | |
| """ | |
| Check if a format maps to WAV. | |
| Args: | |
| format_value: Format string to check | |
| Returns: | |
| bool: True if the format maps to WAV | |
| """ | |
| return format_value.lower() in ['wav', 'opus', 'aac', 'flac', 'pcm'] | |