#!/usr/bin/env python3 """ Flower Dataset class for training ConvNeXt models. """ import glob import os import torch from PIL import Image from torch.utils.data import Dataset class FlowerDataset(Dataset): def __init__(self, image_dir, processor, flower_labels=None): self.image_paths = [] self.labels = [] self.processor = processor # Auto-detect flower types from directory structure if not provided if flower_labels is None: detected_types = [] for item in os.listdir(image_dir): item_path = os.path.join(image_dir, item) if os.path.isdir(item_path): image_files = self._get_image_files(item_path) if image_files: # Only add if there are images detected_types.append(item) self.flower_labels = sorted(detected_types) else: self.flower_labels = flower_labels self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)} # Load images from subdirectories (organized by flower type) for flower_type in os.listdir(image_dir): flower_path = os.path.join(image_dir, flower_type) if os.path.isdir(flower_path) and flower_type in self.label_to_id: image_files = self._get_image_files(flower_path) for img_path in image_files: self.image_paths.append(img_path) self.labels.append(self.label_to_id[flower_type]) print( f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types" ) print(f"Flower types: {self.flower_labels}") def _get_image_files(self, directory): """Get all supported image files from directory.""" extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"] image_files = [] for ext in extensions: image_files.extend(glob.glob(os.path.join(directory, ext))) image_files.extend(glob.glob(os.path.join(directory, ext.upper()))) return image_files def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] image = Image.open(image_path).convert("RGB") label = self.labels[idx] # Process image for ConvNeXt inputs = self.processor(images=image, return_tensors="pt") return { "pixel_values": inputs["pixel_values"].squeeze(), "labels": torch.tensor(label, dtype=torch.long), } def simple_collate_fn(batch): """Simple collation function for training.""" pixel_values = [] labels = [] for item in batch: pixel_values.append(item["pixel_values"]) labels.append(item["labels"]) return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)} def advanced_collate_fn(batch): """Advanced collation function for Trainer.""" # Extract components pixel_values = [item["pixel_values"] for item in batch] labels = [item["labels"] for item in batch if "labels" in item] # Stack everything result = {"pixel_values": torch.stack(pixel_values)} if labels: result["labels"] = torch.stack(labels) return result