xcczach commited on
Commit
80c9a1c
·
verified ·
1 Parent(s): af85ab9

Upload model

Browse files
Files changed (4) hide show
  1. config.json +2 -2
  2. configuration_yags.py +25 -0
  3. modeling_yags.py +648 -0
  4. pytorch_model.bin +1 -1
config.json CHANGED
@@ -377,8 +377,8 @@
377
  "GPTSoVITSModel"
378
  ],
379
  "auto_map": {
380
- "AutoConfig": "configuration_gpt_sovits.GPTSoVITSConfig",
381
- "AutoModel": "modeling_gpt_sovits.GPTSoVITSModel"
382
  },
383
  "model_type": "gpt_sovits",
384
  "prompt_language": "zh",
 
377
  "GPTSoVITSModel"
378
  ],
379
  "auto_map": {
380
+ "AutoConfig": "configuration_yags.GPTSoVITSConfig",
381
+ "AutoModel": "modeling_yags.GPTSoVITSModel"
382
  },
383
  "model_type": "gpt_sovits",
384
  "prompt_language": "zh",
configuration_yags.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+
5
+ class GPTSoVITSConfig(PretrainedConfig):
6
+ model_type = "gpt_sovits"
7
+
8
+ def __init__(
9
+ self,
10
+ prompt_language: str="zh",
11
+ _hubert_config_dict: dict[str, any] = None,
12
+ _hubert_extractor_config_dict: dict[str, any] = None,
13
+ _bert_config_dict: dict[str, any] = None,
14
+ _hps_dict: dict[str, any] = None,
15
+ _gpt_config_dict: dict[str, any] = None,
16
+ **kwargs
17
+ ):
18
+ self.prompt_language = prompt_language
19
+ self._hubert_config_dict = _hubert_config_dict
20
+ self._hubert_extractor_config_dict = _hubert_extractor_config_dict
21
+ self._bert_config_dict = _bert_config_dict
22
+ self._hps_dict = _hps_dict
23
+ self._gpt_config_dict = _gpt_config_dict
24
+
25
+ super().__init__(**kwargs)
modeling_yags.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration_yags import GPTSoVITSConfig
3
+
4
+ import os
5
+ import re
6
+ import LangSegment
7
+ import torch
8
+ import librosa
9
+ import numpy as np
10
+ import soundfile as sf
11
+ from transformers import AutoModelForMaskedLM, BertConfig
12
+
13
+ from .t2s_lightning_module import \
14
+ Text2SemanticLightningModule
15
+ from . import cnhubert
16
+ from .mel_processing import spectrogram_torch
17
+ # from io import BytesIO
18
+ from .models import SynthesizerTrn
19
+ from .my_utils import load_audio
20
+ from .symbols import cleaned_text_to_sequence
21
+ from .cleaner import clean_text
22
+
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ class DictToAttrRecursive(dict):
26
+ def __init__(self, input_dict):
27
+ super().__init__(input_dict)
28
+ for key, value in input_dict.items():
29
+ if isinstance(value, dict):
30
+ value = DictToAttrRecursive(value)
31
+ self[key] = value
32
+ setattr(self, key, value)
33
+
34
+ def __getattr__(self, item):
35
+ try:
36
+ return self[item]
37
+ except KeyError:
38
+ raise AttributeError(f"Attribute {item} not found")
39
+
40
+ def __setattr__(self, key, value):
41
+ if isinstance(value, dict):
42
+ value = DictToAttrRecursive(value)
43
+ super(DictToAttrRecursive, self).__setitem__(key, value)
44
+ super().__setattr__(key, value)
45
+
46
+ def __delattr__(self, item):
47
+ try:
48
+ del self[item]
49
+ except KeyError:
50
+ raise AttributeError(f"Attribute {item} not found")
51
+
52
+ dict_language = {
53
+ "中文": "all_zh",#全部按中文识别
54
+ "英文": "en",#全部按英文识别#######不变
55
+ "日文": "all_ja",#全部按日文识别
56
+ "中英混合": "zh",#按中英混合识别####不变
57
+ "日英混合": "ja",#按日英混合识别####不变
58
+ "多语种混合": "auto",#多语种启动切分识别语种
59
+ "ZH": "zh",
60
+ "EN": "en",
61
+ "JA": "ja",
62
+ "zh": "zh",
63
+ "en": "en",
64
+ "ja": "ja",
65
+ "all_zh": "all_zh", #手动添加,以防万一
66
+ "all_ja": "all_ja", #手动添加,以防万一
67
+ "auto": "auto" #手动添加,以防万一
68
+ }
69
+
70
+ splits = {
71
+ ",",
72
+ "。",
73
+ "?",
74
+ "!",
75
+ ",",
76
+ ".",
77
+ "?",
78
+ "!",
79
+ "~",
80
+ ":",
81
+ ":",
82
+ "—",
83
+ "…",
84
+ } # 不考虑省略号
85
+
86
+ def splite_en_inf(sentence, language):
87
+ pattern = re.compile(r'[a-zA-Z ]+')
88
+ textlist = []
89
+ langlist = []
90
+ pos = 0
91
+ for match in pattern.finditer(sentence):
92
+ start, end = match.span()
93
+ if start > pos:
94
+ textlist.append(sentence[pos:start])
95
+ langlist.append(language)
96
+ textlist.append(sentence[start:end])
97
+ langlist.append("en")
98
+ pos = end
99
+ if pos < len(sentence):
100
+ textlist.append(sentence[pos:])
101
+ langlist.append(language)
102
+ # Merge punctuation into previous word
103
+ for i in range(len(textlist)-1, 0, -1):
104
+ if re.match(r'^[\W_]+$', textlist[i]):
105
+ textlist[i-1] += textlist[i]
106
+ del textlist[i]
107
+ del langlist[i]
108
+ # Merge consecutive words with the same language tag
109
+ i = 0
110
+ while i < len(langlist) - 1:
111
+ if langlist[i] == langlist[i+1]:
112
+ textlist[i] += textlist[i+1]
113
+ del textlist[i+1]
114
+ del langlist[i+1]
115
+ else:
116
+ i += 1
117
+
118
+ return textlist, langlist
119
+
120
+ def clean_text_inf(text, language):
121
+ formattext = ""
122
+ language = language.replace("all_","")
123
+ for tmp in LangSegment.getTexts(text):
124
+ if language == "ja":
125
+ if tmp["lang"] == language or tmp["lang"] == "zh":
126
+ formattext += tmp["text"] + " "
127
+ continue
128
+ if tmp["lang"] == language:
129
+ formattext += tmp["text"] + " "
130
+ while " " in formattext:
131
+ formattext = formattext.replace(" ", " ")
132
+ phones, word2ph, norm_text = clean_text(formattext, language)
133
+ phones = cleaned_text_to_sequence(phones)
134
+ return phones, word2ph, norm_text
135
+
136
+
137
+
138
+ def nonen_clean_text_inf(text, language):
139
+ if(language!="auto"):
140
+ textlist, langlist = splite_en_inf(text, language)
141
+ else:
142
+ textlist=[]
143
+ langlist=[]
144
+ for tmp in LangSegment.getTexts(text):
145
+ langlist.append(tmp["lang"])
146
+ textlist.append(tmp["text"])
147
+ phones_list = []
148
+ word2ph_list = []
149
+ norm_text_list = []
150
+ for i in range(len(textlist)):
151
+ lang = langlist[i]
152
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
153
+ phones_list.append(phones)
154
+ if lang == "zh":
155
+ word2ph_list.append(word2ph)
156
+ norm_text_list.append(norm_text)
157
+ #【日志】 print(word2ph_list)
158
+ phones = sum(phones_list, [])
159
+ word2ph = sum(word2ph_list, [])
160
+ norm_text = ' '.join(norm_text_list)
161
+
162
+ return phones, word2ph, norm_text
163
+
164
+ def get_first(text):
165
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
166
+ text = re.split(pattern, text)[0].strip()
167
+ return text
168
+
169
+ def merge_short_text_in_array(texts, threshold):
170
+ if (len(texts)) < 2:
171
+ return texts
172
+ result = []
173
+ text = ""
174
+ for ele in texts:
175
+ text += ele
176
+ if len(text) >= threshold:
177
+ result.append(text)
178
+ text = ""
179
+ if (len(text) > 0):
180
+ if len(result) == 0:
181
+ result.append(text)
182
+ else:
183
+ result[len(result) - 1] += text
184
+ return result
185
+
186
+ # ====== 对输入文本进行切割 =========
187
+
188
+ def split(todo_text):
189
+ """
190
+ 将大段文本按标点切割,并将每段文本(保留末尾标点)组成列表。
191
+ """
192
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
193
+ if todo_text[-1] not in splits:
194
+ todo_text += "。"
195
+ i_split_head = i_split_tail = 0
196
+ len_text = len(todo_text)
197
+ todo_texts = []
198
+ while 1:
199
+ if i_split_head >= len_text:
200
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
201
+ if todo_text[i_split_head] in splits:
202
+ i_split_head += 1
203
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
204
+ i_split_tail = i_split_head
205
+ else:
206
+ i_split_head += 1
207
+ return todo_texts
208
+
209
+
210
+ def cut1(inp):
211
+ """
212
+ 第一种文本分段法:基于重写的split分割后,凑4段语句推理一次。
213
+ """
214
+ inp = inp.strip("\n")
215
+ inps = split(inp)
216
+ split_idx = list(range(0, len(inps), 4))
217
+ split_idx[-1] = None
218
+ if len(split_idx) > 1:
219
+ opts = []
220
+ for idx in range(len(split_idx) - 1):
221
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
222
+ else:
223
+ opts = [inp]
224
+ return "\n".join(opts)
225
+
226
+
227
+ def cut2(inp):
228
+ """
229
+ 第二种文本分段法:基于重写split分割后,凑50个字推理一次。
230
+ """
231
+ inp = inp.strip("\n")
232
+ inps = split(inp)
233
+ if len(inps) < 2:
234
+ return [inp]
235
+ opts = []
236
+ summ = 0
237
+ tmp_str = ""
238
+ for i in range(len(inps)):
239
+ summ += len(inps[i])
240
+ tmp_str += inps[i]
241
+ if summ > 50:
242
+ summ = 0
243
+ opts.append(tmp_str)
244
+ tmp_str = ""
245
+ if tmp_str != "":
246
+ opts.append(tmp_str)
247
+ if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
248
+ opts[-2] = opts[-2] + opts[-1]
249
+ opts = opts[:-1]
250
+ return "\n".join(opts)
251
+
252
+ def cut3(inp):
253
+ """
254
+ 第三种文本分段法:仅仅按中文句号分割。
255
+ """
256
+ inp = inp.strip("\n")
257
+ return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
258
+
259
+ # 新增两种切法
260
+
261
+ def cut4(inp):
262
+ """
263
+ "按英文句号.切"
264
+ """
265
+ inp = inp.strip("\n")
266
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
267
+
268
+
269
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
270
+ def cut5(inp):
271
+ """
272
+ "按标点符号切"
273
+ """
274
+ # if not re.search(r'[^\w\s]', inp[-1]):
275
+ # inp += '。'
276
+ inp = inp.strip("\n")
277
+ punds = r'[,.;?!、,。?!;:…]'
278
+ items = re.split(f'({punds})', inp)
279
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
280
+ # 在句子不存在符号或句尾无符号的时候保证文本完整
281
+ if len(items)%2 == 1:
282
+ mergeitems.append(items[-1])
283
+ opt = "\n".join(mergeitems)
284
+ return opt
285
+
286
+ def get_spepc(hps, filename):
287
+ audio = load_audio(filename, int(hps.data.sampling_rate))
288
+ audio = torch.FloatTensor(audio)
289
+ audio_norm = audio
290
+ audio_norm = audio_norm.unsqueeze(0)
291
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
292
+ hps.data.win_length, center=False)
293
+ return spec
294
+
295
+ class GPTSoVITSModel(PreTrainedModel):
296
+ config_class = GPTSoVITSConfig
297
+
298
+ def __init__(self, config: GPTSoVITSConfig):
299
+ super().__init__(config)
300
+ self.name_or_path = config.name_or_path
301
+ current_dir = os.path.dirname(os.path.abspath(__file__))
302
+ try:
303
+ for file in ["opencpop-strict.txt","cmudict-fast.rep","cmudict.rep","engdict-hot.rep"]:
304
+ hf_hub_download(
305
+ repo_id=self.name_or_path,
306
+ filename=file,
307
+ repo_type="model",
308
+ local_dir=current_dir
309
+ )
310
+ except:
311
+ print("Download not executed: maybe under dev mode, please put the files in current directory")
312
+ pass
313
+
314
+ self.prompt_language = config.prompt_language
315
+
316
+ self.ssl_model = cnhubert.CNHubert(config._hubert_config_dict, config._hubert_extractor_config_dict)
317
+ self.bert_model = AutoModelForMaskedLM.from_config(BertConfig.from_dict(config._bert_config_dict))
318
+ self.hps = DictToAttrRecursive(config._hps_dict)
319
+ self.hps.model.semantic_frame_rate = "25hz"
320
+ self.gpt_config = config._gpt_config_dict
321
+ self.vq_model = SynthesizerTrn(
322
+ self.hps.data.filter_length // 2 + 1,
323
+ self.hps.train.segment_size // self.hps.data.hop_length,
324
+ n_speakers=self.hps.data.n_speakers,
325
+ **self.hps.model)
326
+ self.t2s_model = Text2SemanticLightningModule(self.gpt_config, "ojbk", is_train=False)
327
+ try:
328
+ self.ref_wav_path = hf_hub_download(
329
+ repo_id=self.name_or_path,
330
+ filename="ref.wav",
331
+ repo_type="model",
332
+ local_dir = current_dir
333
+ )
334
+ self.prompt_text_path = hf_hub_download(
335
+ repo_id=self.name_or_path,
336
+ filename="ref.txt",
337
+ repo_type="model",
338
+ local_dir = current_dir
339
+ )
340
+ except:
341
+ self.ref_wav_path = os.path.join(current_dir, "ref.wav")
342
+ self.prompt_text_path = os.path.join(current_dir, "ref.txt")
343
+ print("Download not executed: maybe under dev mode, please put the files in current directory")
344
+ self.refer = get_spepc(self.hps, self.ref_wav_path)
345
+
346
+
347
+
348
+
349
+ def get_cleaned_text_final(self,text,language):
350
+ """
351
+ 根据语言类型选择适当的文本清洗函数,并返回处理后的音素序列、单词到音素的映射以及规范化文本。
352
+ -> phones,word2ph,norm_text
353
+ - clean_text_inf 针对单一语种{"en","all_zh","all_ja"}
354
+ - clean_text 和 cleaned_text_to_sequence 来自内部text模块cleaner和__init__
355
+ - nonen_clean_text_inf 针对混合语种{"zh", "ja","auto"}
356
+ - splite_en_inf
357
+ """
358
+ if language in {"en","all_zh","all_ja"}:
359
+ phones, word2ph, norm_text = clean_text_inf(text, language)
360
+ elif language in {"zh", "ja","auto"}:
361
+ phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
362
+ return phones, word2ph, norm_text
363
+
364
+ def get_bert_inf(self, phones, word2ph, norm_text, language):
365
+ device = self.device # 【补】
366
+ is_half = self.dtype == torch.float16 # 【补】
367
+
368
+ language=language.replace("all_","")
369
+ if language == "zh":
370
+ bert = self.get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
371
+ else:
372
+ bert = torch.zeros(
373
+ (1024, len(phones)),
374
+ dtype=torch.float16 if is_half == True else torch.float32,
375
+ ).to(device)
376
+
377
+ return bert
378
+
379
+ def get_bert_feature(self, text, word2ph, tokenizer):
380
+
381
+ is_half = self.dtype == torch.float16 # 【补】
382
+ device = self.device # 【补】
383
+
384
+ with torch.no_grad():
385
+ inputs = tokenizer(text, return_tensors="pt")
386
+ for i in inputs:
387
+ inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
388
+ res = self.bert_model(**inputs, output_hidden_states=True)
389
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
390
+ assert len(word2ph) == len(text)
391
+ phone_level_feature = []
392
+ for i in range(len(word2ph)):
393
+ repeat_feature = res[i].repeat(word2ph[i], 1)
394
+ phone_level_feature.append(repeat_feature)
395
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
396
+ if(is_half==True):phone_level_feature=phone_level_feature.half()
397
+
398
+ return phone_level_feature.T
399
+
400
+ # ======适配混合语种输出======
401
+ # ===
402
+ def get_cleaned_text_final(self,text,language):
403
+ """
404
+ 根据语言类型选择适当的文本清洗函数,并返回处理后的音素序列、单词到音素的映射以及规范化文本。
405
+ -> phones,word2ph,norm_text
406
+ - clean_text_inf 针对单一语种{"en","all_zh","all_ja"}
407
+ - clean_text 和 cleaned_text_to_sequence 来自内部text模块cleaner和__init__
408
+ - nonen_clean_text_inf 针对混合语种{"zh", "ja","auto"}
409
+ - splite_en_inf
410
+ """
411
+ if language in {"en","all_zh","all_ja"}:
412
+ phones, word2ph, norm_text = clean_text_inf(text, language)
413
+ elif language in {"zh", "ja","auto"}:
414
+ phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
415
+ return phones, word2ph, norm_text
416
+
417
+ def get_bert_inf(self, phones, word2ph, norm_text, language, tokenizer):
418
+ device = self.device # 【补】
419
+ is_half = self.dtype == torch.float16 # 【补】
420
+
421
+ language=language.replace("all_","")
422
+ if language == "zh":
423
+ bert = self.get_bert_feature(norm_text, word2ph,tokenizer).to(device)#.to(dtype)
424
+ else:
425
+ bert = torch.zeros(
426
+ (1024, len(phones)),
427
+ dtype=torch.float16 if is_half == True else torch.float32,
428
+ ).to(device)
429
+
430
+ return bert
431
+
432
+ def nonen_get_bert_inf(self, text, language, tokenizer):
433
+ if(language!="auto"):
434
+ textlist, langlist = splite_en_inf(text, language)
435
+ else:
436
+ textlist=[]
437
+ langlist=[]
438
+ for tmp in LangSegment.getTexts(text):
439
+ langlist.append(tmp["lang"])
440
+ textlist.append(tmp["text"])
441
+ print(textlist)
442
+ print(langlist)
443
+ bert_list = []
444
+ for i in range(len(textlist)):
445
+ lang = langlist[i]
446
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
447
+ bert = self.get_bert_inf(phones, word2ph, norm_text, lang,tokenizer)
448
+ bert_list.append(bert)
449
+ bert = torch.cat(bert_list, dim=1)
450
+
451
+ return bert
452
+
453
+ def get_bert_final(self,phones, word2ph, text,language, tokenizer):
454
+ """
455
+ 根据语言 选择调用不同的函数来得到一个bert表示。
456
+ 需要输入Get_clean_text_final得到的文字素材
457
+ -> bert
458
+ - get_bert_inf 针对纯英文”en”
459
+ - nonen_get_bert_inf 针对混合语种{"zh", "ja","auto"}
460
+ - get_bert_feature 针对纯中文”all_zh”
461
+ """
462
+ device = self.device # 【补】
463
+
464
+ if language == "en":
465
+ bert = self.get_bert_inf(phones, word2ph, text, language, tokenizer) # 【补】
466
+ elif language in {"zh", "ja","auto"}:
467
+ bert = self.nonen_get_bert_inf(text, language, tokenizer)
468
+ elif language == "all_zh":
469
+ bert = self.get_bert_feature(text, word2ph, tokenizer).to(device)
470
+ else:
471
+ bert = torch.zeros((1024, len(phones))).to(device)
472
+ return bert
473
+
474
+ # ===
475
+ # ======适配混合语种输出======
476
+
477
+ def infer(self, text, tokenizer, text_language="zh",
478
+ how_to_cut="凑四句一切",
479
+ top_k=20, top_p=0.6, temperature=0.6,
480
+ # 关于上面三个参数 https://github.com/RVC-Boss/GPT-SoVITS/pull/457
481
+ # 可以通过降低温度,降低top_p,top_k 提升模型输出内容的一致性
482
+ ref_free = False) -> tuple[np.ndarray,float|int]: # 在不知道参考音频文本的情况下进行推理
483
+
484
+ # ====== 函数内变量 ======
485
+ # ===
486
+ # 根据声色指定相关模型与参考语音
487
+ ref_wav_path = self.ref_wav_path
488
+
489
+ if not ref_free:
490
+ prompt_text_path = self.prompt_text_path
491
+ with open(prompt_text_path, 'r', encoding='utf-8') as file:
492
+ prompt_text = file.read()
493
+ # 如果txt中音频文本为空,则也不使用音频文本。
494
+ if prompt_text is None or len(prompt_text) == 0:
495
+ ref_free = True
496
+ prompt_language = self.prompt_language
497
+
498
+
499
+ device = self.device
500
+ is_half = self.dtype == torch.float16
501
+ dtype = self.dtype
502
+
503
+ hz = 50
504
+ max_sec = self.gpt_config['data']['max_sec']
505
+ # ===
506
+ # ====== 函数内变量 ======
507
+
508
+
509
+ # 确认参考语音和推理文本的语种(可以不必,已对prompt_language和text_language的输入做了严格限制)
510
+ prompt_language = dict_language[prompt_language]
511
+ text_language = dict_language[text_language]
512
+
513
+ if not ref_free:
514
+ prompt_text = prompt_text.strip("\n")
515
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
516
+ #【日志】 print("实际输入的参考文本:", prompt_text)
517
+
518
+ # 预处理推理文本:文本第一段(get_first)若特别短<4字符,则在文本最前方加上句号。
519
+ text = text.strip("\n")
520
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
521
+
522
+ #【日志】 print("实际输入的目标文本:", text)
523
+
524
+ # 创建空音频段
525
+ # 第一个with torch.no_grad() 从参考音频中提取语义信息,并把空音频段放到参考音频末尾->prompt_semantic
526
+ zero_wav = np.zeros(
527
+ int(self.hps.data.sampling_rate * 0.3), # 【补】
528
+ dtype=np.float16 if is_half == True else np.float32,
529
+ )
530
+ with torch.no_grad():
531
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
532
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
533
+ raise OSError("参考音频在3~10秒范围外,请更换!")
534
+ wav16k = torch.from_numpy(wav16k)
535
+ zero_wav_torch = torch.from_numpy(zero_wav)
536
+ if is_half == True:
537
+ wav16k = wav16k.half().to(device)
538
+ zero_wav_torch = zero_wav_torch.half().to(device)
539
+ else:
540
+ wav16k = wav16k.to(device)
541
+ zero_wav_torch = zero_wav_torch.to(device)
542
+ wav16k = torch.cat([wav16k, zero_wav_torch])
543
+ ssl_content = self.ssl_model.model(wav16k.unsqueeze(0))[
544
+ "last_hidden_state"
545
+ ].transpose(
546
+ 1, 2
547
+ ) # .float()
548
+ codes = self.vq_model.extract_latent(ssl_content)
549
+
550
+ prompt_semantic = codes[0, 0]
551
+
552
+ # 切分推理文本,5种方法。一般可选4句一切和��标点符号切。之后,将其中小于5的语句/短语合并(merge_short_text_in_array)。最终得到推理文本切割列表
553
+ # -> texts
554
+ if (how_to_cut == "凑四句一切"):
555
+ text = cut1(text)
556
+ elif (how_to_cut == "凑50字一切"):
557
+ text = cut2(text)
558
+ elif (how_to_cut == "按中文句号。切"):
559
+ text = cut3(text)
560
+ elif (how_to_cut == "按英文句号.切"):
561
+ text = cut4(text)
562
+ elif (how_to_cut == "按标点符号切"):
563
+ text = cut5(text)
564
+ while "\n\n" in text:
565
+ text = text.replace("\n\n", "\n")
566
+
567
+ #【日志】 print("实际输入的目标文本(切句后):", text)
568
+ texts = text.split("\n")
569
+ texts = merge_short_text_in_array(texts, 5)
570
+ audio_opt = []
571
+ if not ref_free:
572
+ # 处理参考文本(get_cleaned_text_final)得到文字素材
573
+ # -> phones1,word2ph1,norm_text1
574
+ phones1, word2ph1, norm_text1=self.get_cleaned_text_final(prompt_text, prompt_language)
575
+ # 处理参考语音(Get_bert_final) 输入文字素材phones1,word2ph1,norm_text1
576
+ # 得到bert表示
577
+ # ->bert1
578
+ bert1=self.get_bert_final(phones1, word2ph1, norm_text1,prompt_language,tokenizer).to(dtype)
579
+
580
+ # for循环 处理推理文本,对texts中的每一段语句/短语
581
+ # 处理文本(get_cleaned_text_final)得到文字素材
582
+ # -> phones2,word2ph2,norm_text2
583
+ # 处理参考语音(Get_bert_final) 输入文字素材phones2,word2ph2,norm_text2
584
+ # 得到bert表示
585
+ # ->bert2
586
+ for text in texts:
587
+ # 解决输入目标文本的空行导致报错的问题
588
+ if (len(text.strip()) == 0):
589
+ continue
590
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
591
+ # 【日志】print("实际输入的目标文本(每句):", text)
592
+ phones2, word2ph2, norm_text2 = self.get_cleaned_text_final(text, text_language)
593
+ bert2 = self.get_bert_final(phones2, word2ph2, norm_text2, text_language,tokenizer).to(dtype)
594
+ if not ref_free:
595
+ bert = torch.cat([bert1, bert2], 1)
596
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
597
+ else:
598
+ bert = bert2
599
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
600
+
601
+ bert = bert.to(device).unsqueeze(0)
602
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
603
+ prompt = prompt_semantic.unsqueeze(0).to(device)
604
+
605
+ with torch.no_grad():
606
+ # pred_semantic = t2s_model.model.infer(
607
+ pred_semantic, idx = self.t2s_model.model.infer_panel(
608
+ all_phoneme_ids,
609
+ all_phoneme_len,
610
+ None if ref_free else prompt,
611
+ bert,
612
+ # prompt_phone_len=ph_offset,
613
+ top_k=top_k,
614
+ top_p=top_p,
615
+ temperature=temperature,
616
+ early_stop_num=hz * max_sec,
617
+ )
618
+
619
+ # print(pred_semantic.shape,idx)
620
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
621
+ 0
622
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
623
+ refer = get_spepc(self.hps, ref_wav_path) # .to(device) # 【补】
624
+ if is_half == True:
625
+ refer = refer.half().to(device)
626
+ else:
627
+ refer = refer.to(device)
628
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
629
+ audio = (
630
+ self.vq_model.decode( # 【补】
631
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
632
+ )
633
+ .detach()
634
+ .cpu()
635
+ .numpy()[0, 0]
636
+ ) ###试试重建不带上prompt部分
637
+ max_audio=np.abs(audio).max()#简单防止16bit爆音
638
+ if max_audio>1:audio/=max_audio
639
+ audio_opt.append(audio)
640
+ audio_opt.append(zero_wav)
641
+
642
+ sampling_rate, audio_data = self.hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
643
+ np.int16
644
+ )
645
+
646
+ # sf.write(wav_save_path, audio_data, sampling_rate, format='wav')
647
+ torch.cuda.empty_cache()
648
+ return audio_data, sampling_rate
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d2d132bb01eae38f54fd4eb0f2fd8087f0ddbeeeace5b81a39c26586db9a8ee
3
  size 2201587998
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c9291172a30df7f6e38fe99950031f9276550dc850202ab84f426492826fc00
3
  size 2201587998