|
|
|
""" |
|
Flower Dataset class for training ConvNeXt models. |
|
""" |
|
|
|
import glob |
|
import os |
|
|
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class FlowerDataset(Dataset): |
|
def __init__(self, image_dir, processor, flower_labels=None): |
|
self.image_paths = [] |
|
self.labels = [] |
|
self.processor = processor |
|
|
|
|
|
if flower_labels is None: |
|
detected_types = [] |
|
for item in os.listdir(image_dir): |
|
item_path = os.path.join(image_dir, item) |
|
if os.path.isdir(item_path): |
|
image_files = self._get_image_files(item_path) |
|
if image_files: |
|
detected_types.append(item) |
|
self.flower_labels = sorted(detected_types) |
|
else: |
|
self.flower_labels = flower_labels |
|
|
|
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)} |
|
|
|
|
|
for flower_type in os.listdir(image_dir): |
|
flower_path = os.path.join(image_dir, flower_type) |
|
if os.path.isdir(flower_path) and flower_type in self.label_to_id: |
|
image_files = self._get_image_files(flower_path) |
|
|
|
for img_path in image_files: |
|
self.image_paths.append(img_path) |
|
self.labels.append(self.label_to_id[flower_type]) |
|
|
|
print( |
|
f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types" |
|
) |
|
print(f"Flower types: {self.flower_labels}") |
|
|
|
def _get_image_files(self, directory): |
|
"""Get all supported image files from directory.""" |
|
extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"] |
|
image_files = [] |
|
for ext in extensions: |
|
image_files.extend(glob.glob(os.path.join(directory, ext))) |
|
image_files.extend(glob.glob(os.path.join(directory, ext.upper()))) |
|
return image_files |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.image_paths[idx] |
|
image = Image.open(image_path).convert("RGB") |
|
label = self.labels[idx] |
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
|
return { |
|
"pixel_values": inputs["pixel_values"].squeeze(), |
|
"labels": torch.tensor(label, dtype=torch.long), |
|
} |
|
|
|
|
|
def simple_collate_fn(batch): |
|
"""Simple collation function for training.""" |
|
pixel_values = [] |
|
labels = [] |
|
|
|
for item in batch: |
|
pixel_values.append(item["pixel_values"]) |
|
labels.append(item["labels"]) |
|
|
|
return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)} |
|
|
|
|
|
def advanced_collate_fn(batch): |
|
"""Advanced collation function for Trainer.""" |
|
|
|
pixel_values = [item["pixel_values"] for item in batch] |
|
labels = [item["labels"] for item in batch if "labels" in item] |
|
|
|
|
|
result = {"pixel_values": torch.stack(pixel_values)} |
|
|
|
if labels: |
|
result["labels"] = torch.stack(labels) |
|
|
|
return result |
|
|