zakerytclarke commited on
Commit
90ef2fa
·
verified ·
1 Parent(s): 9db1f1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -35
app.py CHANGED
@@ -53,43 +53,21 @@ async def brave_search(query, count=1):
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
56
- @traceable
57
- @log_time
58
- def query_teapot(prompt, context, user_input):
59
- input_text = prompt + "\n" + context + "\n" + user_input
60
-
61
- start_time = time.time()
62
-
63
- inputs = tokenizer(input_text, return_tensors="pt")
64
- input_length = inputs["input_ids"].shape[1]
65
-
66
- output = model.generate(**inputs, max_new_tokens=512)
67
-
68
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
69
- total_length = output.shape[1] # Includes both input and output tokens
70
- output_length = total_length - input_length # Extract output token count
71
-
72
- end_time = time.time()
73
-
74
- elapsed_time = end_time - start_time
75
- tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
76
-
77
- return output_text
78
-
79
 
80
- # pipeline_lock = asyncio.Lock()
81
 
82
- # @traceable
83
- # @log_time
84
- # async def query_teapot(prompt, context, user_input):
85
- # input_text = prompt + "\n" + context + "\n" + user_input
86
- # inputs = tokenizer(input_text, return_tensors="pt")
87
 
88
- # async with pipeline_lock: # Ensure only one call runs at a time
89
- # output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
90
 
91
- # output_text = tokenizer.decode(output[0], skip_special_tokens=True)
92
- # return output_text
 
 
 
 
 
 
 
 
93
 
94
 
95
  @log_time
@@ -172,5 +150,11 @@ async def on_message(message):
172
  await thread.send(debug_info)
173
 
174
 
175
- # Run the bot with your token
176
- client.run(DISCORD_TOKEN)
 
 
 
 
 
 
 
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
 
 
 
 
 
 
58
 
59
+ pipeline_lock = asyncio.Lock()
 
60
 
61
+ @traceable
62
+ @log_time
63
+ async def query_teapot(prompt, context, user_input):
64
+ input_text = prompt + "\n" + context + "\n" + user_input
65
+ inputs = tokenizer(input_text, return_tensors="pt")
66
+
67
+ async with pipeline_lock: # Ensure only one call runs at a time
68
+ output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
69
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
+ return output_text
71
 
72
 
73
  @log_time
 
150
  await thread.send(debug_info)
151
 
152
 
153
+ @st.cache_resource
154
+ def initialize():
155
+ st.session_state["initialized"] = True
156
+ client.run(DISCORD_TOKEN)
157
+
158
+ return
159
+
160
+ initialize()