|
from langchain.prompts import PromptTemplate |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.llms import HuggingFacePipeline |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from langchain.chains import RetrievalQA |
|
from download_assets import download_assets |
|
download_assets() |
|
import chainlit as cl |
|
from dotenv import load_dotenv |
|
import torch |
|
import os |
|
|
|
load_dotenv() |
|
|
|
DB_FAISS_PATH = 'vectorstore/db_faiss' |
|
|
|
|
|
custom_prompt_template = """Use the following pieces of information to answer the user's question. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
|
Context: {context} |
|
Question: {question} |
|
|
|
Only return the helpful answer below and nothing else. |
|
Helpful answer: |
|
""" |
|
|
|
def set_custom_prompt(): |
|
prompt = PromptTemplate(template=custom_prompt_template, |
|
input_variables=['context', 'question']) |
|
return prompt |
|
|
|
|
|
def retrieval_qa_chain(llm, prompt, db): |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type='stuff', |
|
retriever=db.as_retriever(search_kwargs={'k': 2}), |
|
return_source_documents=True, |
|
chain_type_kwargs={'prompt': prompt} |
|
) |
|
return qa_chain |
|
|
|
|
|
def load_llm(): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
"google/flan-t5-base", |
|
device_map="cpu", |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
|
|
pipe = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=512, |
|
repetition_penalty=1.15 |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
return llm |
|
|
|
|
|
def qa_bot(): |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
model_kwargs={'device': 'cpu'} |
|
) |
|
db = FAISS.load_local( |
|
DB_FAISS_PATH, |
|
embeddings, |
|
allow_dangerous_deserialization=True |
|
) |
|
|
|
llm = load_llm() |
|
qa_prompt = set_custom_prompt() |
|
qa = retrieval_qa_chain(llm, qa_prompt, db) |
|
return qa |
|
|
|
|
|
def final_result(query): |
|
qa_result = qa_bot() |
|
response = qa_result({'query': query}) |
|
return response |
|
|
|
|
|
@cl.on_chat_start |
|
async def start(): |
|
chain = qa_bot() |
|
msg = cl.Message(content="Starting the bot...") |
|
await msg.send() |
|
msg.content = "Hi, Welcome to MindMate. What is your query?" |
|
await msg.update() |
|
cl.user_session.set("chain", chain) |
|
|
|
|
|
@cl.on_message |
|
async def main(message: cl.Message): |
|
chain = cl.user_session.get("chain") |
|
cb = cl.AsyncLangchainCallbackHandler( |
|
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] |
|
) |
|
cb.answer_reached = True |
|
|
|
|
|
res = await cl.make_async(chain.invoke)( |
|
{"query": message.content}, |
|
callbacks=[cb] |
|
) |
|
|
|
|
|
answer = res.get("result", "No result found") |
|
sources = res.get("source_documents", []) |
|
|
|
|
|
if sources: |
|
formatted_sources = [] |
|
for source in sources: |
|
if hasattr(source, 'page_content'): |
|
formatted_sources.append(source.page_content.strip()) |
|
|
|
if formatted_sources: |
|
answer = f"{answer}\n\nBased on the following information:\n" + "\n\n".join(formatted_sources) |
|
|
|
await cl.Message(content=answer).send() |
|
|