Spaces:
Sleeping
Sleeping
File size: 4,813 Bytes
be63200 db70198 375bd04 db70198 3373c54 db70198 1dc9fa7 db70198 7713f97 db70198 7713f97 3373c54 7713f97 3373c54 db70198 be63200 de20d93 db70198 259cbe8 9e53bcd db70198 a850fbe db70198 7713f97 db70198 7713f97 db70198 7713f97 db70198 7713f97 db70198 7713f97 3373c54 db70198 7713f97 db70198 7713f97 db70198 7713f97 db70198 7713f97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import tempfile
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from streamlit_extras.add_vertical_space import add_vertical_space
# TODO: refactor
# TODO: extract class
# TODO: modularize
# TODO: hide side bar
# TODO: make the page attactive
st.set_page_config(page_title=":books: InkChatGPT: Chat with Documents", page_icon="π")
st.image("./assets/icon.jpg", width=150)
st.header(":gray[:books: InkChatGPT]", divider="blue")
st.write("**Chat** with Documents")
# Setup memory for contextual conversation
msgs = StreamlitChatMessageHistory()
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create embeddings and store in vectordb
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
# Define retriever
retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4}
)
return retriever
class StreamHandler(BaseCallbackHandler):
def __init__(
self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
):
self.container = container
self.text = initial_text
self.run_id_ignore_token = None
def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
# Workaround to prevent showing the rephrased question as output
if prompts[0].startswith("Human"):
self.run_id_ignore_token = kwargs.get("run_id")
def on_llm_new_token(self, token: str, **kwargs) -> None:
if self.run_id_ignore_token == kwargs.get("run_id", False):
return
self.text += token
self.container.markdown(self.text)
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.status = container.status("**Thinking...**")
self.container = container
def on_retriever_start(self, serialized: dict, query: str, **kwargs):
self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
def on_retriever_end(self, documents, **kwargs):
self.container.empty()
with st.sidebar.expander("Documents"):
st.subheader("Files")
uploaded_files = st.file_uploader(
label="Select PDF files", type=["pdf"], accept_multiple_files=True
)
with st.sidebar.expander("Setup"):
st.subheader("API Key")
openai_api_key = st.text_input("OpenAI API Key", type="password")
is_empty_chat_messages = len(msgs.messages) == 0
if is_empty_chat_messages or st.button("Clear message history"):
msgs.clear()
msgs.add_ai_message("How can I help you?")
if not openai_api_key:
st.info("Please add your OpenAI API key in the sidebar to continue.")
st.stop()
if uploaded_files:
retriever = configure_retriever(uploaded_files)
memory = ConversationBufferMemory(
memory_key="chat_history", chat_memory=msgs, return_messages=True
)
# Setup LLM and QA chain
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
openai_api_key=openai_api_key,
temperature=0,
streaming=True,
)
chain = ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, memory=memory, verbose=False
)
avatars = {"human": "user", "ai": "assistant"}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
if user_query := st.chat_input(placeholder="Ask me anything!"):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.empty())
stream_handler = StreamHandler(st.empty())
response = chain.run(
user_query, callbacks=[retrieval_handler, stream_handler]
)
|