sesame-creole-tts / handler.py
jsbeaudry's picture
Update handler.py
1022f85 verified
import torch
import soundfile as sf
from transformers import CsmForConditionalGeneration, AutoProcessor
class EndpointHandler:
def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
def __call__(self, data: dict) -> dict:
# Get nested input dict
input_data = data.get("inputs", {})
# Extract input values
text = input_data.get("text")
if not text:
return {"error": "Missing 'text' parameter inside 'inputs'."}
speaker_id = input_data.get("speaker_id", 0)
sampling_rate = input_data.get("sampling_rate", 24000)
# Format input text with speaker token
input_text = f"[{speaker_id}]{text}"
# Tokenize and generate
inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
output = self.model.generate(**inputs, output_audio=True)
audio_tensor = output[0].to(torch.float32).cpu().numpy()
# Encode audio to base64 WAV
import io, base64
buffer = io.BytesIO()
sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return {
"input_text": input_text,
"audio_base64": audio_base64,
"sampling_rate": sampling_rate,
}