Spaces:
Running
on
Zero
Running
on
Zero
import shutil | |
from test import test_settings | |
from scripts.redact_config import redact_config | |
from utils import load_config | |
from pathlib import Path | |
import os | |
import numpy as np | |
import soundfile as sf | |
from typing import List, Dict | |
MODEL_CONFIGS = { | |
'config_apollo.yaml': {'model_type': 'apollo'}, | |
'config_dnr_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'}, | |
'config_dnr_bandit_v2_mus64.yaml': {'model_type': 'bandit_v2'}, | |
'config_drumsep.yaml': {'model_type': 'htdemucs'}, | |
'config_htdemucs_6stems.yaml': {'model_type': 'htdemucs'}, | |
'config_musdb18_bs_roformer.yaml': {'model_type': 'bs_roformer'}, | |
'config_musdb18_demucs3_mmi.yaml': {'model_type': 'htdemucs'}, | |
'config_musdb18_htdemucs.yaml': {'model_type': 'htdemucs'}, | |
'config_musdb18_mdx23c.yaml': {'model_type': 'mdx23c'}, | |
'config_musdb18_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'}, | |
'config_musdb18_mel_band_roformer_all_stems.yaml': {'model_type': 'mel_band_roformer'}, | |
'config_musdb18_scnet.yaml': {'model_type': 'scnet'}, | |
'config_musdb18_scnet_large.yaml': {'model_type': 'scnet'}, | |
# 'config_musdb18_scnet_large_starrytong.yaml': {'model_type': 'scnet'}, | |
'config_vocals_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'}, | |
'config_vocals_bs_roformer.yaml': {'model_type': 'bs_roformer'}, | |
'config_vocals_htdemucs.yaml': {'model_type': 'htdemucs'}, | |
'config_vocals_mdx23c.yaml': {'model_type': 'mdx23c'}, | |
'config_vocals_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'}, | |
'config_vocals_scnet.yaml': {'model_type': 'scnet'}, | |
'config_vocals_scnet_large.yaml': {'model_type': 'scnet'}, | |
'config_vocals_scnet_unofficial.yaml': {'model_type': 'scnet_unofficial'}, | |
'config_vocals_segm_models.yaml': {'model_type': 'segm_models'}, | |
# 'config_vocals_swin_upernet.yaml': {'model_type': 'swin_upernet'}, | |
# 'config_musdb18_torchseg.yaml': {'model_type': 'torchseg'}, | |
# 'config_musdb18_segm_models.yaml': {'model_type': 'segm_models'}, | |
# 'config_musdb18_bs_mamba2.yaml': {'model_type': 'bs_mamba2'}, | |
# 'config_vocals_bs_mamba2.yaml': {'model_type': 'bs_mamba2'}, | |
# 'config_vocals_torchseg.yaml': {'model_type': 'torchseg'} | |
} | |
# Folders for tests | |
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
CONFIGS_DIR = ROOT_DIR / 'configs/' | |
TEST_DIR = ROOT_DIR / "tests_cache/" | |
TRAIN_DIR = TEST_DIR / "train_tracks/" | |
VALID_DIR = TEST_DIR / "valid_tracks/" | |
def create_dummy_tracks(directory: Path, num_tracks: int, instruments: List[str], | |
duration: float = 5.0, sample_rate: int = 44100) -> None: | |
""" | |
Generates random audio tracks for stems in two subdirectories within the specified directory. | |
Parameters: | |
---------- | |
directory : Path | |
Path to the directory where the tracks will be saved. | |
num_tracks : int | |
Number of tracks to generate in each folder. | |
instruments : List[str] | |
List of instrument names (stems) to create. | |
duration : float, optional | |
Duration of each track in seconds. Default is 5.0. | |
sample_rate : int, optional | |
Sampling rate of the generated audio. Default is 44100 Hz. | |
Returns: | |
------- | |
None | |
""" | |
os.makedirs(directory, exist_ok=True) | |
for folder_name in [str(i) for i in range(1, num_tracks+1)]: | |
folder_path = directory / folder_name | |
os.makedirs(folder_path, exist_ok=True) | |
for instrument in instruments: | |
# Generate random noice for each track | |
samples = int(duration * sample_rate) | |
track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32) | |
file_path = folder_path / f"{instrument}.wav" | |
sf.write(file_path, track.T, sample_rate) | |
def cleanup_test_tracks() -> None: | |
""" | |
Removes all cached test tracks. | |
This function deletes the entire directory specified by the global `TEST_DIR` variable | |
if it exists. | |
Returns: | |
------- | |
None | |
This function does not return a value. It performs cleanup of test data. | |
""" | |
def modify_configs() -> Dict[str, Path]: | |
""" | |
Updates configuration files in the `configs` directory for use with test data. | |
This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary, | |
modifies them to be compatible with test scenarios, and saves the updated configurations | |
in a test-specific directory. | |
Returns: | |
------- | |
Dict[str, Path] | |
A dictionary where the keys are the original configuration file names, and the values | |
are the paths to the updated configuration files. | |
""" | |
config_dir = CONFIGS_DIR | |
updated_configs = {} | |
for config, args in MODEL_CONFIGS.items(): | |
model_type = args['model_type'] | |
config_path = config_dir / config | |
updated_config_path = redact_config({ | |
'orig_config': str(config_path), | |
'model_type': model_type, | |
'new_config': str(TEST_DIR / 'configs' / config) | |
}) | |
updated_configs[config] = updated_config_path | |
return updated_configs | |
def run_tests() -> None: | |
""" | |
Executes validation tests for all configurations. | |
This function updates configurations, generates random dummy data for testing, | |
and runs a series of tests (training, validation, and inference checks) for each | |
model configuration specified in the global `MODEL_CONFIGS` dictionary. | |
Returns: | |
------- | |
None | |
""" | |
updated_configs = modify_configs() | |
# For every config | |
for config, args in MODEL_CONFIGS.items(): | |
model_type = args['model_type'] | |
cfg = load_config(model_type=model_type, config_path=TEST_DIR / 'configs' / config) | |
# Random tracks | |
create_dummy_tracks(TRAIN_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2) | |
create_dummy_tracks(VALID_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2) | |
print(f"\nRunning tests for model: {model_type} (config: {config})") | |
test_args = { | |
'check_train': False, | |
'check_valid': True, | |
'check_inference': True, | |
'config_path': updated_configs[config], | |
'data_path': str(TRAIN_DIR), | |
'valid_path': str(VALID_DIR), | |
'results_path': str(TEST_DIR / "results" / model_type), | |
'store_dir': str(TEST_DIR / "inference_results" / model_type), | |
'metrics': ['sdr', 'si_sdr', 'l1_freq'] | |
} | |
test_args.update(args) | |
test_settings(test_args, 'admin') | |
print(f"Tests for model {model_type} completed successfully.") | |
# Remove test_cache | |
cleanup_test_tracks() | |
if __name__ == "__main__": | |
run_tests() | |