blackccpie commited on
Commit
16f8a25
·
1 Parent(s): d7d8164

feat : managing search credits display.

Browse files
Files changed (4) hide show
  1. agent.py +8 -1
  2. agent_ui.py +47 -23
  3. app.py +4 -3
  4. web_tools.py +27 -8
agent.py CHANGED
@@ -25,7 +25,7 @@ import os
25
  from smolagents import CodeAgent, InferenceClientModel
26
 
27
  from other_tools import ImageQueryTool
28
- from web_tools import TavilySearchTool, TavilyExtractTool, TavilyImageURLSearchTool
29
 
30
  class SmolAlbert(CodeAgent):
31
  """
@@ -81,3 +81,10 @@ class SmolAlbert(CodeAgent):
81
  Reset the agent's internal state.
82
  """
83
  self.agent.memory.reset()
 
 
 
 
 
 
 
 
25
  from smolagents import CodeAgent, InferenceClientModel
26
 
27
  from other_tools import ImageQueryTool
28
+ from web_tools import TavilyBaseClient, TavilySearchTool, TavilyExtractTool, TavilyImageURLSearchTool
29
 
30
  class SmolAlbert(CodeAgent):
31
  """
 
81
  Reset the agent's internal state.
82
  """
83
  self.agent.memory.reset()
84
+
85
+ @staticmethod
86
+ def get_search_credits() -> str:
87
+ """
88
+ Get the current search credits of the Tavily API.
89
+ """
90
+ return TavilyBaseClient.get_usage()
agent_ui.py CHANGED
@@ -42,6 +42,8 @@ from smolagents.agents import MultiStepAgent, PlanningStep
42
  from smolagents.memory import ActionStep, FinalAnswerStep
43
  from smolagents.models import ChatMessageStreamDelta, MessageRole, agglomerate_stream_deltas
44
 
 
 
45
  def get_step_footnote_content(step_log: ActionStep | PlanningStep, step_name: str) -> str:
46
  """Get a footnote string for a step log with duration and token information"""
47
  step_footnote = f"**{step_name}**"
@@ -218,7 +220,7 @@ def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator:
218
  if isinstance(final_answer, AgentText):
219
  yield gr.ChatMessage(
220
  role=MessageRole.ASSISTANT,
221
- content=f"**Final answer:**\n{final_answer.to_string()}\n",
222
  metadata={"status": "done"},
223
  )
224
  elif isinstance(final_answer, AgentImage):
@@ -235,7 +237,7 @@ def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator:
235
  )
236
  else:
237
  yield gr.ChatMessage(
238
- role=MessageRole.ASSISTANT, content=f"**Final answer:** {str(final_answer)}", metadata={"status": "done"}
239
  )
240
 
241
 
@@ -331,11 +333,11 @@ class AgentUI:
331
  content_text = msg.content if isinstance(msg.content, str) else ""
332
 
333
  # Detect final answer messages and append to quiet
