watermelon2 / evaluate_backbones.py
Xalphinions's picture
Upload folder using huggingface_hub
6f4e394 verified
import os
import torch
import torchaudio
import torchvision
import numpy as np
import time
import json
from torch.utils.data import Dataset, DataLoader
import sys
from tqdm import tqdm
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Print library versions
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
# Device selection
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
# Hyperparameters
batch_size = 16
epochs = 1 # Just one epoch for evaluation
learning_rate = 0.0001
class WatermelonDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.samples = []
# Walk through the directory structure
for sweetness_dir in os.listdir(data_dir):
sweetness = float(sweetness_dir)
sweetness_path = os.path.join(data_dir, sweetness_dir)
if os.path.isdir(sweetness_path):
for id_dir in os.listdir(sweetness_path):
id_path = os.path.join(sweetness_path, id_dir)
if os.path.isdir(id_path):
audio_file = os.path.join(id_path, f"{id_dir}.wav")
image_file = os.path.join(id_path, f"{id_dir}.jpg")
if os.path.exists(audio_file) and os.path.exists(image_file):
self.samples.append((audio_file, image_file, sweetness))
print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
audio_path, image_path, label = self.samples[idx]
# Load and process audio
try:
waveform, sample_rate = torchaudio.load(audio_path)
mfcc = process_audio_data(waveform, sample_rate)
# Load and process image
image = torchvision.io.read_image(image_path)
image = image.float()
processed_image = process_image_data(image)
return mfcc, processed_image, torch.tensor(label).float()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
# Return a fallback sample or skip this sample
# For simplicity, we'll return the first sample again
if idx == 0: # Prevent infinite recursion
raise e
return self.__getitem__(0)
# Define available backbone models
IMAGE_BACKBONES = {
"resnet50": {
"model": torchvision.models.resnet50,
"weights": torchvision.models.ResNet50_Weights.DEFAULT,
"output_dim": lambda model: model.fc.in_features
},
"efficientnet_b0": {
"model": torchvision.models.efficientnet_b0,
"weights": torchvision.models.EfficientNet_B0_Weights.DEFAULT,
"output_dim": lambda model: model.classifier[1].in_features
},
"efficientnet_b3": {
"model": torchvision.models.efficientnet_b3,
"weights": torchvision.models.EfficientNet_B3_Weights.DEFAULT,
"output_dim": lambda model: model.classifier[1].in_features
}
}
AUDIO_BACKBONES = {
"lstm": {
"model": lambda input_size, hidden_size: torch.nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
),
"output_dim": lambda hidden_size: hidden_size
},
"gru": {
"model": lambda input_size, hidden_size: torch.nn.GRU(
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
),
"output_dim": lambda hidden_size: hidden_size
},
"bidirectional_lstm": {
"model": lambda input_size, hidden_size: torch.nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True
),
"output_dim": lambda hidden_size: hidden_size * 2 # * 2 because bidirectional
},
"transformer": {
"model": lambda input_size, hidden_size: torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
d_model=input_size, nhead=8, dim_feedforward=hidden_size, batch_first=True
),
num_layers=2
),
"output_dim": lambda hidden_size: 376 # Using input_size (mfcc dimensions)
}
}
class WatermelonModelModular(torch.nn.Module):
def __init__(self, image_backbone_name, audio_backbone_name, audio_hidden_size=128):
super(WatermelonModelModular, self).__init__()
# Audio backbone setup
self.audio_backbone_name = audio_backbone_name
self.audio_hidden_size = audio_hidden_size
self.audio_input_size = 376 # From MFCC dimensions
audio_config = AUDIO_BACKBONES[audio_backbone_name]
self.audio_backbone = audio_config["model"](self.audio_input_size, self.audio_hidden_size)
audio_output_dim = audio_config["output_dim"](self.audio_hidden_size)
self.audio_fc = torch.nn.Linear(audio_output_dim, 128)
# Image backbone setup
self.image_backbone_name = image_backbone_name
image_config = IMAGE_BACKBONES[image_backbone_name]
self.image_backbone = image_config["model"](weights=image_config["weights"])
# Replace final layer for all image backbones to get features
if image_backbone_name.startswith("resnet"):
self.image_output_dim = image_config["output_dim"](self.image_backbone)
self.image_backbone.fc = torch.nn.Identity()
elif image_backbone_name.startswith("efficientnet"):
self.image_output_dim = image_config["output_dim"](self.image_backbone)
self.image_backbone.classifier = torch.nn.Identity()
elif image_backbone_name.startswith("convnext"):
self.image_output_dim = image_config["output_dim"](self.image_backbone)
self.image_backbone.classifier = torch.nn.Identity()
elif image_backbone_name.startswith("swin"):
self.image_output_dim = image_config["output_dim"](self.image_backbone)
self.image_backbone.head = torch.nn.Identity()
self.image_fc = torch.nn.Linear(self.image_output_dim, 128)
# Fully connected layers for final prediction
self.fc1 = torch.nn.Linear(256, 64)
self.fc2 = torch.nn.Linear(64, 1)
self.relu = torch.nn.ReLU()
def forward(self, mfcc, image):
# Audio backbone processing
if self.audio_backbone_name == "lstm" or self.audio_backbone_name == "gru":
audio_output, _ = self.audio_backbone(mfcc)
audio_output = audio_output[:, -1, :] # Use the output of the last time step
elif self.audio_backbone_name == "bidirectional_lstm":
audio_output, _ = self.audio_backbone(mfcc)
audio_output = audio_output[:, -1, :] # Use the output of the last time step
elif self.audio_backbone_name == "transformer":
audio_output = self.audio_backbone(mfcc)
audio_output = audio_output.mean(dim=1) # Average pooling over sequence length
audio_output = self.audio_fc(audio_output)
# Image backbone processing
image_output = self.image_backbone(image)
image_output = self.image_fc(image_output)
# Concatenate audio and image outputs
merged = torch.cat((audio_output, image_output), dim=1)
# Fully connected layers
output = self.relu(self.fc1(merged))
output = self.fc2(output)
return output
def evaluate_model(data_dir, image_backbone, audio_backbone, audio_hidden_size=128, save_model_dir=None):
# Adjust batch size based on model complexity to avoid OOM errors
adjusted_batch_size = batch_size
# Models that typically require more memory get smaller batch sizes
if image_backbone in ["swin_b", "convnext_base"] or audio_backbone in ["transformer", "bidirectional_lstm"]:
adjusted_batch_size = max(4, batch_size // 2) # At least batch size of 4, but reduce by half if needed
print(f"\033[92mINFO\033[0m: Adjusted batch size to {adjusted_batch_size} for larger model")
# Create dataset
dataset = WatermelonDataset(data_dir)
n_samples = len(dataset)
# Split dataset
train_size = int(0.7 * n_samples)
val_size = int(0.2 * n_samples)
test_size = n_samples - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=adjusted_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=adjusted_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=adjusted_batch_size, shuffle=False)
# Initialize model
model = WatermelonModelModular(image_backbone, audio_backbone, audio_hidden_size).to(device)
# Loss function and optimizer
criterion = torch.nn.MSELoss()
mae_criterion = torch.nn.L1Loss() # For MAE evaluation
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(f"\033[92mINFO\033[0m: Evaluating model with {image_backbone} (image) and {audio_backbone} (audio)")
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
print(f"\033[92mINFO\033[0m: Batch size: {adjusted_batch_size}")
# Training loop
print(f"\033[92mINFO\033[0m: Training for evaluation...")
model.train()
running_loss = 0.0
# Wrap with tqdm for progress visualization
train_iterator = tqdm(train_loader, desc="Training")
for i, (mfcc, image, label) in enumerate(train_iterator):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
optimizer.zero_grad()
output = model(mfcc, image)
label = label.view(-1, 1).float()
loss = criterion(output, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
# Clear memory after each batch
if device.type == 'cuda':
del mfcc, image, label, output, loss
torch.cuda.empty_cache()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
# Clear memory in case of error
if device.type == 'cuda':
torch.cuda.empty_cache()
continue
# Validation phase
print(f"\033[92mINFO\033[0m: Validating...")
model.eval()
val_loss = 0.0
val_mae = 0.0
val_iterator = tqdm(val_loader, desc="Validation")
with torch.no_grad():
for i, (mfcc, image, label) in enumerate(val_iterator):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
output = model(mfcc, image)
label = label.view(-1, 1).float()
# Calculate MSE loss
loss = criterion(output, label)
val_loss += loss.item()
# Calculate MAE
mae = mae_criterion(output, label)
val_mae += mae.item()
val_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
# Clear memory after each batch
if device.type == 'cuda':
del mfcc, image, label, output, loss, mae
torch.cuda.empty_cache()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
# Clear memory in case of error
if device.type == 'cuda':
torch.cuda.empty_cache()
continue
avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
avg_val_mae = val_mae / len(val_loader) if len(val_loader) > 0 else float('inf')
# Test phase
print(f"\033[92mINFO\033[0m: Testing...")
model.eval()
test_loss = 0.0
test_mae = 0.0
test_iterator = tqdm(test_loader, desc="Testing")
with torch.no_grad():
for i, (mfcc, image, label) in enumerate(test_iterator):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
output = model(mfcc, image)
label = label.view(-1, 1).float()
# Calculate MSE loss
loss = criterion(output, label)
test_loss += loss.item()
# Calculate MAE
mae = mae_criterion(output, label)
test_mae += mae.item()
test_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
# Clear memory after each batch
if device.type == 'cuda':
del mfcc, image, label, output, loss, mae
torch.cuda.empty_cache()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
# Clear memory in case of error
if device.type == 'cuda':
torch.cuda.empty_cache()
continue
avg_test_loss = test_loss / len(test_loader) if len(test_loader) > 0 else float('inf')
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
results = {
"image_backbone": image_backbone,
"audio_backbone": audio_backbone,
"validation_mse": avg_val_loss,
"validation_mae": avg_val_mae,
"test_mse": avg_test_loss,
"test_mae": avg_test_mae
}
print(f"\033[92mINFO\033[0m: Evaluation Results:")
print(f"Image Backbone: {image_backbone}")
print(f"Audio Backbone: {audio_backbone}")
print(f"Validation MSE: {avg_val_loss:.4f}")
print(f"Validation MAE: {avg_val_mae:.4f}")
print(f"Test MSE: {avg_test_loss:.4f}")
print(f"Test MAE: {avg_test_mae:.4f}")
# Save model if save_model_dir is provided
if save_model_dir:
os.makedirs(save_model_dir, exist_ok=True)
model_filename = f"{image_backbone}_{audio_backbone}_model.pt"
model_path = os.path.join(save_model_dir, model_filename)
torch.save(model.state_dict(), model_path)
print(f"\033[92mINFO\033[0m: Model saved to {model_path}")
# Add model path to results
results["model_path"] = model_path
# Clean up memory before returning
if device.type == 'cuda':
del model, optimizer, criterion, mae_criterion
torch.cuda.empty_cache()
return results
def evaluate_all_combinations(data_dir, image_backbones=None, audio_backbones=None, save_model_dir="test_models", results_file="backbone_evaluation_results.json"):
if image_backbones is None:
image_backbones = list(IMAGE_BACKBONES.keys())
if audio_backbones is None:
audio_backbones = list(AUDIO_BACKBONES.keys())
# Create directory for saving models
if save_model_dir:
os.makedirs(save_model_dir, exist_ok=True)
# Load previous results if the file exists
results = []
evaluated_combinations = set()
if os.path.exists(results_file):
try:
with open(results_file, 'r') as f:
results = json.load(f)
evaluated_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in results}
print(f"\033[92mINFO\033[0m: Loaded {len(results)} previous results from {results_file}")
except Exception as e:
print(f"\033[91mERR!\033[0m: Error loading previous results from {results_file}: {e}")
results = []
evaluated_combinations = set()
else:
print(f"\033[93mWARN\033[0m: Results file '{results_file}' does not exist. Starting with empty results.")
# Create combinations to evaluate, skipping any that have already been evaluated
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones
if (img, aud) not in evaluated_combinations]
if len(combinations) < len(image_backbones) * len(audio_backbones):
print(f"\033[92mINFO\033[0m: Skipping {len(evaluated_combinations)} already evaluated combinations")
print(f"\033[92mINFO\033[0m: Will evaluate {len(combinations)} combinations")
for image_backbone, audio_backbone in combinations:
print(f"\033[92mINFO\033[0m: Evaluating {image_backbone} + {audio_backbone}")
try:
# Clean GPU memory before each model evaluation
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
# Print memory usage for debugging
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
result = evaluate_model(data_dir, image_backbone, audio_backbone, save_model_dir=save_model_dir)
results.append(result)
# Save results after each evaluation
save_results(results, results_file)
print(f"\033[92mINFO\033[0m: Updated results saved to {results_file}")
# Force garbage collection to free memory
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
# Print memory usage for debugging
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
except Exception as e:
print(f"\033[91mERR!\033[0m: Error evaluating {image_backbone} + {audio_backbone}: {e}")
print(f"\033[91mERR!\033[0m: To continue from this point, use --start_from={image_backbone}:{audio_backbone}")
# Force garbage collection to free memory even if there's an error
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
continue
# Sort results by test MAE (ascending)
results.sort(key=lambda x: x["test_mae"])
# Save final sorted results
save_results(results, results_file)
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
print("="*60)
for result in results:
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
return results
def save_results(results, filename="backbone_evaluation_results.json"):
"""Save evaluation results to a JSON file."""
with open(filename, 'w') as f:
json.dump(results, f, indent=4)
print(f"\033[92mINFO\033[0m: Results saved to {filename}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Evaluate Different Backbones for Watermelon Sweetness Prediction")
parser.add_argument(
"--data_dir",
type=str,
default="../cleaned",
help="Path to the cleaned dataset directory"
)
parser.add_argument(
"--image_backbone",
type=str,
default=None,
help="Specific image backbone to evaluate (leave empty to evaluate all available)"
)
parser.add_argument(
"--audio_backbone",
type=str,
default=None,
help="Specific audio backbone to evaluate (leave empty to evaluate all available)"
)
parser.add_argument(
"--evaluate_all",
action="store_true",
help="Evaluate all combinations of backbones"
)
parser.add_argument(
"--start_from",
type=str,
default=None,
help="Start evaluation from a specific combination, format: 'image_backbone:audio_backbone'"
)
parser.add_argument(
"--prioritize_efficient",
action="store_true",
help="Prioritize more efficient models first to avoid memory issues"
)
parser.add_argument(
"--results_file",
type=str,
default="backbone_evaluation_results.json",
help="File to save the evaluation results"
)
parser.add_argument(
"--load_previous_results",
action="store_true",
help="Load previous results from results_file if it exists"
)
parser.add_argument(
"--model_dir",
type=str,
default="test_models",
help="Directory to save model checkpoints"
)
args = parser.parse_args()
# Create model directory if it doesn't exist
if args.model_dir:
os.makedirs(args.model_dir, exist_ok=True)
print(f"\033[92mINFO\033[0m: === Available Image Backbones ===")
for name in IMAGE_BACKBONES.keys():
print(f"- {name}")
print(f"\033[92mINFO\033[0m: === Available Audio Backbones ===")
for name in AUDIO_BACKBONES.keys():
print(f"- {name}")
if args.evaluate_all:
evaluate_all_combinations(args.data_dir, results_file=args.results_file, save_model_dir=args.model_dir)
elif args.image_backbone and args.audio_backbone:
result = evaluate_model(args.data_dir, args.image_backbone, args.audio_backbone, save_model_dir=args.model_dir)
save_results([result], args.results_file)
else:
# Define a default set of backbones to evaluate if not specified
if args.prioritize_efficient:
# Start with less memory-intensive models
image_backbones = ["resnet50", "efficientnet_b0", "resnet101", "efficientnet_b3", "convnext_base", "swin_b"]
audio_backbones = ["lstm", "gru", "bidirectional_lstm", "transformer"]
else:
# Default selection focusing on better performance models
image_backbones = ["resnet101", "efficientnet_b3", "swin_b"]
audio_backbones = ["lstm", "bidirectional_lstm", "transformer"]
# Create all combinations
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones]
# Load previous results if requested and file exists
previous_results = []
previous_combinations = set()
if args.load_previous_results:
try:
if os.path.exists(args.results_file):
with open(args.results_file, 'r') as f:
previous_results = json.load(f)
previous_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in previous_results}
print(f"\033[92mINFO\033[0m: Loaded {len(previous_results)} previous results")
else:
print(f"\033[93mWARN\033[0m: Results file '{args.results_file}' does not exist. Starting with empty results.")
except Exception as e:
print(f"\033[91mERR!\033[0m: Error loading previous results: {e}")
previous_results = []
previous_combinations = set()
# If starting from a specific point
if args.start_from:
try:
start_img, start_aud = args.start_from.split(':')
start_idx = combinations.index((start_img, start_aud))
combinations = combinations[start_idx:]
print(f"\033[92mINFO\033[0m: Starting from combination: {start_img} (image) + {start_aud} (audio)")
except (ValueError, IndexError):
print(f"\033[91mERR!\033[0m: Invalid start_from format or combination not found. Format should be 'image_backbone:audio_backbone'")
print(f"\033[91mERR!\033[0m: Continuing with all combinations.")
# Skip combinations that have already been evaluated
if previous_combinations:
original_count = len(combinations)
combinations = [(img, aud) for img, aud in combinations if (img, aud) not in previous_combinations]
print(f"\033[92mINFO\033[0m: Skipping {original_count - len(combinations)} already evaluated combinations")
# Evaluate each combination
results = previous_results.copy()
for img_backbone, audio_backbone in combinations:
print(f"\033[92mINFO\033[0m: Evaluating {img_backbone} + {audio_backbone}")
try:
# Clean GPU memory before each model evaluation
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
result = evaluate_model(args.data_dir, img_backbone, audio_backbone, save_model_dir=args.model_dir)
results.append(result)
# Save results after each evaluation
save_results(results, args.results_file)
# Force garbage collection to free memory
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
except Exception as e:
print(f"\033[91mERR!\033[0m: Error evaluating {img_backbone} + {audio_backbone}: {e}")
print(f"\033[91mERR!\033[0m: To continue from this point later, use --start_from={img_backbone}:{audio_backbone}")
# Force garbage collection to free memory even if there's an error
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
continue
# Sort results by test MAE (ascending)
results.sort(key=lambda x: x["test_mae"])
# Save final sorted results
save_results(results, args.results_file)
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
print("="*60)
for result in results:
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")