ollieollie commited on
Commit
bff7fc0
·
verified ·
1 Parent(s): de60788

Update chatterbox/tts_turbo.py

Browse files
Files changed (1) hide show
  1. chatterbox/tts_turbo.py +293 -186
chatterbox/tts_turbo.py CHANGED
@@ -1,189 +1,296 @@
1
- import random
2
  import os
3
- import numpy as np
 
 
 
 
4
  import torch
5
- import gradio as gr
6
- import spaces
7
- from chatterbox.tts_turbo import ChatterboxTurboTTS
8
-
9
- # --- 1. FORCE CPU FOR GLOBAL LOADING ---
10
- # ZeroGPU forbids CUDA during startup. We only move to CUDA inside the decorated function.
11
- DEVICE = "cpu"
12
-
13
- MODEL = None
14
-
15
- EVENT_TAGS = [
16
- "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
17
- "[sniff]", "[gasp]", "[chuckle]", "[laugh]"
18
- ]
19
-
20
- CUSTOM_CSS = """
21
- .tag-container {
22
- display: flex !important;
23
- flex-wrap: wrap !important;
24
- gap: 8px !important;
25
- margin-top: 5px !important;
26
- margin-bottom: 10px !important;
27
- border: none !important;
28
- background: transparent !important;
29
- }
30
-
31
- .tag-btn {
32
- min-width: fit-content !important;
33
- width: auto !important;
34
- height: 32px !important;
35
- font-size: 13px !important;
36
- background: #eef2ff !important;
37
- border: 1px solid #c7d2fe !important;
38
- color: #3730a3 !important;
39
- border-radius: 6px !important;
40
- padding: 0 10px !important;
41
- margin: 0 !important;
42
- box-shadow: none !important;
43
- }
44
-
45
- .tag-btn:hover {
46
- background: #c7d2fe !important;
47
- transform: translateY(-1px);
48
- }
49
- """
50
-
51
- INSERT_TAG_JS = """
52
- (tag_val, current_text) => {
53
- const textarea = document.querySelector('#main_textbox textarea');
54
- if (!textarea) return current_text + " " + tag_val;
55
-
56
- const start = textarea.selectionStart;
57
- const end = textarea.selectionEnd;
58
-
59
- let prefix = " ";
60
- let suffix = " ";
61
-
62
- if (start === 0) prefix = "";
63
- else if (current_text[start - 1] === ' ') prefix = "";
64
-
65
- if (end < current_text.length && current_text[end] === ' ') suffix = "";
66
-
67
- return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
68
- }
69
- """
70
-
71
- def set_seed(seed: int):
72
- torch.manual_seed(seed)
73
- torch.cuda.manual_seed(seed)
74
- torch.cuda.manual_seed_all(seed)
75
- random.seed(seed)
76
- np.random.seed(seed)
77
-
78
-
79
- def load_model():
80
- global MODEL
81
- print(f"Loading Chatterbox-Turbo on {DEVICE}...")
82
- MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
83
- return MODEL
84
-
85
- @spaces.GPU
86
- def generate(
87
- text,
88
- audio_prompt_path,
89
- temperature,
90
- seed_num,
91
- min_p,
92
- top_p,
93
- top_k,
94
- repetition_penalty,
95
- norm_loudness
96
- ):
97
- global MODEL
98
- # Reload if the worker lost the global state
99
- if MODEL is None:
100
- MODEL = ChatterboxTurboTTS.from_pretrained("cpu")
101
-
102
- # --- MOVE TO GPU HERE ---
103
- MODEL.to("cuda")
104
-
105
- if seed_num != 0:
106
- set_seed(int(seed_num))
107
-
108
- wav = MODEL.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  text,
110
- audio_prompt_path=audio_prompt_path,
111
- temperature=temperature,
112
- min_p=min_p,
113
- top_p=top_p,
114
- top_k=int(top_k),
115
- repetition_penalty=repetition_penalty,
116
- norm_loudness=norm_loudness,
117
- )
118
-
119
- return (MODEL.sr, wav.squeeze(0).cpu().numpy())
120
-
121
-
122
- with gr.Blocks(title="Chatterbox Turbo") as demo:
123
- gr.Markdown("# Chatterbox Turbo")
124
-
125
- with gr.Row():
126
- with gr.Column():
127
- text = gr.Textbox(
128
- value="Congratulations Miss Connor! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
129
- label="Text to synthesize (max chars 300)",
130
- max_lines=5,
131
- elem_id="main_textbox"
132
- )
133
-
134
- with gr.Row(elem_classes=["tag-container"]):
135
- for tag in EVENT_TAGS:
136
- btn = gr.Button(tag, elem_classes=["tag-btn"])
137
- btn.click(
138
- fn=None,
139
- inputs=[btn, text],
140
- outputs=text,
141
- js=INSERT_TAG_JS
142
- )
143
-
144
- ref_wav = gr.Audio(
145
- sources=["upload", "microphone"],
146
- type="filepath",
147
- label="Reference Audio File",
148
- value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
149
- )
150
-
151
- run_btn = gr.Button("Generate ⚡", variant="primary")
152
-
153
- with gr.Column():
154
- audio_output = gr.Audio(label="Output Audio")
155
-
156
- with gr.Accordion("Advanced Options", open=False):
157
- seed_num = gr.Number(value=0, label="Random seed (0 for random)")
158
- temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
159
- top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
160
- top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
161
- repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
162
- min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
163
- norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (Match prompt volume)")
164
-
165
- # Load on startup (CPU)
166
- demo.load(fn=load_model, inputs=[], outputs=[])
167
-
168
- run_btn.click(
169
- fn=generate,
170
- inputs=[
171
- text,
172
- ref_wav,
173
- temp,
174
- seed_num,
175
- min_p,
176
- top_p,
177
- top_k,
178
- repetition_penalty,
179
- norm_loudness,
180
- ],
181
- outputs=audio_output,
182
- )
183
-
184
- if __name__ == "__main__":
185
- demo.queue().launch(
186
- mcp_server=True,
187
- css=CUSTOM_CSS,
188
- ssr_mode=False
189
- )
 
 
1
  import os
