Toy Claude commited on
Commit
b24c04f
Β·
1 Parent(s): ef5274f

Apply code formatting and fix compatibility issues

Browse files

- Update pydantic dependencies for HF Spaces compatibility
- Apply ruff formatting across codebase
- Fix import organization and type annotations
- Ensure SDXL-only model architecture

πŸ€– Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

app.py CHANGED
@@ -20,6 +20,7 @@ if src_path not in sys.path:
20
 
21
  # Initialize config early to setup cache paths before model imports
22
  from core.config import config
 
23
  print(f"πŸ”§ Environment: {'HF Spaces' if config.is_hf_spaces else 'Local'}")
24
  print(f"πŸ”§ Device: {config.device}, dtype: {config.dtype}")
25
 
@@ -28,64 +29,66 @@ from ui.generate.generate_tab import GenerateTab
28
  from ui.identify.identify_tab import IdentifyTab
29
  from ui.train.train_tab import TrainTab
30
 
 
31
  class FlowerifyApp:
32
  """Main application class for Flowerify."""
33
-
34
  def __init__(self):
35
  self.generate_tab = GenerateTab()
36
  self.identify_tab = IdentifyTab()
37
  self.train_tab = TrainTab()
38
  self.french_style_tab = FrenchStyleTab()
39
-
40
  def create_interface(self) -> gr.Blocks:
41
  """Create the main Gradio interface."""
42
  with gr.Blocks(title="🌸 Flowerify - AI Flower Generator & Identifier") as demo:
43
  gr.Markdown("# 🌸 Flowerfy β€” Text β†’ Image + Flower Identifier")
44
-
45
  with gr.Tabs():
46
  # Create each tab
47
  generate_tab = self.generate_tab.create_ui()
48
  identify_tab = self.identify_tab.create_ui()
49
  train_tab = self.train_tab.create_ui()
50
  french_style_tab = self.french_style_tab.create_ui()
51
-
52
  # Wire cross-tab interactions
53
  self._setup_cross_tab_interactions()
54
-
55
  # Initialize data on load
56
  demo.load(
57
  self.train_tab._count_training_images,
58
- outputs=[self.train_tab.data_status]
59
  )
60
-
61
  return demo
62
-
63
  def _setup_cross_tab_interactions(self):
64
  """Setup interactions between tabs."""
65
  # Auto-send generated image to Identify tab
66
  self.generate_tab.output_image.change(
67
  self.identify_tab.set_image,
68
  inputs=self.generate_tab.output_image,
69
- outputs=self.identify_tab.image_input
70
  )
71
-
72
  def launch(self, **kwargs):
73
  """Launch the application."""
74
  demo = self.create_interface()
75
  # Add share=True for HF Spaces compatibility
76
  if config.is_hf_spaces:
77
- kwargs.setdefault('share', True)
78
  return demo.queue().launch(**kwargs)
79
 
 
80
  def main():
81
  """Main entry point."""
82
  try:
83
  print("🌸 Starting Flowerify (SDXL models)")
84
  print("Loading models and initializing UI...")
85
-
86
  app = FlowerifyApp()
87
  app.launch()
88
-
89
  except KeyboardInterrupt:
90
  print("\nπŸ‘‹ Application stopped by user")
91
  except Exception as e:
@@ -93,5 +96,6 @@ def main():
93
  traceback.print_exc()
94
  sys.exit(1)
95
 
 
96
  if __name__ == "__main__":
97
- main()
 
20
 
21
  # Initialize config early to setup cache paths before model imports
22
  from core.config import config
23
+
24
  print(f"πŸ”§ Environment: {'HF Spaces' if config.is_hf_spaces else 'Local'}")
25
  print(f"πŸ”§ Device: {config.device}, dtype: {config.dtype}")
26
 
 
29
  from ui.identify.identify_tab import IdentifyTab
30
  from ui.train.train_tab import TrainTab
31
 
32
+
33
  class FlowerifyApp:
34
  """Main application class for Flowerify."""
35
+
36
  def __init__(self):
37
  self.generate_tab = GenerateTab()
38
  self.identify_tab = IdentifyTab()
39
  self.train_tab = TrainTab()
40
  self.french_style_tab = FrenchStyleTab()
41
+
42
  def create_interface(self) -> gr.Blocks:
43
  """Create the main Gradio interface."""
44
  with gr.Blocks(title="🌸 Flowerify - AI Flower Generator & Identifier") as demo:
45
  gr.Markdown("# 🌸 Flowerfy β€” Text β†’ Image + Flower Identifier")
46
+
47
  with gr.Tabs():
48
  # Create each tab
49
  generate_tab = self.generate_tab.create_ui()
50
  identify_tab = self.identify_tab.create_ui()
51
  train_tab = self.train_tab.create_ui()
52
  french_style_tab = self.french_style_tab.create_ui()
53
+
54
  # Wire cross-tab interactions
55
  self._setup_cross_tab_interactions()
56
+
57
  # Initialize data on load
58
  demo.load(
59
  self.train_tab._count_training_images,
60
+ outputs=[self.train_tab.data_status],
61
  )
62
+
63
  return demo
64
+
65
  def _setup_cross_tab_interactions(self):
66
  """Setup interactions between tabs."""
67
  # Auto-send generated image to Identify tab
68
  self.generate_tab.output_image.change(
69
  self.identify_tab.set_image,
70
  inputs=self.generate_tab.output_image,
71
+ outputs=self.identify_tab.image_input,
72
  )
73
+
74
  def launch(self, **kwargs):
75
  """Launch the application."""
76
  demo = self.create_interface()
77
  # Add share=True for HF Spaces compatibility
78
  if config.is_hf_spaces:
79
+ kwargs.setdefault("share", True)
80
  return demo.queue().launch(**kwargs)
81
 
82
+
83
  def main():
84
  """Main entry point."""
85
  try:
86
  print("🌸 Starting Flowerify (SDXL models)")
87
  print("Loading models and initializing UI...")
88
+
89
  app = FlowerifyApp()
90
  app.launch()
91
+
92
  except KeyboardInterrupt:
93
  print("\nπŸ‘‹ Application stopped by user")
94
  except Exception as e:
 
96
  traceback.print_exc()
97
  sys.exit(1)
98
 
99
+
100
  if __name__ == "__main__":
101
+ main()
app_original.py CHANGED
@@ -1,13 +1,19 @@
1
- import os, torch, gradio as gr, json
2
- from diffusers import AutoPipelineForText2Image
3
- from transformers import pipeline, ConvNextImageProcessor, ConvNextForImageClassification, AutoImageProcessor, AutoModelForImageClassification
4
- from simple_train import simple_train
5
  import glob
6
- from pathlib import Path
7
- from PIL import Image
 
8
  import numpy as np
 
 
 
9
  from sklearn.cluster import KMeans
10
-
 
 
 
 
 
 
11
 
12
  MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
13
 
@@ -23,6 +29,7 @@ if device == "cuda":
23
  else:
24
  pipe.enable_attention_slicing()
25
 
 
26
  def generate(prompt, steps, width, height, seed):
27
  if seed is None or int(seed) < 0:
28
  generator = None
@@ -32,23 +39,49 @@ def generate(prompt, steps, width, height, seed):
32
  result = pipe(
33
  prompt=prompt,
34
  num_inference_steps=int(steps),
35
- guidance_scale=0.0, # SDXL-Turbo works best at 0.0
36
  width=int(width // 8) * 8,
37
  height=int(height // 8) * 8,
38
- generator=generator
39
  )
40
  return result.images[0]
41
 
42
 
43
-
44
  # ---------- Flower identification (zero-shot) ----------
45
  # Curated label set; edit/extend as you like
46
  FLOWER_LABELS = [
47
- "rose", "tulip", "lily", "peony", "sunflower", "chrysanthemum", "carnation",
48
- "orchid", "hydrangea", "daisy", "dahlia", "ranunculus", "anemone", "marigold",
49
- "lavender", "magnolia", "gardenia", "camellia", "jasmine", "iris", "gerbera",
50
- "zinnia", "hibiscus", "lotus", "poppy", "sweet pea", "freesia", "lisianthus",
51
- "calla lily", "cherry blossom", "plumeria", "cosmos"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ]
53
 
54
  # Initialize classifier - will be updated when trained model is loaded
@@ -58,6 +91,7 @@ convnext_model = None
58
  convnext_processor = None
59
  current_model_path = "facebook/convnext-base-224-22k"
60
 
 
61
  def load_classifier(model_path="facebook/convnext-base-224-22k"):
62
  global zs_classifier, convnext_model, convnext_processor, current_model_path
63
  try:
@@ -70,149 +104,174 @@ def load_classifier(model_path="facebook/convnext-base-224-22k"):
70
  zs_classifier = pipeline(
71
  task="zero-shot-image-classification",
72
  model="openai/clip-vit-base-patch32",
73
- device=clf_device
74
  )
75
  return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
76
  else:
77
  # Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
78
- convnext_model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224-22k")
79
- convnext_processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k")
 
 
 
 
80
  zs_classifier = pipeline(
81
  task="zero-shot-image-classification",
82
  model="openai/clip-vit-base-patch32",
83
- device=clf_device
84
  )
85
  current_model_path = "facebook/convnext-base-224-22k"
86
- return f"βœ… Loaded default ConvNeXt model: facebook/convnext-base-224-22k"
87
  except Exception as e:
88
- return f"❌ Error loading model: {str(e)}"
 
89
 
90
  # Initialize with default model
91
  load_classifier()
92
 
 
93
  def identify_flowers(image, candidate_labels, top_k, min_score):
94
  if image is None:
95
  return [], "Please provide an image (upload or generate first)."
96
-
97
  labels = candidate_labels if candidate_labels else FLOWER_LABELS
98
-
99
  # Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
100
- if convnext_model is not None and os.path.exists(current_model_path) and current_model_path != "facebook/convnext-base-224-22k":
 
 
 
 
101
  try:
102
  # Use trained ConvNeXt model
103
  inputs = convnext_processor(images=image, return_tensors="pt")
104
  with torch.no_grad():
105
  outputs = convnext_model(**inputs)
106
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
107
-
108
  # Convert predictions to results format
109
  results = []
110
  for i, score in enumerate(predictions[0]):
111
  if i < len(labels):
112
  results.append({"label": labels[i], "score": float(score)})
113
-
114
  # Sort by score
115
  results = sorted(results, key=lambda r: r["score"], reverse=True)
116
- except Exception as e:
117
  # Fallback to CLIP zero-shot
118
  results = zs_classifier(
119
- image,
120
- candidate_labels=labels,
121
- hypothesis_template="a photo of a {}"
122
  )
123
  else:
124
  # Use CLIP zero-shot classification
125
  results = zs_classifier(
126
- image,
127
- candidate_labels=labels,
128
- hypothesis_template="a photo of a {}"
129
  )
130
-
131
  # Filter and format results
132
  results = [r for r in results if r["score"] >= float(min_score)]
133
- results = sorted(results, key=lambda r: r["score"], reverse=True)[:int(top_k)]
134
  table = [[r["label"], round(float(r["score"]), 4)] for r in results]
135
- model_type = "ConvNeXt" if (convnext_model is not None and os.path.exists(current_model_path) and current_model_path != "facebook/convnext-base-224-22k") else "CLIP zero-shot"
 
 
 
 
 
 
 
 
136
  msg = f"Detected flowers using {model_type}."
137
  return table, msg
138
 
 
139
  # simple passthrough so the generated image appears in the Identify tab automatically
140
  def passthrough(img):
141
  return img
142
 
 
143
  # Training functions
144
  def get_available_models():
145
  models_dir = "training_data/trained_models"
146
  if not os.path.exists(models_dir):
147
  return ["facebook/convnext-base-224-22k (default)"]
148
-
149
  models = ["facebook/convnext-base-224-22k (default)"]
150
  for item in os.listdir(models_dir):
151
  model_path = os.path.join(models_dir, item)
152
- if os.path.isdir(model_path) and os.path.exists(os.path.join(model_path, "config.json")):
 
 
153
  models.append(f"Custom: {item}")
154
  return models
155
 
 
156
  def count_training_images():
157
  images_dir = "training_data/images"
158
  if not os.path.exists(images_dir):
159
  return "Training directory not found"
160
-
161
  total_images = 0
162
  flower_counts = {}
163
-
164
  for flower_type in os.listdir(images_dir):
165
  flower_path = os.path.join(images_dir, flower_type)
166
  if os.path.isdir(flower_path):
167
- image_files = glob.glob(os.path.join(flower_path, "*.jpg")) + \
168
- glob.glob(os.path.join(flower_path, "*.jpeg")) + \
169
- glob.glob(os.path.join(flower_path, "*.png")) + \
170
- glob.glob(os.path.join(flower_path, "*.webp"))
 
 
171
  count = len(image_files)
172
  if count > 0:
173
  flower_counts[flower_type] = count
174
  total_images += count
175
-
176
  if total_images == 0:
177
  return "No training images found. Add images to subdirectories in training_data/images/"
178
-
179
  result = f"**Total images: {total_images}**\n\n"
180
  for flower_type, count in sorted(flower_counts.items()):
181
  result += f"- {flower_type}: {count} images\n"
182
-
183
  return result
184
 
 
185
  def start_training(epochs=None, batch_size=None, learning_rate=None):
186
  try:
187
  # Check if training data exists
188
  images_dir = "training_data/images"
189
  if not os.path.exists(images_dir):
190
  return "❌ Training directory not found. Please create training_data/images/ and add your data."
191
-
192
  # Count images
193
  total_images = 0
194
  for flower_type in os.listdir(images_dir):
195
  flower_path = os.path.join(images_dir, flower_type)
196
  if os.path.isdir(flower_path):
197
- image_files = glob.glob(os.path.join(flower_path, "*.jpg")) + \
198
- glob.glob(os.path.join(flower_path, "*.jpeg")) + \
199
- glob.glob(os.path.join(flower_path, "*.png")) + \
200
- glob.glob(os.path.join(flower_path, "*.webp"))
 
 
201
  total_images += len(image_files)
202
-
203
  if total_images < 10:
204
  return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
205
-
206
  # Start training
207
  model_path = simple_train()
208
-
209
  if model_path:
210
  return f"βœ… Training completed! Model saved to: {model_path}"
211
  else:
212
  return "❌ Training failed. Check the console for details."
213
-
214
  except Exception as e:
215
- return f"❌ Training error: {str(e)}"
 
216
 
217
  def load_trained_model(model_selection):
218
  if model_selection.startswith("Custom:"):
@@ -222,25 +281,26 @@ def load_trained_model(model_selection):
222
  else:
223
  return load_classifier("facebook/convnext-base-224-22k")
224
 
 
225
  # French-style arrangement functions
226
  def extract_dominant_colors(image, num_colors=5):
227
  """Extract dominant colors from an image using k-means clustering"""
228
  if image is None:
229
  return [], "No image provided"
230
-
231
  # Convert PIL image to numpy array
232
  img_array = np.array(image)
233
-
234
  # Reshape image to be a list of pixels
235
  pixels = img_array.reshape(-1, 3)
236
-
237
  # Use k-means to find dominant colors
238
  kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
239
  kmeans.fit(pixels)
240
-
241
  # Get the colors and convert to RGB values
242
  colors = kmeans.cluster_centers_.astype(int)
243
-
244
  # Convert to color names/descriptions
245
  color_names = []
246
  for color in colors:
@@ -268,54 +328,59 @@ def extract_dominant_colors(image, num_colors=5):
268
  color_names.append("orange")
269
  else:
270
  color_names.append("cream")
271
-
272
  return color_names, colors
273
 
 
274
  def analyze_and_generate_french_style(image):
275
  """Analyze uploaded flower image and generate French-style arrangement"""
276
  if image is None:
277
  return None, "Please upload an image", ""
278
-
279
  # Identify the flower type
280
  if zs_classifier is None:
281
  return None, "Model not loaded", ""
282
-
283
  try:
284
  progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
285
-
286
  # Identify flower
287
  progress_log += "πŸ” Identifying flower type using AI model...\n"
288
  results = zs_classifier(
289
- image,
290
- candidate_labels=FLOWER_LABELS,
291
- hypothesis_template="a photo of a {}"
292
  )
293
-
294
  top_flower = results[0]["label"] if results else "flower"
295
  confidence = results[0]["score"] if results else 0
296
- progress_log += f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
297
-
 
 
298
  # Extract dominant colors
299
  progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
300
  progress_log += "🎨 Extracting dominant colors from image...\n"
301
  color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
302
-
303
  # Create color description
304
  main_colors = color_names[:3] # Top 3 colors
305
  color_desc = ", ".join(main_colors)
306
  progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
307
-
308
  # Generate French-style prompt
309
- progress_log += "πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
 
 
310
  prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
311
  progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
312
-
313
  # Generate the image
314
- progress_log += "πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
 
 
315
  progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
316
  generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
317
  progress_log += "βœ… French-style arrangement generated successfully!\n\n"
318
-
319
  # Create analysis summary
320
  analysis = f"""
321
  **🌸 Flower Analysis:**
@@ -330,14 +395,19 @@ def analyze_and_generate_french_style(image):
330
  **πŸ“‹ Process Log:**
331
  {progress_log}
332
  """
333
-
334
- return generated_image, "βœ… Analysis complete! French-style arrangement generated.", analysis
335
-
 
 
 
 
336
  except Exception as e:
337
- error_log = f"❌ **Error occurred during processing:**\n\n{str(e)}\n\n"
338
- if 'progress_log' in locals():
339
  error_log += f"**Progress before error:**\n{progress_log}"
340
- return None, f"❌ Error: {str(e)}", error_log
 
341
 
342
  # ---------- UI ----------
343
  with gr.Blocks() as demo:
@@ -347,66 +417,113 @@ with gr.Blocks() as demo:
347
  with gr.TabItem("Generate"):
348
  with gr.Row():
349
  with gr.Column():
350
- prompt = gr.Textbox(value="ikebana-style flower arrangement, soft natural light, minimalist", label="Prompt")
351
- steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
352
- width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
 
 
 
353
  height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
354
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
355
- go = gr.Button("Generate", variant="primary")
356
  out = gr.Image(label="Result", type="pil")
357
 
358
  with gr.TabItem("Identify"):
359
  with gr.Row():
360
  with gr.Column():
361
- img_in = gr.Image(label="Image (upload or auto-filled from 'Generate')", type="pil", interactive=True)
362
- labels_box = gr.CheckboxGroup(choices=FLOWER_LABELS, value=["rose","tulip","lily","peony","hydrangea","orchid","sunflower"], label="Candidate labels (edit as needed)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
364
- min_score = gr.Slider(0.0, 1.0, value=0.12, step=0.01, label="Min confidence")
 
 
365
  detect_btn = gr.Button("Identify Flowers", variant="primary")
366
  with gr.Column():
367
- results_tbl = gr.Dataframe(headers=["Flower", "Confidence"], datatype=["str", "number"], interactive=False)
 
 
 
 
368
  status = gr.Markdown()
369
 
370
  with gr.TabItem("Train Model"):
371
  gr.Markdown("## 🎯 Fine-tune the flower identification model")
372
- gr.Markdown("Organize your training images in subdirectories by flower type in `training_data/images/`")
373
- gr.Markdown("Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc.")
374
-
 
 
 
 
375
  with gr.Row():
376
  with gr.Column():
377
  gr.Markdown("### Training Data")
378
  refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
379
  data_status = gr.Markdown()
380
-
381
  gr.Markdown("### Training Parameters")
382
  epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
383
  batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
384
- learning_rate = gr.Number(value=1e-5, label="Learning Rate", precision=6)
385
-
 
 
386
  train_btn = gr.Button("πŸš€ Start Training", variant="primary")
387
-
388
  with gr.Column():
389
  gr.Markdown("### Model Management")
390
- model_dropdown = gr.Dropdown(choices=get_available_models(), value="facebook/convnext-base-224-22k (default)", label="Select Model")
 
 
 
 
391
  refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
392
- load_model_btn = gr.Button("πŸ“₯ Load Selected Model", variant="secondary")
393
-
394
- model_status = gr.Markdown(f"**Current model:** {current_model_path}")
395
-
 
 
 
 
396
  gr.Markdown("### Training Status")
397
  training_output = gr.Markdown()
398
 
399
  with gr.TabItem("French Style arrangement"):
400
  gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
401
- gr.Markdown("Upload a flower image and generate an elegant French-style arrangement with matching colors!")
402
-
 
 
403
  with gr.Row():
404
  with gr.Column():
405
  upload_img = gr.Image(label="Upload Flower Image", type="pil")
406
- analyze_btn = gr.Button("🎨 Analyze & Generate French Style", variant="primary", size="lg")
407
-
 
 
 
 
408
  with gr.Column():
409
- french_result = gr.Image(label="Generated French-Style Arrangement", type="pil")
 
 
410
  french_status = gr.Markdown()
411
  analysis_details = gr.Markdown()
412
 
@@ -415,28 +532,37 @@ with gr.Blocks() as demo:
415
  # Auto-send generated image to Identify tab
416
  out.change(passthrough, inputs=out, outputs=img_in)
417
  # Run identification
418
- detect_btn.click(identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status])
419
-
 
 
420
  # Training tab events
421
  refresh_btn.click(count_training_images, outputs=[data_status])
422
- refresh_models_btn.click(lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown])
423
- load_model_btn.click(load_trained_model, inputs=[model_dropdown], outputs=[model_status])
424
- train_btn.click(start_training, inputs=[epochs, batch_size, learning_rate], outputs=[training_output])
425
-
 
 
 
 
 
 
 
 
426
  # French Style tab events - update status during processing
427
  def update_french_status():
428
  return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
429
-
430
  analyze_btn.click(
431
- update_french_status,
432
- outputs=[french_status, analysis_details]
433
  ).then(
434
- analyze_and_generate_french_style,
435
- inputs=[upload_img],
436
- outputs=[french_result, french_status, analysis_details]
437
  )
438
-
439
  # Initialize data count on load
440
  demo.load(count_training_images, outputs=[data_status])
441
 
442
- demo.queue().launch()
 
 
 
 
 
1
  import glob
2
+ import os
3
+
4
+ import gradio as gr
5
  import numpy as np
6
+ import torch
7
+ from diffusers import AutoPipelineForText2Image
8
+ from simple_train import simple_train
9
  from sklearn.cluster import KMeans
10
+ from transformers import (
11
+ AutoImageProcessor,
12
+ AutoModelForImageClassification,
13
+ ConvNextForImageClassification,
14
+ ConvNextImageProcessor,
15
+ pipeline,
16
+ )
17
 
18
  MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
19
 
 
29
  else:
30
  pipe.enable_attention_slicing()
31
 
32
+
33
  def generate(prompt, steps, width, height, seed):
34
  if seed is None or int(seed) < 0:
35
  generator = None
 
39
  result = pipe(
40
  prompt=prompt,
41
  num_inference_steps=int(steps),
42
+ guidance_scale=0.0, # SDXL-Turbo works best at 0.0
43
  width=int(width // 8) * 8,
44
  height=int(height // 8) * 8,
45
+ generator=generator,
46
  )
47
  return result.images[0]
48
 
49
 
 
50
  # ---------- Flower identification (zero-shot) ----------
51
  # Curated label set; edit/extend as you like
52
  FLOWER_LABELS = [
53
+ "rose",
54
+ "tulip",
55
+ "lily",
56
+ "peony",
57
+ "sunflower",
58
+ "chrysanthemum",
59
+ "carnation",
60
+ "orchid",
61
+ "hydrangea",
62
+ "daisy",
63
+ "dahlia",
64
+ "ranunculus",
65
+ "anemone",
66
+ "marigold",
67
+ "lavender",
68
+ "magnolia",
69
+ "gardenia",
70
+ "camellia",
71
+ "jasmine",
72
+ "iris",
73
+ "gerbera",
74
+ "zinnia",
75
+ "hibiscus",
76
+ "lotus",
77
+ "poppy",
78
+ "sweet pea",
79
+ "freesia",
80
+ "lisianthus",
81
+ "calla lily",
82
+ "cherry blossom",
83
+ "plumeria",
84
+ "cosmos",
85
  ]
86
 
87
  # Initialize classifier - will be updated when trained model is loaded
 
91
  convnext_processor = None
92
  current_model_path = "facebook/convnext-base-224-22k"
93
 
94
+
95
  def load_classifier(model_path="facebook/convnext-base-224-22k"):
96
  global zs_classifier, convnext_model, convnext_processor, current_model_path
97
  try:
 
104
  zs_classifier = pipeline(
105
  task="zero-shot-image-classification",
106
  model="openai/clip-vit-base-patch32",
107
+ device=clf_device,
108
  )
109
  return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
110
  else:
111
  # Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
112
+ convnext_model = ConvNextForImageClassification.from_pretrained(
113
+ "facebook/convnext-base-224-22k"
114
+ )
115
+ convnext_processor = ConvNextImageProcessor.from_pretrained(
116
+ "facebook/convnext-base-224-22k"
117
+ )
118
  zs_classifier = pipeline(
119
  task="zero-shot-image-classification",
120
  model="openai/clip-vit-base-patch32",
121
+ device=clf_device,
122
  )
123
  current_model_path = "facebook/convnext-base-224-22k"
124
+ return "βœ… Loaded default ConvNeXt model: facebook/convnext-base-224-22k"
125
  except Exception as e:
126
+ return f"❌ Error loading model: {e!s}"
127
+
128
 
129
  # Initialize with default model
130
  load_classifier()
131
 
132
+
133
  def identify_flowers(image, candidate_labels, top_k, min_score):
134
  if image is None:
135
  return [], "Please provide an image (upload or generate first)."
136
+
137
  labels = candidate_labels if candidate_labels else FLOWER_LABELS
138
+
139
  # Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
140
+ if (
141
+ convnext_model is not None
142
+ and os.path.exists(current_model_path)
143
+ and current_model_path != "facebook/convnext-base-224-22k"
144
+ ):
145
  try:
146
  # Use trained ConvNeXt model
147
  inputs = convnext_processor(images=image, return_tensors="pt")
148
  with torch.no_grad():
149
  outputs = convnext_model(**inputs)
150
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
151
+
152
  # Convert predictions to results format
153
  results = []
154
  for i, score in enumerate(predictions[0]):
155
  if i < len(labels):
156
  results.append({"label": labels[i], "score": float(score)})
157
+
158
  # Sort by score
159
  results = sorted(results, key=lambda r: r["score"], reverse=True)
160
+ except Exception:
161
  # Fallback to CLIP zero-shot
162
  results = zs_classifier(
163
+ image, candidate_labels=labels, hypothesis_template="a photo of a {}"
 
 
164
  )
165
  else:
166
  # Use CLIP zero-shot classification
167
  results = zs_classifier(
168
+ image, candidate_labels=labels, hypothesis_template="a photo of a {}"
 
 
169
  )
170
+
171
  # Filter and format results
172
  results = [r for r in results if r["score"] >= float(min_score)]
173
+ results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)]
174
  table = [[r["label"], round(float(r["score"]), 4)] for r in results]
175
+ model_type = (
176
+ "ConvNeXt"
177
+ if (
178
+ convnext_model is not None
179
+ and os.path.exists(current_model_path)
180
+ and current_model_path != "facebook/convnext-base-224-22k"
181
+ )
182
+ else "CLIP zero-shot"
183
+ )
184
  msg = f"Detected flowers using {model_type}."
185
  return table, msg
186
 
187
+
188
  # simple passthrough so the generated image appears in the Identify tab automatically
189
  def passthrough(img):
190
  return img
191
 
192
+
193
  # Training functions
194
  def get_available_models():
195
  models_dir = "training_data/trained_models"
196
  if not os.path.exists(models_dir):
197
  return ["facebook/convnext-base-224-22k (default)"]
198
+
199
  models = ["facebook/convnext-base-224-22k (default)"]
200
  for item in os.listdir(models_dir):
201
  model_path = os.path.join(models_dir, item)
202
+ if os.path.isdir(model_path) and os.path.exists(
203
+ os.path.join(model_path, "config.json")
204
+ ):
205
  models.append(f"Custom: {item}")
206
  return models
207
 
208
+
209
  def count_training_images():
210
  images_dir = "training_data/images"
211
  if not os.path.exists(images_dir):
212
  return "Training directory not found"
213
+
214
  total_images = 0
215
  flower_counts = {}
216
+
217
  for flower_type in os.listdir(images_dir):
218
  flower_path = os.path.join(images_dir, flower_type)
219
  if os.path.isdir(flower_path):
220
+ image_files = (
221
+ glob.glob(os.path.join(flower_path, "*.jpg"))
222
+ + glob.glob(os.path.join(flower_path, "*.jpeg"))
223
+ + glob.glob(os.path.join(flower_path, "*.png"))
224
+ + glob.glob(os.path.join(flower_path, "*.webp"))
225
+ )
226
  count = len(image_files)
227
  if count > 0:
228
  flower_counts[flower_type] = count
229
  total_images += count
230
+
231
  if total_images == 0:
232
  return "No training images found. Add images to subdirectories in training_data/images/"
233
+
234
  result = f"**Total images: {total_images}**\n\n"
235
  for flower_type, count in sorted(flower_counts.items()):
236
  result += f"- {flower_type}: {count} images\n"
237
+
238
  return result
239
 
240
+
241
  def start_training(epochs=None, batch_size=None, learning_rate=None):
242
  try:
243
  # Check if training data exists
244
  images_dir = "training_data/images"
245
  if not os.path.exists(images_dir):
246
  return "❌ Training directory not found. Please create training_data/images/ and add your data."
247
+
248
  # Count images
249
  total_images = 0
250
  for flower_type in os.listdir(images_dir):
251
  flower_path = os.path.join(images_dir, flower_type)
252
  if os.path.isdir(flower_path):
253
+ image_files = (
254
+ glob.glob(os.path.join(flower_path, "*.jpg"))
255
+ + glob.glob(os.path.join(flower_path, "*.jpeg"))
256
+ + glob.glob(os.path.join(flower_path, "*.png"))
257
+ + glob.glob(os.path.join(flower_path, "*.webp"))
258
+ )
259
  total_images += len(image_files)
260
+
261
  if total_images < 10:
262
  return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
263
+
264
  # Start training
265
  model_path = simple_train()
266
+
267
  if model_path:
268
  return f"βœ… Training completed! Model saved to: {model_path}"
269
  else:
270
  return "❌ Training failed. Check the console for details."
271
+
272
  except Exception as e:
273
+ return f"❌ Training error: {e!s}"
274
+
275
 
276
  def load_trained_model(model_selection):
277
  if model_selection.startswith("Custom:"):
 
281
  else:
282
  return load_classifier("facebook/convnext-base-224-22k")
283
 
284
+
285
  # French-style arrangement functions
286
  def extract_dominant_colors(image, num_colors=5):
287
  """Extract dominant colors from an image using k-means clustering"""
288
  if image is None:
289
  return [], "No image provided"
290
+
291
  # Convert PIL image to numpy array
292
  img_array = np.array(image)
293
+
294
  # Reshape image to be a list of pixels
295
  pixels = img_array.reshape(-1, 3)
296
+
297
  # Use k-means to find dominant colors
298
  kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
299
  kmeans.fit(pixels)
300
+
301
  # Get the colors and convert to RGB values
302
  colors = kmeans.cluster_centers_.astype(int)
303
+
304
  # Convert to color names/descriptions
305
  color_names = []
306
  for color in colors:
 
328
  color_names.append("orange")
329
  else:
330
  color_names.append("cream")
331
+
332
  return color_names, colors
333
 
334
+
335
  def analyze_and_generate_french_style(image):
336
  """Analyze uploaded flower image and generate French-style arrangement"""
337
  if image is None:
338
  return None, "Please upload an image", ""
339
+
340
  # Identify the flower type
341
  if zs_classifier is None:
342
  return None, "Model not loaded", ""
343
+
344
  try:
345
  progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
346
+
347
  # Identify flower
348
  progress_log += "πŸ” Identifying flower type using AI model...\n"
349
  results = zs_classifier(
350
+ image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}"
 
 
351
  )
