ahmedtarekabd commited on
Commit
4c8f740
·
1 Parent(s): 61e8d59

Add Models & files.

Browse files
.docker-compose.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+ services:
3
+ inference:
4
+ build:
5
+ context: .
6
+ image: audio-infer
7
+ volumes:
8
+ - ./data/data_20_files:/data
9
+ - ./data/output:/results
10
+ command: ["--team_id", "8"]
.dockerignore ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ catboost_info/
3
+ # mlruns/
4
+ # !mlruns/models
5
+
6
+ .git
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # UV
105
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ #uv.lock
109
+
110
+ # poetry
111
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
113
+ # commonly ignored for libraries.
114
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115
+ #poetry.lock
116
+
117
+ # pdm
118
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119
+ #pdm.lock
120
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121
+ # in version control.
122
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
123
+ .pdm.toml
124
+ .pdm-python
125
+ .pdm-build/
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .venv
140
+ env/
141
+ venv/
142
+ ENV/
143
+ env.bak/
144
+ venv.bak/
145
+
146
+ # Spyder project settings
147
+ .spyderproject
148
+ .spyproject
149
+
150
+ # Rope project settings
151
+ .ropeproject
152
+
153
+ # mkdocs documentation
154
+ /site
155
+
156
+ # mypy
157
+ .mypy_cache/
158
+ .dmypy.json
159
+ dmypy.json
160
+
161
+ # Pyre type checker
162
+ .pyre/
163
+
164
+ # pytype static type analyzer
165
+ .pytype/
166
+
167
+ # Cython debug symbols
168
+ cython_debug/
169
+
170
+ # PyCharm
171
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
172
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
173
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
174
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
175
+ #.idea/
176
+
177
+ # Ruff stuff:
178
+ .ruff_cache/
179
+
180
+ # PyPI configuration file
181
+ .pypirc
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ catboost_info/
3
+ mlruns/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # UV
102
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ #uv.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ #pdm.lock
117
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118
+ # in version control.
119
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
120
+ .pdm.toml
121
+ .pdm-python
122
+ .pdm-build/
123
+
124
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125
+ __pypackages__/
126
+
127
+ # Celery stuff
128
+ celerybeat-schedule
129
+ celerybeat.pid
130
+
131
+ # SageMath parsed files
132
+ *.sage.py
133
+
134
+ # Environments
135
+ .env
136
+ .venv
137
+ env/
138
+ venv/
139
+ ENV/
140
+ env.bak/
141
+ venv.bak/
142
+
143
+ # Spyder project settings
144
+ .spyderproject
145
+ .spyproject
146
+
147
+ # Rope project settings
148
+ .ropeproject
149
+
150
+ # mkdocs documentation
151
+ /site
152
+
153
+ # mypy
154
+ .mypy_cache/
155
+ .dmypy.json
156
+ dmypy.json
157
+
158
+ # Pyre type checker
159
+ .pyre/
160
+
161
+ # pytype static type analyzer
162
+ .pytype/
163
+
164
+ # Cython debug symbols
165
+ cython_debug/
166
+
167
+ # PyCharm
168
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
171
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172
+ #.idea/
173
+
174
+ # Ruff stuff:
175
+ .ruff_cache/
176
+
177
+ # PyPI configuration file
178
+ .pypirc
Dockerfile CHANGED
@@ -1,21 +1,23 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
  build-essential \
7
  curl \
8
  software-properties-common \
9
  git \
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
 
15
- RUN pip3 install -r requirements.txt
16
 
17
- EXPOSE 8501
18
 
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
5
+ # Source: https://stackoverflow.com/questions/55036740/lightgbm-inside-docker-libgomp-so-1-cannot-open-shared-object-file
6
  RUN apt-get update && apt-get install -y \
7
  build-essential \
8
  curl \
9
  software-properties-common \
10
  git \
11
+ libgomp1 \
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
+ COPY requirements_docker.txt requirements_docker.txt
15
+ RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements_docker.txt
16
 
17
+ COPY . .
18
 
19
+ VOLUME ["/data", "/results"]
20
 
21
+ EXPOSE 8501
22
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
23
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
README.md CHANGED
@@ -1,20 +1,40 @@
1
- ---
2
- title: Audio Classifier
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: This space is used to deploy Speaker Age & Gender clf.
12
- license: mit
13
- ---
14
 
15
- # Welcome to Streamlit!
 
 
 
 
 
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
 
 
 
 
 
 
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio Classification
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ This repository contains a collection of Jupyter notebooks and Python scripts for audio classification tasks using various machine learning and deep learning techniques. The main focus is on classifying audio files into different categories based on their content.
4
+ The project includes the following components:
5
+ - **Data Preprocessing**: Scripts for loading and preprocessing audio data, including feature extraction using libraries like `librosa`.
6
+ - **Model Training**: Jupyter notebooks for training different models, including traditional machine learning algorithms and deep learning architectures.
7
+ - **Model Evaluation**: Scripts for evaluating the performance of trained models using metrics like accuracy, precision, recall, and F1-score.
8
+ - **Visualization**: Tools for visualizing audio data and model performance, including confusion matrices and ROC curves.
9
 
10
+ ## How to Use
11
+ 1. Clone the repository:
12
+ ```bash
13
+ git clone Adsasda
14
+ cd Audio-Classification
15
+ ```
16
+ 2. Install the required dependencies:
17
+ 3. ```bash
18
+ pip install -r requirements.txt
19
+ ```
20
+ 4. Prepare your audio dataset and place it in the `data/` directory.
21
 
