zakerytclarke commited on
Commit
d97238d
·
verified ·
1 Parent(s): d6c6437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -53,29 +53,45 @@ async def brave_search(query, count=1):
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
56
- @traceable
57
- @log_time
58
- def query_teapot(prompt, context, user_input):
59
- input_text = prompt + "\n" + context + "\n" + user_input
60
 
61
- start_time = time.time()
62
 
63
- inputs = tokenizer(input_text, return_tensors="pt")
64
- input_length = inputs["input_ids"].shape[1]
65
 
66
- output = model.generate(**inputs, max_new_tokens=512)
67
 
68
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
69
- total_length = output.shape[1] # Includes both input and output tokens
70
- output_length = total_length - input_length # Extract output token count
71
 
72
- end_time = time.time()
73
 
74
- elapsed_time = end_time - start_time
75
- tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return output_text
78
 
 
79
  @log_time
80
  async def handle_chat(user_input):
81
  search_start_time = time.time()
 
53
  print(f"Error: {response.status}, {await response.text()}")
54
  return []
55
 
56
+ # @traceable
57
+ # @log_time
58
+ # def query_teapot(prompt, context, user_input):
59
+ # input_text = prompt + "\n" + context + "\n" + user_input
60
 
61
+ # start_time = time.time()
62
 
63
+ # inputs = tokenizer(input_text, return_tensors="pt")
64
+ # input_length = inputs["input_ids"].shape[1]
65
 
66
+ # output = model.generate(**inputs, max_new_tokens=512)
67
 
68
+ # output_text = tokenizer.decode(output[0], skip_special_tokens=True)
69
+ # total_length = output.shape[1] # Includes both input and output tokens
70
+ # output_length = total_length - input_length # Extract output token count
71
 
72
+ # end_time = time.time()
73
 
74
+ # elapsed_time = end_time - start_time
75
+ # tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
76
 
77
+ # return output_text
78
+
79
+
80
+ pipeline_lock = asyncio.Lock()
81
+
82
+ @traceable
83
+ @log_time
84
+ async def query_teapot(prompt, context, user_input):
85
+ input_text = prompt + "\n" + context + "\n" + user_input
86
+ inputs = tokenizer(input_text, return_tensors="pt")
87
+
88
+ async with pipeline_lock: # Ensure only one call runs at a time
89
+ output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
90
+
91
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
92
  return output_text
93
 
94
+
95
  @log_time
96
  async def handle_chat(user_input):
97
  search_start_time = time.time()