352
+
353
  top_flower = results[0]["label"] if results else "flower"
354
  confidence = results[0]["score"] if results else 0
355
+ progress_log += (
356
+ f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
357
+ )
358
+
359
  # Extract dominant colors
360
  progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
361
  progress_log += "🎨 Extracting dominant colors from image...\n"
362
  color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
363
+
364
  # Create color description
365
  main_colors = color_names[:3] # Top 3 colors
366
  color_desc = ", ".join(main_colors)
367
  progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
368
+
369
  # Generate French-style prompt
370
+ progress_log += (
371
+ "πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
372
+ )
373
  prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
374
  progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
375
+
376
  # Generate the image
377
+ progress_log += (
378
+ "πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
379
+ )
380
  progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
381
  generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
382
  progress_log += "βœ… French-style arrangement generated successfully!\n\n"
383
+
384
  # Create analysis summary
385
  analysis = f"""
386
  **🌸 Flower Analysis:**
 
395
  **πŸ“‹ Process Log:**
396
  {progress_log}
397
  """
398
+
399
+ return (
400
+ generated_image,
401
+ "βœ… Analysis complete! French-style arrangement generated.",
402
+ analysis,
403
+ )
404
+
405
  except Exception as e:
406
+ error_log = f"❌ **Error occurred during processing:**\n\n{e!s}\n\n"
407
+ if "progress_log" in locals():
408
  error_log += f"**Progress before error:**\n{progress_log}"
409
+ return None, f"❌ Error: {e!s}", error_log
410
+
411
 
412
  # ---------- UI ----------
413
  with gr.Blocks() as demo:
 
417
  with gr.TabItem("Generate"):
418
  with gr.Row():
419
  with gr.Column():
420
+ prompt = gr.Textbox(
421
+ value="ikebana-style flower arrangement, soft natural light, minimalist",
422
+ label="Prompt",
423
+ )
424
+ steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
425
+ width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
426
  height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
427
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
428
+ go = gr.Button("Generate", variant="primary")
429
  out = gr.Image(label="Result", type="pil")
430
 
431
  with gr.TabItem("Identify"):
432
  with gr.Row():
433
  with gr.Column():
434
+ img_in = gr.Image(
435
+ label="Image (upload or auto-filled from 'Generate')",
436
+ type="pil",
437
+ interactive=True,
438
+ )
439
+ labels_box = gr.CheckboxGroup(
440
+ choices=FLOWER_LABELS,
441
+ value=[
442
+ "rose",
443
+ "tulip",
444
+ "lily",
445
+ "peony",
446
+ "hydrangea",
447
+ "orchid",
448
+ "sunflower",
449
+ ],
450
+ label="Candidate labels (edit as needed)",
451
+ )
452
  topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
453
+ min_score = gr.Slider(
454
+ 0.0, 1.0, value=0.12, step=0.01, label="Min confidence"
455
+ )
456
  detect_btn = gr.Button("Identify Flowers", variant="primary")
457
  with gr.Column():
458
+ results_tbl = gr.Dataframe(
459
+ headers=["Flower", "Confidence"],
460
+ datatype=["str", "number"],
461
+ interactive=False,
462
+ )
463
  status = gr.Markdown()
464
 
465
  with gr.TabItem("Train Model"):
466
  gr.Markdown("## 🎯 Fine-tune the flower identification model")
467
+ gr.Markdown(
468
+ "Organize your training images in subdirectories by flower type in `training_data/images/`"
469
+ )
470
+ gr.Markdown(
471
+ "Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
472
+ )
473
+
474
  with gr.Row():
475
  with gr.Column():
476
  gr.Markdown("### Training Data")
477
  refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
478
  data_status = gr.Markdown()
479
+
480
  gr.Markdown("### Training Parameters")
481
  epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
482
  batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
483
+ learning_rate = gr.Number(
484
+ value=1e-5, label="Learning Rate", precision=6
485
+ )
486
+
487
  train_btn = gr.Button("πŸš€ Start Training", variant="primary")
488
+
489
  with gr.Column():
490
  gr.Markdown("### Model Management")
491
+ model_dropdown = gr.Dropdown(
492
+ choices=get_available_models(),
493
+ value="facebook/convnext-base-224-22k (default)",
494
+ label="Select Model",
495
+ )
496
  refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
497
+ load_model_btn = gr.Button(
498
+ "πŸ“₯ Load Selected Model", variant="secondary"
499
+ )
500
+
501
+ model_status = gr.Markdown(
502
+ f"**Current model:** {current_model_path}"
503
+ )
504
+
505
  gr.Markdown("### Training Status")
506
  training_output = gr.Markdown()
507
 
508
  with gr.TabItem("French Style arrangement"):
509
  gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
510
+ gr.Markdown(
511
+ "Upload a flower image and generate an elegant French-style arrangement with matching colors!"
512
+ )
513
+
514
  with gr.Row():
515
  with gr.Column():
516
  upload_img = gr.Image(label="Upload Flower Image", type="pil")
517
+ analyze_btn = gr.Button(
518
+ "🎨 Analyze & Generate French Style",
519
+ variant="primary",
520
+ size="lg",
521
+ )
522
+
523
  with gr.Column():
524
+ french_result = gr.Image(
525
+ label="Generated French-Style Arrangement", type="pil"
526
+ )
527
  french_status = gr.Markdown()
528
  analysis_details = gr.Markdown()
529
 
 
532
  # Auto-send generated image to Identify tab
533
  out.change(passthrough, inputs=out, outputs=img_in)
534
  # Run identification
535
+ detect_btn.click(
536
+ identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status]
537
+ )
538
+
539
  # Training tab events
540
  refresh_btn.click(count_training_images, outputs=[data_status])
541
+ refresh_models_btn.click(
542
+ lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown]
543
+ )
544
+ load_model_btn.click(
545
+ load_trained_model, inputs=[model_dropdown], outputs=[model_status]
546
+ )
547
+ train_btn.click(
548
+ start_training,
549
+ inputs=[epochs, batch_size, learning_rate],
550
+ outputs=[training_output],
551
+ )
552
+
553
  # French Style tab events - update status during processing
554
  def update_french_status():
555
  return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
556
+
557
  analyze_btn.click(
558
+ update_french_status, outputs=[french_status, analysis_details]
 
559
  ).then(
560
+ analyze_and_generate_french_style,
561
+ inputs=[upload_img],
562
+ outputs=[french_result, french_status, analysis_details],
563
  )
564
+
565
  # Initialize data count on load
566
  demo.load(count_training_images, outputs=[data_status])
567
 
568
+ demo.queue().launch()
requirements.txt CHANGED
@@ -234,7 +234,7 @@ pydantic==2.10.6
234
  # via
235
  # fastapi
236
  # gradio
237
- pydantic-core==2.27.0
238
  # via pydantic
239
  pydub==0.25.1
240
  # via gradio
 
234
  # via
235
  # fastapi
236
  # gradio
237
+ pydantic-core==2.27.2
238
  # via pydantic
239
  pydub==0.25.1
240
  # via gradio
src/__init__.py CHANGED
@@ -1 +1 @@
1
- # Flowerify application package
 
1
+ # Flowerify application package
src/core/__init__.py CHANGED
@@ -1 +1 @@
1
- # Core package
 
1
+ # Core package
src/core/config.py CHANGED
@@ -3,26 +3,29 @@ Configuration management for the application.
3
  """
4
 
5
  import os
 
6
  import torch
 
7
  from .constants import DEFAULT_MODEL_ID
8
 
 
9
  class AppConfig:
10
  """Application configuration singleton."""
11
-
12
  def __init__(self):
13
  self._setup_device()
14
  self.model_id = DEFAULT_MODEL_ID
15
  # Auto-detect Hugging Face Spaces environment
16
  self.is_hf_spaces = os.getenv("SPACE_ID") is not None
17
  self._setup_cache_paths()
18
-
19
  def _setup_device(self):
20
  """Setup device configuration for PyTorch."""
21
  if torch.cuda.is_available():
22
  self.device = "cuda"
23
  self.dtype = torch.float16
24
  self.clf_device = 0
25
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
26
  self.device = "mps"
27
  self.dtype = torch.float16
28
  self.clf_device = 0
@@ -30,7 +33,7 @@ class AppConfig:
30
  self.device = "cpu"
31
  self.dtype = torch.float32
32
  self.clf_device = -1
33
-
34
  def _setup_cache_paths(self):
35
  """Setup cache paths based on environment."""
36
  if self.is_hf_spaces:
@@ -48,16 +51,17 @@ class AppConfig:
48
  print(f"🏠 Using configured HF_HOME: {os.getenv('HF_HOME')}")
49
  else:
50
  print("🏠 Using default Hugging Face cache")
51
-
52
  @property
53
  def is_cuda_available(self):
54
  """Check if CUDA is available."""
55
  return torch.cuda.is_available()
56
-
57
  @property
58
  def is_mps_available(self):
59
  """Check if Apple MPS is available."""
60
- return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
 
61
 
62
  # Global configuration instance
63
- config = AppConfig()
 
3
  """
