jsbeaudry commited on
Commit
7b59fd0
·
verified ·
1 Parent(s): 4369244

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -0
handler.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import soundfile as sf
3
+ import numpy as np
4
+ from transformers import CsmForConditionalGeneration, AutoProcessor
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, model_path: str = "jsbeaudry/sesame-creole-tts"):
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ self.device = device
10
+ self.processor = AutoProcessor.from_pretrained(model_path)
11
+ self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
12
+
13
+ def __call__(self, data: dict) -> dict:
14
+ text = data.get("inputs", "[0]Bonjou tout moun koman nou ye?")
15
+ sampling_rate = data.get("sampling_rate", 24000)
16
+
17
+ # Prepare input
18
+ inputs = self.processor(text, add_special_tokens=True).to(self.device)
19
+
20
+ # Generate audio
21
+ output = self.model.generate(**inputs, output_audio=True)
22
+ audio_tensor = output[0].to(torch.float32).cpu().numpy()
23
+
24
+ # Return audio as base64-encoded WAV (binary isn't supported directly in response)
25
+ import io, base64
26
+ buffer = io.BytesIO()
27
+ sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
28
+ buffer.seek(0)
29
+ audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
30
+
31
+ return {
32
+ "text": text,
33
+ "audio_base64": audio_base64,
34
+ "sampling_rate": sampling_rate,
35
+ }