#!/usr/bin/env python3 """ Simple ConvNeXt training script without using the Transformers Trainer class. This is a lightweight training implementation for quick model fine-tuning. """ import json import os import torch from dataset import FlowerDataset, simple_collate_fn from torch.utils.data import DataLoader from transformers import ConvNextForImageClassification, ConvNextImageProcessor def simple_train( image_dir="training_data/images", output_dir="training_data/trained_models/simple_trained", epochs=3, batch_size=4, learning_rate=1e-5, model_name="facebook/convnext-base-224-22k", ): """ Simple training function for ConvNeXt flower classification. Args: image_dir: Directory containing training images organized by flower type output_dir: Directory to save the trained model epochs: Number of training epochs batch_size: Training batch size learning_rate: Learning rate for optimization model_name: Base ConvNeXt model to fine-tune Returns: str: Path to the saved model directory, or None if training failed """ print("🌸 Simple ConvNeXt Flower Model Training") print("=" * 40) # Check training data if not os.path.exists(image_dir): print(f"āŒ Training directory not found: {image_dir}") return None device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using device: {device}") # Load model and processor print(f"Loading model: {model_name}") model = ConvNextForImageClassification.from_pretrained(model_name) processor = ConvNextImageProcessor.from_pretrained(model_name) model.to(device) # Create dataset dataset = FlowerDataset(image_dir, processor) if len(dataset) < 5: print("āŒ Need at least 5 images for training") return None # Split dataset train_size = int(0.8 * len(dataset)) train_dataset = torch.utils.data.Subset(dataset, range(train_size)) # Update model config for the number of classes if len(dataset.flower_labels) != model.config.num_labels: model.config.num_labels = len(dataset.flower_labels) # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension final_hidden_size = ( model.config.hidden_sizes[-1] if hasattr(model.config, "hidden_sizes") else 768 ) model.classifier = torch.nn.Linear( final_hidden_size, len(dataset.flower_labels) ) model.classifier.to(device) # Create data loader train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=simple_collate_fn ) # Setup optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # Training loop model.train() print(f"Starting training on {len(train_dataset)} samples for {epochs} epochs...") for epoch in range(epochs): total_loss = 0 num_batches = 0 for batch_idx, batch in enumerate(train_loader): # Move to device pixel_values = batch["pixel_values"].to(device) labels = batch["labels"].to(device) # Zero gradients optimizer.zero_grad() # Forward pass outputs = model(pixel_values=pixel_values, labels=labels) loss = outputs.loss # Backward pass loss.backward() optimizer.step() total_loss += loss.item() num_batches += 1 if batch_idx % 2 == 0 or batch_idx == len(train_loader) - 1: print( f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(train_loader)}: Loss = {loss.item():.4f}" ) avg_loss = total_loss / num_batches if num_batches > 0 else 0 print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") # Save model os.makedirs(output_dir, exist_ok=True) model.save_pretrained(output_dir) processor.save_pretrained(output_dir) # Save config config = { "model_name": model_name, "flower_labels": dataset.flower_labels, "num_epochs": epochs, "batch_size": batch_size, "learning_rate": learning_rate, "train_samples": len(train_dataset), "num_labels": len(dataset.flower_labels), "training_type": "simple", } with open(os.path.join(output_dir, "training_config.json"), "w") as f: json.dump(config, f, indent=2) print(f"āœ… ConvNeXt training completed! Model saved to {output_dir}") return output_dir if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Simple ConvNeXt training for flower classification" ) parser.add_argument( "--image_dir", default="training_data/images", help="Directory containing training images", ) parser.add_argument( "--output_dir", default="training_data/trained_models/simple_trained", help="Output directory for trained model", ) parser.add_argument( "--epochs", type=int, default=3, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=4, help="Training batch size") parser.add_argument( "--learning_rate", type=float, default=1e-5, help="Learning rate" ) parser.add_argument( "--model_name", default="facebook/convnext-base-224-22k", help="Base model name" ) args = parser.parse_args() try: result = simple_train( image_dir=args.image_dir, output_dir=args.output_dir, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, model_name=args.model_name, ) if not result: print("āŒ Training failed!") exit(1) except KeyboardInterrupt: print("\nāš ļø Training interrupted by user.") except Exception as e: print(f"āŒ Training failed: {e}") import traceback traceback.print_exc() exit(1)