zakerytclarke commited on
Commit
e2b19bb
·
verified ·
1 Parent(s): 3947bbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import streamlit as st
2
  import hashlib
3
  import os
4
- import requests
 
5
  import time
6
  from langsmith import traceable
7
  import random
8
  import discord
9
- import os
10
  from transformers import pipeline
11
- import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
  import numpy as np
14
  from sklearn.metrics.pairwise import cosine_similarity
@@ -18,8 +17,6 @@ from tqdm import tqdm
18
  import re
19
  import os
20
 
21
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
22
-
23
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
24
  tokenizer = None
25
  model = None
@@ -28,11 +25,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
 
30
 
31
-
32
  def log_time(func):
33
- def wrapper(*args, **kwargs):
34
  start_time = time.time()
35
- result = func(*args, **kwargs)
36
  end_time = time.time()
37
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
38
  return result
@@ -42,24 +38,25 @@ def log_time(func):
42
  API_KEY = os.environ.get("brave_api_key")
43
 
44
  @log_time
45
- def brave_search(query, count=3):
46
  url = "https://api.search.brave.com/res/v1/web/search"
47
  headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
48
  params = {"q": query, "count": count}
49
-
50
- response = requests.get(url, headers=headers, params=params)
51
-
52
- if response.status_code == 200:
53
- results = response.json().get("web", {}).get("results", [])
54
- print(results)
55
- return [(res["title"], res["description"], res["url"]) for res in results]
56
- else:
57
- print(f"Error: {response.status_code}, {response.text}")
58
- return []
 
59
 
60
  @traceable
61
  @log_time
62
- def query_teapot(prompt, context, user_input):
63
  input_text = prompt + "\n" + context + "\n" + user_input
64
  print(input_text)
65
  start_time = time.time()
@@ -67,7 +64,7 @@ def query_teapot(prompt, context, user_input):
67
  inputs = tokenizer(input_text, return_tensors="pt")
68
  input_length = inputs["input_ids"].shape[1]
69
 
70
- output = model.generate(**inputs, max_new_tokens=512)
71
 
72
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
73
  total_length = output.shape[1] # Includes both input and output tokens
@@ -81,19 +78,18 @@ def query_teapot(prompt, context, user_input):
81
  return output_text
82
 
83
 
84
-
85
  @log_time
86
- def handle_chat(user_input):
87
  search_start_time = time.time()
88
- results = brave_search(user_input)
89
  search_end_time = time.time()
90
 
91
- documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results]
92
 
93
  context = "\n".join(documents)
94
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot"."""
95
  generation_start_time = time.time()
96
- response = query_teapot(prompt, context, user_input)
97
  generation_end_time = time.time()
98
 
99
  debug_info = f"""
@@ -108,9 +104,9 @@ Generation time: {generation_end_time - generation_start_time:.2f} seconds
108
  Response: {response}
109
  """
110
 
111
-
112
  return response, debug_info
113
 
 
114
  st.write("418 I'm a teapot")
115
 
116
  DISCORD_TOKEN = os.environ.get("discord_key")
@@ -135,11 +131,10 @@ async def on_message(message):
135
  return
136
  print(message.content)
137
 
138
-
139
  is_debug = "<debug>" in message.content
140
  async with message.channel.typing():
141
  # Respond with "pong" if the message contains "ping"
142
- response, debug_info = handle_chat(message.content.replace("<debug>","").replace("</debug>",""))
143
 
144
  print(response)
145
  sent_message = await message.reply(response)
@@ -153,6 +148,4 @@ async def on_message(message):
153
 
154
 
155
  # Run the bot with your token
156
-
157
  client.run(DISCORD_TOKEN)
158
-
 
1
  import streamlit as st
2
  import hashlib
3
  import os
4
+ import aiohttp
5
+ import asyncio
6
  import time
7
  from langsmith import traceable
8
  import random
9
  import discord
 
10
  from transformers import pipeline
 
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
  import numpy as np
13
  from sklearn.metrics.pairwise import cosine_similarity
 
17
  import re
18
  import os
19
 
 
 
20
  st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
21
  tokenizer = None
22
  model = None
 
25
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
 
27
 
 
28
  def log_time(func):
29
+ async def wrapper(*args, **kwargs):
30
  start_time = time.time()
31
+ result = await func(*args, **kwargs) # Make it awaitable
32
  end_time = time.time()
33
  print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
34
  return result
 
38
  API_KEY = os.environ.get("brave_api_key")
39
 
40
  @log_time
41
+ async def brave_search(query, count=3):
42
  url = "https://api.search.brave.com/res/v1/web/search"
43
  headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
44
  params = {"q": query, "count": count}
45
+
46
+ async with aiohttp.ClientSession() as session:
47
+ async with session.get(url, headers=headers, params=params) as response:
48
+ if response.status == 200:
49
+ results = await response.json()
50
+ print(results)
51
+ return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])]
52
+ else:
53
+ print(f"Error: {response.status}, {await response.text()}")
54
+ return []
55
+
56
 
57
  @traceable
58
  @log_time
59
+ async def query_teapot(prompt, context, user_input):
60
  input_text = prompt + "\n" + context + "\n" + user_input
61
  print(input_text)
62
  start_time = time.time()
 
64
  inputs = tokenizer(input_text, return_tensors="pt")
65
  input_length = inputs["input_ids"].shape[1]
66
 
67
+ output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
68
 
69
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
  total_length = output.shape[1] # Includes both input and output tokens
 
78
  return output_text
79
 
80
 
 
81
  @log_time
82
+ async def handle_chat(user_input):
83
  search_start_time = time.time()
84
+ results = await brave_search(user_input)
85
  search_end_time = time.time()
86
 
87
+ documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
88
 
89
  context = "\n".join(documents)
90
  prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot"."""
91
  generation_start_time = time.time()
92
+ response = await query_teapot(prompt, context, user_input)
93
  generation_end_time = time.time()
94
 
95
  debug_info = f"""
 
104
  Response: {response}
105
  """
106
 
 
107
  return response, debug_info
108
 
109
+
110
  st.write("418 I'm a teapot")
111
 
112
  DISCORD_TOKEN = os.environ.get("discord_key")
 
131
  return
132
  print(message.content)
133
 
 
134
  is_debug = "<debug>" in message.content
135
  async with message.channel.typing():
136
  # Respond with "pong" if the message contains "ping"
137
+ response, debug_info = await handle_chat(message.content.replace("<debug>", "").replace("</debug>", ""))
138
 
139
  print(response)
140
  sent_message = await message.reply(response)
 
148
 
149
 
150
  # Run the bot with your token
 
151
  client.run(DISCORD_TOKEN)