import torch import soundfile as sf from transformers import CsmForConditionalGeneration, AutoProcessor class EndpointHandler: def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"): device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.processor = AutoProcessor.from_pretrained(model_path) self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device) def __call__(self, data: dict) -> dict: # Get nested input dict input_data = data.get("inputs", {}) # Extract input values text = input_data.get("text") if not text: return {"error": "Missing 'text' parameter inside 'inputs'."} speaker_id = input_data.get("speaker_id", 0) sampling_rate = input_data.get("sampling_rate", 24000) # Format input text with speaker token input_text = f"[{speaker_id}]{text}" # Tokenize and generate inputs = self.processor(input_text, add_special_tokens=True).to(self.device) output = self.model.generate(**inputs, output_audio=True) audio_tensor = output[0].to(torch.float32).cpu().numpy() # Encode audio to base64 WAV import io, base64 buffer = io.BytesIO() sf.write(buffer, audio_tensor, sampling_rate, format="WAV") buffer.seek(0) audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") return { "input_text": input_text, "audio_base64": audio_base64, "sampling_rate": sampling_rate, }