Toy
Claude
commited on
Commit
Β·
b24c04f
1
Parent(s):
ef5274f
Apply code formatting and fix compatibility issues
Browse files- Update pydantic dependencies for HF Spaces compatibility
- Apply ruff formatting across codebase
- Fix import organization and type annotations
- Ensure SDXL-only model architecture
π€ Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <[email protected]>
- app.py +18 -14
- app_original.py +250 -124
- requirements.txt +1 -1
- src/__init__.py +1 -1
- src/core/__init__.py +1 -1
- src/core/config.py +12 -8
- src/core/constants.py +2 -2
- src/services/__init__.py +1 -1
- src/services/models/__init__.py +1 -1
- src/services/models/flower_classification.py +70 -41
- src/services/models/image_generation.py +18 -15
- src/services/training/__init__.py +1 -1
- src/services/training/dataset.py +22 -17
- src/services/training/training_service.py +17 -13
- src/training/__init__.py +1 -1
- src/training/simple_train.py +56 -43
- src/ui/__init__.py +1 -1
- src/ui/french_style/__init__.py +1 -1
- src/ui/french_style/french_style_tab.py +59 -47
- src/ui/generate/__init__.py +1 -1
- src/ui/generate/generate_tab.py +25 -19
- src/ui/identify/__init__.py +1 -1
- src/ui/identify/identify_tab.py +45 -25
- src/ui/train/__init__.py +1 -1
- src/ui/train/train_tab.py +40 -32
- src/utils/__init__.py +1 -1
- src/utils/color_utils.py +14 -10
- src/utils/file_utils.py +25 -18
- test_external_cache.py +27 -23
- tests/__init__.py +1 -1
- tests/test_models.py +42 -23
- training/advanced_trainer.py +74 -38
- training/dataset.py +30 -32
- training/simple_trainer.py +83 -50
app.py
CHANGED
@@ -20,6 +20,7 @@ if src_path not in sys.path:
|
|
20 |
|
21 |
# Initialize config early to setup cache paths before model imports
|
22 |
from core.config import config
|
|
|
23 |
print(f"π§ Environment: {'HF Spaces' if config.is_hf_spaces else 'Local'}")
|
24 |
print(f"π§ Device: {config.device}, dtype: {config.dtype}")
|
25 |
|
@@ -28,64 +29,66 @@ from ui.generate.generate_tab import GenerateTab
|
|
28 |
from ui.identify.identify_tab import IdentifyTab
|
29 |
from ui.train.train_tab import TrainTab
|
30 |
|
|
|
31 |
class FlowerifyApp:
|
32 |
"""Main application class for Flowerify."""
|
33 |
-
|
34 |
def __init__(self):
|
35 |
self.generate_tab = GenerateTab()
|
36 |
self.identify_tab = IdentifyTab()
|
37 |
self.train_tab = TrainTab()
|
38 |
self.french_style_tab = FrenchStyleTab()
|
39 |
-
|
40 |
def create_interface(self) -> gr.Blocks:
|
41 |
"""Create the main Gradio interface."""
|
42 |
with gr.Blocks(title="πΈ Flowerify - AI Flower Generator & Identifier") as demo:
|
43 |
gr.Markdown("# πΈ Flowerfy β Text β Image + Flower Identifier")
|
44 |
-
|
45 |
with gr.Tabs():
|
46 |
# Create each tab
|
47 |
generate_tab = self.generate_tab.create_ui()
|
48 |
identify_tab = self.identify_tab.create_ui()
|
49 |
train_tab = self.train_tab.create_ui()
|
50 |
french_style_tab = self.french_style_tab.create_ui()
|
51 |
-
|
52 |
# Wire cross-tab interactions
|
53 |
self._setup_cross_tab_interactions()
|
54 |
-
|
55 |
# Initialize data on load
|
56 |
demo.load(
|
57 |
self.train_tab._count_training_images,
|
58 |
-
outputs=[self.train_tab.data_status]
|
59 |
)
|
60 |
-
|
61 |
return demo
|
62 |
-
|
63 |
def _setup_cross_tab_interactions(self):
|
64 |
"""Setup interactions between tabs."""
|
65 |
# Auto-send generated image to Identify tab
|
66 |
self.generate_tab.output_image.change(
|
67 |
self.identify_tab.set_image,
|
68 |
inputs=self.generate_tab.output_image,
|
69 |
-
outputs=self.identify_tab.image_input
|
70 |
)
|
71 |
-
|
72 |
def launch(self, **kwargs):
|
73 |
"""Launch the application."""
|
74 |
demo = self.create_interface()
|
75 |
# Add share=True for HF Spaces compatibility
|
76 |
if config.is_hf_spaces:
|
77 |
-
kwargs.setdefault(
|
78 |
return demo.queue().launch(**kwargs)
|
79 |
|
|
|
80 |
def main():
|
81 |
"""Main entry point."""
|
82 |
try:
|
83 |
print("πΈ Starting Flowerify (SDXL models)")
|
84 |
print("Loading models and initializing UI...")
|
85 |
-
|
86 |
app = FlowerifyApp()
|
87 |
app.launch()
|
88 |
-
|
89 |
except KeyboardInterrupt:
|
90 |
print("\nπ Application stopped by user")
|
91 |
except Exception as e:
|
@@ -93,5 +96,6 @@ def main():
|
|
93 |
traceback.print_exc()
|
94 |
sys.exit(1)
|
95 |
|
|
|
96 |
if __name__ == "__main__":
|
97 |
-
main()
|
|
|
20 |
|
21 |
# Initialize config early to setup cache paths before model imports
|
22 |
from core.config import config
|
23 |
+
|
24 |
print(f"π§ Environment: {'HF Spaces' if config.is_hf_spaces else 'Local'}")
|
25 |
print(f"π§ Device: {config.device}, dtype: {config.dtype}")
|
26 |
|
|
|
29 |
from ui.identify.identify_tab import IdentifyTab
|
30 |
from ui.train.train_tab import TrainTab
|
31 |
|
32 |
+
|
33 |
class FlowerifyApp:
|
34 |
"""Main application class for Flowerify."""
|
35 |
+
|
36 |
def __init__(self):
|
37 |
self.generate_tab = GenerateTab()
|
38 |
self.identify_tab = IdentifyTab()
|
39 |
self.train_tab = TrainTab()
|
40 |
self.french_style_tab = FrenchStyleTab()
|
41 |
+
|
42 |
def create_interface(self) -> gr.Blocks:
|
43 |
"""Create the main Gradio interface."""
|
44 |
with gr.Blocks(title="πΈ Flowerify - AI Flower Generator & Identifier") as demo:
|
45 |
gr.Markdown("# πΈ Flowerfy β Text β Image + Flower Identifier")
|
46 |
+
|
47 |
with gr.Tabs():
|
48 |
# Create each tab
|
49 |
generate_tab = self.generate_tab.create_ui()
|
50 |
identify_tab = self.identify_tab.create_ui()
|
51 |
train_tab = self.train_tab.create_ui()
|
52 |
french_style_tab = self.french_style_tab.create_ui()
|
53 |
+
|
54 |
# Wire cross-tab interactions
|
55 |
self._setup_cross_tab_interactions()
|
56 |
+
|
57 |
# Initialize data on load
|
58 |
demo.load(
|
59 |
self.train_tab._count_training_images,
|
60 |
+
outputs=[self.train_tab.data_status],
|
61 |
)
|
62 |
+
|
63 |
return demo
|
64 |
+
|
65 |
def _setup_cross_tab_interactions(self):
|
66 |
"""Setup interactions between tabs."""
|
67 |
# Auto-send generated image to Identify tab
|
68 |
self.generate_tab.output_image.change(
|
69 |
self.identify_tab.set_image,
|
70 |
inputs=self.generate_tab.output_image,
|
71 |
+
outputs=self.identify_tab.image_input,
|
72 |
)
|
73 |
+
|
74 |
def launch(self, **kwargs):
|
75 |
"""Launch the application."""
|
76 |
demo = self.create_interface()
|
77 |
# Add share=True for HF Spaces compatibility
|
78 |
if config.is_hf_spaces:
|
79 |
+
kwargs.setdefault("share", True)
|
80 |
return demo.queue().launch(**kwargs)
|
81 |
|
82 |
+
|
83 |
def main():
|
84 |
"""Main entry point."""
|
85 |
try:
|
86 |
print("πΈ Starting Flowerify (SDXL models)")
|
87 |
print("Loading models and initializing UI...")
|
88 |
+
|
89 |
app = FlowerifyApp()
|
90 |
app.launch()
|
91 |
+
|
92 |
except KeyboardInterrupt:
|
93 |
print("\nπ Application stopped by user")
|
94 |
except Exception as e:
|
|
|
96 |
traceback.print_exc()
|
97 |
sys.exit(1)
|
98 |
|
99 |
+
|
100 |
if __name__ == "__main__":
|
101 |
+
main()
|
app_original.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1 |
-
import os, torch, gradio as gr, json
|
2 |
-
from diffusers import AutoPipelineForText2Image
|
3 |
-
from transformers import pipeline, ConvNextImageProcessor, ConvNextForImageClassification, AutoImageProcessor, AutoModelForImageClassification
|
4 |
-
from simple_train import simple_train
|
5 |
import glob
|
6 |
-
|
7 |
-
|
|
|
8 |
import numpy as np
|
|
|
|
|
|
|
9 |
from sklearn.cluster import KMeans
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
|
13 |
|
@@ -23,6 +29,7 @@ if device == "cuda":
|
|
23 |
else:
|
24 |
pipe.enable_attention_slicing()
|
25 |
|
|
|
26 |
def generate(prompt, steps, width, height, seed):
|
27 |
if seed is None or int(seed) < 0:
|
28 |
generator = None
|
@@ -32,23 +39,49 @@ def generate(prompt, steps, width, height, seed):
|
|
32 |
result = pipe(
|
33 |
prompt=prompt,
|
34 |
num_inference_steps=int(steps),
|
35 |
-
guidance_scale=0.0,
|
36 |
width=int(width // 8) * 8,
|
37 |
height=int(height // 8) * 8,
|
38 |
-
generator=generator
|
39 |
)
|
40 |
return result.images[0]
|
41 |
|
42 |
|
43 |
-
|
44 |
# ---------- Flower identification (zero-shot) ----------
|
45 |
# Curated label set; edit/extend as you like
|
46 |
FLOWER_LABELS = [
|
47 |
-
"rose",
|
48 |
-
"
|
49 |
-
"
|
50 |
-
"
|
51 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
]
|
53 |
|
54 |
# Initialize classifier - will be updated when trained model is loaded
|
@@ -58,6 +91,7 @@ convnext_model = None
|
|
58 |
convnext_processor = None
|
59 |
current_model_path = "facebook/convnext-base-224-22k"
|
60 |
|
|
|
61 |
def load_classifier(model_path="facebook/convnext-base-224-22k"):
|
62 |
global zs_classifier, convnext_model, convnext_processor, current_model_path
|
63 |
try:
|
@@ -70,149 +104,174 @@ def load_classifier(model_path="facebook/convnext-base-224-22k"):
|
|
70 |
zs_classifier = pipeline(
|
71 |
task="zero-shot-image-classification",
|
72 |
model="openai/clip-vit-base-patch32",
|
73 |
-
device=clf_device
|
74 |
)
|
75 |
return f"β
Loaded custom ConvNeXt model from: {model_path}"
|
76 |
else:
|
77 |
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
|
78 |
-
convnext_model = ConvNextForImageClassification.from_pretrained(
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
zs_classifier = pipeline(
|
81 |
task="zero-shot-image-classification",
|
82 |
model="openai/clip-vit-base-patch32",
|
83 |
-
device=clf_device
|
84 |
)
|
85 |
current_model_path = "facebook/convnext-base-224-22k"
|
86 |
-
return
|
87 |
except Exception as e:
|
88 |
-
return f"β Error loading model: {
|
|
|
89 |
|
90 |
# Initialize with default model
|
91 |
load_classifier()
|
92 |
|
|
|
93 |
def identify_flowers(image, candidate_labels, top_k, min_score):
|
94 |
if image is None:
|
95 |
return [], "Please provide an image (upload or generate first)."
|
96 |
-
|
97 |
labels = candidate_labels if candidate_labels else FLOWER_LABELS
|
98 |
-
|
99 |
# Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
|
100 |
-
if
|
|
|
|
|
|
|
|
|
101 |
try:
|
102 |
# Use trained ConvNeXt model
|
103 |
inputs = convnext_processor(images=image, return_tensors="pt")
|
104 |
with torch.no_grad():
|
105 |
outputs = convnext_model(**inputs)
|
106 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
107 |
-
|
108 |
# Convert predictions to results format
|
109 |
results = []
|
110 |
for i, score in enumerate(predictions[0]):
|
111 |
if i < len(labels):
|
112 |
results.append({"label": labels[i], "score": float(score)})
|
113 |
-
|
114 |
# Sort by score
|
115 |
results = sorted(results, key=lambda r: r["score"], reverse=True)
|
116 |
-
except Exception
|
117 |
# Fallback to CLIP zero-shot
|
118 |
results = zs_classifier(
|
119 |
-
image,
|
120 |
-
candidate_labels=labels,
|
121 |
-
hypothesis_template="a photo of a {}"
|
122 |
)
|
123 |
else:
|
124 |
# Use CLIP zero-shot classification
|
125 |
results = zs_classifier(
|
126 |
-
image,
|
127 |
-
candidate_labels=labels,
|
128 |
-
hypothesis_template="a photo of a {}"
|
129 |
)
|
130 |
-
|
131 |
# Filter and format results
|
132 |
results = [r for r in results if r["score"] >= float(min_score)]
|
133 |
-
results = sorted(results, key=lambda r: r["score"], reverse=True)[:int(top_k)]
|
134 |
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
|
135 |
-
model_type =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
msg = f"Detected flowers using {model_type}."
|
137 |
return table, msg
|
138 |
|
|
|
139 |
# simple passthrough so the generated image appears in the Identify tab automatically
|
140 |
def passthrough(img):
|
141 |
return img
|
142 |
|
|
|
143 |
# Training functions
|
144 |
def get_available_models():
|
145 |
models_dir = "training_data/trained_models"
|
146 |
if not os.path.exists(models_dir):
|
147 |
return ["facebook/convnext-base-224-22k (default)"]
|
148 |
-
|
149 |
models = ["facebook/convnext-base-224-22k (default)"]
|
150 |
for item in os.listdir(models_dir):
|
151 |
model_path = os.path.join(models_dir, item)
|
152 |
-
if os.path.isdir(model_path) and os.path.exists(
|
|
|
|
|
153 |
models.append(f"Custom: {item}")
|
154 |
return models
|
155 |
|
|
|
156 |
def count_training_images():
|
157 |
images_dir = "training_data/images"
|
158 |
if not os.path.exists(images_dir):
|
159 |
return "Training directory not found"
|
160 |
-
|
161 |
total_images = 0
|
162 |
flower_counts = {}
|
163 |
-
|
164 |
for flower_type in os.listdir(images_dir):
|
165 |
flower_path = os.path.join(images_dir, flower_type)
|
166 |
if os.path.isdir(flower_path):
|
167 |
-
image_files =
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
171 |
count = len(image_files)
|
172 |
if count > 0:
|
173 |
flower_counts[flower_type] = count
|
174 |
total_images += count
|
175 |
-
|
176 |
if total_images == 0:
|
177 |
return "No training images found. Add images to subdirectories in training_data/images/"
|
178 |
-
|
179 |
result = f"**Total images: {total_images}**\n\n"
|
180 |
for flower_type, count in sorted(flower_counts.items()):
|
181 |
result += f"- {flower_type}: {count} images\n"
|
182 |
-
|
183 |
return result
|
184 |
|
|
|
185 |
def start_training(epochs=None, batch_size=None, learning_rate=None):
|
186 |
try:
|
187 |
# Check if training data exists
|
188 |
images_dir = "training_data/images"
|
189 |
if not os.path.exists(images_dir):
|
190 |
return "β Training directory not found. Please create training_data/images/ and add your data."
|
191 |
-
|
192 |
# Count images
|
193 |
total_images = 0
|
194 |
for flower_type in os.listdir(images_dir):
|
195 |
flower_path = os.path.join(images_dir, flower_type)
|
196 |
if os.path.isdir(flower_path):
|
197 |
-
image_files =
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
201 |
total_images += len(image_files)
|
202 |
-
|
203 |
if total_images < 10:
|
204 |
return f"β Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
|
205 |
-
|
206 |
# Start training
|
207 |
model_path = simple_train()
|
208 |
-
|
209 |
if model_path:
|
210 |
return f"β
Training completed! Model saved to: {model_path}"
|
211 |
else:
|
212 |
return "β Training failed. Check the console for details."
|
213 |
-
|
214 |
except Exception as e:
|
215 |
-
return f"β Training error: {
|
|
|
216 |
|
217 |
def load_trained_model(model_selection):
|
218 |
if model_selection.startswith("Custom:"):
|
@@ -222,25 +281,26 @@ def load_trained_model(model_selection):
|
|
222 |
else:
|
223 |
return load_classifier("facebook/convnext-base-224-22k")
|
224 |
|
|
|
225 |
# French-style arrangement functions
|
226 |
def extract_dominant_colors(image, num_colors=5):
|
227 |
"""Extract dominant colors from an image using k-means clustering"""
|
228 |
if image is None:
|
229 |
return [], "No image provided"
|
230 |
-
|
231 |
# Convert PIL image to numpy array
|
232 |
img_array = np.array(image)
|
233 |
-
|
234 |
# Reshape image to be a list of pixels
|
235 |
pixels = img_array.reshape(-1, 3)
|
236 |
-
|
237 |
# Use k-means to find dominant colors
|
238 |
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
239 |
kmeans.fit(pixels)
|
240 |
-
|
241 |
# Get the colors and convert to RGB values
|
242 |
colors = kmeans.cluster_centers_.astype(int)
|
243 |
-
|
244 |
# Convert to color names/descriptions
|
245 |
color_names = []
|
246 |
for color in colors:
|
@@ -268,54 +328,59 @@ def extract_dominant_colors(image, num_colors=5):
|
|
268 |
color_names.append("orange")
|
269 |
else:
|
270 |
color_names.append("cream")
|
271 |
-
|
272 |
return color_names, colors
|
273 |
|
|
|
274 |
def analyze_and_generate_french_style(image):
|
275 |
"""Analyze uploaded flower image and generate French-style arrangement"""
|
276 |
if image is None:
|
277 |
return None, "Please upload an image", ""
|
278 |
-
|
279 |
# Identify the flower type
|
280 |
if zs_classifier is None:
|
281 |
return None, "Model not loaded", ""
|
282 |
-
|
283 |
try:
|
284 |
progress_log = "π **Step 1/4:** Starting flower analysis...\n\n"
|
285 |
-
|
286 |
# Identify flower
|
287 |
progress_log += "π Identifying flower type using AI model...\n"
|
288 |
results = zs_classifier(
|
289 |
-
image,
|
290 |
-
candidate_labels=FLOWER_LABELS,
|
291 |
-
hypothesis_template="a photo of a {}"
|
292 |
)
|
293 |
-
|
294 |
top_flower = results[0]["label"] if results else "flower"
|
295 |
confidence = results[0]["score"] if results else 0
|
296 |
-
progress_log +=
|
297 |
-
|
|
|
|
|
298 |
# Extract dominant colors
|
299 |
progress_log += "π **Step 2/4:** Analyzing color palette...\n\n"
|
300 |
progress_log += "π¨ Extracting dominant colors from image...\n"
|
301 |
color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
|
302 |
-
|
303 |
# Create color description
|
304 |
main_colors = color_names[:3] # Top 3 colors
|
305 |
color_desc = ", ".join(main_colors)
|
306 |
progress_log += f"β
Color palette: **{color_desc}**\n\n"
|
307 |
-
|
308 |
# Generate French-style prompt
|
309 |
-
progress_log +=
|
|
|
|
|
310 |
prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
|
311 |
progress_log += f"β
Prompt created: *{prompt[:100]}...*\n\n"
|
312 |
-
|
313 |
# Generate the image
|
314 |
-
progress_log +=
|
|
|
|
|
315 |
progress_log += "πΌοΈ Using AI image generation (SDXL-Turbo)...\n"
|
316 |
generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
|
317 |
progress_log += "β
French-style arrangement generated successfully!\n\n"
|
318 |
-
|
319 |
# Create analysis summary
|
320 |
analysis = f"""
|
321 |
**πΈ Flower Analysis:**
|
@@ -330,14 +395,19 @@ def analyze_and_generate_french_style(image):
|
|
330 |
**π Process Log:**
|
331 |
{progress_log}
|
332 |
"""
|
333 |
-
|
334 |
-
return
|
335 |
-
|
|
|
|
|
|
|
|
|
336 |
except Exception as e:
|
337 |
-
error_log = f"β **Error occurred during processing:**\n\n{
|
338 |
-
if
|
339 |
error_log += f"**Progress before error:**\n{progress_log}"
|
340 |
-
return None, f"β Error: {
|
|
|
341 |
|
342 |
# ---------- UI ----------
|
343 |
with gr.Blocks() as demo:
|
@@ -347,66 +417,113 @@ with gr.Blocks() as demo:
|
|
347 |
with gr.TabItem("Generate"):
|
348 |
with gr.Row():
|
349 |
with gr.Column():
|
350 |
-
prompt = gr.Textbox(
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
353 |
height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
|
354 |
-
seed
|
355 |
-
go
|
356 |
out = gr.Image(label="Result", type="pil")
|
357 |
|
358 |
with gr.TabItem("Identify"):
|
359 |
with gr.Row():
|
360 |
with gr.Column():
|
361 |
-
img_in = gr.Image(
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
|
364 |
-
min_score = gr.Slider(
|
|
|
|
|
365 |
detect_btn = gr.Button("Identify Flowers", variant="primary")
|
366 |
with gr.Column():
|
367 |
-
results_tbl = gr.Dataframe(
|
|
|
|
|
|
|
|
|
368 |
status = gr.Markdown()
|
369 |
|
370 |
with gr.TabItem("Train Model"):
|
371 |
gr.Markdown("## π― Fine-tune the flower identification model")
|
372 |
-
gr.Markdown(
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
375 |
with gr.Row():
|
376 |
with gr.Column():
|
377 |
gr.Markdown("### Training Data")
|
378 |
refresh_btn = gr.Button("π Refresh Data Count", size="sm")
|
379 |
data_status = gr.Markdown()
|
380 |
-
|
381 |
gr.Markdown("### Training Parameters")
|
382 |
epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
|
383 |
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
|
384 |
-
learning_rate = gr.Number(
|
385 |
-
|
|
|
|
|
386 |
train_btn = gr.Button("π Start Training", variant="primary")
|
387 |
-
|
388 |
with gr.Column():
|
389 |
gr.Markdown("### Model Management")
|
390 |
-
model_dropdown = gr.Dropdown(
|
|
|
|
|
|
|
|
|
391 |
refresh_models_btn = gr.Button("π Refresh Models", size="sm")
|
392 |
-
load_model_btn = gr.Button(
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
396 |
gr.Markdown("### Training Status")
|
397 |
training_output = gr.Markdown()
|
398 |
|
399 |
with gr.TabItem("French Style arrangement"):
|
400 |
gr.Markdown("## π«π· French-Style Flower Arrangements")
|
401 |
-
gr.Markdown(
|
402 |
-
|
|
|
|
|
403 |
with gr.Row():
|
404 |
with gr.Column():
|
405 |
upload_img = gr.Image(label="Upload Flower Image", type="pil")
|
406 |
-
analyze_btn = gr.Button(
|
407 |
-
|
|
|
|
|
|
|
|
|
408 |
with gr.Column():
|
409 |
-
french_result = gr.Image(
|
|
|
|
|
410 |
french_status = gr.Markdown()
|
411 |
analysis_details = gr.Markdown()
|
412 |
|
@@ -415,28 +532,37 @@ with gr.Blocks() as demo:
|
|
415 |
# Auto-send generated image to Identify tab
|
416 |
out.change(passthrough, inputs=out, outputs=img_in)
|
417 |
# Run identification
|
418 |
-
detect_btn.click(
|
419 |
-
|
|
|
|
|
420 |
# Training tab events
|
421 |
refresh_btn.click(count_training_images, outputs=[data_status])
|
422 |
-
refresh_models_btn.click(
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
# French Style tab events - update status during processing
|
427 |
def update_french_status():
|
428 |
return "π Processing... Please wait while we analyze your flower image...", ""
|
429 |
-
|
430 |
analyze_btn.click(
|
431 |
-
update_french_status,
|
432 |
-
outputs=[french_status, analysis_details]
|
433 |
).then(
|
434 |
-
analyze_and_generate_french_style,
|
435 |
-
inputs=[upload_img],
|
436 |
-
outputs=[french_result, french_status, analysis_details]
|
437 |
)
|
438 |
-
|
439 |
# Initialize data count on load
|
440 |
demo.load(count_training_images, outputs=[data_status])
|
441 |
|
442 |
-
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
1 |
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
import numpy as np
|
6 |
+
import torch
|
7 |
+
from diffusers import AutoPipelineForText2Image
|
8 |
+
from simple_train import simple_train
|
9 |
from sklearn.cluster import KMeans
|
10 |
+
from transformers import (
|
11 |
+
AutoImageProcessor,
|
12 |
+
AutoModelForImageClassification,
|
13 |
+
ConvNextForImageClassification,
|
14 |
+
ConvNextImageProcessor,
|
15 |
+
pipeline,
|
16 |
+
)
|
17 |
|
18 |
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
|
19 |
|
|
|
29 |
else:
|
30 |
pipe.enable_attention_slicing()
|
31 |
|
32 |
+
|
33 |
def generate(prompt, steps, width, height, seed):
|
34 |
if seed is None or int(seed) < 0:
|
35 |
generator = None
|
|
|
39 |
result = pipe(
|
40 |
prompt=prompt,
|
41 |
num_inference_steps=int(steps),
|
42 |
+
guidance_scale=0.0, # SDXL-Turbo works best at 0.0
|
43 |
width=int(width // 8) * 8,
|
44 |
height=int(height // 8) * 8,
|
45 |
+
generator=generator,
|
46 |
)
|
47 |
return result.images[0]
|
48 |
|
49 |
|
|
|
50 |
# ---------- Flower identification (zero-shot) ----------
|
51 |
# Curated label set; edit/extend as you like
|
52 |
FLOWER_LABELS = [
|
53 |
+
"rose",
|
54 |
+
"tulip",
|
55 |
+
"lily",
|
56 |
+
"peony",
|
57 |
+
"sunflower",
|
58 |
+
"chrysanthemum",
|
59 |
+
"carnation",
|
60 |
+
"orchid",
|
61 |
+
"hydrangea",
|
62 |
+
"daisy",
|
63 |
+
"dahlia",
|
64 |
+
"ranunculus",
|
65 |
+
"anemone",
|
66 |
+
"marigold",
|
67 |
+
"lavender",
|
68 |
+
"magnolia",
|
69 |
+
"gardenia",
|
70 |
+
"camellia",
|
71 |
+
"jasmine",
|
72 |
+
"iris",
|
73 |
+
"gerbera",
|
74 |
+
"zinnia",
|
75 |
+
"hibiscus",
|
76 |
+
"lotus",
|
77 |
+
"poppy",
|
78 |
+
"sweet pea",
|
79 |
+
"freesia",
|
80 |
+
"lisianthus",
|
81 |
+
"calla lily",
|
82 |
+
"cherry blossom",
|
83 |
+
"plumeria",
|
84 |
+
"cosmos",
|
85 |
]
|
86 |
|
87 |
# Initialize classifier - will be updated when trained model is loaded
|
|
|
91 |
convnext_processor = None
|
92 |
current_model_path = "facebook/convnext-base-224-22k"
|
93 |
|
94 |
+
|
95 |
def load_classifier(model_path="facebook/convnext-base-224-22k"):
|
96 |
global zs_classifier, convnext_model, convnext_processor, current_model_path
|
97 |
try:
|
|
|
104 |
zs_classifier = pipeline(
|
105 |
task="zero-shot-image-classification",
|
106 |
model="openai/clip-vit-base-patch32",
|
107 |
+
device=clf_device,
|
108 |
)
|
109 |
return f"β
Loaded custom ConvNeXt model from: {model_path}"
|
110 |
else:
|
111 |
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
|
112 |
+
convnext_model = ConvNextForImageClassification.from_pretrained(
|
113 |
+
"facebook/convnext-base-224-22k"
|
114 |
+
)
|
115 |
+
convnext_processor = ConvNextImageProcessor.from_pretrained(
|
116 |
+
"facebook/convnext-base-224-22k"
|
117 |
+
)
|
118 |
zs_classifier = pipeline(
|
119 |
task="zero-shot-image-classification",
|
120 |
model="openai/clip-vit-base-patch32",
|
121 |
+
device=clf_device,
|
122 |
)
|
123 |
current_model_path = "facebook/convnext-base-224-22k"
|
124 |
+
return "β
Loaded default ConvNeXt model: facebook/convnext-base-224-22k"
|
125 |
except Exception as e:
|
126 |
+
return f"β Error loading model: {e!s}"
|
127 |
+
|
128 |
|
129 |
# Initialize with default model
|
130 |
load_classifier()
|
131 |
|
132 |
+
|
133 |
def identify_flowers(image, candidate_labels, top_k, min_score):
|
134 |
if image is None:
|
135 |
return [], "Please provide an image (upload or generate first)."
|
136 |
+
|
137 |
labels = candidate_labels if candidate_labels else FLOWER_LABELS
|
138 |
+
|
139 |
# Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
|
140 |
+
if (
|
141 |
+
convnext_model is not None
|
142 |
+
and os.path.exists(current_model_path)
|
143 |
+
and current_model_path != "facebook/convnext-base-224-22k"
|
144 |
+
):
|
145 |
try:
|
146 |
# Use trained ConvNeXt model
|
147 |
inputs = convnext_processor(images=image, return_tensors="pt")
|
148 |
with torch.no_grad():
|
149 |
outputs = convnext_model(**inputs)
|
150 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
151 |
+
|
152 |
# Convert predictions to results format
|
153 |
results = []
|
154 |
for i, score in enumerate(predictions[0]):
|
155 |
if i < len(labels):
|
156 |
results.append({"label": labels[i], "score": float(score)})
|
157 |
+
|
158 |
# Sort by score
|
159 |
results = sorted(results, key=lambda r: r["score"], reverse=True)
|
160 |
+
except Exception:
|
161 |
# Fallback to CLIP zero-shot
|
162 |
results = zs_classifier(
|
163 |
+
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
|
|
|
|
|
164 |
)
|
165 |
else:
|
166 |
# Use CLIP zero-shot classification
|
167 |
results = zs_classifier(
|
168 |
+
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
|
|
|
|
|
169 |
)
|
170 |
+
|
171 |
# Filter and format results
|
172 |
results = [r for r in results if r["score"] >= float(min_score)]
|
173 |
+
results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)]
|
174 |
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
|
175 |
+
model_type = (
|
176 |
+
"ConvNeXt"
|
177 |
+
if (
|
178 |
+
convnext_model is not None
|
179 |
+
and os.path.exists(current_model_path)
|
180 |
+
and current_model_path != "facebook/convnext-base-224-22k"
|
181 |
+
)
|
182 |
+
else "CLIP zero-shot"
|
183 |
+
)
|
184 |
msg = f"Detected flowers using {model_type}."
|
185 |
return table, msg
|
186 |
|
187 |
+
|
188 |
# simple passthrough so the generated image appears in the Identify tab automatically
|
189 |
def passthrough(img):
|
190 |
return img
|
191 |
|
192 |
+
|
193 |
# Training functions
|
194 |
def get_available_models():
|
195 |
models_dir = "training_data/trained_models"
|
196 |
if not os.path.exists(models_dir):
|
197 |
return ["facebook/convnext-base-224-22k (default)"]
|
198 |
+
|
199 |
models = ["facebook/convnext-base-224-22k (default)"]
|
200 |
for item in os.listdir(models_dir):
|
201 |
model_path = os.path.join(models_dir, item)
|
202 |
+
if os.path.isdir(model_path) and os.path.exists(
|
203 |
+
os.path.join(model_path, "config.json")
|
204 |
+
):
|
205 |
models.append(f"Custom: {item}")
|
206 |
return models
|
207 |
|
208 |
+
|
209 |
def count_training_images():
|
210 |
images_dir = "training_data/images"
|
211 |
if not os.path.exists(images_dir):
|
212 |
return "Training directory not found"
|
213 |
+
|
214 |
total_images = 0
|
215 |
flower_counts = {}
|
216 |
+
|
217 |
for flower_type in os.listdir(images_dir):
|
218 |
flower_path = os.path.join(images_dir, flower_type)
|
219 |
if os.path.isdir(flower_path):
|
220 |
+
image_files = (
|
221 |
+
glob.glob(os.path.join(flower_path, "*.jpg"))
|
222 |
+
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
|
223 |
+
+ glob.glob(os.path.join(flower_path, "*.png"))
|
224 |
+
+ glob.glob(os.path.join(flower_path, "*.webp"))
|
225 |
+
)
|
226 |
count = len(image_files)
|
227 |
if count > 0:
|
228 |
flower_counts[flower_type] = count
|
229 |
total_images += count
|
230 |
+
|
231 |
if total_images == 0:
|
232 |
return "No training images found. Add images to subdirectories in training_data/images/"
|
233 |
+
|
234 |
result = f"**Total images: {total_images}**\n\n"
|
235 |
for flower_type, count in sorted(flower_counts.items()):
|
236 |
result += f"- {flower_type}: {count} images\n"
|
237 |
+
|
238 |
return result
|
239 |
|
240 |
+
|
241 |
def start_training(epochs=None, batch_size=None, learning_rate=None):
|
242 |
try:
|
243 |
# Check if training data exists
|
244 |
images_dir = "training_data/images"
|
245 |
if not os.path.exists(images_dir):
|
246 |
return "β Training directory not found. Please create training_data/images/ and add your data."
|
247 |
+
|
248 |
# Count images
|
249 |
total_images = 0
|
250 |
for flower_type in os.listdir(images_dir):
|
251 |
flower_path = os.path.join(images_dir, flower_type)
|
252 |
if os.path.isdir(flower_path):
|
253 |
+
image_files = (
|
254 |
+
glob.glob(os.path.join(flower_path, "*.jpg"))
|
255 |
+
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
|
256 |
+
+ glob.glob(os.path.join(flower_path, "*.png"))
|
257 |
+
+ glob.glob(os.path.join(flower_path, "*.webp"))
|
258 |
+
)
|
259 |
total_images += len(image_files)
|
260 |
+
|
261 |
if total_images < 10:
|
262 |
return f"β Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
|
263 |
+
|
264 |
# Start training
|
265 |
model_path = simple_train()
|
266 |
+
|
267 |
if model_path:
|
268 |
return f"β
Training completed! Model saved to: {model_path}"
|
269 |
else:
|
270 |
return "β Training failed. Check the console for details."
|
271 |
+
|
272 |
except Exception as e:
|
273 |
+
return f"β Training error: {e!s}"
|
274 |
+
|
275 |
|
276 |
def load_trained_model(model_selection):
|
277 |
if model_selection.startswith("Custom:"):
|
|
|
281 |
else:
|
282 |
return load_classifier("facebook/convnext-base-224-22k")
|
283 |
|
284 |
+
|
285 |
# French-style arrangement functions
|
286 |
def extract_dominant_colors(image, num_colors=5):
|
287 |
"""Extract dominant colors from an image using k-means clustering"""
|
288 |
if image is None:
|
289 |
return [], "No image provided"
|
290 |
+
|
291 |
# Convert PIL image to numpy array
|
292 |
img_array = np.array(image)
|
293 |
+
|
294 |
# Reshape image to be a list of pixels
|
295 |
pixels = img_array.reshape(-1, 3)
|
296 |
+
|
297 |
# Use k-means to find dominant colors
|
298 |
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
299 |
kmeans.fit(pixels)
|
300 |
+
|
301 |
# Get the colors and convert to RGB values
|
302 |
colors = kmeans.cluster_centers_.astype(int)
|
303 |
+
|
304 |
# Convert to color names/descriptions
|
305 |
color_names = []
|
306 |
for color in colors:
|
|
|
328 |
color_names.append("orange")
|
329 |
else:
|
330 |
color_names.append("cream")
|
331 |
+
|
332 |
return color_names, colors
|
333 |
|
334 |
+
|
335 |
def analyze_and_generate_french_style(image):
|
336 |
"""Analyze uploaded flower image and generate French-style arrangement"""
|
337 |
if image is None:
|
338 |
return None, "Please upload an image", ""
|
339 |
+
|
340 |
# Identify the flower type
|
341 |
if zs_classifier is None:
|
342 |
return None, "Model not loaded", ""
|
343 |
+
|
344 |
try:
|
345 |
progress_log = "π **Step 1/4:** Starting flower analysis...\n\n"
|
346 |
+
|
347 |
# Identify flower
|
348 |
progress_log += "π Identifying flower type using AI model...\n"
|
349 |
results = zs_classifier(
|
350 |
+
image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}"
|
|
|
|
|
351 |
)
|
352 |
+
|
353 |
top_flower = results[0]["label"] if results else "flower"
|
354 |
confidence = results[0]["score"] if results else 0
|
355 |
+
progress_log += (
|
356 |
+
f"β
Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
|
357 |
+
)
|
358 |
+
|
359 |
# Extract dominant colors
|
360 |
progress_log += "π **Step 2/4:** Analyzing color palette...\n\n"
|
361 |
progress_log += "π¨ Extracting dominant colors from image...\n"
|
362 |
color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
|
363 |
+
|
364 |
# Create color description
|
365 |
main_colors = color_names[:3] # Top 3 colors
|
366 |
color_desc = ", ".join(main_colors)
|
367 |
progress_log += f"β
Color palette: **{color_desc}**\n\n"
|
368 |
+
|
369 |
# Generate French-style prompt
|
370 |
+
progress_log += (
|
371 |
+
"π **Step 3/4:** Creating French-style arrangement prompt...\n\n"
|
372 |
+
)
|
373 |
prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
|
374 |
progress_log += f"β
Prompt created: *{prompt[:100]}...*\n\n"
|
375 |
+
|
376 |
# Generate the image
|
377 |
+
progress_log += (
|
378 |
+
"π **Step 4/4:** Generating French-style arrangement image...\n\n"
|
379 |
+
)
|
380 |
progress_log += "πΌοΈ Using AI image generation (SDXL-Turbo)...\n"
|
381 |
generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
|
382 |
progress_log += "β
French-style arrangement generated successfully!\n\n"
|
383 |
+
|
384 |
# Create analysis summary
|
385 |
analysis = f"""
|
386 |
**πΈ Flower Analysis:**
|
|
|
395 |
**π Process Log:**
|
396 |
{progress_log}
|
397 |
"""
|
398 |
+
|
399 |
+
return (
|
400 |
+
generated_image,
|
401 |
+
"β
Analysis complete! French-style arrangement generated.",
|
402 |
+
analysis,
|
403 |
+
)
|
404 |
+
|
405 |
except Exception as e:
|
406 |
+
error_log = f"β **Error occurred during processing:**\n\n{e!s}\n\n"
|
407 |
+
if "progress_log" in locals():
|
408 |
error_log += f"**Progress before error:**\n{progress_log}"
|
409 |
+
return None, f"β Error: {e!s}", error_log
|
410 |
+
|
411 |
|
412 |
# ---------- UI ----------
|
413 |
with gr.Blocks() as demo:
|
|
|
417 |
with gr.TabItem("Generate"):
|
418 |
with gr.Row():
|
419 |
with gr.Column():
|
420 |
+
prompt = gr.Textbox(
|
421 |
+
value="ikebana-style flower arrangement, soft natural light, minimalist",
|
422 |
+
label="Prompt",
|
423 |
+
)
|
424 |
+
steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
|
425 |
+
width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
|
426 |
height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
|
427 |
+
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
|
428 |
+
go = gr.Button("Generate", variant="primary")
|
429 |
out = gr.Image(label="Result", type="pil")
|
430 |
|
431 |
with gr.TabItem("Identify"):
|
432 |
with gr.Row():
|
433 |
with gr.Column():
|
434 |
+
img_in = gr.Image(
|
435 |
+
label="Image (upload or auto-filled from 'Generate')",
|
436 |
+
type="pil",
|
437 |
+
interactive=True,
|
438 |
+
)
|
439 |
+
labels_box = gr.CheckboxGroup(
|
440 |
+
choices=FLOWER_LABELS,
|
441 |
+
value=[
|
442 |
+
"rose",
|
443 |
+
"tulip",
|
444 |
+
"lily",
|
445 |
+
"peony",
|
446 |
+
"hydrangea",
|
447 |
+
"orchid",
|
448 |
+
"sunflower",
|
449 |
+
],
|
450 |
+
label="Candidate labels (edit as needed)",
|
451 |
+
)
|
452 |
topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
|
453 |
+
min_score = gr.Slider(
|
454 |
+
0.0, 1.0, value=0.12, step=0.01, label="Min confidence"
|
455 |
+
)
|
456 |
detect_btn = gr.Button("Identify Flowers", variant="primary")
|
457 |
with gr.Column():
|
458 |
+
results_tbl = gr.Dataframe(
|
459 |
+
headers=["Flower", "Confidence"],
|
460 |
+
datatype=["str", "number"],
|
461 |
+
interactive=False,
|
462 |
+
)
|
463 |
status = gr.Markdown()
|
464 |
|
465 |
with gr.TabItem("Train Model"):
|
466 |
gr.Markdown("## π― Fine-tune the flower identification model")
|
467 |
+
gr.Markdown(
|
468 |
+
"Organize your training images in subdirectories by flower type in `training_data/images/`"
|
469 |
+
)
|
470 |
+
gr.Markdown(
|
471 |
+
"Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
|
472 |
+
)
|
473 |
+
|
474 |
with gr.Row():
|
475 |
with gr.Column():
|
476 |
gr.Markdown("### Training Data")
|
477 |
refresh_btn = gr.Button("π Refresh Data Count", size="sm")
|
478 |
data_status = gr.Markdown()
|
479 |
+
|
480 |
gr.Markdown("### Training Parameters")
|
481 |
epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
|
482 |
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
|
483 |
+
learning_rate = gr.Number(
|
484 |
+
value=1e-5, label="Learning Rate", precision=6
|
485 |
+
)
|
486 |
+
|
487 |
train_btn = gr.Button("π Start Training", variant="primary")
|
488 |
+
|
489 |
with gr.Column():
|
490 |
gr.Markdown("### Model Management")
|
491 |
+
model_dropdown = gr.Dropdown(
|
492 |
+
choices=get_available_models(),
|
493 |
+
value="facebook/convnext-base-224-22k (default)",
|
494 |
+
label="Select Model",
|
495 |
+
)
|
496 |
refresh_models_btn = gr.Button("π Refresh Models", size="sm")
|
497 |
+
load_model_btn = gr.Button(
|
498 |
+
"π₯ Load Selected Model", variant="secondary"
|
499 |
+
)
|
500 |
+
|
501 |
+
model_status = gr.Markdown(
|
502 |
+
f"**Current model:** {current_model_path}"
|
503 |
+
)
|
504 |
+
|
505 |
gr.Markdown("### Training Status")
|
506 |
training_output = gr.Markdown()
|
507 |
|
508 |
with gr.TabItem("French Style arrangement"):
|
509 |
gr.Markdown("## π«π· French-Style Flower Arrangements")
|
510 |
+
gr.Markdown(
|
511 |
+
"Upload a flower image and generate an elegant French-style arrangement with matching colors!"
|
512 |
+
)
|
513 |
+
|
514 |
with gr.Row():
|
515 |
with gr.Column():
|
516 |
upload_img = gr.Image(label="Upload Flower Image", type="pil")
|
517 |
+
analyze_btn = gr.Button(
|
518 |
+
"π¨ Analyze & Generate French Style",
|
519 |
+
variant="primary",
|
520 |
+
size="lg",
|
521 |
+
)
|
522 |
+
|
523 |
with gr.Column():
|
524 |
+
french_result = gr.Image(
|
525 |
+
label="Generated French-Style Arrangement", type="pil"
|
526 |
+
)
|
527 |
french_status = gr.Markdown()
|
528 |
analysis_details = gr.Markdown()
|
529 |
|
|
|
532 |
# Auto-send generated image to Identify tab
|
533 |
out.change(passthrough, inputs=out, outputs=img_in)
|
534 |
# Run identification
|
535 |
+
detect_btn.click(
|
536 |
+
identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status]
|
537 |
+
)
|
538 |
+
|
539 |
# Training tab events
|
540 |
refresh_btn.click(count_training_images, outputs=[data_status])
|
541 |
+
refresh_models_btn.click(
|
542 |
+
lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown]
|
543 |
+
)
|
544 |
+
load_model_btn.click(
|
545 |
+
load_trained_model, inputs=[model_dropdown], outputs=[model_status]
|
546 |
+
)
|
547 |
+
train_btn.click(
|
548 |
+
start_training,
|
549 |
+
inputs=[epochs, batch_size, learning_rate],
|
550 |
+
outputs=[training_output],
|
551 |
+
)
|
552 |
+
|
553 |
# French Style tab events - update status during processing
|
554 |
def update_french_status():
|
555 |
return "π Processing... Please wait while we analyze your flower image...", ""
|
556 |
+
|
557 |
analyze_btn.click(
|
558 |
+
update_french_status, outputs=[french_status, analysis_details]
|
|
|
559 |
).then(
|
560 |
+
analyze_and_generate_french_style,
|
561 |
+
inputs=[upload_img],
|
562 |
+
outputs=[french_result, french_status, analysis_details],
|
563 |
)
|
564 |
+
|
565 |
# Initialize data count on load
|
566 |
demo.load(count_training_images, outputs=[data_status])
|
567 |
|
568 |
+
demo.queue().launch()
|
requirements.txt
CHANGED
@@ -234,7 +234,7 @@ pydantic==2.10.6
|
|
234 |
# via
|
235 |
# fastapi
|
236 |
# gradio
|
237 |
-
pydantic-core==2.27.
|
238 |
# via pydantic
|
239 |
pydub==0.25.1
|
240 |
# via gradio
|
|
|
234 |
# via
|
235 |
# fastapi
|
236 |
# gradio
|
237 |
+
pydantic-core==2.27.2
|
238 |
# via pydantic
|
239 |
pydub==0.25.1
|
240 |
# via gradio
|
src/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Flowerify application package
|
|
|
1 |
+
# Flowerify application package
|
src/core/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Core package
|
|
|
1 |
+
# Core package
|
src/core/config.py
CHANGED
@@ -3,26 +3,29 @@ Configuration management for the application.
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
6 |
import torch
|
|
|
7 |
from .constants import DEFAULT_MODEL_ID
|
8 |
|
|
|
9 |
class AppConfig:
|
10 |
"""Application configuration singleton."""
|
11 |
-
|
12 |
def __init__(self):
|
13 |
self._setup_device()
|
14 |
self.model_id = DEFAULT_MODEL_ID
|
15 |
# Auto-detect Hugging Face Spaces environment
|
16 |
self.is_hf_spaces = os.getenv("SPACE_ID") is not None
|
17 |
self._setup_cache_paths()
|
18 |
-
|
19 |
def _setup_device(self):
|
20 |
"""Setup device configuration for PyTorch."""
|
21 |
if torch.cuda.is_available():
|
22 |
self.device = "cuda"
|
23 |
self.dtype = torch.float16
|
24 |
self.clf_device = 0
|
25 |
-
elif hasattr(torch.backends,
|
26 |
self.device = "mps"
|
27 |
self.dtype = torch.float16
|
28 |
self.clf_device = 0
|
@@ -30,7 +33,7 @@ class AppConfig:
|
|
30 |
self.device = "cpu"
|
31 |
self.dtype = torch.float32
|
32 |
self.clf_device = -1
|
33 |
-
|
34 |
def _setup_cache_paths(self):
|
35 |
"""Setup cache paths based on environment."""
|
36 |
if self.is_hf_spaces:
|
@@ -48,16 +51,17 @@ class AppConfig:
|
|
48 |
print(f"π Using configured HF_HOME: {os.getenv('HF_HOME')}")
|
49 |
else:
|
50 |
print("π Using default Hugging Face cache")
|
51 |
-
|
52 |
@property
|
53 |
def is_cuda_available(self):
|
54 |
"""Check if CUDA is available."""
|
55 |
return torch.cuda.is_available()
|
56 |
-
|
57 |
@property
|
58 |
def is_mps_available(self):
|
59 |
"""Check if Apple MPS is available."""
|
60 |
-
return hasattr(torch.backends,
|
|
|
61 |
|
62 |
# Global configuration instance
|
63 |
-
config = AppConfig()
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
+
|
7 |
import torch
|
8 |
+
|
9 |
from .constants import DEFAULT_MODEL_ID
|
10 |
|
11 |
+
|
12 |
class AppConfig:
|
13 |
"""Application configuration singleton."""
|
14 |
+
|
15 |
def __init__(self):
|
16 |
self._setup_device()
|
17 |
self.model_id = DEFAULT_MODEL_ID
|
18 |
# Auto-detect Hugging Face Spaces environment
|
19 |
self.is_hf_spaces = os.getenv("SPACE_ID") is not None
|
20 |
self._setup_cache_paths()
|
21 |
+
|
22 |
def _setup_device(self):
|
23 |
"""Setup device configuration for PyTorch."""
|
24 |
if torch.cuda.is_available():
|
25 |
self.device = "cuda"
|
26 |
self.dtype = torch.float16
|
27 |
self.clf_device = 0
|
28 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
29 |
self.device = "mps"
|
30 |
self.dtype = torch.float16
|
31 |
self.clf_device = 0
|
|
|
33 |
self.device = "cpu"
|
34 |
self.dtype = torch.float32
|
35 |
self.clf_device = -1
|
36 |
+
|
37 |
def _setup_cache_paths(self):
|
38 |
"""Setup cache paths based on environment."""
|
39 |
if self.is_hf_spaces:
|
|
|
51 |
print(f"π Using configured HF_HOME: {os.getenv('HF_HOME')}")
|
52 |
else:
|
53 |
print("π Using default Hugging Face cache")
|
54 |
+
|
55 |
@property
|
56 |
def is_cuda_available(self):
|
57 |
"""Check if CUDA is available."""
|
58 |
return torch.cuda.is_available()
|
59 |
+
|
60 |
@property
|
61 |
def is_mps_available(self):
|
62 |
"""Check if Apple MPS is available."""
|
63 |
+
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
64 |
+
|
65 |
|
66 |
# Global configuration instance
|
67 |
+
config = AppConfig()
|
src/core/constants.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6 |
# If using external SSD, models will be cached at /Volumes/extssd/huggingface/hub
|
7 |
# This is configured via environment variables (see .env file and run.sh script)
|
8 |
|
9 |
-
# Model configuration
|
10 |
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
|
11 |
FALLBACK_MODEL_ID = "stabilityai/sdxl-turbo" # Lightweight fallback model
|
12 |
DEFAULT_CONVNEXT_MODEL = "facebook/convnext-tiny-224"
|
@@ -62,4 +62,4 @@ DEFAULT_MIN_SCORE = 0.12
|
|
62 |
DEFAULT_NUM_COLORS = 3
|
63 |
|
64 |
# File extensions for image files
|
65 |
-
SUPPORTED_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]
|
|
|
6 |
# If using external SSD, models will be cached at /Volumes/extssd/huggingface/hub
|
7 |
# This is configured via environment variables (see .env file and run.sh script)
|
8 |
|
9 |
+
# Model configuration
|
10 |
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
|
11 |
FALLBACK_MODEL_ID = "stabilityai/sdxl-turbo" # Lightweight fallback model
|
12 |
DEFAULT_CONVNEXT_MODEL = "facebook/convnext-tiny-224"
|
|
|
62 |
DEFAULT_NUM_COLORS = 3
|
63 |
|
64 |
# File extensions for image files
|
65 |
+
SUPPORTED_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]
|
src/services/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Services package
|
|
|
1 |
+
# Services package
|
src/services/models/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Models package
|
|
|
1 |
+
# Models package
|
src/services/models/flower_classification.py
CHANGED
@@ -3,93 +3,120 @@ Flower classification service using ConvNeXt and CLIP models.
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
6 |
import torch
|
|
|
7 |
from transformers import (
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
)
|
11 |
-
from PIL import Image
|
12 |
-
from typing import List, Dict, Tuple, Optional
|
13 |
|
14 |
try:
|
15 |
from core.config import config
|
16 |
-
from core.constants import
|
|
|
|
|
|
|
|
|
|
|
17 |
except ImportError:
|
18 |
-
import sys
|
19 |
import os
|
|
|
|
|
20 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
21 |
from core.config import config
|
22 |
-
from core.constants import
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class FlowerClassificationService:
|
25 |
"""Service for flower classification using ConvNeXt and CLIP models."""
|
26 |
-
|
27 |
def __init__(self):
|
28 |
self.zs_classifier = None
|
29 |
self.convnext_model = None
|
30 |
self.convnext_processor = None
|
31 |
self.current_model_path = DEFAULT_CONVNEXT_MODEL
|
32 |
self._initialize_models()
|
33 |
-
|
34 |
def _initialize_models(self):
|
35 |
"""Initialize the classification models."""
|
36 |
self.load_classifier()
|
37 |
-
|
38 |
def load_classifier(self, model_path: str = DEFAULT_CONVNEXT_MODEL) -> str:
|
39 |
"""Load classification model from path."""
|
40 |
try:
|
41 |
if os.path.exists(model_path):
|
42 |
# Load custom trained model
|
43 |
-
self.convnext_model = AutoModelForImageClassification.from_pretrained(
|
|
|
|
|
44 |
self.convnext_processor = AutoImageProcessor.from_pretrained(model_path)
|
45 |
self.current_model_path = model_path
|
46 |
# Also keep zero-shot classifier for fallback
|
47 |
self.zs_classifier = pipeline(
|
48 |
task="zero-shot-image-classification",
|
49 |
model=DEFAULT_CLIP_MODEL,
|
50 |
-
device=config.clf_device
|
51 |
)
|
52 |
return f"β
Loaded custom ConvNeXt model from: {model_path}"
|
53 |
else:
|
54 |
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
|
55 |
-
self.convnext_model = ConvNextForImageClassification.from_pretrained(
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
self.zs_classifier = pipeline(
|
58 |
task="zero-shot-image-classification",
|
59 |
model=DEFAULT_CLIP_MODEL,
|
60 |
-
device=config.clf_device
|
61 |
)
|
62 |
self.current_model_path = DEFAULT_CONVNEXT_MODEL
|
63 |
return f"β
Loaded default ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}"
|
64 |
except Exception as e:
|
65 |
-
return f"β Error loading model: {
|
66 |
-
|
67 |
-
def identify_flowers(
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
"""Identify flowers in an image."""
|
71 |
if image is None:
|
72 |
return [], "Please provide an image (upload or generate first)."
|
73 |
-
|
74 |
labels = candidate_labels if candidate_labels else FLOWER_LABELS
|
75 |
-
|
76 |
# Use ConvNeXt for feature extraction if we have a trained model
|
77 |
-
if (
|
78 |
-
|
79 |
-
self.current_model_path
|
|
|
|
|
80 |
try:
|
81 |
# Use trained ConvNeXt model
|
82 |
inputs = self.convnext_processor(images=image, return_tensors="pt")
|
83 |
with torch.no_grad():
|
84 |
outputs = self.convnext_model(**inputs)
|
85 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
86 |
-
|
87 |
# Convert predictions to results format
|
88 |
results = []
|
89 |
for i, score in enumerate(predictions[0]):
|
90 |
if i < len(labels):
|
91 |
results.append({"label": labels[i], "score": float(score)})
|
92 |
-
|
93 |
# Sort by score
|
94 |
results = sorted(results, key=lambda r: r["score"], reverse=True)
|
95 |
model_type = "ConvNeXt"
|
@@ -101,35 +128,36 @@ class FlowerClassificationService:
|
|
101 |
# Use CLIP zero-shot classification
|
102 |
results = self._use_clip_classification(image, labels)
|
103 |
model_type = "CLIP zero-shot"
|
104 |
-
|
105 |
# Filter and format results
|
106 |
results = [r for r in results if r["score"] >= min_score]
|
107 |
results = sorted(results, key=lambda r: r["score"], reverse=True)[:top_k]
|
108 |
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
|
109 |
msg = f"Detected flowers using {model_type}."
|
110 |
return table, msg
|
111 |
-
|
112 |
-
def _use_clip_classification(
|
|
|
|
|
113 |
"""Use CLIP zero-shot classification."""
|
114 |
return self.zs_classifier(
|
115 |
-
image,
|
116 |
-
candidate_labels=labels,
|
117 |
-
hypothesis_template="a photo of a {}"
|
118 |
)
|
119 |
-
|
120 |
-
def get_available_models(self) ->
|
121 |
"""Get list of available models."""
|
122 |
models = [f"{DEFAULT_CONVNEXT_MODEL} (default)"]
|
123 |
-
|
124 |
if os.path.exists(MODELS_DIR):
|
125 |
for item in os.listdir(MODELS_DIR):
|
126 |
model_path = os.path.join(MODELS_DIR, item)
|
127 |
-
if
|
128 |
-
os.path.
|
|
|
129 |
models.append(f"Custom: {item}")
|
130 |
-
|
131 |
return models
|
132 |
-
|
133 |
def load_trained_model(self, model_selection: str) -> str:
|
134 |
"""Load a specific trained model."""
|
135 |
if model_selection.startswith("Custom:"):
|
@@ -139,5 +167,6 @@ class FlowerClassificationService:
|
|
139 |
else:
|
140 |
return self.load_classifier(DEFAULT_CONVNEXT_MODEL)
|
141 |
|
|
|
142 |
# Global service instance
|
143 |
-
flower_classifier = FlowerClassificationService()
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
+
|
7 |
import torch
|
8 |
+
from PIL import Image
|
9 |
from transformers import (
|
10 |
+
AutoImageProcessor,
|
11 |
+
AutoModelForImageClassification,
|
12 |
+
ConvNextForImageClassification,
|
13 |
+
ConvNextImageProcessor,
|
14 |
+
pipeline,
|
15 |
)
|
|
|
|
|
16 |
|
17 |
try:
|
18 |
from core.config import config
|
19 |
+
from core.constants import (
|
20 |
+
DEFAULT_CLIP_MODEL,
|
21 |
+
DEFAULT_CONVNEXT_MODEL,
|
22 |
+
FLOWER_LABELS,
|
23 |
+
MODELS_DIR,
|
24 |
+
)
|
25 |
except ImportError:
|
|
|
26 |
import os
|
27 |
+
import sys
|
28 |
+
|
29 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
30 |
from core.config import config
|
31 |
+
from core.constants import (
|
32 |
+
DEFAULT_CLIP_MODEL,
|
33 |
+
DEFAULT_CONVNEXT_MODEL,
|
34 |
+
FLOWER_LABELS,
|
35 |
+
MODELS_DIR,
|
36 |
+
)
|
37 |
+
|
38 |
|
39 |
class FlowerClassificationService:
|
40 |
"""Service for flower classification using ConvNeXt and CLIP models."""
|
41 |
+
|
42 |
def __init__(self):
|
43 |
self.zs_classifier = None
|
44 |
self.convnext_model = None
|
45 |
self.convnext_processor = None
|
46 |
self.current_model_path = DEFAULT_CONVNEXT_MODEL
|
47 |
self._initialize_models()
|
48 |
+
|
49 |
def _initialize_models(self):
|
50 |
"""Initialize the classification models."""
|
51 |
self.load_classifier()
|
52 |
+
|
53 |
def load_classifier(self, model_path: str = DEFAULT_CONVNEXT_MODEL) -> str:
|
54 |
"""Load classification model from path."""
|
55 |
try:
|
56 |
if os.path.exists(model_path):
|
57 |
# Load custom trained model
|
58 |
+
self.convnext_model = AutoModelForImageClassification.from_pretrained(
|
59 |
+
model_path
|
60 |
+
)
|
61 |
self.convnext_processor = AutoImageProcessor.from_pretrained(model_path)
|
62 |
self.current_model_path = model_path
|
63 |
# Also keep zero-shot classifier for fallback
|
64 |
self.zs_classifier = pipeline(
|
65 |
task="zero-shot-image-classification",
|
66 |
model=DEFAULT_CLIP_MODEL,
|
67 |
+
device=config.clf_device,
|
68 |
)
|
69 |
return f"β
Loaded custom ConvNeXt model from: {model_path}"
|
70 |
else:
|
71 |
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
|
72 |
+
self.convnext_model = ConvNextForImageClassification.from_pretrained(
|
73 |
+
DEFAULT_CONVNEXT_MODEL
|
74 |
+
)
|
75 |
+
self.convnext_processor = ConvNextImageProcessor.from_pretrained(
|
76 |
+
DEFAULT_CONVNEXT_MODEL
|
77 |
+
)
|
78 |
self.zs_classifier = pipeline(
|
79 |
task="zero-shot-image-classification",
|
80 |
model=DEFAULT_CLIP_MODEL,
|
81 |
+
device=config.clf_device,
|
82 |
)
|
83 |
self.current_model_path = DEFAULT_CONVNEXT_MODEL
|
84 |
return f"β
Loaded default ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}"
|
85 |
except Exception as e:
|
86 |
+
return f"β Error loading model: {e!s}"
|
87 |
+
|
88 |
+
def identify_flowers(
|
89 |
+
self,
|
90 |
+
image: Image.Image | None,
|
91 |
+
candidate_labels: list[str] | None = None,
|
92 |
+
top_k: int = 7,
|
93 |
+
min_score: float = 0.12,
|
94 |
+
) -> tuple[list[list], str]:
|
95 |
"""Identify flowers in an image."""
|
96 |
if image is None:
|
97 |
return [], "Please provide an image (upload or generate first)."
|
98 |
+
|
99 |
labels = candidate_labels if candidate_labels else FLOWER_LABELS
|
100 |
+
|
101 |
# Use ConvNeXt for feature extraction if we have a trained model
|
102 |
+
if (
|
103 |
+
self.convnext_model is not None
|
104 |
+
and os.path.exists(self.current_model_path)
|
105 |
+
and self.current_model_path != DEFAULT_CONVNEXT_MODEL
|
106 |
+
):
|
107 |
try:
|
108 |
# Use trained ConvNeXt model
|
109 |
inputs = self.convnext_processor(images=image, return_tensors="pt")
|
110 |
with torch.no_grad():
|
111 |
outputs = self.convnext_model(**inputs)
|
112 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
113 |
+
|
114 |
# Convert predictions to results format
|
115 |
results = []
|
116 |
for i, score in enumerate(predictions[0]):
|
117 |
if i < len(labels):
|
118 |
results.append({"label": labels[i], "score": float(score)})
|
119 |
+
|
120 |
# Sort by score
|
121 |
results = sorted(results, key=lambda r: r["score"], reverse=True)
|
122 |
model_type = "ConvNeXt"
|
|
|
128 |
# Use CLIP zero-shot classification
|
129 |
results = self._use_clip_classification(image, labels)
|
130 |
model_type = "CLIP zero-shot"
|
131 |
+
|
132 |
# Filter and format results
|
133 |
results = [r for r in results if r["score"] >= min_score]
|
134 |
results = sorted(results, key=lambda r: r["score"], reverse=True)[:top_k]
|
135 |
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
|
136 |
msg = f"Detected flowers using {model_type}."
|
137 |
return table, msg
|
138 |
+
|
139 |
+
def _use_clip_classification(
|
140 |
+
self, image: Image.Image, labels: list[str]
|
141 |
+
) -> list[dict]:
|
142 |
"""Use CLIP zero-shot classification."""
|
143 |
return self.zs_classifier(
|
144 |
+
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
|
|
|
|
|
145 |
)
|
146 |
+
|
147 |
+
def get_available_models(self) -> list[str]:
|
148 |
"""Get list of available models."""
|
149 |
models = [f"{DEFAULT_CONVNEXT_MODEL} (default)"]
|
150 |
+
|
151 |
if os.path.exists(MODELS_DIR):
|
152 |
for item in os.listdir(MODELS_DIR):
|
153 |
model_path = os.path.join(MODELS_DIR, item)
|
154 |
+
if os.path.isdir(model_path) and os.path.exists(
|
155 |
+
os.path.join(model_path, "config.json")
|
156 |
+
):
|
157 |
models.append(f"Custom: {item}")
|
158 |
+
|
159 |
return models
|
160 |
+
|
161 |
def load_trained_model(self, model_selection: str) -> str:
|
162 |
"""Load a specific trained model."""
|
163 |
if model_selection.startswith("Custom:"):
|
|
|
167 |
else:
|
168 |
return self.load_classifier(DEFAULT_CONVNEXT_MODEL)
|
169 |
|
170 |
+
|
171 |
# Global service instance
|
172 |
+
flower_classifier = FlowerClassificationService()
|
src/services/models/image_generation.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
"""Image generation service using SDXL models."""
|
2 |
|
3 |
-
from typing import Optional
|
4 |
|
5 |
-
import torch
|
6 |
import numpy as np
|
|
|
7 |
from diffusers import AutoPipelineForText2Image
|
8 |
from PIL import Image
|
9 |
|
@@ -18,6 +17,7 @@ except ImportError:
|
|
18 |
from core.config import config
|
19 |
from core.constants import DEFAULT_MODEL_ID, FALLBACK_MODEL_ID
|
20 |
|
|
|
21 |
class ImageGenerationService:
|
22 |
"""Service for generating images using SDXL models."""
|
23 |
|
@@ -35,7 +35,7 @@ class ImageGenerationService:
|
|
35 |
else:
|
36 |
model_id = DEFAULT_MODEL_ID
|
37 |
model_name = "SDXL"
|
38 |
-
|
39 |
# Try primary SDXL model
|
40 |
try:
|
41 |
print(f"π Attempting to load {model_name}: {model_id}")
|
@@ -44,7 +44,7 @@ class ImageGenerationService:
|
|
44 |
).to(config.device)
|
45 |
self.model_type = "SDXL"
|
46 |
print(f"β
{model_name} loaded successfully")
|
47 |
-
|
48 |
# Enable SDXL-specific optimizations
|
49 |
if config.device == "cuda":
|
50 |
try:
|
@@ -53,10 +53,10 @@ class ImageGenerationService:
|
|
53 |
self.pipe.enable_attention_slicing()
|
54 |
else:
|
55 |
self.pipe.enable_attention_slicing()
|
56 |
-
|
57 |
except Exception as e:
|
58 |
print(f"β οΈ {model_name} failed to load: {e}")
|
59 |
-
|
60 |
# Try fallback to SDXL-Turbo if we're not on HF Spaces and not using it already
|
61 |
if not config.is_hf_spaces and model_id != FALLBACK_MODEL_ID:
|
62 |
try:
|
@@ -66,7 +66,7 @@ class ImageGenerationService:
|
|
66 |
).to(config.device)
|
67 |
self.model_type = "SDXL"
|
68 |
print("β
SDXL-Turbo loaded successfully")
|
69 |
-
|
70 |
# Enable optimizations
|
71 |
if config.device == "cuda":
|
72 |
try:
|
@@ -78,17 +78,19 @@ class ImageGenerationService:
|
|
78 |
return
|
79 |
except Exception as turbo_error:
|
80 |
print(f"β οΈ SDXL-Turbo also failed to load: {turbo_error}")
|
81 |
-
raise RuntimeError(
|
|
|
|
|
82 |
else:
|
83 |
raise RuntimeError(f"SDXL model failed to load: {e}")
|
84 |
-
|
85 |
def generate(
|
86 |
self,
|
87 |
prompt: str,
|
88 |
steps: int = 4,
|
89 |
width: int = 1024,
|
90 |
height: int = 1024,
|
91 |
-
seed:
|
92 |
) -> Image.Image:
|
93 |
"""Generate an image from a text prompt."""
|
94 |
if seed is None or seed < 0:
|
@@ -112,10 +114,10 @@ class ImageGenerationService:
|
|
112 |
|
113 |
# Validate and clean the image before returning
|
114 |
image = result.images[0]
|
115 |
-
|
116 |
# Convert to numpy array to check for invalid values
|
117 |
img_array = np.array(image)
|
118 |
-
|
119 |
# Check for NaN or inf values and replace them
|
120 |
if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
|
121 |
print("β οΈ Warning: Image contains invalid values (NaN/inf), cleaning...")
|
@@ -123,12 +125,13 @@ class ImageGenerationService:
|
|
123 |
# Ensure values are in valid range [0, 255]
|
124 |
img_array = np.clip(img_array, 0, 255).astype(np.uint8)
|
125 |
image = Image.fromarray(img_array)
|
126 |
-
|
127 |
return image
|
128 |
-
|
129 |
def get_model_info(self) -> str:
|
130 |
"""Get information about the currently loaded model."""
|
131 |
return f"Model: {self.model_type} (Stable Diffusion XL)"
|
132 |
|
|
|
133 |
# Global service instance
|
134 |
-
image_generator = ImageGenerationService()
|
|
|
1 |
"""Image generation service using SDXL models."""
|
2 |
|
|
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
+
import torch
|
6 |
from diffusers import AutoPipelineForText2Image
|
7 |
from PIL import Image
|
8 |
|
|
|
17 |
from core.config import config
|
18 |
from core.constants import DEFAULT_MODEL_ID, FALLBACK_MODEL_ID
|
19 |
|
20 |
+
|
21 |
class ImageGenerationService:
|
22 |
"""Service for generating images using SDXL models."""
|
23 |
|
|
|
35 |
else:
|
36 |
model_id = DEFAULT_MODEL_ID
|
37 |
model_name = "SDXL"
|
38 |
+
|
39 |
# Try primary SDXL model
|
40 |
try:
|
41 |
print(f"π Attempting to load {model_name}: {model_id}")
|
|
|
44 |
).to(config.device)
|
45 |
self.model_type = "SDXL"
|
46 |
print(f"β
{model_name} loaded successfully")
|
47 |
+
|
48 |
# Enable SDXL-specific optimizations
|
49 |
if config.device == "cuda":
|
50 |
try:
|
|
|
53 |
self.pipe.enable_attention_slicing()
|
54 |
else:
|
55 |
self.pipe.enable_attention_slicing()
|
56 |
+
|
57 |
except Exception as e:
|
58 |
print(f"β οΈ {model_name} failed to load: {e}")
|
59 |
+
|
60 |
# Try fallback to SDXL-Turbo if we're not on HF Spaces and not using it already
|
61 |
if not config.is_hf_spaces and model_id != FALLBACK_MODEL_ID:
|
62 |
try:
|
|
|
66 |
).to(config.device)
|
67 |
self.model_type = "SDXL"
|
68 |
print("β
SDXL-Turbo loaded successfully")
|
69 |
+
|
70 |
# Enable optimizations
|
71 |
if config.device == "cuda":
|
72 |
try:
|
|
|
78 |
return
|
79 |
except Exception as turbo_error:
|
80 |
print(f"β οΈ SDXL-Turbo also failed to load: {turbo_error}")
|
81 |
+
raise RuntimeError(
|
82 |
+
f"All SDXL models failed to load. Last error: {turbo_error}"
|
83 |
+
)
|
84 |
else:
|
85 |
raise RuntimeError(f"SDXL model failed to load: {e}")
|
86 |
+
|
87 |
def generate(
|
88 |
self,
|
89 |
prompt: str,
|
90 |
steps: int = 4,
|
91 |
width: int = 1024,
|
92 |
height: int = 1024,
|
93 |
+
seed: int | None = None,
|
94 |
) -> Image.Image:
|
95 |
"""Generate an image from a text prompt."""
|
96 |
if seed is None or seed < 0:
|
|
|
114 |
|
115 |
# Validate and clean the image before returning
|
116 |
image = result.images[0]
|
117 |
+
|
118 |
# Convert to numpy array to check for invalid values
|
119 |
img_array = np.array(image)
|
120 |
+
|
121 |
# Check for NaN or inf values and replace them
|
122 |
if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
|
123 |
print("β οΈ Warning: Image contains invalid values (NaN/inf), cleaning...")
|
|
|
125 |
# Ensure values are in valid range [0, 255]
|
126 |
img_array = np.clip(img_array, 0, 255).astype(np.uint8)
|
127 |
image = Image.fromarray(img_array)
|
128 |
+
|
129 |
return image
|
130 |
+
|
131 |
def get_model_info(self) -> str:
|
132 |
"""Get information about the currently loaded model."""
|
133 |
return f"Model: {self.model_type} (Stable Diffusion XL)"
|
134 |
|
135 |
+
|
136 |
# Global service instance
|
137 |
+
image_generator = ImageGenerationService()
|
src/services/training/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Training package
|
|
|
1 |
+
# Training package
|
src/services/training/dataset.py
CHANGED
@@ -3,54 +3,59 @@ Dataset class for flower training data.
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import Dataset
|
9 |
-
from typing import List, Optional
|
10 |
|
11 |
-
from utils.file_utils import
|
|
|
12 |
|
13 |
class FlowerDataset(Dataset):
|
14 |
"""Dataset for flower classification training."""
|
15 |
-
|
16 |
-
def __init__(
|
|
|
|
|
17 |
self.image_paths = []
|
18 |
self.labels = []
|
19 |
self.processor = processor
|
20 |
-
|
21 |
# Auto-detect flower types from directory structure if not provided
|
22 |
if flower_labels is None:
|
23 |
self.flower_labels = get_flower_types_from_directory(image_dir)
|
24 |
else:
|
25 |
self.flower_labels = flower_labels
|
26 |
-
|
27 |
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
|
28 |
-
|
29 |
# Load images from subdirectories (organized by flower type)
|
30 |
for flower_type in os.listdir(image_dir):
|
31 |
flower_path = os.path.join(image_dir, flower_type)
|
32 |
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
|
33 |
image_files = get_image_files(flower_path)
|
34 |
-
|
35 |
for img_path in image_files:
|
36 |
self.image_paths.append(img_path)
|
37 |
self.labels.append(self.label_to_id[flower_type])
|
38 |
-
|
39 |
-
print(
|
|
|
|
|
40 |
print(f"Flower types: {self.flower_labels}")
|
41 |
-
|
42 |
def __len__(self):
|
43 |
return len(self.image_paths)
|
44 |
-
|
45 |
def __getitem__(self, idx):
|
46 |
image_path = self.image_paths[idx]
|
47 |
image = Image.open(image_path).convert("RGB")
|
48 |
label = self.labels[idx]
|
49 |
-
|
50 |
# Process image for ConvNeXt
|
51 |
inputs = self.processor(images=image, return_tensors="pt")
|
52 |
-
|
53 |
return {
|
54 |
-
|
55 |
-
|
56 |
-
}
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
+
|
7 |
import torch
|
8 |
from PIL import Image
|
9 |
from torch.utils.data import Dataset
|
|
|
10 |
|
11 |
+
from utils.file_utils import get_flower_types_from_directory, get_image_files
|
12 |
+
|
13 |
|
14 |
class FlowerDataset(Dataset):
|
15 |
"""Dataset for flower classification training."""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self, image_dir: str, processor, flower_labels: list[str] | None = None
|
19 |
+
):
|
20 |
self.image_paths = []
|
21 |
self.labels = []
|
22 |
self.processor = processor
|
23 |
+
|
24 |
# Auto-detect flower types from directory structure if not provided
|
25 |
if flower_labels is None:
|
26 |
self.flower_labels = get_flower_types_from_directory(image_dir)
|
27 |
else:
|
28 |
self.flower_labels = flower_labels
|
29 |
+
|
30 |
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
|
31 |
+
|
32 |
# Load images from subdirectories (organized by flower type)
|
33 |
for flower_type in os.listdir(image_dir):
|
34 |
flower_path = os.path.join(image_dir, flower_type)
|
35 |
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
|
36 |
image_files = get_image_files(flower_path)
|
37 |
+
|
38 |
for img_path in image_files:
|
39 |
self.image_paths.append(img_path)
|
40 |
self.labels.append(self.label_to_id[flower_type])
|
41 |
+
|
42 |
+
print(
|
43 |
+
f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
|
44 |
+
)
|
45 |
print(f"Flower types: {self.flower_labels}")
|
46 |
+
|
47 |
def __len__(self):
|
48 |
return len(self.image_paths)
|
49 |
+
|
50 |
def __getitem__(self, idx):
|
51 |
image_path = self.image_paths[idx]
|
52 |
image = Image.open(image_path).convert("RGB")
|
53 |
label = self.labels[idx]
|
54 |
+
|
55 |
# Process image for ConvNeXt
|
56 |
inputs = self.processor(images=image, return_tensors="pt")
|
57 |
+
|
58 |
return {
|
59 |
+
"pixel_values": inputs["pixel_values"].squeeze(),
|
60 |
+
"labels": torch.tensor(label, dtype=torch.long),
|
61 |
+
}
|
src/services/training/training_service.py
CHANGED
@@ -3,36 +3,38 @@ Training service for flower classification models.
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
-
from typing import Optional
|
7 |
|
8 |
from core.constants import IMAGES_DIR
|
9 |
from utils.file_utils import count_training_images
|
10 |
|
|
|
11 |
class TrainingService:
|
12 |
"""Service for managing model training."""
|
13 |
-
|
14 |
def __init__(self):
|
15 |
pass
|
16 |
-
|
17 |
-
def start_training(
|
18 |
-
|
|
|
19 |
"""Start the training process."""
|
20 |
try:
|
21 |
# Check if training data exists
|
22 |
if not os.path.exists(IMAGES_DIR):
|
23 |
return "β Training directory not found. Please create training_data/images/ and add your data."
|
24 |
-
|
25 |
# Count images
|
26 |
total_images, _ = count_training_images()
|
27 |
-
|
28 |
if total_images < 10:
|
29 |
return f"β Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
|
30 |
-
|
31 |
# Import and run training (lazy import to avoid startup issues)
|
32 |
try:
|
33 |
from training.simple_train import simple_train
|
|
|
34 |
model_path = simple_train()
|
35 |
-
|
36 |
if model_path:
|
37 |
return f"β
Training completed! Model saved to: {model_path}"
|
38 |
else:
|
@@ -41,17 +43,19 @@ class TrainingService:
|
|
41 |
# Fallback to old training method
|
42 |
try:
|
43 |
from simple_train import simple_train as legacy_train
|
|
|
44 |
model_path = legacy_train()
|
45 |
-
|
46 |
if model_path:
|
47 |
return f"β
Training completed! Model saved to: {model_path}"
|
48 |
else:
|
49 |
return "β Training failed. Check the console for details."
|
50 |
except ImportError:
|
51 |
return "β Training module not found. Please ensure training scripts are available."
|
52 |
-
|
53 |
except Exception as e:
|
54 |
-
return f"β Training error: {
|
|
|
55 |
|
56 |
# Global service instance
|
57 |
-
training_service = TrainingService()
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
6 |
|
7 |
from core.constants import IMAGES_DIR
|
8 |
from utils.file_utils import count_training_images
|
9 |
|
10 |
+
|
11 |
class TrainingService:
|
12 |
"""Service for managing model training."""
|
13 |
+
|
14 |
def __init__(self):
|
15 |
pass
|
16 |
+
|
17 |
+
def start_training(
|
18 |
+
self, epochs: int = 5, batch_size: int = 8, learning_rate: float = 1e-5
|
19 |
+
) -> str:
|
20 |
"""Start the training process."""
|
21 |
try:
|
22 |
# Check if training data exists
|
23 |
if not os.path.exists(IMAGES_DIR):
|
24 |
return "β Training directory not found. Please create training_data/images/ and add your data."
|
25 |
+
|
26 |
# Count images
|
27 |
total_images, _ = count_training_images()
|
28 |
+
|
29 |
if total_images < 10:
|
30 |
return f"β Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
|
31 |
+
|
32 |
# Import and run training (lazy import to avoid startup issues)
|
33 |
try:
|
34 |
from training.simple_train import simple_train
|
35 |
+
|
36 |
model_path = simple_train()
|
37 |
+
|
38 |
if model_path:
|
39 |
return f"β
Training completed! Model saved to: {model_path}"
|
40 |
else:
|
|
|
43 |
# Fallback to old training method
|
44 |
try:
|
45 |
from simple_train import simple_train as legacy_train
|
46 |
+
|
47 |
model_path = legacy_train()
|
48 |
+
|
49 |
if model_path:
|
50 |
return f"β
Training completed! Model saved to: {model_path}"
|
51 |
else:
|
52 |
return "β Training failed. Check the console for details."
|
53 |
except ImportError:
|
54 |
return "β Training module not found. Please ensure training scripts are available."
|
55 |
+
|
56 |
except Exception as e:
|
57 |
+
return f"β Training error: {e!s}"
|
58 |
+
|
59 |
|
60 |
# Global service instance
|
61 |
+
training_service = TrainingService()
|
src/training/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Training package
|
|
|
1 |
+
# Training package
|
src/training/simple_train.py
CHANGED
@@ -3,114 +3,125 @@ Simple ConvNeXt training script without using the Transformers Trainer class.
|
|
3 |
Refactored version of the original simple_train.py
|
4 |
"""
|
5 |
|
|
|
6 |
import os
|
|
|
7 |
import torch
|
8 |
-
import torch.nn as nn
|
9 |
from torch.utils.data import DataLoader
|
10 |
-
from transformers import
|
11 |
-
import json
|
12 |
|
13 |
-
from ..services.training.dataset import FlowerDataset
|
14 |
from ..core.config import config
|
15 |
from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR
|
|
|
|
|
16 |
|
17 |
def simple_train():
|
18 |
"""Simple ConvNeXt training function."""
|
19 |
print("πΈ Simple ConvNeXt Flower Model Training")
|
20 |
print("=" * 40)
|
21 |
-
|
22 |
# Check training data
|
23 |
images_dir = "training_data/images"
|
24 |
if not os.path.exists(images_dir):
|
25 |
print("β Training directory not found")
|
26 |
return
|
27 |
-
|
28 |
device = config.device
|
29 |
print(f"Using device: {device}")
|
30 |
-
|
31 |
# Load model and processor
|
32 |
model_name = DEFAULT_CONVNEXT_MODEL
|
33 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
34 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
35 |
model.to(device)
|
36 |
-
|
37 |
# Create dataset
|
38 |
dataset = FlowerDataset(images_dir, processor)
|
39 |
-
|
40 |
if len(dataset) < 5:
|
41 |
print("β Need at least 5 images for training")
|
42 |
return
|
43 |
-
|
44 |
# Update model config for the number of classes
|
45 |
if len(dataset.flower_labels) != model.config.num_labels:
|
46 |
model.config.num_labels = len(dataset.flower_labels)
|
47 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
48 |
-
final_hidden_size =
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Split dataset
|
52 |
train_size = int(0.8 * len(dataset))
|
53 |
train_dataset = torch.utils.data.Subset(dataset, range(train_size))
|
54 |
-
|
55 |
# Create data loader
|
56 |
def simple_collate_fn(batch):
|
57 |
pixel_values = []
|
58 |
labels = []
|
59 |
-
|
60 |
for item in batch:
|
61 |
-
pixel_values.append(item[
|
62 |
-
labels.append(item[
|
63 |
-
|
64 |
return {
|
65 |
-
|
66 |
-
|
67 |
}
|
68 |
-
|
69 |
-
train_loader = DataLoader(
|
70 |
-
|
|
|
|
|
71 |
# Setup optimizer
|
72 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
73 |
-
|
74 |
# Training loop
|
75 |
model.train()
|
76 |
print(f"Starting training on {len(train_dataset)} samples...")
|
77 |
-
|
78 |
for epoch in range(3):
|
79 |
total_loss = 0
|
80 |
num_batches = 0
|
81 |
-
|
82 |
for batch_idx, batch in enumerate(train_loader):
|
83 |
# Move to device
|
84 |
-
pixel_values = batch[
|
85 |
-
labels = batch[
|
86 |
-
|
87 |
# Zero gradients
|
88 |
optimizer.zero_grad()
|
89 |
-
|
90 |
# Forward pass
|
91 |
outputs = model(pixel_values=pixel_values, labels=labels)
|
92 |
loss = outputs.loss
|
93 |
-
|
94 |
# Backward pass
|
95 |
loss.backward()
|
96 |
optimizer.step()
|
97 |
-
|
98 |
total_loss += loss.item()
|
99 |
num_batches += 1
|
100 |
-
|
101 |
if batch_idx % 2 == 0:
|
102 |
-
print(
|
103 |
-
|
|
|
|
|
104 |
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
105 |
-
print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
|
106 |
-
|
107 |
# Save model
|
108 |
output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext")
|
109 |
os.makedirs(output_dir, exist_ok=True)
|
110 |
-
|
111 |
model.save_pretrained(output_dir)
|
112 |
processor.save_pretrained(output_dir)
|
113 |
-
|
114 |
# Save config
|
115 |
config_data = {
|
116 |
"model_name": model_name,
|
@@ -119,15 +130,16 @@ def simple_train():
|
|
119 |
"batch_size": 4,
|
120 |
"learning_rate": 1e-5,
|
121 |
"train_samples": len(train_dataset),
|
122 |
-
"num_labels": len(dataset.flower_labels)
|
123 |
}
|
124 |
-
|
125 |
with open(os.path.join(output_dir, "training_config.json"), "w") as f:
|
126 |
json.dump(config_data, f, indent=2)
|
127 |
-
|
128 |
print(f"β
ConvNeXt training completed! Model saved to {output_dir}")
|
129 |
return output_dir
|
130 |
|
|
|
131 |
if __name__ == "__main__":
|
132 |
try:
|
133 |
simple_train()
|
@@ -136,4 +148,5 @@ if __name__ == "__main__":
|
|
136 |
except Exception as e:
|
137 |
print(f"β Training failed: {e}")
|
138 |
import traceback
|
139 |
-
|
|
|
|
3 |
Refactored version of the original simple_train.py
|
4 |
"""
|
5 |
|
6 |
+
import json
|
7 |
import os
|
8 |
+
|
9 |
import torch
|
|
|
10 |
from torch.utils.data import DataLoader
|
11 |
+
from transformers import ConvNextForImageClassification, ConvNextImageProcessor
|
|
|
12 |
|
|
|
13 |
from ..core.config import config
|
14 |
from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR
|
15 |
+
from ..services.training.dataset import FlowerDataset
|
16 |
+
|
17 |
|
18 |
def simple_train():
|
19 |
"""Simple ConvNeXt training function."""
|
20 |
print("πΈ Simple ConvNeXt Flower Model Training")
|
21 |
print("=" * 40)
|
22 |
+
|
23 |
# Check training data
|
24 |
images_dir = "training_data/images"
|
25 |
if not os.path.exists(images_dir):
|
26 |
print("β Training directory not found")
|
27 |
return
|
28 |
+
|
29 |
device = config.device
|
30 |
print(f"Using device: {device}")
|
31 |
+
|
32 |
# Load model and processor
|
33 |
model_name = DEFAULT_CONVNEXT_MODEL
|
34 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
35 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
36 |
model.to(device)
|
37 |
+
|
38 |
# Create dataset
|
39 |
dataset = FlowerDataset(images_dir, processor)
|
40 |
+
|
41 |
if len(dataset) < 5:
|
42 |
print("β Need at least 5 images for training")
|
43 |
return
|
44 |
+
|
45 |
# Update model config for the number of classes
|
46 |
if len(dataset.flower_labels) != model.config.num_labels:
|
47 |
model.config.num_labels = len(dataset.flower_labels)
|
48 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
49 |
+
final_hidden_size = (
|
50 |
+
model.config.hidden_sizes[-1]
|
51 |
+
if hasattr(model.config, "hidden_sizes")
|
52 |
+
else 768
|
53 |
+
)
|
54 |
+
model.classifier = torch.nn.Linear(
|
55 |
+
final_hidden_size, len(dataset.flower_labels)
|
56 |
+
)
|
57 |
+
|
58 |
# Split dataset
|
59 |
train_size = int(0.8 * len(dataset))
|
60 |
train_dataset = torch.utils.data.Subset(dataset, range(train_size))
|
61 |
+
|
62 |
# Create data loader
|
63 |
def simple_collate_fn(batch):
|
64 |
pixel_values = []
|
65 |
labels = []
|
66 |
+
|
67 |
for item in batch:
|
68 |
+
pixel_values.append(item["pixel_values"])
|
69 |
+
labels.append(item["labels"])
|
70 |
+
|
71 |
return {
|
72 |
+
"pixel_values": torch.stack(pixel_values),
|
73 |
+
"labels": torch.stack(labels),
|
74 |
}
|
75 |
+
|
76 |
+
train_loader = DataLoader(
|
77 |
+
train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn
|
78 |
+
)
|
79 |
+
|
80 |
# Setup optimizer
|
81 |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
82 |
+
|
83 |
# Training loop
|
84 |
model.train()
|
85 |
print(f"Starting training on {len(train_dataset)} samples...")
|
86 |
+
|
87 |
for epoch in range(3):
|
88 |
total_loss = 0
|
89 |
num_batches = 0
|
90 |
+
|
91 |
for batch_idx, batch in enumerate(train_loader):
|
92 |
# Move to device
|
93 |
+
pixel_values = batch["pixel_values"].to(device)
|
94 |
+
labels = batch["labels"].to(device)
|
95 |
+
|
96 |
# Zero gradients
|
97 |
optimizer.zero_grad()
|
98 |
+
|
99 |
# Forward pass
|
100 |
outputs = model(pixel_values=pixel_values, labels=labels)
|
101 |
loss = outputs.loss
|
102 |
+
|
103 |
# Backward pass
|
104 |
loss.backward()
|
105 |
optimizer.step()
|
106 |
+
|
107 |
total_loss += loss.item()
|
108 |
num_batches += 1
|
109 |
+
|
110 |
if batch_idx % 2 == 0:
|
111 |
+
print(
|
112 |
+
f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss = {loss.item():.4f}"
|
113 |
+
)
|
114 |
+
|
115 |
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
116 |
+
print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
|
117 |
+
|
118 |
# Save model
|
119 |
output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext")
|
120 |
os.makedirs(output_dir, exist_ok=True)
|
121 |
+
|
122 |
model.save_pretrained(output_dir)
|
123 |
processor.save_pretrained(output_dir)
|
124 |
+
|
125 |
# Save config
|
126 |
config_data = {
|
127 |
"model_name": model_name,
|
|
|
130 |
"batch_size": 4,
|
131 |
"learning_rate": 1e-5,
|
132 |
"train_samples": len(train_dataset),
|
133 |
+
"num_labels": len(dataset.flower_labels),
|
134 |
}
|
135 |
+
|
136 |
with open(os.path.join(output_dir, "training_config.json"), "w") as f:
|
137 |
json.dump(config_data, f, indent=2)
|
138 |
+
|
139 |
print(f"β
ConvNeXt training completed! Model saved to {output_dir}")
|
140 |
return output_dir
|
141 |
|
142 |
+
|
143 |
if __name__ == "__main__":
|
144 |
try:
|
145 |
simple_train()
|
|
|
148 |
except Exception as e:
|
149 |
print(f"β Training failed: {e}")
|
150 |
import traceback
|
151 |
+
|
152 |
+
traceback.print_exc()
|
src/ui/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# UI package
|
|
|
1 |
+
# UI package
|
src/ui/french_style/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# French style tab package
|
|
|
1 |
+
# French style tab package
|
src/ui/french_style/french_style_tab.py
CHANGED
@@ -2,121 +2,129 @@
|
|
2 |
French Style tab UI components and logic.
|
3 |
"""
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
-
from typing import Optional, Tuple
|
8 |
|
9 |
try:
|
|
|
10 |
from services.models.flower_classification import flower_classifier
|
11 |
from services.models.image_generation import image_generator
|
12 |
from utils.color_utils import extract_dominant_colors
|
13 |
-
from core.constants import FLOWER_LABELS, DEFAULT_NUM_COLORS
|
14 |
except ImportError:
|
15 |
# Handle when imported from root app.py
|
16 |
-
import sys
|
17 |
import os
|
|
|
|
|
18 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
19 |
from services.models.flower_classification import flower_classifier
|
20 |
from services.models.image_generation import image_generator
|
21 |
from utils.color_utils import extract_dominant_colors
|
22 |
-
|
23 |
|
24 |
class FrenchStyleTab:
|
25 |
"""UI component for the French Style tab."""
|
26 |
-
|
27 |
def __init__(self):
|
28 |
pass
|
29 |
-
|
30 |
def create_ui(self) -> gr.TabItem:
|
31 |
"""Create the French Style tab UI."""
|
32 |
with gr.TabItem("French Style arrangement") as tab:
|
33 |
gr.Markdown("## π«π· French-Style Flower Arrangements")
|
34 |
-
gr.Markdown(
|
35 |
-
|
|
|
|
|
36 |
with gr.Row():
|
37 |
with gr.Column():
|
38 |
self.upload_img = gr.Image(label="Upload Flower Image", type="pil")
|
39 |
self.analyze_btn = gr.Button(
|
40 |
-
"π¨ Analyze & Generate French Style",
|
41 |
-
variant="primary",
|
42 |
-
size="lg"
|
43 |
)
|
44 |
-
|
45 |
with gr.Column():
|
46 |
self.french_result = gr.Image(
|
47 |
-
label="Generated French-Style Arrangement",
|
48 |
-
type="pil"
|
49 |
)
|
50 |
self.french_status = gr.Markdown()
|
51 |
self.analysis_details = gr.Markdown()
|
52 |
-
|
53 |
# Wire events
|
54 |
self.analyze_btn.click(
|
55 |
-
self._update_status,
|
56 |
-
outputs=[self.french_status, self.analysis_details]
|
57 |
).then(
|
58 |
self.analyze_and_generate,
|
59 |
inputs=[self.upload_img],
|
60 |
-
outputs=[self.french_result, self.french_status, self.analysis_details]
|
61 |
)
|
62 |
-
|
63 |
return tab
|
64 |
-
|
65 |
-
def _update_status(self) ->
|
66 |
"""Update status during processing."""
|
67 |
return "π Processing... Please wait while we analyze your flower image...", ""
|
68 |
-
|
69 |
-
def analyze_and_generate(
|
|
|
|
|
70 |
"""Analyze uploaded flower image and generate French-style arrangement."""
|
71 |
if image is None:
|
72 |
return None, "Please upload an image", ""
|
73 |
-
|
74 |
# Check if classifier is loaded
|
75 |
if flower_classifier.zs_classifier is None:
|
76 |
return None, "Model not loaded", ""
|
77 |
-
|
78 |
try:
|
79 |
progress_log = "π **Step 1/4:** Starting flower analysis...\n\n"
|
80 |
-
|
81 |
# Identify flower
|
82 |
progress_log += "π Identifying flower type using AI model...\n"
|
83 |
results = flower_classifier._use_clip_classification(image, FLOWER_LABELS)
|
84 |
-
|
85 |
top_flower = results[0]["label"] if results else "flower"
|
86 |
confidence = results[0]["score"] if results else 0
|
87 |
-
progress_log +=
|
88 |
-
|
|
|
|
|
89 |
# Extract dominant colors
|
90 |
progress_log += "π **Step 2/4:** Analyzing color palette...\n\n"
|
91 |
progress_log += "π¨ Extracting dominant colors from image...\n"
|
92 |
-
color_names, color_rgb = extract_dominant_colors(
|
93 |
-
|
|
|
|
|
94 |
# Create color description
|
95 |
main_colors = color_names[:3] # Top 3 colors
|
96 |
color_desc = ", ".join(main_colors)
|
97 |
progress_log += f"β
Color palette: **{color_desc}**\n\n"
|
98 |
-
|
99 |
# Generate French-style prompt
|
100 |
-
progress_log +=
|
|
|
|
|
101 |
prompt = (
|
102 |
f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, "
|
103 |
f"displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, "
|
104 |
f"minimalist French country kitchen background, professional photography, sophisticated composition"
|
105 |
)
|
106 |
progress_log += f"β
Prompt created: *{prompt[:100]}...*\n\n"
|
107 |
-
|
108 |
# Generate the image
|
109 |
-
progress_log +=
|
|
|
|
|
110 |
progress_log += "πΌοΈ Using AI image generation (SDXL-Turbo)...\n"
|
111 |
generated_image = image_generator.generate(
|
112 |
-
prompt=prompt,
|
113 |
-
steps=4,
|
114 |
-
width=1024,
|
115 |
-
height=1024,
|
116 |
-
seed=None
|
117 |
)
|
118 |
progress_log += "β
French-style arrangement generated successfully!\n\n"
|
119 |
-
|
120 |
# Create analysis summary
|
121 |
analysis = f"""
|
122 |
**πΈ Flower Analysis:**
|
@@ -131,11 +139,15 @@ class FrenchStyleTab:
|
|
131 |
**π Process Log:**
|
132 |
{progress_log}
|
133 |
"""
|
134 |
-
|
135 |
-
return
|
136 |
-
|
|
|
|
|
|
|
|
|
137 |
except Exception as e:
|
138 |
-
error_log = f"β **Error occurred during processing:**\n\n{
|
139 |
-
if
|
140 |
error_log += f"**Progress before error:**\n{progress_log}"
|
141 |
-
return None, f"β Error: {
|
|
|
2 |
French Style tab UI components and logic.
|
3 |
"""
|
4 |
|
5 |
+
|
6 |
import gradio as gr
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
try:
|
10 |
+
from core.constants import DEFAULT_NUM_COLORS, FLOWER_LABELS
|
11 |
from services.models.flower_classification import flower_classifier
|
12 |
from services.models.image_generation import image_generator
|
13 |
from utils.color_utils import extract_dominant_colors
|
|
|
14 |
except ImportError:
|
15 |
# Handle when imported from root app.py
|
|
|
16 |
import os
|
17 |
+
import sys
|
18 |
+
|
19 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
20 |
+
from core.constants import DEFAULT_NUM_COLORS, FLOWER_LABELS
|
21 |
from services.models.flower_classification import flower_classifier
|
22 |
from services.models.image_generation import image_generator
|
23 |
from utils.color_utils import extract_dominant_colors
|
24 |
+
|
25 |
|
26 |
class FrenchStyleTab:
|
27 |
"""UI component for the French Style tab."""
|
28 |
+
|
29 |
def __init__(self):
|
30 |
pass
|
31 |
+
|
32 |
def create_ui(self) -> gr.TabItem:
|
33 |
"""Create the French Style tab UI."""
|
34 |
with gr.TabItem("French Style arrangement") as tab:
|
35 |
gr.Markdown("## π«π· French-Style Flower Arrangements")
|
36 |
+
gr.Markdown(
|
37 |
+
"Upload a flower image and generate an elegant French-style arrangement with matching colors!"
|
38 |
+
)
|
39 |
+
|
40 |
with gr.Row():
|
41 |
with gr.Column():
|
42 |
self.upload_img = gr.Image(label="Upload Flower Image", type="pil")
|
43 |
self.analyze_btn = gr.Button(
|
44 |
+
"π¨ Analyze & Generate French Style",
|
45 |
+
variant="primary",
|
46 |
+
size="lg",
|
47 |
)
|
48 |
+
|
49 |
with gr.Column():
|
50 |
self.french_result = gr.Image(
|
51 |
+
label="Generated French-Style Arrangement", type="pil"
|
|
|
52 |
)
|
53 |
self.french_status = gr.Markdown()
|
54 |
self.analysis_details = gr.Markdown()
|
55 |
+
|
56 |
# Wire events
|
57 |
self.analyze_btn.click(
|
58 |
+
self._update_status, outputs=[self.french_status, self.analysis_details]
|
|
|
59 |
).then(
|
60 |
self.analyze_and_generate,
|
61 |
inputs=[self.upload_img],
|
62 |
+
outputs=[self.french_result, self.french_status, self.analysis_details],
|
63 |
)
|
64 |
+
|
65 |
return tab
|
66 |
+
|
67 |
+
def _update_status(self) -> tuple[str, str]:
|
68 |
"""Update status during processing."""
|
69 |
return "π Processing... Please wait while we analyze your flower image...", ""
|
70 |
+
|
71 |
+
def analyze_and_generate(
|
72 |
+
self, image: Image.Image | None
|
73 |
+
) -> tuple[Image.Image | None, str, str]:
|
74 |
"""Analyze uploaded flower image and generate French-style arrangement."""
|
75 |
if image is None:
|
76 |
return None, "Please upload an image", ""
|
77 |
+
|
78 |
# Check if classifier is loaded
|
79 |
if flower_classifier.zs_classifier is None:
|
80 |
return None, "Model not loaded", ""
|
81 |
+
|
82 |
try:
|
83 |
progress_log = "π **Step 1/4:** Starting flower analysis...\n\n"
|
84 |
+
|
85 |
# Identify flower
|
86 |
progress_log += "π Identifying flower type using AI model...\n"
|
87 |
results = flower_classifier._use_clip_classification(image, FLOWER_LABELS)
|
88 |
+
|
89 |
top_flower = results[0]["label"] if results else "flower"
|
90 |
confidence = results[0]["score"] if results else 0
|
91 |
+
progress_log += (
|
92 |
+
f"β
Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
|
93 |
+
)
|
94 |
+
|
95 |
# Extract dominant colors
|
96 |
progress_log += "π **Step 2/4:** Analyzing color palette...\n\n"
|
97 |
progress_log += "π¨ Extracting dominant colors from image...\n"
|
98 |
+
color_names, color_rgb = extract_dominant_colors(
|
99 |
+
image, num_colors=DEFAULT_NUM_COLORS
|
100 |
+
)
|
101 |
+
|
102 |
# Create color description
|
103 |
main_colors = color_names[:3] # Top 3 colors
|
104 |
color_desc = ", ".join(main_colors)
|
105 |
progress_log += f"β
Color palette: **{color_desc}**\n\n"
|
106 |
+
|
107 |
# Generate French-style prompt
|
108 |
+
progress_log += (
|
109 |
+
"π **Step 3/4:** Creating French-style arrangement prompt...\n\n"
|
110 |
+
)
|
111 |
prompt = (
|
112 |
f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, "
|
113 |
f"displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, "
|
114 |
f"minimalist French country kitchen background, professional photography, sophisticated composition"
|
115 |
)
|
116 |
progress_log += f"β
Prompt created: *{prompt[:100]}...*\n\n"
|
117 |
+
|
118 |
# Generate the image
|
119 |
+
progress_log += (
|
120 |
+
"π **Step 4/4:** Generating French-style arrangement image...\n\n"
|
121 |
+
)
|
122 |
progress_log += "πΌοΈ Using AI image generation (SDXL-Turbo)...\n"
|
123 |
generated_image = image_generator.generate(
|
124 |
+
prompt=prompt, steps=4, width=1024, height=1024, seed=None
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
progress_log += "β
French-style arrangement generated successfully!\n\n"
|
127 |
+
|
128 |
# Create analysis summary
|
129 |
analysis = f"""
|
130 |
**πΈ Flower Analysis:**
|
|
|
139 |
**π Process Log:**
|
140 |
{progress_log}
|
141 |
"""
|
142 |
+
|
143 |
+
return (
|
144 |
+
generated_image,
|
145 |
+
"β
Analysis complete! French-style arrangement generated.",
|
146 |
+
analysis,
|
147 |
+
)
|
148 |
+
|
149 |
except Exception as e:
|
150 |
+
error_log = f"β **Error occurred during processing:**\n\n{e!s}\n\n"
|
151 |
+
if "progress_log" in locals():
|
152 |
error_log += f"**Progress before error:**\n{progress_log}"
|
153 |
+
return None, f"β Error: {e!s}", error_log
|
src/ui/generate/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Generate tab package
|
|
|
1 |
+
# Generate tab package
|
src/ui/generate/generate_tab.py
CHANGED
@@ -2,27 +2,29 @@
|
|
2 |
Generate tab UI components and logic.
|
3 |
"""
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
-
from typing import Optional
|
8 |
|
9 |
try:
|
|
|
10 |
from services.models.image_generation import image_generator
|
11 |
-
from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_WIDTH, DEFAULT_HEIGHT
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
14 |
-
import sys
|
15 |
import os
|
|
|
|
|
16 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
17 |
from services.models.image_generation import image_generator
|
18 |
-
|
19 |
|
20 |
class GenerateTab:
|
21 |
"""UI component for the Generate tab."""
|
22 |
-
|
23 |
def __init__(self):
|
24 |
self.output_image = None
|
25 |
-
|
26 |
def create_ui(self) -> gr.TabItem:
|
27 |
"""Create the Generate tab UI."""
|
28 |
with gr.TabItem("Generate") as tab:
|
@@ -30,7 +32,7 @@ class GenerateTab:
|
|
30 |
with gr.Column():
|
31 |
self.prompt_input = gr.Textbox(
|
32 |
value="ikebana-style flower arrangement, soft natural light, minimalist",
|
33 |
-
label="Prompt"
|
34 |
)
|
35 |
self.steps_input = gr.Slider(
|
36 |
1, 8, value=DEFAULT_GENERATE_STEPS, step=1, label="Steps"
|
@@ -45,23 +47,27 @@ class GenerateTab:
|
|
45 |
value=-1, precision=0, label="Seed (-1 = random)"
|
46 |
)
|
47 |
self.generate_btn = gr.Button("Generate", variant="primary")
|
48 |
-
|
49 |
self.output_image = gr.Image(label="Result", type="pil")
|
50 |
-
|
51 |
# Wire events
|
52 |
self.generate_btn.click(
|
53 |
self.generate_image,
|
54 |
inputs=[
|
55 |
-
self.prompt_input,
|
56 |
-
self.
|
|
|
|
|
|
|
57 |
],
|
58 |
-
outputs=self.output_image
|
59 |
)
|
60 |
-
|
61 |
return tab
|
62 |
-
|
63 |
-
def generate_image(
|
64 |
-
|
|
|
65 |
"""Generate an image from the given parameters."""
|
66 |
try:
|
67 |
return image_generator.generate(
|
@@ -69,8 +75,8 @@ class GenerateTab:
|
|
69 |
steps=steps,
|
70 |
width=width,
|
71 |
height=height,
|
72 |
-
seed=seed if seed >= 0 else None
|
73 |
)
|
74 |
except Exception as e:
|
75 |
-
gr.Warning(f"Error generating image: {
|
76 |
-
return None
|
|
|
2 |
Generate tab UI components and logic.
|
3 |
"""
|
4 |
|
5 |
+
|
6 |
import gradio as gr
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
try:
|
10 |
+
from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_HEIGHT, DEFAULT_WIDTH
|
11 |
from services.models.image_generation import image_generator
|
|
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
|
|
14 |
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
18 |
+
from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_HEIGHT, DEFAULT_WIDTH
|
19 |
from services.models.image_generation import image_generator
|
20 |
+
|
21 |
|
22 |
class GenerateTab:
|
23 |
"""UI component for the Generate tab."""
|
24 |
+
|
25 |
def __init__(self):
|
26 |
self.output_image = None
|
27 |
+
|
28 |
def create_ui(self) -> gr.TabItem:
|
29 |
"""Create the Generate tab UI."""
|
30 |
with gr.TabItem("Generate") as tab:
|
|
|
32 |
with gr.Column():
|
33 |
self.prompt_input = gr.Textbox(
|
34 |
value="ikebana-style flower arrangement, soft natural light, minimalist",
|
35 |
+
label="Prompt",
|
36 |
)
|
37 |
self.steps_input = gr.Slider(
|
38 |
1, 8, value=DEFAULT_GENERATE_STEPS, step=1, label="Steps"
|
|
|
47 |
value=-1, precision=0, label="Seed (-1 = random)"
|
48 |
)
|
49 |
self.generate_btn = gr.Button("Generate", variant="primary")
|
50 |
+
|
51 |
self.output_image = gr.Image(label="Result", type="pil")
|
52 |
+
|
53 |
# Wire events
|
54 |
self.generate_btn.click(
|
55 |
self.generate_image,
|
56 |
inputs=[
|
57 |
+
self.prompt_input,
|
58 |
+
self.steps_input,
|
59 |
+
self.width_input,
|
60 |
+
self.height_input,
|
61 |
+
self.seed_input,
|
62 |
],
|
63 |
+
outputs=self.output_image,
|
64 |
)
|
65 |
+
|
66 |
return tab
|
67 |
+
|
68 |
+
def generate_image(
|
69 |
+
self, prompt: str, steps: int, width: int, height: int, seed: int
|
70 |
+
) -> Image.Image | None:
|
71 |
"""Generate an image from the given parameters."""
|
72 |
try:
|
73 |
return image_generator.generate(
|
|
|
75 |
steps=steps,
|
76 |
width=width,
|
77 |
height=height,
|
78 |
+
seed=seed if seed >= 0 else None,
|
79 |
)
|
80 |
except Exception as e:
|
81 |
+
gr.Warning(f"Error generating image: {e!s}")
|
82 |
+
return None
|
src/ui/identify/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Identify tab package
|
|
|
1 |
+
# Identify tab package
|
src/ui/identify/identify_tab.py
CHANGED
@@ -2,27 +2,29 @@
|
|
2 |
Identify tab UI components and logic.
|
3 |
"""
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
-
from typing import List, Optional, Tuple
|
8 |
|
9 |
try:
|
|
|
10 |
from services.models.flower_classification import flower_classifier
|
11 |
-
from core.constants import FLOWER_LABELS, DEFAULT_TOP_K, DEFAULT_MIN_SCORE
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
14 |
-
import sys
|
15 |
import os
|
|
|
|
|
16 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
17 |
from services.models.flower_classification import flower_classifier
|
18 |
-
|
19 |
|
20 |
class IdentifyTab:
|
21 |
"""UI component for the Identify tab."""
|
22 |
-
|
23 |
def __init__(self):
|
24 |
pass
|
25 |
-
|
26 |
def create_ui(self) -> gr.TabItem:
|
27 |
"""Create the Identify tab UI."""
|
28 |
with gr.TabItem("Identify") as tab:
|
@@ -31,52 +33,70 @@ class IdentifyTab:
|
|
31 |
self.image_input = gr.Image(
|
32 |
label="Image (upload or auto-filled from 'Generate')",
|
33 |
type="pil",
|
34 |
-
interactive=True
|
35 |
)
|
36 |
self.labels_input = gr.CheckboxGroup(
|
37 |
choices=FLOWER_LABELS,
|
38 |
-
value=[
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
41 |
self.topk_input = gr.Slider(
|
42 |
1, 15, value=DEFAULT_TOP_K, step=1, label="Top-K"
|
43 |
)
|
44 |
self.min_score_input = gr.Slider(
|
45 |
-
0.0,
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
self.detect_btn = gr.Button("Identify Flowers", variant="primary")
|
48 |
-
|
49 |
with gr.Column():
|
50 |
self.results_table = gr.Dataframe(
|
51 |
headers=["Flower", "Confidence"],
|
52 |
datatype=["str", "number"],
|
53 |
-
interactive=False
|
54 |
)
|
55 |
self.status_output = gr.Markdown()
|
56 |
-
|
57 |
# Wire events
|
58 |
self.detect_btn.click(
|
59 |
self.identify_flowers,
|
60 |
inputs=[
|
61 |
-
self.image_input,
|
62 |
-
self.
|
|
|
|
|
63 |
],
|
64 |
-
outputs=[self.results_table, self.status_output]
|
65 |
)
|
66 |
-
|
67 |
return tab
|
68 |
-
|
69 |
-
def identify_flowers(
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
"""Identify flowers in the provided image."""
|
73 |
return flower_classifier.identify_flowers(
|
74 |
image=image,
|
75 |
candidate_labels=candidate_labels,
|
76 |
top_k=top_k,
|
77 |
-
min_score=min_score
|
78 |
)
|
79 |
-
|
80 |
-
def set_image(self, image:
|
81 |
"""Set the image for identification (used by other tabs)."""
|
82 |
-
return image
|
|
|
2 |
Identify tab UI components and logic.
|
3 |
"""
|
4 |
|
5 |
+
|
6 |
import gradio as gr
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
try:
|
10 |
+
from core.constants import DEFAULT_MIN_SCORE, DEFAULT_TOP_K, FLOWER_LABELS
|
11 |
from services.models.flower_classification import flower_classifier
|
|
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
|
|
14 |
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
18 |
+
from core.constants import DEFAULT_MIN_SCORE, DEFAULT_TOP_K, FLOWER_LABELS
|
19 |
from services.models.flower_classification import flower_classifier
|
20 |
+
|
21 |
|
22 |
class IdentifyTab:
|
23 |
"""UI component for the Identify tab."""
|
24 |
+
|
25 |
def __init__(self):
|
26 |
pass
|
27 |
+
|
28 |
def create_ui(self) -> gr.TabItem:
|
29 |
"""Create the Identify tab UI."""
|
30 |
with gr.TabItem("Identify") as tab:
|
|
|
33 |
self.image_input = gr.Image(
|
34 |
label="Image (upload or auto-filled from 'Generate')",
|
35 |
type="pil",
|
36 |
+
interactive=True,
|
37 |
)
|
38 |
self.labels_input = gr.CheckboxGroup(
|
39 |
choices=FLOWER_LABELS,
|
40 |
+
value=[
|
41 |
+
"rose",
|
42 |
+
"tulip",
|
43 |
+
"lily",
|
44 |
+
"peony",
|
45 |
+
"hydrangea",
|
46 |
+
"orchid",
|
47 |
+
"sunflower",
|
48 |
+
],
|
49 |
+
label="Candidate labels (edit as needed)",
|
50 |
)
|
51 |
self.topk_input = gr.Slider(
|
52 |
1, 15, value=DEFAULT_TOP_K, step=1, label="Top-K"
|
53 |
)
|
54 |
self.min_score_input = gr.Slider(
|
55 |
+
0.0,
|
56 |
+
1.0,
|
57 |
+
value=DEFAULT_MIN_SCORE,
|
58 |
+
step=0.01,
|
59 |
+
label="Min confidence",
|
60 |
)
|
61 |
self.detect_btn = gr.Button("Identify Flowers", variant="primary")
|
62 |
+
|
63 |
with gr.Column():
|
64 |
self.results_table = gr.Dataframe(
|
65 |
headers=["Flower", "Confidence"],
|
66 |
datatype=["str", "number"],
|
67 |
+
interactive=False,
|
68 |
)
|
69 |
self.status_output = gr.Markdown()
|
70 |
+
|
71 |
# Wire events
|
72 |
self.detect_btn.click(
|
73 |
self.identify_flowers,
|
74 |
inputs=[
|
75 |
+
self.image_input,
|
76 |
+
self.labels_input,
|
77 |
+
self.topk_input,
|
78 |
+
self.min_score_input,
|
79 |
],
|
80 |
+
outputs=[self.results_table, self.status_output],
|
81 |
)
|
82 |
+
|
83 |
return tab
|
84 |
+
|
85 |
+
def identify_flowers(
|
86 |
+
self,
|
87 |
+
image: Image.Image | None,
|
88 |
+
candidate_labels: list[str],
|
89 |
+
top_k: int,
|
90 |
+
min_score: float,
|
91 |
+
) -> tuple[list[list], str]:
|
92 |
"""Identify flowers in the provided image."""
|
93 |
return flower_classifier.identify_flowers(
|
94 |
image=image,
|
95 |
candidate_labels=candidate_labels,
|
96 |
top_k=top_k,
|
97 |
+
min_score=min_score,
|
98 |
)
|
99 |
+
|
100 |
+
def set_image(self, image: Image.Image | None) -> Image.Image | None:
|
101 |
"""Set the image for identification (used by other tabs)."""
|
102 |
+
return image
|
src/ui/train/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Train tab package
|
|
|
1 |
+
# Train tab package
|
src/ui/train/train_tab.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
Train Model tab UI components and logic.
|
3 |
"""
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
-
from typing import List
|
7 |
|
8 |
try:
|
9 |
from services.models.flower_classification import flower_classifier
|
@@ -11,32 +11,38 @@ try:
|
|
11 |
from utils.file_utils import count_training_images
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
14 |
-
import sys
|
15 |
import os
|
|
|
|
|
16 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
17 |
from services.models.flower_classification import flower_classifier
|
18 |
from services.training.training_service import training_service
|
19 |
from utils.file_utils import count_training_images
|
20 |
|
|
|
21 |
class TrainTab:
|
22 |
"""UI component for the Train Model tab."""
|
23 |
-
|
24 |
def __init__(self):
|
25 |
pass
|
26 |
-
|
27 |
def create_ui(self) -> gr.TabItem:
|
28 |
"""Create the Train Model tab UI."""
|
29 |
with gr.TabItem("Train Model") as tab:
|
30 |
gr.Markdown("## π― Fine-tune the flower identification model")
|
31 |
-
gr.Markdown(
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
with gr.Row():
|
35 |
with gr.Column():
|
36 |
gr.Markdown("### Training Data")
|
37 |
self.refresh_btn = gr.Button("π Refresh Data Count", size="sm")
|
38 |
self.data_status = gr.Markdown()
|
39 |
-
|
40 |
gr.Markdown("### Training Parameters")
|
41 |
self.epochs_input = gr.Slider(
|
42 |
1, 20, value=5, step=1, label="Training Epochs"
|
@@ -47,69 +53,71 @@ class TrainTab:
|
|
47 |
self.learning_rate_input = gr.Number(
|
48 |
value=1e-5, label="Learning Rate", precision=6
|
49 |
)
|
50 |
-
|
51 |
self.train_btn = gr.Button("π Start Training", variant="primary")
|
52 |
-
|
53 |
with gr.Column():
|
54 |
gr.Markdown("### Model Management")
|
55 |
self.model_dropdown = gr.Dropdown(
|
56 |
choices=flower_classifier.get_available_models(),
|
57 |
value=f"{flower_classifier.current_model_path} (default)",
|
58 |
-
label="Select Model"
|
59 |
)
|
60 |
self.refresh_models_btn = gr.Button("π Refresh Models", size="sm")
|
61 |
-
self.load_model_btn = gr.Button(
|
62 |
-
|
|
|
|
|
63 |
self.model_status = gr.Markdown(
|
64 |
f"**Current model:** {flower_classifier.current_model_path}"
|
65 |
)
|
66 |
-
|
67 |
gr.Markdown("### Training Status")
|
68 |
self.training_output = gr.Markdown()
|
69 |
-
|
70 |
# Wire events
|
71 |
self.refresh_btn.click(self._count_training_images, outputs=[self.data_status])
|
72 |
self.refresh_models_btn.click(
|
73 |
self._refresh_models, outputs=[self.model_dropdown]
|
74 |
)
|
75 |
self.load_model_btn.click(
|
76 |
-
self._load_trained_model,
|
77 |
-
inputs=[self.model_dropdown],
|
78 |
-
outputs=[self.model_status]
|
79 |
)
|
80 |
self.train_btn.click(
|
81 |
self._start_training,
|
82 |
inputs=[self.epochs_input, self.batch_size_input, self.learning_rate_input],
|
83 |
-
outputs=[self.training_output]
|
84 |
)
|
85 |
-
|
86 |
return tab
|
87 |
-
|
88 |
def _count_training_images(self) -> str:
|
89 |
"""Count and display training images."""
|
90 |
total_images, flower_counts = count_training_images()
|
91 |
-
|
92 |
if total_images == 0:
|
93 |
return "No training images found. Add images to subdirectories in training_data/images/"
|
94 |
-
|
95 |
result = f"**Total images: {total_images}**\n\n"
|
96 |
for flower_type, count in sorted(flower_counts.items()):
|
97 |
result += f"- {flower_type}: {count} images\n"
|
98 |
-
|
99 |
return result
|
100 |
-
|
101 |
def _refresh_models(self) -> gr.Dropdown:
|
102 |
"""Refresh the list of available models."""
|
103 |
return gr.Dropdown(choices=flower_classifier.get_available_models())
|
104 |
-
|
105 |
def _load_trained_model(self, model_selection: str) -> str:
|
106 |
"""Load the selected trained model."""
|
107 |
return flower_classifier.load_trained_model(model_selection)
|
108 |
-
|
109 |
-
def _start_training(
|
|
|
|
|
110 |
"""Start the training process."""
|
111 |
return training_service.start_training(
|
112 |
-
epochs=epochs,
|
113 |
-
|
114 |
-
learning_rate=learning_rate
|
115 |
-
)
|
|
|
2 |
Train Model tab UI components and logic.
|
3 |
"""
|
4 |
|
5 |
+
|
6 |
import gradio as gr
|
|
|
7 |
|
8 |
try:
|
9 |
from services.models.flower_classification import flower_classifier
|
|
|
11 |
from utils.file_utils import count_training_images
|
12 |
except ImportError:
|
13 |
# Handle when imported from root app.py
|
|
|
14 |
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
18 |
from services.models.flower_classification import flower_classifier
|
19 |
from services.training.training_service import training_service
|
20 |
from utils.file_utils import count_training_images
|
21 |
|
22 |
+
|
23 |
class TrainTab:
|
24 |
"""UI component for the Train Model tab."""
|
25 |
+
|
26 |
def __init__(self):
|
27 |
pass
|
28 |
+
|
29 |
def create_ui(self) -> gr.TabItem:
|
30 |
"""Create the Train Model tab UI."""
|
31 |
with gr.TabItem("Train Model") as tab:
|
32 |
gr.Markdown("## π― Fine-tune the flower identification model")
|
33 |
+
gr.Markdown(
|
34 |
+
"Organize your training images in subdirectories by flower type in `training_data/images/`"
|
35 |
+
)
|
36 |
+
gr.Markdown(
|
37 |
+
"Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
|
38 |
+
)
|
39 |
+
|
40 |
with gr.Row():
|
41 |
with gr.Column():
|
42 |
gr.Markdown("### Training Data")
|
43 |
self.refresh_btn = gr.Button("π Refresh Data Count", size="sm")
|
44 |
self.data_status = gr.Markdown()
|
45 |
+
|
46 |
gr.Markdown("### Training Parameters")
|
47 |
self.epochs_input = gr.Slider(
|
48 |
1, 20, value=5, step=1, label="Training Epochs"
|
|
|
53 |
self.learning_rate_input = gr.Number(
|
54 |
value=1e-5, label="Learning Rate", precision=6
|
55 |
)
|
56 |
+
|
57 |
self.train_btn = gr.Button("π Start Training", variant="primary")
|
58 |
+
|
59 |
with gr.Column():
|
60 |
gr.Markdown("### Model Management")
|
61 |
self.model_dropdown = gr.Dropdown(
|
62 |
choices=flower_classifier.get_available_models(),
|
63 |
value=f"{flower_classifier.current_model_path} (default)",
|
64 |
+
label="Select Model",
|
65 |
)
|
66 |
self.refresh_models_btn = gr.Button("π Refresh Models", size="sm")
|
67 |
+
self.load_model_btn = gr.Button(
|
68 |
+
"π₯ Load Selected Model", variant="secondary"
|
69 |
+
)
|
70 |
+
|
71 |
self.model_status = gr.Markdown(
|
72 |
f"**Current model:** {flower_classifier.current_model_path}"
|
73 |
)
|
74 |
+
|
75 |
gr.Markdown("### Training Status")
|
76 |
self.training_output = gr.Markdown()
|
77 |
+
|
78 |
# Wire events
|
79 |
self.refresh_btn.click(self._count_training_images, outputs=[self.data_status])
|
80 |
self.refresh_models_btn.click(
|
81 |
self._refresh_models, outputs=[self.model_dropdown]
|
82 |
)
|
83 |
self.load_model_btn.click(
|
84 |
+
self._load_trained_model,
|
85 |
+
inputs=[self.model_dropdown],
|
86 |
+
outputs=[self.model_status],
|
87 |
)
|
88 |
self.train_btn.click(
|
89 |
self._start_training,
|
90 |
inputs=[self.epochs_input, self.batch_size_input, self.learning_rate_input],
|
91 |
+
outputs=[self.training_output],
|
92 |
)
|
93 |
+
|
94 |
return tab
|
95 |
+
|
96 |
def _count_training_images(self) -> str:
|
97 |
"""Count and display training images."""
|
98 |
total_images, flower_counts = count_training_images()
|
99 |
+
|
100 |
if total_images == 0:
|
101 |
return "No training images found. Add images to subdirectories in training_data/images/"
|
102 |
+
|
103 |
result = f"**Total images: {total_images}**\n\n"
|
104 |
for flower_type, count in sorted(flower_counts.items()):
|
105 |
result += f"- {flower_type}: {count} images\n"
|
106 |
+
|
107 |
return result
|
108 |
+
|
109 |
def _refresh_models(self) -> gr.Dropdown:
|
110 |
"""Refresh the list of available models."""
|
111 |
return gr.Dropdown(choices=flower_classifier.get_available_models())
|
112 |
+
|
113 |
def _load_trained_model(self, model_selection: str) -> str:
|
114 |
"""Load the selected trained model."""
|
115 |
return flower_classifier.load_trained_model(model_selection)
|
116 |
+
|
117 |
+
def _start_training(
|
118 |
+
self, epochs: int, batch_size: int, learning_rate: float
|
119 |
+
) -> str:
|
120 |
"""Start the training process."""
|
121 |
return training_service.start_training(
|
122 |
+
epochs=epochs, batch_size=batch_size, learning_rate=learning_rate
|
123 |
+
)
|
|
|
|
src/utils/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
# Utils package
|
|
|
1 |
+
# Utils package
|
src/utils/color_utils.py
CHANGED
@@ -2,38 +2,42 @@
|
|
2 |
Color analysis utilities.
|
3 |
"""
|
4 |
|
|
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
from sklearn.cluster import KMeans
|
8 |
-
from typing import List, Tuple, Optional
|
9 |
|
10 |
-
|
|
|
|
|
|
|
11 |
"""Extract dominant colors from an image using k-means clustering."""
|
12 |
if image is None:
|
13 |
return [], np.array([])
|
14 |
-
|
15 |
# Convert PIL image to numpy array
|
16 |
img_array = np.array(image)
|
17 |
-
|
18 |
# Reshape image to be a list of pixels
|
19 |
pixels = img_array.reshape(-1, 3)
|
20 |
-
|
21 |
# Use k-means to find dominant colors
|
22 |
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
23 |
kmeans.fit(pixels)
|
24 |
-
|
25 |
# Get the colors and convert to RGB values
|
26 |
colors = kmeans.cluster_centers_.astype(int)
|
27 |
-
|
28 |
# Convert to color names/descriptions
|
29 |
color_names = [_rgb_to_color_name(color) for color in colors]
|
30 |
-
|
31 |
return color_names, colors
|
32 |
|
|
|
33 |
def _rgb_to_color_name(color: np.ndarray) -> str:
|
34 |
"""Convert RGB values to descriptive color name."""
|
35 |
r, g, b = color
|
36 |
-
|
37 |
if r > 200 and g > 200 and b > 200:
|
38 |
return "white"
|
39 |
elif r < 50 and g < 50 and b < 50:
|
@@ -56,4 +60,4 @@ def _rgb_to_color_name(color: np.ndarray) -> str:
|
|
56 |
elif r > 150 and g > 100 and b < 100:
|
57 |
return "orange"
|
58 |
else:
|
59 |
-
return "cream"
|
|
|
2 |
Color analysis utilities.
|
3 |
"""
|
4 |
|
5 |
+
|
6 |
import numpy as np
|
7 |
from PIL import Image
|
8 |
from sklearn.cluster import KMeans
|
|
|
9 |
|
10 |
+
|
11 |
+
def extract_dominant_colors(
|
12 |
+
image: Image.Image | None, num_colors: int = 5
|
13 |
+
) -> tuple[list[str], np.ndarray]:
|
14 |
"""Extract dominant colors from an image using k-means clustering."""
|
15 |
if image is None:
|
16 |
return [], np.array([])
|
17 |
+
|
18 |
# Convert PIL image to numpy array
|
19 |
img_array = np.array(image)
|
20 |
+
|
21 |
# Reshape image to be a list of pixels
|
22 |
pixels = img_array.reshape(-1, 3)
|
23 |
+
|
24 |
# Use k-means to find dominant colors
|
25 |
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
26 |
kmeans.fit(pixels)
|
27 |
+
|
28 |
# Get the colors and convert to RGB values
|
29 |
colors = kmeans.cluster_centers_.astype(int)
|
30 |
+
|
31 |
# Convert to color names/descriptions
|
32 |
color_names = [_rgb_to_color_name(color) for color in colors]
|
33 |
+
|
34 |
return color_names, colors
|
35 |
|
36 |
+
|
37 |
def _rgb_to_color_name(color: np.ndarray) -> str:
|
38 |
"""Convert RGB values to descriptive color name."""
|
39 |
r, g, b = color
|
40 |
+
|
41 |
if r > 200 and g > 200 and b > 200:
|
42 |
return "white"
|
43 |
elif r < 50 and g < 50 and b < 50:
|
|
|
60 |
elif r > 150 and g > 100 and b < 100:
|
61 |
return "orange"
|
62 |
else:
|
63 |
+
return "cream"
|
src/utils/file_utils.py
CHANGED
@@ -2,19 +2,21 @@
|
|
2 |
File and directory utilities.
|
3 |
"""
|
4 |
|
5 |
-
import os
|
6 |
import glob
|
7 |
-
|
|
|
8 |
try:
|
9 |
-
from ..core.constants import
|
10 |
except ImportError:
|
11 |
# Handle direct execution
|
12 |
-
import sys
|
13 |
import os
|
|
|
|
|
14 |
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
15 |
-
from core.constants import
|
|
|
16 |
|
17 |
-
def get_image_files(directory: str) ->
|
18 |
"""Get all image files from a directory."""
|
19 |
image_files = []
|
20 |
for ext in SUPPORTED_IMAGE_EXTENSIONS:
|
@@ -22,11 +24,12 @@ def get_image_files(directory: str) -> List[str]:
|
|
22 |
image_files.extend(glob.glob(pattern))
|
23 |
return image_files
|
24 |
|
25 |
-
|
|
|
26 |
"""Auto-detect flower types from directory structure."""
|
27 |
if not os.path.exists(image_dir):
|
28 |
return []
|
29 |
-
|
30 |
detected_types = []
|
31 |
for item in os.listdir(image_dir):
|
32 |
item_path = os.path.join(image_dir, item)
|
@@ -34,17 +37,18 @@ def get_flower_types_from_directory(image_dir: str = IMAGES_DIR) -> List[str]:
|
|
34 |
image_files = get_image_files(item_path)
|
35 |
if image_files: # Only add if there are images
|
36 |
detected_types.append(item)
|
37 |
-
|
38 |
return sorted(detected_types)
|
39 |
|
40 |
-
|
|
|
41 |
"""Count training images by flower type."""
|
42 |
if not os.path.exists(IMAGES_DIR):
|
43 |
return 0, {}
|
44 |
-
|
45 |
total_images = 0
|
46 |
flower_counts = {}
|
47 |
-
|
48 |
for flower_type in os.listdir(IMAGES_DIR):
|
49 |
flower_path = os.path.join(IMAGES_DIR, flower_type)
|
50 |
if os.path.isdir(flower_path):
|
@@ -53,18 +57,21 @@ def count_training_images() -> Tuple[int, dict]:
|
|
53 |
if count > 0:
|
54 |
flower_counts[flower_type] = count
|
55 |
total_images += count
|
56 |
-
|
57 |
return total_images, flower_counts
|
58 |
|
59 |
-
|
|
|
60 |
"""Get list of available trained models."""
|
61 |
if not os.path.exists(MODELS_DIR):
|
62 |
return []
|
63 |
-
|
64 |
models = []
|
65 |
for item in os.listdir(MODELS_DIR):
|
66 |
model_path = os.path.join(MODELS_DIR, item)
|
67 |
-
if os.path.isdir(model_path) and os.path.exists(
|
|
|
|
|
68 |
models.append(item)
|
69 |
-
|
70 |
-
return sorted(models)
|
|
|
2 |
File and directory utilities.
|
3 |
"""
|
4 |
|
|
|
5 |
import glob
|
6 |
+
import os
|
7 |
+
|
8 |
try:
|
9 |
+
from ..core.constants import IMAGES_DIR, MODELS_DIR, SUPPORTED_IMAGE_EXTENSIONS
|
10 |
except ImportError:
|
11 |
# Handle direct execution
|
|
|
12 |
import os
|
13 |
+
import sys
|
14 |
+
|
15 |
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
16 |
+
from core.constants import IMAGES_DIR, MODELS_DIR, SUPPORTED_IMAGE_EXTENSIONS
|
17 |
+
|
18 |
|
19 |
+
def get_image_files(directory: str) -> list[str]:
|
20 |
"""Get all image files from a directory."""
|
21 |
image_files = []
|
22 |
for ext in SUPPORTED_IMAGE_EXTENSIONS:
|
|
|
24 |
image_files.extend(glob.glob(pattern))
|
25 |
return image_files
|
26 |
|
27 |
+
|
28 |
+
def get_flower_types_from_directory(image_dir: str = IMAGES_DIR) -> list[str]:
|
29 |
"""Auto-detect flower types from directory structure."""
|
30 |
if not os.path.exists(image_dir):
|
31 |
return []
|
32 |
+
|
33 |
detected_types = []
|
34 |
for item in os.listdir(image_dir):
|
35 |
item_path = os.path.join(image_dir, item)
|
|
|
37 |
image_files = get_image_files(item_path)
|
38 |
if image_files: # Only add if there are images
|
39 |
detected_types.append(item)
|
40 |
+
|
41 |
return sorted(detected_types)
|
42 |
|
43 |
+
|
44 |
+
def count_training_images() -> tuple[int, dict]:
|
45 |
"""Count training images by flower type."""
|
46 |
if not os.path.exists(IMAGES_DIR):
|
47 |
return 0, {}
|
48 |
+
|
49 |
total_images = 0
|
50 |
flower_counts = {}
|
51 |
+
|
52 |
for flower_type in os.listdir(IMAGES_DIR):
|
53 |
flower_path = os.path.join(IMAGES_DIR, flower_type)
|
54 |
if os.path.isdir(flower_path):
|
|
|
57 |
if count > 0:
|
58 |
flower_counts[flower_type] = count
|
59 |
total_images += count
|
60 |
+
|
61 |
return total_images, flower_counts
|
62 |
|
63 |
+
|
64 |
+
def get_available_trained_models() -> list[str]:
|
65 |
"""Get list of available trained models."""
|
66 |
if not os.path.exists(MODELS_DIR):
|
67 |
return []
|
68 |
+
|
69 |
models = []
|
70 |
for item in os.listdir(MODELS_DIR):
|
71 |
model_path = os.path.join(MODELS_DIR, item)
|
72 |
+
if os.path.isdir(model_path) and os.path.exists(
|
73 |
+
os.path.join(model_path, "config.json")
|
74 |
+
):
|
75 |
models.append(item)
|
76 |
+
|
77 |
+
return sorted(models)
|
test_external_cache.py
CHANGED
@@ -5,59 +5,60 @@ import os
|
|
5 |
import sys
|
6 |
from pathlib import Path
|
7 |
|
|
|
8 |
def test_cache_configuration():
|
9 |
"""Test that the external cache configuration is working."""
|
10 |
-
|
11 |
print("π§ͺ Testing External SSD Cache Configuration")
|
12 |
print("=" * 50)
|
13 |
-
|
14 |
# Check if external SSD is mounted
|
15 |
external_path = Path("/Volumes/extssd")
|
16 |
if not external_path.exists():
|
17 |
print("β External SSD not found at /Volumes/extssd")
|
18 |
return False
|
19 |
-
|
20 |
print("β
External SSD is mounted")
|
21 |
-
|
22 |
# Check if HF_HOME is set correctly
|
23 |
hf_home = os.environ.get("HF_HOME")
|
24 |
expected_hf_home = "/Volumes/extssd/huggingface"
|
25 |
-
|
26 |
if hf_home != expected_hf_home:
|
27 |
-
print(
|
|
|
|
|
28 |
print(" Set HF_HOME with: export HF_HOME=/Volumes/extssd/huggingface")
|
29 |
return False
|
30 |
-
|
31 |
print(f"β
HF_HOME correctly set to: {hf_home}")
|
32 |
-
|
33 |
# Check if cache directories exist
|
34 |
hub_cache = Path(hf_home) / "hub"
|
35 |
if not hub_cache.exists():
|
36 |
print(f"β Hub cache directory not found at: {hub_cache}")
|
37 |
return False
|
38 |
-
|
39 |
print(f"β
Hub cache directory exists at: {hub_cache}")
|
40 |
-
|
41 |
# Check if models are present
|
42 |
model_count = len(list(hub_cache.glob("models--*")))
|
43 |
print(f"β
Found {model_count} models in cache")
|
44 |
-
|
45 |
# Test importing Hugging Face libraries and check their cache detection
|
46 |
try:
|
47 |
-
from huggingface_hub import HfFolder
|
48 |
from transformers import AutoTokenizer
|
49 |
-
|
50 |
-
|
51 |
print("β
Hugging Face libraries imported successfully")
|
52 |
-
|
53 |
# Test a small model to verify cache is working
|
54 |
print("π Testing cache with a small model (this may take a moment)...")
|
55 |
-
|
56 |
# This should use the external cache
|
57 |
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
58 |
-
|
59 |
print("β
Successfully loaded model from cache")
|
60 |
-
|
61 |
# Check if the model files are in the expected location
|
62 |
clip_path = hub_cache / "models--openai--clip-vit-base-patch32"
|
63 |
if clip_path.exists():
|
@@ -65,13 +66,14 @@ def test_cache_configuration():
|
|
65 |
else:
|
66 |
print(f"β οΈ Model files not found at expected location: {clip_path}")
|
67 |
return False
|
68 |
-
|
69 |
return True
|
70 |
-
|
71 |
except Exception as e:
|
72 |
print(f"β Error loading model: {e}")
|
73 |
return False
|
74 |
|
|
|
75 |
def main():
|
76 |
"""Main test function."""
|
77 |
# Load .env file if available
|
@@ -79,12 +81,13 @@ def main():
|
|
79 |
if env_file.exists():
|
80 |
print("π Loading .env file...")
|
81 |
from dotenv import load_dotenv
|
|
|
82 |
load_dotenv()
|
83 |
else:
|
84 |
print("β οΈ No .env file found, using system environment variables")
|
85 |
-
|
86 |
success = test_cache_configuration()
|
87 |
-
|
88 |
print("\n" + "=" * 50)
|
89 |
if success:
|
90 |
print("π All tests passed! External SSD cache is working correctly.")
|
@@ -93,5 +96,6 @@ def main():
|
|
93 |
print("β Some tests failed. Please check the configuration.")
|
94 |
sys.exit(1)
|
95 |
|
|
|
96 |
if __name__ == "__main__":
|
97 |
-
main()
|
|
|
5 |
import sys
|
6 |
from pathlib import Path
|
7 |
|
8 |
+
|
9 |
def test_cache_configuration():
|
10 |
"""Test that the external cache configuration is working."""
|
11 |
+
|
12 |
print("π§ͺ Testing External SSD Cache Configuration")
|
13 |
print("=" * 50)
|
14 |
+
|
15 |
# Check if external SSD is mounted
|
16 |
external_path = Path("/Volumes/extssd")
|
17 |
if not external_path.exists():
|
18 |
print("β External SSD not found at /Volumes/extssd")
|
19 |
return False
|
20 |
+
|
21 |
print("β
External SSD is mounted")
|
22 |
+
|
23 |
# Check if HF_HOME is set correctly
|
24 |
hf_home = os.environ.get("HF_HOME")
|
25 |
expected_hf_home = "/Volumes/extssd/huggingface"
|
26 |
+
|
27 |
if hf_home != expected_hf_home:
|
28 |
+
print(
|
29 |
+
f"β οΈ HF_HOME not set correctly. Expected: {expected_hf_home}, Got: {hf_home}"
|
30 |
+
)
|
31 |
print(" Set HF_HOME with: export HF_HOME=/Volumes/extssd/huggingface")
|
32 |
return False
|
33 |
+
|
34 |
print(f"β
HF_HOME correctly set to: {hf_home}")
|
35 |
+
|
36 |
# Check if cache directories exist
|
37 |
hub_cache = Path(hf_home) / "hub"
|
38 |
if not hub_cache.exists():
|
39 |
print(f"β Hub cache directory not found at: {hub_cache}")
|
40 |
return False
|
41 |
+
|
42 |
print(f"β
Hub cache directory exists at: {hub_cache}")
|
43 |
+
|
44 |
# Check if models are present
|
45 |
model_count = len(list(hub_cache.glob("models--*")))
|
46 |
print(f"β
Found {model_count} models in cache")
|
47 |
+
|
48 |
# Test importing Hugging Face libraries and check their cache detection
|
49 |
try:
|
|
|
50 |
from transformers import AutoTokenizer
|
51 |
+
|
|
|
52 |
print("β
Hugging Face libraries imported successfully")
|
53 |
+
|
54 |
# Test a small model to verify cache is working
|
55 |
print("π Testing cache with a small model (this may take a moment)...")
|
56 |
+
|
57 |
# This should use the external cache
|
58 |
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
59 |
+
|
60 |
print("β
Successfully loaded model from cache")
|
61 |
+
|
62 |
# Check if the model files are in the expected location
|
63 |
clip_path = hub_cache / "models--openai--clip-vit-base-patch32"
|
64 |
if clip_path.exists():
|
|
|
66 |
else:
|
67 |
print(f"β οΈ Model files not found at expected location: {clip_path}")
|
68 |
return False
|
69 |
+
|
70 |
return True
|
71 |
+
|
72 |
except Exception as e:
|
73 |
print(f"β Error loading model: {e}")
|
74 |
return False
|
75 |
|
76 |
+
|
77 |
def main():
|
78 |
"""Main test function."""
|
79 |
# Load .env file if available
|
|
|
81 |
if env_file.exists():
|
82 |
print("π Loading .env file...")
|
83 |
from dotenv import load_dotenv
|
84 |
+
|
85 |
load_dotenv()
|
86 |
else:
|
87 |
print("β οΈ No .env file found, using system environment variables")
|
88 |
+
|
89 |
success = test_cache_configuration()
|
90 |
+
|
91 |
print("\n" + "=" * 50)
|
92 |
if success:
|
93 |
print("π All tests passed! External SSD cache is working correctly.")
|
|
|
96 |
print("β Some tests failed. Please check the configuration.")
|
97 |
sys.exit(1)
|
98 |
|
99 |
+
|
100 |
if __name__ == "__main__":
|
101 |
+
main()
|
tests/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
"""Tests package for Flowerfy application."""
|
|
|
1 |
+
"""Tests package for Flowerfy application."""
|
tests/test_models.py
CHANGED
@@ -15,18 +15,22 @@ from PIL import Image
|
|
15 |
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "src"))
|
16 |
|
17 |
# Import all required modules - if any fail, the script will fail immediately
|
18 |
-
from transformers import
|
|
|
|
|
|
|
|
|
19 |
|
20 |
from core.constants import DEFAULT_CLIP_MODEL, DEFAULT_CONVNEXT_MODEL
|
21 |
from services.models.flower_classification import FlowerClassificationService
|
22 |
-
from services.models.image_generation import ImageGenerationService
|
23 |
|
24 |
print("β
All dependencies imported successfully")
|
25 |
|
|
|
26 |
def test_convnext_model() -> bool:
|
27 |
"""Test ConvNeXt model loading."""
|
28 |
print("1οΈβ£ Testing ConvNeXt model loading...")
|
29 |
-
|
30 |
try:
|
31 |
print(f"Loading ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}")
|
32 |
model = ConvNextForImageClassification.from_pretrained(DEFAULT_CONVNEXT_MODEL)
|
@@ -38,19 +42,23 @@ def test_convnext_model() -> bool:
|
|
38 |
print(f"β ConvNeXt model test failed: {e}")
|
39 |
return False
|
40 |
|
|
|
41 |
def test_clip_model() -> bool:
|
42 |
"""Test CLIP model loading."""
|
43 |
print("\n2οΈβ£ Testing CLIP model loading...")
|
44 |
-
|
45 |
try:
|
46 |
print(f"Loading CLIP model: {DEFAULT_CLIP_MODEL}")
|
47 |
-
classifier = pipeline(
|
|
|
|
|
48 |
print("β
CLIP model loaded successfully")
|
49 |
return True
|
50 |
except Exception as e:
|
51 |
print(f"β CLIP model test failed: {e}")
|
52 |
return False
|
53 |
|
|
|
54 |
def test_image_generation_models() -> bool:
|
55 |
"""Test image generation models (SDXL models)."""
|
56 |
print("\n3οΈβ£ Testing image generation models...")
|
@@ -59,42 +67,50 @@ def test_image_generation_models() -> bool:
|
|
59 |
# Test SDXL first (now primary)
|
60 |
sdxl_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
61 |
print(f"Testing SDXL model (primary): {sdxl_model_id}")
|
62 |
-
|
63 |
try:
|
64 |
from diffusers import AutoPipelineForText2Image
|
65 |
-
|
|
|
|
|
|
|
66 |
print("β
SDXL model loaded successfully")
|
67 |
return True
|
68 |
except Exception as sdxl_error:
|
69 |
print(f"β οΈ SDXL model failed: {sdxl_error}")
|
70 |
-
|
71 |
# Test SDXL-Turbo fallback
|
72 |
turbo_model_id = "stabilityai/sdxl-turbo"
|
73 |
print(f"Testing SDXL-Turbo fallback: {turbo_model_id}")
|
74 |
-
|
75 |
try:
|
76 |
-
pipe = AutoPipelineForText2Image.from_pretrained(
|
|
|
|
|
77 |
print("β
SDXL-Turbo model loaded successfully as fallback")
|
78 |
return True
|
79 |
except Exception as turbo_error:
|
80 |
print(f"β Both SDXL models failed: {turbo_error}")
|
81 |
return False
|
82 |
-
|
83 |
except Exception as e:
|
84 |
print(f"β Image generation model test failed: {e}")
|
85 |
return False
|
86 |
|
|
|
87 |
def test_flower_classification_service() -> bool:
|
88 |
"""Test flower classification service."""
|
89 |
print("\n4οΈβ£ Testing flower classification service...")
|
90 |
-
|
91 |
try:
|
92 |
print("Initializing flower classification service...")
|
93 |
classifier = FlowerClassificationService()
|
94 |
-
|
95 |
# Create a dummy test image (3-channel RGB)
|
96 |
-
test_image = Image.fromarray(
|
97 |
-
|
|
|
|
|
98 |
# Test classification
|
99 |
results, message = classifier.identify_flowers(test_image, top_k=3)
|
100 |
print(f"β
Classification service working: {message}")
|
@@ -104,10 +120,11 @@ def test_flower_classification_service() -> bool:
|
|
104 |
print(f"β Classification service test failed: {e}")
|
105 |
return False
|
106 |
|
|
|
107 |
def test_image_generation_service() -> bool:
|
108 |
"""Test image generation service initialization."""
|
109 |
print("\n5οΈβ£ Testing image generation service initialization...")
|
110 |
-
|
111 |
try:
|
112 |
print("Testing image generation service initialization...")
|
113 |
# This will test if the service can be imported and initialized
|
@@ -119,11 +136,12 @@ def test_image_generation_service() -> bool:
|
|
119 |
print(f"β Image generation service test failed: {e}")
|
120 |
return False
|
121 |
|
|
|
122 |
def main():
|
123 |
"""Run all model tests."""
|
124 |
print("π§ͺ Testing Flowerfy models...")
|
125 |
print("==============================")
|
126 |
-
|
127 |
tests = [
|
128 |
("ConvNeXt Model", test_convnext_model),
|
129 |
("CLIP Model", test_clip_model),
|
@@ -131,10 +149,10 @@ def main():
|
|
131 |
("Classification Service", test_flower_classification_service),
|
132 |
("Generation Service", test_image_generation_service),
|
133 |
]
|
134 |
-
|
135 |
passed = 0
|
136 |
failed = 0
|
137 |
-
|
138 |
for test_name, test_func in tests:
|
139 |
try:
|
140 |
if test_func():
|
@@ -145,11 +163,11 @@ def main():
|
|
145 |
except Exception as e:
|
146 |
failed += 1
|
147 |
print(f"β {test_name} test failed with exception: {e}")
|
148 |
-
|
149 |
-
print(
|
150 |
print(f"β
Passed: {passed}")
|
151 |
print(f"β Failed: {failed}")
|
152 |
-
|
153 |
if failed == 0:
|
154 |
print("\nπ All model tests passed successfully!")
|
155 |
print("======================================")
|
@@ -167,6 +185,7 @@ def main():
|
|
167 |
print(f"\nβ {failed} test(s) failed. Please check the errors above.")
|
168 |
return False
|
169 |
|
|
|
170 |
if __name__ == "__main__":
|
171 |
success = main()
|
172 |
-
sys.exit(0 if success else 1)
|
|
|
15 |
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "src"))
|
16 |
|
17 |
# Import all required modules - if any fail, the script will fail immediately
|
18 |
+
from transformers import (
|
19 |
+
ConvNextForImageClassification,
|
20 |
+
ConvNextImageProcessor,
|
21 |
+
pipeline,
|
22 |
+
)
|
23 |
|
24 |
from core.constants import DEFAULT_CLIP_MODEL, DEFAULT_CONVNEXT_MODEL
|
25 |
from services.models.flower_classification import FlowerClassificationService
|
|
|
26 |
|
27 |
print("β
All dependencies imported successfully")
|
28 |
|
29 |
+
|
30 |
def test_convnext_model() -> bool:
|
31 |
"""Test ConvNeXt model loading."""
|
32 |
print("1οΈβ£ Testing ConvNeXt model loading...")
|
33 |
+
|
34 |
try:
|
35 |
print(f"Loading ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}")
|
36 |
model = ConvNextForImageClassification.from_pretrained(DEFAULT_CONVNEXT_MODEL)
|
|
|
42 |
print(f"β ConvNeXt model test failed: {e}")
|
43 |
return False
|
44 |
|
45 |
+
|
46 |
def test_clip_model() -> bool:
|
47 |
"""Test CLIP model loading."""
|
48 |
print("\n2οΈβ£ Testing CLIP model loading...")
|
49 |
+
|
50 |
try:
|
51 |
print(f"Loading CLIP model: {DEFAULT_CLIP_MODEL}")
|
52 |
+
classifier = pipeline(
|
53 |
+
"zero-shot-image-classification", model=DEFAULT_CLIP_MODEL
|
54 |
+
)
|
55 |
print("β
CLIP model loaded successfully")
|
56 |
return True
|
57 |
except Exception as e:
|
58 |
print(f"β CLIP model test failed: {e}")
|
59 |
return False
|
60 |
|
61 |
+
|
62 |
def test_image_generation_models() -> bool:
|
63 |
"""Test image generation models (SDXL models)."""
|
64 |
print("\n3οΈβ£ Testing image generation models...")
|
|
|
67 |
# Test SDXL first (now primary)
|
68 |
sdxl_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
69 |
print(f"Testing SDXL model (primary): {sdxl_model_id}")
|
70 |
+
|
71 |
try:
|
72 |
from diffusers import AutoPipelineForText2Image
|
73 |
+
|
74 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
75 |
+
sdxl_model_id, torch_dtype=torch.float32
|
76 |
+
).to("cpu")
|
77 |
print("β
SDXL model loaded successfully")
|
78 |
return True
|
79 |
except Exception as sdxl_error:
|
80 |
print(f"β οΈ SDXL model failed: {sdxl_error}")
|
81 |
+
|
82 |
# Test SDXL-Turbo fallback
|
83 |
turbo_model_id = "stabilityai/sdxl-turbo"
|
84 |
print(f"Testing SDXL-Turbo fallback: {turbo_model_id}")
|
85 |
+
|
86 |
try:
|
87 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
88 |
+
turbo_model_id, torch_dtype=torch.float32
|
89 |
+
).to("cpu")
|
90 |
print("β
SDXL-Turbo model loaded successfully as fallback")
|
91 |
return True
|
92 |
except Exception as turbo_error:
|
93 |
print(f"β Both SDXL models failed: {turbo_error}")
|
94 |
return False
|
95 |
+
|
96 |
except Exception as e:
|
97 |
print(f"β Image generation model test failed: {e}")
|
98 |
return False
|
99 |
|
100 |
+
|
101 |
def test_flower_classification_service() -> bool:
|
102 |
"""Test flower classification service."""
|
103 |
print("\n4οΈβ£ Testing flower classification service...")
|
104 |
+
|
105 |
try:
|
106 |
print("Initializing flower classification service...")
|
107 |
classifier = FlowerClassificationService()
|
108 |
+
|
109 |
# Create a dummy test image (3-channel RGB)
|
110 |
+
test_image = Image.fromarray(
|
111 |
+
np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
112 |
+
)
|
113 |
+
|
114 |
# Test classification
|
115 |
results, message = classifier.identify_flowers(test_image, top_k=3)
|
116 |
print(f"β
Classification service working: {message}")
|
|
|
120 |
print(f"β Classification service test failed: {e}")
|
121 |
return False
|
122 |
|
123 |
+
|
124 |
def test_image_generation_service() -> bool:
|
125 |
"""Test image generation service initialization."""
|
126 |
print("\n5οΈβ£ Testing image generation service initialization...")
|
127 |
+
|
128 |
try:
|
129 |
print("Testing image generation service initialization...")
|
130 |
# This will test if the service can be imported and initialized
|
|
|
136 |
print(f"β Image generation service test failed: {e}")
|
137 |
return False
|
138 |
|
139 |
+
|
140 |
def main():
|
141 |
"""Run all model tests."""
|
142 |
print("π§ͺ Testing Flowerfy models...")
|
143 |
print("==============================")
|
144 |
+
|
145 |
tests = [
|
146 |
("ConvNeXt Model", test_convnext_model),
|
147 |
("CLIP Model", test_clip_model),
|
|
|
149 |
("Classification Service", test_flower_classification_service),
|
150 |
("Generation Service", test_image_generation_service),
|
151 |
]
|
152 |
+
|
153 |
passed = 0
|
154 |
failed = 0
|
155 |
+
|
156 |
for test_name, test_func in tests:
|
157 |
try:
|
158 |
if test_func():
|
|
|
163 |
except Exception as e:
|
164 |
failed += 1
|
165 |
print(f"β {test_name} test failed with exception: {e}")
|
166 |
+
|
167 |
+
print("\nπ Test Results:")
|
168 |
print(f"β
Passed: {passed}")
|
169 |
print(f"β Failed: {failed}")
|
170 |
+
|
171 |
if failed == 0:
|
172 |
print("\nπ All model tests passed successfully!")
|
173 |
print("======================================")
|
|
|
185 |
print(f"\nβ {failed} test(s) failed. Please check the errors above.")
|
186 |
return False
|
187 |
|
188 |
+
|
189 |
if __name__ == "__main__":
|
190 |
success = main()
|
191 |
+
sys.exit(0 if success else 1)
|
training/advanced_trainer.py
CHANGED
@@ -4,26 +4,32 @@ Advanced ConvNeXt training script using Transformers Trainer.
|
|
4 |
This provides more sophisticated training features like evaluation, checkpointing, and logging.
|
5 |
"""
|
6 |
|
|
|
|
|
7 |
import os
|
|
|
8 |
import torch
|
9 |
-
import json
|
10 |
-
from transformers import ConvNextImageProcessor, ConvNextForImageClassification, Trainer, TrainingArguments
|
11 |
from dataset import FlowerDataset, advanced_collate_fn
|
12 |
-
import
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
class ConvNeXtTrainer(Trainer):
|
16 |
"""Custom trainer for ConvNeXt with proper loss computation."""
|
17 |
-
|
18 |
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
19 |
labels = inputs.get("labels")
|
20 |
outputs = model(**inputs)
|
21 |
-
|
22 |
if labels is not None:
|
23 |
loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
|
24 |
else:
|
25 |
loss = outputs.loss
|
26 |
-
|
27 |
return (loss, outputs) if return_outputs else loss
|
28 |
|
29 |
|
@@ -34,11 +40,11 @@ def advanced_train(
|
|
34 |
num_epochs=5,
|
35 |
batch_size=8,
|
36 |
learning_rate=1e-5,
|
37 |
-
flower_labels=None
|
38 |
):
|
39 |
"""
|
40 |
Advanced training function using Transformers Trainer.
|
41 |
-
|
42 |
Args:
|
43 |
image_dir: Directory containing training images organized by flower type
|
44 |
output_dir: Directory to save the trained model
|
@@ -47,43 +53,55 @@ def advanced_train(
|
|
47 |
batch_size: Training batch size
|
48 |
learning_rate: Learning rate for optimization
|
49 |
flower_labels: List of flower labels (auto-detected if None)
|
50 |
-
|
51 |
Returns:
|
52 |
str: Path to the saved model directory, or None if training failed
|
53 |
"""
|
54 |
print("πΈ Advanced ConvNeXt Flower Model Training")
|
55 |
print("=" * 50)
|
56 |
-
|
57 |
# Check training data
|
58 |
if not os.path.exists(image_dir):
|
59 |
print(f"β Training directory not found: {image_dir}")
|
60 |
return None
|
61 |
-
|
62 |
# Load model and processor
|
63 |
print(f"Loading model: {model_name}")
|
64 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
65 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
66 |
-
|
67 |
# Create dataset
|
68 |
dataset = FlowerDataset(image_dir, processor, flower_labels)
|
69 |
-
|
70 |
if len(dataset) == 0:
|
71 |
-
print(
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
return None
|
74 |
-
|
75 |
# Split dataset (80% train, 20% eval)
|
76 |
train_size = int(0.8 * len(dataset))
|
77 |
eval_size = len(dataset) - train_size
|
78 |
-
train_dataset, eval_dataset = torch.utils.data.random_split(
|
79 |
-
|
|
|
|
|
80 |
# Update model config for the number of classes
|
81 |
if len(dataset.flower_labels) != model.config.num_labels:
|
82 |
model.config.num_labels = len(dataset.flower_labels)
|
83 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
84 |
-
final_hidden_size =
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
# Training arguments
|
88 |
training_args = TrainingArguments(
|
89 |
output_dir=output_dir,
|
@@ -102,7 +120,7 @@ def advanced_train(
|
|
102 |
dataloader_num_workers=0, # Set to 0 to avoid multiprocessing issues
|
103 |
remove_unused_columns=False,
|
104 |
)
|
105 |
-
|
106 |
# Create trainer
|
107 |
try:
|
108 |
trainer = ConvNeXtTrainer(
|
@@ -116,7 +134,7 @@ def advanced_train(
|
|
116 |
except Exception as e:
|
117 |
print(f"β Error creating trainer: {e}")
|
118 |
return None
|
119 |
-
|
120 |
# Train model
|
121 |
print("Starting advanced training...")
|
122 |
try:
|
@@ -125,14 +143,15 @@ def advanced_train(
|
|
125 |
except Exception as e:
|
126 |
print(f"β Training failed: {e}")
|
127 |
import traceback
|
|
|
128 |
traceback.print_exc()
|
129 |
return None
|
130 |
-
|
131 |
# Save final model
|
132 |
final_model_path = os.path.join(output_dir, "final_model")
|
133 |
model.save_pretrained(final_model_path)
|
134 |
processor.save_pretrained(final_model_path)
|
135 |
-
|
136 |
# Save training config
|
137 |
config = {
|
138 |
"model_name": model_name,
|
@@ -142,27 +161,43 @@ def advanced_train(
|
|
142 |
"learning_rate": learning_rate,
|
143 |
"train_samples": len(train_dataset),
|
144 |
"eval_samples": len(eval_dataset),
|
145 |
-
"training_type": "advanced"
|
146 |
}
|
147 |
-
|
148 |
with open(os.path.join(final_model_path, "training_config.json"), "w") as f:
|
149 |
json.dump(config, f, indent=2)
|
150 |
-
|
151 |
print(f"β
Advanced training complete! Model saved to {final_model_path}")
|
152 |
return final_model_path
|
153 |
|
154 |
|
155 |
if __name__ == "__main__":
|
156 |
-
parser = argparse.ArgumentParser(
|
157 |
-
|
158 |
-
|
159 |
-
parser.add_argument(
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
|
162 |
-
parser.add_argument(
|
163 |
-
|
|
|
|
|
164 |
args = parser.parse_args()
|
165 |
-
|
166 |
try:
|
167 |
result = advanced_train(
|
168 |
image_dir=args.image_dir,
|
@@ -170,7 +205,7 @@ if __name__ == "__main__":
|
|
170 |
model_name=args.model_name,
|
171 |
num_epochs=args.epochs,
|
172 |
batch_size=args.batch_size,
|
173 |
-
learning_rate=args.learning_rate
|
174 |
)
|
175 |
if not result:
|
176 |
print("β Training failed!")
|
@@ -180,5 +215,6 @@ if __name__ == "__main__":
|
|
180 |
except Exception as e:
|
181 |
print(f"β Training failed: {e}")
|
182 |
import traceback
|
|
|
183 |
traceback.print_exc()
|
184 |
-
exit(1)
|
|
|
4 |
This provides more sophisticated training features like evaluation, checkpointing, and logging.
|
5 |
"""
|
6 |
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
import os
|
10 |
+
|
11 |
import torch
|
|
|
|
|
12 |
from dataset import FlowerDataset, advanced_collate_fn
|
13 |
+
from transformers import (
|
14 |
+
ConvNextForImageClassification,
|
15 |
+
ConvNextImageProcessor,
|
16 |
+
Trainer,
|
17 |
+
TrainingArguments,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
class ConvNeXtTrainer(Trainer):
|
22 |
"""Custom trainer for ConvNeXt with proper loss computation."""
|
23 |
+
|
24 |
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
25 |
labels = inputs.get("labels")
|
26 |
outputs = model(**inputs)
|
27 |
+
|
28 |
if labels is not None:
|
29 |
loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
|
30 |
else:
|
31 |
loss = outputs.loss
|
32 |
+
|
33 |
return (loss, outputs) if return_outputs else loss
|
34 |
|
35 |
|
|
|
40 |
num_epochs=5,
|
41 |
batch_size=8,
|
42 |
learning_rate=1e-5,
|
43 |
+
flower_labels=None,
|
44 |
):
|
45 |
"""
|
46 |
Advanced training function using Transformers Trainer.
|
47 |
+
|
48 |
Args:
|
49 |
image_dir: Directory containing training images organized by flower type
|
50 |
output_dir: Directory to save the trained model
|
|
|
53 |
batch_size: Training batch size
|
54 |
learning_rate: Learning rate for optimization
|
55 |
flower_labels: List of flower labels (auto-detected if None)
|
56 |
+
|
57 |
Returns:
|
58 |
str: Path to the saved model directory, or None if training failed
|
59 |
"""
|
60 |
print("πΈ Advanced ConvNeXt Flower Model Training")
|
61 |
print("=" * 50)
|
62 |
+
|
63 |
# Check training data
|
64 |
if not os.path.exists(image_dir):
|
65 |
print(f"β Training directory not found: {image_dir}")
|
66 |
return None
|
67 |
+
|
68 |
# Load model and processor
|
69 |
print(f"Loading model: {model_name}")
|
70 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
71 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
72 |
+
|
73 |
# Create dataset
|
74 |
dataset = FlowerDataset(image_dir, processor, flower_labels)
|
75 |
+
|
76 |
if len(dataset) == 0:
|
77 |
+
print(
|
78 |
+
"β No training data found. Please add images to subdirectories in training_data/images/"
|
79 |
+
)
|
80 |
+
print(
|
81 |
+
"Example: training_data/images/roses/, training_data/images/tulips/, etc."
|
82 |
+
)
|
83 |
return None
|
84 |
+
|
85 |
# Split dataset (80% train, 20% eval)
|
86 |
train_size = int(0.8 * len(dataset))
|
87 |
eval_size = len(dataset) - train_size
|
88 |
+
train_dataset, eval_dataset = torch.utils.data.random_split(
|
89 |
+
dataset, [train_size, eval_size]
|
90 |
+
)
|
91 |
+
|
92 |
# Update model config for the number of classes
|
93 |
if len(dataset.flower_labels) != model.config.num_labels:
|
94 |
model.config.num_labels = len(dataset.flower_labels)
|
95 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
96 |
+
final_hidden_size = (
|
97 |
+
model.config.hidden_sizes[-1]
|
98 |
+
if hasattr(model.config, "hidden_sizes")
|
99 |
+
else 768
|
100 |
+
)
|
101 |
+
model.classifier = torch.nn.Linear(
|
102 |
+
final_hidden_size, len(dataset.flower_labels)
|
103 |
+
)
|
104 |
+
|
105 |
# Training arguments
|
106 |
training_args = TrainingArguments(
|
107 |
output_dir=output_dir,
|
|
|
120 |
dataloader_num_workers=0, # Set to 0 to avoid multiprocessing issues
|
121 |
remove_unused_columns=False,
|
122 |
)
|
123 |
+
|
124 |
# Create trainer
|
125 |
try:
|
126 |
trainer = ConvNeXtTrainer(
|
|
|
134 |
except Exception as e:
|
135 |
print(f"β Error creating trainer: {e}")
|
136 |
return None
|
137 |
+
|
138 |
# Train model
|
139 |
print("Starting advanced training...")
|
140 |
try:
|
|
|
143 |
except Exception as e:
|
144 |
print(f"β Training failed: {e}")
|
145 |
import traceback
|
146 |
+
|
147 |
traceback.print_exc()
|
148 |
return None
|
149 |
+
|
150 |
# Save final model
|
151 |
final_model_path = os.path.join(output_dir, "final_model")
|
152 |
model.save_pretrained(final_model_path)
|
153 |
processor.save_pretrained(final_model_path)
|
154 |
+
|
155 |
# Save training config
|
156 |
config = {
|
157 |
"model_name": model_name,
|
|
|
161 |
"learning_rate": learning_rate,
|
162 |
"train_samples": len(train_dataset),
|
163 |
"eval_samples": len(eval_dataset),
|
164 |
+
"training_type": "advanced",
|
165 |
}
|
166 |
+
|
167 |
with open(os.path.join(final_model_path, "training_config.json"), "w") as f:
|
168 |
json.dump(config, f, indent=2)
|
169 |
+
|
170 |
print(f"β
Advanced training complete! Model saved to {final_model_path}")
|
171 |
return final_model_path
|
172 |
|
173 |
|
174 |
if __name__ == "__main__":
|
175 |
+
parser = argparse.ArgumentParser(
|
176 |
+
description="Advanced ConvNeXt training for flower classification"
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--image_dir",
|
180 |
+
default="training_data/images",
|
181 |
+
help="Directory containing training images",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--output_dir",
|
185 |
+
default="training_data/trained_models/advanced_trained",
|
186 |
+
help="Output directory for trained model",
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
"--model_name", default="facebook/convnext-base-224-22k", help="Base model name"
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--epochs", type=int, default=5, help="Number of training epochs"
|
193 |
+
)
|
194 |
parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
|
195 |
+
parser.add_argument(
|
196 |
+
"--learning_rate", type=float, default=1e-5, help="Learning rate"
|
197 |
+
)
|
198 |
+
|
199 |
args = parser.parse_args()
|
200 |
+
|
201 |
try:
|
202 |
result = advanced_train(
|
203 |
image_dir=args.image_dir,
|
|
|
205 |
model_name=args.model_name,
|
206 |
num_epochs=args.epochs,
|
207 |
batch_size=args.batch_size,
|
208 |
+
learning_rate=args.learning_rate,
|
209 |
)
|
210 |
if not result:
|
211 |
print("β Training failed!")
|
|
|
215 |
except Exception as e:
|
216 |
print(f"β Training failed: {e}")
|
217 |
import traceback
|
218 |
+
|
219 |
traceback.print_exc()
|
220 |
+
exit(1)
|
training/dataset.py
CHANGED
@@ -3,9 +3,10 @@
|
|
3 |
Flower Dataset class for training ConvNeXt models.
|
4 |
"""
|
5 |
|
|
|
6 |
import os
|
|
|
7 |
import torch
|
8 |
-
import glob
|
9 |
from PIL import Image
|
10 |
from torch.utils.data import Dataset
|
11 |
|
@@ -15,7 +16,7 @@ class FlowerDataset(Dataset):
|
|
15 |
self.image_paths = []
|
16 |
self.labels = []
|
17 |
self.processor = processor
|
18 |
-
|
19 |
# Auto-detect flower types from directory structure if not provided
|
20 |
if flower_labels is None:
|
21 |
detected_types = []
|
@@ -28,22 +29,24 @@ class FlowerDataset(Dataset):
|
|
28 |
self.flower_labels = sorted(detected_types)
|
29 |
else:
|
30 |
self.flower_labels = flower_labels
|
31 |
-
|
32 |
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
|
33 |
-
|
34 |
# Load images from subdirectories (organized by flower type)
|
35 |
for flower_type in os.listdir(image_dir):
|
36 |
flower_path = os.path.join(image_dir, flower_type)
|
37 |
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
|
38 |
image_files = self._get_image_files(flower_path)
|
39 |
-
|
40 |
for img_path in image_files:
|
41 |
self.image_paths.append(img_path)
|
42 |
self.labels.append(self.label_to_id[flower_type])
|
43 |
-
|
44 |
-
print(
|
|
|
|
|
45 |
print(f"Flower types: {self.flower_labels}")
|
46 |
-
|
47 |
def _get_image_files(self, directory):
|
48 |
"""Get all supported image files from directory."""
|
49 |
extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
|
@@ -52,21 +55,21 @@ class FlowerDataset(Dataset):
|
|
52 |
image_files.extend(glob.glob(os.path.join(directory, ext)))
|
53 |
image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
|
54 |
return image_files
|
55 |
-
|
56 |
def __len__(self):
|
57 |
return len(self.image_paths)
|
58 |
-
|
59 |
def __getitem__(self, idx):
|
60 |
image_path = self.image_paths[idx]
|
61 |
image = Image.open(image_path).convert("RGB")
|
62 |
label = self.labels[idx]
|
63 |
-
|
64 |
# Process image for ConvNeXt
|
65 |
inputs = self.processor(images=image, return_tensors="pt")
|
66 |
-
|
67 |
return {
|
68 |
-
|
69 |
-
|
70 |
}
|
71 |
|
72 |
|
@@ -74,29 +77,24 @@ def simple_collate_fn(batch):
|
|
74 |
"""Simple collation function for training."""
|
75 |
pixel_values = []
|
76 |
labels = []
|
77 |
-
|
78 |
for item in batch:
|
79 |
-
pixel_values.append(item[
|
80 |
-
labels.append(item[
|
81 |
-
|
82 |
-
return {
|
83 |
-
'pixel_values': torch.stack(pixel_values),
|
84 |
-
'labels': torch.stack(labels)
|
85 |
-
}
|
86 |
|
87 |
|
88 |
def advanced_collate_fn(batch):
|
89 |
"""Advanced collation function for Trainer."""
|
90 |
# Extract components
|
91 |
-
pixel_values = [item[
|
92 |
-
labels = [item[
|
93 |
-
|
94 |
# Stack everything
|
95 |
-
result = {
|
96 |
-
|
97 |
-
}
|
98 |
-
|
99 |
if labels:
|
100 |
-
result[
|
101 |
-
|
102 |
-
return result
|
|
|
3 |
Flower Dataset class for training ConvNeXt models.
|
4 |
"""
|
5 |
|
6 |
+
import glob
|
7 |
import os
|
8 |
+
|
9 |
import torch
|
|
|
10 |
from PIL import Image
|
11 |
from torch.utils.data import Dataset
|
12 |
|
|
|
16 |
self.image_paths = []
|
17 |
self.labels = []
|
18 |
self.processor = processor
|
19 |
+
|
20 |
# Auto-detect flower types from directory structure if not provided
|
21 |
if flower_labels is None:
|
22 |
detected_types = []
|
|
|
29 |
self.flower_labels = sorted(detected_types)
|
30 |
else:
|
31 |
self.flower_labels = flower_labels
|
32 |
+
|
33 |
self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
|
34 |
+
|
35 |
# Load images from subdirectories (organized by flower type)
|
36 |
for flower_type in os.listdir(image_dir):
|
37 |
flower_path = os.path.join(image_dir, flower_type)
|
38 |
if os.path.isdir(flower_path) and flower_type in self.label_to_id:
|
39 |
image_files = self._get_image_files(flower_path)
|
40 |
+
|
41 |
for img_path in image_files:
|
42 |
self.image_paths.append(img_path)
|
43 |
self.labels.append(self.label_to_id[flower_type])
|
44 |
+
|
45 |
+
print(
|
46 |
+
f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
|
47 |
+
)
|
48 |
print(f"Flower types: {self.flower_labels}")
|
49 |
+
|
50 |
def _get_image_files(self, directory):
|
51 |
"""Get all supported image files from directory."""
|
52 |
extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
|
|
|
55 |
image_files.extend(glob.glob(os.path.join(directory, ext)))
|
56 |
image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
|
57 |
return image_files
|
58 |
+
|
59 |
def __len__(self):
|
60 |
return len(self.image_paths)
|
61 |
+
|
62 |
def __getitem__(self, idx):
|
63 |
image_path = self.image_paths[idx]
|
64 |
image = Image.open(image_path).convert("RGB")
|
65 |
label = self.labels[idx]
|
66 |
+
|
67 |
# Process image for ConvNeXt
|
68 |
inputs = self.processor(images=image, return_tensors="pt")
|
69 |
+
|
70 |
return {
|
71 |
+
"pixel_values": inputs["pixel_values"].squeeze(),
|
72 |
+
"labels": torch.tensor(label, dtype=torch.long),
|
73 |
}
|
74 |
|
75 |
|
|
|
77 |
"""Simple collation function for training."""
|
78 |
pixel_values = []
|
79 |
labels = []
|
80 |
+
|
81 |
for item in batch:
|
82 |
+
pixel_values.append(item["pixel_values"])
|
83 |
+
labels.append(item["labels"])
|
84 |
+
|
85 |
+
return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)}
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
def advanced_collate_fn(batch):
|
89 |
"""Advanced collation function for Trainer."""
|
90 |
# Extract components
|
91 |
+
pixel_values = [item["pixel_values"] for item in batch]
|
92 |
+
labels = [item["labels"] for item in batch if "labels" in item]
|
93 |
+
|
94 |
# Stack everything
|
95 |
+
result = {"pixel_values": torch.stack(pixel_values)}
|
96 |
+
|
|
|
|
|
97 |
if labels:
|
98 |
+
result["labels"] = torch.stack(labels)
|
99 |
+
|
100 |
+
return result
|
training/simple_trainer.py
CHANGED
@@ -4,13 +4,13 @@ Simple ConvNeXt training script without using the Transformers Trainer class.
|
|
4 |
This is a lightweight training implementation for quick model fine-tuning.
|
5 |
"""
|
6 |
|
|
|
7 |
import os
|
|
|
8 |
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
|
12 |
from dataset import FlowerDataset, simple_collate_fn
|
13 |
-
import
|
|
|
14 |
|
15 |
|
16 |
def simple_train(
|
@@ -19,11 +19,11 @@ def simple_train(
|
|
19 |
epochs=3,
|
20 |
batch_size=4,
|
21 |
learning_rate=1e-5,
|
22 |
-
model_name="facebook/convnext-base-224-22k"
|
23 |
):
|
24 |
"""
|
25 |
Simple training function for ConvNeXt flower classification.
|
26 |
-
|
27 |
Args:
|
28 |
image_dir: Directory containing training images organized by flower type
|
29 |
output_dir: Directory to save the trained model
|
@@ -31,91 +31,107 @@ def simple_train(
|
|
31 |
batch_size: Training batch size
|
32 |
learning_rate: Learning rate for optimization
|
33 |
model_name: Base ConvNeXt model to fine-tune
|
34 |
-
|
35 |
Returns:
|
36 |
str: Path to the saved model directory, or None if training failed
|
37 |
"""
|
38 |
print("πΈ Simple ConvNeXt Flower Model Training")
|
39 |
print("=" * 40)
|
40 |
-
|
41 |
# Check training data
|
42 |
if not os.path.exists(image_dir):
|
43 |
print(f"β Training directory not found: {image_dir}")
|
44 |
return None
|
45 |
-
|
46 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
print(f"Using device: {device}")
|
48 |
-
|
49 |
# Load model and processor
|
50 |
print(f"Loading model: {model_name}")
|
51 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
52 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
53 |
model.to(device)
|
54 |
-
|
55 |
# Create dataset
|
56 |
dataset = FlowerDataset(image_dir, processor)
|
57 |
-
|
58 |
if len(dataset) < 5:
|
59 |
print("β Need at least 5 images for training")
|
60 |
return None
|
61 |
-
|
62 |
# Split dataset
|
63 |
train_size = int(0.8 * len(dataset))
|
64 |
train_dataset = torch.utils.data.Subset(dataset, range(train_size))
|
65 |
-
|
66 |
# Update model config for the number of classes
|
67 |
if len(dataset.flower_labels) != model.config.num_labels:
|
68 |
model.config.num_labels = len(dataset.flower_labels)
|
69 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
70 |
-
final_hidden_size =
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
model.classifier.to(device)
|
73 |
-
|
74 |
# Create data loader
|
75 |
-
train_loader = DataLoader(
|
76 |
-
|
|
|
|
|
77 |
# Setup optimizer
|
78 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
79 |
-
|
80 |
# Training loop
|
81 |
model.train()
|
82 |
print(f"Starting training on {len(train_dataset)} samples for {epochs} epochs...")
|
83 |
-
|
84 |
for epoch in range(epochs):
|
85 |
total_loss = 0
|
86 |
num_batches = 0
|
87 |
-
|
88 |
for batch_idx, batch in enumerate(train_loader):
|
89 |
# Move to device
|
90 |
-
pixel_values = batch[
|
91 |
-
labels = batch[
|
92 |
-
|
93 |
# Zero gradients
|
94 |
optimizer.zero_grad()
|
95 |
-
|
96 |
# Forward pass
|
97 |
outputs = model(pixel_values=pixel_values, labels=labels)
|
98 |
loss = outputs.loss
|
99 |
-
|
100 |
# Backward pass
|
101 |
loss.backward()
|
102 |
optimizer.step()
|
103 |
-
|
104 |
total_loss += loss.item()
|
105 |
num_batches += 1
|
106 |
-
|
107 |
if batch_idx % 2 == 0 or batch_idx == len(train_loader) - 1:
|
108 |
-
print(
|
109 |
-
|
|
|
|
|
110 |
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
111 |
-
print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
|
112 |
-
|
113 |
# Save model
|
114 |
os.makedirs(output_dir, exist_ok=True)
|
115 |
-
|
116 |
model.save_pretrained(output_dir)
|
117 |
processor.save_pretrained(output_dir)
|
118 |
-
|
119 |
# Save config
|
120 |
config = {
|
121 |
"model_name": model_name,
|
@@ -125,29 +141,45 @@ def simple_train(
|
|
125 |
"learning_rate": learning_rate,
|
126 |
"train_samples": len(train_dataset),
|
127 |
"num_labels": len(dataset.flower_labels),
|
128 |
-
"training_type": "simple"
|
129 |
}
|
130 |
-
|
131 |
with open(os.path.join(output_dir, "training_config.json"), "w") as f:
|
132 |
json.dump(config, f, indent=2)
|
133 |
-
|
134 |
print(f"β
ConvNeXt training completed! Model saved to {output_dir}")
|
135 |
return output_dir
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
import argparse
|
140 |
-
|
141 |
-
parser = argparse.ArgumentParser(
|
142 |
-
|
143 |
-
|
144 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
parser.add_argument("--batch_size", type=int, default=4, help="Training batch size")
|
146 |
-
parser.add_argument(
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
args = parser.parse_args()
|
150 |
-
|
151 |
try:
|
152 |
result = simple_train(
|
153 |
image_dir=args.image_dir,
|
@@ -155,7 +187,7 @@ if __name__ == "__main__":
|
|
155 |
epochs=args.epochs,
|
156 |
batch_size=args.batch_size,
|
157 |
learning_rate=args.learning_rate,
|
158 |
-
model_name=args.model_name
|
159 |
)
|
160 |
if not result:
|
161 |
print("β Training failed!")
|
@@ -165,5 +197,6 @@ if __name__ == "__main__":
|
|
165 |
except Exception as e:
|
166 |
print(f"β Training failed: {e}")
|
167 |
import traceback
|
|
|
168 |
traceback.print_exc()
|
169 |
-
exit(1)
|
|
|
4 |
This is a lightweight training implementation for quick model fine-tuning.
|
5 |
"""
|
6 |
|
7 |
+
import json
|
8 |
import os
|
9 |
+
|
10 |
import torch
|
|
|
|
|
|
|
11 |
from dataset import FlowerDataset, simple_collate_fn
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from transformers import ConvNextForImageClassification, ConvNextImageProcessor
|
14 |
|
15 |
|
16 |
def simple_train(
|
|
|
19 |
epochs=3,
|
20 |
batch_size=4,
|
21 |
learning_rate=1e-5,
|
22 |
+
model_name="facebook/convnext-base-224-22k",
|
23 |
):
|
24 |
"""
|
25 |
Simple training function for ConvNeXt flower classification.
|
26 |
+
|
27 |
Args:
|
28 |
image_dir: Directory containing training images organized by flower type
|
29 |
output_dir: Directory to save the trained model
|
|
|
31 |
batch_size: Training batch size
|
32 |
learning_rate: Learning rate for optimization
|
33 |
model_name: Base ConvNeXt model to fine-tune
|
34 |
+
|
35 |
Returns:
|
36 |
str: Path to the saved model directory, or None if training failed
|
37 |
"""
|
38 |
print("πΈ Simple ConvNeXt Flower Model Training")
|
39 |
print("=" * 40)
|
40 |
+
|
41 |
# Check training data
|
42 |
if not os.path.exists(image_dir):
|
43 |
print(f"β Training directory not found: {image_dir}")
|
44 |
return None
|
45 |
+
|
46 |
+
device = (
|
47 |
+
"cuda"
|
48 |
+
if torch.cuda.is_available()
|
49 |
+
else "mps"
|
50 |
+
if torch.backends.mps.is_available()
|
51 |
+
else "cpu"
|
52 |
+
)
|
53 |
print(f"Using device: {device}")
|
54 |
+
|
55 |
# Load model and processor
|
56 |
print(f"Loading model: {model_name}")
|
57 |
model = ConvNextForImageClassification.from_pretrained(model_name)
|
58 |
processor = ConvNextImageProcessor.from_pretrained(model_name)
|
59 |
model.to(device)
|
60 |
+
|
61 |
# Create dataset
|
62 |
dataset = FlowerDataset(image_dir, processor)
|
63 |
+
|
64 |
if len(dataset) < 5:
|
65 |
print("β Need at least 5 images for training")
|
66 |
return None
|
67 |
+
|
68 |
# Split dataset
|
69 |
train_size = int(0.8 * len(dataset))
|
70 |
train_dataset = torch.utils.data.Subset(dataset, range(train_size))
|
71 |
+
|
72 |
# Update model config for the number of classes
|
73 |
if len(dataset.flower_labels) != model.config.num_labels:
|
74 |
model.config.num_labels = len(dataset.flower_labels)
|
75 |
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
|
76 |
+
final_hidden_size = (
|
77 |
+
model.config.hidden_sizes[-1]
|
78 |
+
if hasattr(model.config, "hidden_sizes")
|
79 |
+
else 768
|
80 |
+
)
|
81 |
+
model.classifier = torch.nn.Linear(
|
82 |
+
final_hidden_size, len(dataset.flower_labels)
|
83 |
+
)
|
84 |
model.classifier.to(device)
|
85 |
+
|
86 |
# Create data loader
|
87 |
+
train_loader = DataLoader(
|
88 |
+
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=simple_collate_fn
|
89 |
+
)
|
90 |
+
|
91 |
# Setup optimizer
|
92 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
93 |
+
|
94 |
# Training loop
|
95 |
model.train()
|
96 |
print(f"Starting training on {len(train_dataset)} samples for {epochs} epochs...")
|
97 |
+
|
98 |
for epoch in range(epochs):
|
99 |
total_loss = 0
|
100 |
num_batches = 0
|
101 |
+
|
102 |
for batch_idx, batch in enumerate(train_loader):
|
103 |
# Move to device
|
104 |
+
pixel_values = batch["pixel_values"].to(device)
|
105 |
+
labels = batch["labels"].to(device)
|
106 |
+
|
107 |
# Zero gradients
|
108 |
optimizer.zero_grad()
|
109 |
+
|
110 |
# Forward pass
|
111 |
outputs = model(pixel_values=pixel_values, labels=labels)
|
112 |
loss = outputs.loss
|
113 |
+
|
114 |
# Backward pass
|
115 |
loss.backward()
|
116 |
optimizer.step()
|
117 |
+
|
118 |
total_loss += loss.item()
|
119 |
num_batches += 1
|
120 |
+
|
121 |
if batch_idx % 2 == 0 or batch_idx == len(train_loader) - 1:
|
122 |
+
print(
|
123 |
+
f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(train_loader)}: Loss = {loss.item():.4f}"
|
124 |
+
)
|
125 |
+
|
126 |
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
127 |
+
print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
|
128 |
+
|
129 |
# Save model
|
130 |
os.makedirs(output_dir, exist_ok=True)
|
131 |
+
|
132 |
model.save_pretrained(output_dir)
|
133 |
processor.save_pretrained(output_dir)
|
134 |
+
|
135 |
# Save config
|
136 |
config = {
|
137 |
"model_name": model_name,
|
|
|
141 |
"learning_rate": learning_rate,
|
142 |
"train_samples": len(train_dataset),
|
143 |
"num_labels": len(dataset.flower_labels),
|
144 |
+
"training_type": "simple",
|
145 |
}
|
146 |
+
|
147 |
with open(os.path.join(output_dir, "training_config.json"), "w") as f:
|
148 |
json.dump(config, f, indent=2)
|
149 |
+
|
150 |
print(f"β
ConvNeXt training completed! Model saved to {output_dir}")
|
151 |
return output_dir
|
152 |
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
import argparse
|
156 |
+
|
157 |
+
parser = argparse.ArgumentParser(
|
158 |
+
description="Simple ConvNeXt training for flower classification"
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--image_dir",
|
162 |
+
default="training_data/images",
|
163 |
+
help="Directory containing training images",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--output_dir",
|
167 |
+
default="training_data/trained_models/simple_trained",
|
168 |
+
help="Output directory for trained model",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--epochs", type=int, default=3, help="Number of training epochs"
|
172 |
+
)
|
173 |
parser.add_argument("--batch_size", type=int, default=4, help="Training batch size")
|
174 |
+
parser.add_argument(
|
175 |
+
"--learning_rate", type=float, default=1e-5, help="Learning rate"
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--model_name", default="facebook/convnext-base-224-22k", help="Base model name"
|
179 |
+
)
|
180 |
+
|
181 |
args = parser.parse_args()
|
182 |
+
|
183 |
try:
|
184 |
result = simple_train(
|
185 |
image_dir=args.image_dir,
|
|
|
187 |
epochs=args.epochs,
|
188 |
batch_size=args.batch_size,
|
189 |
learning_rate=args.learning_rate,
|
190 |
+
model_name=args.model_name,
|
191 |
)
|
192 |
if not result:
|
193 |
print("β Training failed!")
|
|
|
197 |
except Exception as e:
|
198 |
print(f"β Training failed: {e}")
|
199 |
import traceback
|
200 |
+
|
201 |
traceback.print_exc()
|
202 |
+
exit(1)
|