artcase22 / app.py
abdelrahmanASDLF's picture
Update app.py
23684c1 verified
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
from io import BytesIO
import requests
import numpy as np
from sklearn.cluster import KMeans
# Load CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
device = "cpu"
model.to(device)
PROMPT_CATEGORIES = [
"Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters",
"Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers",
"Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist",
"Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures",
"Football & Basketball", "Extreme Sports", "Fitness & Gym",
"Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic",
"Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics",
"Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs"
]
CATEGORIES = [
"Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters",
"Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers",
"Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist",
"Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures",
"Football & Basketball", "Extreme Sports", "Fitness & Gym",
"Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic",
"Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics",
"Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs"
]
TARGET_COLORS = {
'Red': (255, 107, 107),
'Teal': (78, 205, 196),
'Blue': (69, 183, 209),
'Green': (150, 206, 180),
'Yellow': (255, 234, 167),
'Orange': (225, 112, 85),
'Black': (45, 52, 54),
'Gray': (223, 230, 233),
'Pink': (253, 121, 168),
'Purple': (155, 89, 182),
'Light Purple': (200, 162, 200),
'Dark Blue': (52, 73, 94),
'Beige': (245, 222, 179),
'Brown': (139, 69, 19),
'White': (255, 255, 255),
}
def closest_palette_color(rgb):
min_diff = float('inf')
closest_name = "Unknown"
for name, ref_rgb in TARGET_COLORS.items():
diff = sum((comp1 - comp2) ** 2 for comp1, comp2 in zip(rgb, ref_rgb))
if diff < min_diff:
min_diff = diff
closest_name = name
return closest_name
def get_dominant_colors(img, num_colors=3):
img = img.resize((128, 128))
np_pixels = np.array(img).reshape(-1, 3)
filtered = [p for p in np_pixels if not (
np.all(p > [240, 240, 240]) or np.all(p < [15, 15, 15])
)]
if not filtered:
return ["Unknown"]
filtered = np.array(filtered)
kmeans = KMeans(n_clusters=5, random_state=42, n_init=10)
kmeans.fit(filtered)
centers = kmeans.cluster_centers_
mapped_colors = [closest_palette_color(tuple(map(int, center))) for center in centers]
unique_colors = []
for color in mapped_colors:
if color not in unique_colors and color != "Unknown":
unique_colors.append(color)
if len(unique_colors) == num_colors:
break
return unique_colors or ["Unknown"]
def classify_and_tag(image_url):
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception:
return "", "", "", "Error: Could not load image."
# Classification
inputs = processor(text=PROMPT_CATEGORIES, images=image, return_tensors="pt", padding=True).to(device)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
top_idx = torch.argmax(probs).item()
category = CATEGORIES[top_idx] # βœ… plain text, not hashtag
# Colors
colors = get_dominant_colors(image)
color1 = colors[0] if len(colors) > 0 else ""
color2 = colors[1] if len(colors) > 1 else ""
color3 = colors[2] if len(colors) > 2 else ""
# βœ… Return category + 3 color names
return category, color1, color2, color3
demo = gr.Interface(
fn=classify_and_tag,
inputs=gr.Text(label="Image URL"),
outputs=[
gr.Text(label="Category"), # <== plain text
gr.Text(label="Color 1"),
gr.Text(label="Color 2"),
gr.Text(label="Color 3"),
],
title="🎨 ArtCase Design Analyzer",
description="Paste an image URL to get category and top 3 colors."
)
demo.launch()