File size: 3,322 Bytes
bed1967
 
 
 
 
b24c04f
bed1967
b24c04f
bed1967
 
 
 
 
 
 
 
 
 
b24c04f
bed1967
 
 
 
 
 
 
 
 
 
 
 
b24c04f
bed1967
b24c04f
bed1967
 
 
 
 
b24c04f
bed1967
 
 
b24c04f
 
 
 
bed1967
b24c04f
bed1967
 
 
 
 
 
 
 
b24c04f
bed1967
 
b24c04f
bed1967
 
 
 
b24c04f
bed1967
 
b24c04f
bed1967
b24c04f
 
bed1967
 
 
 
 
 
 
b24c04f
bed1967
b24c04f
 
 
 
bed1967
 
 
 
 
b24c04f
 
 
bed1967
b24c04f
 
bed1967
b24c04f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
"""
Flower Dataset class for training ConvNeXt models.
"""

import glob
import os

import torch
from PIL import Image
from torch.utils.data import Dataset


class FlowerDataset(Dataset):
    def __init__(self, image_dir, processor, flower_labels=None):
        self.image_paths = []
        self.labels = []
        self.processor = processor

        # Auto-detect flower types from directory structure if not provided
        if flower_labels is None:
            detected_types = []
            for item in os.listdir(image_dir):
                item_path = os.path.join(image_dir, item)
                if os.path.isdir(item_path):
                    image_files = self._get_image_files(item_path)
                    if image_files:  # Only add if there are images
                        detected_types.append(item)
            self.flower_labels = sorted(detected_types)
        else:
            self.flower_labels = flower_labels

        self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}

        # Load images from subdirectories (organized by flower type)
        for flower_type in os.listdir(image_dir):
            flower_path = os.path.join(image_dir, flower_type)
            if os.path.isdir(flower_path) and flower_type in self.label_to_id:
                image_files = self._get_image_files(flower_path)

                for img_path in image_files:
                    self.image_paths.append(img_path)
                    self.labels.append(self.label_to_id[flower_type])

        print(
            f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
        )
        print(f"Flower types: {self.flower_labels}")

    def _get_image_files(self, directory):
        """Get all supported image files from directory."""
        extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
        image_files = []
        for ext in extensions:
            image_files.extend(glob.glob(os.path.join(directory, ext)))
            image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
        return image_files

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        # Process image for ConvNeXt
        inputs = self.processor(images=image, return_tensors="pt")

        return {
            "pixel_values": inputs["pixel_values"].squeeze(),
            "labels": torch.tensor(label, dtype=torch.long),
        }


def simple_collate_fn(batch):
    """Simple collation function for training."""
    pixel_values = []
    labels = []

    for item in batch:
        pixel_values.append(item["pixel_values"])
        labels.append(item["labels"])

    return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)}


def advanced_collate_fn(batch):
    """Advanced collation function for Trainer."""
    # Extract components
    pixel_values = [item["pixel_values"] for item in batch]
    labels = [item["labels"] for item in batch if "labels" in item]

    # Stack everything
    result = {"pixel_values": torch.stack(pixel_values)}

    if labels:
        result["labels"] = torch.stack(labels)

    return result