StevesInfinityDrive commited on
Commit
ff3b088
·
verified ·
1 Parent(s): 60f91b4

Update src/chatbot.py

Browse files
Files changed (1) hide show
  1. src/chatbot.py +136 -29
src/chatbot.py CHANGED
@@ -98,19 +98,61 @@ def is_political_stress(prompt: str) -> bool:
98
  return any(keyword in prompt.lower() for keyword in political_keywords)
99
 
100
  ########################################
101
- # 4. GENERATE RESPONSE
102
  ########################################
103
- def generate_response(prompt: str, country: str) -> str:
104
- lower_prompt = prompt.lower()
105
 
106
- # 1. Self-harm detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self_harm_keywords = [
108
- "suicide", "kill myself", "end my life", "self-harm", ...
 
109
  ]
110
  if any(keyword in lower_prompt for keyword in self_harm_keywords):
111
- logging.info(f"Self-harm keyword detected in prompt: {prompt}")
112
  helpline_str = get_helpline_for_country(country)
113
-
114
  return (
115
  "I’m really sorry you’re feeling like this. It sounds like you’re in a very tough place right now. "
116
  "If you’re comfortable, could you share more about what’s bringing you to feel this way? "
@@ -119,28 +161,34 @@ def generate_response(prompt: str, country: str) -> str:
119
  "You’re not alone, and there are caring people who want to help you."
120
  )
121
 
122
- # 2. Harm-to-others detection...
123
- # 3. Hate speech detection...
124
 
125
- # 4. If no major triggers:
126
  system_prompt = (
127
  "You are a supportive, empathetic companion. "
128
  "Your top priority is to listen and help the user feel heard. "
129
  "You:\n"
130
  "- Acknowledge the user's feelings.\n"
131
- "- Ask gentle follow-up questions to encourage them to share more.\n"
132
- "- Suggest small, practical ways they might cope or self-soothe.\n"
133
  "- Avoid judgment or lecturing.\n"
134
- "- Offer additional support if it seems they are in crisis.\n"
135
- "But do NOT prefix your messages with 'User:' or 'Bot:'.\n"
136
  )
137
 
138
- # Political stress addition if needed...
139
- # system_prompt += ...
 
140
 
141
- full_prompt = f"{system_prompt}\nThe user says: {prompt}\nYour response:"
142
- inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
 
 
 
 
143
 
 
 
144
  with torch.no_grad():
145
  outputs = model.generate(
146
  **inputs,
@@ -154,15 +202,53 @@ def generate_response(prompt: str, country: str) -> str:
154
  )
155
 
156
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
157
- # Clean up
158
- response = response.replace(system_prompt, "").replace("The user says:", "").replace("Your response:", "")
 
 
 
 
 
159
  if response and response[-1] not in ".!?":
160
- if "." in response:
161
- response = response.rsplit(".", 1)[0].strip() + "."
162
- else:
163
- response += "."
164
 
165
- return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
  ########################################
@@ -186,6 +272,10 @@ def save_user_feedback(user_input, bot_response, rating, extra_comments=""):
186
  ########################################
187
  # 6. CONVERSATION LOOP
188
  ########################################
 
 
 
 
189
  def chatbot_conversation(country_selection: str):
