flowerfy / training /dataset.py
Toy
Apply code formatting and fix compatibility issues
b24c04f
#!/usr/bin/env python3
"""
Flower Dataset class for training ConvNeXt models.
"""
import glob
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
class FlowerDataset(Dataset):
def __init__(self, image_dir, processor, flower_labels=None):
self.image_paths = []
self.labels = []
self.processor = processor
# Auto-detect flower types from directory structure if not provided
if flower_labels is None:
detected_types = []
for item in os.listdir(image_dir):
item_path = os.path.join(image_dir, item)
if os.path.isdir(item_path):
image_files = self._get_image_files(item_path)
if image_files: # Only add if there are images
detected_types.append(item)
self.flower_labels = sorted(detected_types)
else:
self.flower_labels = flower_labels
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
# Load images from subdirectories (organized by flower type)
for flower_type in os.listdir(image_dir):
flower_path = os.path.join(image_dir, flower_type)
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
image_files = self._get_image_files(flower_path)
for img_path in image_files:
self.image_paths.append(img_path)
self.labels.append(self.label_to_id[flower_type])
print(
f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
)
print(f"Flower types: {self.flower_labels}")
def _get_image_files(self, directory):
"""Get all supported image files from directory."""
extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
image_files = []
for ext in extensions:
image_files.extend(glob.glob(os.path.join(directory, ext)))
image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
return image_files
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")
label = self.labels[idx]
# Process image for ConvNeXt
inputs = self.processor(images=image, return_tensors="pt")
return {
"pixel_values": inputs["pixel_values"].squeeze(),
"labels": torch.tensor(label, dtype=torch.long),
}
def simple_collate_fn(batch):
"""Simple collation function for training."""
pixel_values = []
labels = []
for item in batch:
pixel_values.append(item["pixel_values"])
labels.append(item["labels"])
return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)}
def advanced_collate_fn(batch):
"""Advanced collation function for Trainer."""
# Extract components
pixel_values = [item["pixel_values"] for item in batch]
labels = [item["labels"] for item in batch if "labels" in item]
# Stack everything
result = {"pixel_values": torch.stack(pixel_values)}
if labels:
result["labels"] = torch.stack(labels)
return result