4
 
5
  import os
6
+
7
  import torch
8
+
9
  from .constants import DEFAULT_MODEL_ID
10
 
11
+
12
  class AppConfig:
13
  """Application configuration singleton."""
14
+
15
  def __init__(self):
16
  self._setup_device()
17
  self.model_id = DEFAULT_MODEL_ID
18
  # Auto-detect Hugging Face Spaces environment
19
  self.is_hf_spaces = os.getenv("SPACE_ID") is not None
20
  self._setup_cache_paths()
21
+
22
  def _setup_device(self):
23
  """Setup device configuration for PyTorch."""
24
  if torch.cuda.is_available():
25
  self.device = "cuda"
26
  self.dtype = torch.float16
27
  self.clf_device = 0
28
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
  self.device = "mps"
30
  self.dtype = torch.float16
31
  self.clf_device = 0
 
33
  self.device = "cpu"
34
  self.dtype = torch.float32
35
  self.clf_device = -1
36
+
37
  def _setup_cache_paths(self):
38
  """Setup cache paths based on environment."""
39
  if self.is_hf_spaces:
 
51
  print(f"🏠 Using configured HF_HOME: {os.getenv('HF_HOME')}")
52
  else:
53
  print("🏠 Using default Hugging Face cache")
54
+
55
  @property
56
  def is_cuda_available(self):
57
  """Check if CUDA is available."""
58
  return torch.cuda.is_available()
59
+
60
  @property
61
  def is_mps_available(self):
62
  """Check if Apple MPS is available."""
63
+ return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
64
+
65
 
66
  # Global configuration instance
67
+ config = AppConfig()
src/core/constants.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  # If using external SSD, models will be cached at /Volumes/extssd/huggingface/hub
7
  # This is configured via environment variables (see .env file and run.sh script)
8
 
9
- # Model configuration
10
  DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
11
  FALLBACK_MODEL_ID = "stabilityai/sdxl-turbo" # Lightweight fallback model
12
  DEFAULT_CONVNEXT_MODEL = "facebook/convnext-tiny-224"
@@ -62,4 +62,4 @@ DEFAULT_MIN_SCORE = 0.12
62
  DEFAULT_NUM_COLORS = 3
63
 
64
  # File extensions for image files
65
- SUPPORTED_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]
 
6
  # If using external SSD, models will be cached at /Volumes/extssd/huggingface/hub
7
  # This is configured via environment variables (see .env file and run.sh script)
8
 
9
+ # Model configuration
10
  DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
11
  FALLBACK_MODEL_ID = "stabilityai/sdxl-turbo" # Lightweight fallback model
12
  DEFAULT_CONVNEXT_MODEL = "facebook/convnext-tiny-224"
 
62
  DEFAULT_NUM_COLORS = 3
63
 
64
  # File extensions for image files
65
+ SUPPORTED_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]
src/services/__init__.py CHANGED
@@ -1 +1 @@
1
- # Services package
 
1
+ # Services package
src/services/models/__init__.py CHANGED
@@ -1 +1 @@
1
- # Models package
 
1
+ # Models package
src/services/models/flower_classification.py CHANGED
@@ -3,93 +3,120 @@ Flower classification service using ConvNeXt and CLIP models.
3
  """
4
 
5
  import os
 
6
  import torch
 
7
  from transformers import (
8
- pipeline, ConvNextImageProcessor, ConvNextForImageClassification,
9
- AutoImageProcessor, AutoModelForImageClassification
 
 
 
10
  )
11
- from PIL import Image
12
- from typing import List, Dict, Tuple, Optional
13
 
14
  try:
15
  from core.config import config
16
- from core.constants import DEFAULT_CONVNEXT_MODEL, DEFAULT_CLIP_MODEL, FLOWER_LABELS, MODELS_DIR
 
 
 
 
 
17
  except ImportError:
18
- import sys
19
  import os
 
 
20
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
21
  from core.config import config
22
- from core.constants import DEFAULT_CONVNEXT_MODEL, DEFAULT_CLIP_MODEL, FLOWER_LABELS, MODELS_DIR
 
 
 
 
 
 
23
 
24
  class FlowerClassificationService:
25
  """Service for flower classification using ConvNeXt and CLIP models."""
26
-
27
  def __init__(self):
28
  self.zs_classifier = None
29
  self.convnext_model = None
30
  self.convnext_processor = None
31
  self.current_model_path = DEFAULT_CONVNEXT_MODEL
32
  self._initialize_models()
33
-
34
  def _initialize_models(self):
35
  """Initialize the classification models."""
36
  self.load_classifier()
37
-
38
  def load_classifier(self, model_path: str = DEFAULT_CONVNEXT_MODEL) -> str:
39
  """Load classification model from path."""
40
  try:
41
  if os.path.exists(model_path):
42
  # Load custom trained model
43
- self.convnext_model = AutoModelForImageClassification.from_pretrained(model_path)
 
 
44
  self.convnext_processor = AutoImageProcessor.from_pretrained(model_path)
45
  self.current_model_path = model_path
46
  # Also keep zero-shot classifier for fallback
47
  self.zs_classifier = pipeline(
48
  task="zero-shot-image-classification",
49
  model=DEFAULT_CLIP_MODEL,
50
- device=config.clf_device
51
  )
52
  return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
53
  else:
54
  # Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
55
- self.convnext_model = ConvNextForImageClassification.from_pretrained(DEFAULT_CONVNEXT_MODEL)
56
- self.convnext_processor = ConvNextImageProcessor.from_pretrained(DEFAULT_CONVNEXT_MODEL)
 
 
 
 
57
  self.zs_classifier = pipeline(
58
  task="zero-shot-image-classification",
59
  model=DEFAULT_CLIP_MODEL,
60
- device=config.clf_device
61
  )
62
  self.current_model_path = DEFAULT_CONVNEXT_MODEL
63
  return f"βœ… Loaded default ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}"
64
  except Exception as e:
65
- return f"❌ Error loading model: {str(e)}"
66
-
67
- def identify_flowers(self, image: Optional[Image.Image],
68
- candidate_labels: Optional[List[str]] = None,
69
- top_k: int = 7, min_score: float = 0.12) -> Tuple[List[List], str]:
 
 
 
 
70
  """Identify flowers in an image."""
71
  if image is None:
72
  return [], "Please provide an image (upload or generate first)."
73
-
74
  labels = candidate_labels if candidate_labels else FLOWER_LABELS
75
-
76
  # Use ConvNeXt for feature extraction if we have a trained model
77
- if (self.convnext_model is not None and
78
- os.path.exists(self.current_model_path) and
79
- self.current_model_path != DEFAULT_CONVNEXT_MODEL):
 
 
80
  try:
81
  # Use trained ConvNeXt model
82
  inputs = self.convnext_processor(images=image, return_tensors="pt")
83
  with torch.no_grad():
84
  outputs = self.convnext_model(**inputs)
85
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
86
-
87
  # Convert predictions to results format
88
  results = []
89
  for i, score in enumerate(predictions[0]):
90
  if i < len(labels):
91
  results.append({"label": labels[i], "score": float(score)})
92
-
93
  # Sort by score
94
  results = sorted(results, key=lambda r: r["score"], reverse=True)
95
  model_type = "ConvNeXt"
@@ -101,35 +128,36 @@ class FlowerClassificationService:
101
  # Use CLIP zero-shot classification
102
  results = self._use_clip_classification(image, labels)
103
  model_type = "CLIP zero-shot"
104
-
105
  # Filter and format results
106
  results = [r for r in results if r["score"] >= min_score]
107
  results = sorted(results, key=lambda r: r["score"], reverse=True)[:top_k]
108
  table = [[r["label"], round(float(r["score"]), 4)] for r in results]
109
  msg = f"Detected flowers using {model_type}."
110
  return table, msg
111
-
112
- def _use_clip_classification(self, image: Image.Image, labels: List[str]) -> List[Dict]:
 
 
113
  """Use CLIP zero-shot classification."""
114
  return self.zs_classifier(
115
- image,
116
- candidate_labels=labels,
117
- hypothesis_template="a photo of a {}"
118
  )
119
-
120
- def get_available_models(self) -> List[str]:
121
  """Get list of available models."""
122
  models = [f"{DEFAULT_CONVNEXT_MODEL} (default)"]
123
-
124
  if os.path.exists(MODELS_DIR):
125
  for item in os.listdir(MODELS_DIR):
126
  model_path = os.path.join(MODELS_DIR, item)
127
- if (os.path.isdir(model_path) and
128
- os.path.exists(os.path.join(model_path, "config.json"))):
 
129
  models.append(f"Custom: {item}")
130
-
131
  return models
132
-
133
  def load_trained_model(self, model_selection: str) -> str:
134
  """Load a specific trained model."""
135
  if model_selection.startswith("Custom:"):
@@ -139,5 +167,6 @@ class FlowerClassificationService:
139
  else:
140
  return self.load_classifier(DEFAULT_CONVNEXT_MODEL)
141
 
 
142
  # Global service instance
143
- flower_classifier = FlowerClassificationService()
 
3
  """
4
 
5
  import os
6
+
7
  import torch
8
+ from PIL import Image
9
  from transformers import (
10
+ AutoImageProcessor,
11
+ AutoModelForImageClassification,
12
+ ConvNextForImageClassification,
13
+ ConvNextImageProcessor,
14
+ pipeline,
15
  )
 
 
16
 
17
  try:
18
  from core.config import config
19
+ from core.constants import (
20
+ DEFAULT_CLIP_MODEL,
21
+ DEFAULT_CONVNEXT_MODEL,
22
+ FLOWER_LABELS,
23
+ MODELS_DIR,
24
+ )
25
  except ImportError:
 
26
  import os
27
+ import sys
28
+
29
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
30
  from core.config import config
31
+ from core.constants import (
32
+ DEFAULT_CLIP_MODEL,
33
+ DEFAULT_CONVNEXT_MODEL,
34
+ FLOWER_LABELS,
35
+ MODELS_DIR,
36
+ )
37
+
38
 
39
  class FlowerClassificationService:
40
  """Service for flower classification using ConvNeXt and CLIP models."""
41
+
42
  def __init__(self):
43
  self.zs_classifier = None
44
  self.convnext_model = None
45
  self.convnext_processor = None
46
  self.current_model_path = DEFAULT_CONVNEXT_MODEL
47
  self._initialize_models()
48
+
49
  def _initialize_models(self):
50
  """Initialize the classification models."""
51
  self.load_classifier()
52
+
53
  def load_classifier(self, model_path: str = DEFAULT_CONVNEXT_MODEL) -> str:
54
  """Load classification model from path."""
55
  try:
56
  if os.path.exists(model_path):
57
  # Load custom trained model
58
+ self.convnext_model = AutoModelForImageClassification.from_pretrained(
59
+ model_path
60
+ )
61
  self.convnext_processor = AutoImageProcessor.from_pretrained(model_path)
62
  self.current_model_path = model_path
63
  # Also keep zero-shot classifier for fallback
64
  self.zs_classifier = pipeline(
65
  task="zero-shot-image-classification",
66
  model=DEFAULT_CLIP_MODEL,
67
+ device=config.clf_device,
68
  )
69
  return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
70
  else:
71
  # Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
72
+ self.convnext_model = ConvNextForImageClassification.from_pretrained(
73
+ DEFAULT_CONVNEXT_MODEL
74
+ )
75
+ self.convnext_processor = ConvNextImageProcessor.from_pretrained(
76
+ DEFAULT_CONVNEXT_MODEL
77
+ )
78
  self.zs_classifier = pipeline(
79
  task="zero-shot-image-classification",
80
  model=DEFAULT_CLIP_MODEL,
81
+ device=config.clf_device,
82
  )
83
  self.current_model_path = DEFAULT_CONVNEXT_MODEL
84
  return f"βœ… Loaded default ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}"
85
  except Exception as e:
86
+ return f"❌ Error loading model: {e!s}"
87
+
88
+ def identify_flowers(
89
+ self,
90
+ image: Image.Image | None,
91
+ candidate_labels: list[str] | None = None,
92
+ top_k: int = 7,
93
+ min_score: float = 0.12,
94
+ ) -> tuple[list[list], str]:
95
  """Identify flowers in an image."""
96
  if image is None:
97
  return [], "Please provide an image (upload or generate first)."
98
+
99
  labels = candidate_labels if candidate_labels else FLOWER_LABELS
100
+
101
  # Use ConvNeXt for feature extraction if we have a trained model
102
+ if (
103
+ self.convnext_model is not None
104
+ and os.path.exists(self.current_model_path)
105
+ and self.current_model_path != DEFAULT_CONVNEXT_MODEL
106
+ ):
107
  try:
108
  # Use trained ConvNeXt model
109
  inputs = self.convnext_processor(images=image, return_tensors="pt")
110
  with torch.no_grad():
111
  outputs = self.convnext_model(**inputs)
112
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
113
+
114
  # Convert predictions to results format
115
  results = []
116
  for i, score in enumerate(predictions[0]):
117
  if i < len(labels):
118
  results.append({"label": labels[i], "score": float(score)})
119
+
120
  # Sort by score
121
  results = sorted(results, key=lambda r: r["score"], reverse=True)
122
  model_type = "ConvNeXt"
 
128
  # Use CLIP zero-shot classification
129
  results = self._use_clip_classification(image, labels)
130
  model_type = "CLIP zero-shot"
131
+
132
  # Filter and format results
133
  results = [r for r in results if r["score"] >= min_score]
134
  results = sorted(results, key=lambda r: r["score"], reverse=True)[:top_k]
135
  table = [[r["label"], round(float(r["score"]), 4)] for r in results]
136
  msg = f"Detected flowers using {model_type}."
137
  return table, msg
138
+
139
+ def _use_clip_classification(
140
+ self, image: Image.Image, labels: list[str]
141
+ ) -> list[dict]:
142
  """Use CLIP zero-shot classification."""
143
  return self.zs_classifier(
144
+ image, candidate_labels=labels, hypothesis_template="a photo of a {}"
 
 
145
  )
146
+
147
+ def get_available_models(self) -> list[str]:
148
  """Get list of available models."""
149
  models = [f"{DEFAULT_CONVNEXT_MODEL} (default)"]
150
+
151
  if os.path.exists(MODELS_DIR):
152
  for item in os.listdir(MODELS_DIR):
153
  model_path = os.path.join(MODELS_DIR, item)
154
+ if os.path.isdir(model_path) and os.path.exists(
155
+ os.path.join(model_path, "config.json")
156
+ ):
157
  models.append(f"Custom: {item}")
158
+
159
  return models
160
+
161
  def load_trained_model(self, model_selection: str) -> str:
162
  """Load a specific trained model."""
163
  if model_selection.startswith("Custom:"):
 
167
  else:
168
  return self.load_classifier(DEFAULT_CONVNEXT_MODEL)
169
 
170
+
171
  # Global service instance
172
+ flower_classifier = FlowerClassificationService()
src/services/models/image_generation.py CHANGED
@@ -1,9 +1,8 @@
1
  """Image generation service using SDXL models."""
2
 
3
- from typing import Optional
4
 
5
- import torch
6
  import numpy as np
 
7
  from diffusers import AutoPipelineForText2Image
8
  from PIL import Image
9
 
@@ -18,6 +17,7 @@ except ImportError:
18
  from core.config import config
19
  from core.constants import DEFAULT_MODEL_ID, FALLBACK_MODEL_ID
20
 
 
21
  class ImageGenerationService:
22
  """Service for generating images using SDXL models."""
23
 
@@ -35,7 +35,7 @@ class ImageGenerationService:
35
  else:
36
  model_id = DEFAULT_MODEL_ID
37
  model_name = "SDXL"
38
-
39
  # Try primary SDXL model
40
  try:
41
  print(f"πŸ”„ Attempting to load {model_name}: {model_id}")
@@ -44,7 +44,7 @@ class ImageGenerationService:
44
  ).to(config.device)
45
  self.model_type = "SDXL"
46
  print(f"βœ… {model_name} loaded successfully")
47
-
48
  # Enable SDXL-specific optimizations
49
  if config.device == "cuda":
50
  try:
@@ -53,10 +53,10 @@ class ImageGenerationService:
53
  self.pipe.enable_attention_slicing()
54
  else:
55
  self.pipe.enable_attention_slicing()
56
-
57
  except Exception as e:
58
  print(f"⚠️ {model_name} failed to load: {e}")
59
-
60
  # Try fallback to SDXL-Turbo if we're not on HF Spaces and not using it already
61
  if not config.is_hf_spaces and model_id != FALLBACK_MODEL_ID:
62
  try:
@@ -66,7 +66,7 @@ class ImageGenerationService:
66
  ).to(config.device)
67
  self.model_type = "SDXL"
68
  print("βœ… SDXL-Turbo loaded successfully")
69
-
70
  # Enable optimizations
71
  if config.device == "cuda":
72
  try:
@@ -78,17 +78,19 @@ class ImageGenerationService:
78
  return
79
  except Exception as turbo_error:
80
  print(f"⚠️ SDXL-Turbo also failed to load: {turbo_error}")
81
- raise RuntimeError(f"All SDXL models failed to load. Last error: {turbo_error}")
 
 
82
  else:
83
  raise RuntimeError(f"SDXL model failed to load: {e}")
84
-
85
  def generate(
86
  self,
87
  prompt: str,
88
  steps: int = 4,
89
  width: int = 1024,
90
  height: int = 1024,
91
- seed: Optional[int] = None,
92
  ) -> Image.Image:
93
  """Generate an image from a text prompt."""
94
  if seed is None or seed < 0:
@@ -112,10 +114,10 @@ class ImageGenerationService:
112
 
113
  # Validate and clean the image before returning
114
  image = result.images[0]
115
-
116
  # Convert to numpy array to check for invalid values
117
  img_array = np.array(image)
118
-
119
  # Check for NaN or inf values and replace them
120
  if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
121
  print("⚠️ Warning: Image contains invalid values (NaN/inf), cleaning...")
@@ -123,12 +125,13 @@ class ImageGenerationService:
123
  # Ensure values are in valid range [0, 255]
124
  img_array = np.clip(img_array, 0, 255).astype(np.uint8)
125
  image = Image.fromarray(img_array)
126
-
127
  return image
128
-
129
  def get_model_info(self) -> str:
130
  """Get information about the currently loaded model."""
131
  return f"Model: {self.model_type} (Stable Diffusion XL)"
132
 
 
133
  # Global service instance
134
- image_generator = ImageGenerationService()
 
1
  """Image generation service using SDXL models."""
2
 
 
3
 
 
4
  import numpy as np
5
+ import torch
6
  from diffusers import AutoPipelineForText2Image
7
  from PIL import Image
8
 
 
17
  from core.config import config
18
  from core.constants import DEFAULT_MODEL_ID, FALLBACK_MODEL_ID
19
 
20
+
21
  class ImageGenerationService:
22
  """Service for generating images using SDXL models."""
23
 
 
35
  else:
36
  model_id = DEFAULT_MODEL_ID
37
  model_name = "SDXL"
38
+
39
  # Try primary SDXL model
40
  try:
41
  print(f"πŸ”„ Attempting to load {model_name}: {model_id}")
 
44
  ).to(config.device)
45
  self.model_type = "SDXL"
46
  print(f"βœ… {model_name} loaded successfully")
47
+
48
  # Enable SDXL-specific optimizations
49
  if config.device == "cuda":
50
  try:
 
53
  self.pipe.enable_attention_slicing()
54
  else:
55
  self.pipe.enable_attention_slicing()
56
+
57
  except Exception as e:
58
  print(f"⚠️ {model_name} failed to load: {e}")
59
+
60
  # Try fallback to SDXL-Turbo if we're not on HF Spaces and not using it already
61
  if not config.is_hf_spaces and model_id != FALLBACK_MODEL_ID:
62
  try:
 
66
  ).to(config.device)
67
  self.model_type = "SDXL"
68
  print("βœ… SDXL-Turbo loaded successfully")
69
+
70
  # Enable optimizations
71
  if config.device == "cuda":
72
  try:
 
78
  return
79
  except Exception as turbo_error:
80
  print(f"⚠️ SDXL-Turbo also failed to load: {turbo_error}")
81
+ raise RuntimeError(
82
+ f"All SDXL models failed to load. Last error: {turbo_error}"
83
+ )
84
  else:
85
  raise RuntimeError(f"SDXL model failed to load: {e}")
