Vasanth Sarathy commited on
Commit
73f6aba
·
1 Parent(s): e3b439a

Working streamlit app

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ faiss_document_store.db filter=lfs diff=lfs merge=lfs -text
2
+ faiss_document_store.faiss filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,8 +1,149 @@
1
- import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- demo.launch()
 
1
+ import streamlit as st
2
+ import os
3
+ from pipelines import get_pipeline
4
+ import logging
5
+ from json import JSONDecodeError
6
+ from utils import find_substring_indices
7
+ from annotated_text import annotation
8
+ from markdown import markdown
9
 
 
 
10
 
11
+ # Sliders
12
+ DEFAULT_DOCS_FROM_RETRIEVER = int(os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "3"))
13
+
14
+
15
+ def set_state_if_absent(key, value):
16
+ if key not in st.session_state:
17
+ st.session_state[key] = value
18
+
19
+ def query(concept, filters={}, top_k_retriever=5):
20
+ params ={"Retriever": {"top_k": top_k_retriever}}
21
+ pipe = get_pipeline("data/narratives/processed")
22
+
23
+ prediction = pipe.run(
24
+ query=concept,
25
+ params={"Retriever": {"top_k": top_k_retriever}
26
+ }
27
+ )
28
+
29
+ # Format results
30
+ results = []
31
+ spans = prediction['results']
32
+ for idx, span in enumerate(spans):
33
+ context = prediction["documents"][idx].to_dict()['content']
34
+ span_indices = find_substring_indices(context, span)
35
+
36
+ if span_indices:
37
+ result = {"context": context,
38
+ "span": span,
39
+ "span_start": span_indices[0],
40
+ "span_end": span_indices[1]}
41
+ results.append(result)
42
+ return results
43
+
44
+
45
+
46
+
47
+ def main():
48
+
49
+ st.set_page_config(page_title="Anchor")
50
+
51
+ # Persistent state
52
+ set_state_if_absent("question", "husband's permission")
53
+ set_state_if_absent("results", None)
54
+ set_state_if_absent("raw_json", None)
55
+ set_state_if_absent("random_question_requested", False)
56
+
57
+ # Small callback to reset the interface in case the text of the question changes
58
+ def reset_results(*args):
59
+ st.session_state.answer = None
60
+ st.session_state.results = None
61
+ st.session_state.raw_json = None
62
+
63
+ # Title
64
+ st.write("""
65
+ # ⚓ ANCHOR
66
+
67
+ #### Grounding Abstract Concepts in Text
68
+ """)
69
+
70
+
71
+
72
+ # Sidebar
73
+ st.sidebar.header("Options")
74
+
75
+ top_k_retriever = st.sidebar.slider(
76
+ "Max. number of documents from retriever",
77
+ min_value=1,
78
+ max_value=20,
79
+ value=DEFAULT_DOCS_FROM_RETRIEVER,
80
+ step=1,
81
+ on_change=reset_results,
82
+ )
83
+
84
+
85
+ # Search bar
86
+ question = st.text_input(
87
+ value=st.session_state.question,
88
+ max_chars=100,
89
+ on_change=reset_results,
90
+ label="Concept",
91
+ label_visibility="visible",
92
+ )
93
+ col1, col2 = st.columns(2)
94
+ col1.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
95
+ col2.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
96
+
97
+ # Run button
98
+ run_pressed = col1.button("Run")
99
+
100
+ run_query = (run_pressed or question != st.session_state.question)
101
+
102
+
103
+ # Get results for query
104
+ if run_query and question:
105
+ reset_results()
106
+ st.session_state.question = question
107
+
108
+ with st.spinner(
109
+ "🧠 &nbsp;&nbsp; Performing neural search on documents... \n "
110
+ ):
111
+ try:
112
+ st.session_state.results = query(
113
+ question, top_k_retriever=top_k_retriever
114
+ )
115
+ except JSONDecodeError as je:
116
+ st.error("👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
117
+ return
118
+ except Exception as e:
119
+ logging.exception(e)
120
+ if "The server is busy processing requests" in str(e) or "503" in str(e):
121
+ st.error("🧑‍🌾 &nbsp;&nbsp; All our workers are busy! Try again later.")
122
+ else:
123
+ st.error("🐞 &nbsp;&nbsp; An error occurred during the request.")
124
+ return
125
+
126
+
127
+ if st.session_state.results:
128
+
129
+ st.write("## Results:")
130
+
131
+ for count, result in enumerate(st.session_state.results):
132
+ if result['span']:
133
+ st.write(
134
+ markdown(result['context'][:result['span_start']] +
135
+ str(annotation(result['span'], "anchor", "#fad6a5")) +
136
+ result['context'][result['span_end']+1:]),
137
+ unsafe_allow_html=True
138
+ )
139
+
140
+ else:
141
+ st.info(
142
+ "🤔 &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
143
+ )
144
+ st.write("**Relevance:** ", result["relevance"])
145
+
146
+
147
+ main()
148
+
149
 
 
faiss_document_store.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:815c57a332029c405a6bda58876f80db6a31a60fbbd35b8a4ea8595e9fcd398a
3
+ size 1839104
faiss_document_store.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e55b94d536a2d9998f7cd9a24c9dc13bd86072699ae2045cc574a7ff7a5b0af
3
+ size 5050413
faiss_document_store.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"faiss_index_factory_str": "Flat"}
pipelines.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from haystack.document_stores import InMemoryDocumentStore
3
+ from haystack.pipelines.standard_pipelines import TextIndexingPipeline
4
+ from haystack.nodes import BM25Retriever
5
+ from haystack.nodes import FARMReader
6
+ from haystack.pipelines import ExtractiveQAPipeline
7
+ from haystack.nodes.other import Shaper
8
+ from haystack.nodes import PromptNode, PromptTemplate
9
+ from haystack.pipelines import Pipeline
10
+
11
+ from haystack.document_stores import FAISSDocumentStore
12
+ from haystack.nodes import EmbeddingRetriever
13
+
14
+ from haystack.utils import convert_files_to_docs
15
+
16
+
17
+ # Set logging level to INFO
18
+ import logging
19
+ logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
20
+ logging.getLogger("haystack").setLevel(logging.INFO)
21
+
22
+ api_key = "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"
23
+ #faiss_index_path = "faiss_document_store.db"
24
+
25
+
26
+ # Shaper helps expand the `query` variable into a list of identical queries (length of documents)
27
+ # and store the list of queries in the `questions` variable
28
+ # (the variable used in the question answering template)
29
+ def get_pipeline(doc_dir):
30
+
31
+ # Registering a new prompt template called "concept-exemplar"
32
+ prompt_template = \
33
+ """
34
+ Identify a non-overlapping spans of text in the given
35
+ context that resonates with the given concept. By resonate we mean that the
36
+ meaning of the concept is captured in the span. The span exemplifies what
37
+ the concept means. The identified span MUST be present verbatim in the context. \n\nConcept: 'family support'\nContext: Abubakar and his
38
+ wife are expecting his mother to help his wife with the new born. It is
39
+ evening and she has not arrived yet. Then he decided to call the neighbor's
40
+ wife to Come and help the baby and the new born woman to boil hot water and
41
+ baths the baby and made the baby to sleep before the mother come
42
+ back.\nSpan: are expecting his mother to help\n\n\nConcept: $concepts
43
+ \nContext: $documents\nSpan:
44
+ """
45
+ template = PromptTemplate(name="concept-exemplar",prompt_text=prompt_template)
46
+ prompt_node = PromptNode("text-davinci-003", api_key=api_key)
47
+ prompt_node.add_prompt_template(template)
48
+
49
+ # Set concept-exemplar as my default
50
+ exemplifier = prompt_node.set_default_prompt_template("concept-exemplar")
51
+
52
+
53
+ shaper = Shaper(func="value_to_list", inputs={"value": "query", "target_list":"documents"}, outputs=["concepts"])
54
+
55
+
56
+ if os.path.exists("faiss_document_store.db"):
57
+ print("FAISS document store already exists")
58
+ document_store = FAISSDocumentStore(
59
+ faiss_index_path="faiss_document_store.faiss",
60
+ faiss_config_path="faiss_document_store.json")
61
+
62
+ retriever = EmbeddingRetriever(
63
+ document_store=document_store,
64
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1")
65
+ else:
66
+ print("New Document Store created")
67
+ document_store = FAISSDocumentStore(faiss_index_factory_str='Flat')
68
+
69
+ docs = convert_files_to_docs(dir_path=doc_dir)
70
+
71
+ document_store.write_documents(docs)
72
+
73
+ # 4. Set up retriever
74
+ # bm25_retriever = BM25Retriever(document_store=document_store)
75
+ retriever = EmbeddingRetriever(
76
+ document_store=document_store,
77
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1"
78
+ )
79
+ # Important:
80
+ # Now that we initialized the Retriever, we need to call update_embeddings() to iterate over all
81
+ # previously indexed documents and update their embedding representation.
82
+ # While this can be a time consuming operation (depending on the corpus size), it only needs to be done once.
83
+ # At query time, we only need to embed the query and compare it to the existing document embeddings, which is very fast.
84
+ document_store.update_embeddings(retriever)
85
+
86
+ document_store.save("faiss_document_store.faiss")
87
+
88
+
89
+ # # 1. Setup document store
90
+ # if os.path.exists("faiss_document_store.json"):
91
+ # print("Path exists")
92
+ # document_store = FAISSDocumentStore(
93
+ # faiss_index_path="faiss_document_store.faiss",
94
+ # faiss_config_path="faiss_document_store.json")
95
+ # else:
96
+ # print("path does not exist")
97
+ # document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)
98
+ # document_store.save("faiss_document_store.faiss")
99
+
100
+ #document_store.save("faiss_index")
101
+ # document_store = FAISSDocumentStore(faiss_index_factory_str="Flat")
102
+ #document_store = InMemoryDocumentStore(use_bm25=True)
103
+
104
+ # 2. Put files in habitus folder into a list for indexing
105
+ #files_to_index = [doc_dir + "/" + f for f in os.listdir(doc_dir)]
106
+
107
+ # 3. Set up text indexing pipline and index all files in folder
108
+ #indexing_pipeline = TextIndexingPipeline(document_store)
109
+ #indexing_pipeline.run_batch(file_paths=files_to_index)
110
+
111
+ # New combined pipeline
112
+ pipe = Pipeline()
113
+ pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
114
+ pipe.add_node(component=shaper, name="shaper", inputs=["Retriever"])
115
+ pipe.add_node(component=exemplifier, name="exemplifier", inputs=['shaper'])
116
+
117
+ return pipe
118
+
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
- gradio
2
  farm-haystack[all-gpu]
 
 
 
 
1
  farm-haystack[all-gpu]
2
+ streamlit
3
+ st-annotated-text
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import re
3
+
4
+ def get_span_indices(document, span):
5
+ print(f"\nSpan: {span}")
6
+ print(f"Document: {document}")
7
+ res = re.search(span, document, re.IGNORECASE)
8
+ print(f"Res: {res}")
9
+ print(f"Find: {document.find(span)}")
10
+ if res:
11
+ return res.span()
12
+
13
+ def find_substring_indices(string, substring):
14
+ substring = remove_trailing_periods(substring)
15
+ start_index = string.lower().find(substring.lower())
16
+ if start_index == -1:
17
+ return None
18
+ end_index = start_index + len(substring) - 1
19
+ return (start_index, end_index)
20
+
21
+ def remove_trailing_periods(string):
22
+ while string.endswith('.'):
23
+ string = string[:-1]
24
+ return string