import gradio as gr from transformers import CLIPProcessor, CLIPModel import torch from PIL import Image from io import BytesIO import requests import numpy as np from sklearn.cluster import KMeans # Load CLIP model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") device = "cpu" model.to(device) PROMPT_CATEGORIES = [ "Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters", "Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers", "Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist", "Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures", "Football & Basketball", "Extreme Sports", "Fitness & Gym", "Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic", "Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics", "Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs" ] CATEGORIES = [ "Anime & Manga", "TV Shows & Movies", "Video Games", "Cartoons & Animated Characters", "Pop Culture & Music", "K-Pop & Idol Groups", "Celebrities & Influencers", "Floral & Botanical", "Scenery & Landscapes", "Abstract & Minimalist", "Cats & Dogs", "Wildlife & Exotic Animals", "Fantasy Creatures", "Football & Basketball", "Extreme Sports", "Fitness & Gym", "Motivational & Inspirational", "Funny & Meme-Based", "Dark & Gothic", "Cyberpunk & Sci-Fi", "Glitch & Vaporwave", "AI & Robotics", "Flags & National Pride", "Traditional Art", "Astrology & Zodiac Signs" ] TARGET_COLORS = { 'Red': (255, 107, 107), 'Teal': (78, 205, 196), 'Blue': (69, 183, 209), 'Green': (150, 206, 180), 'Yellow': (255, 234, 167), 'Orange': (225, 112, 85), 'Black': (45, 52, 54), 'Gray': (223, 230, 233), 'Pink': (253, 121, 168), 'Purple': (155, 89, 182), 'Light Purple': (200, 162, 200), 'Dark Blue': (52, 73, 94), 'Beige': (245, 222, 179), 'Brown': (139, 69, 19), 'White': (255, 255, 255), } def closest_palette_color(rgb): min_diff = float('inf') closest_name = "Unknown" for name, ref_rgb in TARGET_COLORS.items(): diff = sum((comp1 - comp2) ** 2 for comp1, comp2 in zip(rgb, ref_rgb)) if diff < min_diff: min_diff = diff closest_name = name return closest_name def get_dominant_colors(img, num_colors=3): img = img.resize((128, 128)) np_pixels = np.array(img).reshape(-1, 3) filtered = [p for p in np_pixels if not ( np.all(p > [240, 240, 240]) or np.all(p < [15, 15, 15]) )] if not filtered: return ["Unknown"] filtered = np.array(filtered) kmeans = KMeans(n_clusters=5, random_state=42, n_init=10) kmeans.fit(filtered) centers = kmeans.cluster_centers_ mapped_colors = [closest_palette_color(tuple(map(int, center))) for center in centers] unique_colors = [] for color in mapped_colors: if color not in unique_colors and color != "Unknown": unique_colors.append(color) if len(unique_colors) == num_colors: break return unique_colors or ["Unknown"] def classify_and_tag(image_url): try: response = requests.get(image_url) image = Image.open(BytesIO(response.content)).convert("RGB") except Exception: return "", "", "", "Error: Could not load image." # Classification inputs = processor(text=PROMPT_CATEGORIES, images=image, return_tensors="pt", padding=True).to(device) outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) top_idx = torch.argmax(probs).item() category = CATEGORIES[top_idx] # ✅ plain text, not hashtag # Colors colors = get_dominant_colors(image) color1 = colors[0] if len(colors) > 0 else "" color2 = colors[1] if len(colors) > 1 else "" color3 = colors[2] if len(colors) > 2 else "" # ✅ Return category + 3 color names return category, color1, color2, color3 demo = gr.Interface( fn=classify_and_tag, inputs=gr.Text(label="Image URL"), outputs=[ gr.Text(label="Category"), # <== plain text gr.Text(label="Color 1"), gr.Text(label="Color 2"), gr.Text(label="Color 3"), ], title="🎨 ArtCase Design Analyzer", description="Paste an image URL to get category and top 3 colors." ) demo.launch()