Jeremy Watt commited on
Commit
bc066dc
·
1 Parent(s): b9a532a
data/dbs/memes.db DELETED
Binary file (28.7 kB)
 
data/dbs/memes.faiss DELETED
Binary file (393 kB)
 
data/dbs/placeholder DELETED
File without changes
data/input/test_meme_1.jpg DELETED
Binary file (77.1 kB)
 
data/input/test_meme_2.jpg DELETED
Binary file (64.6 kB)
 
data/input/test_meme_3.jpg DELETED
Binary file (8.13 kB)
 
data/input/test_meme_4.jpg DELETED
Binary file (12.7 kB)
 
data/input/test_meme_5.jpg DELETED
Binary file (14.5 kB)
 
data/input/test_meme_6.jpg DELETED
Binary file (65 kB)
 
data/input/test_meme_7.jpg DELETED
Binary file (105 kB)
 
data/input/test_meme_8.jpg DELETED
Binary file (43.9 kB)
 
data/input/test_meme_9.jpg DELETED
Binary file (37.7 kB)
 
meme_search/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- import os
2
-
3
- base_dir = os.path.dirname(os.path.abspath(__file__))
4
- meme_search_root_dir = os.path.dirname(base_dir)
5
-
6
- vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
7
- sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
8
-
9
- from meme_search.data_puller import pull_demo_data
10
- pull_demo_data()
 
 
 
 
 
 
 
 
 
 
 
