Toonies commited on
Commit
d9e380c
·
1 Parent(s): 5644b77

update temp

Browse files
Files changed (2) hide show
  1. app.py +72 -2
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,77 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch(inline = False )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from datasets import load_dataset
3
+ from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+ import numpy as np
7
+
8
+ device = 'cpu' # 'cuda' if torch.cuda.is_available() else "cpu"
9
+ model_id = 'openai/clip-vit-base-patch32'
10
+ model = CLIPModel.from_pretrained(model_id).to(device)
11
+ tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
12
+ processor = CLIPProcessor.from_pretrained(model_id)
13
+
14
+
15
 
16
  def greet(name):
17
  return "Hello " + name + "!!"
18
 
19
+ def load_data():
20
+ global imagenette
21
+ imagenette = load_dataset(
22
+ 'frgfm/imagenette',
23
+ 'full_size',
24
+ split = 'train',
25
+ ignore_verifications = False # set to True if seeing splits Error
26
+ )
27
+ return imagenette
28
+
29
+ def embedding_input(text_input):
30
+ token_input = tokenizer(text_input, return_tensors = "pt")
31
+ text_emb = model.get_text_features(**token_input.to(device))
32
+ return text_emb
33
+
34
+ def embedding_img():
35
+ global images
36
+ load_data()
37
+ sample_idx= np.random.randint(0, len(imagenette)+1, 100).tolist()
38
+ images = [imagenette[i]['image'] for i in sample_idx]
39
+ batch_sie = 5
40
+ image_arr = None
41
+ for i in tqdm(range(0, len(images), batch_sie)):
42
+ batch = images[i:i+batch_sie]
43
+
44
+ batch = processor(
45
+ text = None,
46
+ images = batch,
47
+ return_tensors= 'pt',
48
+ padding = True
49
+ )['pixel_values'].to(device)
50
+ batch_emb = model.get_image_features(pixel_values = batch)
51
+ batch_emb = batch_emb.squeeze(0)
52
+ batch_emb = batch_emb.cpu().detach().numpy()
53
+ if image_arr is None:
54
+ image_arr = batch_emb
55
+
56
+ else:
57
+ image_arr = np.concatenate((image_arr, batch_emb), axis = 0)
58
+ return image_arr
59
+
60
+ def norm_val(text_input):
61
+ image_arr = embedding_img()
62
+ text_emb = embedding_input(text_input)
63
+
64
+ image_arr = (image_arr.T / np.linalg.norm(image_arr, axis = 1)).T
65
+ text_emb = text_emb.cpu().detach().numpy()
66
+ scores = np.dot(text_emb, image_arr.T)
67
+ top_k = 1
68
+ idx = np.argsort(-scores[0])[:top_k]
69
+ return images[idx[0]]
70
+
71
+
72
+
73
+
74
+
75
+ if __name__ == "__main__":
76
+ iface = gr.Interface(fn=norm_val, inputs="text", outputs="image")
77
+ iface.launch(inline = False )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ numpy
4
+ pandas
5
+ datasets
6
+ tqdm
7
+ transformers