334
- # HACK : FinalAnswerStep messages are produced by _process_final_answer_step and use "**Final answer:**" text
335
- if "Final answer:" in content_text:
336
- # Remove everything before and including the "Final answer:" label (and any leading/trailing whitespace/newlines)
337
  answer_only = re.sub(
338
- r"(?s)^.*?\*\*Final answer:\*\*\s*[\n]*", # (?s) allows . to match newlines
339
  "",
340
  content_text,
341
  flags=re.IGNORECASE,
@@ -429,10 +431,17 @@ class AgentUI:
429
  """
430
  self.create_app().launch(debug=True, share=share, **kwargs)
431
 
 
 
 
 
 
 
432
  def create_app(self):
433
  import gradio as gr
434
 
435
- with gr.Blocks(theme="glass", fill_height=True) as agent:
 
436
 
437
  # Set up states to hold the session information
438
  stored_query = gr.State("") # current user query
@@ -454,31 +463,38 @@ class AgentUI:
454
  )
455
  submit_btn = gr.Button("Submit", variant="primary")
456
 
 
 
 
 
 
 
 
457
  gr.HTML(
458
  "<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>"
459
  )
460
 
461
  with gr.Tab("Quiet", scale=1):
 
462
  quiet_chatbot = gr.Chatbot(
463
- label="Agent",
464
- type="messages",
465
- avatar_images=(
466
- None,
467
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
468
- ),
469
- resizeable=True,
470
- scale=1,
471
- latex_delimiters=[
472
- {"left": r"$$", "right": r"$$", "display": True},
473
- {"left": r"$", "right": r"$", "display": False},
474
- {"left": r"\[", "right": r"\]", "display": True},
475
- {"left": r"\(", "right": r"\)", "display": False},
476
- ],
477
- )
478
 
479
  with gr.Tab("Chatterbox", scale=1):
480
 
481
- # Main chat interface
482
  verbose_chatbot = gr.Chatbot(
483
  label="Agent",
484
  type="messages",
@@ -505,6 +521,10 @@ class AgentUI:
505
  self.interact_with_agent,
506
  [stored_query, stored_messages_verbose, stored_messages_quiet],
507
  [verbose_chatbot, quiet_chatbot],
 
 
 
 
508
  ).then(
509
  self.enable_query,
510
  None,
@@ -519,6 +539,10 @@ class AgentUI:
519
  self.interact_with_agent,
520
  [stored_query, stored_messages_verbose, stored_messages_quiet],
521
  [verbose_chatbot, quiet_chatbot],
 
 
 
 
522
  ).then(
523
  self.enable_query,
524
  None,
 
42
  from smolagents.memory import ActionStep, FinalAnswerStep
43
  from smolagents.models import ChatMessageStreamDelta, MessageRole, agglomerate_stream_deltas
44
 
45
+ FINAL_ANSWER_TAG = "Final answer:"
46
+
47
  def get_step_footnote_content(step_log: ActionStep | PlanningStep, step_name: str) -> str:
48
  """Get a footnote string for a step log with duration and token information"""
49
  step_footnote = f"**{step_name}**"
 
220
  if isinstance(final_answer, AgentText):
221
  yield gr.ChatMessage(
222
  role=MessageRole.ASSISTANT,
223
+ content=f"**{FINAL_ANSWER_TAG}**\n{final_answer.to_string()}\n",
224
  metadata={"status": "done"},
225
  )
226
  elif isinstance(final_answer, AgentImage):
 
237
  )
238
  else:
239
  yield gr.ChatMessage(
240
+ role=MessageRole.ASSISTANT, content=f"**{FINAL_ANSWER_TAG}** {str(final_answer)}", metadata={"status": "done"}
241
  )
242
 
243
 
 
333
  content_text = msg.content if isinstance(msg.content, str) else ""
334
 
335
  # Detect final answer messages and append to quiet
336
+ # HACK : FinalAnswerStep messages are produced by _process_final_answer_step and use FINAL_ANSWER_TAG
337
+ if FINAL_ANSWER_TAG in content_text:
338
+ # Remove everything before and including the FINAL_ANSWER_TAG label (and any leading/trailing whitespace/newlines)
339
  answer_only = re.sub(
340
+ rf"(?s)^.*?\*\*{FINAL_ANSWER_TAG}\*\*\s*[\n]*", # (?s) allows . to match newlines
341
  "",
342
  content_text,
343
  flags=re.IGNORECASE,
 
431
  """
432
  self.create_app().launch(debug=True, share=share, **kwargs)
433
 
434
+ def get_tavily_credits(self):
435
+ """
436
+ Fetch the Tavily credits.
437
+ """
438
+ return self.agent.get_search_credits()
439
+
440
  def create_app(self):
441
  import gradio as gr
442
 
443
+ # some nice thmes available here: https://huggingface.co/spaces/gradio/theme-gallery
444
+ with gr.Blocks(theme="JohnSmith9982/small_and_pretty", fill_height=True) as agent:
445
 
446
  # Set up states to hold the session information
447
  stored_query = gr.State("") # current user query
 
463
  )
464
  submit_btn = gr.Button("Submit", variant="primary")
465
 
466
+ tavily_credits = gr.Textbox(
467
+ label="Tavily Credits",
468
+ value=self.get_tavily_credits(),
469
+ interactive=False,
470
+ container=True,
471
+ )
472
+
473
  gr.HTML(
474
  "<br><br><h4><center>Powered by <a target='_blank' href='https://github.com/huggingface/smolagents'><b>smolagents</b></a></center></h4>"
475
  )
476
 
477
  with gr.Tab("Quiet", scale=1):
478
+
479
  quiet_chatbot = gr.Chatbot(
480
+ label="Agent",
481
+ type="messages",
482
+ avatar_images=(
483
+ None,
484
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
485
+ ),
486
+ resizeable=True,
487
+ scale=1,
488
+ latex_delimiters=[
489
+ {"left": r"$$", "right": r"$$", "display": True},
490
+ {"left": r"$", "right": r"$", "display": False},
491
+ {"left": r"\[", "right": r"\]", "display": True},
492
+ {"left": r"\(", "right": r"\)", "display": False},
493
+ ],
494
+ )
495
 
496
  with gr.Tab("Chatterbox", scale=1):
