CLIP-Docker / app.py
Toonies's picture
update app.py
eca1850
import gradio as gr
from datasets import load_dataset
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import torch
from tqdm.auto import tqdm
import numpy as np
import time
def CLIP_model():
global model, token, processor
model_id = 'openai/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_id)
token = CLIPTokenizerFast.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
def load_data():
global data
data = load_dataset(
'frgfm/imagenette',
'full_size',
split = 'train',
ignore_verifications = False
)
def embedding_input(text_input):
token_input = token(text_input, return_tensors = "pt")
text_embedd = model.get_text_features(**token_input)
return text_embedd
def embedding_img():
global img_arr, images
images = data['image']
batch_size = 10
img_arr = None
for i in tqdm(range(0, len(images), batch_size)):
batch = images[i:i+batch_size]
batch = processor(
text = None,
images = batch,
return_tensors = 'pt',
padding = True
)['pixel_values']
batch_emb = model.get_image_features(pixel_values=batch)
batch_emb = batch_emb.squeeze(0)
batch_emb = batch_emb.detach().numpy()
if img_arr is None:
img_arr = batch_emb
else:
img_arr = np.concatenate((img_arr, batch_emb), axis = 0)
return images, img_arr
def main():
CLIP_model()
load_data()
embedding_img()
iface = gr.Interface(fn = process, inputs = "text", outputs = "image")
iface.launch(inline = False)
def process(text):
text_input = embedding_input(text)
image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T
text_emb = text_input.detach().numpy()
scores = np.dot(text_emb, image_emb.T)
idx = np.argsort(-scores[0])[0]
return images[idx]
if __name__ == "__main__":
main()