86
+
87
  def generate(
88
  self,
89
  prompt: str,
90
  steps: int = 4,
91
  width: int = 1024,
92
  height: int = 1024,
93
+ seed: int | None = None,
94
  ) -> Image.Image:
95
  """Generate an image from a text prompt."""
96
  if seed is None or seed < 0:
 
114
 
115
  # Validate and clean the image before returning
116
  image = result.images[0]
117
+
118
  # Convert to numpy array to check for invalid values
119
  img_array = np.array(image)
120
+
121
  # Check for NaN or inf values and replace them
122
  if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
123
  print("⚠️ Warning: Image contains invalid values (NaN/inf), cleaning...")
 
125
  # Ensure values are in valid range [0, 255]
126
  img_array = np.clip(img_array, 0, 255).astype(np.uint8)
127
  image = Image.fromarray(img_array)
128
+
129
  return image
130
+
131
  def get_model_info(self) -> str:
132
  """Get information about the currently loaded model."""
133
  return f"Model: {self.model_type} (Stable Diffusion XL)"
134
 
135
+
136
  # Global service instance
137
+ image_generator = ImageGenerationService()
src/services/training/__init__.py CHANGED
@@ -1 +1 @@
1
- # Training package
 
1
+ # Training package
src/services/training/dataset.py CHANGED
@@ -3,54 +3,59 @@ Dataset class for flower training data.
3
  """
4
 
5
  import os
 
6
  import torch
7
  from PIL import Image
8
  from torch.utils.data import Dataset
9
- from typing import List, Optional
10
 
11
- from utils.file_utils import get_image_files, get_flower_types_from_directory
 
12
 
13
  class FlowerDataset(Dataset):
14
  """Dataset for flower classification training."""
15
-
16
- def __init__(self, image_dir: str, processor, flower_labels: Optional[List[str]] = None):
 
 
17
  self.image_paths = []
18
  self.labels = []
19
  self.processor = processor
20
-
21
  # Auto-detect flower types from directory structure if not provided
22
  if flower_labels is None:
23
  self.flower_labels = get_flower_types_from_directory(image_dir)
24
  else:
25
  self.flower_labels = flower_labels
26
-
27
  self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
28
-
29
  # Load images from subdirectories (organized by flower type)
30
  for flower_type in os.listdir(image_dir):
31
  flower_path = os.path.join(image_dir, flower_type)
32
  if os.path.isdir(flower_path) and flower_type in self.label_to_id:
33
  image_files = get_image_files(flower_path)
34
-
35
  for img_path in image_files:
36
  self.image_paths.append(img_path)
37
  self.labels.append(self.label_to_id[flower_type])
38
-
39
- print(f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types")
 
 
40
  print(f"Flower types: {self.flower_labels}")
41
-
42
  def __len__(self):
43
  return len(self.image_paths)
44
-
45
  def __getitem__(self, idx):
46
  image_path = self.image_paths[idx]
47
  image = Image.open(image_path).convert("RGB")
48
  label = self.labels[idx]
49
-
50
  # Process image for ConvNeXt
51
  inputs = self.processor(images=image, return_tensors="pt")
52
-
53
  return {
54
- 'pixel_values': inputs['pixel_values'].squeeze(),
55
- 'labels': torch.tensor(label, dtype=torch.long)
56
- }
 
3
  """
4
 
5
  import os
6
+
7
  import torch
8
  from PIL import Image
9
  from torch.utils.data import Dataset
 
10
 
11
+ from utils.file_utils import get_flower_types_from_directory, get_image_files
12
+
13
 
14
  class FlowerDataset(Dataset):
15
  """Dataset for flower classification training."""
16
+
17
+ def __init__(
18
+ self, image_dir: str, processor, flower_labels: list[str] | None = None
19
+ ):
20
  self.image_paths = []
21
  self.labels = []
22
  self.processor = processor
23
+
24
  # Auto-detect flower types from directory structure if not provided
25
  if flower_labels is None:
26
  self.flower_labels = get_flower_types_from_directory(image_dir)
27
  else:
28
  self.flower_labels = flower_labels
29
+
30
  self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
31
+
32
  # Load images from subdirectories (organized by flower type)
33
  for flower_type in os.listdir(image_dir):
34
  flower_path = os.path.join(image_dir, flower_type)
35
  if os.path.isdir(flower_path) and flower_type in self.label_to_id:
36
  image_files = get_image_files(flower_path)
37
+
38
  for img_path in image_files:
39
  self.image_paths.append(img_path)
40
  self.labels.append(self.label_to_id[flower_type])
41
+
42
+ print(
43
+ f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
44
+ )
45
  print(f"Flower types: {self.flower_labels}")
46
+
47
  def __len__(self):
48
  return len(self.image_paths)
49
+
50
  def __getitem__(self, idx):
51
  image_path = self.image_paths[idx]
52
  image = Image.open(image_path).convert("RGB")
53
  label = self.labels[idx]
54
+
55
  # Process image for ConvNeXt
56
  inputs = self.processor(images=image, return_tensors="pt")
57
+
58
  return {
59
+ "pixel_values": inputs["pixel_values"].squeeze(),
60
+ "labels": torch.tensor(label, dtype=torch.long),
61
+ }
src/services/training/training_service.py CHANGED
@@ -3,36 +3,38 @@ Training service for flower classification models.
3
  """
4
 
5
  import os
6
- from typing import Optional
7
 
8
  from core.constants import IMAGES_DIR
9
  from utils.file_utils import count_training_images
10
 
 
11
  class TrainingService:
12
  """Service for managing model training."""
13
-
14
  def __init__(self):
15
  pass
16
-
17
- def start_training(self, epochs: int = 5, batch_size: int = 8,
18
- learning_rate: float = 1e-5) -> str:
 
19
  """Start the training process."""
20
  try:
21
  # Check if training data exists
22
  if not os.path.exists(IMAGES_DIR):
23
  return "❌ Training directory not found. Please create training_data/images/ and add your data."
24
-
25
  # Count images
26
  total_images, _ = count_training_images()
27
-
28
  if total_images < 10:
29
  return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
30
-
31
  # Import and run training (lazy import to avoid startup issues)
32
  try:
33
  from training.simple_train import simple_train
 
34
  model_path = simple_train()
35
-
36
  if model_path:
37
  return f"βœ… Training completed! Model saved to: {model_path}"
38
  else:
@@ -41,17 +43,19 @@ class TrainingService:
41
  # Fallback to old training method
42
  try:
43
  from simple_train import simple_train as legacy_train
 
44
  model_path = legacy_train()
45
-
46
  if model_path:
47
  return f"βœ… Training completed! Model saved to: {model_path}"
48
  else:
49
  return "❌ Training failed. Check the console for details."
50
  except ImportError:
51
  return "❌ Training module not found. Please ensure training scripts are available."
52
-
53
  except Exception as e:
54
- return f"❌ Training error: {str(e)}"
 
55
 
56
  # Global service instance
57
- training_service = TrainingService()
 
3
  """
4
 
5
  import os
 
6
 
7
  from core.constants import IMAGES_DIR
8
  from utils.file_utils import count_training_images
9
 
10
+
11
  class TrainingService:
12
  """Service for managing model training."""
13
+
14
  def __init__(self):
15
  pass
16
+
17
+ def start_training(
18
+ self, epochs: int = 5, batch_size: int = 8, learning_rate: float = 1e-5
19
+ ) -> str:
20
  """Start the training process."""
21
  try:
22
  # Check if training data exists
23
  if not os.path.exists(IMAGES_DIR):
24
  return "❌ Training directory not found. Please create training_data/images/ and add your data."
25
+
26
  # Count images
27
  total_images, _ = count_training_images()
28
+
29
  if total_images < 10:
30
  return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
31
+
32
  # Import and run training (lazy import to avoid startup issues)
33
  try:
34
  from training.simple_train import simple_train
35
+
36
  model_path = simple_train()
37
+
38
  if model_path:
39
  return f"βœ… Training completed! Model saved to: {model_path}"
40
  else:
 
43
  # Fallback to old training method
44
  try:
45
  from simple_train import simple_train as legacy_train
46
+
47
  model_path = legacy_train()
48
+
49
  if model_path:
50
  return f"βœ… Training completed! Model saved to: {model_path}"
51
  else:
52
  return "❌ Training failed. Check the console for details."
53
  except ImportError:
54
  return "❌ Training module not found. Please ensure training scripts are available."
55
+
56
  except Exception as e:
57
+ return f"❌ Training error: {e!s}"
58
+
59
 
60
  # Global service instance
61
+ training_service = TrainingService()
src/training/__init__.py CHANGED
@@ -1 +1 @@
1
- # Training package
 
1
+ # Training package
src/training/simple_train.py CHANGED
@@ -3,114 +3,125 @@ Simple ConvNeXt training script without using the Transformers Trainer class.
3
  Refactored version of the original simple_train.py
4
  """
5
 
 
6
  import os
 
7
  import torch
8
- import torch.nn as nn
9
  from torch.utils.data import DataLoader
10
- from transformers import ConvNextImageProcessor, ConvNextForImageClassification
11
- import json
12
 
13
- from ..services.training.dataset import FlowerDataset
14
  from ..core.config import config
15
  from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR
 
 
16
 
17
  def simple_train():
18
  """Simple ConvNeXt training function."""
19
  print("🌸 Simple ConvNeXt Flower Model Training")
20
  print("=" * 40)
21
-
22
  # Check training data
23
  images_dir = "training_data/images"
24
  if not os.path.exists(images_dir):
25
  print("❌ Training directory not found")
26
  return
27
-
28
  device = config.device
29
  print(f"Using device: {device}")
30
-
31
  # Load model and processor
32
  model_name = DEFAULT_CONVNEXT_MODEL
33
  model = ConvNextForImageClassification.from_pretrained(model_name)
34
  processor = ConvNextImageProcessor.from_pretrained(model_name)
35
  model.to(device)
36
-
37
  # Create dataset
38
  dataset = FlowerDataset(images_dir, processor)
39
-
40
  if len(dataset) < 5:
41
  print("❌ Need at least 5 images for training")
42
  return
43
-
44
  # Update model config for the number of classes
45
  if len(dataset.flower_labels) != model.config.num_labels:
46
  model.config.num_labels = len(dataset.flower_labels)
47
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
48
- final_hidden_size = model.config.hidden_sizes[-1] if hasattr(model.config, 'hidden_sizes') else 768
49
- model.classifier = torch.nn.Linear(final_hidden_size, len(dataset.flower_labels))
50
-
 
 
 
 
 
 
51
  # Split dataset
52
  train_size = int(0.8 * len(dataset))
53
  train_dataset = torch.utils.data.Subset(dataset, range(train_size))
54
-
55
  # Create data loader
56
  def simple_collate_fn(batch):
57
  pixel_values = []
58
  labels = []
59
-
60
  for item in batch:
61
- pixel_values.append(item['pixel_values'])
62
- labels.append(item['labels'])
63
-
64
  return {
65
- 'pixel_values': torch.stack(pixel_values),
66
- 'labels': torch.stack(labels)
67
  }
68
-
69
- train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn)
70
-
 
 
71
  # Setup optimizer
72
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
73
-
74
  # Training loop
75
  model.train()
76
  print(f"Starting training on {len(train_dataset)} samples...")
77
-
78
  for epoch in range(3):
79
  total_loss = 0
80
  num_batches = 0
81
-
82
  for batch_idx, batch in enumerate(train_loader):
83
  # Move to device
84
- pixel_values = batch['pixel_values'].to(device)
85
- labels = batch['labels'].to(device)
86
-
87
  # Zero gradients
88
  optimizer.zero_grad()
89
-
90
  # Forward pass
91
  outputs = model(pixel_values=pixel_values, labels=labels)
92
  loss = outputs.loss
93
-
94
  # Backward pass
95
  loss.backward()
96
  optimizer.step()
97
-
98
  total_loss += loss.item()
99
  num_batches += 1
100
-
101
  if batch_idx % 2 == 0:
102
- print(f"Epoch {epoch+1}, Batch {batch_idx+1}: Loss = {loss.item():.4f}")
103
-
 
 
104
  avg_loss = total_loss / num_batches if num_batches > 0 else 0
105
- print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
106
-
107
  # Save model
108
  output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext")
109
  os.makedirs(output_dir, exist_ok=True)
110
-
111
  model.save_pretrained(output_dir)
112
  processor.save_pretrained(output_dir)
113
-
114
  # Save config
115
  config_data = {
116
  "model_name": model_name,
@@ -119,15 +130,16 @@ def simple_train():
119
  "batch_size": 4,
120
  "learning_rate": 1e-5,
121
  "train_samples": len(train_dataset),
122
- "num_labels": len(dataset.flower_labels)
123
  }
124
-
125
  with open(os.path.join(output_dir, "training_config.json"), "w") as f:
126
  json.dump(config_data, f, indent=2)
127
-
128
  print(f"βœ… ConvNeXt training completed! Model saved to {output_dir}")
129
  return output_dir
130
 
 
131
  if __name__ == "__main__":
132
  try:
133
  simple_train()
@@ -136,4 +148,5 @@ if __name__ == "__main__":
136
  except Exception as e:
137
  print(f"❌ Training failed: {e}")
138
  import traceback
139
- traceback.print_exc()
 
 
3
  Refactored version of the original simple_train.py
4
  """
5
 
6
+ import json
7
  import os
8
+
9
  import torch
 
10
  from torch.utils.data import DataLoader
11
+ from transformers import ConvNextForImageClassification, ConvNextImageProcessor
 
12
 
 
13
  from ..core.config import config
14
  from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR
15
+ from ..services.training.dataset import FlowerDataset
16
+
17
 
18
  def simple_train():
19
  """Simple ConvNeXt training function."""
20
  print("🌸 Simple ConvNeXt Flower Model Training")
21
  print("=" * 40)
22
+
23
  # Check training data
24
  images_dir = "training_data/images"
25
  if not os.path.exists(images_dir):
26
  print("❌ Training directory not found")
27
  return
28
+
29
  device = config.device
30
  print(f"Using device: {device}")
31
+
32
  # Load model and processor
33
  model_name = DEFAULT_CONVNEXT_MODEL
34
  model = ConvNextForImageClassification.from_pretrained(model_name)
35
  processor = ConvNextImageProcessor.from_pretrained(model_name)
36
  model.to(device)
37
+
38
  # Create dataset
39
  dataset = FlowerDataset(images_dir, processor)
40
+
41
  if len(dataset) < 5:
42
  print("❌ Need at least 5 images for training")
43
  return
44
+
45
  # Update model config for the number of classes
46
  if len(dataset.flower_labels) != model.config.num_labels:
47
  model.config.num_labels = len(dataset.flower_labels)
48
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
49
+ final_hidden_size = (
50
+ model.config.hidden_sizes[-1]
51
+ if hasattr(model.config, "hidden_sizes")
52
+ else 768
53
+ )
54
+ model.classifier = torch.nn.Linear(
55
+ final_hidden_size, len(dataset.flower_labels)
56
+ )
57
+
58
  # Split dataset
59
  train_size = int(0.8 * len(dataset))
60
  train_dataset = torch.utils.data.Subset(dataset, range(train_size))
61
+
62
  # Create data loader
63
  def simple_collate_fn(batch):
64
  pixel_values = []
65
  labels = []
66
+
67
  for item in batch:
68
+ pixel_values.append(item["pixel_values"])
69
+ labels.append(item["labels"])
70
+
71
  return {
72
+ "pixel_values": torch.stack(pixel_values),
73
+ "labels": torch.stack(labels),
74
  }
75
+
76
+ train_loader = DataLoader(
77
+ train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn
78
+ )
79
+
80
  # Setup optimizer
81
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
82
+
83
  # Training loop
84
  model.train()
85
  print(f"Starting training on {len(train_dataset)} samples...")
86
+
87
  for epoch in range(3):
88
  total_loss = 0
89
  num_batches = 0
90
+
91
  for batch_idx, batch in enumerate(train_loader):
92
  # Move to device
93
+ pixel_values = batch["pixel_values"].to(device)
94
+ labels = batch["labels"].to(device)
95
+
96
  # Zero gradients
97
  optimizer.zero_grad()
98
+
99
  # Forward pass
100
  outputs = model(pixel_values=pixel_values, labels=labels)
101
  loss = outputs.loss
102
+
103
  # Backward pass
104
  loss.backward()
105
  optimizer.step()
106
+
107
  total_loss += loss.item()
108
  num_batches += 1
109
+
110
  if batch_idx % 2 == 0:
111
+ print(
112
+ f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss = {loss.item():.4f}"
113
+ )
114
+
115
  avg_loss = total_loss / num_batches if num_batches > 0 else 0
116
+ print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
117
+
118
  # Save model
119
  output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext")
120
  os.makedirs(output_dir, exist_ok=True)
121
+
122
  model.save_pretrained(output_dir)
123
  processor.save_pretrained(output_dir)
124
+
125
  # Save config
126
  config_data = {
127
  "model_name": model_name,
 
130
  "batch_size": 4,
131
  "learning_rate": 1e-5,
132
  "train_samples": len(train_dataset),
133
+ "num_labels": len(dataset.flower_labels),
134
  }
135
+
136
  with open(os.path.join(output_dir, "training_config.json"), "w") as f:
137
  json.dump(config_data, f, indent=2)
138
+
139
  print(f"βœ… ConvNeXt training completed! Model saved to {output_dir}")
140
  return output_dir
141
 
142
+
143
  if __name__ == "__main__":
144
  try:
145
  simple_train()
 
148
  except Exception as e:
149
  print(f"❌ Training failed: {e}")
150
  import traceback
151
+
152
+ traceback.print_exc()
src/ui/__init__.py CHANGED
@@ -1 +1 @@
1
- # UI package
 
1
+ # UI package
src/ui/french_style/__init__.py CHANGED
@@ -1 +1 @@
1
- # French style tab package
 
1
+ # French style tab package
src/ui/french_style/french_style_tab.py CHANGED
@@ -2,121 +2,129 @@
2
  French Style tab UI components and logic.
3
  """
4
 
 
5
  import gradio as gr
6
  from PIL import Image
7
- from typing import Optional, Tuple
8
 
9
  try:
 
10
  from services.models.flower_classification import flower_classifier
11
  from services.models.image_generation import image_generator
12
  from utils.color_utils import extract_dominant_colors
13
- from core.constants import FLOWER_LABELS, DEFAULT_NUM_COLORS
14
  except ImportError:
15
  # Handle when imported from root app.py
16
- import sys
17
  import os
 
 
18
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
 
19
  from services.models.flower_classification import flower_classifier
20
  from services.models.image_generation import image_generator
21
  from utils.color_utils import extract_dominant_colors
22
- from core.constants import FLOWER_LABELS, DEFAULT_NUM_COLORS
23
 
24
  class FrenchStyleTab:
25
  """UI component for the French Style tab."""
26
-
27
  def __init__(self):
28
  pass
29
-
30
  def create_ui(self) -> gr.TabItem:
31
  """Create the French Style tab UI."""
32
  with gr.TabItem("French Style arrangement") as tab:
33
  gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
34
- gr.Markdown("Upload a flower image and generate an elegant French-style arrangement with matching colors!")
35
-
 
 
36
  with gr.Row():
37
  with gr.Column():
38
  self.upload_img = gr.Image(label="Upload Flower Image", type="pil")
39
  self.analyze_btn = gr.Button(
40
- "🎨 Analyze & Generate French Style",
41
- variant="primary",
42
- size="lg"
43
  )
44
-
45
  with gr.Column():
46
  self.french_result = gr.Image(
47
- label="Generated French-Style Arrangement",
48
- type="pil"
49
  )
50
  self.french_status = gr.Markdown()
51
  self.analysis_details = gr.Markdown()
52
-
53
  # Wire events
54
  self.analyze_btn.click(
55
- self._update_status,
56
- outputs=[self.french_status, self.analysis_details]
57
  ).then(
58
  self.analyze_and_generate,
59
  inputs=[self.upload_img],
60
- outputs=[self.french_result, self.french_status, self.analysis_details]
61
  )
62
-
63
  return tab
64
-
65
- def _update_status(self) -> Tuple[str, str]:
66
  """Update status during processing."""
67
  return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
68
-
69
- def analyze_and_generate(self, image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], str, str]:
 
 
70
  """Analyze uploaded flower image and generate French-style arrangement."""
71
  if image is None:
72
  return None, "Please upload an image", ""
73
-
74
  # Check if classifier is loaded
75
  if flower_classifier.zs_classifier is None:
76
  return None, "Model not loaded", ""
77
-
78
  try:
79
  progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
80
-
81
  # Identify flower
82
  progress_log += "πŸ” Identifying flower type using AI model...\n"
83
  results = flower_classifier._use_clip_classification(image, FLOWER_LABELS)
84
-
85
  top_flower = results[0]["label"] if results else "flower"
86
  confidence = results[0]["score"] if results else 0
87
- progress_log += f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
88
-
 
 
89
  # Extract dominant colors
90
  progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
91
  progress_log += "🎨 Extracting dominant colors from image...\n"
92
- color_names, color_rgb = extract_dominant_colors(image, num_colors=DEFAULT_NUM_COLORS)
93
-
 
 
94
  # Create color description
95
  main_colors = color_names[:3] # Top 3 colors
96
  color_desc = ", ".join(main_colors)
97
  progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
98
-
99
  # Generate French-style prompt
100
- progress_log += "πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
 
 
101
  prompt = (
102
  f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, "
103
  f"displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, "
104
  f"minimalist French country kitchen background, professional photography, sophisticated composition"
105
  )
106
  progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
107
-
108
  # Generate the image
109
- progress_log += "πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
 
 
110
  progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
111
  generated_image = image_generator.generate(
112
- prompt=prompt,
113
- steps=4,
114
- width=1024,
115
- height=1024,
116
- seed=None
117
  )
118
  progress_log += "βœ… French-style arrangement generated successfully!\n\n"
119
-
120
  # Create analysis summary
121
  analysis = f"""
