Spaces:
Running
Running
| import argparse | |
| import shutil | |
| from pathlib import Path | |
| from queue import Queue | |
| from threading import Thread | |
| from typing import Any, Optional | |
| import soundfile as sf | |
| import torch | |
| from tqdm import tqdm | |
| from config import get_path_config | |
| from style_bert_vits2.logging import logger | |
| from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT | |
| def is_audio_file(file: Path) -> bool: | |
| supported_extensions = [".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a"] | |
| return file.suffix.lower() in supported_extensions | |
| def get_stamps( | |
| vad_model: Any, | |
| utils: Any, | |
| audio_file: Path, | |
| min_silence_dur_ms: int = 700, | |
| min_sec: float = 2, | |
| max_sec: float = 12, | |
| ): | |
| """ | |
| min_silence_dur_ms: int (ミリ秒): | |
| このミリ秒数以上を無音だと判断する。 | |
| 逆に、この秒数以下の無音区間では区切られない。 | |
| 小さくすると、音声がぶつ切りに小さくなりすぎ、 | |
| 大きくすると音声一つ一つが長くなりすぎる。 | |
| データセットによってたぶん要調整。 | |
| min_sec: float (秒): | |
| この秒数より小さい発話は無視する。 | |
| max_sec: float (秒): | |
| この秒数より大きい発話は無視する。 | |
| """ | |
| (get_speech_timestamps, _, read_audio, *_) = utils | |
| sampling_rate = 16000 # 16kHzか8kHzのみ対応 | |
| min_ms = int(min_sec * 1000) | |
| wav = read_audio(str(audio_file), sampling_rate=sampling_rate) | |
| speech_timestamps = get_speech_timestamps( | |
| wav, | |
| vad_model, | |
| sampling_rate=sampling_rate, | |
| min_silence_duration_ms=min_silence_dur_ms, | |
| min_speech_duration_ms=min_ms, | |
| max_speech_duration_s=max_sec, | |
| ) | |
| return speech_timestamps | |
| def split_wav( | |
| vad_model: Any, | |
| utils: Any, | |
| audio_file: Path, | |
| target_dir: Path, | |
| min_sec: float = 2, | |
| max_sec: float = 12, | |
| min_silence_dur_ms: int = 700, | |
| time_suffix: bool = False, | |
| ) -> tuple[float, int]: | |
| margin: int = 200 # ミリ秒単位で、音声の前後に余裕を持たせる | |
| speech_timestamps = get_stamps( | |
| vad_model=vad_model, | |
| utils=utils, | |
| audio_file=audio_file, | |
| min_silence_dur_ms=min_silence_dur_ms, | |
| min_sec=min_sec, | |
| max_sec=max_sec, | |
| ) | |
| data, sr = sf.read(audio_file) | |
| total_ms = len(data) / sr * 1000 | |
| file_name = audio_file.stem | |
| target_dir.mkdir(parents=True, exist_ok=True) | |
| total_time_ms: float = 0 | |
| count = 0 | |
| # タイムスタンプに従って分割し、ファイルに保存 | |
| for i, ts in enumerate(speech_timestamps): | |
| start_ms = max(ts["start"] / 16 - margin, 0) | |
| end_ms = min(ts["end"] / 16 + margin, total_ms) | |
| start_sample = int(start_ms / 1000 * sr) | |
| end_sample = int(end_ms / 1000 * sr) | |
| segment = data[start_sample:end_sample] | |
| if time_suffix: | |
| file = f"{file_name}-{int(start_ms)}-{int(end_ms)}.wav" | |
| else: | |
| file = f"{file_name}-{i}.wav" | |
| sf.write(str(target_dir / file), segment, sr) | |
| total_time_ms += end_ms - start_ms | |
| count += 1 | |
| return total_time_ms / 1000, count | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice" | |
| ) | |
| parser.add_argument( | |
| "--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice" | |
| ) | |
| parser.add_argument( | |
| "--input_dir", | |
| "-i", | |
| type=str, | |
| default="inputs", | |
| help="Directory of input wav files", | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| required=True, | |
| help="The result will be in Data/{model_name}/raw/ (if Data is dataset_root in configs/paths.yml)", | |
| ) | |
| parser.add_argument( | |
| "--min_silence_dur_ms", | |
| "-s", | |
| type=int, | |
| default=700, | |
| help="Silence above this duration (ms) is considered as a split point.", | |
| ) | |
| parser.add_argument( | |
| "--time_suffix", | |
| "-t", | |
| action="store_true", | |
| help="Make the filename end with -start_ms-end_ms when saving wav.", | |
| ) | |
| parser.add_argument( | |
| "--num_processes", | |
| type=int, | |
| default=3, | |
| help="Number of processes to use. Default 3 seems to be the best.", | |
| ) | |
| args = parser.parse_args() | |
| path_config = get_path_config() | |
| dataset_root = path_config.dataset_root | |
| model_name = str(args.model_name) | |
| input_dir = Path(args.input_dir) | |
| output_dir = dataset_root / model_name / "raw" | |
| min_sec: float = args.min_sec | |
| max_sec: float = args.max_sec | |
| min_silence_dur_ms: int = args.min_silence_dur_ms | |
| time_suffix: bool = args.time_suffix | |
| num_processes: int = args.num_processes | |
| audio_files = [file for file in input_dir.rglob("*") if is_audio_file(file)] | |
| logger.info(f"Found {len(audio_files)} audio files.") | |
| if output_dir.exists(): | |
| logger.warning(f"Output directory {output_dir} already exists, deleting...") | |
| shutil.rmtree(output_dir) | |
| # モデルをダウンロードしておく | |
| _ = torch.hub.load( | |
| repo_or_dir="litagin02/silero-vad", | |
| model="silero_vad", | |
| onnx=True, | |
| trust_repo=True, | |
| ) | |
| # Silero VADのモデルは、同じインスタンスで並列処理するとおかしくなるらしい | |
| # ワーカーごとにモデルをロードするようにするため、Queueを使って処理する | |
| def process_queue( | |
| q: Queue[Optional[Path]], | |
| result_queue: Queue[tuple[float, int]], | |
| error_queue: Queue[tuple[Path, Exception]], | |
| ): | |
| # logger.debug("Worker started.") | |
| vad_model, utils = torch.hub.load( | |
| repo_or_dir="litagin02/silero-vad", | |
| model="silero_vad", | |
| onnx=True, | |
| trust_repo=True, | |
| ) | |
| while True: | |
| file = q.get() | |
| if file is None: # 終了シグナルを確認 | |
| q.task_done() | |
| break | |
| try: | |
| rel_path = file.relative_to(input_dir) | |
| time_sec, count = split_wav( | |
| vad_model=vad_model, | |
| utils=utils, | |
| audio_file=file, | |
| target_dir=output_dir / rel_path.parent, | |
| min_sec=min_sec, | |
| max_sec=max_sec, | |
| min_silence_dur_ms=min_silence_dur_ms, | |
| time_suffix=time_suffix, | |
| ) | |
| result_queue.put((time_sec, count)) | |
| except Exception as e: | |
| logger.error(f"Error processing {file}: {e}") | |
| error_queue.put((file, e)) | |
| result_queue.put((0, 0)) | |
| finally: | |
| q.task_done() | |
| q: Queue[Optional[Path]] = Queue() | |
| result_queue: Queue[tuple[float, int]] = Queue() | |
| error_queue: Queue[tuple[Path, Exception]] = Queue() | |
| # ファイル数が少ない場合は、ワーカー数をファイル数に合わせる | |
| num_processes = min(num_processes, len(audio_files)) | |
| threads = [ | |
| Thread(target=process_queue, args=(q, result_queue, error_queue)) | |
| for _ in range(num_processes) | |
| ] | |
| for t in threads: | |
| t.start() | |
| pbar = tqdm(total=len(audio_files), file=SAFE_STDOUT) | |
| for file in audio_files: | |
| q.put(file) | |
| # result_queueを監視し、要素が追加されるごとに結果を加算しプログレスバーを更新 | |
| total_sec = 0 | |
| total_count = 0 | |
| for _ in range(len(audio_files)): | |
| time, count = result_queue.get() | |
| total_sec += time | |
| total_count += count | |
| pbar.update(1) | |
| # 全ての処理が終わるまで待つ | |
| q.join() | |
| # 終了シグナル None を送る | |
| for _ in range(num_processes): | |
| q.put(None) | |
| for t in threads: | |
| t.join() | |
| pbar.close() | |
| if not error_queue.empty(): | |
| error_str = "Error slicing some files:" | |
| while not error_queue.empty(): | |
| file, e = error_queue.get() | |
| error_str += f"\n{file}: {e}" | |
| raise RuntimeError(error_str) | |
| logger.info( | |
| f"Slice done! Total time: {total_sec / 60:.2f} min, {total_count} files." | |
| ) | |