flowerfy / app_original.py
Toy
Apply code formatting and fix compatibility issues
b24c04f
raw
history blame
20.5 kB
import glob
import os
import gradio as gr
import numpy as np
import torch
from diffusers import AutoPipelineForText2Image
from simple_train import simple_train
from sklearn.cluster import KMeans
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
ConvNextForImageClassification,
ConvNextImageProcessor,
pipeline,
)
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device)
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pipe.enable_attention_slicing()
else:
pipe.enable_attention_slicing()
def generate(prompt, steps, width, height, seed):
if seed is None or int(seed) < 0:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(int(seed))
result = pipe(
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=0.0, # SDXL-Turbo works best at 0.0
width=int(width // 8) * 8,
height=int(height // 8) * 8,
generator=generator,
)
return result.images[0]
# ---------- Flower identification (zero-shot) ----------
# Curated label set; edit/extend as you like
FLOWER_LABELS = [
"rose",
"tulip",
"lily",
"peony",
"sunflower",
"chrysanthemum",
"carnation",
"orchid",
"hydrangea",
"daisy",
"dahlia",
"ranunculus",
"anemone",
"marigold",
"lavender",
"magnolia",
"gardenia",
"camellia",
"jasmine",
"iris",
"gerbera",
"zinnia",
"hibiscus",
"lotus",
"poppy",
"sweet pea",
"freesia",
"lisianthus",
"calla lily",
"cherry blossom",
"plumeria",
"cosmos",
]
# Initialize classifier - will be updated when trained model is loaded
clf_device = 0 if torch.cuda.is_available() else -1
zs_classifier = None
convnext_model = None
convnext_processor = None
current_model_path = "facebook/convnext-base-224-22k"
def load_classifier(model_path="facebook/convnext-base-224-22k"):
global zs_classifier, convnext_model, convnext_processor, current_model_path
try:
if os.path.exists(model_path):
# Load custom trained model
convnext_model = AutoModelForImageClassification.from_pretrained(model_path)
convnext_processor = AutoImageProcessor.from_pretrained(model_path)
current_model_path = model_path
# Also keep zero-shot classifier for fallback
zs_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
device=clf_device,
)
return f"βœ… Loaded custom ConvNeXt model from: {model_path}"
else:
# Load default ConvNeXt model for feature extraction and fallback to CLIP for zero-shot
convnext_model = ConvNextForImageClassification.from_pretrained(
"facebook/convnext-base-224-22k"
)
convnext_processor = ConvNextImageProcessor.from_pretrained(
"facebook/convnext-base-224-22k"
)
zs_classifier = pipeline(
task="zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
device=clf_device,
)
current_model_path = "facebook/convnext-base-224-22k"
return "βœ… Loaded default ConvNeXt model: facebook/convnext-base-224-22k"
except Exception as e:
return f"❌ Error loading model: {e!s}"
# Initialize with default model
load_classifier()
def identify_flowers(image, candidate_labels, top_k, min_score):
if image is None:
return [], "Please provide an image (upload or generate first)."
labels = candidate_labels if candidate_labels else FLOWER_LABELS
# Use ConvNeXt for feature extraction if we have a trained model, otherwise fallback to CLIP
if (
convnext_model is not None
and os.path.exists(current_model_path)
and current_model_path != "facebook/convnext-base-224-22k"
):
try:
# Use trained ConvNeXt model
inputs = convnext_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = convnext_model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Convert predictions to results format
results = []
for i, score in enumerate(predictions[0]):
if i < len(labels):
results.append({"label": labels[i], "score": float(score)})
# Sort by score
results = sorted(results, key=lambda r: r["score"], reverse=True)
except Exception:
# Fallback to CLIP zero-shot
results = zs_classifier(
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
)
else:
# Use CLIP zero-shot classification
results = zs_classifier(
image, candidate_labels=labels, hypothesis_template="a photo of a {}"
)
# Filter and format results
results = [r for r in results if r["score"] >= float(min_score)]
results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)]
table = [[r["label"], round(float(r["score"]), 4)] for r in results]
model_type = (
"ConvNeXt"
if (
convnext_model is not None
and os.path.exists(current_model_path)
and current_model_path != "facebook/convnext-base-224-22k"
)
else "CLIP zero-shot"
)
msg = f"Detected flowers using {model_type}."
return table, msg
# simple passthrough so the generated image appears in the Identify tab automatically
def passthrough(img):
return img
# Training functions
def get_available_models():
models_dir = "training_data/trained_models"
if not os.path.exists(models_dir):
return ["facebook/convnext-base-224-22k (default)"]
models = ["facebook/convnext-base-224-22k (default)"]
for item in os.listdir(models_dir):
model_path = os.path.join(models_dir, item)
if os.path.isdir(model_path) and os.path.exists(
os.path.join(model_path, "config.json")
):
models.append(f"Custom: {item}")
return models
def count_training_images():
images_dir = "training_data/images"
if not os.path.exists(images_dir):
return "Training directory not found"
total_images = 0
flower_counts = {}
for flower_type in os.listdir(images_dir):
flower_path = os.path.join(images_dir, flower_type)
if os.path.isdir(flower_path):
image_files = (
glob.glob(os.path.join(flower_path, "*.jpg"))
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
+ glob.glob(os.path.join(flower_path, "*.png"))
+ glob.glob(os.path.join(flower_path, "*.webp"))
)
count = len(image_files)
if count > 0:
flower_counts[flower_type] = count
total_images += count
if total_images == 0:
return "No training images found. Add images to subdirectories in training_data/images/"
result = f"**Total images: {total_images}**\n\n"
for flower_type, count in sorted(flower_counts.items()):
result += f"- {flower_type}: {count} images\n"
return result
def start_training(epochs=None, batch_size=None, learning_rate=None):
try:
# Check if training data exists
images_dir = "training_data/images"
if not os.path.exists(images_dir):
return "❌ Training directory not found. Please create training_data/images/ and add your data."
# Count images
total_images = 0
for flower_type in os.listdir(images_dir):
flower_path = os.path.join(images_dir, flower_type)
if os.path.isdir(flower_path):
image_files = (
glob.glob(os.path.join(flower_path, "*.jpg"))
+ glob.glob(os.path.join(flower_path, "*.jpeg"))
+ glob.glob(os.path.join(flower_path, "*.png"))
+ glob.glob(os.path.join(flower_path, "*.webp"))
)
total_images += len(image_files)
if total_images < 10:
return f"❌ Need at least 10 training images. Found {total_images}. Add more images to training_data/images/"
# Start training
model_path = simple_train()
if model_path:
return f"βœ… Training completed! Model saved to: {model_path}"
else:
return "❌ Training failed. Check the console for details."
except Exception as e:
return f"❌ Training error: {e!s}"
def load_trained_model(model_selection):
if model_selection.startswith("Custom:"):
model_name = model_selection.replace("Custom: ", "")
model_path = os.path.join("training_data/trained_models", model_name)
return load_classifier(model_path)
else:
return load_classifier("facebook/convnext-base-224-22k")
# French-style arrangement functions
def extract_dominant_colors(image, num_colors=5):
"""Extract dominant colors from an image using k-means clustering"""
if image is None:
return [], "No image provided"
# Convert PIL image to numpy array
img_array = np.array(image)
# Reshape image to be a list of pixels
pixels = img_array.reshape(-1, 3)
# Use k-means to find dominant colors
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(pixels)
# Get the colors and convert to RGB values
colors = kmeans.cluster_centers_.astype(int)
# Convert to color names/descriptions
color_names = []
for color in colors:
r, g, b = color
if r > 200 and g > 200 and b > 200:
color_names.append("white")
elif r < 50 and g < 50 and b < 50:
color_names.append("black")
elif r > g and r > b:
if r > 150 and g < 100:
color_names.append("red" if g < 50 else "pink")
else:
color_names.append("coral")
elif g > r and g > b:
if b < 100:
color_names.append("yellow" if g > 200 and r > 150 else "green")
else:
color_names.append("teal")
elif b > r and b > g:
if r < 100:
color_names.append("blue" if b > 150 else "navy")
else:
color_names.append("purple" if r > g else "lavender")
elif r > 150 and g > 100 and b < 100:
color_names.append("orange")
else:
color_names.append("cream")
return color_names, colors
def analyze_and_generate_french_style(image):
"""Analyze uploaded flower image and generate French-style arrangement"""
if image is None:
return None, "Please upload an image", ""
# Identify the flower type
if zs_classifier is None:
return None, "Model not loaded", ""
try:
progress_log = "πŸ”„ **Step 1/4:** Starting flower analysis...\n\n"
# Identify flower
progress_log += "πŸ” Identifying flower type using AI model...\n"
results = zs_classifier(
image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}"
)
top_flower = results[0]["label"] if results else "flower"
confidence = results[0]["score"] if results else 0
progress_log += (
f"βœ… Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n"
)
# Extract dominant colors
progress_log += "πŸ”„ **Step 2/4:** Analyzing color palette...\n\n"
progress_log += "🎨 Extracting dominant colors from image...\n"
color_names, color_rgb = extract_dominant_colors(image, num_colors=3)
# Create color description
main_colors = color_names[:3] # Top 3 colors
color_desc = ", ".join(main_colors)
progress_log += f"βœ… Color palette: **{color_desc}**\n\n"
# Generate French-style prompt
progress_log += (
"πŸ”„ **Step 3/4:** Creating French-style arrangement prompt...\n\n"
)
prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition"
progress_log += f"βœ… Prompt created: *{prompt[:100]}...*\n\n"
# Generate the image
progress_log += (
"πŸ”„ **Step 4/4:** Generating French-style arrangement image...\n\n"
)
progress_log += "πŸ–ΌοΈ Using AI image generation (SDXL-Turbo)...\n"
generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1)
progress_log += "βœ… French-style arrangement generated successfully!\n\n"
# Create analysis summary
analysis = f"""
**🌸 Flower Analysis:**
- **Type:** {top_flower} (confidence: {confidence:.2%})
- **Dominant Colors:** {color_desc}
**πŸ‡«πŸ‡· Generated Prompt:**
"{prompt}"
---
**πŸ“‹ Process Log:**
{progress_log}
"""
return (
generated_image,
"βœ… Analysis complete! French-style arrangement generated.",
analysis,
)
except Exception as e:
error_log = f"❌ **Error occurred during processing:**\n\n{e!s}\n\n"
if "progress_log" in locals():
error_log += f"**Progress before error:**\n{progress_log}"
return None, f"❌ Error: {e!s}", error_log
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown("# 🌸 SDXL-Turbo β€” Text β†’ Image + Flower Identifier")
with gr.Tabs():
with gr.TabItem("Generate"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
value="ikebana-style flower arrangement, soft natural light, minimalist",
label="Prompt",
)
steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
go = gr.Button("Generate", variant="primary")
out = gr.Image(label="Result", type="pil")
with gr.TabItem("Identify"):
with gr.Row():
with gr.Column():
img_in = gr.Image(
label="Image (upload or auto-filled from 'Generate')",
type="pil",
interactive=True,
)
labels_box = gr.CheckboxGroup(
choices=FLOWER_LABELS,
value=[
"rose",
"tulip",
"lily",
"peony",
"hydrangea",
"orchid",
"sunflower",
],
label="Candidate labels (edit as needed)",
)
topk = gr.Slider(1, 15, value=7, step=1, label="Top-K")
min_score = gr.Slider(
0.0, 1.0, value=0.12, step=0.01, label="Min confidence"
)
detect_btn = gr.Button("Identify Flowers", variant="primary")
with gr.Column():
results_tbl = gr.Dataframe(
headers=["Flower", "Confidence"],
datatype=["str", "number"],
interactive=False,
)
status = gr.Markdown()
with gr.TabItem("Train Model"):
gr.Markdown("## 🎯 Fine-tune the flower identification model")
gr.Markdown(
"Organize your training images in subdirectories by flower type in `training_data/images/`"
)
gr.Markdown(
"Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc."
)
with gr.Row():
with gr.Column():
gr.Markdown("### Training Data")
refresh_btn = gr.Button("πŸ”„ Refresh Data Count", size="sm")
data_status = gr.Markdown()
gr.Markdown("### Training Parameters")
epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs")
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
learning_rate = gr.Number(
value=1e-5, label="Learning Rate", precision=6
)
train_btn = gr.Button("πŸš€ Start Training", variant="primary")
with gr.Column():
gr.Markdown("### Model Management")
model_dropdown = gr.Dropdown(
choices=get_available_models(),
value="facebook/convnext-base-224-22k (default)",
label="Select Model",
)
refresh_models_btn = gr.Button("πŸ”„ Refresh Models", size="sm")
load_model_btn = gr.Button(
"πŸ“₯ Load Selected Model", variant="secondary"
)
model_status = gr.Markdown(
f"**Current model:** {current_model_path}"
)
gr.Markdown("### Training Status")
training_output = gr.Markdown()
with gr.TabItem("French Style arrangement"):
gr.Markdown("## πŸ‡«πŸ‡· French-Style Flower Arrangements")
gr.Markdown(
"Upload a flower image and generate an elegant French-style arrangement with matching colors!"
)
with gr.Row():
with gr.Column():
upload_img = gr.Image(label="Upload Flower Image", type="pil")
analyze_btn = gr.Button(
"🎨 Analyze & Generate French Style",
variant="primary",
size="lg",
)
with gr.Column():
french_result = gr.Image(
label="Generated French-Style Arrangement", type="pil"
)
french_status = gr.Markdown()
analysis_details = gr.Markdown()
# Wire events
go.click(generate, [prompt, steps, width, height, seed], [out])
# Auto-send generated image to Identify tab
out.change(passthrough, inputs=out, outputs=img_in)
# Run identification
detect_btn.click(
identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status]
)
# Training tab events
refresh_btn.click(count_training_images, outputs=[data_status])
refresh_models_btn.click(
lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown]
)
load_model_btn.click(
load_trained_model, inputs=[model_dropdown], outputs=[model_status]
)
train_btn.click(
start_training,
inputs=[epochs, batch_size, learning_rate],
outputs=[training_output],
)
# French Style tab events - update status during processing
def update_french_status():
return "πŸ”„ Processing... Please wait while we analyze your flower image...", ""
analyze_btn.click(
update_french_status, outputs=[french_status, analysis_details]
).then(
analyze_and_generate_french_style,
inputs=[upload_img],
outputs=[french_result, french_status, analysis_details],
)
# Initialize data count on load
demo.load(count_training_images, outputs=[data_status])
demo.queue().launch()