Spaces:
Running
Running
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import shutil | |
| import torchaudio | |
| from datasets import load_dataset | |
| from datasets.arrow_writer import ArrowWriter | |
| from tqdm import tqdm | |
| import soundfile as sf | |
| import csv | |
| import subprocess | |
| import argparse | |
| def save_dataset_to_local_disk(output_dir, base_model, audio_header, text_header): | |
| """ | |
| Saves a dataset to a local directory. | |
| Args: | |
| output_dir (str): The directory to save the dataset to. | |
| base_model (str): The base model to load the dataset from. | |
| audio_header (str): The header for the audio data in the dataset. | |
| text_header (str): The header for the text data in the dataset. | |
| """ | |
| wavs_dir = os.path.join(output_dir, "wavs") | |
| metadata_path = os.path.join(output_dir, "metadata.csv") | |
| os.makedirs(wavs_dir, exist_ok=True) | |
| try: | |
| ds = load_dataset(base_model)['train'] | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}", file=sys.stderr) | |
| return | |
| metadata = [] | |
| for idx, sample in tqdm(enumerate(ds), total=len(ds), desc="Saving samples to directory"): | |
| try: | |
| audio_array = sample[audio_header]['array'] | |
| sampling_rate = sample[audio_header]['sampling_rate'] | |
| filename = f"audio_{idx:06d}.wav" | |
| sf.write(os.path.join(wavs_dir, filename), audio_array, sampling_rate) | |
| metadata.append([f"wavs/{filename}", sample[text_header]]) | |
| except Exception as e: | |
| print(f"Error processing sample {idx}: {e}", file=sys.stderr) | |
| continue | |
| try: | |
| with open(metadata_path, 'w', newline='', encoding='utf-8') as f: | |
| csv.writer(f, delimiter='|').writerows(metadata) | |
| print(f"Dataset saved to {output_dir}") | |
| except Exception as e: | |
| print(f"Error writing metadata: {e}", file=sys.stderr) | |
| def run_preprocess(input_dir, output_dir, workers): | |
| """ | |
| Runs the preprocessing script with real-time output. | |
| Args: | |
| input_dir (str): Input directory for preprocessing. | |
| output_dir (str): Output directory for processed data. | |
| workers (int): Number of parallel processes. | |
| """ | |
| script_path = "./src/f5_tts/train/datasets/prepare_csv_wavs.py" | |
| if not os.path.exists(script_path): | |
| print(f"Preprocessing script not found at {script_path}", file=sys.stderr) | |
| return | |
| command = [ | |
| "python", script_path, | |
| input_dir, output_dir, | |
| "--workers", str(workers) | |
| ] | |
| try: | |
| process = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| bufsize=1, # Line buffered | |
| universal_newlines=True | |
| ) | |
| # Real-time output for stdout and stderr | |
| while True: | |
| stdout_line = process.stdout.readline() | |
| stderr_line = process.stderr.readline() | |
| if stdout_line: | |
| print(stdout_line, end='', flush=True) | |
| if stderr_line: | |
| print(stderr_line, end='', flush=True, file=sys.stderr) | |
| if process.poll() is not None: | |
| break | |
| # Capture any remaining output | |
| stdout, stderr = process.communicate() | |
| if stdout: | |
| print(stdout, end='', flush=True) | |
| if stderr: | |
| print(stderr, end='', flush=True, file=sys.stderr) | |
| if process.returncode == 0: | |
| print("\nPreprocessing completed successfully.") | |
| else: | |
| print(f"\nPreprocessing failed with return code {process.returncode}.", file=sys.stderr) | |
| except Exception as e: | |
| print(f"Error during preprocessing: {e}", file=sys.stderr) | |
| if __name__ == "__main__": | |
| # Set up argument parsing | |
| parser = argparse.ArgumentParser(description="Prepare dataset for training.") | |
| # parser.add_argument("--command", type=str, choices=["save", "preprocess"], required=True, | |
| # help="Command to execute: 'save' or 'preprocess'") | |
| parser.add_argument("--output_dir", type=str, default="./data/vin100h-preprocessed-v2", | |
| help="Output directory for save command") | |
| parser.add_argument("--base_model", type=str, default="htdung167/vin100h-preprocessed-v2", | |
| help="Base model for save command") | |
| parser.add_argument("--audio_header", type=str, default="audio", | |
| help="Audio header for save command") | |
| parser.add_argument("--text_header", type=str, default="preprocessed_sentence_v2", | |
| help="Text header for save command") | |
| parser.add_argument("--prepare_csv_input_dir", type=str, | |
| default="./data/vin100h-preprocessed-v2", | |
| help="Input directory for preprocess command") | |
| parser.add_argument("--prepare_csv_output_dir", type=str, | |
| default="./data/vin100h-preprocessed-v2_pinyin", | |
| help="Output directory for preprocess command") | |
| parser.add_argument("--workers", type=int, default=4, | |
| help="Number of parallel processes for preprocess command") | |
| args = parser.parse_args() | |
| # if args.command == "save": | |
| save_dataset_to_local_disk(args.output_dir, args.base_model, args.audio_header, args.text_header) | |
| # elif args.command == "preprocess": | |
| run_preprocess(args.prepare_csv_input_dir, args.prepare_csv_output_dir, args.workers) |