Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import numpy as np | |
from sklearn.cluster import KMeans | |
# Load CLIP | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
device = "cpu" | |
model.to(device) | |
PROMPT_CATEGORIES = [ | |
"Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters", | |
"Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers", | |
"Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist", | |
"Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures", | |
"Football & Basketball", "Extreme Sports", "Fitness & Gym", | |
"Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic", | |
"Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics", | |
"Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs" | |
] | |
CATEGORIES = [ | |
"Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters", | |
"Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers", | |
"Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist", | |
"Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures", | |
"Football & Basketball", "Extreme Sports", "Fitness & Gym", | |
"Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic", | |
"Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics", | |
"Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs" | |
] | |
TARGET_COLORS = { | |
'Red': (255, 107, 107), | |
'Teal': (78, 205, 196), | |
'Blue': (69, 183, 209), | |
'Green': (150, 206, 180), | |
'Yellow': (255, 234, 167), | |
'Orange': (225, 112, 85), | |
'Black': (45, 52, 54), | |
'Gray': (223, 230, 233), | |
'Pink': (253, 121, 168), | |
'Purple': (155, 89, 182), | |
'Light Purple': (200, 162, 200), | |
'Dark Blue': (52, 73, 94), | |
'Beige': (245, 222, 179), | |
'Brown': (139, 69, 19), | |
'White': (255, 255, 255), | |
} | |
def closest_palette_color(rgb): | |
min_diff = float('inf') | |
closest_name = "Unknown" | |
for name, ref_rgb in TARGET_COLORS.items(): | |
diff = sum((comp1 - comp2) ** 2 for comp1, comp2 in zip(rgb, ref_rgb)) | |
if diff < min_diff: | |
min_diff = diff | |
closest_name = name | |
return closest_name | |
def get_dominant_colors(img, num_colors=3): | |
img = img.resize((128, 128)) | |
np_pixels = np.array(img).reshape(-1, 3) | |
filtered = [p for p in np_pixels if not ( | |
np.all(p > [240, 240, 240]) or np.all(p < [15, 15, 15]) | |
)] | |
if not filtered: | |
return ["Unknown"] | |
filtered = np.array(filtered) | |
kmeans = KMeans(n_clusters=5, random_state=42, n_init=10) | |
kmeans.fit(filtered) | |
centers = kmeans.cluster_centers_ | |
mapped_colors = [closest_palette_color(tuple(map(int, center))) for center in centers] | |
unique_colors = [] | |
for color in mapped_colors: | |
if color not in unique_colors and color != "Unknown": | |
unique_colors.append(color) | |
if len(unique_colors) == num_colors: | |
break | |
return unique_colors or ["Unknown"] | |
def classify_and_tag(image_url): | |
try: | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
except Exception: | |
return "", "", "", "Error: Could not load image." | |
# Classification | |
inputs = processor(text=PROMPT_CATEGORIES, images=image, return_tensors="pt", padding=True).to(device) | |
outputs = model(**inputs) | |
probs = outputs.logits_per_image.softmax(dim=1) | |
top_idx = torch.argmax(probs).item() | |
category = CATEGORIES[top_idx] # β plain text, not hashtag | |
# Colors | |
colors = get_dominant_colors(image) | |
color1 = colors[0] if len(colors) > 0 else "" | |
color2 = colors[1] if len(colors) > 1 else "" | |
color3 = colors[2] if len(colors) > 2 else "" | |
# β Return category + 3 color names | |
return category, color1, color2, color3 | |
demo = gr.Interface( | |
fn=classify_and_tag, | |
inputs=gr.Text(label="Image URL"), | |
outputs=[ | |
gr.Text(label="Category"), # <== plain text | |
gr.Text(label="Color 1"), | |
gr.Text(label="Color 2"), | |
gr.Text(label="Color 3"), | |
], | |
title="π¨ ArtCase Design Analyzer", | |
description="Paste an image URL to get category and top 3 colors." | |
) | |
demo.launch() | |