|
import torch |
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
from transformers import AutoTokenizer, set_seed |
|
import soundfile as sf |
|
import base64 |
|
import logging |
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
self.model = ParlerTTSForConditionalGeneration.from_pretrained( |
|
"parler-tts/parler-tts-mini-expresso", |
|
torch_dtype=torch.float16 |
|
).to(self.device) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso") |
|
|
|
def __call__(self, data): |
|
inputs = data["inputs"] |
|
prompt = inputs["prompt"] |
|
description = inputs["description"] |
|
|
|
input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device) |
|
prompt_input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
|
|
set_seed(42) |
|
generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) |
|
audio_arr = generation.cpu().numpy().squeeze() |
|
sf.write("parler_tts_out.wav", audio_arr, self.model.config.sampling_rate) |
|
with open("parler_tts_out.wav", "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
|