import torch from datasets import load_dataset from PIL import Image from io import BytesIO import requests import open_clip import chromadb # 1. Setup CLIP device = "cuda" if torch.cuda.is_available() else "cpu" model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai') model = model.to(device).eval() # 2. Load Pokémon cards dataset dataset = load_dataset("TheFusion21/PokemonCards", split="train") # 3. Setup ChromaDB client client = chromadb.PersistentClient(path="./chroma_db") collection = client.get_or_create_collection( name="pokemon_cards_clip", metadata={"hnsw:space": "cosine"} ) # 4. Insert embeddings into ChromaDB if not already indexed if collection.count() == 0: print("📦 Indexing images into ChromaDB...") for i, entry in enumerate(dataset): try: url = entry.get("image_url") if not url: continue # Téléchargement et prétraitement de l'image response = requests.get(url, timeout=10) image = Image.open(BytesIO(response.content)).convert("RGB") image_input = preprocess(image).unsqueeze(0).to(device) # Extraction de l'embedding with torch.no_grad(): image_features = model.encode_image(image_input) image_features /= image_features.norm(dim=-1, keepdim=True) # Ajout dans la collection collection.add( ids=[f"card-{i}"], embeddings=[image_features.squeeze().cpu().tolist()], metadatas=[{ "name": entry.get("name", ""), "set": entry.get("set_name", ""), "hp": entry.get("hp", ""), "image_url": url, "caption": entry.get("caption", "") }] ) if i % 100 == 0: print(f"✅ {i} cartes indexées...") except Exception as e: print(f"❌ Erreur sur carte {i}: {e}") print("✅ Indexation terminée.") else: print("✅ Les embeddings sont déjà présents dans ChromaDB.")