|
from html import escape |
|
import re |
|
import streamlit as st |
|
import pandas as pd, numpy as np |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
from st_clickable_images import clickable_images |
|
|
|
MODEL_NAMES = [ |
|
|
|
|
|
|
|
"large-patch14-336" |
|
] |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load(): |
|
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} |
|
models = {} |
|
processors = {} |
|
embeddings = {} |
|
for name in MODEL_NAMES: |
|
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval() |
|
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}") |
|
embeddings[name] = { |
|
0: np.load(f"embeddings-vit-{name}.npy"), |
|
1: np.load(f"embeddings2-vit-{name}.npy"), |
|
} |
|
for k in [0, 1]: |
|
embeddings[name][k] = embeddings[name][k] / np.linalg.norm( |
|
embeddings[name][k], axis=1, keepdims=True |
|
) |
|
return models, processors, df, embeddings |
|
|
|
|
|
models, processors, df, embeddings = load() |
|
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} |
|
|
|
|
|
def compute_text_embeddings(list_of_strings, name): |
|
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
result = models[name].get_text_features(**inputs).detach().numpy() |
|
return result / np.linalg.norm(result, axis=1, keepdims=True) |
|
|
|
|
|
def image_search(query, corpus, name, n_results=24): |
|
positive_embeddings = None |
|
|
|
def concatenate_embeddings(e1, e2): |
|
if e1 is None: |
|
return e2 |
|
else: |
|
return np.concatenate((e1, e2), axis=0) |
|
|
|
splitted_query = query.split("EXCLUDING ") |
|
dot_product = 0 |
|
k = 0 if corpus == "Unsplash" else 1 |
|
if len(splitted_query[0]) > 0: |
|
positive_queries = splitted_query[0].split(";") |
|
for positive_query in positive_queries: |
|
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query) |
|
if match: |
|
corpus2, idx, remainder = match.groups() |
|
idx, remainder = int(idx), remainder.strip() |
|
k2 = 0 if corpus2 == "Unsplash" else 1 |
|
positive_embeddings = concatenate_embeddings( |
|
positive_embeddings, embeddings[name][k2][idx : idx + 1, :] |
|
) |
|
if len(remainder) > 0: |
|
positive_embeddings = concatenate_embeddings( |
|
positive_embeddings, compute_text_embeddings([remainder], name) |
|
) |
|
else: |
|
positive_embeddings = concatenate_embeddings( |
|
positive_embeddings, compute_text_embeddings([positive_query], name) |
|
) |
|
dot_product = embeddings[name][k] @ positive_embeddings.T |
|
dot_product = dot_product - np.median(dot_product, axis=0) |
|
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True) |
|
dot_product = np.min(dot_product, axis=1) |
|
|
|
if len(splitted_query) > 1: |
|
negative_queries = (" ".join(splitted_query[1:])).split(";") |
|
negative_embeddings = compute_text_embeddings(negative_queries, name) |
|
dot_product2 = embeddings[name][k] @ negative_embeddings.T |
|
dot_product2 = dot_product2 - np.median(dot_product2, axis=0) |
|
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True) |
|
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1) |
|
|
|
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] |
|
return [ |
|
( |
|
df[k].iloc[i]["path"], |
|
df[k].iloc[i]["tooltip"] + source[k], |
|
i, |
|
) |
|
for i in results |
|
] |
|
|
|
|
|
description = """ |
|
# Semantic image search |
|
|
|
**Enter your query and hit enter** |
|
|
|
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, π€ Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* |
|
|
|
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe* |
|
""" |
|
|
|
howto = """ |
|
- Click on an image to use it as a query and find similar images |
|
- Several queries, including one based on an image, can be combined (use "**;**" as a separator) |
|
- If the input includes "**EXCLUDING**", the part right of it will be used as a negative query |
|
""" |
|
|
|
div_style = { |
|
"display": "flex", |
|
"justify-content": "center", |
|
"flex-wrap": "wrap", |
|
} |
|
|
|
|
|
def main(): |
|
st.markdown( |
|
""" |
|
<style> |
|
.block-container{ |
|
max-width: 1200px; |
|
} |
|
div.row-widget.stRadio > div{ |
|
flex-direction:row; |
|
display: flex; |
|
justify-content: center; |
|
} |
|
div.row-widget.stRadio > div > label{ |
|
margin-left: 5px; |
|
margin-right: 5px; |
|
} |
|
.row-widget { |
|
margin-top: -25px; |
|
} |
|
section>div:first-child { |
|
padding-top: 30px; |
|
} |
|
div.reportview-container > section:first-child{ |
|
max-width: 320px; |
|
} |
|
#MainMenu { |
|
visibility: hidden; |
|
} |
|
footer { |
|
visibility: hidden; |
|
} |
|
</style>""", |
|
unsafe_allow_html=True, |
|
) |
|
st.sidebar.markdown(description) |
|
with st.sidebar.expander("Advanced use"): |
|
st.markdown(howto) |
|
|
|
|
|
|
|
|
|
_, c, _ = st.columns((1, 3, 1)) |
|
if "query" in st.session_state: |
|
query = c.text_input("", value=st.session_state["query"]) |
|
else: |
|
query = c.text_input("", value="clouds at sunset") |
|
corpus = st.radio("", ["Unsplash", "Movies"]) |
|
|
|
models_dict = { |
|
"ViT-B/32 (quicker)": "base-patch32", |
|
"ViT-B/16 (average)": "base-patch16", |
|
|
|
"ViT-L/14@336px (slower)": "large-patch14-336", |
|
} |
|
|
|
if False: |
|
c1, c2 = st.columns((1, 1)) |
|
selection1 = c1.selectbox("", models_dict.keys(), index=0) |
|
selection2 = c2.selectbox("", models_dict.keys(), index=2) |
|
name1 = models_dict[selection1] |
|
name2 = models_dict[selection2] |
|
else: |
|
name1 = MODEL_NAMES[-1] |
|
|
|
if len(query) > 0: |
|
results1 = image_search(query, corpus, name1) |
|
if False: |
|
with c1: |
|
clicked1 = clickable_images( |
|
[result[0] for result in results1], |
|
titles=[result[1] for result in results1], |
|
div_style=div_style, |
|
img_style={"margin": "2px", "height": "150px"}, |
|
key=query + corpus + name1 + "1", |
|
) |
|
results2 = image_search(query, corpus, name2) |
|
with c2: |
|
clicked2 = clickable_images( |
|
[result[0] for result in results2], |
|
titles=[result[1] for result in results2], |
|
div_style=div_style, |
|
img_style={"margin": "2px", "height": "150px"}, |
|
key=query + corpus + name2 + "2", |
|
) |
|
else: |
|
clicked1 = clickable_images( |
|
[result[0] for result in results1], |
|
titles=[result[1] for result in results1], |
|
div_style=div_style, |
|
img_style={"margin": "2px", "height": "200px"}, |
|
key=query + corpus + name1 + "1", |
|
) |
|
clicked2 = -1 |
|
|
|
if clicked2 >= 0 or clicked1 >= 0: |
|
change_query = False |
|
if "last_clicked" not in st.session_state: |
|
change_query = True |
|
else: |
|
if max(clicked2, clicked1) != st.session_state["last_clicked"]: |
|
change_query = True |
|
if change_query: |
|
if clicked1 >= 0: |
|
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]" |
|
|
|
|
|
st.experimental_rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|