|
import gradio as gr |
|
from datasets import load_dataset |
|
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel |
|
import torch |
|
from tqdm.auto import tqdm |
|
import numpy as np |
|
import time |
|
|
|
device = 'cpu' |
|
model_id = 'openai/clip-vit-base-patch32' |
|
model = CLIPModel.from_pretrained(model_id).to(device) |
|
tokenizer = CLIPTokenizerFast.from_pretrained(model_id) |
|
processor = CLIPProcessor.from_pretrained(model_id) |
|
|
|
|
|
|
|
def load_data(): |
|
global imagenette |
|
imagenette = load_dataset( |
|
'frgfm/imagenette', |
|
'full_size', |
|
split = 'train', |
|
ignore_verifications = False |
|
) |
|
return imagenette |
|
|
|
def embedding_input(text_input): |
|
token_input = tokenizer(text_input, return_tensors = "pt") |
|
text_emb = model.get_text_features(**token_input.to(device)) |
|
return text_emb |
|
|
|
def embedding_img(): |
|
global images |
|
img_batch = imagenette['image'] |
|
|
|
images = processor( |
|
text = None, |
|
images = img_batch, |
|
return_tensors = 'pt' |
|
)['pixel_values'].to(device) |
|
batch_emb = model.get_image_features(pixel_values =img_batch) |
|
batch_emb = batch_emb.squeeze(0) |
|
image_arr = batch_emb.cpu().detach().numpy() |
|
|
|
return image_arr |
|
|
|
def norm_val(text_input): |
|
image_arr = embedding_img() |
|
time.sleep(5) |
|
text_emb = embedding_input(text_input) |
|
|
|
image_arr = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T |
|
text_emb = text_emb.cpu().detach().numpy() |
|
scores = np.dot(text_emb, image_arr.T) |
|
top_k = 1 |
|
idx = np.argsort(-scores[0])[:top_k] |
|
return images[idx[0]] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
load_data() |
|
iface = gr.Interface(fn=norm_val, inputs="text", outputs="image") |
|
iface.launch(inline = False ) |
|
|