File size: 2,267 Bytes
5644b77
d9e380c
 
 
 
 
6ae99b6
d9e380c
 
 
 
 
 
 
 
5644b77
d9e380c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc1f91
 
 
 
 
 
 
 
 
6ae99b6
4dc1f91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae99b6
d9e380c
 
 
4dc1f91
d9e380c
4dc1f91
d9e380c
 
 
 
 
 
 
 
 
4dc1f91
33a4812
d9e380c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from datasets import load_dataset
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import torch
from tqdm.auto import tqdm
import numpy as np
import time

device = 'cpu' # 'cuda' if torch.cuda.is_available() else "cpu"
model_id = 'openai/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_id).to(device)
tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)



def load_data():
    global imagenette
    imagenette = load_dataset(
        'frgfm/imagenette',
        'full_size',
        split = 'train',
        ignore_verifications = False # set to True if seeing splits Error
    )
    return imagenette

def embedding_input(text_input):
    token_input = tokenizer(text_input, return_tensors = "pt")
    text_emb = model.get_text_features(**token_input.to(device))
    return text_emb

def embedding_img():
    global images, image_arr
    load_data()
    sample_idx= np.random.randint(0, len(imagenette)+1, 100).tolist()
    images = [imagenette[i]['image'] for i in sample_idx]
    batch_sie = 5
    image_arr = None
    for i in tqdm(range(0, len(images), batch_sie)):
        time.sleep(1)
        batch = images[i:i+batch_sie]

        batch = processor(
            text = None,
            images = batch,
            return_tensors= 'pt',
            padding = True
        )['pixel_values'].to(device)
        batch_emb = model.get_image_features(pixel_values = batch)
        batch_emb = batch_emb.squeeze(0)
        batch_emb = batch_emb.cpu().detach().numpy()
        
        if image_arr is None:
            image_arr = batch_emb
            
        else:
            image_arr = np.concatenate((image_arr, batch_emb), axis = 0)
    return image_arr

def norm_val(text_input):
    text_emb = embedding_input(text_input)
    image_emb = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T
    text_emb = text_emb.cpu().detach().numpy()
    scores = np.dot(text_emb, image_emb.T)
    top_k = 1
    idx = np.argsort(-scores[0])[:top_k]
    return images[idx[0]]
        




if __name__ == "__main__":
    embedding_img()
    load_data()
    iface = gr.Interface(fn=norm_val, inputs="text", outputs="image")
    iface.launch(inline = False )