khalifssa commited on
Commit
7a146f7
·
verified ·
1 Parent(s): 862b47f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -23
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  import streamlit as st
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -8,6 +9,10 @@ from langchain_community.vectorstores import FAISS
8
  from langchain.prompts import PromptTemplate
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
 
 
 
 
 
11
  # Step 1: Load the PDF and create a vector store
12
  @st.cache_resource
13
  def load_pdf_to_vectorstore(pdf_path):
@@ -16,8 +21,8 @@ def load_pdf_to_vectorstore(pdf_path):
16
  documents = loader.load()
17
 
18
  text_splitter = RecursiveCharacterTextSplitter(
19
- chunk_size=500,
20
- chunk_overlap=50,
21
  separators=["\n\n", "\n", ".", " ", ""]
22
  )
23
 
@@ -34,33 +39,40 @@ def load_pdf_to_vectorstore(pdf_path):
34
  # Step 2: Initialize the LaMini model
35
  @st.cache_resource
36
  def setup_model():
37
- model_id = "MBZUAI/LaMini-Flan-T5-783M"
38
  tokenizer = AutoTokenizer.from_pretrained(model_id)
39
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
 
 
 
 
 
 
40
 
41
  pipe = pipeline(
42
  "text2text-generation",
43
  model=model,
44
  tokenizer=tokenizer,
45
- max_length=512,
46
- do_sample=True,
47
  temperature=0.3,
48
  top_p=0.95,
49
- device=0 if torch.cuda.is_available() else -1
 
50
  )
51
  return pipe
52
 
53
  # Step 3: Generate a response using the model and vector store
54
  def generate_response(pipe, vectorstore, user_input):
55
  # Get relevant context
56
-
57
- # Increase k for more context if needed
58
- docs = vectorstore.similarity_search(user_input, k=4) # increased from 3
59
-
60
- # Add document metadata (like page numbers) to help track sources
61
- context = "\n".join([f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}" for doc in docs])
62
 
63
- # Enhanced prompt template
64
  prompt = PromptTemplate(
65
  input_variables=["context", "question"],
66
  template="""
@@ -80,6 +92,20 @@ def generate_response(pipe, vectorstore, user_input):
80
 
81
  return response
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Streamlit UI
84
  def main():
85
  st.title("Medical Chatbot Assistant 🏥")
@@ -88,17 +114,39 @@ def main():
88
  pdf_path = "Medical_book.pdf"
89
 
90
  if os.path.exists(pdf_path):
91
- # Load vector store and model
92
- vectorstore = load_pdf_to_vectorstore(pdf_path)
93
- pipe = setup_model()
94
 
95
- # User input
96
- user_input = st.text_input("Ask your medical question:")
 
 
 
97
 
98
- if user_input:
99
- with st.spinner("Generating response..."):
100
- response = generate_response(pipe, vectorstore, user_input)
101
- st.write(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  else:
103
  st.error("The file 'Medical_book.pdf' was not found in the root directory.")
104
 
 
1
  import os
2
  import torch
3
+ import torch.backends.cudnn as cudnn
4
  import streamlit as st
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
9
  from langchain.prompts import PromptTemplate
10
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
11
 
12
+ # Enable CUDA optimizations if available
13
+ if torch.cuda.is_available():
14
+ cudnn.benchmark = True
15
+
16
  # Step 1: Load the PDF and create a vector store
17
  @st.cache_resource
18
  def load_pdf_to_vectorstore(pdf_path):
 
21
  documents = loader.load()
22
 
23
  text_splitter = RecursiveCharacterTextSplitter(
24
+ chunk_size=1000,
25
+ chunk_overlap=20,
26
  separators=["\n\n", "\n", ".", " ", ""]
27
  )
28
 
 
39
  # Step 2: Initialize the LaMini model
40
  @st.cache_resource
41
  def setup_model():
42
+ model_id = "MBZUAI/LaMini-Flan-T5-248M" # Using smaller model for faster inference
43
  tokenizer = AutoTokenizer.from_pretrained(model_id)
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
47
+ low_cpu_mem_usage=True
48
+ )
49
+
50
+ if torch.cuda.is_available():
51
+ model = model.cuda()
52
 
53
  pipe = pipeline(
54
  "text2text-generation",
55
  model=model,
56
  tokenizer=tokenizer,
57
+ max_length=256,
58
+ do_sample=False,
59
  temperature=0.3,
60
  top_p=0.95,
61
+ device=0 if torch.cuda.is_available() else -1,
62
+ batch_size=1
63
  )
64
  return pipe
65
 
66
  # Step 3: Generate a response using the model and vector store
67
  def generate_response(pipe, vectorstore, user_input):
68
  # Get relevant context
69
+ docs = vectorstore.similarity_search(user_input, k=2)
70
+ context = "\n".join([
71
+ f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}"
72
+ for doc in docs
73
+ ])
 
74
 
75
+ # Create prompt
76
  prompt = PromptTemplate(
77
  input_variables=["context", "question"],
78
  template="""
 
92
 
93
  return response
94
 
95
+ # Cache responses for repeated questions
96
+ @st.cache_data
97
+ def cached_generate_response(user_input, _pipe, _vectorstore):
98
+ return generate_response(_pipe, _vectorstore, user_input)
99
+
100
+ # Batch processing for multiple questions
101
+ def batch_generate_responses(pipe, vectorstore, questions, batch_size=4):
102
+ responses = []
103
+ for i in range(0, len(questions), batch_size):
104
+ batch = questions[i:i + batch_size]
105
+ batch_responses = [generate_response(pipe, vectorstore, q) for q in batch]
106
+ responses.extend(batch_responses)
107
+ return responses
108
+
109
  # Streamlit UI
110
  def main():
111
  st.title("Medical Chatbot Assistant 🏥")
 
114
  pdf_path = "Medical_book.pdf"
115
 
116
  if os.path.exists(pdf_path):
117
+ # Initialize progress
118
+ progress_text = "Operation in progress. Please wait."
 
119
 
120
+ # Load vector store and model with progress indication
121
+ with st.spinner("Loading PDF and initializing model..."):
122
+ vectorstore = load_pdf_to_vectorstore(pdf_path)
123
+ pipe = setup_model()
124
+ st.success("Ready to answer questions!")
125
 
126
+ # Create a chat-like interface
127
+ if "messages" not in st.session_state:
128
+ st.session_state.messages = []
129
+
130
+ # Display chat history
131
+ for message in st.session_state.messages:
132
+ with st.chat_message(message["role"]):
133
+ st.markdown(message["content"])
134
+
135
+ # User input
136
+ if prompt := st.chat_input("Ask your medical question:"):
137
+ # Add user message to chat history
138
+ st.session_state.messages.append({"role": "user", "content": prompt})
139
+ with st.chat_message("user"):
140
+ st.markdown(prompt)
141
+
142
+ # Generate and display response
143
+ with st.chat_message("assistant"):
144
+ with st.spinner("Generating response..."):
145
+ response = cached_generate_response(prompt, pipe, vectorstore)
146
+ st.markdown(response)
147
+ # Add assistant message to chat history
148
+ st.session_state.messages.append({"role": "assistant", "content": response})
149
+
150
  else:
151
  st.error("The file 'Medical_book.pdf' was not found in the root directory.")
152