Toonies commited on
Commit
eca1850
·
1 Parent(s): f5c0af0

update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -5
app.py CHANGED
@@ -14,16 +14,66 @@ def CLIP_model():
14
  token = CLIPTokenizerFast.from_pretrained(model_id)
15
  processor = CLIPProcessor.from_pretrained(model_id)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def hello_name(name):
19
- return "Hello " + name
20
 
21
  def main():
22
  CLIP_model()
23
-
24
- iface = gr.Interface(fn = hello_name, inputs = "text", outputs = "text")
 
25
  iface.launch(inline = False)
26
 
27
 
 
 
 
 
 
 
 
 
 
 
28
  if __name__ == "__main__":
29
- main()
 
 
14
  token = CLIPTokenizerFast.from_pretrained(model_id)
15
  processor = CLIPProcessor.from_pretrained(model_id)
16
 
17
+ def load_data():
18
+ global data
19
+ data = load_dataset(
20
+ 'frgfm/imagenette',
21
+ 'full_size',
22
+ split = 'train',
23
+ ignore_verifications = False
24
+ )
25
+
26
+ def embedding_input(text_input):
27
+ token_input = token(text_input, return_tensors = "pt")
28
+ text_embedd = model.get_text_features(**token_input)
29
+ return text_embedd
30
+
31
+ def embedding_img():
32
+ global img_arr, images
33
+ images = data['image']
34
+ batch_size = 10
35
+ img_arr = None
36
+ for i in tqdm(range(0, len(images), batch_size)):
37
+ batch = images[i:i+batch_size]
38
+
39
+ batch = processor(
40
+ text = None,
41
+ images = batch,
42
+ return_tensors = 'pt',
43
+ padding = True
44
+ )['pixel_values']
45
+
46
+ batch_emb = model.get_image_features(pixel_values=batch)
47
+ batch_emb = batch_emb.squeeze(0)
48
+ batch_emb = batch_emb.detach().numpy()
49
+
50
+ if img_arr is None:
51
+ img_arr = batch_emb
52
+
53
+ else:
54
+ img_arr = np.concatenate((img_arr, batch_emb), axis = 0)
55
+ return images, img_arr
56
+
57
 
 
 
58
 
59
  def main():
60
  CLIP_model()
61
+ load_data()
62
+ embedding_img()
63
+ iface = gr.Interface(fn = process, inputs = "text", outputs = "image")
64
  iface.launch(inline = False)
65
 
66
 
67
+ def process(text):
68
+ text_input = embedding_input(text)
69
+ image_emb = (img_arr.T/np.linalg.norm(img_arr, axis = 1)).T
70
+ text_emb = text_input.detach().numpy()
71
+ scores = np.dot(text_emb, image_emb.T)
72
+ idx = np.argsort(-scores[0])[0]
73
+ return images[idx]
74
+
75
+
76
+
77
  if __name__ == "__main__":
78
+ main()
79
+