Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
from pathlib import Path | |
def get_parser(): | |
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") | |
default_raw = None | |
default_musdb = None | |
if 'DEMUCS_RAW' in os.environ: | |
default_raw = Path(os.environ['DEMUCS_RAW']) | |
if 'DEMUCS_MUSDB' in os.environ: | |
default_musdb = Path(os.environ['DEMUCS_MUSDB']) | |
parser.add_argument( | |
"--raw", | |
type=Path, | |
default=default_raw, | |
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") | |
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") | |
parser.add_argument("-m", | |
"--musdb", | |
type=Path, | |
default=default_musdb, | |
help="Path to musdb root") | |
parser.add_argument("--is_wav", action="store_true", | |
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") | |
parser.add_argument("--metadata", type=Path, default=Path("metadata/"), | |
help="Folder where metadata information is stored.") | |
parser.add_argument("--wav", type=Path, | |
help="Path to a wav dataset. This should contain a 'train' and a 'valid' " | |
"subfolder.") | |
parser.add_argument("--samplerate", type=int, default=44100) | |
parser.add_argument("--audio_channels", type=int, default=2) | |
parser.add_argument("--samples", | |
default=44100 * 10, | |
type=int, | |
help="number of samples to feed in") | |
parser.add_argument("--data_stride", | |
default=44100, | |
type=int, | |
help="Stride for chunks, shorter = longer epochs") | |
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") | |
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") | |
parser.add_argument("-d", | |
"--device", | |
help="Device to train on, default is cuda if available else cpu") | |
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") | |
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") | |
parser.add_argument("--test", help="Just run the test pipeline + one validation. " | |
"This should be a filename relative to the models/ folder.") | |
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " | |
"on a pretrained model. ") | |
parser.add_argument("--rank", default=0, type=int) | |
parser.add_argument("--world_size", default=1, type=int) | |
parser.add_argument("--master") | |
parser.add_argument("--checkpoints", | |
type=Path, | |
default=Path("checkpoints"), | |
help="Folder where to store checkpoints etc") | |
parser.add_argument("--evals", | |
type=Path, | |
default=Path("evals"), | |
help="Folder where to store evals and waveforms") | |
parser.add_argument("--save", | |
action="store_true", | |
help="Save estimated for the test set waveforms") | |
parser.add_argument("--logs", | |
type=Path, | |
default=Path("logs"), | |
help="Folder where to store logs") | |
parser.add_argument("--models", | |
type=Path, | |
default=Path("models"), | |
help="Folder where to store trained models") | |
parser.add_argument("-R", | |
"--restart", | |
action='store_true', | |
help='Restart training, ignoring previous run') | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") | |
parser.add_argument("-r", | |
"--repeat", | |
type=int, | |
default=2, | |
help="Repeat the train set, longer epochs") | |
parser.add_argument("-b", "--batch_size", type=int, default=64) | |
parser.add_argument("--lr", type=float, default=3e-4) | |
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") | |
parser.add_argument("--init", help="Initialize from a pre-trained model.") | |
# Augmentation options | |
parser.add_argument("--no_augment", | |
action="store_false", | |
dest="augment", | |
default=True, | |
help="No basic data augmentation.") | |
parser.add_argument("--repitch", type=float, default=0.2, | |
help="Probability to do tempo/pitch change") | |
parser.add_argument("--max_tempo", type=float, default=12, | |
help="Maximum relative tempo change in %% when using repitch.") | |
parser.add_argument("--remix_group_size", | |
type=int, | |
default=4, | |
help="Shuffle sources using group of this size. Useful to somewhat " | |
"replicate multi-gpu training " | |
"on less GPUs.") | |
parser.add_argument("--shifts", | |
type=int, | |
default=10, | |
help="Number of random shifts used for the shift trick.") | |
parser.add_argument("--overlap", | |
type=float, | |
default=0.25, | |
help="Overlap when --split_valid is passed.") | |
# See model.py for doc | |
parser.add_argument("--growth", | |
type=float, | |
default=2., | |
help="Number of channels between two layers will increase by this factor") | |
parser.add_argument("--depth", | |
type=int, | |
default=6, | |
help="Number of layers for the encoder and decoder") | |
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") | |
parser.add_argument("--channels", | |
type=int, | |
default=64, | |
help="Number of channels for the first encoder layer") | |
parser.add_argument("--kernel_size", | |
type=int, | |
default=8, | |
help="Kernel size for the (transposed) convolutions") | |
parser.add_argument("--conv_stride", | |
type=int, | |
default=4, | |
help="Stride for the (transposed) convolutions") | |
parser.add_argument("--context", | |
type=int, | |
default=3, | |
help="Context size for the decoder convolutions " | |
"before the transposed convolutions") | |
parser.add_argument("--rescale", | |
type=float, | |
default=0.1, | |
help="Initial weight rescale reference") | |
parser.add_argument("--no_resample", action="store_false", | |
default=True, dest="resample", | |
help="No Resampling of the input/output x2") | |
parser.add_argument("--no_glu", | |
action="store_false", | |
default=True, | |
dest="glu", | |
help="Replace all GLUs by ReLUs") | |
parser.add_argument("--no_rewrite", | |
action="store_false", | |
default=True, | |
dest="rewrite", | |
help="No 1x1 rewrite convolutions") | |
parser.add_argument("--normalize", action="store_true") | |
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) | |
# Tasnet options | |
parser.add_argument("--tasnet", action="store_true") | |
parser.add_argument("--split_valid", | |
action="store_true", | |
help="Predict chunks by chunks for valid and test. Required for tasnet") | |
parser.add_argument("--X", type=int, default=8) | |
# Other options | |
parser.add_argument("--show", | |
action="store_true", | |
help="Show model architecture, size and exit") | |
parser.add_argument("--save_model", action="store_true", | |
help="Skip traning, just save final model " | |
"for the current checkpoint value.") | |
parser.add_argument("--save_state", | |
help="Skip training, just save state " | |
"for the current checkpoint value. You should " | |
"provide a model name as argument.") | |
# Quantization options | |
parser.add_argument("--q-min-size", type=float, default=1, | |
help="Only quantize layers over this size (in MB)") | |
parser.add_argument( | |
"--qat", type=int, help="If provided, use QAT training with that many bits.") | |
parser.add_argument("--diffq", type=float, default=0) | |
parser.add_argument( | |
"--ms-target", type=float, default=162, | |
help="Model size target in MB, when using DiffQ. Best model will be kept " | |
"only if it is smaller than this target.") | |
return parser | |
def get_name(parser, args): | |
""" | |
Return the name of an experiment given the args. Some parameters are ignored, | |
for instance --workers, as they do not impact the final result. | |
""" | |
ignore_args = set([ | |
"checkpoints", | |
"deterministic", | |
"eval", | |
"evals", | |
"eval_cpu", | |
"eval_workers", | |
"logs", | |
"master", | |
"rank", | |
"restart", | |
"save", | |
"save_model", | |
"save_state", | |
"show", | |
"workers", | |
"world_size", | |
]) | |
parts = [] | |
name_args = dict(args.__dict__) | |
for name, value in name_args.items(): | |
if name in ignore_args: | |
continue | |
if value != parser.get_default(name): | |
if isinstance(value, Path): | |
parts.append(f"{name}={value.name}") | |
else: | |
parts.append(f"{name}={value}") | |
if parts: | |
name = " ".join(parts) | |
else: | |
name = "default" | |
return name | |