File size: 5,595 Bytes
3f9cba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import os
import sys
from pathlib import Path
import shutil
import torchaudio
from datasets import load_dataset
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
import soundfile as sf
import csv
import subprocess
import argparse

def save_dataset_to_local_disk(output_dir, base_model, audio_header, text_header):
    """

    Saves a dataset to a local directory.



    Args:

        output_dir (str): The directory to save the dataset to.

        base_model (str): The base model to load the dataset from.

        audio_header (str): The header for the audio data in the dataset.

        text_header (str): The header for the text data in the dataset.

    """
    wavs_dir = os.path.join(output_dir, "wavs")
    metadata_path = os.path.join(output_dir, "metadata.csv")
    os.makedirs(wavs_dir, exist_ok=True)

    try:
        ds = load_dataset(base_model)['train']
    except Exception as e:
        print(f"Error loading dataset: {e}", file=sys.stderr)
        return

    metadata = []
    for idx, sample in tqdm(enumerate(ds), total=len(ds), desc="Saving samples to directory"):
        try:
            audio_array = sample[audio_header]['array']
            sampling_rate = sample[audio_header]['sampling_rate']
            filename = f"audio_{idx:06d}.wav"
            sf.write(os.path.join(wavs_dir, filename), audio_array, sampling_rate)
            metadata.append([f"wavs/{filename}", sample[text_header]])
        except Exception as e:
            print(f"Error processing sample {idx}: {e}", file=sys.stderr)
            continue

    try:
        with open(metadata_path, 'w', newline='', encoding='utf-8') as f:
            csv.writer(f, delimiter='|').writerows(metadata)
        print(f"Dataset saved to {output_dir}")
    except Exception as e:
        print(f"Error writing metadata: {e}", file=sys.stderr)

def run_preprocess(input_dir, output_dir, workers):
    """

    Runs the preprocessing script with real-time output.



    Args:

        input_dir (str): Input directory for preprocessing.

        output_dir (str): Output directory for processed data.

        workers (int): Number of parallel processes.

    """
    script_path = "./src/f5_tts/train/datasets/prepare_csv_wavs.py"
    if not os.path.exists(script_path):
        print(f"Preprocessing script not found at {script_path}", file=sys.stderr)
        return

    command = [
        "python", script_path,
        input_dir, output_dir,
        "--workers", str(workers)
    ]
    try:
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,  # Line buffered
            universal_newlines=True
        )

        # Real-time output for stdout and stderr
        while True:
            stdout_line = process.stdout.readline()
            stderr_line = process.stderr.readline()

            if stdout_line:
                print(stdout_line, end='', flush=True)
            if stderr_line:
                print(stderr_line, end='', flush=True, file=sys.stderr)

            if process.poll() is not None:
                break

        # Capture any remaining output
        stdout, stderr = process.communicate()
        if stdout:
            print(stdout, end='', flush=True)
        if stderr:
            print(stderr, end='', flush=True, file=sys.stderr)

        if process.returncode == 0:
            print("\nPreprocessing completed successfully.")
        else:
            print(f"\nPreprocessing failed with return code {process.returncode}.", file=sys.stderr)
    except Exception as e:
        print(f"Error during preprocessing: {e}", file=sys.stderr)

if __name__ == "__main__":
    # Set up argument parsing
    parser = argparse.ArgumentParser(description="Prepare dataset for training.")
    # parser.add_argument("--command", type=str, choices=["save", "preprocess"], required=True,
    #                     help="Command to execute: 'save' or 'preprocess'")
    parser.add_argument("--output_dir", type=str, default="./data/vin100h-preprocessed-v2",
                        help="Output directory for save command")
    parser.add_argument("--base_model", type=str, default="htdung167/vin100h-preprocessed-v2",
                        help="Base model for save command")
    parser.add_argument("--audio_header", type=str, default="audio",
                        help="Audio header for save command")
    parser.add_argument("--text_header", type=str, default="preprocessed_sentence_v2",
                        help="Text header for save command")
    parser.add_argument("--prepare_csv_input_dir", type=str,
                        default="./data/vin100h-preprocessed-v2",
                        help="Input directory for preprocess command")
    parser.add_argument("--prepare_csv_output_dir", type=str,
                        default="./data/vin100h-preprocessed-v2_pinyin",
                        help="Output directory for preprocess command")
    parser.add_argument("--workers", type=int, default=4,
                        help="Number of parallel processes for preprocess command")

    args = parser.parse_args()

    # if args.command == "save":
    save_dataset_to_local_disk(args.output_dir, args.base_model, args.audio_header, args.text_header)
    # elif args.command == "preprocess":
    run_preprocess(args.prepare_csv_input_dir, args.prepare_csv_output_dir, args.workers)