# ruff: noqa import glob import os import gradio as gr import numpy as np import torch from diffusers import AutoPipelineForText2Image from simple_train import simple_train from sklearn.cluster import KMeans from transformers import ( AutoImageProcessor, AutoModelForImageClassification, ConvNextForImageClassification, ConvNextImageProcessor, pipeline, ) MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo") device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) if device == "cuda": try: pipe.enable_xformers_memory_efficient_attention() except Exception: pipe.enable_attention_slicing() else: pipe.enable_attention_slicing() def generate(prompt, steps, width, height, seed): if seed is None or int(seed) < 0: generator = None else: generator = torch.Generator(device=device).manual_seed(int(seed)) result = pipe( prompt=prompt, num_inference_steps=int(steps), guidance_scale=0.0, # SDXL-Turbo works best at 0.0 width=int(width // 8) * 8, height=int(height // 8) * 8, generator=generator, ) return result.images[0] # ---------- Flower identification (zero-shot) ---------- # Curated label set; edit/extend as you like FLOWER_LABELS = [ "rose", "tulip", "lily", "peony", "sunflower", "chrysanthemum", "carnation", "orchid", "hydrangea", "daisy", "dahlia", "ranunculus", "anemone", "marigold", "lavender", "magnolia", "gardenia", "camellia", "jasmine", "iris", "gerbera", "zinnia", "hibiscus", "lotus", "poppy", "sweet pea", "freesia", "lisianthus", "calla lily", "cherry blossom", "plumeria", "cosmos", ] # Initialize classifier - will be updated when trained model is loaded clf_device = 0 if torch.cuda.is_available() else -1 zs_classifier = None convnext_model = None convnext_processor = None current_model_path = "facebook/convnext-base-224-22k" def load_classifier(model_path="facebook/convnext-base-224-22k"): global zs_classifier, convnext_model, convnext_processor, current_model_path try: if os.path.exists(model_path): # Load custom trained model convnext_model = AutoModelForImageClassification.from_pretrained(model_path) convnext_processor = AutoImageProcessor.from_pretrained(model_path) current_model_path = model_path # Also keep zero-shot classifier for fallback zs_classifier = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-base-patch32", device=clf_device, ) return f"✅ Loaded custom ConvNeXt model from: {model_path}" else: # Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot convnext_model = ConvNextForImageClassification.from_pretrained( "facebook/convnext-base-224-22k" ) convnext_processor = ConvNextImageProcessor.from_pretrained( "facebook/convnext-base-224-22k" ) zs_classifier = pipeline( task="zero-shot-image-classification", model="openai/clip-vit-base-patch32", device=clf_device, ) current_model_path = "facebook/convnext-base-224-22k" return "✅ Loaded default ConvNeXt model: facebook/convnext-base-224-22k" except Exception as e: return f"❌ Error loading model: {e!s}" # Initialize with default model load_classifier() def identify_flowers(image, candidate_labels, top_k, min_score): if image is None: return [], "Please provide an image (upload or generate first)." labels = candidate_labels if candidate_labels else FLOWER_LABELS # Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP if ( convnext_model is not None and os.path.exists(current_model_path) and current_model_path != "facebook/convnext-base-224-22k" ): try: # Use trained ConvNeXt model inputs = convnext_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = convnext_model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Convert predictions to results format results = [] for i, score in enumerate(predictions[0]): if i < len(labels): results.append({"label": labels[i], "score": float(score)}) # Sort by score results = sorted(results, key=lambda r: r["score"], reverse=True) except Exception: # Fallback to CLIP zero-shot results = zs_classifier( image, candidate_labels=labels, hypothesis_template="a photo of a {}" ) else: # Use CLIP zero-shot classification results = zs_classifier( image, candidate_labels=labels, hypothesis_template="a photo of a {}" ) # Filter and format results results = [r for r in results if r["score"] >= float(min_score)] results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)] table = [[r["label"], round(float(r["score"]), 4)] for r in results] model_type = ( "ConvNeXt" if ( convnext_model is not None and os.path.exists(current_model_path) and current_model_path != "facebook/convnext-base-224-22k" ) else "CLIP zero-shot" ) msg = f"Detected flowers using {model_type}." return table, msg # simple passthrough so the generated image appears in the Identify tab automatically def passthrough(img): return img # Training functions def get_available_models(): models_dir = "training_data/trained_models" if not os.path.exists(models_dir): return ["facebook/convnext-base-224-22k (default)"] models = ["facebook/convnext-base-224-22k (default)"] for item in os.listdir(models_dir): model_path = os.path.join(models_dir, item) if os.path.isdir(model_path) and os.path.exists( os.path.join(model_path, "config.json") ): models.append(f"Custom: {item}") return models def count_training_images(): images_dir = "training_data/images" if not os.path.exists(images_dir): return "Training directory not found" total_images = 0 flower_counts = {} for flower_type in os.listdir(images_dir): flower_path = os.path.join(images_dir, flower_type) if os.path.isdir(flower_path): image_files = ( glob.glob(os.path.join(flower_path, "*.jpg")) + glob.glob(os.path.join(flower_path, "*.jpeg")) + glob.glob(os.path.join(flower_path, "*.png")) + glob.glob(os.path.join(flower_path, "*.webp")) ) count = len(image_files) if count > 0: flower_counts[flower_type] = count total_images += count if total_images == 0: return "No training images found. Add images to subdirectories in training_data/images/" result = f"**Total images: {total_images}**\n\n" for flower_type, count in sorted(flower_counts.items()): result += f"- {flower_type}: {count} images\n" return result def start_training(epochs=None, batch_size=None, learning_rate=None): try: # Check if training data exists images_dir = "training_data/images" if not os.path.exists(images_dir): return "❌ Training directory not found. Please create training_data/images/ and add your data." # Count images total_images = 0 for flower_type in os.listdir(images_dir): flower_path = os.path.join(images_dir, flower_type) if os.path.isdir(flower_path): image_files = ( glob.glob(os.path.join(flower_path, "*.jpg")) + glob.glob(os.path.join(flower_path, "*.jpeg")) + glob.glob(os.path.join(flower_path, "*.png")) + glob.glob(os.path.join(flower_path, "*.webp")) ) total_images += len(image_files) if total_images < 10: return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/" # Start training model_path = simple_train() if model_path: return f"✅ Training completed! Model saved to: {model_path}" else: return "❌ Training failed. Check the console for details." except Exception as e: return f"❌ Training error: {e!s}" def load_trained_model(model_selection): if model_selection.startswith("Custom:"): model_name = model_selection.replace("Custom: ", "") model_path = os.path.join("training_data/trained_models", model_name) return load_classifier(model_path) else: return load_classifier("facebook/convnext-base-224-22k") # French-style arrangement functions def extract_dominant_colors(image, num_colors=5): """Extract dominant colors from an image using k-means clustering""" if image is None: return [], "No image provided" # Convert PIL image to numpy array img_array = np.array(image) # Reshape image to be a list of pixels pixels = img_array.reshape(-1, 3) # Use k-means to find dominant colors kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) kmeans.fit(pixels) # Get the colors and convert to RGB values colors = kmeans.cluster_centers_.astype(int) # Convert to color names/descriptions color_names = [] for color in colors: r, g, b = color if r > 200 and g > 200 and b > 200: color_names.append("white") elif r < 50 and g < 50 and b < 50: color_names.append("black") elif r > g and r > b: if r > 150 and g < 100: color_names.append("red" if g < 50 else "pink") else: color_names.append("coral") elif g > r and g > b: if b < 100: color_names.append("yellow" if g > 200 and r > 150 else "green") else: color_names.append("teal") elif b > r and b > g: if r < 100: color_names.append("blue" if b > 150 else "navy") else: color_names.append("purple" if r > g else "lavender") elif r > 150 and g > 100 and b < 100: color_names.append("orange") else: color_names.append("cream") return color_names, colors def analyze_and_generate_french_style(image): """Analyze uploaded flower image and generate French-style arrangement""" if image is None: return None, "Please upload an image", "" # Identify the flower type if zs_classifier is None: return None, "Model not loaded", "" try: progress_log = "🔄 **Step 1/4:** Starting flower analysis...\n\n" # Identify flower progress_log += "🔍 Identifying flower type using AI model...\n" results = zs_classifier( image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}" ) top_flower = results[0]["label"] if results else "flower" confidence = results[0]["score"] if results else 0 progress_log += ( f"✅ Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n" ) # Extract dominant colors progress_log += "🔄 **Step 2/4:** Analyzing color palette...\n\n" progress_log += "🎨 Extracting dominant colors from image...\n" color_names, color_rgb = extract_dominant_colors(image, num_colors=3) # Create color description main_colors = color_names[:3] # Top 3 colors color_desc = ", ".join(main_colors) progress_log += f"✅ Color palette: **{color_desc}**\n\n" # Generate French-style prompt progress_log += ( "🔄 **Step 3/4:** Creating French-style arrangement prompt...\n\n" ) prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition" progress_log += f"✅ Prompt created: *{prompt[:100]}...*\n\n" # Generate the image progress_log += ( "🔄 **Step 4/4:** Generating French-style arrangement image...\n\n" ) progress_log += "🖼️ Using AI image generation (SDXL-Turbo)...\n" generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1) progress_log += "✅ French-style arrangement generated successfully!\n\n" # Create analysis summary analysis = f""" **🌸 Flower Analysis:** - **Type:** {top_flower} (confidence: {confidence:.2%}) - **Dominant Colors:** {color_desc} **🇫🇷 Generated Prompt:** "{prompt}" --- **📋 Process Log:** {progress_log} """ return ( generated_image, "✅ Analysis complete! French-style arrangement generated.", analysis, ) except Exception as e: error_log = f"❌ **Error occurred during processing:**\n\n{e!s}\n\n" if "progress_log" in locals(): error_log += f"**Progress before error:**\n{progress_log}" return None, f"❌ Error: {e!s}", error_log # ---------- UI ---------- with gr.Blocks() as demo: gr.Markdown("# 🌸 SDXL-Turbo — Text → Image + Flower Identifier") with gr.Tabs(): with gr.TabItem("Generate"): with gr.Row(): with gr.Column(): prompt = gr.Textbox( value="ikebana-style flower arrangement, soft natural light, minimalist", label="Prompt", ) steps = gr.Slider(1, 8, value=4, step=1, label="Steps") width = gr.Slider(512, 1536, value=1024, step=8, label="Width") height = gr.Slider(512, 1536, value=1024, step=8, label="Height") seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") go = gr.Button("Generate", variant="primary") out = gr.Image(label="Result", type="pil") with gr.TabItem("Identify"), gr.Row(): with gr.Column(): img_in = gr.Image( label="Image (upload or auto-filled from 'Generate')", type="pil", interactive=True, ) labels_box = gr.CheckboxGroup( choices=FLOWER_LABELS, value=[ "rose", "tulip", "lily", "peony", "hydrangea", "orchid", "sunflower", ], label="Candidate labels (edit as needed)", ) topk = gr.Slider(1, 15, value=7, step=1, label="Top-K") min_score = gr.Slider( 0.0, 1.0, value=0.12, step=0.01, label="Min confidence" ) detect_btn = gr.Button("Identify Flowers", variant="primary") with gr.Column(): results_tbl = gr.Dataframe( headers=["Flower", "Confidence"], datatype=["str", "number"], interactive=False, ) status = gr.Markdown() with gr.TabItem("Train Model"): gr.Markdown("## 🎯 Fine-tune the flower identification model") gr.Markdown( "Organize your training images in subdirectories by flower type in `training_data/images/`" ) gr.Markdown( "Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc." ) with gr.Row(): with gr.Column(): gr.Markdown("### Training Data") refresh_btn = gr.Button("🔄 Refresh Data Count", size="sm") data_status = gr.Markdown() gr.Markdown("### Training Parameters") epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs") batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size") learning_rate = gr.Number( value=1e-5, label="Learning Rate", precision=6 ) train_btn = gr.Button("🚀 Start Training", variant="primary") with gr.Column(): gr.Markdown("### Model Management") model_dropdown = gr.Dropdown( choices=get_available_models(), value="facebook/convnext-base-224-22k (default)", label="Select Model", ) refresh_models_btn = gr.Button("🔄 Refresh Models", size="sm") load_model_btn = gr.Button( "📥 Load Selected Model", variant="secondary" ) model_status = gr.Markdown( f"**Current model:** {current_model_path}" ) gr.Markdown("### Training Status") training_output = gr.Markdown() with gr.TabItem("French Style arrangement"): gr.Markdown("## 🇫🇷 French-Style Flower Arrangements") gr.Markdown( "Upload a flower image and generate an elegant French-style arrangement with matching colors!" ) with gr.Row(): with gr.Column(): upload_img = gr.Image(label="Upload Flower Image", type="pil") analyze_btn = gr.Button( "🎨 Analyze & Generate French Style", variant="primary", size="lg", ) with gr.Column(): french_result = gr.Image( label="Generated French-Style Arrangement", type="pil" ) french_status = gr.Markdown() analysis_details = gr.Markdown() # Wire events go.click(generate, [prompt, steps, width, height, seed], [out]) # Auto-send generated image to Identify tab out.change(passthrough, inputs=out, outputs=img_in) # Run identification detect_btn.click( identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status] ) # Training tab events refresh_btn.click(count_training_images, outputs=[data_status]) refresh_models_btn.click( lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown] ) load_model_btn.click( load_trained_model, inputs=[model_dropdown], outputs=[model_status] ) train_btn.click( start_training, inputs=[epochs, batch_size, learning_rate], outputs=[training_output], ) # French Style tab events - update status during processing def update_french_status(): return "🔄 Processing... Please wait while we analyze your flower image...", "" analyze_btn.click( update_french_status, outputs=[french_status, analysis_details] ).then( analyze_and_generate_french_style, inputs=[upload_img], outputs=[french_result, french_status, analysis_details], ) # Initialize data count on load demo.load(count_training_images, outputs=[data_status]) demo.queue().launch()