import gradio as gr from datasets import load_dataset from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel import torch from tqdm.auto import tqdm import numpy as np import time def CLIP_model(): global model, token, processor model_id = 'openai/clip-vit-base-patch32' model = CLIPModel.from_pretrained(model_id) token = CLIPTokenizerFast.from_pretrained(model_id) processor = CLIPProcessor.from_pretrained(model_id) def load_data(): global data data = load_dataset( 'frgfm/imagenette', 'full_size', split = 'train', ignore_verifications = False ) def embedding_input(text_input): token_input = token(text_input, return_tensors = "pt") text_embedd = model.get_text_features(**token_input) return text_embedd def embedding_img(): global img_arr, images images = data['image'] batch_size = 10 img_arr = None for i in tqdm(range(0, len(images), batch_size)): batch = images[i:i+batch_size] batch = processor( text = None, images = batch, return_tensors = 'pt', padding = True )['pixel_values'] batch_emb = model.get_image_features(pixel_values=batch) batch_emb = batch_emb.squeeze(0) batch_emb = batch_emb.detach().numpy() if img_arr is None: img_arr = batch_emb else: img_arr = np.concatenate((img_arr, batch_emb), axis = 0) return images, img_arr def main(): CLIP_model() load_data() embedding_img() iface = gr.Interface(fn = process, inputs = "text", outputs = "image") iface.launch(inline = False) def process(text): text_input = embedding_input(text) image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T text_emb = text_input.detach().numpy() scores = np.dot(text_emb, image_emb.T) idx = np.argsort(-scores[0])[0] return images[idx] if __name__ == "__main__": main()