import os import glob import torch import torchaudio import torchvision from torch.utils.data import Dataset from concurrent.futures import ThreadPoolExecutor from preprocess import process_audio_data, process_image_data, resample_rate class PreprocessedDataset(Dataset): def __init__(self, data_dir): self.data_dir = data_dir self.samples = [ os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt") ] def __len__(self): return len(self.samples) def __getitem__(self, idx): sample_path = self.samples[idx] mfcc, image, label = torch.load(sample_path) # Process data mfcc = process_audio_data(mfcc, resample_rate) image = process_image_data(image) return mfcc, image, label def load_audio_file(audio_path): if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") try: # Try the default torchaudio loader first waveform, sample_rate = torchaudio.load(audio_path) except Exception as e: print(f"Warning: Could not load {audio_path} with torchaudio: {e}") # Fall back to librosa (you'll need to install it: pip install librosa) try: import librosa import numpy as np waveform_np, sample_rate = librosa.load(audio_path, sr=None) # Convert to torch tensor with shape [1, length] to match torchaudio format waveform = torch.from_numpy(waveform_np[np.newaxis, :]).float() print(f"Successfully loaded with librosa: {audio_path}") except Exception as final_e: raise RuntimeError(f"Failed to load audio file {audio_path} with all available methods: {final_e}") return waveform, sample_rate def load_image_file(image_path): if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") image = torchvision.io.read_image(image_path) return image def process_sample(sample_path, save_dir): # Recursively search for audio and image files audio_files = [] image_files = [] # Walk through all subdirectories for root, _, files in os.walk(sample_path): for file in files: if file.lower().endswith(('.wav', '.mp3', '.flac')): audio_files.append(os.path.join(root, file)) elif file.lower().endswith(('.jpg', '.jpeg', '.png')): image_files.append(os.path.join(root, file)) if not audio_files: print(f"Warning: No audio file found in {sample_path}. Skipping this sample.") return if not image_files: print(f"Warning: No image file found in {sample_path}. Skipping this sample.") return # Use the first found audio and image files audio_path = audio_files[0] image_path = image_files[0] print(f"Processing audio: {audio_path}") print(f"Processing image: {image_path}") waveform, sample_rate = load_audio_file(audio_path) image = load_image_file(image_path) # Process data mfcc = process_audio_data(waveform, sample_rate) processed_image = process_image_data(image) # Save processed data save_path = os.path.join(save_dir, f"{os.path.basename(sample_path)}.pt") torch.save((mfcc, processed_image, float(os.path.basename(sample_path))), save_path) print(f"Processed and saved: {save_path}") def process_and_save(data_dir, save_dir): os.makedirs(save_dir, exist_ok=True) sample_paths = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] if not sample_paths: print(f"Warning: No sample directories found in {data_dir}") return print(f"Found {len(sample_paths)} sample directories to process") successful = 0 failed = 0 with ThreadPoolExecutor() as executor: futures = [executor.submit(process_sample, path, save_dir) for path in sample_paths] for future in futures: try: future.result() # Wait for all threads to complete successful += 1 except Exception as e: failed += 1 print(f"Error processing a sample: {e}") print(f"Processing complete. Successfully processed: {successful}, Failed: {failed}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Preprocess the dataset") parser.add_argument( "--data_dir", type=str, default="cleaned", help="Path to the cleaned dataset directory", ) parser.add_argument( "--save_dir", type=str, default="processed", help="Path to the processed dataset directory", ) args = parser.parse_args() print(f"Processing dataset from: {args.data_dir}") print(f"Saving processed data to: {args.save_dir}") process_and_save(args.data_dir, args.save_dir) print("Preprocessing complete")