File size: 1,619 Bytes
7b59fd0
 
2618cfe
7b59fd0
 
 
 
 
 
2618cfe
7b59fd0
 
4c00d6d
1022f85
4c00d6d
 
1022f85
f908e68
4c00d6d
f908e68
7283d38
 
4c00d6d
 
f908e68
 
2d831b9
f908e68
2618cfe
 
 
4c00d6d
 
 
 
 
 
0d1bfa2
7b59fd0
f908e68
4c00d6d
7b59fd0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import soundfile as sf
from transformers import CsmForConditionalGeneration, AutoProcessor

class EndpointHandler:
    def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)

    def __call__(self, data: dict) -> dict:
        # Get nested input dict
        input_data = data.get("inputs", {})
        
        # Extract input values
        text =  input_data.get("text")
        if not text:
            return {"error": "Missing 'text' parameter inside 'inputs'."}

        speaker_id = input_data.get("speaker_id", 0)
        sampling_rate = input_data.get("sampling_rate", 24000)

        # Format input text with speaker token
        input_text = f"[{speaker_id}]{text}"

        # Tokenize and generate
        inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
        output = self.model.generate(**inputs, output_audio=True)
        audio_tensor = output[0].to(torch.float32).cpu().numpy()

        # Encode audio to base64 WAV
        import io, base64
        buffer = io.BytesIO()
        sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
        buffer.seek(0)
        audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")

        return {
            "input_text": input_text,
            "audio_base64": audio_base64,
            "sampling_rate": sampling_rate,
        }