190
  """
191
  Starts the console-based conversation loop, using 'country_selection'
@@ -196,6 +286,7 @@ def chatbot_conversation(country_selection: str):
196
  print(f"Country set to: {country_selection}\n")
197
 
198
  print("Type 'exit' or 'quit' to end the conversation.")
 
199
  while True:
200
  user_input = input("\nYou: ")
201
  if user_input.lower() in ["exit", "quit"]:
@@ -212,8 +303,19 @@ def chatbot_conversation(country_selection: str):
212
  truncated_input = user_input[:HARD_LIMIT]
213
  print("(System Note): Your message exceeded 2048 chars and was truncated.")
214
 
215
- # Generate & display response
216
- bot_response = generate_response(truncated_input, country_selection)
 
 
 
 
 
 
 
 
 
 
 
217
  print(bot_response)
218
 
219
  # Ask for feedback
@@ -225,12 +327,12 @@ def chatbot_conversation(country_selection: str):
225
  print("Please enter 'y' or 'n' or just hit Enter to skip.")
226
 
227
  if rating in ["y", "n"]:
228
- # Optional additional comments
229
  extra_comments = input("Any additional comments? (press Enter to skip): ")
230
  rating_symbol = "👍" if rating == "y" else "👎"
231
  save_user_feedback(truncated_input, bot_response, rating_symbol, extra_comments)
232
  print("(Feedback saved.)")
233
 
 
234
  ########################################
235
  # 7. TKINTER GUI WITH COUNTRY SELECTION
236
  ########################################
@@ -243,7 +345,11 @@ def main():
243
  root.title("Mental Health Chatbot Prototype")
244
 
245
  # Disclaimer message
246
- disclaimer_label = tk.Label(root, text="Disclaimer: This is just a prototype and not a substitute for professional help.", fg="red")
 
 
 
 
247
  disclaimer_label.pack(padx=20, pady=10)
248
 
249
  # Country selection label
@@ -268,3 +374,4 @@ def main():
268
 
269
  if __name__ == "__main__":
270
  main()
 
 
98
  return any(keyword in prompt.lower() for keyword in political_keywords)
99
 
100
  ########################################
101
+ # 4A. SUMMARIZATION HELPER (to keep conversation shorter)
102
  ########################################
 
 
103
 
104
+ def summarize_text(text: str) -> str:
105
+ """
106
+ Very basic summarization approach:
107
+ 1) We feed a short prompt to the same model to summarize the text.
108
+ 2) Return a short summary from the model.
109
+
110
+ In a real application, you might use a specialized summarization model.
111
+ """
112
+ summary_prompt = (
113
+ "Summarize this conversation in a concise way, focusing on key points:\n\n"
114
+ f"{text}\n\nSummary:"
115
+ )
116
+ inputs = tokenizer(summary_prompt, return_tensors="pt").to(device)
117
+ with torch.no_grad():
118
+ summary_outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=128,
121
+ temperature=0.7,
122
+ top_p=0.9,
123
+ top_k=50,
124
+ repetition_penalty=1.2,
125
+ no_repeat_ngram_size=2,
126
+ do_sample=True
127
+ )
128
+ summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True)
129
+ # Basic cleanup
130
+ if "Summary:" in summary:
131
+ summary = summary.split("Summary:")[-1].strip()
132
+ return summary
133
+
134
+ ########################################
135
+ # 4B. GPT-Style Response with Conversation History
136
+ ########################################
137
+
138
+ def generate_response_with_history(conversation_history, country: str) -> str:
139
+ """
140
+ Generates a response based on the entire conversation history.
141
+ - conversation_history: a list of turns, each is { "role": "user"/"bot", "content": str }
142
+ - country: used for helpline logic if self-harm is detected in the LAST user message
143
+ """
144
+ # Check the last user message for self-harm, etc.
145
+ last_user_message = conversation_history[-1]["content"]
146
+ lower_prompt = last_user_message.lower()
147
+
148
+ # 1) Self-harm detection
149
  self_harm_keywords = [
150
+ "suicide", "kill myself", "end my life", "self-harm",
151
+ "hurt myself", "want to die", "no reason to live", "tired of living"
152
  ]
153
  if any(keyword in lower_prompt for keyword in self_harm_keywords):
154
+ logging.info(f"Self-harm keyword detected in prompt: {last_user_message}")
155
  helpline_str = get_helpline_for_country(country)
 
156
  return (
157
  "I’m really sorry you’re feeling like this. It sounds like you’re in a very tough place right now. "
158
  "If you’re comfortable, could you share more about what’s bringing you to feel this way? "
 
161
  "You’re not alone, and there are caring people who want to help you."
162
  )
163
 
164
+ # 2) Harm-to-others / Hate speech can be similarly handled, or you can do a single pass prior to generation.
 
165
 
166
+ # 3) Construct a system prompt to direct the style:
167
  system_prompt = (
168
  "You are a supportive, empathetic companion. "
169
  "Your top priority is to listen and help the user feel heard. "
170
  "You:\n"
171
  "- Acknowledge the user's feelings.\n"
172
+ "- Ask gentle follow-up questions.\n"
173
+ "- Suggest small, practical coping ideas.\n"
174
  "- Avoid judgment or lecturing.\n"
175
+ "- Provide helpline info only if user is in crisis.\n"
176
+ "Do NOT prefix lines with 'User:' or 'Bot:'.\n"
177
  )
178
 
179
+ # 4) Build a single text prompt from the entire conversation
180
+ # Start with system instructions:
181
+ conversation_text = system_prompt
182
 
183
+ for turn in conversation_history:
184
+ if turn["role"] == "user":
185
+ conversation_text += f"\nUser says: {turn['content']}"
186
+ else:
187
+ conversation_text += f"\nBot says: {turn['content']}"
188
+ conversation_text += "\nBot says:"
189
 
190
+ # 5) Convert text to tokens and generate
191
+ inputs = tokenizer(conversation_text, return_tensors="pt").to(device)
192
  with torch.no_grad():
193
  outputs = model.generate(
194
  **inputs,
 
202
  )
203
 
204
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
205
+
206
+ # 6) Clean up any repeated system text
207
+ response = response.replace(system_prompt, "")
208
+ response = response.replace("User says:", "").replace("Bot says:", "")
209
+ response = response.strip()
210
+
211
+ # Make sure it ends with punctuation
212
  if response and response[-1] not in ".!?":
213
+ response += "."
 
 
 
214
 
215
+ return response
216
+
217
+
218
+ ########################################
219
+ # 4C. Possibly Summarize Older History
220
+ ########################################
221
+
222
+ def maybe_summarize_history(conversation_history, max_turns=10):
223
+ """
224
+ If conversation_history is too long, summarize the earliest turns into a single 'summary so far'
225
+ turn. This helps keep token usage manageable.
226
+ """
227
+ if len(conversation_history) > max_turns:
228
+ # Separate out the first chunk (e.g. first 5 turns).
229
+ old_turns = conversation_history[:-5]
230
+ recent_turns = conversation_history[-5:]
231
+
232
+ # Build a text from old_turns
233
+ old_text = ""
234
+ for turn in old_turns:
235
+ role = turn["role"]
236
+ content = turn["content"]
237
+ old_text += f"{role.upper()}:\n{content}\n\n"
238
+
239
+ summary = summarize_text(old_text)
240
+ # Now store that summary as a single turn with role "bot" (or "system")
241
+ summary_turn = {
242
+ "role": "bot",
243
+ "content": f"(A summary of earlier conversation: {summary})"
244
+ }
245
+
246
+ # Rebuild the conversation: summary + recent 5 turns
247
+ new_history = [summary_turn] + recent_turns
248
+
249
+ # Clear and replace
250
+ conversation_history.clear()
251
+ conversation_history.extend(new_history)
252
 
253
 
254
  ########################################
 
272
  ########################################
273
  # 6. CONVERSATION LOOP
274
  ########################################
275
+
276
+ # Keep a global conversation history for the console-based loop
277
+ conversation_history = []
278
+
279
  def chatbot_conversation(country_selection: str):
280
  """
