Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision.models import convnext_tiny | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| import os | |
| # Dataset path on RunPod | |
| dataset_path = "/workspace/VCR Cleaned" | |
| train_dir = os.path.join(dataset_path, "train") | |
| val_dir = os.path.join(dataset_path, "val") | |
| # Transforms | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # Load datasets | |
| train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) | |
| val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) | |
| # Verify class mapping | |
| print("\nLabel mapping:", train_dataset.class_to_idx) | |
| print("Number of classes:", len(train_dataset.classes)) | |
| # Load model | |
| model = convnext_tiny(pretrained=True) | |
| model.classifier[2] = nn.Linear(768, len(train_dataset.classes)) | |
| # Setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) | |
| # Paths | |
| checkpoint_path = "/workspace/convnext_checkpoint.pth" | |
| best_model_path = "/workspace/convnext_best_model.pth" | |
| final_model_path = "/workspace/convnext_final_model.pth" | |
| # Load checkpoint if available | |
| start_epoch = 0 | |
| train_losses = [] | |
| val_losses = [] | |
| val_accuracies = [] | |
| best_acc = 0.0 | |
| if os.path.exists(checkpoint_path): | |
| print("\nLoading checkpoint...") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| train_losses = checkpoint['train_losses'] | |
| val_losses = checkpoint['val_losses'] | |
| val_accuracies = checkpoint['val_accuracies'] | |
| best_acc = max(val_accuracies) if val_accuracies else 0.0 | |
| start_epoch = checkpoint['epoch'] | |
| print(f"Resumed from epoch {start_epoch}") | |
| else: | |
| print("\nStarting training from scratch") | |
| # Training loop | |
| for epoch in range(start_epoch, 100): | |
| model.train() | |
| train_loss = 0 | |
| for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"): | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| scheduler.step() | |
| # Validation | |
| model.eval() | |
| val_loss = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for images, labels in DataLoader(val_dataset, batch_size=64): | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| val_loss += criterion(outputs, labels).item() | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| # Metrics | |
| epoch_train_loss = train_loss / len(train_dataset) | |
| epoch_val_loss = val_loss / len(val_dataset) | |
| epoch_val_acc = correct / len(val_dataset) | |
| train_losses.append(epoch_train_loss) | |
| val_losses.append(epoch_val_loss) | |
| val_accuracies.append(epoch_val_acc) | |
| print(f"\nEpoch {epoch+1}:") | |
| print(f" Train Loss: {epoch_train_loss:.4f}") | |
| print(f" Val Loss: {epoch_val_loss:.4f}") | |
| print(f" Val Acc: {epoch_val_acc:.4f}") | |
| # Save checkpoint | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'train_losses': train_losses, | |
| 'val_losses': val_losses, | |
| 'val_accuracies': val_accuracies | |
| }, checkpoint_path) | |
| # β Save best model | |
| if epoch_val_acc > best_acc: | |
| best_acc = epoch_val_acc | |
| torch.save(model.state_dict(), best_model_path) | |
| print(f"β Best model saved at epoch {epoch+1} with acc {best_acc:.4f}") | |
| # β Save final model | |
| torch.save(model.state_dict(), final_model_path) | |
| print(f"\nβ Final model saved to {final_model_path}") | |
| # Plot | |
| plt.figure(figsize=(12, 4)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(train_losses, label='Train Loss') | |
| plt.plot(val_losses, label='Val Loss') | |
| plt.title("Loss") | |
| plt.legend() | |
| plt.subplot(1, 2, 2) | |
| plt.plot(val_accuracies, label='Val Accuracy') | |
| plt.title("Validation Accuracy") | |
| plt.legend() | |
| plt.show() | |