122
  **🌸 Flower Analysis:**
@@ -131,11 +139,15 @@ class FrenchStyleTab:
131
  **πŸ“‹ Process Log:**
132
  {progress_log}
133
  """
134
-
135
- return generated_image, "βœ… Analysis complete! French-style arrangement generated.", analysis
136
-
 
 
 
 
137
  except Exception as e:
138
- error_log = f"❌ **Error occurred during processing:**\n\n{str(e)}\n\n"
139
- if 'progress_log' in locals():
140
  error_log += f"**Progress before error:**\n{progress_log}"
141
- return None, f"❌ Error: {str(e)}", error_log
 
2
  French Style tab UI components and logic.
3
  """
4
 
5
+
6
  import gradio as gr
7
  from PIL import Image
 
8
 
9
  try:
10
+ from core.constants import DEFAULT_NUM_COLORS, FLOWER_LABELS
11
  from services.models.flower_classification import flower_classifier
12
  from services.models.image_generation import image_generator
13
  from utils.color_utils import extract_dominant_colors
 
14
  except ImportError:
15
  # Handle when imported from root app.py
 
16
  import os
17
+ import sys
18
+
19
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
20
+ from core.constants import DEFAULT_NUM_COLORS, FLOWER_LABELS
21
  from services.models.flower_classification import flower_classifier
22
  from services.models.image_generation import image_generator
23
  from utils.color_utils import extract_dominant_colors
24
+
25
 
26
  class FrenchStyleTab:
27
  """UI component for the French Style tab."""
28
+
29
  def __init__(self):
30
  pass
31
+
32
  def create_ui(self) -> gr.TabItem:
33
  """Create the French Style tab UI."""
34
  with gr.TabItem("French Style arrangement") as tab:
35
  gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
36
+ gr.Markdown(
37
+ "Upload a flower image and generate an elegant French-style arrangement with matching colors!"
38
+ )
39
+
40
  with gr.Row():
41
  with gr.Column():
42
  self.upload_img = gr.Image(label="Upload Flower Image", type="pil")
43
  self.analyze_btn = gr.Button(
44
+ "🎨 Analyze & Generate French Style",
45
+ variant="primary",
46
+ size="lg",
47
  )
48
+
49
  with gr.Column():
50
  self.french_result = gr.Image(
51
+ label="Generated French-Style Arrangement", type="pil"
 
52
  )
53
  self.french_status = gr.Markdown()
54
  self.analysis_details = gr.Markdown()
55
+
56
  # Wire events
57
  self.analyze_btn.click(
58
+ self._update_status, outputs=[self.french_status, self.analysis_details]
 
59
  ).then(
60
  self.analyze_and_generate,
61
  inputs=[self.upload_img],
62
+ outputs=[self.french_result, self.french_status, self.analysis_details],
63
  )
64
+
65
  return tab
66
+
67
+ def _update_status(self) -> tuple[str, str]:
68
  """Update status during processing."""
69
  return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
70
+
71
+ def analyze_and_generate(
72
+ self, image: Image.Image | None
73
+ ) -> tuple[Image.Image | None, str, str]:
74
  """Analyze uploaded flower image and generate French-style arrangement."""
75
  if image is None:
76
  return None, "Please upload an image", ""
77
+
78
  # Check if classifier is loaded
79
  if flower_classifier.zs_classifier is None:
80
  return None, "Model not loaded", ""
81
+
82
  try:
83
  progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
84
+
85
  # Identify flower
86
  progress_log += "πŸ” Identifying flower type using AI model...\n"
87
  results = flower_classifier._use_clip_classification(image, FLOWER_LABELS)
88
+
89
  top_flower = results[0]["label"] if results else "flower"
90
  confidence = results[0]["score"] if results else 0
91
+ progress_log += (
92
+ f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
93
+ )
94
+
95
  # Extract dominant colors
96
  progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
97
  progress_log += "🎨 Extracting dominant colors from image...\n"
98
+ color_names, color_rgb = extract_dominant_colors(
99
+ image, num_colors=DEFAULT_NUM_COLORS
100
+ )
101
+
102
  # Create color description
103
  main_colors = color_names[:3] # Top 3 colors
104
  color_desc = ", ".join(main_colors)
105
  progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
106
+
107
  # Generate French-style prompt
108
+ progress_log += (
109
+ "πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
110
+ )
111
  prompt = (
112
  f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, "
113
  f"displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, "
114
  f"minimalist French country kitchen background, professional photography, sophisticated composition"
115
  )
116
  progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
117
+
118
  # Generate the image
119
+ progress_log += (
120
+ "πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
121
+ )
122
  progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
123
  generated_image = image_generator.generate(
124
+ prompt=prompt, steps=4, width=1024, height=1024, seed=None
 
 
 
 
125
  )
126
  progress_log += "βœ… French-style arrangement generated successfully!\n\n"
127
+
128
  # Create analysis summary
129
  analysis = f"""
130
  **🌸 Flower Analysis:**
 
139
  **πŸ“‹ Process Log:**
140
  {progress_log}
141
  """
142
+
143
+ return (
144
+ generated_image,
145
+ "βœ… Analysis complete! French-style arrangement generated.",
146
+ analysis,
147
+ )
148
+
149
  except Exception as e:
150
+ error_log = f"❌ **Error occurred during processing:**\n\n{e!s}\n\n"
151
+ if "progress_log" in locals():
152
  error_log += f"**Progress before error:**\n{progress_log}"
153
+ return None, f"❌ Error: {e!s}", error_log
src/ui/generate/__init__.py CHANGED
@@ -1 +1 @@
1
- # Generate tab package
 
1
+ # Generate tab package
src/ui/generate/generate_tab.py CHANGED
@@ -2,27 +2,29 @@
2
  Generate tab UI components and logic.
3
  """
4
 
 
5
  import gradio as gr
6
  from PIL import Image
7
- from typing import Optional
8
 
9
  try:
 
10
  from services.models.image_generation import image_generator
11
- from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_WIDTH, DEFAULT_HEIGHT
12
  except ImportError:
13
  # Handle when imported from root app.py
14
- import sys
15
  import os
 
 
16
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
 
17
  from services.models.image_generation import image_generator
18
- from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_WIDTH, DEFAULT_HEIGHT
19
 
20
  class GenerateTab:
21
  """UI component for the Generate tab."""
22
-
23
  def __init__(self):
24
  self.output_image = None
25
-
26
  def create_ui(self) -> gr.TabItem:
27
  """Create the Generate tab UI."""
28
  with gr.TabItem("Generate") as tab:
@@ -30,7 +32,7 @@ class GenerateTab:
30
  with gr.Column():
31
  self.prompt_input = gr.Textbox(
32
  value="ikebana-style flower arrangement, soft natural light, minimalist",
33
- label="Prompt"
34
  )
35
  self.steps_input = gr.Slider(
36
  1, 8, value=DEFAULT_GENERATE_STEPS, step=1, label="Steps"
@@ -45,23 +47,27 @@ class GenerateTab:
45
  value=-1, precision=0, label="Seed (-1 = random)"
46
  )
47
  self.generate_btn = gr.Button("Generate", variant="primary")
48
-
49
  self.output_image = gr.Image(label="Result", type="pil")
50
-
51
  # Wire events
52
  self.generate_btn.click(
53
  self.generate_image,
54
  inputs=[
55
- self.prompt_input, self.steps_input, self.width_input,
56
- self.height_input, self.seed_input
 
 
 
57
  ],
58
- outputs=self.output_image
59
  )
60
-
61
  return tab
62
-
63
- def generate_image(self, prompt: str, steps: int, width: int,
64
- height: int, seed: int) -> Optional[Image.Image]:
 
65
  """Generate an image from the given parameters."""
66
  try:
67
  return image_generator.generate(
@@ -69,8 +75,8 @@ class GenerateTab:
69
  steps=steps,
70
  width=width,
71
  height=height,
72
- seed=seed if seed >= 0 else None
73
  )
74
  except Exception as e:
75
- gr.Warning(f"Error generating image: {str(e)}")
76
- return None
 
2
  Generate tab UI components and logic.
3
  """
4
 
5
+
6
  import gradio as gr
7
  from PIL import Image
 
8
 
9
  try:
10
+ from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_HEIGHT, DEFAULT_WIDTH
11
  from services.models.image_generation import image_generator
 
12
  except ImportError:
13
  # Handle when imported from root app.py
 
14
  import os
15
+ import sys
16
+
17
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
18
+ from core.constants import DEFAULT_GENERATE_STEPS, DEFAULT_HEIGHT, DEFAULT_WIDTH
19
  from services.models.image_generation import image_generator
20
+
21
 
22
  class GenerateTab:
23
  """UI component for the Generate tab."""
24
+
25
  def __init__(self):
26
  self.output_image = None
27
+
28
  def create_ui(self) -> gr.TabItem:
29
  """Create the Generate tab UI."""
30
  with gr.TabItem("Generate") as tab:
 
32
  with gr.Column():
33
  self.prompt_input = gr.Textbox(
34
  value="ikebana-style flower arrangement, soft natural light, minimalist",
35
+ label="Prompt",
36
  )
37
  self.steps_input = gr.Slider(
38
  1, 8, value=DEFAULT_GENERATE_STEPS, step=1, label="Steps"
 
47
  value=-1, precision=0, label="Seed (-1 = random)"
48
  )
49
  self.generate_btn = gr.Button("Generate", variant="primary")
50
+
51
  self.output_image = gr.Image(label="Result", type="pil")
52
+
53
  # Wire events
54
  self.generate_btn.click(
55
  self.generate_image,
56
  inputs=[
57
+ self.prompt_input,
58
+ self.steps_input,
59
+ self.width_input,
60
+ self.height_input,
61
+ self.seed_input,
62
  ],
63
+ outputs=self.output_image,
64
  )
65
+
66
  return tab
67
+
68
+ def generate_image(
69
+ self, prompt: str, steps: int, width: int, height: int, seed: int
70
+ ) -> Image.Image | None:
71
  """Generate an image from the given parameters."""
72
  try:
73
  return image_generator.generate(
 
75
  steps=steps,
76
  width=width,
77
  height=height,
78
+ seed=seed if seed >= 0 else None,
79
  )
80
  except Exception as e:
81
+ gr.Warning(f"Error generating image: {e!s}")
82
+ return None
src/ui/identify/__init__.py CHANGED
@@ -1 +1 @@
1
- # Identify tab package
 
1
+ # Identify tab package
src/ui/identify/identify_tab.py CHANGED
@@ -2,27 +2,29 @@
2
  Identify tab UI components and logic.
3
  """
4
 
 
5
  import gradio as gr
6
  from PIL import Image
7
- from typing import List, Optional, Tuple
8
 
9
  try:
 
10
  from services.models.flower_classification import flower_classifier
11
- from core.constants import FLOWER_LABELS, DEFAULT_TOP_K, DEFAULT_MIN_SCORE
12
  except ImportError:
13
  # Handle when imported from root app.py
14
- import sys
15
  import os
 
 
16
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
 
17
  from services.models.flower_classification import flower_classifier
18
- from core.constants import FLOWER_LABELS, DEFAULT_TOP_K, DEFAULT_MIN_SCORE
19
 
20
  class IdentifyTab:
21
  """UI component for the Identify tab."""
22
-
23
  def __init__(self):
24
  pass
25
-
26
  def create_ui(self) -> gr.TabItem:
27
  """Create the Identify tab UI."""
28
  with gr.TabItem("Identify") as tab:
@@ -31,52 +33,70 @@ class IdentifyTab:
31
  self.image_input = gr.Image(
32
  label="Image (upload or auto-filled from 'Generate')",
33
  type="pil",
34
- interactive=True
35
  )
36
  self.labels_input = gr.CheckboxGroup(
37
  choices=FLOWER_LABELS,
38
- value=["rose", "tulip", "lily", "peony", "hydrangea", "orchid", "sunflower"],
39
- label="Candidate labels (edit as needed)"
 
 
 
 
 
 
 
 
40
  )
41
  self.topk_input = gr.Slider(
42
  1, 15, value=DEFAULT_TOP_K, step=1, label="Top-K"
43
  )
44
  self.min_score_input = gr.Slider(
45
- 0.0, 1.0, value=DEFAULT_MIN_SCORE, step=0.01, label="Min confidence"
 
 
 
 
46
  )
47
  self.detect_btn = gr.Button("Identify Flowers", variant="primary")
48
-
49
  with gr.Column():
50
  self.results_table = gr.Dataframe(
51
  headers=["Flower", "Confidence"],
52
  datatype=["str", "number"],
53
- interactive=False
54
  )
55
  self.status_output = gr.Markdown()
56
-
57
  # Wire events
58
  self.detect_btn.click(
59
  self.identify_flowers,
60
  inputs=[
61
- self.image_input, self.labels_input,
62
- self.topk_input, self.min_score_input
 
 
63
  ],
64
- outputs=[self.results_table, self.status_output]
65
  )
66
-
67
  return tab
68
-
69
- def identify_flowers(self, image: Optional[Image.Image],
70
- candidate_labels: List[str], top_k: int,
71
- min_score: float) -> Tuple[List[List], str]:
 
 
 
 
72
  """Identify flowers in the provided image."""
73
  return flower_classifier.identify_flowers(
74
  image=image,
75
  candidate_labels=candidate_labels,
76
  top_k=top_k,
77
- min_score=min_score
78
  )
79
-
80
- def set_image(self, image: Optional[Image.Image]) -> Optional[Image.Image]:
81
  """Set the image for identification (used by other tabs)."""
82
- return image
 
2
  Identify tab UI components and logic.
3
  """
4
 
5
+
6
  import gradio as gr
7
  from PIL import Image
 
8
 
9
  try:
10
+ from core.constants import DEFAULT_MIN_SCORE, DEFAULT_TOP_K, FLOWER_LABELS
11
  from services.models.flower_classification import flower_classifier
 
12
  except ImportError:
13
  # Handle when imported from root app.py
 
14
  import os
15
+ import sys
16
+
17
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
18
+ from core.constants import DEFAULT_MIN_SCORE, DEFAULT_TOP_K, FLOWER_LABELS
19
  from services.models.flower_classification import flower_classifier
20
+
21
 
22
  class IdentifyTab:
23
  """UI component for the Identify tab."""
24
+
25
  def __init__(self):
26
  pass
27
+
28
  def create_ui(self) -> gr.TabItem:
29
  """Create the Identify tab UI."""
30
  with gr.TabItem("Identify") as tab:
 
33
  self.image_input = gr.Image(
34
  label="Image (upload or auto-filled from 'Generate')",
35
  type="pil",
36
+ interactive=True,
37
  )
38
  self.labels_input = gr.CheckboxGroup(
39
  choices=FLOWER_LABELS,
40
+ value=[
41
+ "rose",
42
+ "tulip",
43
+ "lily",
44
+ "peony",
45
+ "hydrangea",
46
+ "orchid",
47
+ "sunflower",
48
+ ],
49
+ label="Candidate labels (edit as needed)",
50
  )
51
  self.topk_input = gr.Slider(
52
  1, 15, value=DEFAULT_TOP_K, step=1, label="Top-K"
53
  )
54
  self.min_score_input = gr.Slider(
55
+ 0.0,
56
+ 1.0,
57
+ value=DEFAULT_MIN_SCORE,
58
+ step=0.01,
59
+ label="Min confidence",
60
  )
61
  self.detect_btn = gr.Button("Identify Flowers", variant="primary")
62
+
63
  with gr.Column():
64
  self.results_table = gr.Dataframe(
65
  headers=["Flower", "Confidence"],
66
  datatype=["str", "number"],
67
+ interactive=False,
68
  )
69
  self.status_output = gr.Markdown()
70
+
71
  # Wire events
72
  self.detect_btn.click(
73
  self.identify_flowers,
74
  inputs=[
75
+ self.image_input,
76
+ self.labels_input,
77
+ self.topk_input,
78
+ self.min_score_input,
79
  ],
80
+ outputs=[self.results_table, self.status_output],
81
  )
82
+
83
  return tab
84
+
85
+ def identify_flowers(
86
+ self,
87
+ image: Image.Image | None,
88
+ candidate_labels: list[str],
89
+ top_k: int,
90
+ min_score: float,
91
+ ) -> tuple[list[list], str]:
92
  """Identify flowers in the provided image."""
93
  return flower_classifier.identify_flowers(
94
  image=image,
95
  candidate_labels=candidate_labels,
96
  top_k=top_k,
97
+ min_score=min_score,
98
  )
99
+
100
+ def set_image(self, image: Image.Image | None) -> Image.Image | None:
101
  """Set the image for identification (used by other tabs)."""
102
+ return image
src/ui/train/__init__.py CHANGED
@@ -1 +1 @@
1
- # Train tab package
 
1
+ # Train tab package
src/ui/train/train_tab.py CHANGED
@@ -2,8 +2,8 @@
2
  Train Model tab UI components and logic.
3
  """
4
 
 
5
  import gradio as gr
6
- from typing import List
7
 
8
  try:
9
  from services.models.flower_classification import flower_classifier
@@ -11,32 +11,38 @@ try:
11
  from utils.file_utils import count_training_images
12
  except ImportError:
13
  # Handle when imported from root app.py
14
- import sys
15
  import os
 
 
16
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
17
  from services.models.flower_classification import flower_classifier
18
  from services.training.training_service import training_service
19
  from utils.file_utils import count_training_images
20
 
 
21
  class TrainTab:
22
  """UI component for the Train Model tab."""
23
-
24
  def __init__(self):
25
  pass
26
-
27
  def create_ui(self) -> gr.TabItem:
28
  """Create the Train Model tab UI."""
29
  with gr.TabItem("Train Model") as tab:
30
  gr.Markdown("## 🎯 Fine-tune the flower identification model")
31
- gr.Markdown("Organize your training images in subdirectories by flower type in `training_data/images/`")
32
- gr.Markdown("Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc.")
33
-
 
 
 
 
34
  with gr.Row():
35
  with gr.Column():
36
  gr.Markdown("### Training Data")
37
  self.refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
38
  self.data_status = gr.Markdown()
39
-
40
  gr.Markdown("### Training Parameters")
41
  self.epochs_input = gr.Slider(
42
  1, 20, value=5, step=1, label="Training Epochs"
@@ -47,69 +53,71 @@ class TrainTab:
47
  self.learning_rate_input = gr.Number(
48
  value=1e-5, label="Learning Rate", precision=6
49
  )
50
-
51
  self.train_btn = gr.Button("πŸš€ Start Training", variant="primary")
52
-
53
  with gr.Column():
54
  gr.Markdown("### Model Management")
55
  self.model_dropdown = gr.Dropdown(
56
  choices=flower_classifier.get_available_models(),
57
  value=f"{flower_classifier.current_model_path} (default)",
58
- label="Select Model"
59
  )
60
  self.refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
61
- self.load_model_btn = gr.Button("πŸ“₯ Load Selected Model", variant="secondary")
62
-
 
 
63
  self.model_status = gr.Markdown(
64
  f"**Current model:** {flower_classifier.current_model_path}"
65
  )
66
-
67
  gr.Markdown("### Training Status")
68
  self.training_output = gr.Markdown()
69
-
70
  # Wire events
71
  self.refresh_btn.click(self._count_training_images, outputs=[self.data_status])
72
  self.refresh_models_btn.click(
73
  self._refresh_models, outputs=[self.model_dropdown]
74
  )
75
  self.load_model_btn.click(
76
- self._load_trained_model,
77
- inputs=[self.model_dropdown],
78
- outputs=[self.model_status]
79
  )
80
  self.train_btn.click(
81
  self._start_training,
82
  inputs=[self.epochs_input, self.batch_size_input, self.learning_rate_input],
83
- outputs=[self.training_output]
84
  )
85
-
86
  return tab
87
-
88
  def _count_training_images(self) -> str:
89
  """Count and display training images."""
90
  total_images, flower_counts = count_training_images()
91
-
92
  if total_images == 0:
93
  return "No training images found. Add images to subdirectories in training_data/images/"
94
-
95
  result = f"**Total images: {total_images}**\n\n"
96
  for flower_type, count in sorted(flower_counts.items()):
97
  result += f"- {flower_type}: {count} images\n"
98
-
99
  return result
100
-
101
  def _refresh_models(self) -> gr.Dropdown:
102
  """Refresh the list of available models."""
103
  return gr.Dropdown(choices=flower_classifier.get_available_models())
104
-
105
  def _load_trained_model(self, model_selection: str) -> str:
106
  """Load the selected trained model."""
107
  return flower_classifier.load_trained_model(model_selection)
108
-
109
- def _start_training(self, epochs: int, batch_size: int, learning_rate: float) -> str:
 
 
110
  """Start the training process."""
111
  return training_service.start_training(
112
- epochs=epochs,
113
- batch_size=batch_size,
114
- learning_rate=learning_rate
115
- )
 
2
  Train Model tab UI components and logic.
3
  """
