MoizK commited on
Commit
1b52b0b
·
verified ·
1 Parent(s): 760961f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +130 -128
model.py CHANGED
@@ -1,128 +1,130 @@
1
- from langchain.prompts import PromptTemplate
2
- from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.vectorstores import FAISS
4
- from langchain.llms import HuggingFacePipeline
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- from langchain.chains import RetrievalQA
7
- import chainlit as cl
8
- from dotenv import load_dotenv
9
- import torch
10
- import os
11
-
12
- load_dotenv()
13
-
14
- DB_FAISS_PATH = 'vectorstore/db_faiss'
15
-
16
- # Prompt Template
17
- custom_prompt_template = """Use the following pieces of information to answer the user's question.
18
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
19
-
20
- Context: {context}
21
- Question: {question}
22
-
23
- Only return the helpful answer below and nothing else.
24
- Helpful answer:
25
- """
26
-
27
- def set_custom_prompt():
28
- prompt = PromptTemplate(template=custom_prompt_template,
29
- input_variables=['context', 'question'])
30
- return prompt
31
-
32
- # Create RetrievalQA chain
33
- def retrieval_qa_chain(llm, prompt, db):
34
- qa_chain = RetrievalQA.from_chain_type(
35
- llm=llm,
36
- chain_type='stuff',
37
- retriever=db.as_retriever(search_kwargs={'k': 2}),
38
- return_source_documents=True,
39
- chain_type_kwargs={'prompt': prompt}
40
- )
41
- return qa_chain
42
-
43
- # Load Hugging Face LLM
44
- def load_llm():
45
- # Load model and tokenizer
46
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
47
- model = AutoModelForSeq2SeqLM.from_pretrained(
48
- "google/flan-t5-base",
49
- device_map="cpu",
50
- torch_dtype=torch.float32
51
- )
52
-
53
- # Create text-generation pipeline without invalid parameters
54
- pipe = pipeline(
55
- "text2text-generation",
56
- model=model,
57
- tokenizer=tokenizer,
58
- max_new_tokens=512,
59
- repetition_penalty=1.15
60
- )
61
-
62
- # Create LangChain wrapper for the pipeline
63
- llm = HuggingFacePipeline(pipeline=pipe)
64
- return llm
65
-
66
- # Build full chatbot pipeline
67
- def qa_bot():
68
- embeddings = HuggingFaceEmbeddings(
69
- model_name="sentence-transformers/all-MiniLM-L6-v2",
70
- model_kwargs={'device': 'cpu'}
71
- )
72
- db = FAISS.load_local(
73
- DB_FAISS_PATH,
74
- embeddings,
75
- allow_dangerous_deserialization=True
76
- )
77
-
78
- llm = load_llm()
79
- qa_prompt = set_custom_prompt()
80
- qa = retrieval_qa_chain(llm, qa_prompt, db)
81
- return qa
82
-
83
- # Run for one query (used internally)
84
- def final_result(query):
85
- qa_result = qa_bot()
86
- response = qa_result({'query': query})
87
- return response
88
-
89
- # Chainlit UI - Start
90
- @cl.on_chat_start
91
- async def start():
92
- chain = qa_bot()
93
- msg = cl.Message(content="Starting the bot...")
94
- await msg.send()
95
- msg.content = "Hi, Welcome to MindMate. What is your query?"
96
- await msg.update()
97
- cl.user_session.set("chain", chain)
98
-
99
- # Chainlit UI - Handle messages
100
- @cl.on_message
101
- async def main(message: cl.Message):
102
- chain = cl.user_session.get("chain")
103
- cb = cl.AsyncLangchainCallbackHandler(
104
- stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
105
- )
106
- cb.answer_reached = True
107
-
108
- # Use invoke with proper query format
109
- res = await cl.make_async(chain.invoke)(
110
- {"query": message.content},
111
- callbacks=[cb]
112
- )
113
-
114
- # Extract result and sources from the response
115
- answer = res.get("result", "No result found")
116
- sources = res.get("source_documents", [])
117
-
118
- # Format sources to show only the content
119
- if sources:
120
- formatted_sources = []
121
- for source in sources:
122
- if hasattr(source, 'page_content'):
123
- formatted_sources.append(source.page_content.strip())
124
-
125
- if formatted_sources:
126
- answer = f"{answer}\n\nBased on the following information:\n" + "\n\n".join(formatted_sources)
127
-
128
- await cl.Message(content=answer).send()
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain.llms import HuggingFacePipeline
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from langchain.chains import RetrievalQA
7
+ from download_assets import download_assets
8
+ download_assets()
9
+ import chainlit as cl
10
+ from dotenv import load_dotenv
11
+ import torch
12
+ import os
13
+
14
+ load_dotenv()
15
+
16
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
17
+
18
+ # Prompt Template
19
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
20
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
21
+
22
+ Context: {context}
23
+ Question: {question}
24
+
25
+ Only return the helpful answer below and nothing else.
26
+ Helpful answer:
27
+ """
28
+
29
+ def set_custom_prompt():
30
+ prompt = PromptTemplate(template=custom_prompt_template,
31
+ input_variables=['context', 'question'])
32
+ return prompt
33
+
34
+ # Create RetrievalQA chain
35
+ def retrieval_qa_chain(llm, prompt, db):
36
+ qa_chain = RetrievalQA.from_chain_type(
37
+ llm=llm,
38
+ chain_type='stuff',
39
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
40
+ return_source_documents=True,
41
+ chain_type_kwargs={'prompt': prompt}
42
+ )
43
+ return qa_chain
44
+
45
+ # Load Hugging Face LLM
46
+ def load_llm():
47
+ # Load model and tokenizer
48
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
49
+ model = AutoModelForSeq2SeqLM.from_pretrained(
50
+ "google/flan-t5-base",
51
+ device_map="cpu",
52
+ torch_dtype=torch.float32
53
+ )
54
+
55
+ # Create text-generation pipeline without invalid parameters
56
+ pipe = pipeline(
57
+ "text2text-generation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ max_new_tokens=512,
61
+ repetition_penalty=1.15
62
+ )
63
+
64
+ # Create LangChain wrapper for the pipeline
65
+ llm = HuggingFacePipeline(pipeline=pipe)
66
+ return llm
67
+
68
+ # Build full chatbot pipeline
69
+ def qa_bot():
70
+ embeddings = HuggingFaceEmbeddings(
71
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
72
+ model_kwargs={'device': 'cpu'}
73
+ )
74
+ db = FAISS.load_local(
75
+ DB_FAISS_PATH,
76
+ embeddings,
77
+ allow_dangerous_deserialization=True
78
+ )
79
+
80
+ llm = load_llm()
81
+ qa_prompt = set_custom_prompt()
82
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
83
+ return qa
84
+
85
+ # Run for one query (used internally)
86
+ def final_result(query):
87
+ qa_result = qa_bot()
88
+ response = qa_result({'query': query})
89
+ return response
90
+
91
+ # Chainlit UI - Start
92
+ @cl.on_chat_start
93
+ async def start():
94
+ chain = qa_bot()
95
+ msg = cl.Message(content="Starting the bot...")
96
+ await msg.send()
97
+ msg.content = "Hi, Welcome to MindMate. What is your query?"
98
+ await msg.update()
99
+ cl.user_session.set("chain", chain)
100
+
101
+ # Chainlit UI - Handle messages
102
+ @cl.on_message
103
+ async def main(message: cl.Message):
104
+ chain = cl.user_session.get("chain")
105
+ cb = cl.AsyncLangchainCallbackHandler(
106
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
107
+ )
108
+ cb.answer_reached = True
109
+
110
+ # Use invoke with proper query format
111
+ res = await cl.make_async(chain.invoke)(
112
+ {"query": message.content},
113
+ callbacks=[cb]
114
+ )
115
+
116
+ # Extract result and sources from the response
117
+ answer = res.get("result", "No result found")
118
+ sources = res.get("source_documents", [])
119
+
120
+ # Format sources to show only the content
121
+ if sources:
122
+ formatted_sources = []
123
+ for source in sources:
124
+ if hasattr(source, 'page_content'):
125
+ formatted_sources.append(source.page_content.strip())
126
+
127
+ if formatted_sources:
128
+ answer = f"{answer}\n\nBased on the following information:\n" + "\n\n".join(formatted_sources)
129
+
130
+ await cl.Message(content=answer).send()