import argparse import os import time from glob import glob import pandas as pd from pathlib import Path from modules.preprocessing import AudioPreprocessor from modules.feature_extraction import FeatureExtractor from models.lightgbm import LightGBMModel from models.xgboost import XGBoostModel from modules.pipelines import ModelPipeline import warnings warnings.filterwarnings("ignore") MODEL_NAME = { "XGBoost": XGBoostModel, "LightGBM": LightGBMModel, } def run_batch_inference(model, input_folder, output_folder, sr=16000, feature_mode="traditional"): preprocessor = AudioPreprocessor() extractor = FeatureExtractor() # Sort files in the correct order files = sorted(glob(os.path.join(input_folder, "*")), key=lambda x: int(Path(x).stem)) # Overwrite if exsists results_path = os.path.join(output_folder, "results.txt") time_path = os.path.join(output_folder, "time.txt") with open(results_path, "w") as f: pass with open(time_path, "w") as f: pass pred = 0 for file in files: # Measure inference time start_time = time.time() y = preprocessor.preprocess(preprocessor.load_audio(str(file), sr=sr)) if y is not None: x = extractor.extract(y, sr=sr, mode=feature_mode, n_mfcc=20) pred = model.predict([x])[0] end_time = time.time() # Save results to results.txt with open(results_path, "a") as f: f.write(f"{pred}\n") # Save inference time to time.txt with open(time_path, "a") as f: f.write(f"{end_time - start_time:.6f}\n") print(f"✅ Results saved to {results_path}") print(f"✅ Inference time saved to {time_path}") def main(input_path, model_name, output_folder): if not os.path.exists(input_path): raise FileNotFoundError(f"Input path {input_path} does not exist.") if model_name not in MODEL_NAME.keys(): raise ValueError(f"Model name {model_name} is not valid. Choose from {list(MODEL_NAME.keys())}.") if not os.path.exists(output_folder): os.makedirs(output_folder, exist_ok=True) print(f"Output folder {output_folder} created.") model = ModelPipeline(model=MODEL_NAME[model_name]) model.load_model_from_registry(model_name=model_name) print("✅ Model loaded successfully") run_batch_inference(model, input_path, output_folder) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input-path', type=str, default="/data", help="Path to the input folder containing test audio files. Default is '/data'.") parser.add_argument('--model-name', type=str, default="XGBoost", help="Name of the model to use for inference. Default is 'XGBoost'.") parser.add_argument('--team_id', type=str, required=True, help="Team ID for output folder.") args = parser.parse_args() output_folder = os.path.join("/results", args.team_id) print(f"Input Path: {args.input_path}") print(f"Model Name: {args.model_name}") print(f"Output Folder: {output_folder}") main(args.input_path, args.model_name, output_folder)