4
 
5
+
6
  import gradio as gr
 
7
 
8
  try:
9
  from services.models.flower_classification import flower_classifier
 
11
  from utils.file_utils import count_training_images
12
  except ImportError:
13
  # Handle when imported from root app.py
 
14
  import os
15
+ import sys
16
+
17
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
18
  from services.models.flower_classification import flower_classifier
19
  from services.training.training_service import training_service
20
  from utils.file_utils import count_training_images
21
 
22
+
23
  class TrainTab:
24
  """UI component for the Train Model tab."""
25
+
26
  def __init__(self):
27
  pass
28
+
29
  def create_ui(self) -> gr.TabItem:
30
  """Create the Train Model tab UI."""
31
  with gr.TabItem("Train Model") as tab:
32
  gr.Markdown("## 🎯 Fine-tune the flower identification model")
33
+ gr.Markdown(
34
+ "Organize your training images in subdirectories by flower type in `training_data/images/`"
35
+ )
36
+ gr.Markdown(
37
+ "Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
38
+ )
39
+
40
  with gr.Row():
41
  with gr.Column():
42
  gr.Markdown("### Training Data")
43
  self.refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
44
  self.data_status = gr.Markdown()
45
+
46
  gr.Markdown("### Training Parameters")
47
  self.epochs_input = gr.Slider(
48
  1, 20, value=5, step=1, label="Training Epochs"
 
53
  self.learning_rate_input = gr.Number(
54
  value=1e-5, label="Learning Rate", precision=6
55
  )
56
+
57
  self.train_btn = gr.Button("πŸš€ Start Training", variant="primary")
58
+
59
  with gr.Column():
60
  gr.Markdown("### Model Management")
61
  self.model_dropdown = gr.Dropdown(
62
  choices=flower_classifier.get_available_models(),
63
  value=f"{flower_classifier.current_model_path} (default)",
64
+ label="Select Model",
65
  )
66
  self.refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
67
+ self.load_model_btn = gr.Button(
68
+ "πŸ“₯ Load Selected Model", variant="secondary"
69
+ )
70
+
71
  self.model_status = gr.Markdown(
72
  f"**Current model:** {flower_classifier.current_model_path}"
73
  )
74
+
75
  gr.Markdown("### Training Status")
76
  self.training_output = gr.Markdown()
77
+
78
  # Wire events
79
  self.refresh_btn.click(self._count_training_images, outputs=[self.data_status])
80
  self.refresh_models_btn.click(
81
  self._refresh_models, outputs=[self.model_dropdown]
82
  )
83
  self.load_model_btn.click(
84
+ self._load_trained_model,
85
+ inputs=[self.model_dropdown],
86
+ outputs=[self.model_status],
87
  )
88
  self.train_btn.click(
89
  self._start_training,
90
  inputs=[self.epochs_input, self.batch_size_input, self.learning_rate_input],
91
+ outputs=[self.training_output],
92
  )
93
+
94
  return tab
95
+
96
  def _count_training_images(self) -> str:
97
  """Count and display training images."""
98
  total_images, flower_counts = count_training_images()
99
+
100
  if total_images == 0:
101
  return "No training images found. Add images to subdirectories in training_data/images/"
102
+
103
  result = f"**Total images: {total_images}**\n\n"
104
  for flower_type, count in sorted(flower_counts.items()):
105
  result += f"- {flower_type}: {count} images\n"
106
+
107
  return result
108
+
109
  def _refresh_models(self) -> gr.Dropdown:
110
  """Refresh the list of available models."""
111
  return gr.Dropdown(choices=flower_classifier.get_available_models())
112
+
113
  def _load_trained_model(self, model_selection: str) -> str:
114
  """Load the selected trained model."""
115
  return flower_classifier.load_trained_model(model_selection)
116
+
117
+ def _start_training(
118
+ self, epochs: int, batch_size: int, learning_rate: float
119
+ ) -> str:
120
  """Start the training process."""
121
  return training_service.start_training(
122
+ epochs=epochs, batch_size=batch_size, learning_rate=learning_rate
123
+ )
 
 
src/utils/__init__.py CHANGED
@@ -1 +1 @@
1
- # Utils package
 
1
+ # Utils package
src/utils/color_utils.py CHANGED
@@ -2,38 +2,42 @@
2
  Color analysis utilities.
3
  """
4
 
 
5
  import numpy as np
6
  from PIL import Image
7
  from sklearn.cluster import KMeans
8
- from typing import List, Tuple, Optional
9
 
10
- def extract_dominant_colors(image: Optional[Image.Image], num_colors: int = 5) -> Tuple[List[str], np.ndarray]:
 
 
 
11
  """Extract dominant colors from an image using k-means clustering."""
12
  if image is None:
13
  return [], np.array([])
14
-
15
  # Convert PIL image to numpy array
16
  img_array = np.array(image)
17
-
18
  # Reshape image to be a list of pixels
19
  pixels = img_array.reshape(-1, 3)
20
-
21
  # Use k-means to find dominant colors
22
  kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
23
  kmeans.fit(pixels)
24
-
25
  # Get the colors and convert to RGB values
26
  colors = kmeans.cluster_centers_.astype(int)
27
-
28
  # Convert to color names/descriptions
29
  color_names = [_rgb_to_color_name(color) for color in colors]
30
-
31
  return color_names, colors
32
 
 
33
  def _rgb_to_color_name(color: np.ndarray) -> str:
34
  """Convert RGB values to descriptive color name."""
35
  r, g, b = color
36
-
37
  if r > 200 and g > 200 and b > 200:
38
  return "white"
39
  elif r < 50 and g < 50 and b < 50:
@@ -56,4 +60,4 @@ def _rgb_to_color_name(color: np.ndarray) -> str:
56
  elif r > 150 and g > 100 and b < 100:
57
  return "orange"
58
  else:
59
- return "cream"
 
2
  Color analysis utilities.
3
  """
4
 
5
+
6
  import numpy as np
7
  from PIL import Image
8
  from sklearn.cluster import KMeans
 
9
 
10
+
11
+ def extract_dominant_colors(
12
+ image: Image.Image | None, num_colors: int = 5
13
+ ) -> tuple[list[str], np.ndarray]:
14
  """Extract dominant colors from an image using k-means clustering."""
15
  if image is None:
16
  return [], np.array([])
17
+
18
  # Convert PIL image to numpy array
19
  img_array = np.array(image)
20
+
21
  # Reshape image to be a list of pixels
22
  pixels = img_array.reshape(-1, 3)
23
+
24
  # Use k-means to find dominant colors
25
  kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
26
  kmeans.fit(pixels)
27
+
28
  # Get the colors and convert to RGB values
29
  colors = kmeans.cluster_centers_.astype(int)
30
+
31
  # Convert to color names/descriptions
32
  color_names = [_rgb_to_color_name(color) for color in colors]
33
+
34
  return color_names, colors
35
 
36
+
37
  def _rgb_to_color_name(color: np.ndarray) -> str:
38
  """Convert RGB values to descriptive color name."""
39
  r, g, b = color
40
+
41
  if r > 200 and g > 200 and b > 200:
42
  return "white"
43
  elif r < 50 and g < 50 and b < 50:
 
60
  elif r > 150 and g > 100 and b < 100:
61
  return "orange"
62
  else:
63
+ return "cream"
src/utils/file_utils.py CHANGED
@@ -2,19 +2,21 @@
2
  File and directory utilities.
3
  """
4
 
5
- import os
6
  import glob
7
- from typing import List, Tuple
 
8
  try:
9
- from ..core.constants import SUPPORTED_IMAGE_EXTENSIONS, IMAGES_DIR, MODELS_DIR
10
  except ImportError:
11
  # Handle direct execution
12
- import sys
13
  import os
 
 
14
  sys.path.append(os.path.dirname(os.path.dirname(__file__)))
15
- from core.constants import SUPPORTED_IMAGE_EXTENSIONS, IMAGES_DIR, MODELS_DIR
 
16
 
17
- def get_image_files(directory: str) -> List[str]:
18
  """Get all image files from a directory."""
19
  image_files = []
20
  for ext in SUPPORTED_IMAGE_EXTENSIONS:
@@ -22,11 +24,12 @@ def get_image_files(directory: str) -> List[str]:
22
  image_files.extend(glob.glob(pattern))
23
  return image_files
24
 
25
- def get_flower_types_from_directory(image_dir: str = IMAGES_DIR) -> List[str]:
 
26
  """Auto-detect flower types from directory structure."""
27
  if not os.path.exists(image_dir):
28
  return []
29
-
30
  detected_types = []
31
  for item in os.listdir(image_dir):
32
  item_path = os.path.join(image_dir, item)
@@ -34,17 +37,18 @@ def get_flower_types_from_directory(image_dir: str = IMAGES_DIR) -> List[str]:
34
  image_files = get_image_files(item_path)
35
  if image_files: # Only add if there are images
36
  detected_types.append(item)
37
-
38
  return sorted(detected_types)
39
 
40
- def count_training_images() -> Tuple[int, dict]:
 
41
  """Count training images by flower type."""
42
  if not os.path.exists(IMAGES_DIR):
43
  return 0, {}
44
-
45
  total_images = 0
46
  flower_counts = {}
47
-
48
  for flower_type in os.listdir(IMAGES_DIR):
49
  flower_path = os.path.join(IMAGES_DIR, flower_type)
50
  if os.path.isdir(flower_path):
@@ -53,18 +57,21 @@ def count_training_images() -> Tuple[int, dict]:
53
  if count > 0:
54
  flower_counts[flower_type] = count
55
  total_images += count
56
-
57
  return total_images, flower_counts
58
 
59
- def get_available_trained_models() -> List[str]:
 
60
  """Get list of available trained models."""
61
  if not os.path.exists(MODELS_DIR):
62
  return []
63
-
64
  models = []
65
  for item in os.listdir(MODELS_DIR):
66
  model_path = os.path.join(MODELS_DIR, item)
67
- if os.path.isdir(model_path) and os.path.exists(os.path.join(model_path, "config.json")):
 
 
68
  models.append(item)
69
-
70
- return sorted(models)
 
2
  File and directory utilities.
3
  """
4
 
 
5
  import glob
6
+ import os
7
+
8
  try:
9
+ from ..core.constants import IMAGES_DIR, MODELS_DIR, SUPPORTED_IMAGE_EXTENSIONS
10
  except ImportError:
11
  # Handle direct execution
 
12
  import os
13
+ import sys
14
+
15
  sys.path.append(os.path.dirname(os.path.dirname(__file__)))
16
+ from core.constants import IMAGES_DIR, MODELS_DIR, SUPPORTED_IMAGE_EXTENSIONS
17
+
18
 
19
+ def get_image_files(directory: str) -> list[str]:
20
  """Get all image files from a directory."""
21
  image_files = []
22
  for ext in SUPPORTED_IMAGE_EXTENSIONS:
 
24
  image_files.extend(glob.glob(pattern))
25
  return image_files
26
 
27
+
28
+ def get_flower_types_from_directory(image_dir: str = IMAGES_DIR) -> list[str]:
29
  """Auto-detect flower types from directory structure."""
30
  if not os.path.exists(image_dir):
31
  return []
32
+
33
  detected_types = []
34
  for item in os.listdir(image_dir):
35
  item_path = os.path.join(image_dir, item)
 
37
  image_files = get_image_files(item_path)
38
  if image_files: # Only add if there are images
39
  detected_types.append(item)
40
+
41
  return sorted(detected_types)
42
 
43
+
44
+ def count_training_images() -> tuple[int, dict]:
45
  """Count training images by flower type."""
46
  if not os.path.exists(IMAGES_DIR):
47
  return 0, {}
48
+
49
  total_images = 0
50
  flower_counts = {}
51
+
52
  for flower_type in os.listdir(IMAGES_DIR):
53
  flower_path = os.path.join(IMAGES_DIR, flower_type)
54
  if os.path.isdir(flower_path):
 
57
  if count > 0:
58
  flower_counts[flower_type] = count
59
  total_images += count
60
+
61
  return total_images, flower_counts
62
 
63
+
64
+ def get_available_trained_models() -> list[str]:
65
  """Get list of available trained models."""
66
  if not os.path.exists(MODELS_DIR):
67
  return []
68
+
69
  models = []
70
  for item in os.listdir(MODELS_DIR):
71
  model_path = os.path.join(MODELS_DIR, item)
72
+ if os.path.isdir(model_path) and os.path.exists(
73
+ os.path.join(model_path, "config.json")
74
+ ):
75
  models.append(item)
76
+
77
+ return sorted(models)
test_external_cache.py CHANGED
@@ -5,59 +5,60 @@ import os
5
  import sys
6
  from pathlib import Path
7
 
 
8
  def test_cache_configuration():
9
  """Test that the external cache configuration is working."""
10
-
11
  print("πŸ§ͺ Testing External SSD Cache Configuration")
12
  print("=" * 50)
13
-
14
  # Check if external SSD is mounted
15
  external_path = Path("/Volumes/extssd")
16
  if not external_path.exists():
17
  print("❌ External SSD not found at /Volumes/extssd")
18
  return False
19
-
20
  print("βœ… External SSD is mounted")
21
-
22
  # Check if HF_HOME is set correctly
23
  hf_home = os.environ.get("HF_HOME")
24
  expected_hf_home = "/Volumes/extssd/huggingface"
25
-
26
  if hf_home != expected_hf_home:
27
- print(f"⚠️ HF_HOME not set correctly. Expected: {expected_hf_home}, Got: {hf_home}")
 
 
28
  print(" Set HF_HOME with: export HF_HOME=/Volumes/extssd/huggingface")
29
  return False
30
-
31
  print(f"βœ… HF_HOME correctly set to: {hf_home}")
32
-
33
  # Check if cache directories exist
34
  hub_cache = Path(hf_home) / "hub"
35
  if not hub_cache.exists():
36
  print(f"❌ Hub cache directory not found at: {hub_cache}")
37
  return False
38
-
39
  print(f"βœ… Hub cache directory exists at: {hub_cache}")
40
-
41
  # Check if models are present
42
  model_count = len(list(hub_cache.glob("models--*")))
43
  print(f"βœ… Found {model_count} models in cache")
44
-
45
  # Test importing Hugging Face libraries and check their cache detection
46
  try:
47
- from huggingface_hub import HfFolder
48
  from transformers import AutoTokenizer
49
- from diffusers import DiffusionPipeline
50
-
51
  print("βœ… Hugging Face libraries imported successfully")
52
-
53
  # Test a small model to verify cache is working
54
  print("πŸ”„ Testing cache with a small model (this may take a moment)...")
55
-
56
  # This should use the external cache
57
  tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
58
-
59
  print("βœ… Successfully loaded model from cache")
60
-
61
  # Check if the model files are in the expected location
62
  clip_path = hub_cache / "models--openai--clip-vit-base-patch32"
63
  if clip_path.exists():
@@ -65,13 +66,14 @@ def test_cache_configuration():
65
  else:
66
  print(f"⚠️ Model files not found at expected location: {clip_path}")
67
  return False
68
-
69
  return True
70
-
71
  except Exception as e:
72
  print(f"❌ Error loading model: {e}")
73
  return False
74
 
 
75
  def main():
76
  """Main test function."""
77
  # Load .env file if available
@@ -79,12 +81,13 @@ def main():
79
  if env_file.exists():
80
  print("πŸ“ Loading .env file...")
81
  from dotenv import load_dotenv
 
82
  load_dotenv()
83
  else:
84
  print("⚠️ No .env file found, using system environment variables")
85
-
86
  success = test_cache_configuration()
87
-
88
  print("\n" + "=" * 50)
89
  if success:
90
  print("πŸŽ‰ All tests passed! External SSD cache is working correctly.")
@@ -93,5 +96,6 @@ def main():
93
  print("❌ Some tests failed. Please check the configuration.")
94
  sys.exit(1)
95
 
 
96
  if __name__ == "__main__":
97
- main()
 
5
  import sys
6
  from pathlib import Path
7
 
8
+
9
  def test_cache_configuration():
10
  """Test that the external cache configuration is working."""
11
+
12
  print("πŸ§ͺ Testing External SSD Cache Configuration")
13
  print("=" * 50)
14
+
15
  # Check if external SSD is mounted
16
  external_path = Path("/Volumes/extssd")
17
  if not external_path.exists():
18
  print("❌ External SSD not found at /Volumes/extssd")
19
  return False
20
+
21
  print("βœ… External SSD is mounted")
22
+
23
  # Check if HF_HOME is set correctly
24
  hf_home = os.environ.get("HF_HOME")
25
  expected_hf_home = "/Volumes/extssd/huggingface"
26
+
27
  if hf_home != expected_hf_home:
28
+ print(
29
+ f"⚠️ HF_HOME not set correctly. Expected: {expected_hf_home}, Got: {hf_home}"
30
+ )
31
  print(" Set HF_HOME with: export HF_HOME=/Volumes/extssd/huggingface")
32
  return False
33
+
34
  print(f"βœ… HF_HOME correctly set to: {hf_home}")
35
+
36
  # Check if cache directories exist
37
  hub_cache = Path(hf_home) / "hub"
38
  if not hub_cache.exists():
39
  print(f"❌ Hub cache directory not found at: {hub_cache}")
40
  return False
41
+
42
  print(f"βœ… Hub cache directory exists at: {hub_cache}")
43
+
44
  # Check if models are present
45
  model_count = len(list(hub_cache.glob("models--*")))
46
  print(f"βœ… Found {model_count} models in cache")
47
+
48
  # Test importing Hugging Face libraries and check their cache detection
49
  try:
 
50
  from transformers import AutoTokenizer
51
+
 
52
  print("βœ… Hugging Face libraries imported successfully")
53
+
54
  # Test a small model to verify cache is working
55
  print("πŸ”„ Testing cache with a small model (this may take a moment)...")
56
+
57
  # This should use the external cache
58
  tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
59
+
60
  print("βœ… Successfully loaded model from cache")
61
+
62
  # Check if the model files are in the expected location
63
  clip_path = hub_cache / "models--openai--clip-vit-base-patch32"
64
  if clip_path.exists():
 
66
  else:
67
  print(f"⚠️ Model files not found at expected location: {clip_path}")
68
  return False
69
+
70
  return True
71
+
72
  except Exception as e:
73
  print(f"❌ Error loading model: {e}")
74
  return False
75
 
76
+
77
  def main():
78
  """Main test function."""
79
  # Load .env file if available
 
81
  if env_file.exists():
82
  print("πŸ“ Loading .env file...")
83
  from dotenv import load_dotenv
84
+
85
  load_dotenv()
86
  else:
87
  print("⚠️ No .env file found, using system environment variables")
88
+
89
  success = test_cache_configuration()
90
+
91
  print("\n" + "=" * 50)
92
  if success:
93
  print("πŸŽ‰ All tests passed! External SSD cache is working correctly.")
 
96
  print("❌ Some tests failed. Please check the configuration.")
97
  sys.exit(1)
98
 
99
+
100
  if __name__ == "__main__":
101
+ main()
tests/__init__.py CHANGED
@@ -1 +1 @@
1
- """Tests package for Flowerfy application."""
 
1
+ """Tests package for Flowerfy application."""
tests/test_models.py CHANGED
@@ -15,18 +15,22 @@ from PIL import Image
15
  sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "src"))
