File size: 1,619 Bytes
7b59fd0 2618cfe 7b59fd0 2618cfe 7b59fd0 4c00d6d 1022f85 4c00d6d 1022f85 f908e68 4c00d6d f908e68 7283d38 4c00d6d f908e68 2d831b9 f908e68 2618cfe 4c00d6d 0d1bfa2 7b59fd0 f908e68 4c00d6d 7b59fd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
import soundfile as sf
from transformers import CsmForConditionalGeneration, AutoProcessor
class EndpointHandler:
def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
def __call__(self, data: dict) -> dict:
# Get nested input dict
input_data = data.get("inputs", {})
# Extract input values
text = input_data.get("text")
if not text:
return {"error": "Missing 'text' parameter inside 'inputs'."}
speaker_id = input_data.get("speaker_id", 0)
sampling_rate = input_data.get("sampling_rate", 24000)
# Format input text with speaker token
input_text = f"[{speaker_id}]{text}"
# Tokenize and generate
inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
output = self.model.generate(**inputs, output_audio=True)
audio_tensor = output[0].to(torch.float32).cpu().numpy()
# Encode audio to base64 WAV
import io, base64
buffer = io.BytesIO()
sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return {
"input_text": input_text,
"audio_base64": audio_base64,
"sampling_rate": sampling_rate,
}
|