|
|
import torch |
|
|
import soundfile as sf |
|
|
from transformers import CsmForConditionalGeneration, AutoProcessor |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = device |
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
|
self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device) |
|
|
|
|
|
def __call__(self, data: dict) -> dict: |
|
|
|
|
|
input_data = data.get("inputs", {}) |
|
|
|
|
|
|
|
|
text = input_data.get("text") |
|
|
if not text: |
|
|
return {"error": "Missing 'text' parameter inside 'inputs'."} |
|
|
|
|
|
speaker_id = input_data.get("speaker_id", 0) |
|
|
sampling_rate = input_data.get("sampling_rate", 24000) |
|
|
|
|
|
|
|
|
input_text = f"[{speaker_id}]{text}" |
|
|
|
|
|
|
|
|
inputs = self.processor(input_text, add_special_tokens=True).to(self.device) |
|
|
output = self.model.generate(**inputs, output_audio=True) |
|
|
audio_tensor = output[0].to(torch.float32).cpu().numpy() |
|
|
|
|
|
|
|
|
import io, base64 |
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, audio_tensor, sampling_rate, format="WAV") |
|
|
buffer.seek(0) |
|
|
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") |
|
|
|
|
|
return { |
|
|
"input_text": input_text, |
|
|
"audio_base64": audio_base64, |
|
|
"sampling_rate": sampling_rate, |
|
|
} |
|
|
|