16
 
17
  # Import all required modules - if any fail, the script will fail immediately
18
- from transformers import ConvNextForImageClassification, ConvNextImageProcessor, pipeline
 
 
 
 
19
 
20
  from core.constants import DEFAULT_CLIP_MODEL, DEFAULT_CONVNEXT_MODEL
21
  from services.models.flower_classification import FlowerClassificationService
22
- from services.models.image_generation import ImageGenerationService
23
 
24
  print("βœ… All dependencies imported successfully")
25
 
 
26
  def test_convnext_model() -> bool:
27
  """Test ConvNeXt model loading."""
28
  print("1️⃣ Testing ConvNeXt model loading...")
29
-
30
  try:
31
  print(f"Loading ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}")
32
  model = ConvNextForImageClassification.from_pretrained(DEFAULT_CONVNEXT_MODEL)
@@ -38,19 +42,23 @@ def test_convnext_model() -> bool:
38
  print(f"❌ ConvNeXt model test failed: {e}")
39
  return False
40
 
 
41
  def test_clip_model() -> bool:
42
  """Test CLIP model loading."""
43
  print("\n2️⃣ Testing CLIP model loading...")
44
-
45
  try:
46
  print(f"Loading CLIP model: {DEFAULT_CLIP_MODEL}")
47
- classifier = pipeline('zero-shot-image-classification', model=DEFAULT_CLIP_MODEL)
 
 
48
  print("βœ… CLIP model loaded successfully")
49
  return True
50
  except Exception as e:
51
  print(f"❌ CLIP model test failed: {e}")
52
  return False
53
 
 
54
  def test_image_generation_models() -> bool:
55
  """Test image generation models (SDXL models)."""
56
  print("\n3️⃣ Testing image generation models...")
@@ -59,42 +67,50 @@ def test_image_generation_models() -> bool:
59
  # Test SDXL first (now primary)
60
  sdxl_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
61
  print(f"Testing SDXL model (primary): {sdxl_model_id}")
62
-
63
  try:
64
  from diffusers import AutoPipelineForText2Image
65
- pipe = AutoPipelineForText2Image.from_pretrained(sdxl_model_id, torch_dtype=torch.float32).to("cpu")
 
 
 
66
  print("βœ… SDXL model loaded successfully")
67
  return True
68
  except Exception as sdxl_error:
69
  print(f"⚠️ SDXL model failed: {sdxl_error}")
70
-
71
  # Test SDXL-Turbo fallback
72
  turbo_model_id = "stabilityai/sdxl-turbo"
73
  print(f"Testing SDXL-Turbo fallback: {turbo_model_id}")
74
-
75
  try:
76
- pipe = AutoPipelineForText2Image.from_pretrained(turbo_model_id, torch_dtype=torch.float32).to("cpu")
 
 
77
  print("βœ… SDXL-Turbo model loaded successfully as fallback")
78
  return True
79
  except Exception as turbo_error:
80
  print(f"❌ Both SDXL models failed: {turbo_error}")
81
  return False
82
-
83
  except Exception as e:
84
  print(f"❌ Image generation model test failed: {e}")
85
  return False
86
 
 
87
  def test_flower_classification_service() -> bool:
88
  """Test flower classification service."""
89
  print("\n4️⃣ Testing flower classification service...")
90
-
91
  try:
92
  print("Initializing flower classification service...")
93
  classifier = FlowerClassificationService()
94
-
95
  # Create a dummy test image (3-channel RGB)
96
- test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
97
-
 
 
98
  # Test classification
99
  results, message = classifier.identify_flowers(test_image, top_k=3)
100
  print(f"βœ… Classification service working: {message}")
@@ -104,10 +120,11 @@ def test_flower_classification_service() -> bool:
104
  print(f"❌ Classification service test failed: {e}")
105
  return False
106
 
 
107
  def test_image_generation_service() -> bool:
108
  """Test image generation service initialization."""
109
  print("\n5️⃣ Testing image generation service initialization...")
110
-
111
  try:
112
  print("Testing image generation service initialization...")
113
  # This will test if the service can be imported and initialized
@@ -119,11 +136,12 @@ def test_image_generation_service() -> bool:
119
  print(f"❌ Image generation service test failed: {e}")
120
  return False
121
 
 
122
  def main():
123
  """Run all model tests."""
124
  print("πŸ§ͺ Testing Flowerfy models...")
125
  print("==============================")
126
-
127
  tests = [
128
  ("ConvNeXt Model", test_convnext_model),
129
  ("CLIP Model", test_clip_model),
@@ -131,10 +149,10 @@ def main():
131
  ("Classification Service", test_flower_classification_service),
132
  ("Generation Service", test_image_generation_service),
133
  ]
134
-
135
  passed = 0
136
  failed = 0
137
-
138
  for test_name, test_func in tests:
139
  try:
140
  if test_func():
@@ -145,11 +163,11 @@ def main():
145
  except Exception as e:
146
  failed += 1
147
  print(f"❌ {test_name} test failed with exception: {e}")
148
-
149
- print(f"\nπŸ“Š Test Results:")
150
  print(f"βœ… Passed: {passed}")
151
  print(f"❌ Failed: {failed}")
152
-
153
  if failed == 0:
154
  print("\nπŸŽ‰ All model tests passed successfully!")
155
  print("======================================")
@@ -167,6 +185,7 @@ def main():
167
  print(f"\n❌ {failed} test(s) failed. Please check the errors above.")
168
  return False
169
 
 
170
  if __name__ == "__main__":
171
  success = main()
172
- sys.exit(0 if success else 1)
 
15
  sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "src"))
16
 
17
  # Import all required modules - if any fail, the script will fail immediately
18
+ from transformers import (
19
+ ConvNextForImageClassification,
20
+ ConvNextImageProcessor,
21
+ pipeline,
22
+ )
23
 
24
  from core.constants import DEFAULT_CLIP_MODEL, DEFAULT_CONVNEXT_MODEL
25
  from services.models.flower_classification import FlowerClassificationService
 
26
 
27
  print("βœ… All dependencies imported successfully")
28
 
29
+
30
  def test_convnext_model() -> bool:
31
  """Test ConvNeXt model loading."""
32
  print("1️⃣ Testing ConvNeXt model loading...")
33
+
34
  try:
35
  print(f"Loading ConvNeXt model: {DEFAULT_CONVNEXT_MODEL}")
36
  model = ConvNextForImageClassification.from_pretrained(DEFAULT_CONVNEXT_MODEL)
 
42
  print(f"❌ ConvNeXt model test failed: {e}")
43
  return False
44
 
45
+
46
  def test_clip_model() -> bool:
47
  """Test CLIP model loading."""
48
  print("\n2️⃣ Testing CLIP model loading...")
49
+
50
  try:
51
  print(f"Loading CLIP model: {DEFAULT_CLIP_MODEL}")
52
+ classifier = pipeline(
53
+ "zero-shot-image-classification", model=DEFAULT_CLIP_MODEL
54
+ )
55
  print("βœ… CLIP model loaded successfully")
56
  return True
57
  except Exception as e:
58
  print(f"❌ CLIP model test failed: {e}")
59
  return False
60
 
61
+
62
  def test_image_generation_models() -> bool:
63
  """Test image generation models (SDXL models)."""
64
  print("\n3️⃣ Testing image generation models...")
 
67
  # Test SDXL first (now primary)
68
  sdxl_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
69
  print(f"Testing SDXL model (primary): {sdxl_model_id}")
70
+
71
  try:
72
  from diffusers import AutoPipelineForText2Image
73
+
74
+ pipe = AutoPipelineForText2Image.from_pretrained(
75
+ sdxl_model_id, torch_dtype=torch.float32
76
+ ).to("cpu")
77
  print("βœ… SDXL model loaded successfully")
78
  return True
79
  except Exception as sdxl_error:
80
  print(f"⚠️ SDXL model failed: {sdxl_error}")
81
+
82
  # Test SDXL-Turbo fallback
83
  turbo_model_id = "stabilityai/sdxl-turbo"
84
  print(f"Testing SDXL-Turbo fallback: {turbo_model_id}")
85
+
86
  try:
87
+ pipe = AutoPipelineForText2Image.from_pretrained(
88
+ turbo_model_id, torch_dtype=torch.float32
89
+ ).to("cpu")
90
  print("βœ… SDXL-Turbo model loaded successfully as fallback")
91
  return True
92
  except Exception as turbo_error:
93
  print(f"❌ Both SDXL models failed: {turbo_error}")
94
  return False
95
+
96
  except Exception as e:
97
  print(f"❌ Image generation model test failed: {e}")
98
  return False
99
 
100
+
101
  def test_flower_classification_service() -> bool:
102
  """Test flower classification service."""
103
  print("\n4️⃣ Testing flower classification service...")
104
+
105
  try:
106
  print("Initializing flower classification service...")
107
  classifier = FlowerClassificationService()
108
+
109
  # Create a dummy test image (3-channel RGB)
110
+ test_image = Image.fromarray(
111
+ np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
112
+ )
113
+
114
  # Test classification
115
  results, message = classifier.identify_flowers(test_image, top_k=3)
116
  print(f"βœ… Classification service working: {message}")
 
120
  print(f"❌ Classification service test failed: {e}")
121
  return False
122
 
123
+
124
  def test_image_generation_service() -> bool:
125
  """Test image generation service initialization."""
126
  print("\n5️⃣ Testing image generation service initialization...")
127
+
128
  try:
129
  print("Testing image generation service initialization...")
130
  # This will test if the service can be imported and initialized
 
136
  print(f"❌ Image generation service test failed: {e}")
137
  return False
138
 
139
+
140
  def main():
141
  """Run all model tests."""
142
  print("πŸ§ͺ Testing Flowerfy models...")
143
  print("==============================")
144
+
145
  tests = [
146
  ("ConvNeXt Model", test_convnext_model),
147
  ("CLIP Model", test_clip_model),
 
149
  ("Classification Service", test_flower_classification_service),
150
  ("Generation Service", test_image_generation_service),
151
  ]
152
+
153
  passed = 0
154
  failed = 0
155
+
156
  for test_name, test_func in tests:
157
  try:
158
  if test_func():
 
163
  except Exception as e:
164
  failed += 1
165
  print(f"❌ {test_name} test failed with exception: {e}")
166
+
167
+ print("\nπŸ“Š Test Results:")
168
  print(f"βœ… Passed: {passed}")
169
  print(f"❌ Failed: {failed}")
170
+
171
  if failed == 0:
172
  print("\nπŸŽ‰ All model tests passed successfully!")
173
  print("======================================")
 
185
  print(f"\n❌ {failed} test(s) failed. Please check the errors above.")
186
  return False
187
 
188
+
189
  if __name__ == "__main__":
190
  success = main()
191
+ sys.exit(0 if success else 1)
training/advanced_trainer.py CHANGED
@@ -4,26 +4,32 @@ Advanced ConvNeXt training script using Transformers Trainer.
4
  This provides more sophisticated training features like evaluation, checkpointing, and logging.
5
  """
6
 
 
 
7
  import os
 
8
  import torch
9
- import json
10
- from transformers import ConvNextImageProcessor, ConvNextForImageClassification, Trainer, TrainingArguments
11
  from dataset import FlowerDataset, advanced_collate_fn
12
- import argparse
 
 
 
 
 
13
 
14
 
15
  class ConvNeXtTrainer(Trainer):
16
  """Custom trainer for ConvNeXt with proper loss computation."""
17
-
18
  def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
19
  labels = inputs.get("labels")
20
  outputs = model(**inputs)
21
-
22
  if labels is not None:
23
  loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
24
  else:
25
  loss = outputs.loss
26
-
27
  return (loss, outputs) if return_outputs else loss
28
 
29
 
@@ -34,11 +40,11 @@ def advanced_train(
34
  num_epochs=5,
35
  batch_size=8,
36
  learning_rate=1e-5,
37
- flower_labels=None
38
  ):
39
  """
40
  Advanced training function using Transformers Trainer.
41
-
42
  Args:
43
  image_dir: Directory containing training images organized by flower type
44
  output_dir: Directory to save the trained model
@@ -47,43 +53,55 @@ def advanced_train(
47
  batch_size: Training batch size
48
  learning_rate: Learning rate for optimization
49
  flower_labels: List of flower labels (auto-detected if None)
50
-
51
  Returns:
52
  str: Path to the saved model directory, or None if training failed
53
  """
54
  print("🌸 Advanced ConvNeXt Flower Model Training")
55
  print("=" * 50)
56
-
57
  # Check training data
58
  if not os.path.exists(image_dir):
59
  print(f"❌ Training directory not found: {image_dir}")
60
  return None
61
-
62
  # Load model and processor
63
  print(f"Loading model: {model_name}")
64
  model = ConvNextForImageClassification.from_pretrained(model_name)
65
  processor = ConvNextImageProcessor.from_pretrained(model_name)
66
-
67
  # Create dataset
68
  dataset = FlowerDataset(image_dir, processor, flower_labels)
69
-
70
  if len(dataset) == 0:
71
- print("❌ No training data found. Please add images to subdirectories in training_data/images/")
72
- print("Example: training_data/images/roses/, training_data/images/tulips/, etc.")
 
 
 
 
73
  return None
74
-
75
  # Split dataset (80% train, 20% eval)
76
  train_size = int(0.8 * len(dataset))
77
  eval_size = len(dataset) - train_size
78
- train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])
79
-
 
 
80
  # Update model config for the number of classes
81
  if len(dataset.flower_labels) != model.config.num_labels:
82
  model.config.num_labels = len(dataset.flower_labels)
83
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
84
- final_hidden_size = model.config.hidden_sizes[-1] if hasattr(model.config, 'hidden_sizes') else 768
85
- model.classifier = torch.nn.Linear(final_hidden_size, len(dataset.flower_labels))
86
-
 
 
 
 
 
 
87
  # Training arguments
88
  training_args = TrainingArguments(
89
  output_dir=output_dir,
@@ -102,7 +120,7 @@ def advanced_train(
102
  dataloader_num_workers=0, # Set to 0 to avoid multiprocessing issues
103
  remove_unused_columns=False,
104
  )
105
-
106
  # Create trainer
107
  try:
108
  trainer = ConvNeXtTrainer(
@@ -116,7 +134,7 @@ def advanced_train(
116
  except Exception as e:
117
  print(f"❌ Error creating trainer: {e}")
118
  return None
119
-
120
  # Train model
121
  print("Starting advanced training...")
122
  try:
@@ -125,14 +143,15 @@ def advanced_train(
125
  except Exception as e:
126
  print(f"❌ Training failed: {e}")
127
  import traceback
 
128
  traceback.print_exc()
129
  return None
130
-
131
  # Save final model
132
  final_model_path = os.path.join(output_dir, "final_model")
133
  model.save_pretrained(final_model_path)
134
  processor.save_pretrained(final_model_path)
135
-
136
  # Save training config
137
  config = {
138
  "model_name": model_name,
@@ -142,27 +161,43 @@ def advanced_train(
142
  "learning_rate": learning_rate,
143
  "train_samples": len(train_dataset),
144
  "eval_samples": len(eval_dataset),
145
- "training_type": "advanced"
146
  }
147
-
148
  with open(os.path.join(final_model_path, "training_config.json"), "w") as f:
149
  json.dump(config, f, indent=2)
150
-
151
  print(f"βœ… Advanced training complete! Model saved to {final_model_path}")
152
  return final_model_path
153
 
154
 
155
  if __name__ == "__main__":
156
- parser = argparse.ArgumentParser(description="Advanced ConvNeXt training for flower classification")
157
- parser.add_argument("--image_dir", default="training_data/images", help="Directory containing training images")
158
- parser.add_argument("--output_dir", default="training_data/trained_models/advanced_trained", help="Output directory for trained model")
159
- parser.add_argument("--model_name", default="facebook/convnext-base-224-22k", help="Base model name")
160
- parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
162
- parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
163
-
 
 
164
  args = parser.parse_args()
165
-
166
  try:
167
  result = advanced_train(
168
  image_dir=args.image_dir,
@@ -170,7 +205,7 @@ if __name__ == "__main__":
170
  model_name=args.model_name,
171
  num_epochs=args.epochs,
172
  batch_size=args.batch_size,
173
- learning_rate=args.learning_rate
174
  )
175
  if not result:
176
  print("❌ Training failed!")
@@ -180,5 +215,6 @@ if __name__ == "__main__":
180
  except Exception as e:
181
  print(f"❌ Training failed: {e}")
182
  import traceback
 
183
  traceback.print_exc()
184
- exit(1)
 
4
  This provides more sophisticated training features like evaluation, checkpointing, and logging.
5
  """
6
 
7
+ import argparse
8
+ import json
9
  import os
10
+
11
  import torch
 
 
12
  from dataset import FlowerDataset, advanced_collate_fn
13
+ from transformers import (
14
+ ConvNextForImageClassification,
15
+ ConvNextImageProcessor,
16
+ Trainer,
17
+ TrainingArguments,
18
+ )
19
 
20
 
21
  class ConvNeXtTrainer(Trainer):
22
  """Custom trainer for ConvNeXt with proper loss computation."""
23
+
24
  def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
25
  labels = inputs.get("labels")
26
  outputs = model(**inputs)
27
+
28
  if labels is not None:
29
  loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
30
  else:
31
  loss = outputs.loss
32
+
33
  return (loss, outputs) if return_outputs else loss
34
 
35
 
 
40
  num_epochs=5,
41
  batch_size=8,
42
  learning_rate=1e-5,
43
+ flower_labels=None,
44
  ):
45
  """
46
  Advanced training function using Transformers Trainer.
47
+
48
  Args:
49
  image_dir: Directory containing training images organized by flower type
50
  output_dir: Directory to save the trained model
 
53
  batch_size: Training batch size
54
  learning_rate: Learning rate for optimization
55
  flower_labels: List of flower labels (auto-detected if None)
56
+
57
  Returns:
58
  str: Path to the saved model directory, or None if training failed
59
  """
60
  print("🌸 Advanced ConvNeXt Flower Model Training")
61
  print("=" * 50)
62
+
63
  # Check training data
64
  if not os.path.exists(image_dir):
65
  print(f"❌ Training directory not found: {image_dir}")
66
  return None
67
+
68
  # Load model and processor
69
  print(f"Loading model: {model_name}")
70
  model = ConvNextForImageClassification.from_pretrained(model_name)
71
  processor = ConvNextImageProcessor.from_pretrained(model_name)
72
+
73
  # Create dataset
74
  dataset = FlowerDataset(image_dir, processor, flower_labels)
75
+
76
  if len(dataset) == 0:
77
+ print(
78
+ "❌ No training data found. Please add images to subdirectories in training_data/images/"
79
+ )
80
+ print(
81
+ "Example: training_data/images/roses/, training_data/images/tulips/, etc."
82
+ )
83
  return None
84
+
85
  # Split dataset (80% train, 20% eval)
86
  train_size = int(0.8 * len(dataset))
87
  eval_size = len(dataset) - train_size
88
+ train_dataset, eval_dataset = torch.utils.data.random_split(
89
+ dataset, [train_size, eval_size]
90
+ )
91
+
92
  # Update model config for the number of classes
93
  if len(dataset.flower_labels) != model.config.num_labels:
94
  model.config.num_labels = len(dataset.flower_labels)
95
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
96
+ final_hidden_size = (
97
+ model.config.hidden_sizes[-1]
98
+ if hasattr(model.config, "hidden_sizes")
99
+ else 768
100
+ )
101
+ model.classifier = torch.nn.Linear(
102
+ final_hidden_size, len(dataset.flower_labels)
103
+ )
104
+
105
  # Training arguments
106
  training_args = TrainingArguments(
107
  output_dir=output_dir,
 
120
  dataloader_num_workers=0, # Set to 0 to avoid multiprocessing issues
121
  remove_unused_columns=False,
122
  )
123
+
124
  # Create trainer
125
  try:
