flowerfy / app_original.py
Toy
Fix pre-commit configuration and resolve all linting issues
5aeda0b
# ruff: noqa
import glob
import os
import gradio as gr
import numpy as np
import torch
from diffusers import AutoPipelineForText2Image
from simple_train import simple_train
from sklearn.cluster import KMeans
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
ConvNextForImageClassification,
ConvNextImageProcessor,
pipeline,
)
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device)
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pipe.enable_attention_slicing()
else:
pipe.enable_attention_slicing()
def generate(prompt, steps, width, height, seed):
if seed is None or int(seed) < 0:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(int(seed))
result = pipe(
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=0.0, # SDXL-Turbo works best at 0.0
width=int(width // 8) * 8,
height=int(height // 8) * 8,
generator=generator,
)
return result.images[0]
# ---------- Flower identification (zero-shot) ----------
# Curated label set; edit/extend as you like
FLOWER_LABELS = [
"rose",
"tulip",
"lily",
"peony",
"sunflower",
"chrysanthemum",
"carnation",
"orchid",
"hydrangea",
"daisy",
"dahlia",
"ranunculus",
"anemone",
"marigold",
"lavender",
"magnolia",
"gardenia",
"camellia",
"jasmine",
"iris",
"gerbera",
"zinnia",
"hibiscus",
"lotus",
"poppy",
"sweet pea",
"freesia",
"lisianthus",
"calla lily",
"cherry blossom",
"plumeria",
"cosmos",
]
# Initialize classifier - will be updated when trained model is loaded
clf_device = 0 if torch.cuda.is_available() else -1
zs_classifier = None
convnext_model = None
convnext_processor = None
current_model_path = "facebook/convnext-base-224-22k"
def load_classifier(model_path="facebook/convnext-base-224-22k"):
global zs_classifier, convnext_model, convnext_processor, current_model_path
try:
if os.path.exists(model_path):
# Load custom trained model
convnext_model = AutoModelForImageClassification.from_pretrained(model_path)
convnext_processor = AutoImageProcessor.from_pretrained(model_path)
current_model_path = model_path
# Also keep zero-shot classifier for fallback
zs_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
device=clf_device,
)
return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
else:
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
convnext_model = ConvNextForImageClassification.from_pretrained(
"facebook/convnext-base-224-22k"
)
convnext_processor = ConvNextImageProcessor.from_pretrained(
"facebook/convnext-base-224-22k"
)
zs_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
device=clf_device,
)
current_model_path = "facebook/convnext-base-224-22k"
return "βœ… Loaded default ConvNeXt model: facebook/convnext-base-224-22k"
except Exception as e:
return f"❌ Error loading model: {e!s}"
# Initialize with default model
load_classifier()
def identify_flowers(image, candidate_labels, top_k, min_score):
if image is None:
return [], "Please provide an image (upload or generate first)."
labels = candidate_labels if candidate_labels else FLOWER_LABELS
# Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
if (
convnext_model is not None
and os.path.exists(current_model_path)
and current_model_path != "facebook/convnext-base-224-22k"
):
try:
# Use trained ConvNeXt model
inputs = convnext_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = convnext_model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Convert predictions to results format
results = []
for i, score in enumerate(predictions[0]):
if i < len(labels):
results.append({"label": labels[i], "score": float(score)})
# Sort by score
results = sorted(results, key=lambda r: r["score"], reverse=True)
except Exception:
# Fallback to CLIP zero-shot
results = zs_classifier(
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
)
else:
# Use CLIP zero-shot classification
results = zs_classifier(
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
)
# Filter and format results
results = [r for r in results if r["score"] >= float(min_score)]
results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)]
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
model_type = (
"ConvNeXt"
if (
convnext_model is not None
and os.path.exists(current_model_path)
and current_model_path != "facebook/convnext-base-224-22k"
)
else "CLIP zero-shot"
)
msg = f"Detected flowers using {model_type}."
return table, msg
# simple passthrough so the generated image appears in the Identify tab automatically
def passthrough(img):
return img
# Training functions
def get_available_models():
models_dir = "training_data/trained_models"
if not os.path.exists(models_dir):
return ["facebook/convnext-base-224-22k (default)"]
models = ["facebook/convnext-base-224-22k (default)"]
for item in os.listdir(models_dir):
model_path = os.path.join(models_dir, item)
if os.path.isdir(model_path) and os.path.exists(
os.path.join(model_path, "config.json")
):
models.append(f"Custom: {item}")
return models
def count_training_images():
images_dir = "training_data/images"
if not os.path.exists(images_dir):
return "Training directory not found"
total_images = 0
flower_counts = {}
for flower_type in os.listdir(images_dir):
flower_path = os.path.join(images_dir, flower_type)
if os.path.isdir(flower_path):
image_files = (
glob.glob(os.path.join(flower_path, "*.jpg"))
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
+ glob.glob(os.path.join(flower_path, "*.png"))
+ glob.glob(os.path.join(flower_path, "*.webp"))
)
count = len(image_files)
if count > 0:
flower_counts[flower_type] = count
total_images += count
if total_images == 0:
return "No training images found. Add images to subdirectories in training_data/images/"
result = f"**Total images: {total_images}**\n\n"
for flower_type, count in sorted(flower_counts.items()):
result += f"- {flower_type}: {count} images\n"
return result
def start_training(epochs=None, batch_size=None, learning_rate=None):
try:
# Check if training data exists
images_dir = "training_data/images"
if not os.path.exists(images_dir):
return "❌ Training directory not found. Please create training_data/images/ and add your data."
# Count images
total_images = 0
for flower_type in os.listdir(images_dir):
flower_path = os.path.join(images_dir, flower_type)
if os.path.isdir(flower_path):
image_files = (
glob.glob(os.path.join(flower_path, "*.jpg"))
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
+ glob.glob(os.path.join(flower_path, "*.png"))
+ glob.glob(os.path.join(flower_path, "*.webp"))
)
total_images += len(image_files)
if total_images < 10:
return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
# Start training
model_path = simple_train()
if model_path:
return f"βœ… Training completed! Model saved to: {model_path}"
else:
return "❌ Training failed. Check the console for details."
except Exception as e:
return f"❌ Training error: {e!s}"
def load_trained_model(model_selection):
if model_selection.startswith("Custom:"):
model_name = model_selection.replace("Custom: ", "")
model_path = os.path.join("training_data/trained_models", model_name)
return load_classifier(model_path)
else:
return load_classifier("facebook/convnext-base-224-22k")
# French-style arrangement functions
def extract_dominant_colors(image, num_colors=5):
"""Extract dominant colors from an image using k-means clustering"""
if image is None:
return [], "No image provided"
# Convert PIL image to numpy array
img_array = np.array(image)
# Reshape image to be a list of pixels
pixels = img_array.reshape(-1, 3)
# Use k-means to find dominant colors
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(pixels)
# Get the colors and convert to RGB values
colors = kmeans.cluster_centers_.astype(int)
# Convert to color names/descriptions
color_names = []
for color in colors:
r, g, b = color
if r > 200 and g > 200 and b > 200:
color_names.append("white")
elif r < 50 and g < 50 and b < 50:
color_names.append("black")
elif r > g and r > b:
if r > 150 and g < 100:
color_names.append("red" if g < 50 else "pink")
else:
color_names.append("coral")
elif g > r and g > b:
if b < 100:
color_names.append("yellow" if g > 200 and r > 150 else "green")
else:
color_names.append("teal")
elif b > r and b > g:
if r < 100:
color_names.append("blue" if b > 150 else "navy")
else:
color_names.append("purple" if r > g else "lavender")
elif r > 150 and g > 100 and b < 100:
color_names.append("orange")
else:
color_names.append("cream")
return color_names, colors
def analyze_and_generate_french_style(image):
"""Analyze uploaded flower image and generate French-style arrangement"""
if image is None:
return None, "Please upload an image", ""
# Identify the flower type
if zs_classifier is None:
return None, "Model not loaded", ""
try:
progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
# Identify flower
progress_log += "πŸ” Identifying flower type using AI model...\n"
results = zs_classifier(
image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}"
)
top_flower = results[0]["label"] if results else "flower"
confidence = results[0]["score"] if results else 0
progress_log += (
f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
)
# Extract dominant colors
progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
progress_log += "🎨 Extracting dominant colors from image...\n"
color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
# Create color description
main_colors = color_names[:3] # Top 3 colors
color_desc = ", ".join(main_colors)
progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
# Generate French-style prompt
progress_log += (
"πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
)
prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
# Generate the image
progress_log += (
"πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
)
progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
progress_log += "βœ… French-style arrangement generated successfully!\n\n"
# Create analysis summary
analysis = f"""
**🌸 Flower Analysis:**
- **Type:** {top_flower} (confidence: {confidence:.2%})
- **Dominant Colors:** {color_desc}
**πŸ‡«πŸ‡· Generated Prompt:**
"{prompt}"
---
**πŸ“‹ Process Log:**
{progress_log}
"""
return (
generated_image,
"βœ… Analysis complete! French-style arrangement generated.",
analysis,
)
except Exception as e:
error_log = f"❌ **Error occurred during processing:**\n\n{e!s}\n\n"
if "progress_log" in locals():
error_log += f"**Progress before error:**\n{progress_log}"
return None, f"❌ Error: {e!s}", error_log
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown("# 🌸 SDXL-Turbo β€” Text β†’ Image + Flower Identifier")
with gr.Tabs():
with gr.TabItem("Generate"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
value="ikebana-style flower arrangement, soft natural light, minimalist",
label="Prompt",
)
steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
go = gr.Button("Generate", variant="primary")
out = gr.Image(label="Result", type="pil")
with gr.TabItem("Identify"), gr.Row():
with gr.Column():
img_in = gr.Image(
label="Image (upload or auto-filled from 'Generate')",
type="pil",
interactive=True,
)
labels_box = gr.CheckboxGroup(
choices=FLOWER_LABELS,
value=[
"rose",
"tulip",
"lily",
"peony",
"hydrangea",
"orchid",
"sunflower",
],
label="Candidate labels (edit as needed)",
)
topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
min_score = gr.Slider(
0.0, 1.0, value=0.12, step=0.01, label="Min confidence"
)
detect_btn = gr.Button("Identify Flowers", variant="primary")
with gr.Column():
results_tbl = gr.Dataframe(
headers=["Flower", "Confidence"],
datatype=["str", "number"],
interactive=False,
)
status = gr.Markdown()
with gr.TabItem("Train Model"):
gr.Markdown("## 🎯 Fine-tune the flower identification model")
gr.Markdown(
"Organize your training images in subdirectories by flower type in `training_data/images/`"
)
gr.Markdown(
"Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
)
with gr.Row():
with gr.Column():
gr.Markdown("### Training Data")
refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
data_status = gr.Markdown()
gr.Markdown("### Training Parameters")
epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
learning_rate = gr.Number(
value=1e-5, label="Learning Rate", precision=6
)
train_btn = gr.Button("πŸš€ Start Training", variant="primary")
with gr.Column():
gr.Markdown("### Model Management")
model_dropdown = gr.Dropdown(
choices=get_available_models(),
value="facebook/convnext-base-224-22k (default)",
label="Select Model",
)
refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
load_model_btn = gr.Button(
"πŸ“₯ Load Selected Model", variant="secondary"
)
model_status = gr.Markdown(
f"**Current model:** {current_model_path}"
)
gr.Markdown("### Training Status")
training_output = gr.Markdown()
with gr.TabItem("French Style arrangement"):
gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
gr.Markdown(
"Upload a flower image and generate an elegant French-style arrangement with matching colors!"
)
with gr.Row():
with gr.Column():
upload_img = gr.Image(label="Upload Flower Image", type="pil")
analyze_btn = gr.Button(
"🎨 Analyze & Generate French Style",
variant="primary",
size="lg",
)
with gr.Column():
french_result = gr.Image(
label="Generated French-Style Arrangement", type="pil"
)
french_status = gr.Markdown()
analysis_details = gr.Markdown()
# Wire events
go.click(generate, [prompt, steps, width, height, seed], [out])
# Auto-send generated image to Identify tab
out.change(passthrough, inputs=out, outputs=img_in)
# Run identification
detect_btn.click(
identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status]
)
# Training tab events
refresh_btn.click(count_training_images, outputs=[data_status])
refresh_models_btn.click(
lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown]
)
load_model_btn.click(
load_trained_model, inputs=[model_dropdown], outputs=[model_status]
)
train_btn.click(
start_training,
inputs=[epochs, batch_size, learning_rate],
outputs=[training_output],
)
# French Style tab events - update status during processing
def update_french_status():
return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
analyze_btn.click(
update_french_status, outputs=[french_status, analysis_details]
).then(
analyze_and_generate_french_style,
inputs=[upload_img],
outputs=[french_result, french_status, analysis_details],
)
# Initialize data count on load
demo.load(count_training_images, outputs=[data_status])
demo.queue().launch()