Vasanth Sarathy
commited on
Commit
·
73f6aba
1
Parent(s):
e3b439a
Working streamlit app
Browse files- .gitattributes +2 -0
- app.py +146 -5
- faiss_document_store.db +3 -0
- faiss_document_store.faiss +3 -0
- faiss_document_store.json +1 -0
- pipelines.py +118 -0
- requirements.txt +2 -1
- utils.py +24 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
faiss_document_store.db filter=lfs diff=lfs merge=lfs -text
|
2 |
+
faiss_document_store.faiss filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,8 +1,149 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!"
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
demo.launch()
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from pipelines import get_pipeline
|
4 |
+
import logging
|
5 |
+
from json import JSONDecodeError
|
6 |
+
from utils import find_substring_indices
|
7 |
+
from annotated_text import annotation
|
8 |
+
from markdown import markdown
|
9 |
|
|
|
|
|
10 |
|
11 |
+
# Sliders
|
12 |
+
DEFAULT_DOCS_FROM_RETRIEVER = int(os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "3"))
|
13 |
+
|
14 |
+
|
15 |
+
def set_state_if_absent(key, value):
|
16 |
+
if key not in st.session_state:
|
17 |
+
st.session_state[key] = value
|
18 |
+
|
19 |
+
def query(concept, filters={}, top_k_retriever=5):
|
20 |
+
params ={"Retriever": {"top_k": top_k_retriever}}
|
21 |
+
pipe = get_pipeline("data/narratives/processed")
|
22 |
+
|
23 |
+
prediction = pipe.run(
|
24 |
+
query=concept,
|
25 |
+
params={"Retriever": {"top_k": top_k_retriever}
|
26 |
+
}
|
27 |
+
)
|
28 |
+
|
29 |
+
# Format results
|
30 |
+
results = []
|
31 |
+
spans = prediction['results']
|
32 |
+
for idx, span in enumerate(spans):
|
33 |
+
context = prediction["documents"][idx].to_dict()['content']
|
34 |
+
span_indices = find_substring_indices(context, span)
|
35 |
+
|
36 |
+
if span_indices:
|
37 |
+
result = {"context": context,
|
38 |
+
"span": span,
|
39 |
+
"span_start": span_indices[0],
|
40 |
+
"span_end": span_indices[1]}
|
41 |
+
results.append(result)
|
42 |
+
return results
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
|
49 |
+
st.set_page_config(page_title="Anchor")
|
50 |
+
|
51 |
+
# Persistent state
|
52 |
+
set_state_if_absent("question", "husband's permission")
|
53 |
+
set_state_if_absent("results", None)
|
54 |
+
set_state_if_absent("raw_json", None)
|
55 |
+
set_state_if_absent("random_question_requested", False)
|
56 |
+
|
57 |
+
# Small callback to reset the interface in case the text of the question changes
|
58 |
+
def reset_results(*args):
|
59 |
+
st.session_state.answer = None
|
60 |
+
st.session_state.results = None
|
61 |
+
st.session_state.raw_json = None
|
62 |
+
|
63 |
+
# Title
|
64 |
+
st.write("""
|
65 |
+
# ⚓ ANCHOR
|
66 |
+
|
67 |
+
#### Grounding Abstract Concepts in Text
|
68 |
+
""")
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
# Sidebar
|
73 |
+
st.sidebar.header("Options")
|
74 |
+
|
75 |
+
top_k_retriever = st.sidebar.slider(
|
76 |
+
"Max. number of documents from retriever",
|
77 |
+
min_value=1,
|
78 |
+
max_value=20,
|
79 |
+
value=DEFAULT_DOCS_FROM_RETRIEVER,
|
80 |
+
step=1,
|
81 |
+
on_change=reset_results,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
# Search bar
|
86 |
+
question = st.text_input(
|
87 |
+
value=st.session_state.question,
|
88 |
+
max_chars=100,
|
89 |
+
on_change=reset_results,
|
90 |
+
label="Concept",
|
91 |
+
label_visibility="visible",
|
92 |
+
)
|
93 |
+
col1, col2 = st.columns(2)
|
94 |
+
col1.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
|
95 |
+
col2.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
|
96 |
+
|
97 |
+
# Run button
|
98 |
+
run_pressed = col1.button("Run")
|
99 |
+
|
100 |
+
run_query = (run_pressed or question != st.session_state.question)
|
101 |
+
|
102 |
+
|
103 |
+
# Get results for query
|
104 |
+
if run_query and question:
|
105 |
+
reset_results()
|
106 |
+
st.session_state.question = question
|
107 |
+
|
108 |
+
with st.spinner(
|
109 |
+
"🧠 Performing neural search on documents... \n "
|
110 |
+
):
|
111 |
+
try:
|
112 |
+
st.session_state.results = query(
|
113 |
+
question, top_k_retriever=top_k_retriever
|
114 |
+
)
|
115 |
+
except JSONDecodeError as je:
|
116 |
+
st.error("👓 An error occurred reading the results. Is the document store working?")
|
117 |
+
return
|
118 |
+
except Exception as e:
|
119 |
+
logging.exception(e)
|
120 |
+
if "The server is busy processing requests" in str(e) or "503" in str(e):
|
121 |
+
st.error("🧑🌾 All our workers are busy! Try again later.")
|
122 |
+
else:
|
123 |
+
st.error("🐞 An error occurred during the request.")
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
if st.session_state.results:
|
128 |
+
|
129 |
+
st.write("## Results:")
|
130 |
+
|
131 |
+
for count, result in enumerate(st.session_state.results):
|
132 |
+
if result['span']:
|
133 |
+
st.write(
|
134 |
+
markdown(result['context'][:result['span_start']] +
|
135 |
+
str(annotation(result['span'], "anchor", "#fad6a5")) +
|
136 |
+
result['context'][result['span_end']+1:]),
|
137 |
+
unsafe_allow_html=True
|
138 |
+
)
|
139 |
+
|
140 |
+
else:
|
141 |
+
st.info(
|
142 |
+
"🤔 Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
|
143 |
+
)
|
144 |
+
st.write("**Relevance:** ", result["relevance"])
|
145 |
+
|
146 |
+
|
147 |
+
main()
|
148 |
+
|
149 |
|
|
faiss_document_store.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:815c57a332029c405a6bda58876f80db6a31a60fbbd35b8a4ea8595e9fcd398a
|
3 |
+
size 1839104
|
faiss_document_store.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e55b94d536a2d9998f7cd9a24c9dc13bd86072699ae2045cc574a7ff7a5b0af
|
3 |
+
size 5050413
|
faiss_document_store.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"faiss_index_factory_str": "Flat"}
|
pipelines.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from haystack.document_stores import InMemoryDocumentStore
|
3 |
+
from haystack.pipelines.standard_pipelines import TextIndexingPipeline
|
4 |
+
from haystack.nodes import BM25Retriever
|
5 |
+
from haystack.nodes import FARMReader
|
6 |
+
from haystack.pipelines import ExtractiveQAPipeline
|
7 |
+
from haystack.nodes.other import Shaper
|
8 |
+
from haystack.nodes import PromptNode, PromptTemplate
|
9 |
+
from haystack.pipelines import Pipeline
|
10 |
+
|
11 |
+
from haystack.document_stores import FAISSDocumentStore
|
12 |
+
from haystack.nodes import EmbeddingRetriever
|
13 |
+
|
14 |
+
from haystack.utils import convert_files_to_docs
|
15 |
+
|
16 |
+
|
17 |
+
# Set logging level to INFO
|
18 |
+
import logging
|
19 |
+
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
|
20 |
+
logging.getLogger("haystack").setLevel(logging.INFO)
|
21 |
+
|
22 |
+
api_key = "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"
|
23 |
+
#faiss_index_path = "faiss_document_store.db"
|
24 |
+
|
25 |
+
|
26 |
+
# Shaper helps expand the `query` variable into a list of identical queries (length of documents)
|
27 |
+
# and store the list of queries in the `questions` variable
|
28 |
+
# (the variable used in the question answering template)
|
29 |
+
def get_pipeline(doc_dir):
|
30 |
+
|
31 |
+
# Registering a new prompt template called "concept-exemplar"
|
32 |
+
prompt_template = \
|
33 |
+
"""
|
34 |
+
Identify a non-overlapping spans of text in the given
|
35 |
+
context that resonates with the given concept. By resonate we mean that the
|
36 |
+
meaning of the concept is captured in the span. The span exemplifies what
|
37 |
+
the concept means. The identified span MUST be present verbatim in the context. \n\nConcept: 'family support'\nContext: Abubakar and his
|
38 |
+
wife are expecting his mother to help his wife with the new born. It is
|
39 |
+
evening and she has not arrived yet. Then he decided to call the neighbor's
|
40 |
+
wife to Come and help the baby and the new born woman to boil hot water and
|
41 |
+
baths the baby and made the baby to sleep before the mother come
|
42 |
+
back.\nSpan: are expecting his mother to help\n\n\nConcept: $concepts
|
43 |
+
\nContext: $documents\nSpan:
|
44 |
+
"""
|
45 |
+
template = PromptTemplate(name="concept-exemplar",prompt_text=prompt_template)
|
46 |
+
prompt_node = PromptNode("text-davinci-003", api_key=api_key)
|
47 |
+
prompt_node.add_prompt_template(template)
|
48 |
+
|
49 |
+
# Set concept-exemplar as my default
|
50 |
+
exemplifier = prompt_node.set_default_prompt_template("concept-exemplar")
|
51 |
+
|
52 |
+
|
53 |
+
shaper = Shaper(func="value_to_list", inputs={"value": "query", "target_list":"documents"}, outputs=["concepts"])
|
54 |
+
|
55 |
+
|
56 |
+
if os.path.exists("faiss_document_store.db"):
|
57 |
+
print("FAISS document store already exists")
|
58 |
+
document_store = FAISSDocumentStore(
|
59 |
+
faiss_index_path="faiss_document_store.faiss",
|
60 |
+
faiss_config_path="faiss_document_store.json")
|
61 |
+
|
62 |
+
retriever = EmbeddingRetriever(
|
63 |
+
document_store=document_store,
|
64 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
65 |
+
else:
|
66 |
+
print("New Document Store created")
|
67 |
+
document_store = FAISSDocumentStore(faiss_index_factory_str='Flat')
|
68 |
+
|
69 |
+
docs = convert_files_to_docs(dir_path=doc_dir)
|
70 |
+
|
71 |
+
document_store.write_documents(docs)
|
72 |
+
|
73 |
+
# 4. Set up retriever
|
74 |
+
# bm25_retriever = BM25Retriever(document_store=document_store)
|
75 |
+
retriever = EmbeddingRetriever(
|
76 |
+
document_store=document_store,
|
77 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
78 |
+
)
|
79 |
+
# Important:
|
80 |
+
# Now that we initialized the Retriever, we need to call update_embeddings() to iterate over all
|
81 |
+
# previously indexed documents and update their embedding representation.
|
82 |
+
# While this can be a time consuming operation (depending on the corpus size), it only needs to be done once.
|
83 |
+
# At query time, we only need to embed the query and compare it to the existing document embeddings, which is very fast.
|
84 |
+
document_store.update_embeddings(retriever)
|
85 |
+
|
86 |
+
document_store.save("faiss_document_store.faiss")
|
87 |
+
|
88 |
+
|
89 |
+
# # 1. Setup document store
|
90 |
+
# if os.path.exists("faiss_document_store.json"):
|
91 |
+
# print("Path exists")
|
92 |
+
# document_store = FAISSDocumentStore(
|
93 |
+
# faiss_index_path="faiss_document_store.faiss",
|
94 |
+
# faiss_config_path="faiss_document_store.json")
|
95 |
+
# else:
|
96 |
+
# print("path does not exist")
|
97 |
+
# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)
|
98 |
+
# document_store.save("faiss_document_store.faiss")
|
99 |
+
|
100 |
+
#document_store.save("faiss_index")
|
101 |
+
# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat")
|
102 |
+
#document_store = InMemoryDocumentStore(use_bm25=True)
|
103 |
+
|
104 |
+
# 2. Put files in habitus folder into a list for indexing
|
105 |
+
#files_to_index = [doc_dir + "/" + f for f in os.listdir(doc_dir)]
|
106 |
+
|
107 |
+
# 3. Set up text indexing pipline and index all files in folder
|
108 |
+
#indexing_pipeline = TextIndexingPipeline(document_store)
|
109 |
+
#indexing_pipeline.run_batch(file_paths=files_to_index)
|
110 |
+
|
111 |
+
# New combined pipeline
|
112 |
+
pipe = Pipeline()
|
113 |
+
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
114 |
+
pipe.add_node(component=shaper, name="shaper", inputs=["Retriever"])
|
115 |
+
pipe.add_node(component=exemplifier, name="exemplifier", inputs=['shaper'])
|
116 |
+
|
117 |
+
return pipe
|
118 |
+
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
-
gradio
|
2 |
farm-haystack[all-gpu]
|
|
|
|
|
|
|
|
1 |
farm-haystack[all-gpu]
|
2 |
+
streamlit
|
3 |
+
st-annotated-text
|
utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import re
|
3 |
+
|
4 |
+
def get_span_indices(document, span):
|
5 |
+
print(f"\nSpan: {span}")
|
6 |
+
print(f"Document: {document}")
|
7 |
+
res = re.search(span, document, re.IGNORECASE)
|
8 |
+
print(f"Res: {res}")
|
9 |
+
print(f"Find: {document.find(span)}")
|
10 |
+
if res:
|
11 |
+
return res.span()
|
12 |
+
|
13 |
+
def find_substring_indices(string, substring):
|
14 |
+
substring = remove_trailing_periods(substring)
|
15 |
+
start_index = string.lower().find(substring.lower())
|
16 |
+
if start_index == -1:
|
17 |
+
return None
|
18 |
+
end_index = start_index + len(substring) - 1
|
19 |
+
return (start_index, end_index)
|
20 |
+
|
21 |
+
def remove_trailing_periods(string):
|
22 |
+
while string.endswith('.'):
|
23 |
+
string = string[:-1]
|
24 |
+
return string
|