MindMedic / model.py
MoizK's picture
Update model.py
1b52b0b verified
from langchain.prompts import PromptTemplate
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.chains import RetrievalQA
from download_assets import download_assets
download_assets()
import chainlit as cl
from dotenv import load_dotenv
import torch
import os
load_dotenv()
DB_FAISS_PATH = 'vectorstore/db_faiss'
# Prompt Template
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=['context', 'question'])
return prompt
# Create RetrievalQA chain
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
# Load Hugging Face LLM
def load_llm():
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-base",
device_map="cpu",
torch_dtype=torch.float32
)
# Create text-generation pipeline without invalid parameters
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
repetition_penalty=1.15
)
# Create LangChain wrapper for the pipeline
llm = HuggingFacePipeline(pipeline=pipe)
return llm
# Build full chatbot pipeline
def qa_bot():
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
db = FAISS.load_local(
DB_FAISS_PATH,
embeddings,
allow_dangerous_deserialization=True
)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa = retrieval_qa_chain(llm, qa_prompt, db)
return qa
# Run for one query (used internally)
def final_result(query):
qa_result = qa_bot()
response = qa_result({'query': query})
return response
# Chainlit UI - Start
@cl.on_chat_start
async def start():
chain = qa_bot()
msg = cl.Message(content="Starting the bot...")
await msg.send()
msg.content = "Hi, Welcome to MindMate. What is your query?"
await msg.update()
cl.user_session.set("chain", chain)
# Chainlit UI - Handle messages
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
# Use invoke with proper query format
res = await cl.make_async(chain.invoke)(
{"query": message.content},
callbacks=[cb]
)
# Extract result and sources from the response
answer = res.get("result", "No result found")
sources = res.get("source_documents", [])
# Format sources to show only the content
if sources:
formatted_sources = []
for source in sources:
if hasattr(source, 'page_content'):
formatted_sources.append(source.page_content.strip())
if formatted_sources:
answer = f"{answer}\n\nBased on the following information:\n" + "\n\n".join(formatted_sources)
await cl.Message(content=answer).send()