Spaces:
Running
Running
File size: 5,595 Bytes
3f9cba0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import json
import os
import sys
from pathlib import Path
import shutil
import torchaudio
from datasets import load_dataset
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
import soundfile as sf
import csv
import subprocess
import argparse
def save_dataset_to_local_disk(output_dir, base_model, audio_header, text_header):
"""
Saves a dataset to a local directory.
Args:
output_dir (str): The directory to save the dataset to.
base_model (str): The base model to load the dataset from.
audio_header (str): The header for the audio data in the dataset.
text_header (str): The header for the text data in the dataset.
"""
wavs_dir = os.path.join(output_dir, "wavs")
metadata_path = os.path.join(output_dir, "metadata.csv")
os.makedirs(wavs_dir, exist_ok=True)
try:
ds = load_dataset(base_model)['train']
except Exception as e:
print(f"Error loading dataset: {e}", file=sys.stderr)
return
metadata = []
for idx, sample in tqdm(enumerate(ds), total=len(ds), desc="Saving samples to directory"):
try:
audio_array = sample[audio_header]['array']
sampling_rate = sample[audio_header]['sampling_rate']
filename = f"audio_{idx:06d}.wav"
sf.write(os.path.join(wavs_dir, filename), audio_array, sampling_rate)
metadata.append([f"wavs/{filename}", sample[text_header]])
except Exception as e:
print(f"Error processing sample {idx}: {e}", file=sys.stderr)
continue
try:
with open(metadata_path, 'w', newline='', encoding='utf-8') as f:
csv.writer(f, delimiter='|').writerows(metadata)
print(f"Dataset saved to {output_dir}")
except Exception as e:
print(f"Error writing metadata: {e}", file=sys.stderr)
def run_preprocess(input_dir, output_dir, workers):
"""
Runs the preprocessing script with real-time output.
Args:
input_dir (str): Input directory for preprocessing.
output_dir (str): Output directory for processed data.
workers (int): Number of parallel processes.
"""
script_path = "./src/f5_tts/train/datasets/prepare_csv_wavs.py"
if not os.path.exists(script_path):
print(f"Preprocessing script not found at {script_path}", file=sys.stderr)
return
command = [
"python", script_path,
input_dir, output_dir,
"--workers", str(workers)
]
try:
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1, # Line buffered
universal_newlines=True
)
# Real-time output for stdout and stderr
while True:
stdout_line = process.stdout.readline()
stderr_line = process.stderr.readline()
if stdout_line:
print(stdout_line, end='', flush=True)
if stderr_line:
print(stderr_line, end='', flush=True, file=sys.stderr)
if process.poll() is not None:
break
# Capture any remaining output
stdout, stderr = process.communicate()
if stdout:
print(stdout, end='', flush=True)
if stderr:
print(stderr, end='', flush=True, file=sys.stderr)
if process.returncode == 0:
print("\nPreprocessing completed successfully.")
else:
print(f"\nPreprocessing failed with return code {process.returncode}.", file=sys.stderr)
except Exception as e:
print(f"Error during preprocessing: {e}", file=sys.stderr)
if __name__ == "__main__":
# Set up argument parsing
parser = argparse.ArgumentParser(description="Prepare dataset for training.")
# parser.add_argument("--command", type=str, choices=["save", "preprocess"], required=True,
# help="Command to execute: 'save' or 'preprocess'")
parser.add_argument("--output_dir", type=str, default="./data/vin100h-preprocessed-v2",
help="Output directory for save command")
parser.add_argument("--base_model", type=str, default="htdung167/vin100h-preprocessed-v2",
help="Base model for save command")
parser.add_argument("--audio_header", type=str, default="audio",
help="Audio header for save command")
parser.add_argument("--text_header", type=str, default="preprocessed_sentence_v2",
help="Text header for save command")
parser.add_argument("--prepare_csv_input_dir", type=str,
default="./data/vin100h-preprocessed-v2",
help="Input directory for preprocess command")
parser.add_argument("--prepare_csv_output_dir", type=str,
default="./data/vin100h-preprocessed-v2_pinyin",
help="Output directory for preprocess command")
parser.add_argument("--workers", type=int, default=4,
help="Number of parallel processes for preprocess command")
args = parser.parse_args()
# if args.command == "save":
save_dataset_to_local_disk(args.output_dir, args.base_model, args.audio_header, args.text_header)
# elif args.command == "preprocess":
run_preprocess(args.prepare_csv_input_dir, args.prepare_csv_output_dir, args.workers) |