497
 
 
498
  verbose_chatbot = gr.Chatbot(
499
  label="Agent",
500
  type="messages",
 
521
  self.interact_with_agent,
522
  [stored_query, stored_messages_verbose, stored_messages_quiet],
523
  [verbose_chatbot, quiet_chatbot],
524
+ ).then(
525
+ self.get_tavily_credits,
526
+ None,
527
+ tavily_credits,
528
  ).then(
529
  self.enable_query,
530
  None,
 
539
  self.interact_with_agent,
540
  [stored_query, stored_messages_verbose, stored_messages_quiet],
541
  [verbose_chatbot, quiet_chatbot],
542
+ ).then(
543
+ self.get_tavily_credits,
544
+ None,
545
+ tavily_credits,
546
  ).then(
547
  self.enable_query,
548
  None,
app.py CHANGED
@@ -23,6 +23,7 @@
23
  from agent import SmolAlbert
24
  from agent_ui import AgentUI
25
 
26
- agent = SmolAlbert()
27
- agent_ui = AgentUI(agent)
28
- agent_ui.launch(share=False)
 
 
23
  from agent import SmolAlbert
24
  from agent_ui import AgentUI
25
 
26
+ if __name__ == "__main__":
27
+ agent = SmolAlbert()
28
+ agent_ui = AgentUI(agent)
29
+ agent_ui.launch(share=False)
web_tools.py CHANGED
@@ -21,13 +21,32 @@
21
  # THE SOFTWARE.
22
 
23
  import os
 
24
 
25
  from smolagents import Tool
26
  from tavily import TavilyClient
27
 
28
- tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
 
 
29
 
30
- class TavilySearchTool(Tool):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
  A tool to perform web searches using the Tavily API.
33
  """
@@ -42,10 +61,10 @@ class TavilySearchTool(Tool):
42
  output_type = "string"
43
 
44
  def forward(self, query: str):
45
- response = tavily_client.search(query)
46
  return response
47
-
48
- class TavilyExtractTool(Tool):
49
  """
50
  A tool to extract information from web pages using the Tavily API.
51
  """
@@ -60,10 +79,10 @@ class TavilyExtractTool(Tool):
60
  output_type = "string"
61
 
62
  def forward(self, url: str):
63
- response = tavily_client.extract(url)
64
  return response
65
 
66
- class TavilyImageURLSearchTool(Tool):
67
  """
68
  A tool to search for image URLs using the Tavily API.
69
  """
@@ -78,7 +97,7 @@ class TavilyImageURLSearchTool(Tool):
78
  output_type = "string"
79
 
80
  def forward(self, query: str):
81
- response = tavily_client.search(query, include_images=True)
82
 
83
  images = response.get("images", [])
84
 
 
21
  # THE SOFTWARE.
22
 
23
  import os
24
+ import requests
25
 
26
  from smolagents import Tool
27
  from tavily import TavilyClient
28
 
29
+ class TavilyBaseClient:
30
+ __api_key = os.getenv("TAVILY_API_KEY")
31
+ _tavily_client = TavilyClient(api_key=__api_key)
32
 
33
+ @staticmethod
34
+ def get_usage() -> str:
35
+ url = "https://api.tavily.com/usage"
36
+ headers = {
37
+ "Authorization": f"Bearer {TavilyBaseClient.__api_key}",
38
+ "Content-Type": "application/json",
39
+ }
40
+ res = requests.get(url, headers=headers)
41
+ res.raise_for_status()
42
+
43
+ account = res.json().get("account", {})
44
+ plan_usage = account.get("plan_usage")
45
+ plan_limit = account.get("plan_limit")
46
+
47
+ return f"{plan_usage}/{plan_limit}"
48
+
49
+ class TavilySearchTool(TavilyBaseClient, Tool):
50
  """
51
  A tool to perform web searches using the Tavily API.
52
  """
 
61
  output_type = "string"
62
 
63
  def forward(self, query: str):
64
+ response = self._tavily_client.search(query)
65
  return response
66
+
67
+ class TavilyExtractTool(TavilyBaseClient, Tool):
68
  """
69
  A tool to extract information from web pages using the Tavily API.
70
  """
 
79
  output_type = "string"
80
 
81
  def forward(self, url: str):
82
+ response = self._tavily_client.extract(url)
83
  return response
84
 
85
+ class TavilyImageURLSearchTool(TavilyBaseClient, Tool):
86
  """
87
  A tool to search for image URLs using the Tavily API.
88
  """
 
97
  output_type = "string"
98
 
99
  def forward(self, query: str):
100
+ response = self._tavily_client.search(query, include_images=True)
101
 
102
  images = response.get("images", [])
103