File size: 3,322 Bytes
bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f bed1967 b24c04f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
#!/usr/bin/env python3
"""
Flower Dataset class for training ConvNeXt models.
"""
import glob
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
class FlowerDataset(Dataset):
def __init__(self, image_dir, processor, flower_labels=None):
self.image_paths = []
self.labels = []
self.processor = processor
# Auto-detect flower types from directory structure if not provided
if flower_labels is None:
detected_types = []
for item in os.listdir(image_dir):
item_path = os.path.join(image_dir, item)
if os.path.isdir(item_path):
image_files = self._get_image_files(item_path)
if image_files: # Only add if there are images
detected_types.append(item)
self.flower_labels = sorted(detected_types)
else:
self.flower_labels = flower_labels
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
# Load images from subdirectories (organized by flower type)
for flower_type in os.listdir(image_dir):
flower_path = os.path.join(image_dir, flower_type)
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
image_files = self._get_image_files(flower_path)
for img_path in image_files:
self.image_paths.append(img_path)
self.labels.append(self.label_to_id[flower_type])
print(
f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
)
print(f"Flower types: {self.flower_labels}")
def _get_image_files(self, directory):
"""Get all supported image files from directory."""
extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
image_files = []
for ext in extensions:
image_files.extend(glob.glob(os.path.join(directory, ext)))
image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
return image_files
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert("RGB")
label = self.labels[idx]
# Process image for ConvNeXt
inputs = self.processor(images=image, return_tensors="pt")
return {
"pixel_values": inputs["pixel_values"].squeeze(),
"labels": torch.tensor(label, dtype=torch.long),
}
def simple_collate_fn(batch):
"""Simple collation function for training."""
pixel_values = []
labels = []
for item in batch:
pixel_values.append(item["pixel_values"])
labels.append(item["labels"])
return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)}
def advanced_collate_fn(batch):
"""Advanced collation function for Trainer."""
# Extract components
pixel_values = [item["pixel_values"] for item in batch]
labels = [item["labels"] for item in batch if "labels" in item]
# Stack everything
result = {"pixel_values": torch.stack(pixel_values)}
if labels:
result["labels"] = torch.stack(labels)
return result
|