""" Data models and types for the TTSFM package. This module defines the core data structures used throughout the package, including request/response models, enums, and error types. """ from enum import Enum from typing import Optional, Dict, Any, Union from dataclasses import dataclass from datetime import datetime class Voice(str, Enum): """Available voice options for TTS generation.""" ALLOY = "alloy" ASH = "ash" BALLAD = "ballad" CORAL = "coral" ECHO = "echo" FABLE = "fable" NOVA = "nova" ONYX = "onyx" SAGE = "sage" SHIMMER = "shimmer" VERSE = "verse" class AudioFormat(str, Enum): """Supported audio output formats.""" MP3 = "mp3" WAV = "wav" OPUS = "opus" AAC = "aac" FLAC = "flac" PCM = "pcm" @dataclass class TTSRequest: """ Request model for TTS generation. Attributes: input: Text to convert to speech voice: Voice to use for generation response_format: Audio format for output instructions: Optional instructions for voice modulation model: Model to use (for OpenAI compatibility, usually ignored) speed: Speech speed (for OpenAI compatibility, usually ignored) max_length: Maximum allowed text length (default: 4096 characters) validate_length: Whether to validate text length (default: True) """ input: str voice: Union[Voice, str] = Voice.ALLOY response_format: Union[AudioFormat, str] = AudioFormat.MP3 instructions: Optional[str] = None model: Optional[str] = None speed: Optional[float] = None max_length: int = 4096 validate_length: bool = True def __post_init__(self): """Validate and normalize fields after initialization.""" # Ensure voice is a valid Voice enum if isinstance(self.voice, str): try: self.voice = Voice(self.voice.lower()) except ValueError: raise ValueError(f"Invalid voice: {self.voice}. Must be one of {list(Voice)}") # Ensure response_format is a valid AudioFormat enum if isinstance(self.response_format, str): try: self.response_format = AudioFormat(self.response_format.lower()) except ValueError: raise ValueError(f"Invalid format: {self.response_format}. Must be one of {list(AudioFormat)}") # Validate input text if not self.input or not self.input.strip(): raise ValueError("Input text cannot be empty") # Validate text length if enabled if self.validate_length: text_length = len(self.input) if text_length > self.max_length: raise ValueError( f"Input text is too long ({text_length} characters). " f"Maximum allowed length is {self.max_length} characters. " f"Consider splitting your text into smaller chunks or disable " f"length validation with validate_length=False." ) # Validate max_length parameter if self.max_length <= 0: raise ValueError("max_length must be a positive integer") # Validate speed if provided if self.speed is not None and (self.speed < 0.25 or self.speed > 4.0): raise ValueError("Speed must be between 0.25 and 4.0") def to_dict(self) -> Dict[str, Any]: """Convert request to dictionary for API calls.""" data = { "input": self.input, "voice": self.voice.value if isinstance(self.voice, Voice) else self.voice, "response_format": self.response_format.value if isinstance(self.response_format, AudioFormat) else self.response_format } if self.instructions: data["instructions"] = self.instructions if self.model: data["model"] = self.model if self.speed is not None: data["speed"] = self.speed return data @dataclass class TTSResponse: """ Response model for TTS generation. Attributes: audio_data: Generated audio as bytes content_type: MIME type of the audio data format: Audio format used size: Size of audio data in bytes duration: Estimated duration in seconds (if available) metadata: Additional response metadata """ audio_data: bytes content_type: str format: AudioFormat size: int duration: Optional[float] = None metadata: Optional[Dict[str, Any]] = None def __post_init__(self): """Calculate derived fields after initialization.""" if self.size is None: self.size = len(self.audio_data) def save_to_file(self, filename: str) -> str: """ Save audio data to a file. Args: filename: Target filename (extension will be added if missing) Returns: str: Final filename used """ import os # Use the actual returned format for the extension, not any requested format expected_extension = f".{self.format.value}" # Check if filename already has the correct extension if filename.endswith(expected_extension): final_filename = filename else: # Remove any existing extension and add the correct one base_name = filename # Remove common audio extensions if present for ext in ['.mp3', '.wav', '.opus', '.aac', '.flac', '.pcm']: if base_name.endswith(ext): base_name = base_name[:-len(ext)] break final_filename = f"{base_name}{expected_extension}" # Create directory if it doesn't exist os.makedirs(os.path.dirname(final_filename) if os.path.dirname(final_filename) else ".", exist_ok=True) # Write audio data with open(final_filename, "wb") as f: f.write(self.audio_data) return final_filename @dataclass class TTSError: """ Error information from TTS API. Attributes: code: Error code message: Human-readable error message type: Error type/category details: Additional error details timestamp: When the error occurred """ code: str message: str type: Optional[str] = None details: Optional[Dict[str, Any]] = None timestamp: Optional[datetime] = None def __post_init__(self): """Set timestamp if not provided.""" if self.timestamp is None: self.timestamp = datetime.now() @dataclass class APIError(TTSError): """API-specific error information.""" status_code: int = 500 headers: Optional[Dict[str, str]] = None @dataclass class NetworkError(TTSError): """Network-related error information.""" timeout: Optional[float] = None retry_count: int = 0 @dataclass class ValidationError(TTSError): """Validation error information.""" field: Optional[str] = None value: Optional[Any] = None # Content type mappings for audio formats CONTENT_TYPE_MAP = { AudioFormat.MP3: "audio/mpeg", AudioFormat.OPUS: "audio/opus", AudioFormat.AAC: "audio/aac", AudioFormat.FLAC: "audio/flac", AudioFormat.WAV: "audio/wav", AudioFormat.PCM: "audio/pcm" } # Reverse mapping for content type to format FORMAT_FROM_CONTENT_TYPE = {v: k for k, v in CONTENT_TYPE_MAP.items()} def get_content_type(format: Union[AudioFormat, str]) -> str: """Get MIME content type for audio format.""" if isinstance(format, str): format = AudioFormat(format.lower()) return CONTENT_TYPE_MAP.get(format, "audio/mpeg") def get_format_from_content_type(content_type: str) -> AudioFormat: """Get audio format from MIME content type.""" return FORMAT_FROM_CONTENT_TYPE.get(content_type, AudioFormat.MP3) def get_supported_format(requested_format: AudioFormat) -> AudioFormat: """ Map requested format to supported format. Args: requested_format: The requested audio format Returns: AudioFormat: MP3 or WAV (the supported formats) """ if requested_format == AudioFormat.MP3: return AudioFormat.MP3 else: # All other formats (WAV, OPUS, AAC, FLAC, PCM) return WAV return AudioFormat.WAV def maps_to_wav(format_value: str) -> bool: """ Check if a format maps to WAV. Args: format_value: Format string to check Returns: bool: True if the format maps to WAV """ return format_value.lower() in ['wav', 'opus', 'aac', 'flac', 'pcm']