clip / app.py
Vivien
Add eval and torch.no_grad (because inference only)
0779f15
raw
history blame
8.68 kB
from html import escape
import re
import streamlit as st
import pandas as pd, numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from st_clickable_images import clickable_images
MODEL_NAMES = [
# "base-patch32",
# "base-patch16",
# "large-patch14",
"large-patch14-336"
]
@st.cache(allow_output_mutation=True)
def load():
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
models = {}
processors = {}
embeddings = {}
for name in MODEL_NAMES:
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
embeddings[name] = {
0: np.load(f"embeddings-vit-{name}.npy"),
1: np.load(f"embeddings2-vit-{name}.npy"),
}
for k in [0, 1]:
embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
embeddings[name][k], axis=1, keepdims=True
)
return models, processors, df, embeddings
models, processors, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def compute_text_embeddings(list_of_strings, name):
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
with torch.no_grad():
result = models[name].get_text_features(**inputs).detach().numpy()
return result / np.linalg.norm(result, axis=1, keepdims=True)
def image_search(query, corpus, name, n_results=24):
positive_embeddings = None
def concatenate_embeddings(e1, e2):
if e1 is None:
return e2
else:
return np.concatenate((e1, e2), axis=0)
splitted_query = query.split("EXCLUDING ")
dot_product = 0
k = 0 if corpus == "Unsplash" else 1
if len(splitted_query[0]) > 0:
positive_queries = splitted_query[0].split(";")
for positive_query in positive_queries:
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
if match:
corpus2, idx, remainder = match.groups()
idx, remainder = int(idx), remainder.strip()
k2 = 0 if corpus2 == "Unsplash" else 1
positive_embeddings = concatenate_embeddings(
positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
)
if len(remainder) > 0:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([remainder], name)
)
else:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([positive_query], name)
)
dot_product = embeddings[name][k] @ positive_embeddings.T
dot_product = dot_product - np.median(dot_product, axis=0)
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
dot_product = np.min(dot_product, axis=1)
if len(splitted_query) > 1:
negative_queries = (" ".join(splitted_query[1:])).split(";")
negative_embeddings = compute_text_embeddings(negative_queries, name)
dot_product2 = embeddings[name][k] @ negative_embeddings.T
dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
i,
)
for i in results
]
description = """
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
"""
howto = """
- Click on an image to use it as a query and find similar images
- Several queries, including one based on an image, can be combined (use "**;**" as a separator)
- If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
"""
div_style = {
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
}
def main():
st.markdown(
"""
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
.row-widget {
margin-top: -25px;
}
section>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
with st.sidebar.expander("Advanced use"):
st.markdown(howto)
# mode = st.sidebar.selectbox(
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
# )
_, c, _ = st.columns((1, 3, 1))
if "query" in st.session_state:
query = c.text_input("", value=st.session_state["query"])
else:
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
models_dict = {
"ViT-B/32 (quicker)": "base-patch32",
"ViT-B/16 (average)": "base-patch16",
# "ViT-L/14 (slow)": "large-patch14",
"ViT-L/14@336px (slower)": "large-patch14-336",
}
if False: # "Comparison" in mode:
c1, c2 = st.columns((1, 1))
selection1 = c1.selectbox("", models_dict.keys(), index=0)
selection2 = c2.selectbox("", models_dict.keys(), index=2)
name1 = models_dict[selection1]
name2 = models_dict[selection2]
else:
name1 = MODEL_NAMES[-1]
if len(query) > 0:
results1 = image_search(query, corpus, name1)
if False: # "Comparison" in mode:
with c1:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name1 + "1",
)
results2 = image_search(query, corpus, name2)
with c2:
clicked2 = clickable_images(
[result[0] for result in results2],
titles=[result[1] for result in results2],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name2 + "2",
)
else:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "200px"},
key=query + corpus + name1 + "1",
)
clicked2 = -1
if clicked2 >= 0 or clicked1 >= 0:
change_query = False
if "last_clicked" not in st.session_state:
change_query = True
else:
if max(clicked2, clicked1) != st.session_state["last_clicked"]:
change_query = True
if change_query:
if clicked1 >= 0:
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
# elif clicked2 >= 0:
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
st.experimental_rerun()
if __name__ == "__main__":
main()