meme_search/app.py DELETED
@@ -1,34 +0,0 @@
1
- from meme_search import base_dir, sqlite_db_path, vector_db_path
2
- from meme_search.utilities.query import complete_query
3
- import streamlit as st
4
-
5
- st.set_page_config(page_title="Meme Search")
6
-
7
-
8
- # search bar taken from --> https://discuss.streamlit.io/t/creating-a-nicely-formatted-search-field/1804/2
9
- def local_css(file_name):
10
- with open(file_name) as f:
11
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
12
-
13
-
14
- def remote_css(url):
15
- st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
16
-
17
-
18
- local_css(base_dir + "/style.css")
19
- remote_css("https://fonts.googleapis.com/icon?family=Material+Icons")
20
-
21
- # icon("search")
22
- buff, col, buff2 = st.columns([1, 4, 1])
23
-
24
- selected = col.text_input(label="search for meme", placeholder="search for a meme")
25
- if selected:
26
- results = complete_query(selected, vector_db_path, sqlite_db_path)
27
- img_paths = [v["img_path"] for v in results]
28
- for result in results:
29
- with col.container(border=True):
30
- st.image(
31
- result["img_path"],
32
- output_format="auto",
33
- caption=f'{result["full_description"]} (query distance = {result["distance"]})',
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/data_puller.py DELETED
@@ -1,114 +0,0 @@
1
- import os
2
- import requests
3
-
4
- def download_file_from_github(repo_url, file_path, save_dir):
5
- raw_url = f"https://raw.githubusercontent.com/{repo_url}/main/{file_path}"
6
- response = requests.get(raw_url)
7
-
8
- if response.status_code == 200:
9
- if not os.path.exists(save_dir):
10
- os.makedirs(save_dir)
11
- save_path = os.path.join(save_dir, os.path.basename(file_path))
12
- with open(save_path, 'wb') as file:
13
- file.write(response.content)
14
-
15
- print(f"File downloaded successfully: {save_path}")
16
- else:
17
- print(f"Failed to download file. Status code: {response.status_code}")
18
-
19
-
20
- def list_files_in_github_directory(owner, repo, directory_path):
21
- url = f"https://api.github.com/repos/{owner}/{repo}/contents/{directory_path}"
22
- response = requests.get(url)
23
-
24
- if response.status_code == 200:
25
- files = response.json()
26
- names = []
27
- for file in files:
28
- names.append(file["name"])
29
- return names
30
- else:
31
- print(f"Failed to retrieve directory contents. Status code: {response.status_code}")
32
-
33
-
34
- def collect_repo_file_names():
35
- owner = "neonwatty"
36
- repo = "meme_search"
37
- input_path = "/data/input"
38
- input_names = list_files_in_github_directory(owner, repo, input_path)
39
-
40
- db_path = "/data/dbs"
41
- db_names = list_files_in_github_directory(owner, repo, db_path)
42
- db_names = [v for v in db_names if ".db" in v or ".faiss" in v]
43
- return input_path, db_path, input_names, db_names
44
-
45
- def check_directory_exists(directory_path):
46
- return os.path.isdir(directory_path)
47
-
48
- def create_directory(directory_path):
49
- try:
50
- os.makedirs(directory_path, exist_ok=True)
51
- print(f"Directory '{directory_path}' created successfully.")
52
- except OSError as error:
53
- print(f"Error creating directory '{directory_path}': {error}")
54
-
55
- def check_files_in_directory(directory_path, file_list):
56
- missing_files = []
57
- for file_name in file_list:
58
- if not os.path.isfile(os.path.join(directory_path, file_name)):
59
- missing_files.append(file_name)
60
- return missing_files
61
-
62
- def list_files_in_directory(directory_path):
63
- try:
64
- files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
65
- return files
66
- except OSError as error:
67
- print(f"Error accessing directory '{directory_path}': {error}")
68
- return []
69
-
70
- def delete_file(directory_path, file_name):
71
- try:
72
- file_path = os.path.join(directory_path, file_name)
73
- if os.path.isfile(file_path):
74
- os.remove(file_path)
75
- print(f"File '{file_name}' deleted successfully.")
76
- else:
77
- print(f"File '{file_name}' does not exist in the directory '{directory_path}'.")
78
- except OSError as error:
79
- print(f"Error deleting file '{file_name}': {error}")
80
-
81
- def pull_demo_data():
82
- repo_url = "neonwatty/meme_search"
83
- input_path, db_path, repo_input_names, repo_db_names = collect_repo_file_names()
84
- if not check_directory_exists("." + input_path):
85
- create_directory("." + input_path)
86
- for name in repo_input_names:
87
- file_path = input_path + "/" + name
88
- download_file_from_github(repo_url, file_path, "." + input_path)
89
- else:
90
- local_input_files = list_files_in_directory("." + input_path)
91
- input_files_to_pull = [item for item in repo_input_names if item not in local_input_files]
92
- input_files_to_delete = [item for item in local_input_files if item not in repo_input_names]
93
-
94
- for name in input_files_to_delete:
95
- delete_file("." + input_path, name)
96
- for name in input_files_to_pull:
97
- file_path = input_path + "/" + name
98
- download_file_from_github(repo_url, file_path, "." + input_path)
99
-
100
- if not check_directory_exists("." + db_path):
101
- create_directory("." + db_path)
102
- repo_url = "neonwatty/meme_search"
103
- for name in repo_db_names:
104
- file_path = db_path + "/" + name
105
- download_file_from_github(repo_url, file_path, "." + db_path)
106
- else:
107
- local_db_files = list_files_in_directory("." + db_path)
108
- db_files_to_pull = [item for item in repo_db_names if item not in local_db_files]
109
- db_files_to_delete = [item for item in local_db_files if item not in repo_db_names]
110
- for name in db_files_to_delete:
111
- delete_file("." + db_path, name)
112
- for name in db_files_to_pull:
113
- file_path = db_path + "/" + name
114
- download_file_from_github(repo_url, file_path, "." + db_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/style.css DELETED
@@ -1,15 +0,0 @@
1
- body {
2
- color: #fff;
3
- background-color: #4F8BF9;
4
- }
5
-
6
- .stButton>button {
7
- color: #4F8BF9;
8
- border-radius: 50%;
9
- height: 3em;
10
- width: 3em;
11
- }
12
-
13
- .stTextInput>div>div>input {
14
- color: #4F8BF9;
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- import os
2
- from sentence_transformers import SentenceTransformer
3
-
4
- model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
5
- utilities_base_dir = os.path.dirname(os.path.abspath(__file__))
6
- meme_search_dir = os.path.dirname(utilities_base_dir)
7
- meme_search_root_dir = os.path.dirname(meme_search_dir)
8
-
9
- vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
10
- sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (538 Bytes)
 
meme_search/utilities/__pycache__/query.cpython-310.pyc DELETED
Binary file (2.66 kB)
 
meme_search/utilities/chunks.py DELETED
@@ -1,68 +0,0 @@
1
- import re
2
-
3
-
4
- def clean_word(text: str) -> str:
5
- # clean input text - keeping only lower case letters, numbers, punctuation, and single quote symbols
6
- return re.sub(" +", " ", re.compile("[^a-z0-9,.!?']").sub(" ", text.lower().strip()))
7
-
8
-
9
- def chunk_text(text: str) -> list:
10
- # split and clean input text
11
- text_split = clean_word(text).split(" ")
12
- text_split = [v for v in text_split if len(v) > 0]
13
-
14
- # use two pointers to create chunks
15
- chunk_size = 4
16
- overlap_size = 2
17
-
18
- # create next chunk by moving right pointer until chunk_size is reached or line_number changes by more than 1 or end of word_sequence is reached
19
- left_pointer = 0
20
- right_pointer = chunk_size - 1
21
- chunks = []
22
-
23
- if right_pointer >= len(text_split):
24
- chunks = [" ".join(text_split)]
25
- else:
26
- while right_pointer < len(text_split):
27
- # check if chunk_size has been reached
28
- # create chunk
29
- chunk = text_split[left_pointer : right_pointer + 1]
30
-
31
- # move left pointer
32
- left_pointer += chunk_size - overlap_size
33
-
34
- # move right pointer
35
- right_pointer += chunk_size - overlap_size
36
-
37
- # store chunk
38
- chunks.append(" ".join(chunk))
39
-
40
- # check if there is final chunk
41
- if len(text_split[left_pointer:]) > 0:
42
- last_chunk = text_split[left_pointer:]
43
- chunks.append(" ".join(last_chunk))
44
-
45
- # insert the full text
46
- if len(chunks) > 1:
47
- chunks.insert(0, text.lower())
48
- return chunks
49
-
50
-
51
- # loop over each meme's moondream based text descriptor and create a short dict containing its full and chunked text
52
- def create_all_img_chunks(img_paths: list, answers: list) -> list:
53
- try:
54
- print("STARTING: create_all_img_chunks")
55
- img_chunks = []
56
- for ind, img_path in enumerate(img_paths):
57
- moondream_meme_text = answers[ind]
58
- moondream_chunks = chunk_text(moondream_meme_text)
59
- for chunk in moondream_chunks:
60
- entry = {}
61
- entry["img_path"] = img_path
62
- entry["chunk"] = chunk
63
- img_chunks.append(entry)
64
- print("SUCCESS: create_all_img_chunks ran successfully")
65
- return img_chunks
66
- except Exception as e:
67
- print(f"FAILURE: create_all_img_chunks failed with exception {e}")
68
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/create.py DELETED
@@ -1,75 +0,0 @@
1
- import sqlite3
2
- import faiss
3
- from meme_search.utilities import model
4
- from meme_search.utilities.imgs import collect_img_paths
5
- from meme_search.utilities.text_extraction import extract_text_from_imgs
6
- from meme_search.utilities.chunks import create_all_img_chunks
7
- from meme_search.utilities import vector_db_path, sqlite_db_path
8
-
9
-
10
- def create_chunk_db(img_chunks: list, db_filepath: str) -> None:
11
- # Create a lookup table for chunks
12
- conn = sqlite3.connect(db_filepath)
13
- cursor = conn.cursor()
14
-
15
- # Create the table - delete old table if it exists
16
- cursor.execute("DROP TABLE IF EXISTS chunks_reverse_lookup")
17
-
18
- # Create the table - alias rowid as chunk_index
19
- cursor.execute("""
20
- CREATE TABLE IF NOT EXISTS chunks_reverse_lookup (
21
- chunk_index INTEGER PRIMARY KEY,
22
- img_path TEXT,
23
- chunk TEXT
24
- );
25
- """)
26
-
27
- # Insert data into the table
28
- for chunk_index, entry in enumerate(img_chunks):
29
- img_path = entry["img_path"]
30
- chunk = entry["chunk"]
31
- cursor.execute(
32
- "INSERT INTO chunks_reverse_lookup (chunk_index, img_path, chunk) VALUES (?, ?, ?)",
33
- (chunk_index, img_path, chunk),
34
- )
35
-
36
- conn.commit()
37
- conn.close()
38
-
39
-
40
- def create_vector_db(chunks: list, db_file_path: str) -> None:
41
- # embed inputs
42
- embeddings = model.encode(chunks)
43
-
44
- # dump all_embeddings to faiss index
45
- index = faiss.IndexFlatL2(embeddings.shape[1])
46
- index.add(embeddings)
47
-
48
- # write index to disk
49
- faiss.write_index(index, db_file_path)
50
-
51
-
52
- def complete_create_dbs(img_chunks: list, vector_db_path: str, sqlite_db_path: str) -> None:
53
- try:
54
- print("STARTING: complete_create_dbs")
55
-
56
- # create db for img_chunks
57
- create_chunk_db(img_chunks, sqlite_db_path)
58
-
59
- # create vector embedding db for chunks
60
- chunks = [v["chunk"] for v in img_chunks]
61
- create_vector_db(chunks, vector_db_path)
62
- print("SUCCESS: complete_create_dbs succeeded")
63
- except Exception as e:
64
- print(f"FAILURE: complete_create_dbs failed with exception {e}")
65
-
66
-
67
- def process():
68
- all_img_paths = collect_img_paths()
69
- moondream_answers = extract_text_from_imgs(all_img_paths)
70
- img_chunks = create_all_img_chunks(all_img_paths, moondream_answers)
71
- complete_create_dbs(img_chunks, vector_db_path, sqlite_db_path)
72
-
73
-
74
- if __name__ == "__main__":
75
- process()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/imgs.py DELETED
@@ -1,19 +0,0 @@
1
- import os
2
- from meme_search.utilities import meme_search_root_dir
3
-
4
- allowable_extensions = ["jpg", "jpeg", "png"]
5
-
6
-
7
- def collect_img_paths() -> list:
8
- try:
9
- img_dir = meme_search_root_dir + "/data/input"
10
- print("STARTING: collect_img_paths")
11
-
12
- all_img_paths = [os.path.join(img_dir, name) for name in os.listdir(img_dir) if name.split(".")[-1] in allowable_extensions]
13
- all_img_paths = sorted(all_img_paths)
14
-
15
- print(f"SUCCESS: collect_img_paths ran successfully - image paths loaded from '{img_dir}'")
16
- return all_img_paths
17
- except Exception as e:
18
- print(f"FAILURE: collect_img_paths failed with img_dir {img_dir} with exception {e}")
19
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/query.py DELETED
@@ -1,76 +0,0 @@
1
- import faiss
2
- import sqlite3
3
- import numpy as np
4
- from typing import Tuple
5
- import argparse
6
- from meme_search.utilities import model
7
- from meme_search.utilities import vector_db_path, sqlite_db_path
8
-
9
-
10
- def query_vector_db(query: str, db_file_path: str, k: int = 10) -> Tuple[list, list]:
11
- # connect to db
12
- faiss_index = faiss.read_index(db_file_path)
13
-
14
- # test
15
- encoded_query = np.expand_dims(model.encode(query), axis=0)
16
-
17
- # query db
18
- distances, indices = faiss_index.search(encoded_query, k)
19
- distances = distances.tolist()[0]
20
- indices = indices.tolist()[0]
21
- return distances, indices
22
-
23
-
24
- def query_sqlite_db(indices: list, db_filepath: str) -> list:
25
- conn = sqlite3.connect(db_filepath)
26
- cursor = conn.cursor()
27
- query = f"SELECT * FROM chunks_reverse_lookup WHERE chunk_index IN {tuple(indices)}"
28
- cursor.execute(query)
29
- rows = cursor.fetchall()
30
- rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
31
- rows = sorted(rows, key=lambda x: indices.index(x["index"])) # re-sort rows according to input indices
32
- for row in rows:
33
- query = f"SELECT * FROM chunks_reverse_lookup WHERE chunk_index=(SELECT MIN(chunk_index) FROM chunks_reverse_lookup WHERE img_path='{row['img_path']}')"
34
- cursor.execute(query)
35
- full_description_row = cursor.fetchall()
36
- row["full_description"] = full_description_row[0][2]
37
- conn.close()
38
- return rows
39
-
40
-
41
- def complete_query(query: str, vector_db_path: str, sqlite_db_path: str, k: int = 10) -> list:
42
- try:
43
- print("STARTING: complete_query")
44
-
45
- # query vector_db, first converting input query to embedding
46
- distances, indices = query_vector_db(query, vector_db_path, k=k)
47
-
48
- # use indices to query sqlite db containing chunk data
49
- img_chunks = query_sqlite_db(indices, sqlite_db_path) # bump up indices by 1 since sqlite row index starts at 1 not 0
50
-
51
- # map indices back to correct image in img_chunks
52
- imgs_seen = []
53
- unique_img_entries = []
54
- for ind, entry in enumerate(img_chunks):
55
- if entry["img_path"] in imgs_seen:
56
- continue
57
- else:
58
- entry["distance"] = round(distances[ind], 2)
59
- unique_img_entries.append(entry)
60
- imgs_seen.append(entry["img_path"])
61
- print("SUCCESS: complete_query succeeded")
62
- return unique_img_entries
63
- except Exception as e:
64
- print(f"FAILURE: complete_query failed with exception {e}")
65
- raise e
66
-
67
-
68
- if __name__ == "__main__":
69
- parser = argparse.ArgumentParser()
70
- parser.add_argument("--query", dest="query", type=str, help="Add query")
71
- args = parser.parse_args()
72
- query = args.query
73
-
74
- print(query)
75
- results = complete_query(query, vector_db_path, sqlite_db_path)
76
- print(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
meme_search/utilities/text_extraction.py DELETED
@@ -1,38 +0,0 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from PIL import Image
3
- import transformers
4
-
5
- transformers.logging.set_verbosity_error()
6
-
7
-
8
- def prompt_moondream(img_path: str, prompt: str) -> str:
9
- # copied from moondream demo readme --> https://github.com/vikhyat/moondream/tree/main
10
- model_id = "vikhyatk/moondream2"
11
- revision = "2024-05-20"
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- trust_remote_code=True,
15
- revision=revision,
16
- )
17
- tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
18
- image = Image.open(img_path)
19
- enc_image = model.encode_image(image)
20
- moondream_response = model.answer_question(enc_image, prompt, tokenizer)
21
- return moondream_response
22
-
23
-
24
- def extract_text_from_imgs(img_paths: list) -> list:
25
- try:
26
- print("STARTING: extract_text_from_imgs")
27
- prompt = "Describe this image."
28
- answers = []
29
- for img_path in img_paths:
30
- print(f"INFO: prompting moondream for a description of image: '{img_path}'")
31
- answer = prompt_moondream(img_path, prompt)
32
- answers.append(answer)
33
- print("DONE!")
34
- print("SUCCESS: extract_text_from_imgs succeeded")
35
- return answers
36
- except Exception as e:
37
- print(f"FAILURE: extract_text_from_imgs failed with exception {e}")
38
- raise e