Vivien commited on
Commit
5185219
1 Parent(s): dab4ce7

Add possibility to compose queries and use images as queries

Browse files
Files changed (2) hide show
  1. app.py +75 -27
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,8 +1,9 @@
 
 
1
  import streamlit as st
2
  import pandas as pd, numpy as np
3
- from html import escape
4
- import os
5
  from transformers import CLIPProcessor, CLIPModel
 
6
 
7
 
8
  @st.cache(
@@ -19,47 +20,72 @@ def load():
19
  df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
20
  embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
21
  for k in [0, 1]:
22
- embeddings[k] = np.divide(
23
- embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
 
24
  )
25
  return model, processor, df, embeddings
26
 
27
 
28
  model, processor, df, embeddings = load()
29
-
30
  source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
31
 
32
 
33
- def get_html(url_list, height=200):
34
- html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
35
- for url, title, link in url_list:
36
- html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
37
- if len(link) > 0:
38
- html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
39
- html = html + html2
40
- html += "</div>"
41
- return html
42
-
43
-
44
  def compute_text_embeddings(list_of_strings):
45
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
46
- return model.get_text_features(**inputs)
47
-
48
-
49
- st.cache(show_spinner=False)
50
 
51
 
52
  def image_search(query, corpus, n_results=24):
53
- text_embeddings = compute_text_embeddings([query]).detach().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  k = 0 if corpus == "Unsplash" else 1
55
- results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
56
- -1 : -n_results - 1 : -1
57
- ]
 
 
 
 
 
 
 
 
 
 
 
58
  return [
59
  (
60
  df[k].iloc[i]["path"],
61
  df[k].iloc[i]["tooltip"] + source[k],
62
- df[k].iloc[i]["link"],
63
  )
64
  for i in results
65
  ]
@@ -112,11 +138,33 @@ def main():
112
  )
113
  st.sidebar.markdown(description)
114
  _, c, _ = st.columns((1, 3, 1))
115
- query = c.text_input("", value="clouds at sunset")
 
 
 
116
  corpus = st.radio("", ["Unsplash", "Movies"])
117
  if len(query) > 0:
118
  results = image_search(query, corpus)
119
- st.markdown(get_html(results), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  if __name__ == "__main__":
 
1
+ from html import escape
2
+ import re
3
  import streamlit as st
4
  import pandas as pd, numpy as np
 
 
5
  from transformers import CLIPProcessor, CLIPModel
6
+ from st_clickable_images import clickable_images
7
 
8
 
9
  @st.cache(
 
20
  df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
21
  embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
22
  for k in [0, 1]:
23
+ embeddings[k] = embeddings[k] - np.mean(embeddings[k], axis=0)
24
+ embeddings[k] = embeddings[k] / np.linalg.norm(
25
+ embeddings[k], axis=1, keepdims=True
26
  )
27
  return model, processor, df, embeddings
28
 
29
 
30
  model, processor, df, embeddings = load()
 
31
  source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def compute_text_embeddings(list_of_strings):
35
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
36
+ result = model.get_text_features(**inputs).detach().numpy()
37
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
 
 
38
 
39
 
40
  def image_search(query, corpus, n_results=24):
41
+ positive_embeddings = None
42
+
43
+ def concatenate_embeddings(e1, e2):
44
+ if e1 is None:
45
+ return e2
46
+ else:
47
+ return np.concatenate((e1, e2), axis=0)
48
+
49
+ splitted_query = query.split("/")
50
+
51
+ positive_queries = splitted_query[0].split(";")
52
+ for positive_query in positive_queries:
53
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
54
+ if match:
55
+ corpus2, idx, remainder = match.groups()
56
+ idx, remainder = int(idx), remainder.strip()
57
+ k = 0 if corpus2 == "Unsplash" else 1
58
+ positive_embeddings = concatenate_embeddings(
59
+ positive_embeddings, embeddings[k][idx : idx + 1, :]
60
+ )
61
+ if len(remainder) > 0:
62
+ positive_embeddings = concatenate_embeddings(
63
+ positive_embeddings, compute_text_embeddings([remainder])
64
+ )
65
+ else:
66
+ positive_embeddings = concatenate_embeddings(
67
+ positive_embeddings, compute_text_embeddings([positive_query])
68
+ )
69
  k = 0 if corpus == "Unsplash" else 1
70
+ dot_product = embeddings[k] @ positive_embeddings.T
71
+ dot_product = dot_product - np.mean(dot_product, axis=0)
72
+ dot_product = dot_product / np.linalg.norm(dot_product, axis=0)
73
+ dot_product = np.min(dot_product, axis=1)
74
+
75
+ if len(splitted_query) > 1:
76
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
77
+ negative_embeddings = compute_text_embeddings(negative_queries)
78
+ dot_product2 = embeddings[k] @ negative_embeddings.T
79
+ dot_product2 = dot_product2 - np.mean(dot_product2, axis=0)
80
+ dot_product2 = dot_product2 / np.linalg.norm(dot_product2, axis=0)
81
+ dot_product -= np.max(dot_product2, axis=1)
82
+
83
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
84
  return [
85
  (
86
  df[k].iloc[i]["path"],
87
  df[k].iloc[i]["tooltip"] + source[k],
88
+ i,
89
  )
90
  for i in results
91
  ]
 
138
  )
139
  st.sidebar.markdown(description)
140
  _, c, _ = st.columns((1, 3, 1))
141
+ if "query" in st.session_state:
142
+ query = c.text_input("", value=st.session_state["query"])
143
+ else:
144
+ query = c.text_input("", value="clouds at sunset")
145
  corpus = st.radio("", ["Unsplash", "Movies"])
146
  if len(query) > 0:
147
  results = image_search(query, corpus)
148
+ clicked = clickable_images(
149
+ [result[0] for result in results],
150
+ titles=[result[1] for result in results],
151
+ div_style={
152
+ "display": "flex",
153
+ "justify-content": "center",
154
+ "flex-wrap": "wrap",
155
+ },
156
+ img_style={"margin": "2px", "height": "200px"},
157
+ )
158
+ if clicked >= 0:
159
+ change_query = False
160
+ if "last_clicked" not in st.session_state:
161
+ change_query = True
162
+ else:
163
+ if clicked != st.session_state["last_clicked"]:
164
+ change_query = True
165
+ if change_query:
166
+ st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
167
+ st.experimental_rerun()
168
 
169
 
170
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,3 +2,4 @@ torch
2
  transformers
3
  numpy
4
  pandas
 
 
2
  transformers
3
  numpy
4
  pandas
5
+ st-clickable-images