22
+ ### Using Docker
23
+ 1. Build the Docker image:
24
+ ```bash
25
+ docker build -t audio-infer .
26
+ ```
27
+ 2. Run the Docker container with your audio files mounted:
28
+ ```bash
29
+ docker run --rm -v "$(pwd)/data/data_20_files:/data" -v "$(pwd)/data/output:/results" audio-infer --team_id 8
30
+ ```
31
+ 3. The results will be saved in the `data/output` directory.
32
+ 4. You can also run the container with a specific model:
33
+ ```bash
34
+ docker run --rm -v "$(pwd)/data/data_20_files:/data" -v "$(pwd)/data/output:/results" audio-infer --team_id 8 --model_path /path/to/your/model
35
+ ```
36
+ ### Using Docker Compose
37
+ 1. Build the Docker image:
38
+ ```bash
39
+ docker-compose -f .docker-compose.yaml up --build
40
+ ```
config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tqdm import tqdm
4
+ import pandas as pd
5
+
6
+ if os.path.exists("/kaggle"):
7
+ # Running on Kaggle
8
+ DATA_DIR = Path("/kaggle/input/your-dataset-name")
9
+ elif os.path.exists("/content"):
10
+ # Running on Google Colab
11
+ DATA_DIR = Path("/content")
12
+ else:
13
+ DATA_DIR = Path("data")
14
+
15
+ AUDIO_PATH = DATA_DIR / "audios"
16
+ AUDIO_CACHE = DATA_DIR / "audio_cache"
17
+ PREPROCESSED_CACHE = DATA_DIR / "preprocessed_cache"
18
+ FEATURES_CACHE = DATA_DIR / "features_cache"
19
+ MODELS_DIR = DATA_DIR / "models"
20
+
21
+ NUM_WORKERS = os.cpu_count() or 4
22
+
23
+ def run_config():
24
+ for folder in [AUDIO_CACHE, PREPROCESSED_CACHE, FEATURES_CACHE, MODELS_DIR]:
25
+ folder.mkdir(parents=True, exist_ok=True)
26
+
27
+ tqdm.pandas()
inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from glob import glob
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from modules.preprocessing import AudioPreprocessor
8
+ from modules.feature_extraction import FeatureExtractor
9
+ from models.lightgbm import LightGBMModel
10
+ from models.xgboost import XGBoostModel
11
+ from modules.pipelines import ModelPipeline
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ MODEL_NAME = {
16
+ "XGBoost": XGBoostModel,
17
+ "LightGBM": LightGBMModel,
18
+ }
19
+
20
+ def run_batch_inference(model, input_folder, output_folder, sr=16000, feature_mode="traditional"):
21
+ preprocessor = AudioPreprocessor()
22
+ extractor = FeatureExtractor()
23
+
24
+ # Sort files in the correct order
25
+ files = sorted(glob(os.path.join(input_folder, "*")), key=lambda x: int(Path(x).stem))
26
+
27
+ # Overwrite if exsists
28
+ results_path = os.path.join(output_folder, "results.txt")
29
+ time_path = os.path.join(output_folder, "time.txt")
30
+ with open(results_path, "w") as f: pass
31
+ with open(time_path, "w") as f: pass
32
+
33
+ pred = 0
34
+ for file in files:
35
+ # Measure inference time
36
+ start_time = time.time()
37
+ y = preprocessor.preprocess(preprocessor.load_audio(str(file), sr=sr))
38
+ if y is not None:
39
+ x = extractor.extract(y, sr=sr, mode=feature_mode, n_mfcc=20)
40
+ pred = model.predict([x])[0]
41
+ end_time = time.time()
42
+ # Save results to results.txt
43
+ with open(results_path, "a") as f:
44
+ f.write(f"{pred}\n")
45
+
46
+ # Save inference time to time.txt
47
+ with open(time_path, "a") as f:
48
+ f.write(f"{end_time - start_time:.6f}\n")
49
+
50
+ print(f"✅ Results saved to {results_path}")
51
+ print(f"✅ Inference time saved to {time_path}")
52
+
53
+ def main(input_path, model_name, output_folder):
54
+ if not os.path.exists(input_path):
55
+ raise FileNotFoundError(f"Input path {input_path} does not exist.")
56
+
57
+ if model_name not in MODEL_NAME.keys():
58
+ raise ValueError(f"Model name {model_name} is not valid. Choose from {list(MODEL_NAME.keys())}.")
59
+
60
+ if not os.path.exists(output_folder):
61
+ os.makedirs(output_folder, exist_ok=True)
62
+ print(f"Output folder {output_folder} created.")
63
+
64
+ model = ModelPipeline(model=MODEL_NAME[model_name])
65
+ model.load_model_from_registry(model_name=model_name)
66
+ print("✅ Model loaded successfully")
67
+
68
+ run_batch_inference(model, input_path, output_folder)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument('--input-path', type=str, default="/data", help="Path to the input folder containing test audio files. Default is '/data'.")
74
+ parser.add_argument('--model-name', type=str, default="XGBoost", help="Name of the model to use for inference. Default is 'XGBoost'.")
75
+ parser.add_argument('--team_id', type=str, required=True, help="Team ID for output folder.")
76
+ args = parser.parse_args()
77
+
78
+ output_folder = os.path.join("/results", args.team_id)
79
+ print(f"Input Path: {args.input_path}")
80
+ print(f"Model Name: {args.model_name}")
81
+ print(f"Output Folder: {output_folder}")
82
+
83
+ main(args.input_path, args.model_name, output_folder)
models/base_model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ from typing import Any, Dict
3
+ from numpy import ndarray
4
+ from sklearn.base import BaseEstimator
5
+ from modules.evaluate import PerformanceAnalyzer
6
+
7
+ # === Base Model Interface ===
8
+ class BaseModel:
9
+ def __init__(self) -> None:
10
+ self.model: BaseEstimator = None
11
+ self.best_params: Dict[str, Any] = {}
12
+ self.model_name: str = self.__class__.__name__ # Automatically set model name
13
+
14
+ def train(self, X_train: ndarray, y_train: ndarray, X_val: ndarray, y_val: ndarray) -> None:
15
+ raise NotImplementedError
16
+
17
+ def predict(self, X: ndarray) -> ndarray:
18
+ return self.model.predict(X)
19
+
20
+ def score(self, X: ndarray, y: ndarray) -> float:
21
+ return self.model.score(X, y)
22
+
23
+ def log_mlflow(self, y_val: ndarray, y_pred: ndarray):
24
+ """
25
+ Logs model performance metrics and the trained model to MLflow.
26
+
27
+ This method evaluates the model's performance using the provided true
28
+ and predicted values, logs the evaluation metrics to MLflow, and saves
29
+ the trained model to MLflow for tracking and reproducibility.
30
+
31
+ Args:
32
+ y_val (ndarray): The ground truth target values.
33
+ y_pred (ndarray): The predicted target values from the model.
34
+
35
+ Returns:
36
+ str | dict: A string representation of the evaluation metrics or
37
+ a dictionary containing the metrics.
38
+
39
+ Input Example:
40
+ y_val = np.array([1, 0, 1, 1, 0])
41
+ y_pred = np.array([1, 0, 1, 0, 0])
42
+ """
43
+ analyzer = PerformanceAnalyzer()
44
+ metrics, metrics_str = analyzer.evaluate(y_val, y_pred)
45
+ mlflow.log_params(self.best_params or {})
46
+
47
+ for category, category_metrics in metrics.items():
48
+ if isinstance(category_metrics, dict):
49
+ mlflow.log_metrics({f"{category}_{k}": v for k, v in category_metrics.items() if isinstance(v, (int, float))})
50
+
51
+ mlflow.sklearn.log_model(self.model, "model")
52
+ mlflow.set_tag("model_name", self.model_name) # Add model name as a tag
53
+ return metrics_str
54
+
55
+ def load_model_from_run(
56
+ self,
57
+ run_id: str = None,
58
+ experiment_id: str = None,
59
+ experiment_name: str = None,
60
+ best_metric: str = None,
61
+ maximize: bool = True,
62
+ additional_tags: Dict[str, str] = None
63
+ ) -> None:
64
+ """
65
+ Loads a model from a specific MLflow run, the last run, or the best run based on a metric.
66
+
67
+ Args:
68
+ run_id (str, optional): The ID of the MLflow run from which to load the model. Defaults to None.
69
+ experiment_id (str, optional): The ID of the MLflow experiment to search for runs. Defaults to None.
70
+ experiment_name (str, optional): The name of the MLflow experiment to search for runs. Required if run_id is not provided.
71
+ best_metric (str, optional): The metric to use for selecting the best run. Defaults to None. Example: "weighted avg_f1-score
72
+ maximize (bool, optional): Whether to maximize or minimize the metric when selecting the best run. Defaults to True.
73
+ additional_tags (dict, optional): Additional tags to filter runs. Defaults to None.
74
+
75
+ Raises:
76
+ ValueError: If neither `run_id` nor `experiment_name` is provided.
77
+ """
78
+ if run_id:
79
+ # Load model from the specified run ID
80
+ run = mlflow.get_run(run_id)
81
+ # elif experiment_id or experiment_name:
82
+ else:
83
+ # Default to the first experiment if not provided
84
+ if not (experiment_id or experiment_name): experiment_id = "0"
85
+
86
+ # Determine the order_by clause
87
+ if best_metric:
88
+ metric_order = f"metrics.'{best_metric}' {'DESC' if maximize else 'ASC'}"
89
+ order_by = [metric_order]
90
+ else:
91
+ order_by = ["start_time DESC"]
92
+
93
+ # Build the filter string
94
+ filter_string = f"attributes.run_name LIKE '{self.model_name}%'"
95
+ if additional_tags:
96
+ for key, value in additional_tags.items():
97
+ filter_string += f" and tags.{key} = '{value}'"
98
+
99
+ # Search for the most relevant run with the model name and additional tags as filters
100
+ runs = mlflow.search_runs(
101
+ experiment_ids=[experiment_id] if experiment_id else None,
102
+ experiment_names=[experiment_name] if experiment_name else None,
103
+ filter_string=filter_string,
104
+ order_by=order_by,
105
+ max_results=1
106
+ )
107
+
108
+ if runs.empty:
109
+ raise ValueError(f"No runs found in experiment '{experiment_name}' with the specified criteria.")
110
+
111
+ # Get the best or last run
112
+ run = mlflow.get_run(runs.iloc[0]["run_id"])
113
+ # else:
114
+ # raise ValueError("Either 'run_id' or 'experiment_id' or 'experiment_name' must be provided.")
115
+
116
+ # Load the model and metadata
117
+ # self.model = mlflow.pyfunc.load_model(mlflow.get_tracking_uri() + f"/{experiment_id}/{run.info.run_id}/artifacts/model")
118
+ self.model = mlflow.pyfunc.load_model(f"runs:/{run.info.run_id}/model")
119
+ self.best_params = run.data.params
120
+ self.metrics = run.data.metrics
121
+ self.model_name = run.info.run_name
122
+ self.run_id = run.info.run_id
123
+
124
+ def register_model(
125
+ self,
126
+ run_id: str,
127
+ model_name: str = None,
128
+ tags: Dict[str, str] = None
129
+ ) -> None:
130
+ """
131
+ Registers a model in MLflow's Model Registry.
132
+
133
+ Args:
134
+ run_id (str): The ID of the MLflow run containing the model to register.
135
+ model_name (str): The name to assign to the registered model.
136
+ description (str, optional): A description for the registered model. Defaults to None.
137
+ tags (dict, optional): Tags to associate with the registered model. Defaults to None.
138
+ """
139
+ mlflow.register_model(
140
+ model_uri=f"runs:/{run_id}/model",
141
+ name=model_name or self.model_name,
142
+ tags=tags
143
+ )
144
+
145
+ def load_model_from_registry(self, model_name: str, version: int = None) -> None:
146
+ """
147
+ Loads a model from MLflow's Model Registry.
148
+
149
+ Args:
150
+ model_name (str): The name of the model to load.
151
+ version (int, optional): The version of the model to load. If None, the latest version is loaded. Defaults to None.
152
+ """
153
+ self.model = mlflow.pyfunc.load_model(model_uri=f"models:/{model_name}/{version if version else 'latest'}")
models/catboost.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from catboost import CatBoostClassifier
2
+ import optuna
3
+ import cupy as cp
4
+ from numpy import ndarray
5
+ import numpy as np
6
+ from models.base_model import BaseModel
7
+ from typing import Dict, Any
8
+
9
+ # === CatBoost Implementation ===
10
+ class CatBoostModel(BaseModel):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ def objective(
15
+ self,
16
+ trial: optuna.trial.Trial,
17
+ X_train: ndarray,
18
+ y_train: ndarray,
19
+ X_val: ndarray,
20
+ y_val: ndarray
21
+ ) -> float:
22
+ params: Dict[str, int] = {
23
+ "iterations": trial.suggest_int("iterations", 300, 500),
24
+ "learning_rate": trial.suggest_float("learning_rate", 1e-2, 1e-1, log=True),
25
+ "depth": trial.suggest_int("depth", 10, 15),
26
+ "l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 1e-5, 1.0, log=True),
27
+ "task_type": "GPU" if cp.cuda.is_available() else "CPU",
28
+ "verbose": False
29
+ }
30
+ model = CatBoostClassifier(**params)
31
+ model.fit(X_train, y_train, eval_set=(X_val, y_val), early_stopping_rounds=50, verbose=False)
32
+ return model.score(X_val, y_val)
33
+
34
+ def train(
35
+ self,
36
+ X_train: ndarray,
37
+ y_train: ndarray,
38
+ X_val: ndarray,
39
+ y_val: ndarray,
40
+ use_optuna: bool = False,
41
+ n_trials: int = 20
42
+ ) -> None:
43
+ if use_optuna:
44
+ study = optuna.create_study(direction="maximize")
45
+ # CatBoost returns device (cuda) already in use if n_jobs > 1
46
+ study.optimize(lambda trial: self.objective(trial, X_train, y_train, X_val, y_val), n_trials=n_trials) #, n_jobs=-1
47
+ self.best_params: Dict[str, Any] = study.best_params
48
+ self.model = CatBoostClassifier(**self.best_params, verbose=False)
49
+ else:
50
+ self.model = CatBoostClassifier(verbose=0)
51
+ X, y = np.vstack([X_train, X_val]), np.hstack([y_train, y_val])
52
+ self.model.fit(X, y)
53
+
models/gmboost.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.ensemble import GradientBoostingClassifier
2
+ import optuna
3
+ from models.base_model import BaseModel
4
+ from numpy import ndarray
5
+ import numpy as np
6
+
7
+ # === GradientBoosting Implementation ===
8
+ class GradientBoostingModel(BaseModel):
9
+ def __init__(self) -> None:
10
+ super().__init__()
11
+
12
+ def objective(
13
+ self,
14
+ trial: optuna.trial.Trial,
15
+ X_train: ndarray,
16
+ y_train: ndarray,
17
+ X_val: ndarray,
18
+ y_val: ndarray
19
+ ) -> float:
20
+ params = {
21
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
22
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
23
+ "n_estimators": trial.suggest_int("n_estimators", 100, 1000),
24
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
25
+ }
26
+ model = GradientBoostingClassifier(**params)
27
+ model.fit(X_train, y_train)
28
+ return model.score(X_val, y_val)
29
+
30
+ def train(
31
+ self,
32
+ X_train: ndarray,
33
+ y_train: ndarray,
34
+ X_val: ndarray,
35
+ y_val: ndarray,
36
+ use_optuna: bool = False,
37
+ n_trials: int = 20
38
+ ) -> None:
39
+ if use_optuna:
40
+ study = optuna.create_study(direction="maximize")
41
+ study.optimize(lambda trial: self.objective(trial, X_train, y_train, X_val, y_val), n_trials=n_trials, n_jobs=-1)
42
+ self.best_params = study.best_params
43
+ self.model = GradientBoostingClassifier(**self.best_params)
44
+ else:
45
+ self.model = GradientBoostingClassifier()
46
+ X, y = np.vstack([X_train, X_val]), np.hstack([y_train, y_val])
47
+ self.model.fit(X, y)
models/lightgbm.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from models.base_model import BaseModel
3
+ import lightgbm as lgb
4
+ from numpy import ndarray
5
+ import numpy as np
6
+ from sklearn.utils import class_weight
7
+
8
+
9
+ # === LightGBM Implementation ===
10
+ class LightGBMModel(BaseModel):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ def objective(
15
+ self,
16
+ trial: optuna.trial.Trial,
17
+ X_train: ndarray,
18
+ y_train: ndarray,
19
+ X_val: ndarray,
20
+ y_val: ndarray,
21
+ class_weight_type: str = "",
22
+ ) -> float:
23
+ params = {
24
+ "objective": "multiclass",
25
+ "num_class": len(set(y_train)),
26
+ "metric": "multi_logloss",
27
+ "learning_rate": trial.suggest_float("learning_rate", 1e-2, 2e-1, log=True),
28
+ "num_leaves": trial.suggest_int("num_leaves", 50, 130),
29
+ "max_depth": trial.suggest_int("max_depth", 20, 30),
30
+ "min_child_samples": trial.suggest_int("min_child_samples", 20, 50),
31
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
32
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
33
+ "n_jobs": -1,
34
+ "verbosity": -1
35
+ }
36
+ model = lgb.LGBMClassifier(**params)
37
+ # Compute class weights if specified
38
+ if class_weight_type:
39
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y_train)
40
+ model.fit(X_train, y_train, eval_set=[(X_val, y_val)], sample_weight=class_weights)
41
+ else:
42
+ model.fit(X_train, y_train, eval_set=[(X_val, y_val)]) # Fit without class weights
43
+ return model.score(X_val, y_val)
44
+
45
+ def train(
46
+ self,
47
+ X_train: ndarray,
48
+ y_train: ndarray,
49
+ X_val: ndarray,
50
+ y_val: ndarray,
51
+ use_optuna: bool = False,
52
+ n_trials: int = 20,
53
+ class_weight_type: str = "",
54
+ ) -> None:
55
+ if use_optuna:
56
+ study = optuna.create_study(direction="maximize")
57
+ study.optimize(lambda trial: self.objective(trial, X_train, y_train, X_val, y_val, class_weight_type), n_trials=n_trials, n_jobs=-1, show_progress_bar=True)
58
+ self.best_params = study.best_params
59
+ self.model = lgb.LGBMClassifier(**self.best_params, verbosity=-1)
60
+ else:
61
+ self.model = lgb.LGBMClassifier(**self.best_params)
62
+ X, y = np.vstack([X_train, X_val]), np.hstack([y_train, y_val])
63
+ if class_weight_type:
64
+ # Compute class weights if specified
65
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y)
66
+ self.model.fit(X, y, sample_weight=class_weights)
67
+ else:
68
+ # Fit without class weights
69
+ self.model.fit(X, y)
models/svm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.base_model import BaseModel
2
+ from sklearn.svm import SVC
3
+ from sklearn.utils import class_weight
4
+ import optuna
5
+ from numpy import ndarray
6
+ import numpy as np
7
+
8
+
9
+ # === SVM Implementation ===
10
+ class SVMModel(BaseModel):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ def objective(
15
+ self,
16
+ trial: optuna.trial.Trial,
17
+ X_train: ndarray,
18
+ y_train: ndarray,
19
+ X_val: ndarray,
20
+ y_val: ndarray,
21
+ class_weight_type: str = "",
22
+ ) -> float:
23
+ params = {
24
+ "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
25
+ "kernel": trial.suggest_categorical("kernel", ["linear", "poly", "rbf", "sigmoid"]),
26
+ "gamma": trial.suggest_categorical("gamma", ["scale", "auto"]),
27
+ }
28
+ model = SVC(**params, probability=False)
29
+ if class_weight_type:
30
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y_train)
31
+ model.fit(X_train, y_train, sample_weight=class_weights)
32
+ else:
33
+ model.fit(X_train, y_train)
34
+ return model.score(X_val, y_val)
35
+
36
+ def train(
37
+ self,
38
+ X_train: ndarray,
39
+ y_train: ndarray,
40
+ X_val: ndarray,
41
+ y_val: ndarray,
42
+ use_optuna: bool = False,
43
+ n_trials: int = 20,
44
+ class_weight_type: str = "",
45
+ ) -> None:
46
+ if use_optuna:
47
+ study = optuna.create_study(direction="maximize")
48
+ study.optimize(lambda trial: self.objective(trial, X_train, y_train, X_val, y_val, class_weight_type), n_trials=n_trials, n_jobs=-1, show_progress_bar=True)
49
+ self.best_params = study.best_params
50
+ self.model = SVC(**self.best_params, probability=False)
51
+ else:
52
+ self.model = SVC(**self.best_params, probability=False)
53
+ X, y = np.vstack([X_train, X_val]), np.hstack([y_train, y_val])
54
+ if class_weight_type:
55
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y)
56
+ self.model.fit(X, y, sample_weight=class_weights)
57
+ else:
58
+ self.model.fit(X, y)
models/xgboost.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.base_model import BaseModel
2
+ from xgboost import XGBClassifier
3
+ import optuna
4
+ from numpy import ndarray
5
+ import numpy as np
6
+ import cupy as cp
7
+ from sklearn.utils import class_weight
8
+
9
+
10
+ # === XGBoost Implementation ===
11
+ class XGBoostModel(BaseModel):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+ def objective(
16
+ self,
17
+ trial: optuna.trial.Trial,
18
+ X_train: ndarray,
19
+ y_train: ndarray,
20
+ X_val: ndarray,
21
+ y_val: ndarray,
22
+ class_weight_type: str = "",
23
+ ) -> float:
24
+ params = {
25
+ "learning_rate": trial.suggest_float("learning_rate", 1e-2, 0.1, log=True),
26
+ "max_depth": trial.suggest_int("max_depth", 15, 20),
27
+ "n_estimators": trial.suggest_int("n_estimators", 200, 1000),
28
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
29
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
30
+ "gamma": trial.suggest_float("gamma", 0, 5),
31
+ "min_child_weight": trial.suggest_int("min_child_weight", 1, 10),
32
+ "scale_pos_weight": trial.suggest_float("scale_pos_weight", 1, 10),
33
+ "device": "cuda" if cp.cuda.is_available() else "cpu",
34
+ }
35
+ model = XGBClassifier(**params, use_label_encoder=False, eval_metric="mlogloss")
36
+ if class_weight_type:
37
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y_train)
38
+ model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False, sample_weight=class_weights)
39
+ else:
40
+ model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
41
+ return model.score(X_val, y_val)
42
+
43
+ def train(
44
+ self,
45
+ X_train: ndarray,
46
+ y_train: ndarray,
47
+ X_val: ndarray,
48
+ y_val: ndarray,
49
+ use_optuna: bool = False,
50
+ n_trials: int = 20,
51
+ class_weight_type: str = "",
52
+ ) -> None:
53
+ if use_optuna:
54
+ study = optuna.create_study(direction="maximize")
55
+ study.optimize(lambda trial: self.objective(trial, X_train, y_train, X_val, y_val, class_weight_type), n_trials=n_trials, n_jobs=2, show_progress_bar=True)
56
+ self.best_params = study.best_params
57
+ self.model = XGBClassifier(**self.best_params, use_label_encoder=False, eval_metric="mlogloss")
58
+ else:
59
+ self.model = XGBClassifier(**self.best_params, use_label_encoder=False, eval_metric="mlogloss")
60
+ X, y = np.vstack([X_train, X_val]), np.hstack([y_train, y_val])
61
+ if class_weight_type:
62
+ # Source: https://stackoverflow.com/questions/42192227/xgboost-python-classifier-class-weight-option
63
+ class_weights = class_weight.compute_sample_weight(class_weight_type, y=y)
64
+ self.model.fit(X, y, sample_weight=class_weights)
65
+ else:
66
+ self.model.fit(X, y)
modules/evaluate.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import ndarray # For type hinting
2
+ from sklearn.metrics import classification_report
3
+
4
+ # === Evaluation ===
5
+ class PerformanceAnalyzer:
6
+ def evaluate(self, y_true: ndarray, y_pred: ndarray):
7
+ report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
8
+ report_str = classification_report(y_true, y_pred, zero_division=0)
9
+ return report, report_str
modules/feature_extraction.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ import librosa
3
+ import numpy as np
4
+ import parselmouth
5
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
6
+ from config import FEATURES_CACHE
7
+ from pathlib import Path
8
+ from typing import Tuple, Optional
9
+
10
+ # === Feature Extraction ===
11
+ class FeatureExtractor:
12
+ def __init__(self) -> None:
13
+ # self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
14
+ # self.wav2vec_proc = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
15
+ pass
16
+
17
+ def traditional(self, y: np.ndarray, sr: int = 16000, n_mfcc: int = 13) -> np.ndarray:
18
+ # MFCCs (13 is standard for voice tasks)
19
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)
20
+ # delta = librosa.feature.delta(mfcc)
21
+ # delta2 = librosa.feature.delta(mfcc, order=2)
22
+
23
+ # Chroma
24
+ chroma = librosa.feature.chroma_stft(y=y, sr=sr)
25
+
26
+ # Spectral Contrast
27
+ contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
28
+
29
+ # # Tonnetz
30
+ # tonnetz = librosa.feature.tonnetz(y=librosa.effects.harmonic(y), sr=sr)
31
+
32
+ # RMS Energy & ZCR
33
+ rmse = librosa.feature.rms(y=y)
34
+ zcr = librosa.feature.zero_crossing_rate(y)
35
+
36
+ # Spectral Centroid
37
+ centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
38
+
39
+ #* PROSODIC FEATURES
40
+ # Fundamental frequency (F0) using YIN
41
+ try:
42
+ f0 = librosa.yin(y, fmin=50, fmax=500, sr=sr)
43
+ f0_mean = np.nanmean(f0)
44
+ f0_std = np.nanstd(f0)
45
+ f0_max = np.nanmax(f0)
46
+ except:
47
+ f0_mean = f0_std = f0_max = 0
48
+
49
+ # Loudness (Log energy)
50
+ loudness = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
51
+ loudness_mean = np.mean(loudness)
52
+ loudness_std = np.std(loudness)
53
+
54
+ # Rhythm / Duration
55
+ intervals = librosa.effects.split(y, top_db=30)
56
+ durations = [(e - s) / sr for s, e in intervals]
57
+ if durations:
58
+ dur_mean = np.mean(durations)
59
+ dur_std = np.std(durations)
60
+ dur_count = len(durations)
61
+ else:
62
+ dur_mean = dur_std = dur_count = 0
63
+
64
+ # Formant Features
65
+ formants = self.extract_formants(y, sr)
66
+ f1_mean = formants["f1_mean"]
67
+ f1_std = formants["f1_std"]
68
+ f2_mean = formants["f2_mean"]
69
+ f2_std = formants["f2_std"]
70
+
71
+ return np.concatenate([
72
+ mfcc.mean(axis=1),
73
+ # delta.mean(axis=1),
74
+ # delta2.mean(axis=1),
75
+ chroma.mean(axis=1),
76
+ contrast.mean(axis=1),
77
+ # tonnetz.mean(axis=1),
78
+ [rmse.mean()],
79
+ [zcr.mean()],
80
+ [centroid.mean()],
81
+ [f0_mean, f0_std, f0_max],
82
+ [loudness_mean, loudness_std],
83
+ [dur_mean, dur_std, dur_count],
84
+ [f1_mean, f1_std, f2_mean, f2_std],
85
+ ])
86
+
87
+ def extract_formants(self, audio: np.ndarray, sr: int = 16000) -> dict:
88
+ try:
89
+ sound = parselmouth.Sound(audio, sampling_frequency=sr)
90
+ formant = sound.to_formant_burg()
91
+
92
+ duration = sound.duration
93
+ times = np.linspace(0.01, duration - 0.01, 100)
94
+ f1_list, f2_list = [], []
95
+
96
+ for t in times:
97
+ f1 = formant.get_value_at_time(1, t)
98
+ f2 = formant.get_value_at_time(2, t)
99
+ if f1: f1_list.append(f1)
100
+ if f2: f2_list.append(f2)
101
+
102
+ return {
103
+ "f1_mean": np.nanmean(f1_list) if f1_list else 0,
104
+ "f1_std": np.nanstd(f1_list) if f1_list else 0,
105
+ "f2_mean": np.nanmean(f2_list) if f2_list else 0,
106
+ "f2_std": np.nanstd(f2_list) if f2_list else 0,
107
+ }
108
+ except Exception as e:
109
+ print(f"[Formant Error] {e}")
110
+ return {
111
+ "f1_mean": 0, "f1_std": 0,
112
+ "f2_mean": 0, "f2_std": 0,
113
+ }
114
+
115
+ # def wav2vec(self, y: np.ndarray, sr: int = 16000) -> np.ndarray:
116
+ # if sr != 16000:
117
+ # y = librosa.resample(y, orig_sr=sr, target_sr=16000)
118
+ # input_values: torch.Tensor = self.wav2vec_proc(y, return_tensors="pt", sampling_rate=16000).input_values
119
+ # with torch.no_grad():
120
+ # embeddings: torch.Tensor = self.wav2vec_model(input_values).last_hidden_state
121
+ # return embeddings.mean(dim=1).squeeze().numpy()
122
+
123
+ def extract(self, y: np.ndarray, sr: int = 16000, mode: str = "traditional", n_mfcc: int = 40) -> np.ndarray:
124
+ return self.traditional(y, sr, n_mfcc=n_mfcc) if mode == "traditional" else self.wav2vec(y, sr)
125
+
126
+ def cache_features(self, X: np.ndarray, y: np.ndarray, mode: str, version: Optional[int] = None, force_update: bool = False) -> None:
127
+ X_path = FEATURES_CACHE / f"X_{mode}.npy" if version is None else FEATURES_CACHE / f"X_{mode}_v{version}.npy"
128
+ y_path = FEATURES_CACHE / f"y_{mode}.npy" if version is None else FEATURES_CACHE / f"y_{mode}_v{version}.npy"
129
+ if force_update or not X_path.exists() or not y_path.exists():
130
+ np.save(X_path, X)
131
+ np.save(y_path, y)
132
+
133
+ def load_cached_features(self, mode: str, version: Optional[int] = None) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
134
+ X_path = FEATURES_CACHE / f"X_{mode}.npy" if version is None else FEATURES_CACHE / f"X_{mode}_v{version}.npy"
135
+ y_path = FEATURES_CACHE / f"y_{mode}.npy" if version is None else FEATURES_CACHE / f"y_{mode}_v{version}.npy"
136
+ if X_path.exists() and y_path.exists():
137
+ return np.load(X_path), np.load(y_path)
138
+ return None, None
139
+
140
+ def remove_cached_features(self, mode: str, version: Optional[int] = None) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
141
+ X_path = FEATURES_CACHE / f"X_{mode}.npy" if version is None else FEATURES_CACHE / f"X_{mode}_v{version}.npy"
142
+ y_path = FEATURES_CACHE / f"y_{mode}.npy" if version is None else FEATURES_CACHE / f"y_{mode}_v{version}.npy"
143
+ if X_path.exists(): X_path.unlink()
144
+ if y_path.exists(): y_path.unlink()
145
+ return None, None
146
+
147
+ def merge_features(self, mode: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
148
+ X = []
149
+ y = []
150
+ for file in FEATURES_CACHE.glob(f"X_{mode}_*.npy"):
151
+ X.append(np.load(file))
152
+ y.append(np.load(file.with_name(file.name.replace("X_", "y_"))))
153
+ return np.concatenate(X), np.concatenate(y) if y else None
154
+
155
+ def get_latest_version(self, mode: str) -> int:
156
+ versions = [
157
+ int(file.stem.split("_v")[-1])
158
+ for file in FEATURES_CACHE.glob(f"X_{mode}_*.npy")
159
+ if "_v" in file.stem and file.stem.split("_v")[-1].isdigit()
160
+ ]
161
+ return max(versions) if versions else 0
modules/pipelines.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ from datetime import datetime
3
+ from models.lightgbm import LightGBMModel
4
+ from modules.preprocessing import AudioPreprocessor
5
+ from models.base_model import BaseModel
6
+ from typing import Tuple
7
+ import numpy as np
8
+ from typing import Dict, Optional
9
+ from modules.evaluate import PerformanceAnalyzer
10
+
11
+ # === Unified Model Pipeline ===
12
+ class ModelPipeline:
13
+ def __init__(self, model: BaseModel = LightGBMModel) -> None:
14
+ self.model = model()
15
+ self.model_name = self.model.__class__.__name__
16
+ self.best_params = {}
17
+ self.metrics = {}
18
+ self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
19
+ self.preprocessor = AudioPreprocessor()
20
+
21
+ def load_model(self, run_id: str = None, experiment_id: str = None, experiment_name: str = None, best_metric: str = None, maximize: bool = True, additional_tags: Dict[str, str] = None) -> None:
22
+ self.model.load_model_from_run(run_id, experiment_id, experiment_name, best_metric, maximize, additional_tags)
23
+
24
+ def train(
25
+ self,
26
+ X_train: np.ndarray,
27
+ y_train: np.ndarray,
28
+ X_val: np.ndarray,
29
+ y_val: np.ndarray,
30
+ X_test: Optional[np.ndarray] = None,
31
+ y_test: Optional[np.ndarray] = None,
32
+ use_optuna: bool = False,
33
+ n_trials: int = 20,
34
+ class_weight_type: str = "",
35
+ save_run: bool = True,
36
+ experiment_name: Optional[str] = None,
37
+ run_name: str = None,
38
+ mlflow_tags: Optional[Dict[str, str]] = None,
39
+ ) -> (str | dict):
40
+
41
+ try:
42
+ experiment_id = mlflow.set_experiment(experiment_name).experiment_id if experiment_name else None
43
+ except mlflow.exceptions.RestException:
44
+ experiment_id = None
45
+
46
+ with mlflow.start_run(run_name=run_name or f"{self.model_name}_{self.run_id}", experiment_id=experiment_id):
47
+ self.model.train(X_train, y_train, X_val, y_val, use_optuna=use_optuna, n_trials=n_trials, class_weight_type=class_weight_type)
48
+
49
+ ## If X_test and y_test are not provided, use X_val and y_val for testing
50
+ if X_test is None or y_test is None:
51
+ X_test, y_test = X_val, y_val
52
+
53
+ y_pred_test = self.model.predict(X_test)
54
+ if save_run:
55
+ metrics = self.model.log_mlflow(y_test, y_pred_test)
56
+ mlflow.set_tags(mlflow_tags or {})
57
+ else:
58
+ metrics = self.model.classification_report(y_test, y_pred_test)
59
+
60
+ return metrics
61
+
62
+ def load_model_from_registry(self, model_name: str, version: int = None) -> None:
63
+ self.model.load_model_from_registry(model_name, version)
64
+
65
+ def register_model(
66
+ self,
67
+ run_id: str,
68
+ model_name: str = None,
69
+ tags: Dict[str, str] = None
70
+ ) -> None:
71
+ self.model.register_model(run_id, model_name, tags)
72
+
73
+ def predict(self, X: np.ndarray) -> np.ndarray:
74
+ return self.model.predict(X)
75
+
76
+ def score(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
77
+ return self.model.score(X, y)
78
+
79
+ def classification_report(self, X: np.ndarray, y: np.ndarray) -> str:
80
+ evaluator = PerformanceAnalyzer()
81
+ y_pred = self.model.predict(X)
82
+ report, report_str = evaluator.evaluate(y, y_pred)
83
+ return report_str
modules/preprocessing.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ from config import PREPROCESSED_CACHE
4
+ import noisereduce as nr
5
+ from sklearn.model_selection import train_test_split
6
+ from typing import Optional
7
+ from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
8
+ from imblearn.combine import SMOTETomek
9
+ import random
10
+ from collections import Counter
11
+
12
+
13
+ # === Preprocessing ===
14
+ class AudioPreprocessor:
15
+ def __init__(self):
16
+ self.augment_pipeline = Compose([
17
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.0),
18
+ TimeStretch(min_rate=0.9, max_rate=1.1, p=1.0),
19
+ PitchShift(min_semitones=-2, max_semitones=2, p=1.0),
20
+ Shift(min_shift=-0.2, max_shift=0.2, p=1.0),
21
+ ])
22
+ self.augment_prob_by_class = { # set your probabilities here
23
+ 0: 0.01,
24
+ 1: 0.8,
25
+ 2: 0.9,
26
+ 3: 0.95
27
+ }
28
+
29
+ def load_audio(self, path: str, sr: int = 16000) -> Optional[np.ndarray]:
30
+ try:
31
+ y, _ = librosa.load(path, sr=sr)
32
+ return y
33
+ except Exception as e:
34
+ print(f"[ERROR] {path}: {e}")
35
+ return None
36
+
37
+ def preprocess(self, y: Optional[np.ndarray], sr: int = 16000, padding: bool = False, label: Optional[int] = None) -> Optional[np.ndarray]:
38
+ if y is None: return None
39
+
40
+ # Remove silence
41
+ intervals = librosa.effects.split(y, top_db=20)
42
+ y_trimmed = np.concatenate([y[start:end] for start, end in intervals])
43
+
44
+ # Normalize volume: Volume variations, Different microphone quality
45
+ y_norm = librosa.util.normalize(y_trimmed)
46
+
47
+ # Noise reduction
48
+ y_denoised = nr.reduce_noise(y=y_norm, sr=sr, n_jobs=-1)
49
+
50
+
51
+ # Conditional augmentation
52
+ if label is not None and random.random() < self.augment_prob_by_class.get(label, 0.5):
53
+ y_augmented = self.augment_pipeline(samples=y_denoised, sample_rate=sr)
54
+ else:
55
+ y_augmented = y_denoised
56
+
57
+ # Padding
58
+ if padding:
59
+ desired_len = sr * 5
60
+ if len(y_augmented) > desired_len:
61
+ y_augmented = y_augmented[:desired_len]
62
+ else:
63
+ y_augmented = np.pad(y_augmented, (0, max(0, desired_len - len(y_augmented))))
64
+
65
+ return y_augmented
66
+
67
+ def cache_preprocessed(self, idx: str, y: np.ndarray, force_update: bool = False) -> None:
68
+ path = PREPROCESSED_CACHE / f"{idx}.npy"
69
+ if force_update or not path.exists():
70
+ np.save(path, y)
71
+
72
+ def load_cached_preprocessed(self, idx: str) -> Optional[np.ndarray]:
73
+ try:
74
+ path = PREPROCESSED_CACHE / f"{idx}.npy"
75
+ return np.load(path) if path.exists() else None
76
+ except Exception as e:
77
+ print(f"[ERROR] {path}: {e}")
78
+ return None
79
+
80
+ def split_data(self, X, y, train_size: float = 0.75, val_size: float = 0.1, random_state: int = 42, stratify: bool = True,
81
+ apply_smote: bool = False, smote_percentage: float = 0.7, verbose = True) -> tuple:
82
+
83
+ # First split: train vs (val + test)
84
+ stratify_option = y if stratify else None
85
+ X_train, X_temp, y_train, y_temp = train_test_split(
86
+ X, y, train_size=train_size, random_state=random_state, stratify=stratify_option
87
+ )
88
+
89
+ # Second split: validation vs test
90
+ stratify_temp = y_temp if stratify else None
91
+ X_val, X_test, y_val, y_test = train_test_split(
92
+ X_temp, y_temp, train_size=val_size / (1 - train_size), random_state=random_state, stratify=stratify_temp
93
+ )
94
+
95
+ if apply_smote:
96
+ if verbose: print(f"[INFO] Class distribution before SMOTE: {Counter(y_train)}")
97
+
98
+ class_counts = Counter(y_train)
99
+ majority_class_count = max(class_counts.values())
100
+ sampling_strategy = {
101
+ cls: int(majority_class_count * smote_percentage) for cls in class_counts.keys()
102
+ }
103
+ sampling_strategy[0] = majority_class_count
104
+
105
+ resampler = SMOTETomek(
106
+ random_state=random_state,
107
+ n_jobs=-1,
108
+ sampling_strategy=sampling_strategy # Specify sampling strategy as a dictionary
109
+ )
110
+ X_train, y_train = resampler.fit_resample(X_train, y_train)
111
+
112
+ if verbose: print(f"[INFO] Class distribution after SMOTE: {Counter(y_train)}")
113
+
114
+ return X_train, y_train, X_val, y_val, X_test, y_test
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
requirements_docker.txt ADDED
Binary file (6.11 kB). View file
 
src/app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import time
5
+ from pathlib import Path
6
+ from modules.preprocessing import AudioPreprocessor
7
+ from modules.feature_extraction import FeatureExtractor
8
+ from models.lightgbm import LightGBMModel
9
+ from models.xgboost import XGBoostModel
10
+ from modules.pipelines import ModelPipeline
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # Constants
15
+ MODEL_NAME = {
16
+ "XGBoost": XGBoostModel,
17
+ "LightGBM": LightGBMModel,
18
+ }
19
+
20
+ # UI Layout
21
+ st.set_page_config(page_title="Audio Classification App", layout="centered")
22
+ st.title("🎧 Audio Classification")
23
+ st.markdown("Upload an `.mp3` or `.wav` file and select a model to get a prediction.")
24
+
25
+ # File Upload
26
+ uploaded_file = st.file_uploader("Upload your audio file", type=["wav", "mp3"])
27
+
28
+ # Model Selection
29
+ selected_model_name = st.selectbox("Select a model", list(MODEL_NAME.keys()))
30
+
31
+ # Process if file is uploaded
32
+ if uploaded_file is not None:
33
+ # Save uploaded file temporarily
34
+ with tempfile.TemporaryDirectory() as tmpdir:
35
+ audio_path = os.path.join(tmpdir, "input_audio.wav")
36
+ with open(audio_path, "wb") as f:
37
+ f.write(uploaded_file.read())
38
+
39
+ # Preprocess, extract features, predict
40
+ st.info("🔍 Processing audio...")
41
+ try:
42
+ # Initialize pipeline
43
+ preprocessor = AudioPreprocessor()
44
+ extractor = FeatureExtractor()
45
+ model = ModelPipeline(model=MODEL_NAME[selected_model_name])
46
+ model.load_model_from_registry(model_name=selected_model_name)
47
+
48
+ # Preprocess and predict
49
+ start_time = time.time()
50
+ y = preprocessor.preprocess(preprocessor.load_audio(audio_path, sr=16000))
51
+ if y is None:
52
+ st.error("Audio preprocessing failed.")
53
+ else:
54
+ x = extractor.extract(y, sr=16000, mode="traditional", n_mfcc=20)
55
+ pred = model.predict([x])[0]
56
+ elapsed = time.time() - start_time
57
+
58
+ # Display result
59
+ st.success(f"✅ Predicted Class: `{pred}`")
60
+ st.write(f"Inference time: `{elapsed:.4f}` seconds")
61
+ except Exception as e:
62
+ st.error(f"❌ An error occurred: {str(e)}")
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import numpy as np
5
+ import mlflow
6
+ import mlflow.sklearn
7
+ from mlflow.tracking import MlflowClient
8
+
9
+ from config import AUDIO_CACHE, FEATURES_CACHE, MODELS_DIR
10
+
11
+
12
+ #* Audio Caching
13
+ def cache_audio(data: np.ndarray, filename: str = "default", force_update=False):
14
+ path = AUDIO_CACHE / f"{filename}.npy"
15
+ if force_update or not path.exists():
16
+ np.save(path, data)
17
+
18
+ def load_cached_audio(filename: str = "default"):
19
+ path = AUDIO_CACHE / f"{filename}.npy"
20
+ return np.load(path) if path.exists() else None
21
+
22
+
23
+ #* Feature Caching
24
+ def cache_features(X, y, feature_name: str = "features", label_name: str = "labels", force_update=False):
25
+ X_path = FEATURES_CACHE / f"{feature_name}.npy"
26
+ y_path = FEATURES_CACHE / f"{label_name}.npy"
27
+ if force_update or not X_path.exists() or not y_path.exists():
28
+ np.save(X_path, X)
29
+ np.save(y_path, y)
30
+
31
+ def load_cached_features(feature_name: str = "features", label_name: str = "labels"):
32
+ X_path = FEATURES_CACHE / f"{feature_name}.npy"
33
+ y_path = FEATURES_CACHE / f"{label_name}.npy"
34
+ if X_path.exists() and y_path.exists():
35
+ return np.load(X_path), np.load(y_path)
36
+ return None, None
37
+
38
+
39
+ #* Model Caching
40
+ def cache_model(model, best_params: dict, model_name: str = None, save_option='default', force_update=False):
41
+ model_class = model.__class__.__name__
42
+ model_folder = MODELS_DIR / (model_name or model_class)
43
+ model_folder.mkdir(exist_ok=True)
44
+
45
+ model_path = model_folder / ("model.pkl" if save_option == "joblib" else "model.cbm")
46
+ params_path = model_folder / "best_params.json"
47
+
48
+ # Save model
49
+ if force_update or not model_path.exists():
50
+ if save_option == "joblib":
51
+ joblib.dump(model, model_path)
52
+ else:
53
+ model.save_model(model_path)
54
+
55
+ # Save best params
56
+ if force_update or not params_path.exists():
57
+ with open(params_path, "w") as f:
58
+ json.dump(best_params, f, indent=2)
59
+
60
+ def load_model(model_class, model_name: str = None, save_option='default'):
61
+ model_class_name = model_class.__name__
62
+ model_folder = MODELS_DIR / (model_name or model_class_name)
63
+
64
+ model_path = model_folder / ("model.pkl" if save_option == "joblib" else "model.cbm")
65
+ params_path = model_folder / "best_params.json"
66
+
67
+ if not model_path.exists() or not params_path.exists():
68
+ return None, None
69
+
70
+ with open(params_path, "r") as f:
71
+ best_params = json.load(f)
72
+
73
+ if save_option == "joblib":
74
+ model = joblib.load(model_path)
75
+ else:
76
+ model = model_class()
77
+ model.load_model(model_path)
78
+
79
+ return model, best_params
80
+
81
+
82
+ # === Utility: MLflow Helpers ===
83
+ def list_top_mlflow_runs(metric="f1-score", top_n=5):
84
+ client = MlflowClient()
85
+ runs = mlflow.search_runs(experiment_ids=["0"], order_by=[f"metrics.weighted avg.{metric} DESC"])
86
+ return runs[["run_id", "params.model_type", f"metrics.weighted avg.{metric}"]].head(top_n)
87
+
88
+ def load_mlflow_model(run_id):
89
+ client = MlflowClient()
90
+ run = client.get_run(run_id)
91
+ model = mlflow.sklearn.load_model(f"runs:/{run_id}/model")
92
+ params = run.data.params
93
+ metrics = run.data.metrics
94
+ return model, params, metrics
95
+