Update src/chatbot.py
Browse files- src/chatbot.py +136 -29
src/chatbot.py
CHANGED
@@ -98,19 +98,61 @@ def is_political_stress(prompt: str) -> bool:
|
|
98 |
return any(keyword in prompt.lower() for keyword in political_keywords)
|
99 |
|
100 |
########################################
|
101 |
-
#
|
102 |
########################################
|
103 |
-
def generate_response(prompt: str, country: str) -> str:
|
104 |
-
lower_prompt = prompt.lower()
|
105 |
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
self_harm_keywords = [
|
108 |
-
"suicide", "kill myself", "end my life", "self-harm",
|
|
|
109 |
]
|
110 |
if any(keyword in lower_prompt for keyword in self_harm_keywords):
|
111 |
-
logging.info(f"Self-harm keyword detected in prompt: {
|
112 |
helpline_str = get_helpline_for_country(country)
|
113 |
-
|
114 |
return (
|
115 |
"I’m really sorry you’re feeling like this. It sounds like you’re in a very tough place right now. "
|
116 |
"If you’re comfortable, could you share more about what’s bringing you to feel this way? "
|
@@ -119,28 +161,34 @@ def generate_response(prompt: str, country: str) -> str:
|
|
119 |
"You’re not alone, and there are caring people who want to help you."
|
120 |
)
|
121 |
|
122 |
-
# 2
|
123 |
-
# 3. Hate speech detection...
|
124 |
|
125 |
-
#
|
126 |
system_prompt = (
|
127 |
"You are a supportive, empathetic companion. "
|
128 |
"Your top priority is to listen and help the user feel heard. "
|
129 |
"You:\n"
|
130 |
"- Acknowledge the user's feelings.\n"
|
131 |
-
"- Ask gentle follow-up questions
|
132 |
-
"- Suggest small, practical
|
133 |
"- Avoid judgment or lecturing.\n"
|
134 |
-
"-
|
135 |
-
"
|
136 |
)
|
137 |
|
138 |
-
#
|
139 |
-
#
|
|
|
140 |
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
|
|
|
|
|
144 |
with torch.no_grad():
|
145 |
outputs = model.generate(
|
146 |
**inputs,
|
@@ -154,15 +202,53 @@ def generate_response(prompt: str, country: str) -> str:
|
|
154 |
)
|
155 |
|
156 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
159 |
if response and response[-1] not in ".!?":
|
160 |
-
|
161 |
-
response = response.rsplit(".", 1)[0].strip() + "."
|
162 |
-
else:
|
163 |
-
response += "."
|
164 |
|
165 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
########################################
|
@@ -186,6 +272,10 @@ def save_user_feedback(user_input, bot_response, rating, extra_comments=""):
|
|
186 |
########################################
|
187 |
# 6. CONVERSATION LOOP
|
188 |
########################################
|
|
|
|
|
|
|
|
|
189 |
def chatbot_conversation(country_selection: str):
|
190 |
"""
|
191 |
Starts the console-based conversation loop, using 'country_selection'
|
@@ -196,6 +286,7 @@ def chatbot_conversation(country_selection: str):
|
|
196 |
print(f"Country set to: {country_selection}\n")
|
197 |
|
198 |
print("Type 'exit' or 'quit' to end the conversation.")
|
|
|
199 |
while True:
|
200 |
user_input = input("\nYou: ")
|
201 |
if user_input.lower() in ["exit", "quit"]:
|
@@ -212,8 +303,19 @@ def chatbot_conversation(country_selection: str):
|
|
212 |
truncated_input = user_input[:HARD_LIMIT]
|
213 |
print("(System Note): Your message exceeded 2048 chars and was truncated.")
|
214 |
|
215 |
-
#
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
print(bot_response)
|
218 |
|
219 |
# Ask for feedback
|
@@ -225,12 +327,12 @@ def chatbot_conversation(country_selection: str):
|
|
225 |
print("Please enter 'y' or 'n' or just hit Enter to skip.")
|
226 |
|
227 |
if rating in ["y", "n"]:
|
228 |
-
# Optional additional comments
|
229 |
extra_comments = input("Any additional comments? (press Enter to skip): ")
|
230 |
rating_symbol = "👍" if rating == "y" else "👎"
|
231 |
save_user_feedback(truncated_input, bot_response, rating_symbol, extra_comments)
|
232 |
print("(Feedback saved.)")
|
233 |
|
|
|
234 |
########################################
|
235 |
# 7. TKINTER GUI WITH COUNTRY SELECTION
|
236 |
########################################
|
@@ -243,7 +345,11 @@ def main():
|
|
243 |
root.title("Mental Health Chatbot Prototype")
|
244 |
|
245 |
# Disclaimer message
|
246 |
-
disclaimer_label = tk.Label(
|
|
|
|
|
|
|
|
|
247 |
disclaimer_label.pack(padx=20, pady=10)
|
248 |
|
249 |
# Country selection label
|
@@ -268,3 +374,4 @@ def main():
|
|
268 |
|
269 |
if __name__ == "__main__":
|
270 |
main()
|
|
|
|
98 |
return any(keyword in prompt.lower() for keyword in political_keywords)
|
99 |
|
100 |
########################################
|
101 |
+
# 4A. SUMMARIZATION HELPER (to keep conversation shorter)
|
102 |
########################################
|
|
|
|
|
103 |
|
104 |
+
def summarize_text(text: str) -> str:
|
105 |
+
"""
|
106 |
+
Very basic summarization approach:
|
107 |
+
1) We feed a short prompt to the same model to summarize the text.
|
108 |
+
2) Return a short summary from the model.
|
109 |
+
|
110 |
+
In a real application, you might use a specialized summarization model.
|
111 |
+
"""
|
112 |
+
summary_prompt = (
|
113 |
+
"Summarize this conversation in a concise way, focusing on key points:\n\n"
|
114 |
+
f"{text}\n\nSummary:"
|
115 |
+
)
|
116 |
+
inputs = tokenizer(summary_prompt, return_tensors="pt").to(device)
|
117 |
+
with torch.no_grad():
|
118 |
+
summary_outputs = model.generate(
|
119 |
+
**inputs,
|
120 |
+
max_new_tokens=128,
|
121 |
+
temperature=0.7,
|
122 |
+
top_p=0.9,
|
123 |
+
top_k=50,
|
124 |
+
repetition_penalty=1.2,
|
125 |
+
no_repeat_ngram_size=2,
|
126 |
+
do_sample=True
|
127 |
+
)
|
128 |
+
summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True)
|
129 |
+
# Basic cleanup
|
130 |
+
if "Summary:" in summary:
|
131 |
+
summary = summary.split("Summary:")[-1].strip()
|
132 |
+
return summary
|
133 |
+
|
134 |
+
########################################
|
135 |
+
# 4B. GPT-Style Response with Conversation History
|
136 |
+
########################################
|
137 |
+
|
138 |
+
def generate_response_with_history(conversation_history, country: str) -> str:
|
139 |
+
"""
|
140 |
+
Generates a response based on the entire conversation history.
|
141 |
+
- conversation_history: a list of turns, each is { "role": "user"/"bot", "content": str }
|
142 |
+
- country: used for helpline logic if self-harm is detected in the LAST user message
|
143 |
+
"""
|
144 |
+
# Check the last user message for self-harm, etc.
|
145 |
+
last_user_message = conversation_history[-1]["content"]
|
146 |
+
lower_prompt = last_user_message.lower()
|
147 |
+
|
148 |
+
# 1) Self-harm detection
|
149 |
self_harm_keywords = [
|
150 |
+
"suicide", "kill myself", "end my life", "self-harm",
|
151 |
+
"hurt myself", "want to die", "no reason to live", "tired of living"
|
152 |
]
|
153 |
if any(keyword in lower_prompt for keyword in self_harm_keywords):
|
154 |
+
logging.info(f"Self-harm keyword detected in prompt: {last_user_message}")
|
155 |
helpline_str = get_helpline_for_country(country)
|
|
|
156 |
return (
|
157 |
"I’m really sorry you’re feeling like this. It sounds like you’re in a very tough place right now. "
|
158 |
"If you’re comfortable, could you share more about what’s bringing you to feel this way? "
|
|
|
161 |
"You’re not alone, and there are caring people who want to help you."
|
162 |
)
|
163 |
|
164 |
+
# 2) Harm-to-others / Hate speech can be similarly handled, or you can do a single pass prior to generation.
|
|
|
165 |
|
166 |
+
# 3) Construct a system prompt to direct the style:
|
167 |
system_prompt = (
|
168 |
"You are a supportive, empathetic companion. "
|
169 |
"Your top priority is to listen and help the user feel heard. "
|
170 |
"You:\n"
|
171 |
"- Acknowledge the user's feelings.\n"
|
172 |
+
"- Ask gentle follow-up questions.\n"
|
173 |
+
"- Suggest small, practical coping ideas.\n"
|
174 |
"- Avoid judgment or lecturing.\n"
|
175 |
+
"- Provide helpline info only if user is in crisis.\n"
|
176 |
+
"Do NOT prefix lines with 'User:' or 'Bot:'.\n"
|
177 |
)
|
178 |
|
179 |
+
# 4) Build a single text prompt from the entire conversation
|
180 |
+
# Start with system instructions:
|
181 |
+
conversation_text = system_prompt
|
182 |
|
183 |
+
for turn in conversation_history:
|
184 |
+
if turn["role"] == "user":
|
185 |
+
conversation_text += f"\nUser says: {turn['content']}"
|
186 |
+
else:
|
187 |
+
conversation_text += f"\nBot says: {turn['content']}"
|
188 |
+
conversation_text += "\nBot says:"
|
189 |
|
190 |
+
# 5) Convert text to tokens and generate
|
191 |
+
inputs = tokenizer(conversation_text, return_tensors="pt").to(device)
|
192 |
with torch.no_grad():
|
193 |
outputs = model.generate(
|
194 |
**inputs,
|
|
|
202 |
)
|
203 |
|
204 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
205 |
+
|
206 |
+
# 6) Clean up any repeated system text
|
207 |
+
response = response.replace(system_prompt, "")
|
208 |
+
response = response.replace("User says:", "").replace("Bot says:", "")
|
209 |
+
response = response.strip()
|
210 |
+
|
211 |
+
# Make sure it ends with punctuation
|
212 |
if response and response[-1] not in ".!?":
|
213 |
+
response += "."
|
|
|
|
|
|
|
214 |
|
215 |
+
return response
|
216 |
+
|
217 |
+
|
218 |
+
########################################
|
219 |
+
# 4C. Possibly Summarize Older History
|
220 |
+
########################################
|
221 |
+
|
222 |
+
def maybe_summarize_history(conversation_history, max_turns=10):
|
223 |
+
"""
|
224 |
+
If conversation_history is too long, summarize the earliest turns into a single 'summary so far'
|
225 |
+
turn. This helps keep token usage manageable.
|
226 |
+
"""
|
227 |
+
if len(conversation_history) > max_turns:
|
228 |
+
# Separate out the first chunk (e.g. first 5 turns).
|
229 |
+
old_turns = conversation_history[:-5]
|
230 |
+
recent_turns = conversation_history[-5:]
|
231 |
+
|
232 |
+
# Build a text from old_turns
|
233 |
+
old_text = ""
|
234 |
+
for turn in old_turns:
|
235 |
+
role = turn["role"]
|
236 |
+
content = turn["content"]
|
237 |
+
old_text += f"{role.upper()}:\n{content}\n\n"
|
238 |
+
|
239 |
+
summary = summarize_text(old_text)
|
240 |
+
# Now store that summary as a single turn with role "bot" (or "system")
|
241 |
+
summary_turn = {
|
242 |
+
"role": "bot",
|
243 |
+
"content": f"(A summary of earlier conversation: {summary})"
|
244 |
+
}
|
245 |
+
|
246 |
+
# Rebuild the conversation: summary + recent 5 turns
|
247 |
+
new_history = [summary_turn] + recent_turns
|
248 |
+
|
249 |
+
# Clear and replace
|
250 |
+
conversation_history.clear()
|
251 |
+
conversation_history.extend(new_history)
|
252 |
|
253 |
|
254 |
########################################
|
|
|
272 |
########################################
|
273 |
# 6. CONVERSATION LOOP
|
274 |
########################################
|
275 |
+
|
276 |
+
# Keep a global conversation history for the console-based loop
|
277 |
+
conversation_history = []
|
278 |
+
|
279 |
def chatbot_conversation(country_selection: str):
|
280 |
"""
|
281 |
Starts the console-based conversation loop, using 'country_selection'
|
|
|
286 |
print(f"Country set to: {country_selection}\n")
|
287 |
|
288 |
print("Type 'exit' or 'quit' to end the conversation.")
|
289 |
+
|
290 |
while True:
|
291 |
user_input = input("\nYou: ")
|
292 |
if user_input.lower() in ["exit", "quit"]:
|
|
|
303 |
truncated_input = user_input[:HARD_LIMIT]
|
304 |
print("(System Note): Your message exceeded 2048 chars and was truncated.")
|
305 |
|
306 |
+
# 1) Append user turn to the conversation
|
307 |
+
conversation_history.append({"role": "user", "content": truncated_input})
|
308 |
+
|
309 |
+
# 2) Summarize older history if we exceed ~10 turns
|
310 |
+
maybe_summarize_history(conversation_history, max_turns=10)
|
311 |
+
|
312 |
+
# 3) Generate response with the entire conversation
|
313 |
+
bot_response = generate_response_with_history(conversation_history, country_selection)
|
314 |
+
|
315 |
+
# 4) Append the bot's response to the conversation
|
316 |
+
conversation_history.append({"role": "bot", "content": bot_response})
|
317 |
+
|
318 |
+
# 5) Display
|
319 |
print(bot_response)
|
320 |
|
321 |
# Ask for feedback
|
|
|
327 |
print("Please enter 'y' or 'n' or just hit Enter to skip.")
|
328 |
|
329 |
if rating in ["y", "n"]:
|
|
|
330 |
extra_comments = input("Any additional comments? (press Enter to skip): ")
|
331 |
rating_symbol = "👍" if rating == "y" else "👎"
|
332 |
save_user_feedback(truncated_input, bot_response, rating_symbol, extra_comments)
|
333 |
print("(Feedback saved.)")
|
334 |
|
335 |
+
|
336 |
########################################
|
337 |
# 7. TKINTER GUI WITH COUNTRY SELECTION
|
338 |
########################################
|
|
|
345 |
root.title("Mental Health Chatbot Prototype")
|
346 |
|
347 |
# Disclaimer message
|
348 |
+
disclaimer_label = tk.Label(
|
349 |
+
root,
|
350 |
+
text="Disclaimer: This is just a prototype and not a substitute for professional help.",
|
351 |
+
fg="red"
|
352 |
+
)
|
353 |
disclaimer_label.pack(padx=20, pady=10)
|
354 |
|
355 |
# Country selection label
|
|
|
374 |
|
375 |
if __name__ == "__main__":
|
376 |
main()
|
377 |
+
|