|
|
|
import glob |
|
import os |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from diffusers import AutoPipelineForText2Image |
|
from simple_train import simple_train |
|
from sklearn.cluster import KMeans |
|
from transformers import ( |
|
AutoImageProcessor, |
|
AutoModelForImageClassification, |
|
ConvNextForImageClassification, |
|
ConvNextImageProcessor, |
|
pipeline, |
|
) |
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sdxl-turbo") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) |
|
if device == "cuda": |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except Exception: |
|
pipe.enable_attention_slicing() |
|
else: |
|
pipe.enable_attention_slicing() |
|
|
|
|
|
def generate(prompt, steps, width, height, seed): |
|
if seed is None or int(seed) < 0: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(int(seed)) |
|
|
|
result = pipe( |
|
prompt=prompt, |
|
num_inference_steps=int(steps), |
|
guidance_scale=0.0, |
|
width=int(width // 8) * 8, |
|
height=int(height // 8) * 8, |
|
generator=generator, |
|
) |
|
return result.images[0] |
|
|
|
|
|
|
|
|
|
FLOWER_LABELS = [ |
|
"rose", |
|
"tulip", |
|
"lily", |
|
"peony", |
|
"sunflower", |
|
"chrysanthemum", |
|
"carnation", |
|
"orchid", |
|
"hydrangea", |
|
"daisy", |
|
"dahlia", |
|
"ranunculus", |
|
"anemone", |
|
"marigold", |
|
"lavender", |
|
"magnolia", |
|
"gardenia", |
|
"camellia", |
|
"jasmine", |
|
"iris", |
|
"gerbera", |
|
"zinnia", |
|
"hibiscus", |
|
"lotus", |
|
"poppy", |
|
"sweet pea", |
|
"freesia", |
|
"lisianthus", |
|
"calla lily", |
|
"cherry blossom", |
|
"plumeria", |
|
"cosmos", |
|
] |
|
|
|
|
|
clf_device = 0 if torch.cuda.is_available() else -1 |
|
zs_classifier = None |
|
convnext_model = None |
|
convnext_processor = None |
|
current_model_path = "facebook/convnext-base-224-22k" |
|
|
|
|
|
def load_classifier(model_path="facebook/convnext-base-224-22k"): |
|
global zs_classifier, convnext_model, convnext_processor, current_model_path |
|
try: |
|
if os.path.exists(model_path): |
|
|
|
convnext_model = AutoModelForImageClassification.from_pretrained(model_path) |
|
convnext_processor = AutoImageProcessor.from_pretrained(model_path) |
|
current_model_path = model_path |
|
|
|
zs_classifier = pipeline( |
|
task="zero-shot-image-classification", |
|
model="openai/clip-vit-base-patch32", |
|
device=clf_device, |
|
) |
|
return f"β
Loaded custom ConvNeXt model from: {model_path}" |
|
else: |
|
|
|
convnext_model = ConvNextForImageClassification.from_pretrained( |
|
"facebook/convnext-base-224-22k" |
|
) |
|
convnext_processor = ConvNextImageProcessor.from_pretrained( |
|
"facebook/convnext-base-224-22k" |
|
) |
|
zs_classifier = pipeline( |
|
task="zero-shot-image-classification", |
|
model="openai/clip-vit-base-patch32", |
|
device=clf_device, |
|
) |
|
current_model_path = "facebook/convnext-base-224-22k" |
|
return "β
Loaded default ConvNeXt model: facebook/convnext-base-224-22k" |
|
except Exception as e: |
|
return f"β Error loading model: {e!s}" |
|
|
|
|
|
|
|
load_classifier() |
|
|
|
|
|
def identify_flowers(image, candidate_labels, top_k, min_score): |
|
if image is None: |
|
return [], "Please provide an image (upload or generate first)." |
|
|
|
labels = candidate_labels if candidate_labels else FLOWER_LABELS |
|
|
|
|
|
if ( |
|
convnext_model is not None |
|
and os.path.exists(current_model_path) |
|
and current_model_path != "facebook/convnext-base-224-22k" |
|
): |
|
try: |
|
|
|
inputs = convnext_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = convnext_model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
results = [] |
|
for i, score in enumerate(predictions[0]): |
|
if i < len(labels): |
|
results.append({"label": labels[i], "score": float(score)}) |
|
|
|
|
|
results = sorted(results, key=lambda r: r["score"], reverse=True) |
|
except Exception: |
|
|
|
results = zs_classifier( |
|
image, candidate_labels=labels, hypothesis_template="a photo of a {}" |
|
) |
|
else: |
|
|
|
results = zs_classifier( |
|
image, candidate_labels=labels, hypothesis_template="a photo of a {}" |
|
) |
|
|
|
|
|
results = [r for r in results if r["score"] >= float(min_score)] |
|
results = sorted(results, key=lambda r: r["score"], reverse=True)[: int(top_k)] |
|
table = [[r["label"], round(float(r["score"]), 4)] for r in results] |
|
model_type = ( |
|
"ConvNeXt" |
|
if ( |
|
convnext_model is not None |
|
and os.path.exists(current_model_path) |
|
and current_model_path != "facebook/convnext-base-224-22k" |
|
) |
|
else "CLIP zero-shot" |
|
) |
|
msg = f"Detected flowers using {model_type}." |
|
return table, msg |
|
|
|
|
|
|
|
def passthrough(img): |
|
return img |
|
|
|
|
|
|
|
def get_available_models(): |
|
models_dir = "training_data/trained_models" |
|
if not os.path.exists(models_dir): |
|
return ["facebook/convnext-base-224-22k (default)"] |
|
|
|
models = ["facebook/convnext-base-224-22k (default)"] |
|
for item in os.listdir(models_dir): |
|
model_path = os.path.join(models_dir, item) |
|
if os.path.isdir(model_path) and os.path.exists( |
|
os.path.join(model_path, "config.json") |
|
): |
|
models.append(f"Custom: {item}") |
|
return models |
|
|
|
|
|
def count_training_images(): |
|
images_dir = "training_data/images" |
|
if not os.path.exists(images_dir): |
|
return "Training directory not found" |
|
|
|
total_images = 0 |
|
flower_counts = {} |
|
|
|
for flower_type in os.listdir(images_dir): |
|
flower_path = os.path.join(images_dir, flower_type) |
|
if os.path.isdir(flower_path): |
|
image_files = ( |
|
glob.glob(os.path.join(flower_path, "*.jpg")) |
|
+ glob.glob(os.path.join(flower_path, "*.jpeg")) |
|
+ glob.glob(os.path.join(flower_path, "*.png")) |
|
+ glob.glob(os.path.join(flower_path, "*.webp")) |
|
) |
|
count = len(image_files) |
|
if count > 0: |
|
flower_counts[flower_type] = count |
|
total_images += count |
|
|
|
if total_images == 0: |
|
return "No training images found. Add images to subdirectories in training_data/images/" |
|
|
|
result = f"**Total images: {total_images}**\n\n" |
|
for flower_type, count in sorted(flower_counts.items()): |
|
result += f"- {flower_type}: {count} images\n" |
|
|
|
return result |
|
|
|
|
|
def start_training(epochs=None, batch_size=None, learning_rate=None): |
|
try: |
|
|
|
images_dir = "training_data/images" |
|
if not os.path.exists(images_dir): |
|
return "β Training directory not found. Please create training_data/images/ and add your data." |
|
|
|
|
|
total_images = 0 |
|
for flower_type in os.listdir(images_dir): |
|
flower_path = os.path.join(images_dir, flower_type) |
|
if os.path.isdir(flower_path): |
|
image_files = ( |
|
glob.glob(os.path.join(flower_path, "*.jpg")) |
|
+ glob.glob(os.path.join(flower_path, "*.jpeg")) |
|
+ glob.glob(os.path.join(flower_path, "*.png")) |
|
+ glob.glob(os.path.join(flower_path, "*.webp")) |
|
) |
|
total_images += len(image_files) |
|
|
|
if total_images < 10: |
|
return f"β Need at least 10 training images. Found {total_images}. Add more images to training_data/images/" |
|
|
|
|
|
model_path = simple_train() |
|
|
|
if model_path: |
|
return f"β
Training completed! Model saved to: {model_path}" |
|
else: |
|
return "β Training failed. Check the console for details." |
|
|
|
except Exception as e: |
|
return f"β Training error: {e!s}" |
|
|
|
|
|
def load_trained_model(model_selection): |
|
if model_selection.startswith("Custom:"): |
|
model_name = model_selection.replace("Custom: ", "") |
|
model_path = os.path.join("training_data/trained_models", model_name) |
|
return load_classifier(model_path) |
|
else: |
|
return load_classifier("facebook/convnext-base-224-22k") |
|
|
|
|
|
|
|
def extract_dominant_colors(image, num_colors=5): |
|
"""Extract dominant colors from an image using k-means clustering""" |
|
if image is None: |
|
return [], "No image provided" |
|
|
|
|
|
img_array = np.array(image) |
|
|
|
|
|
pixels = img_array.reshape(-1, 3) |
|
|
|
|
|
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) |
|
kmeans.fit(pixels) |
|
|
|
|
|
colors = kmeans.cluster_centers_.astype(int) |
|
|
|
|
|
color_names = [] |
|
for color in colors: |
|
r, g, b = color |
|
if r > 200 and g > 200 and b > 200: |
|
color_names.append("white") |
|
elif r < 50 and g < 50 and b < 50: |
|
color_names.append("black") |
|
elif r > g and r > b: |
|
if r > 150 and g < 100: |
|
color_names.append("red" if g < 50 else "pink") |
|
else: |
|
color_names.append("coral") |
|
elif g > r and g > b: |
|
if b < 100: |
|
color_names.append("yellow" if g > 200 and r > 150 else "green") |
|
else: |
|
color_names.append("teal") |
|
elif b > r and b > g: |
|
if r < 100: |
|
color_names.append("blue" if b > 150 else "navy") |
|
else: |
|
color_names.append("purple" if r > g else "lavender") |
|
elif r > 150 and g > 100 and b < 100: |
|
color_names.append("orange") |
|
else: |
|
color_names.append("cream") |
|
|
|
return color_names, colors |
|
|
|
|
|
def analyze_and_generate_french_style(image): |
|
"""Analyze uploaded flower image and generate French-style arrangement""" |
|
if image is None: |
|
return None, "Please upload an image", "" |
|
|
|
|
|
if zs_classifier is None: |
|
return None, "Model not loaded", "" |
|
|
|
try: |
|
progress_log = "π **Step 1/4:** Starting flower analysis...\n\n" |
|
|
|
|
|
progress_log += "π Identifying flower type using AI model...\n" |
|
results = zs_classifier( |
|
image, candidate_labels=FLOWER_LABELS, hypothesis_template="a photo of a {}" |
|
) |
|
|
|
top_flower = results[0]["label"] if results else "flower" |
|
confidence = results[0]["score"] if results else 0 |
|
progress_log += ( |
|
f"β
Identified: **{top_flower}** (confidence: {confidence:.2%})\n\n" |
|
) |
|
|
|
|
|
progress_log += "π **Step 2/4:** Analyzing color palette...\n\n" |
|
progress_log += "π¨ Extracting dominant colors from image...\n" |
|
color_names, color_rgb = extract_dominant_colors(image, num_colors=3) |
|
|
|
|
|
main_colors = color_names[:3] |
|
color_desc = ", ".join(main_colors) |
|
progress_log += f"β
Color palette: **{color_desc}**\n\n" |
|
|
|
|
|
progress_log += ( |
|
"π **Step 3/4:** Creating French-style arrangement prompt...\n\n" |
|
) |
|
prompt = f"elegant French-style floral arrangement featuring {top_flower}s in {color_desc} colors, displayed in a clear crystal vase on a marble kitchen countertop, soft natural lighting, minimalist French country kitchen background, professional photography, sophisticated composition" |
|
progress_log += f"β
Prompt created: *{prompt[:100]}...*\n\n" |
|
|
|
|
|
progress_log += ( |
|
"π **Step 4/4:** Generating French-style arrangement image...\n\n" |
|
) |
|
progress_log += "πΌοΈ Using AI image generation (SDXL-Turbo)...\n" |
|
generated_image = generate(prompt, steps=4, width=1024, height=1024, seed=-1) |
|
progress_log += "β
French-style arrangement generated successfully!\n\n" |
|
|
|
|
|
analysis = f""" |
|
**πΈ Flower Analysis:** |
|
- **Type:** {top_flower} (confidence: {confidence:.2%}) |
|
- **Dominant Colors:** {color_desc} |
|
|
|
**π«π· Generated Prompt:** |
|
"{prompt}" |
|
|
|
--- |
|
|
|
**π Process Log:** |
|
{progress_log} |
|
""" |
|
|
|
return ( |
|
generated_image, |
|
"β
Analysis complete! French-style arrangement generated.", |
|
analysis, |
|
) |
|
|
|
except Exception as e: |
|
error_log = f"β **Error occurred during processing:**\n\n{e!s}\n\n" |
|
if "progress_log" in locals(): |
|
error_log += f"**Progress before error:**\n{progress_log}" |
|
return None, f"β Error: {e!s}", error_log |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# πΈ SDXL-Turbo β Text β Image + Flower Identifier") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Generate"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
value="ikebana-style flower arrangement, soft natural light, minimalist", |
|
label="Prompt", |
|
) |
|
steps = gr.Slider(1, 8, value=4, step=1, label="Steps") |
|
width = gr.Slider(512, 1536, value=1024, step=8, label="Width") |
|
height = gr.Slider(512, 1536, value=1024, step=8, label="Height") |
|
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") |
|
go = gr.Button("Generate", variant="primary") |
|
out = gr.Image(label="Result", type="pil") |
|
|
|
with gr.TabItem("Identify"), gr.Row(): |
|
with gr.Column(): |
|
img_in = gr.Image( |
|
label="Image (upload or auto-filled from 'Generate')", |
|
type="pil", |
|
interactive=True, |
|
) |
|
labels_box = gr.CheckboxGroup( |
|
choices=FLOWER_LABELS, |
|
value=[ |
|
"rose", |
|
"tulip", |
|
"lily", |
|
"peony", |
|
"hydrangea", |
|
"orchid", |
|
"sunflower", |
|
], |
|
label="Candidate labels (edit as needed)", |
|
) |
|
topk = gr.Slider(1, 15, value=7, step=1, label="Top-K") |
|
min_score = gr.Slider( |
|
0.0, 1.0, value=0.12, step=0.01, label="Min confidence" |
|
) |
|
detect_btn = gr.Button("Identify Flowers", variant="primary") |
|
with gr.Column(): |
|
results_tbl = gr.Dataframe( |
|
headers=["Flower", "Confidence"], |
|
datatype=["str", "number"], |
|
interactive=False, |
|
) |
|
status = gr.Markdown() |
|
|
|
with gr.TabItem("Train Model"): |
|
gr.Markdown("## π― Fine-tune the flower identification model") |
|
gr.Markdown( |
|
"Organize your training images in subdirectories by flower type in `training_data/images/`" |
|
) |
|
gr.Markdown( |
|
"Example: `training_data/images/roses/`, `training_data/images/tulips/`, etc." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Training Data") |
|
refresh_btn = gr.Button("π Refresh Data Count", size="sm") |
|
data_status = gr.Markdown() |
|
|
|
gr.Markdown("### Training Parameters") |
|
epochs = gr.Slider(1, 20, value=5, step=1, label="Training Epochs") |
|
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size") |
|
learning_rate = gr.Number( |
|
value=1e-5, label="Learning Rate", precision=6 |
|
) |
|
|
|
train_btn = gr.Button("π Start Training", variant="primary") |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Model Management") |
|
model_dropdown = gr.Dropdown( |
|
choices=get_available_models(), |
|
value="facebook/convnext-base-224-22k (default)", |
|
label="Select Model", |
|
) |
|
refresh_models_btn = gr.Button("π Refresh Models", size="sm") |
|
load_model_btn = gr.Button( |
|
"π₯ Load Selected Model", variant="secondary" |
|
) |
|
|
|
model_status = gr.Markdown( |
|
f"**Current model:** {current_model_path}" |
|
) |
|
|
|
gr.Markdown("### Training Status") |
|
training_output = gr.Markdown() |
|
|
|
with gr.TabItem("French Style arrangement"): |
|
gr.Markdown("## π«π· French-Style Flower Arrangements") |
|
gr.Markdown( |
|
"Upload a flower image and generate an elegant French-style arrangement with matching colors!" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
upload_img = gr.Image(label="Upload Flower Image", type="pil") |
|
analyze_btn = gr.Button( |
|
"π¨ Analyze & Generate French Style", |
|
variant="primary", |
|
size="lg", |
|
) |
|
|
|
with gr.Column(): |
|
french_result = gr.Image( |
|
label="Generated French-Style Arrangement", type="pil" |
|
) |
|
french_status = gr.Markdown() |
|
analysis_details = gr.Markdown() |
|
|
|
|
|
go.click(generate, [prompt, steps, width, height, seed], [out]) |
|
|
|
out.change(passthrough, inputs=out, outputs=img_in) |
|
|
|
detect_btn.click( |
|
identify_flowers, [img_in, labels_box, topk, min_score], [results_tbl, status] |
|
) |
|
|
|
|
|
refresh_btn.click(count_training_images, outputs=[data_status]) |
|
refresh_models_btn.click( |
|
lambda: gr.Dropdown(choices=get_available_models()), outputs=[model_dropdown] |
|
) |
|
load_model_btn.click( |
|
load_trained_model, inputs=[model_dropdown], outputs=[model_status] |
|
) |
|
train_btn.click( |
|
start_training, |
|
inputs=[epochs, batch_size, learning_rate], |
|
outputs=[training_output], |
|
) |
|
|
|
|
|
def update_french_status(): |
|
return "π Processing... Please wait while we analyze your flower image...", "" |
|
|
|
analyze_btn.click( |
|
update_french_status, outputs=[french_status, analysis_details] |
|
).then( |
|
analyze_and_generate_french_style, |
|
inputs=[upload_img], |
|
outputs=[french_result, french_status, analysis_details], |
|
) |
|
|
|
|
|
demo.load(count_training_images, outputs=[data_status]) |
|
|
|
demo.queue().launch() |
|
|