Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_community.chat_models import ChatOpenAI | |
from dotenv import load_dotenv | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from download_chart import construct_plot | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain import hub | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_experimental.text_splitter import SemanticChunker | |
load_dotenv() | |
def get_docs_from_pdf(file): | |
loader = PyPDFLoader(file) | |
docs = loader.load_and_split() | |
return docs | |
def get_doc_chunks(docs): | |
text_splitter = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-small")) | |
chunks = text_splitter.split_documents(docs) | |
return chunks | |
def get_vectorstore_from_docs(doc_chunks): | |
embedding = OpenAIEmbeddings(model="text-embedding-3-small") | |
vectorstore = FAISS.from_documents(documents=doc_chunks, embedding=embedding) | |
return vectorstore | |
def get_conversation_chain(vectorstore): | |
llm = ChatOpenAI(model="gpt-4o",temperature=0.5, max_tokens=2048) | |
retriever=vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
# Chain | |
rag_chain = ( | |
{"context": retriever , "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return rag_chain | |
def create_db(file): | |
# docs = get_docs_from_pdf(file) | |
# doc_chunks = get_doc_chunks(docs) | |
# vectorstore = get_vectorstore_from_docs(doc_chunks) | |
vectorstore = FAISS.load_local(file, OpenAIEmbeddings(model="text-embedding-3-small"),allow_dangerous_deserialization= True) | |
return vectorstore | |
def get_response(chain,user_query, chat_history): | |
template = """ | |
Chat history: {chat_history} | |
User question: {user_question} | |
""" | |
question = ChatPromptTemplate.from_template(template) | |
question = question.format(chat_history=chat_history, user_question=user_query) | |
return chain.stream(question) | |
def vote(item): | |
st.write(f"Why is {item} your favorite?") | |
reason = st.text_input("Because...") | |
if st.button("Submit"): | |
st.rerun() | |
def display_chat_te(): | |
# app config | |
st.title("Chatbot") | |
# session state | |
if "chat_history_te" not in st.session_state: | |
st.session_state.chat_history_te = [ | |
AIMessage(content="Salut, posez-moi vos question sur la transistion ecologique."), | |
] | |
if "chain" not in st.session_state: | |
db=create_db("./DATA_bziiit/vectorstore_op") | |
chain = get_conversation_chain(db) | |
st.session_state.chain = chain | |
# conversation | |
for message in st.session_state.chat_history_te: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.write(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Moi"): | |
st.write(message.content) | |
style = """ | |
<style> | |
.css-ocqkz7 { | |
position: fixed; | |
bottom: 0; | |
width: 50%; | |
justify-content: center; | |
align-items: end; | |
margin-bottom: 0.5rem; | |
} | |
</style> | |
""" | |
# Inject the styling code for both elements | |
# st.markdown(style, unsafe_allow_html=True) | |
# # user input | |
# col1 , col2 = st.columns([1,8]) | |
# if col1.button("chatbot"): | |
# vote("chatbot") | |
# with col2: | |
user_query = st.chat_input(placeholder="c'est quoi la transition écologique ?") | |
if user_query is not None and user_query != "": | |
st.session_state.chat_history_te.append(HumanMessage(content=user_query)) | |
with st.chat_message("Moi"): | |
st.markdown(user_query) | |
with st.chat_message("AI"): | |
response = st.write_stream(get_response(st.session_state.chain,user_query, st.session_state.chat_history_te)) | |
st.session_state.chat_history_te.append(AIMessage(content=response)) | |