2
+ import math
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import librosa
7
  import torch
8
+ import perth
9
+ import pyloudnorm as ln
10
+
11
+ from safetensors.torch import load_file
12
+ from huggingface_hub import snapshot_download
13
+ from transformers import AutoTokenizer
14
+
15
+ from .models.t3 import T3
16
+ from .models.s3tokenizer import S3_SR
17
+ from .models.s3gen import S3GEN_SR, S3Gen
18
+ from .models.tokenizers import EnTokenizer
19
+ from .models.voice_encoder import VoiceEncoder
20
+ from .models.t3.modules.cond_enc import T3Cond
21
+ from .models.t3.modules.t3_config import T3Config
22
+ from .models.s3gen.const import S3GEN_SIL
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ REPO_ID = "ResembleAI/chatterbox-turbo"
27
+
28
+
29
+ def punc_norm(text: str) -> str:
30
+ """
31
+ Quick cleanup func for punctuation from LLMs or
32
+ containing chars not seen often in the dataset
33
+ """
34
+ if len(text) == 0:
35
+ return "You need to add some text for me to talk."
36
+
37
+ # Capitalise first letter
38
+ if text[0].islower():
39
+ text = text[0].upper() + text[1:]
40
+
41
+ # Remove multiple space chars
42
+ text = " ".join(text.split())
43
+
44
+ # Replace uncommon/llm punc
45
+ punc_to_replace = [
46
+ ("…", ", "),
47
+ (":", ","),
48
+ ("—", "-"),
49
+ ("–", "-"),
50
+ (" ,", ","),
51
+ ("“", "\""),
52
+ ("", "\""),
53
+ ("‘", "'"),
54
+ ("’", "'"),
55
+ ]
56
+ for old_char_sequence, new_char in punc_to_replace:
57
+ text = text.replace(old_char_sequence, new_char)
58
+
59
+ # Add full stop if no ending punc
60
+ text = text.rstrip(" ")
61
+ sentence_enders = {".", "!", "?", "-", ","}
62
+ if not any(text.endswith(p) for p in sentence_enders):
63
+ text += "."
64
+
65
+ return text
66
+
67
+
68
+ @dataclass
69
+ class Conditionals:
70
+ """
71
+ Conditionals for T3 and S3Gen
72
+ - T3 conditionals:
73
+ - speaker_emb
74
+ - clap_emb
75
+ - cond_prompt_speech_tokens
76
+ - cond_prompt_speech_emb
77
+ - emotion_adv
78
+ - S3Gen conditionals:
79
+ - prompt_token
80
+ - prompt_token_len
81
+ - prompt_feat
82
+ - prompt_feat_len
83
+ - embedding
84
+ """
85
+ t3: T3Cond
86
+ gen: dict
87
+
88
+ def to(self, device):
89
+ self.t3 = self.t3.to(device=device)
90
+ for k, v in self.gen.items():
91
+ if torch.is_tensor(v):
92
+ self.gen[k] = v.to(device=device)
93
+ return self
94
+
95
+ def save(self, fpath: Path):
96
+ arg_dict = dict(
97
+ t3=self.t3.__dict__,
98
+ gen=self.gen
99
+ )
100
+ torch.save(arg_dict, fpath)
101
+
102
+ @classmethod
103
+ def load(cls, fpath, map_location="cpu"):
104
+ if isinstance(map_location, str):
105
+ map_location = torch.device(map_location)
106
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
107
+ return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
108
+
109
+
110
+ class ChatterboxTurboTTS:
111
+ ENC_COND_LEN = 15 * S3_SR
112
+ DEC_COND_LEN = 10 * S3GEN_SR
113
+
114
+ def __init__(
115
+ self,
116
+ t3: T3,
117
+ s3gen: S3Gen,
118
+ ve: VoiceEncoder,
119
+ tokenizer: EnTokenizer,
120
+ device: str,
121
+ conds: Conditionals = None,
122
+ ):
123
+ self.sr = S3GEN_SR # sample rate of synthesized audio
124
+ self.t3 = t3
125
+ self.s3gen = s3gen
126
+ self.ve = ve
127
+ self.tokenizer = tokenizer
128
+ self.device = device
129
+ self.conds = conds
130
+ self.watermarker = perth.PerthImplicitWatermarker()
131
+
132
+ @classmethod
133
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
134
+ ckpt_dir = Path(ckpt_dir)
135
+
136
+ # Always load to CPU first for non-CUDA devices to handle CUDA-saved models
137
+ if device in ["cpu", "mps"]:
138
+ map_location = torch.device('cpu')
139
+ else:
140
+ map_location = None
141
+
142
+ ve = VoiceEncoder()
143
+ ve.load_state_dict(
144
+ load_file(ckpt_dir / "ve.safetensors")
145
+ )
146
+ ve.to(device).eval()
147
+
148
+ # Turbo specific hp
149
+ hp = T3Config(text_tokens_dict_size=50276)
150
+ hp.llama_config_name = "GPT2_medium"
151
+ hp.speech_tokens_dict_size = 6563
152
+ hp.input_pos_emb = None
153
+ hp.speech_cond_prompt_len = 375
154
+ hp.use_perceiver_resampler = False
155
+ hp.emotion_adv = False
156
+
157
+ t3 = T3(hp)
158
+ t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
159
+ if "model" in t3_state.keys():
160
+ t3_state = t3_state["model"][0]
161
+ t3.load_state_dict(t3_state)
162
+ del t3.tfmr.wte
163
+ t3.to(device).eval()
164
+
165
+ s3gen = S3Gen(meanflow=True)
166
+ weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
167
+ s3gen.load_state_dict(
168
+ weights, strict=True
169
+ )
170
+ s3gen.to(device).eval()
171
+
172
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
173
+ if tokenizer.pad_token is None:
174
+ tokenizer.pad_token = tokenizer.eos_token
175
+ if len(tokenizer) != 50276:
176
+ print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
177
+
178
+ conds = None
179
+ builtin_voice = ckpt_dir / "conds.pt"
180
+ if builtin_voice.exists():
181
+ conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
182
+
183
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
184
+
185
+ @classmethod
186
+ def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
187
+ # Check if MPS is available on macOS
188
+ if device == "mps" and not torch.backends.mps.is_available():
189
+ if not torch.backends.mps.is_built():
190
+ print("MPS not available because the current PyTorch install was not built with MPS enabled.")
191
+ else:
192
+ print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
193
+ device = "cpu"
194
+
195
+ local_path = snapshot_download(
196
+ repo_id=REPO_ID,
197
+ token=os.getenv("HF_TOKEN") or True,
198
+ # Optional: Filter to download only what you need
199
+ allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
200
+ )
201
+
202
+ return cls.from_local(local_path, device)
203
+
204
+ def norm_loudness(self, wav, sr, target_lufs=-27):
205
+ try:
206
+ meter = ln.Meter(sr)
207
+ loudness = meter.integrated_loudness(wav)
208
+ gain_db = target_lufs - loudness
209
+ gain_linear = 10.0 ** (gain_db / 20.0)
210
+ if math.isfinite(gain_linear) and gain_linear > 0.0:
211
+ wav = wav * gain_linear
212
+ except Exception as e:
213
+ print(f"Warning: Error in norm_loudness, skipping: {e}")
214
+
215
+ return wav
216
+
217
+ def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
218
+ ## Load and norm reference wav
219
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
220
+
221
+ assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
222
+
223
+ if norm_loudness:
224
+ s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
225
+
226
+ ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
227
+
228
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
229
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
230
+
231
+ # Speech cond prompt tokens
232
+ if plen := self.t3.hp.speech_cond_prompt_len:
233
+ s3_tokzr = self.s3gen.tokenizer
234
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
235
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
236
+
237
+ # Voice-encoder speaker embedding
238
+ ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
239
+ ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
240
+
241
+ t3_cond = T3Cond(
242
+ speaker_emb=ve_embed,
243
+ cond_prompt_speech_tokens=t3_cond_prompt_tokens,
244
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
245
+ ).to(device=self.device)
246
+ self.conds = Conditionals(t3_cond, s3gen_ref_dict)
247
+
248
+ def generate(
249
+ self,
250
  text,
251
+ repetition_penalty=1.2,
252
+ min_p=0.00,
253
+ top_p=0.95,
254
+ audio_prompt_path=None,
255
+ exaggeration=0.0,
256
+ cfg_weight=0.0,
257
+ temperature=0.8,
258
+ top_k=1000,
259
+ norm_loudness=True,
260
+ ):
261
+ if audio_prompt_path:
262
+ self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
263
+ else:
264
+ assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
265
+
266
+ if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
267
+ logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
268
+
269
+ # Norm and tokenize text
270
+ text = punc_norm(text)
271
+ text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
272
+ text_tokens = text_tokens.input_ids.to(self.device)
273
+
274
+ speech_tokens = self.t3.inference_turbo(
275
+ t3_cond=self.conds.t3,
276
+ text_tokens=text_tokens,
277
+ temperature=temperature,
278
+ top_k=top_k,
279
+ top_p=top_p,
280
+ repetition_penalty=repetition_penalty,
281
+ )
282
+
283
+ # Remove OOV tokens and add silence to end
284
+ speech_tokens = speech_tokens[speech_tokens < 6561]
285
+ speech_tokens = speech_tokens.to(self.device)
286
+ silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
287
+ speech_tokens = torch.cat([speech_tokens, silence])
288
+
289
+ wav, _ = self.s3gen.inference(
290
+ speech_tokens=speech_tokens,
291
+ ref_dict=self.conds.gen,
292
+ n_cfm_timesteps=2,
293
+ )
294
+ wav = wav.squeeze(0).detach().cpu().numpy()
295
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
296
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)