Bryan Lincoln commited on
Commit
009313d
·
0 Parent(s):

feat: add demo code

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. main.py +203 -0
  3. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
main.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from langchain.chains import (
4
+ ConversationalRetrievalChain,
5
+ LLMChain,
6
+ MapReduceDocumentsChain,
7
+ ReduceDocumentsChain,
8
+ StuffDocumentsChain,
9
+ )
10
+ from langchain.embeddings import OpenAIEmbeddings
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.vectorstores import Chroma
15
+ from langchain_community.chat_models import ChatOpenAI
16
+ from langchain_community.document_loaders import WebBaseLoader
17
+
18
+
19
+ def wait_for_summarization(url):
20
+ return [(None, f"Please wait while I summarize the contents of {url}...")]
21
+
22
+
23
+ def load_page(url, api_key, history):
24
+ global docs, summary, llm
25
+ loader = WebBaseLoader(url)
26
+ docs = loader.load()
27
+ llm = ChatOpenAI(
28
+ model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key
29
+ )
30
+ map_template = """The following is a set of snippets from a web page:
31
+ {docs}
32
+ Based on this list of snippets, please identify the main themes
33
+ Helpful Answer:"""
34
+ map_prompt = PromptTemplate.from_template(map_template)
35
+ map_chain = LLMChain(llm=llm, prompt=map_prompt)
36
+
37
+ # Reduce
38
+
39
+ reduce_template = """The following is set of summaries of a web page:
40
+ {docs}
41
+ Take these and distill it into a final, consolidated summary of the main themes.
42
+ Helpful Answer:"""
43
+ reduce_prompt = PromptTemplate.from_template(reduce_template)
44
+ reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
45
+
46
+ # Takes a list of documents, combines them into a single string, and passes this to an LLMChain
47
+ combine_documents_chain = StuffDocumentsChain(
48
+ llm_chain=reduce_chain, document_variable_name="docs"
49
+ )
50
+
51
+ # Combines and iteratively reduces the mapped documents
52
+ reduce_documents_chain = ReduceDocumentsChain(
53
+ # This is final chain that is called.
54
+ combine_documents_chain=combine_documents_chain,
55
+ # If documents exceed context for `StuffDocumentsChain`
56
+ collapse_documents_chain=combine_documents_chain,
57
+ # The maximum number of tokens to group documents into.
58
+ token_max=4000,
59
+ )
60
+ # Combining documents by mapping a chain over them, then combining results
61
+ map_reduce_chain = MapReduceDocumentsChain(
62
+ # Map chain
63
+ llm_chain=map_chain,
64
+ # Reduce chain
65
+ reduce_documents_chain=reduce_documents_chain,
66
+ # The variable name in the llm_chain to put the documents in
67
+ document_variable_name="docs",
68
+ # Return the results of the map steps in the output
69
+ return_intermediate_steps=False,
70
+ )
71
+
72
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
73
+ chunk_size=1000, chunk_overlap=0
74
+ )
75
+ split_docs = text_splitter.split_documents(docs)
76
+
77
+ summary = map_reduce_chain.run(split_docs)
78
+ return history + [(None, summary)]
79
+
80
+
81
+ def prepare_chat(api_key, history):
82
+ global docs, summary, llm, qa
83
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128)
84
+ documents = text_splitter.split_documents(docs)
85
+ embeddings = OpenAIEmbeddings(openai_api_key=api_key)
86
+ vectorstore = Chroma.from_documents(documents, embeddings)
87
+ retriever = vectorstore.as_retriever(
88
+ search_type="similarity", search_kwargs={"k": 6}
89
+ )
90
+ qa_prompt_template = (
91
+ """As an AI assistant you help in answering questions about the contents of a web page.
92
+ The summary of the current web page is this:
93
+
94
+ """
95
+ + summary
96
+ + """
97
+
98
+ Also, consider this additional context that may be relevant for the user's question:
99
+
100
+ {context}
101
+
102
+ Please answer following question: {question}"""
103
+ )
104
+
105
+ qa_prompt = PromptTemplate(
106
+ template=qa_prompt_template, input_variables=["context", "question"]
107
+ )
108
+
109
+ memory = ConversationBufferMemory(
110
+ memory_key="chat_history", return_messages=True, output_key="answer"
111
+ )
112
+ qa = ConversationalRetrievalChain.from_llm(
113
+ llm=llm,
114
+ memory=memory,
115
+ retriever=retriever,
116
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
117
+ )
118
+ return history + [(None, "You can now ask me specific questions about the page.")]
119
+
120
+
121
+ def chatbot_function(message, history):
122
+ global qa
123
+ return "", history + [(message, qa.run(message))]
124
+
125
+
126
+ def build_demo():
127
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
128
+ with gr.Row() as config_row:
129
+ with gr.Column():
130
+ api_key_box = gr.Textbox(
131
+ show_label=False,
132
+ placeholder="OpenAI API Key",
133
+ container=False,
134
+ autofocus=True,
135
+ )
136
+ url_box = gr.Textbox(
137
+ show_label=False,
138
+ placeholder="URL",
139
+ container=False,
140
+ )
141
+ load_btn = gr.Button(value="Load", variant="primary")
142
+ with gr.Row(visible=False) as chat_row:
143
+ with gr.Column():
144
+ with gr.Row():
145
+ chatbot = gr.Chatbot(
146
+ elem_id="chatbot",
147
+ label="Web Chat",
148
+ height=550,
149
+ )
150
+ with gr.Row(visible=False) as inputs_row:
151
+ with gr.Column(scale=8):
152
+ text_box = gr.Textbox(
153
+ show_label=False,
154
+ placeholder="Enter text and press ENTER",
155
+ autofocus=True,
156
+ container=False,
157
+ )
158
+ with gr.Column(scale=1, min_width=50):
159
+ submit_btn = gr.Button(
160
+ value="Send",
161
+ variant="primary",
162
+ )
163
+
164
+ load_btn.click(
165
+ lambda: gr.update(visible=False),
166
+ outputs=[config_row],
167
+ ).then(
168
+ lambda: gr.update(visible=True),
169
+ outputs=[chat_row],
170
+ ).then(
171
+ wait_for_summarization,
172
+ inputs=[url_box],
173
+ outputs=[chatbot],
174
+ ).then(
175
+ load_page,
176
+ inputs=[url_box, api_key_box, chatbot],
177
+ outputs=[chatbot],
178
+ ).then(
179
+ prepare_chat,
180
+ inputs=[api_key_box, chatbot],
181
+ outputs=[chatbot],
182
+ ).then(
183
+ lambda: gr.update(visible=True),
184
+ outputs=[inputs_row],
185
+ )
186
+
187
+ text_box.submit(
188
+ chatbot_function,
189
+ [text_box, chatbot],
190
+ [text_box, chatbot],
191
+ )
192
+ submit_btn.click(
193
+ chatbot_function,
194
+ [text_box, chatbot],
195
+ [text_box, chatbot],
196
+ )
197
+
198
+ return demo
199
+
200
+
201
+ if __name__ == "__main__":
202
+ demo = build_demo()
203
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ langchain==0.1.0
2
+ langchain-community==0.0.12
3
+ langchain-core==0.1.10
4
+ langsmith==0.0.80
5
+ openai==1.7.2
6
+ chromadb==0.4.22
7
+ tiktoken==0.5.2