import os import time import torch import torchaudio import torchvision import numpy as np from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter import sys # Add parent directory to path to import the preprocess functions sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from preprocess import process_audio_data, process_image_data # Print library versions print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") # Device selection device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"\033[92mINFO\033[0m: Using device: {device}") # Hyperparameters batch_size = 16 epochs = 2 learning_rate = 0.0001 # Model save directory os.makedirs("models/", exist_ok=True) class WatermelonDataset(Dataset): def __init__(self, data_dir): self.data_dir = data_dir self.samples = [] # Walk through the directory structure for sweetness_dir in os.listdir(data_dir): sweetness = float(sweetness_dir) sweetness_path = os.path.join(data_dir, sweetness_dir) if os.path.isdir(sweetness_path): for id_dir in os.listdir(sweetness_path): id_path = os.path.join(sweetness_path, id_dir) if os.path.isdir(id_path): audio_file = os.path.join(id_path, f"{id_dir}.wav") image_file = os.path.join(id_path, f"{id_dir}.jpg") if os.path.exists(audio_file) and os.path.exists(image_file): self.samples.append((audio_file, image_file, sweetness)) print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}") def __len__(self): return len(self.samples) def __getitem__(self, idx): audio_path, image_path, label = self.samples[idx] # Load and process audio try: waveform, sample_rate = torchaudio.load(audio_path) mfcc = process_audio_data(waveform, sample_rate) # Load and process image image = torchvision.io.read_image(image_path) image = image.float() processed_image = process_image_data(image) return mfcc, processed_image, torch.tensor(label).float() except Exception as e: print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}") # Return a fallback sample or skip this sample # For simplicity, we'll return the first sample again if idx == 0: # Prevent infinite recursion raise e return self.__getitem__(0) class WatermelonModel(torch.nn.Module): def __init__(self): super(WatermelonModel, self).__init__() # LSTM for audio features self.lstm = torch.nn.LSTM( input_size=376, hidden_size=64, num_layers=2, batch_first=True ) self.lstm_fc = torch.nn.Linear( 64, 128 ) # Convert LSTM output to 128-dim for merging # ResNet50 for image features self.resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) self.resnet.fc = torch.nn.Linear( self.resnet.fc.in_features, 128 ) # Convert ResNet output to 128-dim for merging # Fully connected layers for final prediction self.fc1 = torch.nn.Linear(256, 64) self.fc2 = torch.nn.Linear(64, 1) self.relu = torch.nn.ReLU() def forward(self, mfcc, image): # LSTM branch lstm_output, _ = self.lstm(mfcc) lstm_output = lstm_output[:, -1, :] # Use the output of the last time step lstm_output = self.lstm_fc(lstm_output) # ResNet branch resnet_output = self.resnet(image) # Concatenate LSTM and ResNet outputs merged = torch.cat((lstm_output, resnet_output), dim=1) # Fully connected layers output = self.relu(self.fc1(merged)) output = self.fc2(output) return output def train_model(data_dir, output_dir="models/"): # Create dataset dataset = WatermelonDataset(data_dir) n_samples = len(dataset) # Split dataset train_size = int(0.7 * n_samples) val_size = int(0.2 * n_samples) test_size = n_samples - train_size - val_size train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( dataset, [train_size, val_size, test_size] ) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # Initialize model model = WatermelonModel().to(device) # Loss function and optimizer criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # TensorBoard writer = SummaryWriter("runs/") global_step = 0 print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs") print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}") print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}") print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}") print(f"\033[92mINFO\033[0m: Batch size: {batch_size}") # Training loop for epoch in range(epochs): print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})") model.train() running_loss = 0.0 for i, (mfcc, image, label) in enumerate(train_loader): try: mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) optimizer.zero_grad() output = model(mfcc, image) label = label.view(-1, 1).float() loss = criterion(output, label) loss.backward() optimizer.step() running_loss += loss.item() writer.add_scalar("Training Loss", loss.item(), global_step) global_step += 1 if i % 10 == 0: print(f"\033[92mINFO\033[0m: Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}") except Exception as e: print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}") continue # Validation phase model.eval() val_loss = 0.0 with torch.no_grad(): for i, (mfcc, image, label) in enumerate(val_loader): try: mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) output = model(mfcc, image) label = label.view(-1, 1).float() loss = criterion(output, label) val_loss += loss.item() except Exception as e: print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}") continue avg_train_loss = running_loss / len(train_loader) if len(train_loader) > 0 else float('inf') avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf') # Record validation loss writer.add_scalar("Validation Loss", avg_val_loss, epoch) print( f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}, " f"Validation Loss: {avg_val_loss:.4f}" ) # Save model checkpoint timestamp = time.strftime("%Y%m%d-%H%M%S") model_path = os.path.join(output_dir, f"model_{epoch+1}_{timestamp}.pt") torch.save(model.state_dict(), model_path) print( f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}" ) # Save final model final_model_path = os.path.join(output_dir, "watermelon_model_final.pt") torch.save(model.state_dict(), final_model_path) print(f"\033[92mINFO\033[0m: Final model saved: {final_model_path}") print(f"\033[92mINFO\033[0m: Training complete") return final_model_path if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Train the Watermelon Sweetness Prediction Model") parser.add_argument( "--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory" ) parser.add_argument( "--output_dir", type=str, default="models/", help="Directory to save model checkpoints and the final model" ) args = parser.parse_args() # Ensure output directory exists os.makedirs(args.output_dir, exist_ok=True) # Train the model final_model_path = train_model(args.data_dir, args.output_dir) print(f"\033[92mINFO\033[0m: Training completed successfully!") print(f"\033[92mINFO\033[0m: Final model saved at: {final_model_path}")