Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import argparse | |
# Добавляем корень репозитория в системный путь | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from valid import check_validation | |
from inference import proc_folder | |
from train import train_model | |
from scripts.redact_config import redact_config | |
from scripts.valid_to_inference import copying_files | |
from scripts.trim import trim_directory | |
base_args = { | |
'device_ids': '0', | |
'model_type': '', | |
'start_check_point': '', | |
'config_path': '', | |
'data_path': '', | |
'valid_path': '', | |
'results_path': 'tests/train_results', | |
'store_dir': 'tests/valid_inference_result', | |
'input_folder': '', | |
'metrics': ['neg_log_wmse', 'l1_freq', 'si_sdr', 'sdr', 'aura_stft', 'aura_mrstft', 'bleedless', 'fullness'], | |
'max_folders': 2 | |
} | |
def parse_args(dict_args): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--check_train", action='store_true', help="Check train or not") | |
parser.add_argument("--check_valid", action='store_true', help="Check train or not") | |
parser.add_argument("--check_inference", action='store_true', help="Check train or not") | |
parser.add_argument('--device_ids', type=str, help='Device IDs for training/inference') | |
parser.add_argument('--model_type', type=str, help='Model type') | |
parser.add_argument('--start_check_point', type=str, help='Path to the checkpoint to start from') | |
parser.add_argument('--config_path', type=str, help='Path to the configuration file') | |
parser.add_argument('--data_path', type=str, help='Path to the training data') | |
parser.add_argument('--valid_path', type=str, help='Path to the validation data') | |
parser.add_argument('--results_path', type=str, help='Path to save training results') | |
parser.add_argument('--store_dir', type=str, help='Path to store validation/inference results') | |
parser.add_argument('--input_folder', type=str, help='Path to the input folder for inference') | |
parser.add_argument('--metrics', nargs='+', help='List of metrics to evaluate') | |
parser.add_argument('--max_folders', type=str, help='Maximum number of folders to process') | |
parser.add_argument("--dataset_type", type=int, default=1, | |
help="Dataset type. Must be one of: 1, 2, 3 or 4.") | |
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers") | |
parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory") | |
parser.add_argument("--seed", type=int, default=0, help="random seed") | |
parser.add_argument("--use_multistft_loss", action='store_true', | |
help="Use MultiSTFT Loss (from auraloss package)") | |
parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss") | |
parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss") | |
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key') | |
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training') | |
parser.add_argument("--metric_for_scheduler", default="sdr", | |
choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', | |
'fullness'], help='Metric which will be used for scheduler.') | |
parser.add_argument("--train_lora", action='store_true', help="Train with LoRA") | |
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") | |
parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation") | |
parser.add_argument("--use_tta", action='store_true', | |
help="Flag adds test time augmentation during inference (polarity and channel inverse)." | |
" While this triples the runtime, it reduces noise and slightly improves prediction quality.") | |
parser.add_argument("--extract_instrumental", action='store_true', | |
help="invert vocals to get instrumental if provided") | |
parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar") | |
parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available") | |
parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav") | |
parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', | |
help="PCM type for FLAC files (PCM_16 or PCM_24)") | |
parser.add_argument("--draw_spectro", type=float, default=0, | |
help="If --store_dir is set then code will generate spectrograms for resulted stems as well." | |
" Value defines for how many seconds os track spectrogram will be generated.") | |
if dict_args is not None: | |
args = parser.parse_args([]) | |
args_dict = vars(args) | |
args_dict.update(dict_args) | |
args = argparse.Namespace(**args_dict) | |
else: | |
args = parser.parse_args() | |
return args | |
def test_settings(dict_args, test_type): | |
# Parse from cmd | |
cli_args = parse_args(dict_args) | |
# If args from cmd, add or replace in base_args | |
for key, value in vars(cli_args).items(): | |
if value is not None: | |
base_args[key] = value | |
if test_type == 'user': | |
# Check required arguments | |
missing_args = [arg for arg in ['model_type', 'config_path', 'start_check_point', 'data_path', 'valid_path'] if | |
not base_args[arg]] | |
if missing_args: | |
missing_args_str = ', '.join(f'--{arg}' for arg in missing_args) | |
raise ValueError( | |
f"The following arguments are required but missing: {missing_args_str}." | |
f" Please specify them either via command-line arguments or directly in `base_args`.") | |
# Replace config | |
base_args['config_path'] = redact_config({'orig_config': base_args['config_path'], | |
'model_type': base_args['model_type'], | |
'new_config': ''}) | |
# Trim train | |
trim_args_train = {'input_directory': base_args['data_path'], | |
'max_folders': base_args['max_folders']} | |
base_args['data_path'] = trim_directory(trim_args_train) | |
# Trim valid | |
trim_args_valid = {'input_directory': base_args['valid_path'], | |
'max_folders': base_args['max_folders']} | |
base_args['valid_path'] = trim_directory(trim_args_valid) | |
# Valid to inference | |
if not base_args['input_folder']: | |
tests_dir = os.path.join(os.path.dirname(base_args['valid_path']), 'for_inference') | |
base_args['input_folder'] = tests_dir | |
val_to_inf_args = {'valid_path': base_args['valid_path'], | |
'inference_dir': base_args['input_folder'], | |
'max_mixtures': 1} | |
copying_files(val_to_inf_args) | |
if base_args['check_valid']: | |
valid_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', | |
'store_dir', 'device_ids', 'num_workers', 'pin_memory', 'extension', | |
'use_tta', 'metrics', 'lora_checkpoint', 'draw_spectro']} | |
valid_args['valid_path'] = [base_args['valid_path']] | |
print('Start validation.') | |
check_validation(valid_args) | |
print(f'Validation ended. See results in {base_args["store_dir"]}') | |
if base_args['check_inference']: | |
inference_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'input_folder', | |
'store_dir', 'device_ids', 'extract_instrumental', | |
'disable_detailed_pbar', 'force_cpu', 'flac_file', 'pcm_type', | |
'use_tta', 'lora_checkpoint', 'draw_spectro']} | |
print('Start inference.') | |
proc_folder(inference_args) | |
print(f'Inference ended. See results in {base_args["store_dir"]}') | |
if base_args['check_train']: | |
train_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'results_path', | |
'data_path', 'dataset_type', 'valid_path', 'num_workers', 'pin_memory', | |
'seed', 'device_ids', 'use_multistft_loss', 'use_mse_loss', | |
'use_l1_loss', 'wandb_key', 'pre_valid', 'metrics', | |
'metric_for_scheduler', 'train_lora', 'lora_checkpoint']} | |
print('Start train.') | |
train_model(train_args) | |
print('End!') | |
if __name__ == "__main__": | |
test_settings(None, 'user') | |