jsbeaudry commited on
Commit
7283d38
·
verified ·
1 Parent(s): f908e68

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -9
handler.py CHANGED
@@ -10,25 +10,26 @@ class EndpointHandler:
10
  self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
11
 
12
  def __call__(self, data: dict) -> dict:
 
 
 
13
  # Extract input values
14
- text = data.get("text")
15
  if not text:
16
- return {"error": "Missing 'text' parameter in request."}
17
 
18
- speaker_id = data.get("speaker_id", 0) # Optional speaker ID
19
- sampling_rate = data.get("sampling_rate", 24000)
20
 
21
- # Inject speaker ID token into input
22
  input_text = f"[{speaker_id}]{text}"
23
 
24
- # Tokenize input
25
  inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
26
-
27
- # Generate audio
28
  output = self.model.generate(**inputs, output_audio=True)
29
  audio_tensor = output[0].to(torch.float32).cpu().numpy()
30
 
31
- # Convert audio to base64 WAV
32
  import io, base64
33
  buffer = io.BytesIO()
34
  sf.write(buffer, audio_tensor, sampling_rate, format="WAV")
 
10
  self.model = CsmForConditionalGeneration.from_pretrained(model_path, device_map=device)
11
 
12
  def __call__(self, data: dict) -> dict:
13
+ # Get nested input dict
14
+ input_data = data.get("inputs", {})
15
+
16
  # Extract input values
17
+ text = input_data.get("text")
18
  if not text:
19
+ return {"error": "Missing 'text' parameter inside 'inputs'."}
20
 
21
+ speaker_id = input_data.get("speaker_id", 0)
22
+ sampling_rate = input_data.get("sampling_rate", 24000)
23
 
24
+ # Format input text with speaker token
25
  input_text = f"[{speaker_id}]{text}"
26
 
27
+ # Tokenize and generate
28
  inputs = self.processor(input_text, add_special_tokens=True).to(self.device)
 
 
29
  output = self.model.generate(**inputs, output_audio=True)
30
  audio_tensor = output[0].to(torch.float32).cpu().numpy()
31
 
32
+ # Encode audio to base64 WAV
33
  import io, base64
34
  buffer = io.BytesIO()
35
  sf.write(buffer, audio_tensor, sampling_rate, format="WAV")