Ibrahim Kaiser commited on
Commit
3c7b348
·
1 Parent(s): 9bda727

Replace default files with traffic rules Q&A application

Browse files
Files changed (9) hide show
  1. .gitignore +19 -0
  2. Dockerfile +0 -21
  3. README.md +12 -17
  4. requirements.txt +10 -3
  5. src/app.py +238 -0
  6. src/config.py +37 -0
  7. src/rag.py +262 -0
  8. src/state.py +27 -0
  9. src/streamlit_app.py +0 -40
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ vector_store/
7
+ data/
8
+ wheels/
9
+ *.egg-info
10
+
11
+ # Virtual environments
12
+ .venv
13
+ .env
14
+ EOF
15
+
16
+ # PDF files
17
+ data/traffic_rules.pdf
18
+ src/data/traffic_rules.pdf
19
+ *.pdf
Dockerfile DELETED
@@ -1,21 +0,0 @@
1
- FROM python:3.9-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
-
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
-
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,20 +1,15 @@
1
- ---
2
- title: Bangladesh Traffic Rules
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- license: mit
13
- ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
1
+ # Bangladesh Traffic Rules Q&A
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ This chatbot specializes in answering questions about Bangladesh traffic rules and regulations.
4
 
5
+ ## Features
6
 
7
+ - RAG-based retrieval from official documents
8
+ - Caching for faster responses
9
+ - Traffic-specific query detection
10
+
11
+ ## How to Use
12
+
13
+ 1. Ask a question about Bangladesh traffic rules
14
+ 2. The system will search for relevant information
15
+ 3. You'll receive an answer based on the traffic rules database
requirements.txt CHANGED
@@ -1,3 +1,10 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
1
+ langgraph
2
+ langchain-community #for local ollama
3
+ langchain-ollama
4
+ langchain-core
5
+ langchain-huggingface
6
+ python-dotenv
7
+ faiss-cpu
8
+ pypdf
9
+ streamlit
10
+ sentence-transformers
src/app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import streamlit as st # For creating the web interface
3
+ from langgraph.graph import StateGraph, START, END # For creating the agent workflow
4
+ from langgraph.checkpoint.memory import MemorySaver # For saving conversation state
5
+ from langchain_core.messages import HumanMessage, AIMessage # For message types
6
+ from rag import build_vectorstore, load_vectorstore, get_context # RAG functions
7
+ from state import AgentState # State definition for our agent
8
+
9
+ # Build or load the vector store (only done once at startup)
10
+ build_vectorstore()
11
+ vectorstore = load_vectorstore()
12
+
13
+ def retrieve_rules(state: AgentState) -> AgentState:
14
+ """Retrieve relevant traffic rules based on the user's question."""
15
+ question = state["messages"][-1].content
16
+ context = get_context(vectorstore, question, k=5)
17
+ state["retriever"] = context
18
+ return state
19
+
20
+ def is_traffic_related(question: str) -> bool:
21
+ """Check if the user's question is related to traffic rules."""
22
+ traffic_keywords = [
23
+ 'traffic', 'driving', 'vehicle', 'car', 'motorcycle', 'bike', 'road', 'speed',
24
+ 'license', 'licence', 'signal', 'parking', 'park', 'helmet', 'seatbelt', 'accident',
25
+ 'fine', 'penalty', 'police', 'highway', 'junction', 'crossing', 'lane', 'overtake',
26
+ 'bangladesh', 'dhaka', 'chittagong', 'sylhet', 'rajshahi', 'khulna', 'barisal',
27
+ 'rangpur', 'mymensingh', 'traffic light', 'roundabout', 'u-turn', 'horn',
28
+ 'designated', 'violation', 'rule', 'law', 'regulation', 'transport', 'motor',
29
+ 'rickshaw', 'bus', 'truck', 'taxi', 'cng', 'auto', 'tempo'
30
+ ]
31
+ question_lower = question.lower()
32
+ return any(keyword in question_lower for keyword in traffic_keywords)
33
+
34
+ def clean_text(text: str) -> str:
35
+ """Clean and normalize text formatting."""
36
+ # Remove excessive whitespace and normalize line breaks
37
+ import re
38
+
39
+ # Replace multiple whitespace characters with single space
40
+ text = re.sub(r'\s+', ' ', text)
41
+
42
+ # Remove extra newlines and normalize
43
+ text = re.sub(r'\n+', ' ', text)
44
+
45
+ # Clean up bullet points and formatting
46
+ text = re.sub(r'●\s*', '• ', text)
47
+
48
+ return text.strip()
49
+
50
+ def generate_smart_answer(question: str, context: str) -> str:
51
+ """Generate answer by intelligently parsing the context."""
52
+ if not context:
53
+ return "I don't have specific information about this traffic rule in my database."
54
+
55
+ # Clean the context first
56
+ context = clean_text(context)
57
+ context_lower = context.lower()
58
+ question_lower = question.lower()
59
+
60
+ # Split context into sentences for better processing
61
+ sentences = []
62
+ for s in context.split('.'):
63
+ cleaned_sentence = clean_text(s)
64
+ if cleaned_sentence and len(cleaned_sentence) > 10: # Only keep meaningful sentences
65
+ sentences.append(cleaned_sentence)
66
+
67
+ relevant_sentences = []
68
+
69
+ # Look for penalty/fine information
70
+ if any(word in question_lower for word in ['penalty', 'fine', 'punishment']):
71
+ for sentence in sentences:
72
+ if any(word in sentence.lower() for word in ['penalty', 'fine', 'tk ', 'taka', 'punishment', 'fee', 'charge', 'jail', 'prison']):
73
+ relevant_sentences.append(sentence)
74
+
75
+ # Look for speed-related information
76
+ elif any(word in question_lower for word in ['speed', 'limit', 'fast', 'overspeeding']):
77
+ for sentence in sentences:
78
+ if any(word in sentence.lower() for word in ['speed', 'limit', 'kmph', 'km/h', 'velocity', 'overtake', 'overspeed']):
79
+ relevant_sentences.append(sentence)
80
+
81
+ # Look for parking-related information
82
+ elif any(word in question_lower for word in ['park', 'parking']):
83
+ for sentence in sentences:
84
+ if any(word in sentence.lower() for word in ['park', 'parking', 'designated', 'stand', 'stopping']):
85
+ relevant_sentences.append(sentence)
86
+
87
+ # Look for helmet-related information
88
+ elif 'helmet' in question_lower:
89
+ for sentence in sentences:
90
+ if 'helmet' in sentence.lower():
91
+ relevant_sentences.append(sentence)
92
+
93
+ # Look for license-related information
94
+ elif any(word in question_lower for word in ['license', 'licence']):
95
+ for sentence in sentences:
96
+ if any(word in sentence.lower() for word in ['license', 'licence', 'permit', 'driving']):
97
+ relevant_sentences.append(sentence)
98
+
99
+ # General keyword matching
100
+ else:
101
+ question_words = question_lower.split()
102
+ for sentence in sentences:
103
+ sentence_lower = sentence.lower()
104
+ if any(word in sentence_lower for word in question_words if len(word) > 3):
105
+ relevant_sentences.append(sentence)
106
+
107
+ # If we found relevant sentences, use them
108
+ if relevant_sentences:
109
+ # Take the most relevant sentences (max 3) and clean them
110
+ cleaned_sentences = []
111
+ for sentence in relevant_sentences[:3]:
112
+ cleaned = clean_text(sentence)
113
+ if cleaned:
114
+ cleaned_sentences.append(cleaned)
115
+
116
+ if cleaned_sentences:
117
+ answer_text = '. '.join(cleaned_sentences)
118
+ return f"**Based on Bangladesh traffic rules:**\n\n{answer_text}."
119
+
120
+ # Fallback: return first few sentences of context
121
+ fallback_sentences = []
122
+ for sentence in sentences[:2]:
123
+ cleaned = clean_text(sentence)
124
+ if cleaned:
125
+ fallback_sentences.append(cleaned)
126
+
127
+ if fallback_sentences:
128
+ return f"**Based on Bangladesh traffic rules:**\n\n{'. '.join(fallback_sentences)}."
129
+ else:
130
+ return "I found relevant information but couldn't format it properly. Please try rephrasing your question."
131
+
132
+ def generate_answer(state: AgentState) -> AgentState:
133
+ """Generate an answer using context parsing instead of LLM."""
134
+ question = state['messages'][-1].content
135
+ context = state['retriever'].strip()
136
+
137
+ # Show debug information
138
+ st.write(f"🔍 Debug - Is traffic related: {is_traffic_related(question)}")
139
+ st.write(f"🔍 Debug - Context found: {len(context) > 0}")
140
+ st.write(f"🔍 Debug - Context length: {len(context)}")
141
+
142
+ # Option to show the actual context being used
143
+ if context and st.checkbox("Show retrieved context (for debugging)", key="debug_context"):
144
+ st.text_area("Retrieved Context:", context, height=200)
145
+
146
+ if is_traffic_related(question):
147
+ if context:
148
+ # Use our smart parsing instead of LLM
149
+ answer = generate_smart_answer(question, context)
150
+ st.write("✅ Using context-based answer generation")
151
+ else:
152
+ answer = "I don't have specific information about this particular traffic rule in my Bangladesh traffic rules database. For accurate and up-to-date information, I recommend checking with local traffic police or the Bangladesh Road Transport Authority (BRTA)."
153
+ else:
154
+ answer = """I specialize in answering questions about Bangladesh traffic rules and regulations. Your question seems to be about something else.
155
+
156
+ If you have any questions about traffic laws, driving regulations, parking rules, vehicle licensing, or traffic penalties in Bangladesh, I'd be happy to help!"""
157
+
158
+ state["answer"] = answer
159
+ state["messages"].append(AIMessage(content=answer))
160
+ return state
161
+
162
+ # Create the workflow
163
+ workflow = StateGraph(AgentState)
164
+ workflow.add_node("retrieve", retrieve_rules)
165
+ workflow.add_node("answer", generate_answer)
166
+ workflow.add_edge(START, "retrieve")
167
+ workflow.add_edge("retrieve", "answer")
168
+ workflow.add_edge("answer", END)
169
+ graph = workflow.compile(checkpointer=MemorySaver())
170
+
171
+ # Streamlit UI
172
+ st.title("🚦 Bangladesh Traffic Rules Q&A")
173
+ st.markdown("Ask questions about traffic rules, parking regulations, penalties, and driving laws in Bangladesh.")
174
+
175
+ if "state" not in st.session_state:
176
+ st.session_state.state = AgentState(messages=[], retriever="", answer="")
177
+
178
+ with st.expander("📝 Example Questions"):
179
+ st.markdown("""
180
+ - What is the penalty for not parking in a designated place?
181
+ - What is the penalty for overspeeding?
182
+ - What are the helmet rules for motorcycle riders?
183
+ - What documents do I need to carry while driving?
184
+ - What is the fine for running a red light?
185
+ """)
186
+
187
+ question = st.text_input("Ask a question about Bangladesh traffic rules:",
188
+ placeholder="e.g., What is the penalty for overspeeding?")
189
+
190
+ if st.button("Submit") and question.strip():
191
+ with st.spinner("Searching traffic rules..."):
192
+ st.session_state.state["messages"].append(HumanMessage(content=question))
193
+ thread_config = {"configurable": {"thread_id": "traffic_qa"}}
194
+ result = graph.invoke(st.session_state.state, config=thread_config)
195
+ st.session_state.state = result
196
+
197
+ st.markdown("### 📋 Answer")
198
+ st.markdown(result['answer'])
199
+
200
+ # Show conversation history
201
+ if st.session_state.state["messages"]:
202
+ st.markdown("---")
203
+ st.markdown("### 💬 Conversation History")
204
+ recent_messages = st.session_state.state["messages"][-10:]
205
+
206
+ for i, msg in enumerate(recent_messages):
207
+ if isinstance(msg, HumanMessage):
208
+ st.markdown(f"**❓ You:** {msg.content}")
209
+ else:
210
+ st.markdown(f"**🤖 Assistant:** {msg.content}")
211
+ if i < len(recent_messages) - 1:
212
+ st.markdown("---")
213
+
214
+ # Sidebar
215
+ with st.sidebar:
216
+ st.markdown("### 📊 System Status")
217
+ from rag import query_cache
218
+ st.write(f"🔍 Cached queries: {len(query_cache)}")
219
+ st.write(f"💬 Total messages: {len(st.session_state.state['messages'])}")
220
+
221
+ if st.button("Clear Cache"):
222
+ query_cache.clear()
223
+ st.success("Cache cleared!")
224
+
225
+ if st.button("Clear Conversation"):
226
+ st.session_state.state = AgentState(messages=[], retriever="", answer="")
227
+ st.success("Conversation cleared!")
228
+
229
+ st.markdown("---")
230
+ st.markdown("### ℹ️ About")
231
+ st.markdown("""
232
+ This version uses intelligent context parsing instead of LLM for more reliable responses.
233
+
234
+ **Features:**
235
+ - RAG-based retrieval from official documents
236
+ - Smart context parsing for answers
237
+ - Traffic-specific query detection
238
+ """)
src/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import os # For interacting with the operating system
3
+ from dotenv import load_dotenv # For loading environment variables from a .env file
4
+
5
+ # Load environment variables from .env file
6
+ load_dotenv()
7
+
8
+ # Configuration for the Hugging Face model
9
+ # This specifies which model to use for generating answers
10
+ HUGGINGFACE_MODEL = os.getenv("HUGGINGFACE_MODEL", "google/flan-t5-base")
11
+
12
+ # Path to the traffic rules PDF document
13
+ # This is the source document that contains all the traffic rules
14
+ DATA_PATH = os.getenv("DATA_PATH", "./data/traffic_rules.pdf")
15
+
16
+ # Path to the vector store file
17
+ # This is where the processed and searchable version of the traffic rules is stored
18
+ VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./vector_store/index.faiss")
19
+
20
+ # Hugging Face API token for accessing models
21
+ # This is required to use Hugging Face's hosted models
22
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "")
23
+
24
+ # Hugging Face embedding model - using a small but efficient model
25
+ # This model is only 22MB in size but provides good performance for semantic search
26
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+ # Text chunking configuration
29
+ # These settings determine how the PDF document is broken into manageable pieces
30
+
31
+ # CHUNK_SIZE: The maximum number of characters in each chunk
32
+ # Larger chunks preserve more context but may be less precise
33
+ CHUNK_SIZE = 800
34
+
35
+ # CHUNK_OVERLAP: The number of characters that overlap between adjacent chunks
36
+ # Overlap helps ensure that important information isn't split between chunks
37
+ CHUNK_OVERLAP = 100
src/rag.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # For file and directory operations
2
+ import re # For regular expressions (text processing)
3
+ from langchain_community.document_loaders import PyPDFLoader # For loading PDF documents
4
+ from langchain_community.vectorstores import FAISS # For creating and using vector stores
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings # For creating text embeddings using Hugging Face
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter # For splitting text into chunks
7
+ from huggingface_hub import hf_hub_download # For downloading files from Hugging Face
8
+ from config import DATA_PATH, VECTOR_STORE_PATH, CHUNK_SIZE, CHUNK_OVERLAP, EMBEDDING_MODEL # Configuration settings
9
+
10
+ # Simple cache for storing query results
11
+ # This helps speed up responses for repeated questions
12
+ query_cache = {}
13
+
14
+ def build_vectorstore():
15
+ """
16
+ Build a vector store from the traffic rules PDF document.
17
+
18
+ This function processes the PDF document by:
19
+ 1. Downloading the PDF from Hugging Face Datasets
20
+ 2. Loading the document
21
+ 3. Splitting it into manageable chunks
22
+ 4. Creating embeddings for each chunk using Hugging Face
23
+ 5. Storing the embeddings in a vector store for efficient searching
24
+
25
+ The vector store is saved to disk so it only needs to be built once.
26
+ """
27
+ # Check if the vector store already exists
28
+ if os.path.exists(VECTOR_STORE_PATH):
29
+ print("Vector store already exists.")
30
+ return
31
+
32
+ print("Building vector store...")
33
+
34
+ # Download the PDF from Hugging Face Datasets
35
+ # Replace "ikReza/traffic-rules-pdf" with your actual dataset ID
36
+ pdf_path = hf_hub_download(
37
+ repo_id="ikReza/traffic-rules-pdf",
38
+ filename="traffic_rules.pdf",
39
+ repo_type="dataset",
40
+ local_dir="./data" # Download to the data directory
41
+ )
42
+
43
+ # Load the PDF document
44
+ loader = PyPDFLoader(pdf_path)
45
+ docs = loader.load()
46
+
47
+ # Create a text splitter to break the document into chunks
48
+ splitter = RecursiveCharacterTextSplitter(
49
+ chunk_size=CHUNK_SIZE, # Maximum characters per chunk
50
+ chunk_overlap=CHUNK_OVERLAP, # Characters that overlap between chunks
51
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] # Hierarchy of separators to use
52
+ )
53
+
54
+ # Split the document into chunks
55
+ split_docs = splitter.split_documents(docs)
56
+
57
+ # Create embeddings for the chunks using Hugging Face model
58
+ # Embeddings are numerical representations of text that capture semantic meaning
59
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
60
+
61
+ # Create a vector store from the document chunks and their embeddings
62
+ vectorstore = FAISS.from_documents(split_docs, embeddings)
63
+
64
+ # Save the vector store to disk
65
+ os.makedirs(os.path.dirname(VECTOR_STORE_PATH), exist_ok=True)
66
+ vectorstore.save_local(VECTOR_STORE_PATH)
67
+
68
+ print(f"Vector store built with {len(split_docs)} chunks")
69
+
70
+ def load_vectorstore():
71
+ """
72
+ Load the vector store from disk.
73
+
74
+ This function loads the previously created vector store so it can be used
75
+ for searching relevant traffic rules based on user questions.
76
+
77
+ Returns:
78
+ The loaded vector store
79
+ """
80
+ # Create embeddings object (same model used when building the vector store)
81
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
82
+
83
+ # Load the vector store from disk
84
+ return FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True)
85
+
86
+ def preprocess_query(query):
87
+ """
88
+ Enhance the user's query with synonyms and related terms.
89
+
90
+ This function improves search results by adding related terms to the query,
91
+ increasing the chances of finding relevant traffic rules.
92
+
93
+ Args:
94
+ query: The original user query
95
+
96
+ Returns:
97
+ Enhanced query with additional related terms
98
+ """
99
+ # Convert query to lowercase for processing
100
+ query_lower = query.lower()
101
+
102
+ # List of enhancements to add based on query content
103
+ enhancements = []
104
+
105
+ # Add parking-related terms if the query mentions parking
106
+ if 'park' in query_lower or 'parking' in query_lower:
107
+ enhancements.extend(['parking', 'park', 'designated place', 'vehicle standing',
108
+ 'section 47', 'section 90', 'picking up', 'dropping off'])
109
+
110
+ # Add penalty-related terms if the query mentions penalties or fines
111
+ if 'penalty' in query_lower or 'fine' in query_lower:
112
+ enhancements.extend(['penalty', 'fine', 'punishment', 'fee', 'charge', 'taka',
113
+ 'demerit point', 'violating', 'violation', 'offence'])
114
+
115
+ # Add helmet-related terms if the query mentions helmets
116
+ if 'helmet' in query_lower:
117
+ enhancements.extend(['helmet', 'protective gear', 'safety equipment'])
118
+
119
+ # Add speed-related terms if the query mentions speed
120
+ if 'speed' in query_lower:
121
+ enhancements.extend(['speed limit', 'velocity', 'driving speed'])
122
+
123
+ # Add license-related terms if the query mentions licenses
124
+ if 'license' in query_lower or 'licence' in query_lower:
125
+ enhancements.extend(['license', 'licence', 'driving permit', 'permit'])
126
+
127
+ # Add section references if the query mentions designated places
128
+ if 'designated' in query_lower:
129
+ enhancements.extend(['designated place', 'section 47', 'section 90',
130
+ 'motor vehicles', 'provisions'])
131
+
132
+ # Combine the original query with the enhancements
133
+ enhanced_query = f"{query} {' '.join(enhancements)}"
134
+ return enhanced_query
135
+
136
+ def get_context(vectorstore, query, k=5):
137
+ """
138
+ Retrieve relevant traffic rules based on the user's query.
139
+
140
+ This function searches the vector store for traffic rules that match the query,
141
+ using caching to speed up repeated queries.
142
+
143
+ Args:
144
+ vectorstore: The vector store to search in
145
+ query: The user's question
146
+ k: The number of relevant documents to retrieve
147
+
148
+ Returns:
149
+ A string containing the relevant traffic rules, or an empty string if no relevant rules are found
150
+ """
151
+ # Create a cache key for this query
152
+ cache_key = f"{query}_{k}"
153
+
154
+ # Check if we already have a cached result for this query
155
+ if cache_key in query_cache:
156
+ print("Using cached result")
157
+ return query_cache[cache_key]
158
+
159
+ print(f"Searching for: {query}")
160
+
161
+ # Enhance the query with related terms to improve search results
162
+ enhanced_query = preprocess_query(query)
163
+ print(f"Enhanced query: {enhanced_query}")
164
+
165
+ try:
166
+ # Search the vector store for documents similar to the query
167
+ docs_with_scores = vectorstore.similarity_search_with_score(query, k=k)
168
+
169
+ # Filter relevant results based on their similarity scores
170
+ # Lower scores indicate higher similarity
171
+ relevant_docs = []
172
+ for doc, score in docs_with_scores:
173
+ # Use a threshold to determine if a document is relevant enough
174
+ if score < 1.5: # This threshold may need adjustment based on your data
175
+ relevant_docs.append(doc.page_content)
176
+ print(f"Including doc with score: {score}")
177
+ else:
178
+ print(f"Excluding doc with score: {score}")
179
+
180
+ # Join the relevant documents into a single context string
181
+ if relevant_docs:
182
+ context = "\n\n---\n\n".join(relevant_docs)
183
+ print(f"Found {len(relevant_docs)} relevant documents")
184
+ else:
185
+ context = ""
186
+ print("No relevant documents found")
187
+
188
+ except Exception as e:
189
+ # Handle any errors that might occur during the search
190
+ print(f"Error during search: {e}")
191
+ context = ""
192
+
193
+ # Simple cache management - keep only the last 100 queries
194
+ if len(query_cache) > 100:
195
+ # Remove the oldest entries to make room for new ones
196
+ oldest_keys = list(query_cache.keys())[:20]
197
+ for key in oldest_keys:
198
+ del query_cache[key]
199
+
200
+ # Cache the result for future use
201
+ query_cache[cache_key] = context
202
+
203
+ return context
204
+
205
+ def extract_key_terms(query):
206
+ """
207
+ Extract key terms from the query for better searching.
208
+
209
+ This function removes common stop words from the query to focus on
210
+ the important terms that are likely to match traffic rules.
211
+
212
+ Args:
213
+ query: The user's question
214
+
215
+ Returns:
216
+ A list of key terms extracted from the query
217
+ """
218
+ # Set of common stop words to filter out
219
+ stop_words = {'what', 'is', 'the', 'how', 'where', 'when', 'why', 'do', 'does',
220
+ 'can', 'could', 'should', 'would', 'will', 'are', 'am', 'was',
221
+ 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having',
222
+ 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
223
+ 'of', 'with', 'by', 'about', 'into', 'through', 'during',
224
+ 'before', 'after', 'above', 'below', 'up', 'down', 'out',
225
+ 'off', 'over', 'under', 'again', 'further', 'then', 'once'}
226
+
227
+ # Extract words from the query using regular expressions
228
+ words = re.findall(r'\b\w+\b', query.lower())
229
+
230
+ # Filter out stop words and very short words
231
+ key_terms = [word for word in words if word not in stop_words and len(word) > 2]
232
+
233
+ return key_terms
234
+
235
+ # Debug function to test search results (useful for development)
236
+ def debug_search(vectorstore, query, k=5):
237
+ """
238
+ Debug function to test and analyze search results.
239
+
240
+ This function is useful during development to understand how the search
241
+ is performing and to fine-tune parameters.
242
+
243
+ Args:
244
+ vectorstore: The vector store to search in
245
+ query: The query to test
246
+ k: The number of results to return
247
+
248
+ Returns:
249
+ The search results with their scores
250
+ """
251
+ print(f"\n=== DEBUG SEARCH FOR: '{query}' ===")
252
+
253
+ # Perform the search and get results with scores
254
+ docs_with_scores = vectorstore.similarity_search_with_score(query, k=k)
255
+
256
+ # Print each result with its score
257
+ for i, (doc, score) in enumerate(docs_with_scores):
258
+ print(f"\n--- Result {i+1} (Score: {score:.4f}) ---")
259
+ print(f"Content: {doc.page_content[:200]}...") # Show first 200 characters
260
+ print(f"Metadata: {doc.metadata}")
261
+
262
+ return docs_with_scores
src/state.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, TypedDict
2
+ from langgraph.graph.message import add_messages # reducer
3
+
4
+ class AgentState(TypedDict):
5
+ """
6
+ Define the state of our traffic rules agent.
7
+
8
+ This class defines the structure of the state that will be passed between
9
+ different nodes in our LangGraph workflow. It contains all the information
10
+ needed to process a user's question and generate an answer.
11
+
12
+ Attributes:
13
+ messages: A list of conversation messages (both human and AI)
14
+ retriever: The retrieved traffic rules context
15
+ answer: The generated answer to the user's question
16
+ """
17
+ # The messages attribute stores the conversation history
18
+ # It's annotated with add_messages to handle appending new messages
19
+ messages: Annotated[list, add_messages]
20
+
21
+ # The retriever attribute stores the traffic rules context retrieved from the vector store
22
+ # This is used as the source of information for generating answers
23
+ retriever: str
24
+
25
+ # The answer attribute stores the final generated answer
26
+ # This is what will be displayed to the user
27
+ answer: str
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))