Spaces:
Running
Running
Jeremy Watt
commited on
Commit
·
bc066dc
1
Parent(s):
b9a532a
reset
Browse files- data/dbs/memes.db +0 -0
- data/dbs/memes.faiss +0 -0
- data/dbs/placeholder +0 -0
- data/input/test_meme_1.jpg +0 -0
- data/input/test_meme_2.jpg +0 -0
- data/input/test_meme_3.jpg +0 -0
- data/input/test_meme_4.jpg +0 -0
- data/input/test_meme_5.jpg +0 -0
- data/input/test_meme_6.jpg +0 -0
- data/input/test_meme_7.jpg +0 -0
- data/input/test_meme_8.jpg +0 -0
- data/input/test_meme_9.jpg +0 -0
- meme_search/__init__.py +0 -10
- meme_search/app.py +0 -34
- meme_search/data_puller.py +0 -114
- meme_search/style.css +0 -15
- meme_search/utilities/__init__.py +0 -10
- meme_search/utilities/__pycache__/__init__.cpython-310.pyc +0 -0
- meme_search/utilities/__pycache__/query.cpython-310.pyc +0 -0
- meme_search/utilities/chunks.py +0 -68
- meme_search/utilities/create.py +0 -75
- meme_search/utilities/imgs.py +0 -19
- meme_search/utilities/query.py +0 -76
- meme_search/utilities/text_extraction.py +0 -38
data/dbs/memes.db
DELETED
Binary file (28.7 kB)
|
|
data/dbs/memes.faiss
DELETED
Binary file (393 kB)
|
|
data/dbs/placeholder
DELETED
File without changes
|
data/input/test_meme_1.jpg
DELETED
Binary file (77.1 kB)
|
|
data/input/test_meme_2.jpg
DELETED
Binary file (64.6 kB)
|
|
data/input/test_meme_3.jpg
DELETED
Binary file (8.13 kB)
|
|
data/input/test_meme_4.jpg
DELETED
Binary file (12.7 kB)
|
|
data/input/test_meme_5.jpg
DELETED
Binary file (14.5 kB)
|
|
data/input/test_meme_6.jpg
DELETED
Binary file (65 kB)
|
|
data/input/test_meme_7.jpg
DELETED
Binary file (105 kB)
|
|
data/input/test_meme_8.jpg
DELETED
Binary file (43.9 kB)
|
|
data/input/test_meme_9.jpg
DELETED
Binary file (37.7 kB)
|
|
meme_search/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
base_dir = os.path.dirname(os.path.abspath(__file__))
|
4 |
-
meme_search_root_dir = os.path.dirname(base_dir)
|
5 |
-
|
6 |
-
vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
|
7 |
-
sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
|
8 |
-
|
9 |
-
from meme_search.data_puller import pull_demo_data
|
10 |
-
pull_demo_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/app.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
from meme_search import base_dir, sqlite_db_path, vector_db_path
|
2 |
-
from meme_search.utilities.query import complete_query
|
3 |
-
import streamlit as st
|
4 |
-
|
5 |
-
st.set_page_config(page_title="Meme Search")
|
6 |
-
|
7 |
-
|
8 |
-
# search bar taken from --> https://discuss.streamlit.io/t/creating-a-nicely-formatted-search-field/1804/2
|
9 |
-
def local_css(file_name):
|
10 |
-
with open(file_name) as f:
|
11 |
-
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
12 |
-
|
13 |
-
|
14 |
-
def remote_css(url):
|
15 |
-
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
|
16 |
-
|
17 |
-
|
18 |
-
local_css(base_dir + "/style.css")
|
19 |
-
remote_css("https://fonts.googleapis.com/icon?family=Material+Icons")
|
20 |
-
|
21 |
-
# icon("search")
|
22 |
-
buff, col, buff2 = st.columns([1, 4, 1])
|
23 |
-
|
24 |
-
selected = col.text_input(label="search for meme", placeholder="search for a meme")
|
25 |
-
if selected:
|
26 |
-
results = complete_query(selected, vector_db_path, sqlite_db_path)
|
27 |
-
img_paths = [v["img_path"] for v in results]
|
28 |
-
for result in results:
|
29 |
-
with col.container(border=True):
|
30 |
-
st.image(
|
31 |
-
result["img_path"],
|
32 |
-
output_format="auto",
|
33 |
-
caption=f'{result["full_description"]} (query distance = {result["distance"]})',
|
34 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/data_puller.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import requests
|
3 |
-
|
4 |
-
def download_file_from_github(repo_url, file_path, save_dir):
|
5 |
-
raw_url = f"https://raw.githubusercontent.com/{repo_url}/main/{file_path}"
|
6 |
-
response = requests.get(raw_url)
|
7 |
-
|
8 |
-
if response.status_code == 200:
|
9 |
-
if not os.path.exists(save_dir):
|
10 |
-
os.makedirs(save_dir)
|
11 |
-
save_path = os.path.join(save_dir, os.path.basename(file_path))
|
12 |
-
with open(save_path, 'wb') as file:
|
13 |
-
file.write(response.content)
|
14 |
-
|
15 |
-
print(f"File downloaded successfully: {save_path}")
|
16 |
-
else:
|
17 |
-
print(f"Failed to download file. Status code: {response.status_code}")
|
18 |
-
|
19 |
-
|
20 |
-
def list_files_in_github_directory(owner, repo, directory_path):
|
21 |
-
url = f"https://api.github.com/repos/{owner}/{repo}/contents/{directory_path}"
|
22 |
-
response = requests.get(url)
|
23 |
-
|
24 |
-
if response.status_code == 200:
|
25 |
-
files = response.json()
|
26 |
-
names = []
|
27 |
-
for file in files:
|
28 |
-
names.append(file["name"])
|
29 |
-
return names
|
30 |
-
else:
|
31 |
-
print(f"Failed to retrieve directory contents. Status code: {response.status_code}")
|
32 |
-
|
33 |
-
|
34 |
-
def collect_repo_file_names():
|
35 |
-
owner = "neonwatty"
|
36 |
-
repo = "meme_search"
|
37 |
-
input_path = "/data/input"
|
38 |
-
input_names = list_files_in_github_directory(owner, repo, input_path)
|
39 |
-
|
40 |
-
db_path = "/data/dbs"
|
41 |
-
db_names = list_files_in_github_directory(owner, repo, db_path)
|
42 |
-
db_names = [v for v in db_names if ".db" in v or ".faiss" in v]
|
43 |
-
return input_path, db_path, input_names, db_names
|
44 |
-
|
45 |
-
def check_directory_exists(directory_path):
|
46 |
-
return os.path.isdir(directory_path)
|
47 |
-
|
48 |
-
def create_directory(directory_path):
|
49 |
-
try:
|
50 |
-
os.makedirs(directory_path, exist_ok=True)
|
51 |
-
print(f"Directory '{directory_path}' created successfully.")
|
52 |
-
except OSError as error:
|
53 |
-
print(f"Error creating directory '{directory_path}': {error}")
|
54 |
-
|
55 |
-
def check_files_in_directory(directory_path, file_list):
|
56 |
-
missing_files = []
|
57 |
-
for file_name in file_list:
|
58 |
-
if not os.path.isfile(os.path.join(directory_path, file_name)):
|
59 |
-
missing_files.append(file_name)
|
60 |
-
return missing_files
|
61 |
-
|
62 |
-
def list_files_in_directory(directory_path):
|
63 |
-
try:
|
64 |
-
files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
|
65 |
-
return files
|
66 |
-
except OSError as error:
|
67 |
-
print(f"Error accessing directory '{directory_path}': {error}")
|
68 |
-
return []
|
69 |
-
|
70 |
-
def delete_file(directory_path, file_name):
|
71 |
-
try:
|
72 |
-
file_path = os.path.join(directory_path, file_name)
|
73 |
-
if os.path.isfile(file_path):
|
74 |
-
os.remove(file_path)
|
75 |
-
print(f"File '{file_name}' deleted successfully.")
|
76 |
-
else:
|
77 |
-
print(f"File '{file_name}' does not exist in the directory '{directory_path}'.")
|
78 |
-
except OSError as error:
|
79 |
-
print(f"Error deleting file '{file_name}': {error}")
|
80 |
-
|
81 |
-
def pull_demo_data():
|
82 |
-
repo_url = "neonwatty/meme_search"
|
83 |
-
input_path, db_path, repo_input_names, repo_db_names = collect_repo_file_names()
|
84 |
-
if not check_directory_exists("." + input_path):
|
85 |
-
create_directory("." + input_path)
|
86 |
-
for name in repo_input_names:
|
87 |
-
file_path = input_path + "/" + name
|
88 |
-
download_file_from_github(repo_url, file_path, "." + input_path)
|
89 |
-
else:
|
90 |
-
local_input_files = list_files_in_directory("." + input_path)
|
91 |
-
input_files_to_pull = [item for item in repo_input_names if item not in local_input_files]
|
92 |
-
input_files_to_delete = [item for item in local_input_files if item not in repo_input_names]
|
93 |
-
|
94 |
-
for name in input_files_to_delete:
|
95 |
-
delete_file("." + input_path, name)
|
96 |
-
for name in input_files_to_pull:
|
97 |
-
file_path = input_path + "/" + name
|
98 |
-
download_file_from_github(repo_url, file_path, "." + input_path)
|
99 |
-
|
100 |
-
if not check_directory_exists("." + db_path):
|
101 |
-
create_directory("." + db_path)
|
102 |
-
repo_url = "neonwatty/meme_search"
|
103 |
-
for name in repo_db_names:
|
104 |
-
file_path = db_path + "/" + name
|
105 |
-
download_file_from_github(repo_url, file_path, "." + db_path)
|
106 |
-
else:
|
107 |
-
local_db_files = list_files_in_directory("." + db_path)
|
108 |
-
db_files_to_pull = [item for item in repo_db_names if item not in local_db_files]
|
109 |
-
db_files_to_delete = [item for item in local_db_files if item not in repo_db_names]
|
110 |
-
for name in db_files_to_delete:
|
111 |
-
delete_file("." + db_path, name)
|
112 |
-
for name in db_files_to_pull:
|
113 |
-
file_path = db_path + "/" + name
|
114 |
-
download_file_from_github(repo_url, file_path, "." + db_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/style.css
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
body {
|
2 |
-
color: #fff;
|
3 |
-
background-color: #4F8BF9;
|
4 |
-
}
|
5 |
-
|
6 |
-
.stButton>button {
|
7 |
-
color: #4F8BF9;
|
8 |
-
border-radius: 50%;
|
9 |
-
height: 3em;
|
10 |
-
width: 3em;
|
11 |
-
}
|
12 |
-
|
13 |
-
.stTextInput>div>div>input {
|
14 |
-
color: #4F8BF9;
|
15 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from sentence_transformers import SentenceTransformer
|
3 |
-
|
4 |
-
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
5 |
-
utilities_base_dir = os.path.dirname(os.path.abspath(__file__))
|
6 |
-
meme_search_dir = os.path.dirname(utilities_base_dir)
|
7 |
-
meme_search_root_dir = os.path.dirname(meme_search_dir)
|
8 |
-
|
9 |
-
vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
|
10 |
-
sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (538 Bytes)
|
|
meme_search/utilities/__pycache__/query.cpython-310.pyc
DELETED
Binary file (2.66 kB)
|
|
meme_search/utilities/chunks.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
|
4 |
-
def clean_word(text: str) -> str:
|
5 |
-
# clean input text - keeping only lower case letters, numbers, punctuation, and single quote symbols
|
6 |
-
return re.sub(" +", " ", re.compile("[^a-z0-9,.!?']").sub(" ", text.lower().strip()))
|
7 |
-
|
8 |
-
|
9 |
-
def chunk_text(text: str) -> list:
|
10 |
-
# split and clean input text
|
11 |
-
text_split = clean_word(text).split(" ")
|
12 |
-
text_split = [v for v in text_split if len(v) > 0]
|
13 |
-
|
14 |
-
# use two pointers to create chunks
|
15 |
-
chunk_size = 4
|
16 |
-
overlap_size = 2
|
17 |
-
|
18 |
-
# create next chunk by moving right pointer until chunk_size is reached or line_number changes by more than 1 or end of word_sequence is reached
|
19 |
-
left_pointer = 0
|
20 |
-
right_pointer = chunk_size - 1
|
21 |
-
chunks = []
|
22 |
-
|
23 |
-
if right_pointer >= len(text_split):
|
24 |
-
chunks = [" ".join(text_split)]
|
25 |
-
else:
|
26 |
-
while right_pointer < len(text_split):
|
27 |
-
# check if chunk_size has been reached
|
28 |
-
# create chunk
|
29 |
-
chunk = text_split[left_pointer : right_pointer + 1]
|
30 |
-
|
31 |
-
# move left pointer
|
32 |
-
left_pointer += chunk_size - overlap_size
|
33 |
-
|
34 |
-
# move right pointer
|
35 |
-
right_pointer += chunk_size - overlap_size
|
36 |
-
|
37 |
-
# store chunk
|
38 |
-
chunks.append(" ".join(chunk))
|
39 |
-
|
40 |
-
# check if there is final chunk
|
41 |
-
if len(text_split[left_pointer:]) > 0:
|
42 |
-
last_chunk = text_split[left_pointer:]
|
43 |
-
chunks.append(" ".join(last_chunk))
|
44 |
-
|
45 |
-
# insert the full text
|
46 |
-
if len(chunks) > 1:
|
47 |
-
chunks.insert(0, text.lower())
|
48 |
-
return chunks
|
49 |
-
|
50 |
-
|
51 |
-
# loop over each meme's moondream based text descriptor and create a short dict containing its full and chunked text
|
52 |
-
def create_all_img_chunks(img_paths: list, answers: list) -> list:
|
53 |
-
try:
|
54 |
-
print("STARTING: create_all_img_chunks")
|
55 |
-
img_chunks = []
|
56 |
-
for ind, img_path in enumerate(img_paths):
|
57 |
-
moondream_meme_text = answers[ind]
|
58 |
-
moondream_chunks = chunk_text(moondream_meme_text)
|
59 |
-
for chunk in moondream_chunks:
|
60 |
-
entry = {}
|
61 |
-
entry["img_path"] = img_path
|
62 |
-
entry["chunk"] = chunk
|
63 |
-
img_chunks.append(entry)
|
64 |
-
print("SUCCESS: create_all_img_chunks ran successfully")
|
65 |
-
return img_chunks
|
66 |
-
except Exception as e:
|
67 |
-
print(f"FAILURE: create_all_img_chunks failed with exception {e}")
|
68 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/create.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
import sqlite3
|
2 |
-
import faiss
|
3 |
-
from meme_search.utilities import model
|
4 |
-
from meme_search.utilities.imgs import collect_img_paths
|
5 |
-
from meme_search.utilities.text_extraction import extract_text_from_imgs
|
6 |
-
from meme_search.utilities.chunks import create_all_img_chunks
|
7 |
-
from meme_search.utilities import vector_db_path, sqlite_db_path
|
8 |
-
|
9 |
-
|
10 |
-
def create_chunk_db(img_chunks: list, db_filepath: str) -> None:
|
11 |
-
# Create a lookup table for chunks
|
12 |
-
conn = sqlite3.connect(db_filepath)
|
13 |
-
cursor = conn.cursor()
|
14 |
-
|
15 |
-
# Create the table - delete old table if it exists
|
16 |
-
cursor.execute("DROP TABLE IF EXISTS chunks_reverse_lookup")
|
17 |
-
|
18 |
-
# Create the table - alias rowid as chunk_index
|
19 |
-
cursor.execute("""
|
20 |
-
CREATE TABLE IF NOT EXISTS chunks_reverse_lookup (
|
21 |
-
chunk_index INTEGER PRIMARY KEY,
|
22 |
-
img_path TEXT,
|
23 |
-
chunk TEXT
|
24 |
-
);
|
25 |
-
""")
|
26 |
-
|
27 |
-
# Insert data into the table
|
28 |
-
for chunk_index, entry in enumerate(img_chunks):
|
29 |
-
img_path = entry["img_path"]
|
30 |
-
chunk = entry["chunk"]
|
31 |
-
cursor.execute(
|
32 |
-
"INSERT INTO chunks_reverse_lookup (chunk_index, img_path, chunk) VALUES (?, ?, ?)",
|
33 |
-
(chunk_index, img_path, chunk),
|
34 |
-
)
|
35 |
-
|
36 |
-
conn.commit()
|
37 |
-
conn.close()
|
38 |
-
|
39 |
-
|
40 |
-
def create_vector_db(chunks: list, db_file_path: str) -> None:
|
41 |
-
# embed inputs
|
42 |
-
embeddings = model.encode(chunks)
|
43 |
-
|
44 |
-
# dump all_embeddings to faiss index
|
45 |
-
index = faiss.IndexFlatL2(embeddings.shape[1])
|
46 |
-
index.add(embeddings)
|
47 |
-
|
48 |
-
# write index to disk
|
49 |
-
faiss.write_index(index, db_file_path)
|
50 |
-
|
51 |
-
|
52 |
-
def complete_create_dbs(img_chunks: list, vector_db_path: str, sqlite_db_path: str) -> None:
|
53 |
-
try:
|
54 |
-
print("STARTING: complete_create_dbs")
|
55 |
-
|
56 |
-
# create db for img_chunks
|
57 |
-
create_chunk_db(img_chunks, sqlite_db_path)
|
58 |
-
|
59 |
-
# create vector embedding db for chunks
|
60 |
-
chunks = [v["chunk"] for v in img_chunks]
|
61 |
-
create_vector_db(chunks, vector_db_path)
|
62 |
-
print("SUCCESS: complete_create_dbs succeeded")
|
63 |
-
except Exception as e:
|
64 |
-
print(f"FAILURE: complete_create_dbs failed with exception {e}")
|
65 |
-
|
66 |
-
|
67 |
-
def process():
|
68 |
-
all_img_paths = collect_img_paths()
|
69 |
-
moondream_answers = extract_text_from_imgs(all_img_paths)
|
70 |
-
img_chunks = create_all_img_chunks(all_img_paths, moondream_answers)
|
71 |
-
complete_create_dbs(img_chunks, vector_db_path, sqlite_db_path)
|
72 |
-
|
73 |
-
|
74 |
-
if __name__ == "__main__":
|
75 |
-
process()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/imgs.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from meme_search.utilities import meme_search_root_dir
|
3 |
-
|
4 |
-
allowable_extensions = ["jpg", "jpeg", "png"]
|
5 |
-
|
6 |
-
|
7 |
-
def collect_img_paths() -> list:
|
8 |
-
try:
|
9 |
-
img_dir = meme_search_root_dir + "/data/input"
|
10 |
-
print("STARTING: collect_img_paths")
|
11 |
-
|
12 |
-
all_img_paths = [os.path.join(img_dir, name) for name in os.listdir(img_dir) if name.split(".")[-1] in allowable_extensions]
|
13 |
-
all_img_paths = sorted(all_img_paths)
|
14 |
-
|
15 |
-
print(f"SUCCESS: collect_img_paths ran successfully - image paths loaded from '{img_dir}'")
|
16 |
-
return all_img_paths
|
17 |
-
except Exception as e:
|
18 |
-
print(f"FAILURE: collect_img_paths failed with img_dir {img_dir} with exception {e}")
|
19 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/query.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import faiss
|
2 |
-
import sqlite3
|
3 |
-
import numpy as np
|
4 |
-
from typing import Tuple
|
5 |
-
import argparse
|
6 |
-
from meme_search.utilities import model
|
7 |
-
from meme_search.utilities import vector_db_path, sqlite_db_path
|
8 |
-
|
9 |
-
|
10 |
-
def query_vector_db(query: str, db_file_path: str, k: int = 10) -> Tuple[list, list]:
|
11 |
-
# connect to db
|
12 |
-
faiss_index = faiss.read_index(db_file_path)
|
13 |
-
|
14 |
-
# test
|
15 |
-
encoded_query = np.expand_dims(model.encode(query), axis=0)
|
16 |
-
|
17 |
-
# query db
|
18 |
-
distances, indices = faiss_index.search(encoded_query, k)
|
19 |
-
distances = distances.tolist()[0]
|
20 |
-
indices = indices.tolist()[0]
|
21 |
-
return distances, indices
|
22 |
-
|
23 |
-
|
24 |
-
def query_sqlite_db(indices: list, db_filepath: str) -> list:
|
25 |
-
conn = sqlite3.connect(db_filepath)
|
26 |
-
cursor = conn.cursor()
|
27 |
-
query = f"SELECT * FROM chunks_reverse_lookup WHERE chunk_index IN {tuple(indices)}"
|
28 |
-
cursor.execute(query)
|
29 |
-
rows = cursor.fetchall()
|
30 |
-
rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
|
31 |
-
rows = sorted(rows, key=lambda x: indices.index(x["index"])) # re-sort rows according to input indices
|
32 |
-
for row in rows:
|
33 |
-
query = f"SELECT * FROM chunks_reverse_lookup WHERE chunk_index=(SELECT MIN(chunk_index) FROM chunks_reverse_lookup WHERE img_path='{row['img_path']}')"
|
34 |
-
cursor.execute(query)
|
35 |
-
full_description_row = cursor.fetchall()
|
36 |
-
row["full_description"] = full_description_row[0][2]
|
37 |
-
conn.close()
|
38 |
-
return rows
|
39 |
-
|
40 |
-
|
41 |
-
def complete_query(query: str, vector_db_path: str, sqlite_db_path: str, k: int = 10) -> list:
|
42 |
-
try:
|
43 |
-
print("STARTING: complete_query")
|
44 |
-
|
45 |
-
# query vector_db, first converting input query to embedding
|
46 |
-
distances, indices = query_vector_db(query, vector_db_path, k=k)
|
47 |
-
|
48 |
-
# use indices to query sqlite db containing chunk data
|
49 |
-
img_chunks = query_sqlite_db(indices, sqlite_db_path) # bump up indices by 1 since sqlite row index starts at 1 not 0
|
50 |
-
|
51 |
-
# map indices back to correct image in img_chunks
|
52 |
-
imgs_seen = []
|
53 |
-
unique_img_entries = []
|
54 |
-
for ind, entry in enumerate(img_chunks):
|
55 |
-
if entry["img_path"] in imgs_seen:
|
56 |
-
continue
|
57 |
-
else:
|
58 |
-
entry["distance"] = round(distances[ind], 2)
|
59 |
-
unique_img_entries.append(entry)
|
60 |
-
imgs_seen.append(entry["img_path"])
|
61 |
-
print("SUCCESS: complete_query succeeded")
|
62 |
-
return unique_img_entries
|
63 |
-
except Exception as e:
|
64 |
-
print(f"FAILURE: complete_query failed with exception {e}")
|
65 |
-
raise e
|
66 |
-
|
67 |
-
|
68 |
-
if __name__ == "__main__":
|
69 |
-
parser = argparse.ArgumentParser()
|
70 |
-
parser.add_argument("--query", dest="query", type=str, help="Add query")
|
71 |
-
args = parser.parse_args()
|
72 |
-
query = args.query
|
73 |
-
|
74 |
-
print(query)
|
75 |
-
results = complete_query(query, vector_db_path, sqlite_db_path)
|
76 |
-
print(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meme_search/utilities/text_extraction.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
-
from PIL import Image
|
3 |
-
import transformers
|
4 |
-
|
5 |
-
transformers.logging.set_verbosity_error()
|
6 |
-
|
7 |
-
|
8 |
-
def prompt_moondream(img_path: str, prompt: str) -> str:
|
9 |
-
# copied from moondream demo readme --> https://github.com/vikhyat/moondream/tree/main
|
10 |
-
model_id = "vikhyatk/moondream2"
|
11 |
-
revision = "2024-05-20"
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(
|
13 |
-
model_id,
|
14 |
-
trust_remote_code=True,
|
15 |
-
revision=revision,
|
16 |
-
)
|
17 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
18 |
-
image = Image.open(img_path)
|
19 |
-
enc_image = model.encode_image(image)
|
20 |
-
moondream_response = model.answer_question(enc_image, prompt, tokenizer)
|
21 |
-
return moondream_response
|
22 |
-
|
23 |
-
|
24 |
-
def extract_text_from_imgs(img_paths: list) -> list:
|
25 |
-
try:
|
26 |
-
print("STARTING: extract_text_from_imgs")
|
27 |
-
prompt = "Describe this image."
|
28 |
-
answers = []
|
29 |
-
for img_path in img_paths:
|
30 |
-
print(f"INFO: prompting moondream for a description of image: '{img_path}'")
|
31 |
-
answer = prompt_moondream(img_path, prompt)
|
32 |
-
answers.append(answer)
|
33 |
-
print("DONE!")
|
34 |
-
print("SUCCESS: extract_text_from_imgs succeeded")
|
35 |
-
return answers
|
36 |
-
except Exception as e:
|
37 |
-
print(f"FAILURE: extract_text_from_imgs failed with exception {e}")
|
38 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|