126
  trainer = ConvNeXtTrainer(
 
134
  except Exception as e:
135
  print(f"❌ Error creating trainer: {e}")
136
  return None
137
+
138
  # Train model
139
  print("Starting advanced training...")
140
  try:
 
143
  except Exception as e:
144
  print(f"❌ Training failed: {e}")
145
  import traceback
146
+
147
  traceback.print_exc()
148
  return None
149
+
150
  # Save final model
151
  final_model_path = os.path.join(output_dir, "final_model")
152
  model.save_pretrained(final_model_path)
153
  processor.save_pretrained(final_model_path)
154
+
155
  # Save training config
156
  config = {
157
  "model_name": model_name,
 
161
  "learning_rate": learning_rate,
162
  "train_samples": len(train_dataset),
163
  "eval_samples": len(eval_dataset),
164
+ "training_type": "advanced",
165
  }
166
+
167
  with open(os.path.join(final_model_path, "training_config.json"), "w") as f:
168
  json.dump(config, f, indent=2)
169
+
170
  print(f"βœ… Advanced training complete! Model saved to {final_model_path}")
171
  return final_model_path
172
 
173
 
174
  if __name__ == "__main__":
175
+ parser = argparse.ArgumentParser(
176
+ description="Advanced ConvNeXt training for flower classification"
177
+ )
178
+ parser.add_argument(
179
+ "--image_dir",
180
+ default="training_data/images",
181
+ help="Directory containing training images",
182
+ )
183
+ parser.add_argument(
184
+ "--output_dir",
185
+ default="training_data/trained_models/advanced_trained",
186
+ help="Output directory for trained model",
187
+ )
188
+ parser.add_argument(
189
+ "--model_name", default="facebook/convnext-base-224-22k", help="Base model name"
190
+ )
191
+ parser.add_argument(
192
+ "--epochs", type=int, default=5, help="Number of training epochs"
193
+ )
194
  parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
195
+ parser.add_argument(
196
+ "--learning_rate", type=float, default=1e-5, help="Learning rate"
197
+ )
198
+
199
  args = parser.parse_args()
200
+
201
  try:
202
  result = advanced_train(
203
  image_dir=args.image_dir,
 
205
  model_name=args.model_name,
206
  num_epochs=args.epochs,
207
  batch_size=args.batch_size,
208
+ learning_rate=args.learning_rate,
209
  )
210
  if not result:
211
  print("❌ Training failed!")
 
215
  except Exception as e:
216
  print(f"❌ Training failed: {e}")
217
  import traceback
218
+
219
  traceback.print_exc()
220
+ exit(1)
training/dataset.py CHANGED
@@ -3,9 +3,10 @@
3
  Flower Dataset class for training ConvNeXt models.
4
  """
5
 
 
6
  import os
 
7
  import torch
8
- import glob
9
  from PIL import Image
10
  from torch.utils.data import Dataset
11
 
@@ -15,7 +16,7 @@ class FlowerDataset(Dataset):
15
  self.image_paths = []
16
  self.labels = []
17
  self.processor = processor
18
-
19
  # Auto-detect flower types from directory structure if not provided
20
  if flower_labels is None:
21
  detected_types = []
@@ -28,22 +29,24 @@ class FlowerDataset(Dataset):
28
  self.flower_labels = sorted(detected_types)
29
  else:
30
  self.flower_labels = flower_labels
31
-
32
  self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
33
-
34
  # Load images from subdirectories (organized by flower type)
35
  for flower_type in os.listdir(image_dir):
36
  flower_path = os.path.join(image_dir, flower_type)
37
  if os.path.isdir(flower_path) and flower_type in self.label_to_id:
38
  image_files = self._get_image_files(flower_path)
39
-
40
  for img_path in image_files:
41
  self.image_paths.append(img_path)
42
  self.labels.append(self.label_to_id[flower_type])
43
-
44
- print(f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types")
 
 
45
  print(f"Flower types: {self.flower_labels}")
46
-
47
  def _get_image_files(self, directory):
48
  """Get all supported image files from directory."""
49
  extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
@@ -52,21 +55,21 @@ class FlowerDataset(Dataset):
52
  image_files.extend(glob.glob(os.path.join(directory, ext)))
53
  image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
54
  return image_files
55
-
56
  def __len__(self):
57
  return len(self.image_paths)
58
-
59
  def __getitem__(self, idx):
60
  image_path = self.image_paths[idx]
61
  image = Image.open(image_path).convert("RGB")
62
  label = self.labels[idx]
63
-
64
  # Process image for ConvNeXt
65
  inputs = self.processor(images=image, return_tensors="pt")
66
-
67
  return {
68
- 'pixel_values': inputs['pixel_values'].squeeze(),
69
- 'labels': torch.tensor(label, dtype=torch.long)
70
  }
71
 
72
 
@@ -74,29 +77,24 @@ def simple_collate_fn(batch):
74
  """Simple collation function for training."""
75
  pixel_values = []
76
  labels = []
77
-
78
  for item in batch:
79
- pixel_values.append(item['pixel_values'])
80
- labels.append(item['labels'])
81
-
82
- return {
83
- 'pixel_values': torch.stack(pixel_values),
84
- 'labels': torch.stack(labels)
85
- }
86
 
87
 
88
  def advanced_collate_fn(batch):
89
  """Advanced collation function for Trainer."""
90
  # Extract components
91
- pixel_values = [item['pixel_values'] for item in batch]
92
- labels = [item['labels'] for item in batch if 'labels' in item]
93
-
94
  # Stack everything
95
- result = {
96
- 'pixel_values': torch.stack(pixel_values)
97
- }
98
-
99
  if labels:
100
- result['labels'] = torch.stack(labels)
101
-
102
- return result
 
3
  Flower Dataset class for training ConvNeXt models.
4
  """
5
 
6
+ import glob
7
  import os
8
+
9
  import torch
 
10
  from PIL import Image
11
  from torch.utils.data import Dataset
12
 
 
16
  self.image_paths = []
17
  self.labels = []
18
  self.processor = processor
19
+
20
  # Auto-detect flower types from directory structure if not provided
21
  if flower_labels is None:
22
  detected_types = []
 
29
  self.flower_labels = sorted(detected_types)
30
  else:
31
  self.flower_labels = flower_labels
32
+
33
  self.label_to_id = {label: idx for idx, label in enumerate(self.flower_labels)}
34
+
35
  # Load images from subdirectories (organized by flower type)
36
  for flower_type in os.listdir(image_dir):
37
  flower_path = os.path.join(image_dir, flower_type)
38
  if os.path.isdir(flower_path) and flower_type in self.label_to_id:
39
  image_files = self._get_image_files(flower_path)
40
+
41
  for img_path in image_files:
42
  self.image_paths.append(img_path)
43
  self.labels.append(self.label_to_id[flower_type])
44
+
45
+ print(
46
+ f"Loaded {len(self.image_paths)} images from {len(set(self.labels))} flower types"
47
+ )
48
  print(f"Flower types: {self.flower_labels}")
49
+
50
  def _get_image_files(self, directory):
51
  """Get all supported image files from directory."""
52
  extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
 
55
  image_files.extend(glob.glob(os.path.join(directory, ext)))
56
  image_files.extend(glob.glob(os.path.join(directory, ext.upper())))
57
  return image_files
58
+
59
  def __len__(self):
60
  return len(self.image_paths)
61
+
62
  def __getitem__(self, idx):
63
  image_path = self.image_paths[idx]
64
  image = Image.open(image_path).convert("RGB")
65
  label = self.labels[idx]
66
+
67
  # Process image for ConvNeXt
68
  inputs = self.processor(images=image, return_tensors="pt")
69
+
70
  return {
71
+ "pixel_values": inputs["pixel_values"].squeeze(),
72
+ "labels": torch.tensor(label, dtype=torch.long),
73
  }
74
 
75
 
 
77
  """Simple collation function for training."""
78
  pixel_values = []
79
  labels = []
80
+
81
  for item in batch:
82
+ pixel_values.append(item["pixel_values"])
83
+ labels.append(item["labels"])
84
+
85
+ return {"pixel_values": torch.stack(pixel_values), "labels": torch.stack(labels)}
 
 
 
86
 
87
 
88
  def advanced_collate_fn(batch):
89
  """Advanced collation function for Trainer."""
90
  # Extract components
91
+ pixel_values = [item["pixel_values"] for item in batch]
92
+ labels = [item["labels"] for item in batch if "labels" in item]
93
+
94
  # Stack everything
95
+ result = {"pixel_values": torch.stack(pixel_values)}
96
+
 
 
97
  if labels:
98
+ result["labels"] = torch.stack(labels)
99
+
100
+ return result
training/simple_trainer.py CHANGED
@@ -4,13 +4,13 @@ Simple ConvNeXt training script without using the Transformers Trainer class.
4
  This is a lightweight training implementation for quick model fine-tuning.
5
  """
6
 
 
7
  import os
 
8
  import torch
9
- import torch.nn as nn
10
- from torch.utils.data import DataLoader
11
- from transformers import ConvNextImageProcessor, ConvNextForImageClassification
12
  from dataset import FlowerDataset, simple_collate_fn
13
- import json
 
14
 
15
 
16
  def simple_train(
@@ -19,11 +19,11 @@ def simple_train(
19
  epochs=3,
20
  batch_size=4,
21
  learning_rate=1e-5,
22
- model_name="facebook/convnext-base-224-22k"
23
  ):
24
  """
25
  Simple training function for ConvNeXt flower classification.
26
-
27
  Args:
28
  image_dir: Directory containing training images organized by flower type
29
  output_dir: Directory to save the trained model
@@ -31,91 +31,107 @@ def simple_train(
31
  batch_size: Training batch size
32
  learning_rate: Learning rate for optimization
33
  model_name: Base ConvNeXt model to fine-tune
34
-
35
  Returns:
36
  str: Path to the saved model directory, or None if training failed
37
  """
38
  print("🌸 Simple ConvNeXt Flower Model Training")
39
  print("=" * 40)
40
-
41
  # Check training data
42
  if not os.path.exists(image_dir):
43
  print(f"❌ Training directory not found: {image_dir}")
44
  return None
45
-
46
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
47
  print(f"Using device: {device}")
48
-
49
  # Load model and processor
50
  print(f"Loading model: {model_name}")
51
  model = ConvNextForImageClassification.from_pretrained(model_name)
52
  processor = ConvNextImageProcessor.from_pretrained(model_name)
53
  model.to(device)
54
-
55
  # Create dataset
56
  dataset = FlowerDataset(image_dir, processor)
57
-
58
  if len(dataset) < 5:
59
  print("❌ Need at least 5 images for training")
60
  return None
61
-
62
  # Split dataset
63
  train_size = int(0.8 * len(dataset))
64
  train_dataset = torch.utils.data.Subset(dataset, range(train_size))
65
-
66
  # Update model config for the number of classes
67
  if len(dataset.flower_labels) != model.config.num_labels:
68
  model.config.num_labels = len(dataset.flower_labels)
69
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
70
- final_hidden_size = model.config.hidden_sizes[-1] if hasattr(model.config, 'hidden_sizes') else 768
71
- model.classifier = torch.nn.Linear(final_hidden_size, len(dataset.flower_labels))
 
 
 
 
 
 
72
  model.classifier.to(device)
73
-
74
  # Create data loader
75
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=simple_collate_fn)
76
-
 
 
77
  # Setup optimizer
78
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
79
-
80
  # Training loop
81
  model.train()
82
  print(f"Starting training on {len(train_dataset)} samples for {epochs} epochs...")
83
-
84
  for epoch in range(epochs):
85
  total_loss = 0
86
  num_batches = 0
87
-
88
  for batch_idx, batch in enumerate(train_loader):
89
  # Move to device
90
- pixel_values = batch['pixel_values'].to(device)
91
- labels = batch['labels'].to(device)
92
-
93
  # Zero gradients
94
  optimizer.zero_grad()
95
-
96
  # Forward pass
97
  outputs = model(pixel_values=pixel_values, labels=labels)
98
  loss = outputs.loss
99
-
100
  # Backward pass
101
  loss.backward()
102
  optimizer.step()
103
-
104
  total_loss += loss.item()
105
  num_batches += 1
106
-
107
  if batch_idx % 2 == 0 or batch_idx == len(train_loader) - 1:
108
- print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(train_loader)}: Loss = {loss.item():.4f}")
109
-
 
 
110
  avg_loss = total_loss / num_batches if num_batches > 0 else 0
111
- print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
112
-
113
  # Save model
114
  os.makedirs(output_dir, exist_ok=True)
115
-
116
  model.save_pretrained(output_dir)
117
  processor.save_pretrained(output_dir)
118
-
119
  # Save config
120
  config = {
121
  "model_name": model_name,
@@ -125,29 +141,45 @@ def simple_train(
125
  "learning_rate": learning_rate,
126
  "train_samples": len(train_dataset),
127
  "num_labels": len(dataset.flower_labels),
128
- "training_type": "simple"
129
  }
130
-
131
  with open(os.path.join(output_dir, "training_config.json"), "w") as f:
132
  json.dump(config, f, indent=2)
133
-
134
  print(f"βœ… ConvNeXt training completed! Model saved to {output_dir}")
135
  return output_dir
136
 
137
 
138
  if __name__ == "__main__":
139
  import argparse
140
-
141
- parser = argparse.ArgumentParser(description="Simple ConvNeXt training for flower classification")
142
- parser.add_argument("--image_dir", default="training_data/images", help="Directory containing training images")
143
- parser.add_argument("--output_dir", default="training_data/trained_models/simple_trained", help="Output directory for trained model")
144
- parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
 
 
 
 
 
 
 
 
 
 
 
 
145
  parser.add_argument("--batch_size", type=int, default=4, help="Training batch size")
146
- parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
147
- parser.add_argument("--model_name", default="facebook/convnext-base-224-22k", help="Base model name")
148
-
 
 
 
 
149
  args = parser.parse_args()
150
-
151
  try:
152
  result = simple_train(
153
  image_dir=args.image_dir,
@@ -155,7 +187,7 @@ if __name__ == "__main__":
155
  epochs=args.epochs,
156
  batch_size=args.batch_size,
157
  learning_rate=args.learning_rate,
158
- model_name=args.model_name
159
  )
160
  if not result:
161
  print("❌ Training failed!")
@@ -165,5 +197,6 @@ if __name__ == "__main__":
165
  except Exception as e:
166
  print(f"❌ Training failed: {e}")
167
  import traceback
 
168
  traceback.print_exc()
169
- exit(1)
 
4
  This is a lightweight training implementation for quick model fine-tuning.
5
  """
6
 
7
+ import json
8
  import os
9
+
10
  import torch
 
 
 
11
  from dataset import FlowerDataset, simple_collate_fn
12
+ from torch.utils.data import DataLoader
13
+ from transformers import ConvNextForImageClassification, ConvNextImageProcessor
14
 
15
 
16
  def simple_train(
 
19
  epochs=3,
20
  batch_size=4,
21
  learning_rate=1e-5,
22
+ model_name="facebook/convnext-base-224-22k",
23
  ):
24
  """
25
  Simple training function for ConvNeXt flower classification.
26
+
27
  Args:
28
  image_dir: Directory containing training images organized by flower type
29
  output_dir: Directory to save the trained model
 
31
  batch_size: Training batch size
32
  learning_rate: Learning rate for optimization
33
  model_name: Base ConvNeXt model to fine-tune
34
+
35
  Returns:
36
  str: Path to the saved model directory, or None if training failed
37
  """
38
  print("🌸 Simple ConvNeXt Flower Model Training")
39
  print("=" * 40)
40
+
41
  # Check training data
42
  if not os.path.exists(image_dir):
43
  print(f"❌ Training directory not found: {image_dir}")
44
  return None
45
+
46
+ device = (
47
+ "cuda"
48
+ if torch.cuda.is_available()
49
+ else "mps"
50
+ if torch.backends.mps.is_available()
51
+ else "cpu"
52
+ )
53
  print(f"Using device: {device}")
54
+
55
  # Load model and processor
56
  print(f"Loading model: {model_name}")
57
  model = ConvNextForImageClassification.from_pretrained(model_name)
58
  processor = ConvNextImageProcessor.from_pretrained(model_name)
59
  model.to(device)
60
+
61
  # Create dataset
62
  dataset = FlowerDataset(image_dir, processor)
63
+
64
  if len(dataset) < 5:
65
  print("❌ Need at least 5 images for training")
66
  return None
67
+
68
  # Split dataset
69
  train_size = int(0.8 * len(dataset))
70
  train_dataset = torch.utils.data.Subset(dataset, range(train_size))
71
+
72
  # Update model config for the number of classes
73
  if len(dataset.flower_labels) != model.config.num_labels:
74
  model.config.num_labels = len(dataset.flower_labels)
75
  # ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
76
+ final_hidden_size = (
77
+ model.config.hidden_sizes[-1]
78
+ if hasattr(model.config, "hidden_sizes")
79
+ else 768
80
+ )
81
+ model.classifier = torch.nn.Linear(
82
+ final_hidden_size, len(dataset.flower_labels)
83
+ )
84
  model.classifier.to(device)
85
+
86
  # Create data loader
87
+ train_loader = DataLoader(
88
+ train_dataset, batch_size=batch_size, shuffle=True, collate_fn=simple_collate_fn
89
+ )
90
+
91
  # Setup optimizer
92
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
93
+
94
  # Training loop
95
  model.train()
96
  print(f"Starting training on {len(train_dataset)} samples for {epochs} epochs...")
97
+
98
  for epoch in range(epochs):
99
  total_loss = 0
100
  num_batches = 0
101
+
102
  for batch_idx, batch in enumerate(train_loader):
103
  # Move to device
104
+ pixel_values = batch["pixel_values"].to(device)
105
+ labels = batch["labels"].to(device)
106
+
107
  # Zero gradients
108
  optimizer.zero_grad()
109
+
110
  # Forward pass
111
  outputs = model(pixel_values=pixel_values, labels=labels)
112
  loss = outputs.loss
113
+
114
  # Backward pass
115
  loss.backward()
116
  optimizer.step()
117
+
118
  total_loss += loss.item()
119
  num_batches += 1
120
+
121
  if batch_idx % 2 == 0 or batch_idx == len(train_loader) - 1:
122
+ print(
123
+ f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(train_loader)}: Loss = {loss.item():.4f}"
124
+ )
125
+
126
  avg_loss = total_loss / num_batches if num_batches > 0 else 0
127
+ print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
128
+
129
  # Save model
130
  os.makedirs(output_dir, exist_ok=True)
131
+
132
  model.save_pretrained(output_dir)
133
  processor.save_pretrained(output_dir)
134
+
135
  # Save config
136
  config = {
137
  "model_name": model_name,
 
141
  "learning_rate": learning_rate,
142
  "train_samples": len(train_dataset),
143
  "num_labels": len(dataset.flower_labels),
144
+ "training_type": "simple",
145
  }
146
+
147
  with open(os.path.join(output_dir, "training_config.json"), "w") as f:
148
  json.dump(config, f, indent=2)
149
+
150
  print(f"βœ… ConvNeXt training completed! Model saved to {output_dir}")
151
  return output_dir
152
 
153
 
154
  if __name__ == "__main__":
155
  import argparse
156
+
157
+ parser = argparse.ArgumentParser(
158
+ description="Simple ConvNeXt training for flower classification"
159
+ )
160
+ parser.add_argument(
161
+ "--image_dir",
162
+ default="training_data/images",
163
+ help="Directory containing training images",
164
+ )
165
+ parser.add_argument(
166
+ "--output_dir",
167
+ default="training_data/trained_models/simple_trained",
168
+ help="Output directory for trained model",
169
+ )
170
+ parser.add_argument(
171
+ "--epochs", type=int, default=3, help="Number of training epochs"
172
+ )
173
  parser.add_argument("--batch_size", type=int, default=4, help="Training batch size")
174
+ parser.add_argument(
175
+ "--learning_rate", type=float, default=1e-5, help="Learning rate"
176
+ )
177
+ parser.add_argument(
178
+ "--model_name", default="facebook/convnext-base-224-22k", help="Base model name"
179
+ )
180
+
181
  args = parser.parse_args()
182
+
183
  try:
184
  result = simple_train(
185
  image_dir=args.image_dir,
 
187
  epochs=args.epochs,
188
  batch_size=args.batch_size,
189
  learning_rate=args.learning_rate,
190
+ model_name=args.model_name,
191
  )
192
  if not result:
193
  print("❌ Training failed!")
 
197
  except Exception as e:
198
  print(f"❌ Training failed: {e}")
199
  import traceback
200
+
201
  traceback.print_exc()
202
+ exit(1)