File size: 2,267 Bytes
5644b77 d9e380c 6ae99b6 d9e380c 5644b77 d9e380c 4dc1f91 6ae99b6 4dc1f91 6ae99b6 d9e380c 4dc1f91 d9e380c 4dc1f91 d9e380c 4dc1f91 33a4812 d9e380c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from datasets import load_dataset
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import torch
from tqdm.auto import tqdm
import numpy as np
import time
device = 'cpu' # 'cuda' if torch.cuda.is_available() else "cpu"
model_id = 'openai/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_id).to(device)
tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
def load_data():
global imagenette
imagenette = load_dataset(
'frgfm/imagenette',
'full_size',
split = 'train',
ignore_verifications = False # set to True if seeing splits Error
)
return imagenette
def embedding_input(text_input):
token_input = tokenizer(text_input, return_tensors = "pt")
text_emb = model.get_text_features(**token_input.to(device))
return text_emb
def embedding_img():
global images, image_arr
load_data()
sample_idx= np.random.randint(0, len(imagenette)+1, 100).tolist()
images = [imagenette[i]['image'] for i in sample_idx]
batch_sie = 5
image_arr = None
for i in tqdm(range(0, len(images), batch_sie)):
time.sleep(1)
batch = images[i:i+batch_sie]
batch = processor(
text = None,
images = batch,
return_tensors= 'pt',
padding = True
)['pixel_values'].to(device)
batch_emb = model.get_image_features(pixel_values = batch)
batch_emb = batch_emb.squeeze(0)
batch_emb = batch_emb.cpu().detach().numpy()
if image_arr is None:
image_arr = batch_emb
else:
image_arr = np.concatenate((image_arr, batch_emb), axis = 0)
return image_arr
def norm_val(text_input):
text_emb = embedding_input(text_input)
image_emb = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T
text_emb = text_emb.cpu().detach().numpy()
scores = np.dot(text_emb, image_emb.T)
top_k = 1
idx = np.argsort(-scores[0])[:top_k]
return images[idx[0]]
if __name__ == "__main__":
embedding_img()
load_data()
iface = gr.Interface(fn=norm_val, inputs="text", outputs="image")
iface.launch(inline = False )
|