Update handler.py
Browse files- handler.py +10 -9
handler.py
CHANGED
|
@@ -10,25 +10,26 @@ class EndpointHandler:
|
|
| 10 |
self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
|
| 11 |
|
| 12 |
def __call__(self, data: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
| 13 |
# Extract input values
|
| 14 |
-
text =
|
| 15 |
if not text:
|
| 16 |
-
return {"error": "Missing 'text' parameter
|
| 17 |
|
| 18 |
-
speaker_id =
|
| 19 |
-
sampling_rate =
|
| 20 |
|
| 21 |
-
#
|
| 22 |
input_text = f"[{speaker_id}]{text}"
|
| 23 |
|
| 24 |
-
# Tokenize
|
| 25 |
inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
|
| 26 |
-
|
| 27 |
-
# Generate audio
|
| 28 |
output = self.model.generate(**inputs, output_audio=True)
|
| 29 |
audio_tensor = output[0].to(torch.float32).cpu().numpy()
|
| 30 |
|
| 31 |
-
#
|
| 32 |
import io, base64
|
| 33 |
buffer = io.BytesIO()
|
| 34 |
sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
|
|
|
|
| 10 |
self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
|
| 11 |
|
| 12 |
def __call__(self, data: dict) -> dict:
|
| 13 |
+
# Get nested input dict
|
| 14 |
+
input_data = data.get("inputs", {})
|
| 15 |
+
|
| 16 |
# Extract input values
|
| 17 |
+
text = input_data.get("text")
|
| 18 |
if not text:
|
| 19 |
+
return {"error": "Missing 'text' parameter inside 'inputs'."}
|
| 20 |
|
| 21 |
+
speaker_id = input_data.get("speaker_id", 0)
|
| 22 |
+
sampling_rate = input_data.get("sampling_rate", 24000)
|
| 23 |
|
| 24 |
+
# Format input text with speaker token
|
| 25 |
input_text = f"[{speaker_id}]{text}"
|
| 26 |
|
| 27 |
+
# Tokenize and generate
|
| 28 |
inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
|
|
|
|
|
|
|
| 29 |
output = self.model.generate(**inputs, output_audio=True)
|
| 30 |
audio_tensor = output[0].to(torch.float32).cpu().numpy()
|
| 31 |
|
| 32 |
+
# Encode audio to base64 WAV
|
| 33 |
import io, base64
|
| 34 |
buffer = io.BytesIO()
|
| 35 |
sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
|