#!/usr/bin/env python3 """ Advanced ConvNeXt training script using Transformers Trainer. This provides more sophisticated training features like evaluation, checkpointing, and logging. """ import argparse import json import os import torch from dataset import FlowerDataset, advanced_collate_fn from transformers import ( ConvNextForImageClassification, ConvNextImageProcessor, Trainer, TrainingArguments, ) class ConvNeXtTrainer(Trainer): """Custom trainer for ConvNeXt with proper loss computation.""" def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.get("labels") outputs = model(**inputs) if labels is not None: loss = torch.nn.functional.cross_entropy(outputs.logits, labels) else: loss = outputs.loss return (loss, outputs) if return_outputs else loss def advanced_train( image_dir="training_data/images", output_dir="training_data/trained_models/advanced_trained", model_name="facebook/convnext-base-224-22k", num_epochs=5, batch_size=8, learning_rate=1e-5, flower_labels=None, ): """ Advanced training function using Transformers Trainer. Args: image_dir: Directory containing training images organized by flower type output_dir: Directory to save the trained model model_name: Base ConvNeXt model to fine-tune num_epochs: Number of training epochs batch_size: Training batch size learning_rate: Learning rate for optimization flower_labels: List of flower labels (auto-detected if None) Returns: str: Path to the saved model directory, or None if training failed """ print("🌸 Advanced ConvNeXt Flower Model Training") print("=" * 50) # Check training data if not os.path.exists(image_dir): print(f"āŒ Training directory not found: {image_dir}") return None # Load model and processor print(f"Loading model: {model_name}") model = ConvNextForImageClassification.from_pretrained(model_name) processor = ConvNextImageProcessor.from_pretrained(model_name) # Create dataset dataset = FlowerDataset(image_dir, processor, flower_labels) if len(dataset) == 0: print( "āŒ No training data found. Please add images to subdirectories in training_data/images/" ) print( "Example: training_data/images/roses/, training_data/images/tulips/, etc." ) return None # Split dataset (80% train, 20% eval) train_size = int(0.8 * len(dataset)) eval_size = len(dataset) - train_size train_dataset, eval_dataset = torch.utils.data.random_split( dataset, [train_size, eval_size] ) # Update model config for the number of classes if len(dataset.flower_labels) != model.config.num_labels: model.config.num_labels = len(dataset.flower_labels) # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension final_hidden_size = ( model.config.hidden_sizes[-1] if hasattr(model.config, "hidden_sizes") else 768 ) model.classifier = torch.nn.Linear( final_hidden_size, len(dataset.flower_labels) ) # Training arguments training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, learning_rate=learning_rate, warmup_steps=100, logging_steps=10, eval_strategy="epoch", save_strategy="epoch", save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, dataloader_num_workers=0, # Set to 0 to avoid multiprocessing issues remove_unused_columns=False, ) # Create trainer try: trainer = ConvNeXtTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=advanced_collate_fn, ) print("āœ… Trainer created successfully") except Exception as e: print(f"āŒ Error creating trainer: {e}") return None # Train model print("Starting advanced training...") try: trainer.train() print("āœ… Training completed successfully!") except Exception as e: print(f"āŒ Training failed: {e}") import traceback traceback.print_exc() return None # Save final model final_model_path = os.path.join(output_dir, "final_model") model.save_pretrained(final_model_path) processor.save_pretrained(final_model_path) # Save training config config = { "model_name": model_name, "flower_labels": dataset.flower_labels, "num_epochs": num_epochs, "batch_size": batch_size, "learning_rate": learning_rate, "train_samples": len(train_dataset), "eval_samples": len(eval_dataset), "training_type": "advanced", } with open(os.path.join(final_model_path, "training_config.json"), "w") as f: json.dump(config, f, indent=2) print(f"āœ… Advanced training complete! Model saved to {final_model_path}") return final_model_path if __name__ == "__main__": parser = argparse.ArgumentParser( description="Advanced ConvNeXt training for flower classification" ) parser.add_argument( "--image_dir", default="training_data/images", help="Directory containing training images", ) parser.add_argument( "--output_dir", default="training_data/trained_models/advanced_trained", help="Output directory for trained model", ) parser.add_argument( "--model_name", default="facebook/convnext-base-224-22k", help="Base model name" ) parser.add_argument( "--epochs", type=int, default=5, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Training batch size") parser.add_argument( "--learning_rate", type=float, default=1e-5, help="Learning rate" ) args = parser.parse_args() try: result = advanced_train( image_dir=args.image_dir, output_dir=args.output_dir, model_name=args.model_name, num_epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, ) if not result: print("āŒ Training failed!") exit(1) except KeyboardInterrupt: print("\nāš ļø Training interrupted by user.") except Exception as e: print(f"āŒ Training failed: {e}") import traceback traceback.print_exc() exit(1)