import os import numpy as np import matplotlib.pyplot as plt from datasets import load_dataset from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix from sklearn.utils.class_weight import compute_class_weight import seaborn as sns import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, models from PIL import Image from tqdm import tqdm import warnings warnings.filterwarnings('ignore') # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) # Configuration CONFIG = { 'img_size': 224, 'batch_size': 16, # Reduced batch size 'num_epochs': 30, 'learning_rate': 0.0001, 'patience': 7, 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'num_workers': 0, # Set to 0 to avoid multiprocessing issues 'model_save_path': 'best_trash_classifier.pth', } print(f"Using device: {CONFIG['device']}") # Memory-Efficient Dataset Class class TrashDatasetLazy(Dataset): def __init__(self, dataset, indices, transform=None): self.dataset = dataset self.indices = indices self.transform = transform def __len__(self): return len(self.indices) def __getitem__(self, idx): actual_idx = self.indices[idx] item = self.dataset[actual_idx] image = item['image'] label = item['label'] # Convert to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(np.array(image)) # Convert to RGB if grayscale if image.mode != 'RGB': image = image.convert('RGB') if self.transform: image = self.transform(image) return image, label # Data Augmentation and Normalization train_transform = transforms.Compose([ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load Dataset (streaming mode) print("\n" + "="*60) print("LOADING DATASET") print("="*60) ds = load_dataset("rootstrap-org/waste-classifier", split="train") print(f"Dataset loaded successfully!") print(f"Total samples: {len(ds)}") # Get class names class_names = ds.features['label'].names num_classes = len(class_names) print(f"\nNumber of classes: {num_classes}") print(f"Classes: {class_names}") # Extract only labels for splitting (not images!) labels = [item['label'] for item in ds] # Check class distribution unique, counts = np.unique(labels, return_counts=True) print("\nClass Distribution:") for cls_idx, count in zip(unique, counts): print(f" {class_names[cls_idx]}: {count} samples ({count/len(labels)*100:.2f}%)") # Split dataset: 70% train, 15% val, 15% test print("\n" + "="*60) print("SPLITTING DATASET") print("="*60) indices = np.arange(len(ds)) train_idx, temp_idx, y_train, y_temp = train_test_split( indices, labels, test_size=0.3, random_state=42, stratify=labels ) val_idx, test_idx, y_val, y_test = train_test_split( temp_idx, y_temp, test_size=0.5, random_state=42, stratify=y_temp ) print(f"Train set: {len(train_idx)} samples") print(f"Validation set: {len(val_idx)} samples") print(f"Test set: {len(test_idx)} samples") # Calculate class weights for handling imbalance class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(y_train), y=y_train ) class_weights = torch.FloatTensor(class_weights).to(CONFIG['device']) print(f"\nClass weights (for imbalance): {class_weights.cpu().numpy()}") # Create datasets and dataloaders train_dataset = TrashDatasetLazy(ds, train_idx, transform=train_transform) val_dataset = TrashDatasetLazy(ds, val_idx, transform=val_transform) test_dataset = TrashDatasetLazy(ds, test_idx, transform=val_transform) train_loader = DataLoader( train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True ) test_loader = DataLoader( test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True ) # Build Model using EfficientNetV2 (pretrained) print("\n" + "="*60) print("BUILDING MODEL") print("="*60) model = models.efficientnet_v2_s(weights='IMAGENET1K_V1') # Freeze early layers for param in list(model.parameters())[:-20]: param.requires_grad = False # Modify classifier for our number of classes num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(num_features, 512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, num_classes) ) model = model.to(CONFIG['device']) print(f"Model: EfficientNetV2-S (pretrained)") print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") # Loss function with class weights and optimizer criterion = nn.CrossEntropyLoss(weight=class_weights) optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=0.01) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=3, verbose=True ) # Training and Validation Functions def train_epoch(model, loader, criterion, optimizer, device): model.train() running_loss = 0.0 correct = 0 total = 0 pbar = tqdm(loader, desc='Training') for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'}) epoch_loss = running_loss / total epoch_acc = 100. * correct / total return epoch_loss, epoch_acc def validate_epoch(model, loader, criterion, device): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): pbar = tqdm(loader, desc='Validation') for images, labels in pbar: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item() * images.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'}) epoch_loss = running_loss / total epoch_acc = 100. * correct / total return epoch_loss, epoch_acc # Training Loop print("\n" + "="*60) print("TRAINING MODEL") print("="*60) history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [] } best_val_acc = 0.0 patience_counter = 0 for epoch in range(CONFIG['num_epochs']): print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}") print("-" * 60) train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device']) val_loss, val_acc = validate_epoch(model, val_loader, criterion, CONFIG['device']) history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") scheduler.step(val_loss) # Save best model if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, 'class_names': class_names }, CONFIG['model_save_path']) print(f"✓ Model saved! (Val Acc: {val_acc:.2f}%)") patience_counter = 0 else: patience_counter += 1 print(f"No improvement ({patience_counter}/{CONFIG['patience']})") if patience_counter >= CONFIG['patience']: print("\nEarly stopping triggered!") break # Plot Training History print("\n" + "="*60) print("SAVING TRAINING GRAPHS") print("="*60) fig, axes = plt.subplots(1, 2, figsize=(15, 5)) # Loss plot axes[0].plot(history['train_loss'], label='Train Loss', marker='o') axes[0].plot(history['val_loss'], label='Val Loss', marker='s') axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Loss') axes[0].set_title('Training and Validation Loss') axes[0].legend() axes[0].grid(True, alpha=0.3) # Accuracy plot axes[1].plot(history['train_acc'], label='Train Acc', marker='o') axes[1].plot(history['val_acc'], label='Val Acc', marker='s') axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Accuracy (%)') axes[1].set_title('Training and Validation Accuracy') axes[1].legend() axes[1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig('training_history.png', dpi=300, bbox_inches='tight') print("✓ Training graphs saved as 'training_history.png'") # Load Best Model and Test print("\n" + "="*60) print("LOADING BEST MODEL AND TESTING") print("="*60) checkpoint = torch.load(CONFIG['model_save_path']) model.load_state_dict(checkpoint['model_state_dict']) print(f"✓ Loaded best model from epoch {checkpoint['epoch']+1}") print(f" Best validation accuracy: {checkpoint['val_acc']:.2f}%") # Test the model model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm(test_loader, desc='Testing'): images = images.to(CONFIG['device']) outputs = model(images) _, predicted = outputs.max(1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.numpy()) # Calculate test accuracy test_acc = 100. * np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels) print(f"\n{'='*60}") print(f"TEST SET ACCURACY: {test_acc:.2f}%") print(f"{'='*60}") # Classification Report print("\n" + "="*60) print("CLASSIFICATION REPORT") print("="*60) print(classification_report(all_labels, all_preds, target_names=class_names, digits=4)) # Confusion Matrix cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.title('Confusion Matrix - Test Set') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight') print("\n✓ Confusion matrix saved as 'confusion_matrix.png'") print("\n" + "="*60) print("TRAINING COMPLETE!") print("="*60) print(f"✓ Best model saved: {CONFIG['model_save_path']}") print(f"✓ Training history: training_history.png") print(f"✓ Confusion matrix: confusion_matrix.png") print(f"✓ Final test accuracy: {test_acc:.2f}%")