Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel | |
import torch | |
from tqdm.auto import tqdm | |
import numpy as np | |
import time | |
def CLIP_model(): | |
global model, token, processor | |
model_id = 'openai/clip-vit-base-patch32' | |
model = CLIPModel.from_pretrained(model_id) | |
token = CLIPTokenizerFast.from_pretrained(model_id) | |
processor = CLIPProcessor.from_pretrained(model_id) | |
def load_data(): | |
global data | |
data = load_dataset( | |
'frgfm/imagenette', | |
'full_size', | |
split = 'train', | |
ignore_verifications = False | |
) | |
def embedding_input(text_input): | |
token_input = token(text_input, return_tensors = "pt") | |
text_embedd = model.get_text_features(**token_input) | |
return text_embedd | |
def embedding_img(): | |
global img_arr, images | |
images = data['image'] | |
batch_size = 10 | |
img_arr = None | |
for i in tqdm(range(0, len(images), batch_size)): | |
batch = images[i:i+batch_size] | |
batch = processor( | |
text = None, | |
images = batch, | |
return_tensors = 'pt', | |
padding = True | |
)['pixel_values'] | |
batch_emb = model.get_image_features(pixel_values=batch) | |
batch_emb = batch_emb.squeeze(0) | |
batch_emb = batch_emb.detach().numpy() | |
if img_arr is None: | |
img_arr = batch_emb | |
else: | |
img_arr = np.concatenate((img_arr, batch_emb), axis = 0) | |
return images, img_arr | |
def main(): | |
CLIP_model() | |
load_data() | |
embedding_img() | |
iface = gr.Interface(fn = process, inputs = "text", outputs = "image") | |
iface.launch(inline = False) | |
def process(text): | |
text_input = embedding_input(text) | |
image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T | |
text_emb = text_input.detach().numpy() | |
scores = np.dot(text_emb, image_emb.T) | |
idx = np.argsort(-scores[0])[0] | |
return images[idx] | |
if __name__ == "__main__": | |
main() | |