281
  Starts the console-based conversation loop, using 'country_selection'
 
286
  print(f"Country set to: {country_selection}\n")
287
 
288
  print("Type 'exit' or 'quit' to end the conversation.")
289
+
290
  while True:
291
  user_input = input("\nYou: ")
292
  if user_input.lower() in ["exit", "quit"]:
 
303
  truncated_input = user_input[:HARD_LIMIT]
304
  print("(System Note): Your message exceeded 2048 chars and was truncated.")
305
 
306
+ # 1) Append user turn to the conversation
307
+ conversation_history.append({"role": "user", "content": truncated_input})
308
+
309
+ # 2) Summarize older history if we exceed ~10 turns
310
+ maybe_summarize_history(conversation_history, max_turns=10)
311
+
312
+ # 3) Generate response with the entire conversation
313
+ bot_response = generate_response_with_history(conversation_history, country_selection)
314
+
315
+ # 4) Append the bot's response to the conversation
316
+ conversation_history.append({"role": "bot", "content": bot_response})
317
+
318
+ # 5) Display
319
  print(bot_response)
320
 
321
  # Ask for feedback
 
327
  print("Please enter 'y' or 'n' or just hit Enter to skip.")
328
 
329
  if rating in ["y", "n"]:
 
330
  extra_comments = input("Any additional comments? (press Enter to skip): ")
331
  rating_symbol = "👍" if rating == "y" else "👎"
332
  save_user_feedback(truncated_input, bot_response, rating_symbol, extra_comments)
333
  print("(Feedback saved.)")
334
 
335
+
336
  ########################################
337
  # 7. TKINTER GUI WITH COUNTRY SELECTION
338
  ########################################
 
345
  root.title("Mental Health Chatbot Prototype")
346
 
347
  # Disclaimer message
348
+ disclaimer_label = tk.Label(
349
+ root,
350
+ text="Disclaimer: This is just a prototype and not a substitute for professional help.",
351
+ fg="red"
352
+ )
353
  disclaimer_label.pack(padx=20, pady=10)
354
 
355
  # Country selection label
 
374
 
375
  if __name__ == "__main__":
376
  main()
377
+