Spaces:
Paused
Paused
| # Copyright (c) 2023 Amphion. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import os | |
| import numpy as np | |
| import torch | |
| from text.symbol_table import SymbolTable | |
| from text import text_to_sequence | |
| """ | |
| TextToken: map text to id | |
| """ | |
| # TextTokenCollator is modified from | |
| # https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/collation.py | |
| class TextTokenCollator: | |
| def __init__( | |
| self, | |
| text_tokens: List[str], | |
| add_eos: bool = True, | |
| add_bos: bool = True, | |
| pad_symbol: str = "<pad>", | |
| bos_symbol: str = "<bos>", | |
| eos_symbol: str = "<eos>", | |
| ): | |
| self.pad_symbol = pad_symbol | |
| self.add_eos = add_eos | |
| self.add_bos = add_bos | |
| self.bos_symbol = bos_symbol | |
| self.eos_symbol = eos_symbol | |
| unique_tokens = [pad_symbol] | |
| if add_bos: | |
| unique_tokens.append(bos_symbol) | |
| if add_eos: | |
| unique_tokens.append(eos_symbol) | |
| unique_tokens.extend(sorted(text_tokens)) | |
| self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} | |
| self.idx2token = unique_tokens | |
| def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| seqs, seq_lens = [], [] | |
| for tokens in tokens_list: | |
| assert all([True if s in self.token2idx else False for s in tokens]) is True | |
| seq = ( | |
| ([self.bos_symbol] if self.add_bos else []) | |
| + list(tokens) | |
| + ([self.eos_symbol] if self.add_eos else []) | |
| ) | |
| seqs.append(seq) | |
| seq_lens.append(len(seq)) | |
| max_len = max(seq_lens) | |
| for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): | |
| seq.extend([self.pad_symbol] * (max_len - seq_len)) | |
| tokens = torch.from_numpy( | |
| np.array( | |
| [[self.token2idx[token] for token in seq] for seq in seqs], | |
| dtype=np.int64, | |
| ) | |
| ) | |
| tokens_lens = torch.IntTensor(seq_lens) | |
| return tokens, tokens_lens | |
| def __call__(self, text): | |
| tokens_seq = [p for p in text] | |
| seq = ( | |
| ([self.bos_symbol] if self.add_bos else []) | |
| + tokens_seq | |
| + ([self.eos_symbol] if self.add_eos else []) | |
| ) | |
| token_ids = [self.token2idx[token] for token in seq] | |
| token_lens = len(tokens_seq) + self.add_eos + self.add_bos | |
| return token_ids, token_lens | |
| def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator: | |
| text_tokens_path = Path(text_tokens_file) | |
| unique_tokens = SymbolTable.from_file(text_tokens_path) | |
| collater = TextTokenCollator(unique_tokens.symbols, add_bos=True, add_eos=True) | |
| token2idx = collater.token2idx | |
| return collater, token2idx | |
| class phoneIDCollation: | |
| def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None: | |
| if cfg.preprocess.phone_extractor != "lexicon": | |
| ### get text token collator | |
| if symbols_dict_file is None: | |
| assert dataset is not None | |
| symbols_dict_file = os.path.join( | |
| cfg.preprocess.processed_dir, dataset, cfg.preprocess.symbols_dict | |
| ) | |
| self.text_token_colloator, token2idx = get_text_token_collater( | |
| symbols_dict_file | |
| ) | |
| # # unique_tokens = SymbolTable.from_file(symbols_dict_path) | |
| # # text_tokenizer = TextToken(unique_tokens.symbols, add_bos=True, add_eos=True) | |
| # # update phone symbols dict file with pad_symbol or optional tokens (add_bos and add_eos) in TextTokenCollator | |
| # phone_symbol_dict = SymbolTable() | |
| # for s in sorted(list(set(token2idx.keys()))): | |
| # phone_symbol_dict.add(s) | |
| # phone_symbol_dict.to_file(symbols_dict_file) | |
| def get_phone_id_sequence(self, cfg, phones_seq): | |
| if cfg.preprocess.phone_extractor == "lexicon": | |
| phones_seq = " ".join(phones_seq) | |
| sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners) | |
| else: | |
| sequence, seq_len = self.text_token_colloator(phones_seq) | |
| return sequence | |