File size: 2,009 Bytes
28957be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eca1850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28957be
 
 
 
eca1850
 
 
28957be
 
 
eca1850
 
 
 
 
 
 
 
 
 
28957be
eca1850
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from datasets import load_dataset
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import torch
from tqdm.auto import tqdm
import numpy as np
import time


def CLIP_model():
    global model, token, processor
    model_id = 'openai/clip-vit-base-patch32'
    model =  CLIPModel.from_pretrained(model_id)
    token = CLIPTokenizerFast.from_pretrained(model_id)
    processor = CLIPProcessor.from_pretrained(model_id)

def load_data():
    global data
    data = load_dataset(
        'frgfm/imagenette',
        'full_size',
        split = 'train',
        ignore_verifications = False 
    )

def embedding_input(text_input):
    token_input = token(text_input, return_tensors = "pt")
    text_embedd = model.get_text_features(**token_input)
    return text_embedd

def embedding_img():
    global img_arr, images
    images = data['image']
    batch_size = 10
    img_arr = None
    for i in tqdm(range(0, len(images), batch_size)):
        batch = images[i:i+batch_size]

        batch = processor(
            text = None,
            images = batch,
            return_tensors = 'pt',
            padding = True
        )['pixel_values']

        batch_emb = model.get_image_features(pixel_values=batch)
        batch_emb = batch_emb.squeeze(0)
        batch_emb = batch_emb.detach().numpy()

        if img_arr is None:
            img_arr = batch_emb
            
        else:
            img_arr = np.concatenate((img_arr, batch_emb), axis = 0)
    return images, img_arr



def main():
    CLIP_model()
    load_data()
    embedding_img()
    iface = gr.Interface(fn = process, inputs = "text", outputs  = "image")
    iface.launch(inline = False)


def process(text):
    text_input = embedding_input(text)
    image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T
    text_emb = text_input.detach().numpy()
    scores = np.dot(text_emb, image_emb.T)
    idx = np.argsort(-scores[0])[0]
    return images[idx]



if __name__ == "__main__":
    main()