Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- app.py +15 -14
- core/agent.py +35 -46
- core/tool_recommender.py +23 -13
- database/setup.py +38 -27
- tools/news_tool.py +7 -7
- tools/stock_tool.py +8 -8
- tools/tool_registry.py +11 -12
app.py
CHANGED
@@ -3,39 +3,41 @@ import os
|
|
3 |
import time
|
4 |
|
5 |
# ------------------------------------------------------------------
|
6 |
-
# 1.
|
7 |
# ------------------------------------------------------------------
|
8 |
api_key = os.environ.get("GEMINI_API_KEY")
|
9 |
if not api_key:
|
10 |
-
print(
|
|
|
|
|
11 |
|
12 |
# ------------------------------------------------------------------
|
13 |
-
# 2.
|
14 |
# ------------------------------------------------------------------
|
15 |
-
print("---
|
16 |
try:
|
17 |
from database.setup import initialize_system
|
18 |
from core.agent import SmartAIAgent
|
19 |
|
20 |
-
print("
|
21 |
|
22 |
registered_tools, tool_recommender = initialize_system()
|
23 |
-
print("
|
24 |
agent = SmartAIAgent(
|
25 |
tool_recommender=tool_recommender,
|
26 |
registered_tools=registered_tools,
|
27 |
api_key=api_key,
|
28 |
)
|
29 |
-
print("AI
|
30 |
except Exception as e:
|
31 |
-
print(f"
|
32 |
agent = None
|
33 |
|
34 |
-
print("--- FeiMatrix Synapse
|
35 |
|
36 |
|
37 |
# ------------------------------------------------------------------
|
38 |
-
# 3. Gradio
|
39 |
# ------------------------------------------------------------------
|
40 |
def handle_user_message(user_input, history):
|
41 |
if not user_input.strip():
|
@@ -72,7 +74,7 @@ def generate_bot_response(history):
|
|
72 |
|
73 |
|
74 |
# ------------------------------------------------------------------
|
75 |
-
# 4.
|
76 |
# ------------------------------------------------------------------
|
77 |
custom_css = """
|
78 |
#chatbot .message-bubble-content { color: #000000 !important; }
|
@@ -87,7 +89,6 @@ with gr.Blocks(
|
|
87 |
title="FeiMatrix Synapse",
|
88 |
) as demo:
|
89 |
|
90 |
-
# --- 界面文本已全部修改为英文 ---
|
91 |
gr.Markdown(
|
92 |
"""
|
93 |
# 🚀 FeiMatrix Synapse - Intelligent AI Assistant
|
@@ -154,13 +155,13 @@ with gr.Blocks(
|
|
154 |
elem_classes="footer",
|
155 |
)
|
156 |
|
157 |
-
# --- 对话事件的触发流程 (保持不变) ---
|
158 |
submit_event = text_input.submit(
|
159 |
fn=handle_user_message,
|
160 |
inputs=[text_input, chatbot],
|
161 |
outputs=[text_input, chatbot],
|
162 |
queue=False,
|
163 |
).then(fn=generate_bot_response, inputs=[chatbot], outputs=[chatbot])
|
|
|
164 |
submit_button.click(
|
165 |
fn=handle_user_message,
|
166 |
inputs=[text_input, chatbot],
|
@@ -169,7 +170,7 @@ with gr.Blocks(
|
|
169 |
).then(fn=generate_bot_response, inputs=[chatbot], outputs=[chatbot])
|
170 |
|
171 |
# ------------------------------------------------------------------
|
172 |
-
# 5.
|
173 |
# ------------------------------------------------------------------
|
174 |
if __name__ == "__main__":
|
175 |
demo.queue()
|
|
|
3 |
import time
|
4 |
|
5 |
# ------------------------------------------------------------------
|
6 |
+
# 1. Load Environment Variables
|
7 |
# ------------------------------------------------------------------
|
8 |
api_key = os.environ.get("GEMINI_API_KEY")
|
9 |
if not api_key:
|
10 |
+
print(
|
11 |
+
"Warning: GEMINI_API_KEY not found. Please set it in your Hugging Face Spaces Secrets."
|
12 |
+
)
|
13 |
|
14 |
# ------------------------------------------------------------------
|
15 |
+
# 2. Initialize Backend
|
16 |
# ------------------------------------------------------------------
|
17 |
+
print("--- Starting FeiMatrix Synapse System ---")
|
18 |
try:
|
19 |
from database.setup import initialize_system
|
20 |
from core.agent import SmartAIAgent
|
21 |
|
22 |
+
print("Core modules imported successfully.")
|
23 |
|
24 |
registered_tools, tool_recommender = initialize_system()
|
25 |
+
print("System database and tool recommender initialized successfully.")
|
26 |
agent = SmartAIAgent(
|
27 |
tool_recommender=tool_recommender,
|
28 |
registered_tools=registered_tools,
|
29 |
api_key=api_key,
|
30 |
)
|
31 |
+
print("AI Agent Core created successfully.")
|
32 |
except Exception as e:
|
33 |
+
print(f"A critical error occurred during system initialization: {e}")
|
34 |
agent = None
|
35 |
|
36 |
+
print("--- FeiMatrix Synapse is ready ---")
|
37 |
|
38 |
|
39 |
# ------------------------------------------------------------------
|
40 |
+
# 3. Gradio Event Handler Functions
|
41 |
# ------------------------------------------------------------------
|
42 |
def handle_user_message(user_input, history):
|
43 |
if not user_input.strip():
|
|
|
74 |
|
75 |
|
76 |
# ------------------------------------------------------------------
|
77 |
+
# 4. Create Gradio Interface
|
78 |
# ------------------------------------------------------------------
|
79 |
custom_css = """
|
80 |
#chatbot .message-bubble-content { color: #000000 !important; }
|
|
|
89 |
title="FeiMatrix Synapse",
|
90 |
) as demo:
|
91 |
|
|
|
92 |
gr.Markdown(
|
93 |
"""
|
94 |
# 🚀 FeiMatrix Synapse - Intelligent AI Assistant
|
|
|
155 |
elem_classes="footer",
|
156 |
)
|
157 |
|
|
|
158 |
submit_event = text_input.submit(
|
159 |
fn=handle_user_message,
|
160 |
inputs=[text_input, chatbot],
|
161 |
outputs=[text_input, chatbot],
|
162 |
queue=False,
|
163 |
).then(fn=generate_bot_response, inputs=[chatbot], outputs=[chatbot])
|
164 |
+
|
165 |
submit_button.click(
|
166 |
fn=handle_user_message,
|
167 |
inputs=[text_input, chatbot],
|
|
|
170 |
).then(fn=generate_bot_response, inputs=[chatbot], outputs=[chatbot])
|
171 |
|
172 |
# ------------------------------------------------------------------
|
173 |
+
# 5. Launch the Application
|
174 |
# ------------------------------------------------------------------
|
175 |
if __name__ == "__main__":
|
176 |
demo.queue()
|
core/agent.py
CHANGED
@@ -3,32 +3,32 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
3 |
from typing import List, Any
|
4 |
import json
|
5 |
import os
|
6 |
-
import re
|
7 |
|
8 |
from .tool_recommender import DirectToolRecommender
|
9 |
from tools.tool_registry import get_tool_by_name
|
10 |
|
11 |
-
# Agent
|
12 |
AGENT_PROMPT_TEMPLATE = """
|
13 |
-
|
14 |
|
15 |
-
|
16 |
{tools}
|
17 |
|
18 |
-
|
19 |
{{
|
20 |
-
"tool": "
|
21 |
-
"tool_input": {{ "
|
22 |
}}
|
23 |
|
24 |
-
|
25 |
|
26 |
-
|
27 |
{chat_history}
|
28 |
|
29 |
-
|
30 |
|
31 |
-
|
32 |
"""
|
33 |
|
34 |
|
@@ -48,35 +48,28 @@ class SmartAIAgent:
|
|
48 |
convert_system_message_to_human=True,
|
49 |
)
|
50 |
self.chat_history = []
|
51 |
-
print(f"LangChain Agent
|
52 |
|
53 |
-
# ------------------- 核心修复在这里! -------------------
|
54 |
-
# 我们添加一个更健壮的JSON提取函数
|
55 |
def _extract_json_from_string(self, text: str) -> dict | None:
|
56 |
-
"""
|
57 |
-
# 匹配被 markdown 包裹的JSON
|
58 |
match = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL)
|
59 |
if match:
|
60 |
json_str = match.group(1)
|
61 |
else:
|
62 |
-
# 匹配裸露的JSON
|
63 |
match = re.search(r"\{.*\}", text, re.DOTALL)
|
64 |
if match:
|
65 |
json_str = match.group(0)
|
66 |
else:
|
67 |
return None
|
68 |
-
|
69 |
try:
|
70 |
return json.loads(json_str)
|
71 |
except json.JSONDecodeError:
|
72 |
return None
|
73 |
|
74 |
-
# ----------------------------------------------------
|
75 |
-
|
76 |
def _format_tools_for_prompt(self, tools: List[dict]) -> str:
|
77 |
-
|
78 |
if not tools:
|
79 |
-
return "
|
80 |
tool_strings = []
|
81 |
for tool in tools:
|
82 |
try:
|
@@ -85,44 +78,45 @@ class SmartAIAgent:
|
|
85 |
[f"{p_name}: {p_type}" for p_name, p_type in params.items()]
|
86 |
)
|
87 |
tool_strings.append(
|
88 |
-
f"-
|
89 |
)
|
90 |
except (json.JSONDecodeError, TypeError):
|
91 |
tool_strings.append(
|
92 |
-
f"-
|
93 |
)
|
94 |
return "\n".join(tool_strings)
|
95 |
|
96 |
def _format_chat_history(self) -> str:
|
97 |
-
|
98 |
formatted_history = []
|
99 |
for msg in self.chat_history:
|
100 |
if isinstance(msg, HumanMessage):
|
101 |
-
formatted_history.append(f"
|
102 |
elif isinstance(msg, AIMessage):
|
103 |
-
formatted_history.append(f"
|
104 |
elif isinstance(msg, ToolMessage):
|
105 |
-
formatted_history.append(f"
|
106 |
return "\n".join(formatted_history)
|
107 |
|
108 |
def stream_run(self, user_input: str):
|
|
|
109 |
self.chat_history.append(HumanMessage(content=user_input))
|
110 |
-
yield "🤔
|
111 |
|
112 |
-
yield "🔍
|
113 |
recommended_tools_meta = self.tool_recommender.recommend_tools(user_input)
|
114 |
|
115 |
if not recommended_tools_meta:
|
116 |
-
yield "ℹ️
|
117 |
-
recommended_tools_prompt = "
|
118 |
else:
|
119 |
tool_names = [t["name"] for t in recommended_tools_meta]
|
120 |
-
yield f"✅
|
121 |
recommended_tools_prompt = self._format_tools_for_prompt(
|
122 |
recommended_tools_meta
|
123 |
)
|
124 |
|
125 |
-
yield f"🧠
|
126 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
127 |
tools=recommended_tools_prompt,
|
128 |
chat_history=self._format_chat_history(),
|
@@ -132,22 +126,19 @@ class SmartAIAgent:
|
|
132 |
llm_response = self.llm.invoke(prompt)
|
133 |
llm_decision_content = llm_response.content.strip()
|
134 |
|
135 |
-
# ------------------- 核心修复在这里! -------------------
|
136 |
-
# 使用我们新的、更健壮的JSON提取逻辑
|
137 |
decision = self._extract_json_from_string(llm_decision_content)
|
138 |
|
139 |
if decision and "tool" in decision and "tool_input" in decision:
|
140 |
-
# 如果成功提取出有效的工具调用JSON
|
141 |
tool_name = decision.get("tool")
|
142 |
tool_input = decision.get("tool_input")
|
143 |
|
144 |
-
yield f"💡 AI
|
145 |
|
146 |
tool_to_execute = get_tool_by_name(tool_name)
|
147 |
if tool_to_execute:
|
148 |
-
yield f"⚙️
|
149 |
tool_output = tool_to_execute.invoke(tool_input)
|
150 |
-
yield f"📊
|
151 |
|
152 |
self.chat_history.append(
|
153 |
AIMessage(content=json.dumps(decision, ensure_ascii=False))
|
@@ -156,8 +147,8 @@ class SmartAIAgent:
|
|
156 |
ToolMessage(content=str(tool_output), tool_call_id="N/A")
|
157 |
)
|
158 |
|
159 |
-
yield "✍️
|
160 |
-
final_answer_prompt = f"
|
161 |
final_answer_stream = self.llm.stream(final_answer_prompt)
|
162 |
full_final_answer = ""
|
163 |
for chunk in final_answer_stream:
|
@@ -165,10 +156,8 @@ class SmartAIAgent:
|
|
165 |
full_final_answer += chunk.content
|
166 |
self.chat_history.append(AIMessage(content=full_final_answer))
|
167 |
else:
|
168 |
-
yield f"❌
|
169 |
else:
|
170 |
-
|
171 |
-
yield "✅ AI决策:直接回答。\n\n"
|
172 |
yield llm_decision_content
|
173 |
self.chat_history.append(AIMessage(content=llm_decision_content))
|
174 |
-
# ----------------------------------------------------
|
|
|
3 |
from typing import List, Any
|
4 |
import json
|
5 |
import os
|
6 |
+
import re
|
7 |
|
8 |
from .tool_recommender import DirectToolRecommender
|
9 |
from tools.tool_registry import get_tool_by_name
|
10 |
|
11 |
+
# --- Agent Prompt, now fully in English ---
|
12 |
AGENT_PROMPT_TEMPLATE = """
|
13 |
+
You are a powerful AI assistant. Your task is to understand the user's question and decide if a tool is needed to answer it.
|
14 |
|
15 |
+
You have the following tools available:
|
16 |
{tools}
|
17 |
|
18 |
+
If you need to use a tool, you must respond in the following JSON format strictly, without any other text or explanation:
|
19 |
{{
|
20 |
+
"tool": "the_name_of_the_tool_to_call",
|
21 |
+
"tool_input": {{ "parameter1": "value1", "parameter2": "value2" }}
|
22 |
}}
|
23 |
|
24 |
+
If you do not need to use any tool, answer the user's question directly.
|
25 |
|
26 |
+
This is the conversation history:
|
27 |
{chat_history}
|
28 |
|
29 |
+
User's question: {input}
|
30 |
|
31 |
+
Now, think and provide your response (either JSON or a direct answer):
|
32 |
"""
|
33 |
|
34 |
|
|
|
48 |
convert_system_message_to_human=True,
|
49 |
)
|
50 |
self.chat_history = []
|
51 |
+
print(f"LangChain Agent initialized, using model: {self.model_name}.")
|
52 |
|
|
|
|
|
53 |
def _extract_json_from_string(self, text: str) -> dict | None:
|
54 |
+
"""Extracts a JSON block from a string that might contain other text."""
|
|
|
55 |
match = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL)
|
56 |
if match:
|
57 |
json_str = match.group(1)
|
58 |
else:
|
|
|
59 |
match = re.search(r"\{.*\}", text, re.DOTALL)
|
60 |
if match:
|
61 |
json_str = match.group(0)
|
62 |
else:
|
63 |
return None
|
|
|
64 |
try:
|
65 |
return json.loads(json_str)
|
66 |
except json.JSONDecodeError:
|
67 |
return None
|
68 |
|
|
|
|
|
69 |
def _format_tools_for_prompt(self, tools: List[dict]) -> str:
|
70 |
+
"""Formats the list of tools into a clear string for the prompt."""
|
71 |
if not tools:
|
72 |
+
return "No tools available."
|
73 |
tool_strings = []
|
74 |
for tool in tools:
|
75 |
try:
|
|
|
78 |
[f"{p_name}: {p_type}" for p_name, p_type in params.items()]
|
79 |
)
|
80 |
tool_strings.append(
|
81 |
+
f"- Tool Name: {tool['name']}\n - Description: {tool['description']}\n - Parameters: {param_str}"
|
82 |
)
|
83 |
except (json.JSONDecodeError, TypeError):
|
84 |
tool_strings.append(
|
85 |
+
f"- Tool Name: {tool['name']}\n - Description: {tool['description']}\n - Parameters: Could not be parsed"
|
86 |
)
|
87 |
return "\n".join(tool_strings)
|
88 |
|
89 |
def _format_chat_history(self) -> str:
|
90 |
+
"""Formats the chat history for the prompt."""
|
91 |
formatted_history = []
|
92 |
for msg in self.chat_history:
|
93 |
if isinstance(msg, HumanMessage):
|
94 |
+
formatted_history.append(f"User: {msg.content}")
|
95 |
elif isinstance(msg, AIMessage):
|
96 |
+
formatted_history.append(f"Assistant: {msg.content}")
|
97 |
elif isinstance(msg, ToolMessage):
|
98 |
+
formatted_history.append(f"Tool Result: {msg.content}")
|
99 |
return "\n".join(formatted_history)
|
100 |
|
101 |
def stream_run(self, user_input: str):
|
102 |
+
"""Processes user input in a streaming fashion."""
|
103 |
self.chat_history.append(HumanMessage(content=user_input))
|
104 |
+
yield "🤔 Analyzing your question...\n"
|
105 |
|
106 |
+
yield "🔍 Recommending relevant tools from the library...\n"
|
107 |
recommended_tools_meta = self.tool_recommender.recommend_tools(user_input)
|
108 |
|
109 |
if not recommended_tools_meta:
|
110 |
+
yield "ℹ️ No relevant tools found. Answering directly.\n"
|
111 |
+
recommended_tools_prompt = "No recommended tools."
|
112 |
else:
|
113 |
tool_names = [t["name"] for t in recommended_tools_meta]
|
114 |
+
yield f"✅ Recommended tools: `{', '.join(tool_names)}`\n"
|
115 |
recommended_tools_prompt = self._format_tools_for_prompt(
|
116 |
recommended_tools_meta
|
117 |
)
|
118 |
|
119 |
+
yield f"🧠 Letting the AI Brain ({self.model_name}) decide on the action...\n"
|
120 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
121 |
tools=recommended_tools_prompt,
|
122 |
chat_history=self._format_chat_history(),
|
|
|
126 |
llm_response = self.llm.invoke(prompt)
|
127 |
llm_decision_content = llm_response.content.strip()
|
128 |
|
|
|
|
|
129 |
decision = self._extract_json_from_string(llm_decision_content)
|
130 |
|
131 |
if decision and "tool" in decision and "tool_input" in decision:
|
|
|
132 |
tool_name = decision.get("tool")
|
133 |
tool_input = decision.get("tool_input")
|
134 |
|
135 |
+
yield f"💡 AI Action: Call tool `{tool_name}` with parameters `{tool_input}`\n"
|
136 |
|
137 |
tool_to_execute = get_tool_by_name(tool_name)
|
138 |
if tool_to_execute:
|
139 |
+
yield f"⚙️ Executing tool `{tool_name}`...\n"
|
140 |
tool_output = tool_to_execute.invoke(tool_input)
|
141 |
+
yield f"📊 Tool Result:\n---\n{str(tool_output)[:500]}...\n---\n"
|
142 |
|
143 |
self.chat_history.append(
|
144 |
AIMessage(content=json.dumps(decision, ensure_ascii=False))
|
|
|
147 |
ToolMessage(content=str(tool_output), tool_call_id="N/A")
|
148 |
)
|
149 |
|
150 |
+
yield "✍️ Generating final answer based on tool results...\n\n"
|
151 |
+
final_answer_prompt = f"Based on the conversation history and the latest tool result, generate a final, complete, and natural response for the user.\n\nConversation History:\n{self._format_chat_history()}\n\nPlease answer directly without mentioning your thought process."
|
152 |
final_answer_stream = self.llm.stream(final_answer_prompt)
|
153 |
full_final_answer = ""
|
154 |
for chunk in final_answer_stream:
|
|
|
156 |
full_final_answer += chunk.content
|
157 |
self.chat_history.append(AIMessage(content=full_final_answer))
|
158 |
else:
|
159 |
+
yield f"❌ Error: The tool `{tool_name}` decided by the AI does not exist.\n"
|
160 |
else:
|
161 |
+
yield "✅ AI Action: Answer directly.\n\n"
|
|
|
162 |
yield llm_decision_content
|
163 |
self.chat_history.append(AIMessage(content=llm_decision_content))
|
|
core/tool_recommender.py
CHANGED
@@ -6,27 +6,35 @@ from typing import List, Dict
|
|
6 |
|
7 |
|
8 |
class DirectToolRecommender:
|
|
|
|
|
|
|
|
|
|
|
9 |
def __init__(self, milvus_client: MilvusClient, sqlite_db_path: str):
|
10 |
self.milvus_client = milvus_client
|
11 |
self.sqlite_db_path = sqlite_db_path
|
12 |
self.collection_name = "tool_embeddings"
|
13 |
-
|
14 |
-
# ------------------- 核心修复在这里! -------------------
|
15 |
-
# 使用你指定的、有额度的嵌入模型
|
16 |
self.embedding_model_name = "gemini-embedding-exp-03-07"
|
17 |
-
# ----------------------------------------------------
|
18 |
|
19 |
api_key = os.environ.get("GEMINI_API_KEY")
|
20 |
if not api_key:
|
21 |
-
raise ValueError(
|
|
|
|
|
22 |
genai.configure(api_key=api_key)
|
23 |
|
24 |
-
print(
|
|
|
|
|
25 |
|
26 |
def recommend_tools(self, user_query: str, top_k: int = 3) -> List[Dict]:
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
-
# 1.
|
30 |
result = genai.embed_content(
|
31 |
model=self.embedding_model_name,
|
32 |
content=user_query,
|
@@ -34,7 +42,7 @@ class DirectToolRecommender:
|
|
34 |
)
|
35 |
query_embedding = result["embedding"]
|
36 |
|
37 |
-
# 2.
|
38 |
search_results = self.milvus_client.search(
|
39 |
collection_name=self.collection_name,
|
40 |
data=[query_embedding],
|
@@ -43,13 +51,13 @@ class DirectToolRecommender:
|
|
43 |
)
|
44 |
|
45 |
if not search_results or not search_results[0]:
|
46 |
-
print("[
|
47 |
return []
|
48 |
|
49 |
recommended_ids = [hit["id"] for hit in search_results[0]]
|
50 |
-
print(f"[
|
51 |
|
52 |
-
# 3.
|
53 |
with sqlite3.connect(self.sqlite_db_path) as conn:
|
54 |
cursor = conn.cursor()
|
55 |
if not recommended_ids:
|
@@ -71,5 +79,7 @@ class DirectToolRecommender:
|
|
71 |
if tool_id in id_to_tool_meta
|
72 |
]
|
73 |
|
74 |
-
print(
|
|
|
|
|
75 |
return sorted_tools
|
|
|
6 |
|
7 |
|
8 |
class DirectToolRecommender:
|
9 |
+
"""
|
10 |
+
Directly uses Milvus and Google GenAI for tool recommendation.
|
11 |
+
No dependency on LlamaIndex.
|
12 |
+
"""
|
13 |
+
|
14 |
def __init__(self, milvus_client: MilvusClient, sqlite_db_path: str):
|
15 |
self.milvus_client = milvus_client
|
16 |
self.sqlite_db_path = sqlite_db_path
|
17 |
self.collection_name = "tool_embeddings"
|
|
|
|
|
|
|
18 |
self.embedding_model_name = "gemini-embedding-exp-03-07"
|
|
|
19 |
|
20 |
api_key = os.environ.get("GEMINI_API_KEY")
|
21 |
if not api_key:
|
22 |
+
raise ValueError(
|
23 |
+
"Error: GEMINI_API_KEY not found. The recommender cannot function."
|
24 |
+
)
|
25 |
genai.configure(api_key=api_key)
|
26 |
|
27 |
+
print(
|
28 |
+
f"Direct Tool Recommender initialized, using embedding model: {self.embedding_model_name}."
|
29 |
+
)
|
30 |
|
31 |
def recommend_tools(self, user_query: str, top_k: int = 3) -> List[Dict]:
|
32 |
+
"""
|
33 |
+
Recommends the top_k most relevant tools based on the user query.
|
34 |
+
"""
|
35 |
+
print(f"\n[Tool Recommender] Received query: '{user_query}'")
|
36 |
|
37 |
+
# 1. Generate query embedding directly
|
38 |
result = genai.embed_content(
|
39 |
model=self.embedding_model_name,
|
40 |
content=user_query,
|
|
|
42 |
)
|
43 |
query_embedding = result["embedding"]
|
44 |
|
45 |
+
# 2. Search for similar tools in Milvus
|
46 |
search_results = self.milvus_client.search(
|
47 |
collection_name=self.collection_name,
|
48 |
data=[query_embedding],
|
|
|
51 |
)
|
52 |
|
53 |
if not search_results or not search_results[0]:
|
54 |
+
print("[Tool Recommender] No similar tools found in Milvus.")
|
55 |
return []
|
56 |
|
57 |
recommended_ids = [hit["id"] for hit in search_results[0]]
|
58 |
+
print(f"[Tool Recommender] Milvus recommended tool IDs: {recommended_ids}")
|
59 |
|
60 |
+
# 3. Get full tool metadata from SQLite and sort
|
61 |
with sqlite3.connect(self.sqlite_db_path) as conn:
|
62 |
cursor = conn.cursor()
|
63 |
if not recommended_ids:
|
|
|
79 |
if tool_id in id_to_tool_meta
|
80 |
]
|
81 |
|
82 |
+
print(
|
83 |
+
f"[Tool Recommender] Final recommended tools: {[t['name'] for t in sorted_tools]}"
|
84 |
+
)
|
85 |
return sorted_tools
|
database/setup.py
CHANGED
@@ -6,46 +6,52 @@ import google.generativeai as genai
|
|
6 |
|
7 |
from tools.tool_registry import get_all_tools
|
8 |
|
9 |
-
# ---
|
10 |
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
|
11 |
SQLITE_DB_PATH = os.path.join(DATA_DIR, "tools.metadata.db")
|
12 |
MILVUS_DATA_PATH = os.path.join(DATA_DIR, "milvus_lite.db")
|
13 |
|
14 |
-
# ---
|
15 |
EMBEDDING_DIM = 3072
|
16 |
EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
|
17 |
MILVUS_COLLECTION_NAME = "tool_embeddings"
|
18 |
|
19 |
|
20 |
def initialize_system():
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
22 |
os.makedirs(DATA_DIR, exist_ok=True)
|
23 |
|
24 |
-
# ---
|
25 |
|
26 |
-
# 1.
|
27 |
-
#
|
28 |
_init_sqlite_db()
|
29 |
all_tools_definitions = get_all_tools()
|
30 |
_sync_tools_to_sqlite(all_tools_definitions)
|
31 |
|
32 |
-
# 2.
|
33 |
-
#
|
34 |
milvus_client = _init_milvus_and_sync_embeddings()
|
35 |
|
36 |
-
# 3.
|
37 |
from core.tool_recommender import DirectToolRecommender
|
38 |
|
39 |
tool_recommender = DirectToolRecommender(
|
40 |
milvus_client=milvus_client, sqlite_db_path=SQLITE_DB_PATH
|
41 |
)
|
42 |
|
43 |
-
print("---
|
44 |
return all_tools_definitions, tool_recommender
|
45 |
|
46 |
|
47 |
def _init_sqlite_db():
|
48 |
-
|
|
|
49 |
with sqlite3.connect(SQLITE_DB_PATH) as conn:
|
50 |
cursor = conn.cursor()
|
51 |
cursor.execute(
|
@@ -59,11 +65,12 @@ def _init_sqlite_db():
|
|
59 |
"""
|
60 |
)
|
61 |
conn.commit()
|
62 |
-
print("SQLite DB
|
63 |
|
64 |
|
65 |
def _sync_tools_to_sqlite(tools_definitions):
|
66 |
-
|
|
|
67 |
with sqlite3.connect(SQLITE_DB_PATH) as conn:
|
68 |
cursor = conn.cursor()
|
69 |
for tool in tools_definitions:
|
@@ -73,21 +80,24 @@ def _sync_tools_to_sqlite(tools_definitions):
|
|
73 |
"INSERT INTO tools (name, description, parameters) VALUES (?, ?, ?)",
|
74 |
(tool.name, tool.description, json.dumps(tool.args)),
|
75 |
)
|
76 |
-
print(f" -
|
77 |
conn.commit()
|
78 |
-
print("SQLite
|
79 |
|
80 |
|
81 |
def _init_milvus_and_sync_embeddings():
|
82 |
-
|
|
|
83 |
client = MilvusClient(uri=MILVUS_DATA_PATH)
|
84 |
|
85 |
-
#
|
86 |
if client.has_collection(collection_name=MILVUS_COLLECTION_NAME):
|
87 |
client.drop_collection(collection_name=MILVUS_COLLECTION_NAME)
|
88 |
-
print("
|
89 |
|
90 |
-
print(
|
|
|
|
|
91 |
fields = [
|
92 |
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
93 |
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
|
@@ -102,9 +112,9 @@ def _init_milvus_and_sync_embeddings():
|
|
102 |
client.create_index(
|
103 |
collection_name=MILVUS_COLLECTION_NAME, index_params=index_params
|
104 |
)
|
105 |
-
print("Milvus
|
106 |
|
107 |
-
#
|
108 |
_sync_tool_embeddings_to_milvus(client)
|
109 |
|
110 |
client.load_collection(collection_name=MILVUS_COLLECTION_NAME)
|
@@ -112,10 +122,11 @@ def _init_milvus_and_sync_embeddings():
|
|
112 |
|
113 |
|
114 |
def _sync_tool_embeddings_to_milvus(milvus_client):
|
115 |
-
|
|
|
116 |
api_key = os.environ.get("GEMINI_API_KEY")
|
117 |
if not api_key:
|
118 |
-
print("
|
119 |
return
|
120 |
genai.configure(api_key=api_key)
|
121 |
|
@@ -125,13 +136,13 @@ def _sync_tool_embeddings_to_milvus(milvus_client):
|
|
125 |
all_tools_in_db = cursor.fetchall()
|
126 |
|
127 |
if not all_tools_in_db:
|
128 |
-
print("SQLite
|
129 |
return
|
130 |
|
131 |
-
print(f"
|
132 |
docs_to_embed = [tool[1] for tool in all_tools_in_db]
|
133 |
|
134 |
-
print(f"
|
135 |
result = genai.embed_content(
|
136 |
model=EMBEDDING_MODEL_NAME,
|
137 |
content=docs_to_embed,
|
@@ -147,4 +158,4 @@ def _sync_tool_embeddings_to_milvus(milvus_client):
|
|
147 |
]
|
148 |
|
149 |
milvus_client.insert(collection_name=MILVUS_COLLECTION_NAME, data=data_to_insert)
|
150 |
-
print(f"
|
|
|
6 |
|
7 |
from tools.tool_registry import get_all_tools
|
8 |
|
9 |
+
# --- Configuration for persistence paths ---
|
10 |
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
|
11 |
SQLITE_DB_PATH = os.path.join(DATA_DIR, "tools.metadata.db")
|
12 |
MILVUS_DATA_PATH = os.path.join(DATA_DIR, "milvus_lite.db")
|
13 |
|
14 |
+
# --- Model and DB Configuration ---
|
15 |
EMBEDDING_DIM = 3072
|
16 |
EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
|
17 |
MILVUS_COLLECTION_NAME = "tool_embeddings"
|
18 |
|
19 |
|
20 |
def initialize_system():
|
21 |
+
"""
|
22 |
+
The main system initialization function.
|
23 |
+
It creates directories, sets up the database and vector store, and loads tools.
|
24 |
+
This function is designed to be idempotent.
|
25 |
+
"""
|
26 |
+
print("--- Starting System Initialization (Final Version) ---")
|
27 |
os.makedirs(DATA_DIR, exist_ok=True)
|
28 |
|
29 |
+
# --- Correct Initialization Order ---
|
30 |
|
31 |
+
# 1. Initialize SQLite and sync tool metadata
|
32 |
+
# Ensures SQLite always has the latest tool information
|
33 |
_init_sqlite_db()
|
34 |
all_tools_definitions = get_all_tools()
|
35 |
_sync_tools_to_sqlite(all_tools_definitions)
|
36 |
|
37 |
+
# 2. Initialize Milvus and sync vector embeddings
|
38 |
+
# It reads data from the already populated SQLite DB
|
39 |
milvus_client = _init_milvus_and_sync_embeddings()
|
40 |
|
41 |
+
# 3. Create the tool recommender instance
|
42 |
from core.tool_recommender import DirectToolRecommender
|
43 |
|
44 |
tool_recommender = DirectToolRecommender(
|
45 |
milvus_client=milvus_client, sqlite_db_path=SQLITE_DB_PATH
|
46 |
)
|
47 |
|
48 |
+
print("--- System Initialization Complete ---")
|
49 |
return all_tools_definitions, tool_recommender
|
50 |
|
51 |
|
52 |
def _init_sqlite_db():
|
53 |
+
"""Initializes the SQLite database and creates the tools table if it doesn't exist."""
|
54 |
+
print(f"SQLite DB Path: {SQLITE_DB_PATH}")
|
55 |
with sqlite3.connect(SQLITE_DB_PATH) as conn:
|
56 |
cursor = conn.cursor()
|
57 |
cursor.execute(
|
|
|
65 |
"""
|
66 |
)
|
67 |
conn.commit()
|
68 |
+
print("SQLite DB table verified.")
|
69 |
|
70 |
|
71 |
def _sync_tools_to_sqlite(tools_definitions):
|
72 |
+
"""Syncs tool definitions into the SQLite database."""
|
73 |
+
print("Syncing tool metadata to SQLite...")
|
74 |
with sqlite3.connect(SQLITE_DB_PATH) as conn:
|
75 |
cursor = conn.cursor()
|
76 |
for tool in tools_definitions:
|
|
|
80 |
"INSERT INTO tools (name, description, parameters) VALUES (?, ?, ?)",
|
81 |
(tool.name, tool.description, json.dumps(tool.args)),
|
82 |
)
|
83 |
+
print(f" - Added new tool to SQLite: {tool.name}")
|
84 |
conn.commit()
|
85 |
+
print("SQLite sync complete.")
|
86 |
|
87 |
|
88 |
def _init_milvus_and_sync_embeddings():
|
89 |
+
"""Initializes Milvus Lite, rebuilds the collection, and syncs embeddings."""
|
90 |
+
print(f"Milvus Lite Data Path: {MILVUS_DATA_PATH}")
|
91 |
client = MilvusClient(uri=MILVUS_DATA_PATH)
|
92 |
|
93 |
+
# Recreate the collection on every startup to ensure correct dimensionality and fresh data for the demo.
|
94 |
if client.has_collection(collection_name=MILVUS_COLLECTION_NAME):
|
95 |
client.drop_collection(collection_name=MILVUS_COLLECTION_NAME)
|
96 |
+
print("Found old Milvus collection. Dropped it to rebuild.")
|
97 |
|
98 |
+
print(
|
99 |
+
f"Creating Milvus collection '{MILVUS_COLLECTION_NAME}' with dimension {EMBEDDING_DIM}..."
|
100 |
+
)
|
101 |
fields = [
|
102 |
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
103 |
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
|
|
|
112 |
client.create_index(
|
113 |
collection_name=MILVUS_COLLECTION_NAME, index_params=index_params
|
114 |
)
|
115 |
+
print("Milvus collection and index created successfully.")
|
116 |
|
117 |
+
# Critical Step: Now we sync the embeddings to the newly created collection
|
118 |
_sync_tool_embeddings_to_milvus(client)
|
119 |
|
120 |
client.load_collection(collection_name=MILVUS_COLLECTION_NAME)
|
|
|
122 |
|
123 |
|
124 |
def _sync_tool_embeddings_to_milvus(milvus_client):
|
125 |
+
"""Generates and syncs tool description embeddings to Milvus Lite."""
|
126 |
+
print("Syncing tool embeddings to Milvus...")
|
127 |
api_key = os.environ.get("GEMINI_API_KEY")
|
128 |
if not api_key:
|
129 |
+
print("Error: GEMINI_API_KEY not found.")
|
130 |
return
|
131 |
genai.configure(api_key=api_key)
|
132 |
|
|
|
136 |
all_tools_in_db = cursor.fetchall()
|
137 |
|
138 |
if not all_tools_in_db:
|
139 |
+
print("Error: No tools found in SQLite to sync.")
|
140 |
return
|
141 |
|
142 |
+
print(f"Found {len(all_tools_in_db)} tools from SQLite, generating embeddings...")
|
143 |
docs_to_embed = [tool[1] for tool in all_tools_in_db]
|
144 |
|
145 |
+
print(f"Using embedding model: {EMBEDDING_MODEL_NAME}")
|
146 |
result = genai.embed_content(
|
147 |
model=EMBEDDING_MODEL_NAME,
|
148 |
content=docs_to_embed,
|
|
|
158 |
]
|
159 |
|
160 |
milvus_client.insert(collection_name=MILVUS_COLLECTION_NAME, data=data_to_insert)
|
161 |
+
print(f"Successfully inserted {len(data_to_insert)} new embeddings into Milvus.")
|
tools/news_tool.py
CHANGED
@@ -6,9 +6,9 @@ from bs4 import BeautifulSoup
|
|
6 |
|
7 |
def search_latest_news(query: str) -> str:
|
8 |
"""
|
9 |
-
|
10 |
"""
|
11 |
-
print(f"---
|
12 |
headers = {
|
13 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
|
14 |
}
|
@@ -22,9 +22,9 @@ def search_latest_news(query: str) -> str:
|
|
22 |
results = soup.find_all("div", class_="result")
|
23 |
|
24 |
if not results:
|
25 |
-
return "
|
26 |
|
27 |
-
#
|
28 |
snippets = []
|
29 |
for result in results[:3]:
|
30 |
title_tag = result.find("a", class_="result__a")
|
@@ -32,11 +32,11 @@ def search_latest_news(query: str) -> str:
|
|
32 |
if title_tag and snippet_tag:
|
33 |
title = title_tag.text.strip()
|
34 |
snippet = snippet_tag.text.strip()
|
35 |
-
snippets.append(f"
|
36 |
|
37 |
return "\n---\n".join(snippets)
|
38 |
|
39 |
except requests.RequestException as e:
|
40 |
-
return f"
|
41 |
except Exception as e:
|
42 |
-
return f"
|
|
|
6 |
|
7 |
def search_latest_news(query: str) -> str:
|
8 |
"""
|
9 |
+
Simulates a news search by scraping DuckDuckGo search results using requests and BeautifulSoup.
|
10 |
"""
|
11 |
+
print(f"--- Executing Tool: search_latest_news, Parameters: {query} ---")
|
12 |
headers = {
|
13 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
|
14 |
}
|
|
|
22 |
results = soup.find_all("div", class_="result")
|
23 |
|
24 |
if not results:
|
25 |
+
return "No relevant news articles were found."
|
26 |
|
27 |
+
# Extract snippets from the top 3 results
|
28 |
snippets = []
|
29 |
for result in results[:3]:
|
30 |
title_tag = result.find("a", class_="result__a")
|
|
|
32 |
if title_tag and snippet_tag:
|
33 |
title = title_tag.text.strip()
|
34 |
snippet = snippet_tag.text.strip()
|
35 |
+
snippets.append(f"Title: {title}\nSnippet: {snippet}\n")
|
36 |
|
37 |
return "\n---\n".join(snippets)
|
38 |
|
39 |
except requests.RequestException as e:
|
40 |
+
return f"A network error occurred while searching for news: {e}"
|
41 |
except Exception as e:
|
42 |
+
return f"An error occurred while parsing news search results: {e}"
|
tools/stock_tool.py
CHANGED
@@ -6,18 +6,18 @@ import random
|
|
6 |
|
7 |
def get_stock_price(symbol: str) -> str:
|
8 |
"""
|
9 |
-
|
10 |
-
|
11 |
"""
|
12 |
-
print(f"---
|
13 |
symbol = symbol.upper()
|
14 |
-
#
|
15 |
try:
|
16 |
-
#
|
17 |
if symbol in ["AAPL", "GOOGL", "MSFT"]:
|
18 |
price = round(random.uniform(100, 500), 2)
|
19 |
-
return f"
|
20 |
else:
|
21 |
-
return f"
|
22 |
except Exception as e:
|
23 |
-
return f"
|
|
|
6 |
|
7 |
def get_stock_price(symbol: str) -> str:
|
8 |
"""
|
9 |
+
Simulates fetching a stock price.
|
10 |
+
In a real-world scenario, this would call a proper financial API.
|
11 |
"""
|
12 |
+
print(f"--- Executing Tool: get_stock_price, Parameters: {symbol} ---")
|
13 |
symbol = symbol.upper()
|
14 |
+
# Simulate an API call
|
15 |
try:
|
16 |
+
# This is a simulation. A real implementation would use an API like Alpha Vantage, Yahoo Finance, etc.
|
17 |
if symbol in ["AAPL", "GOOGL", "MSFT"]:
|
18 |
price = round(random.uniform(100, 500), 2)
|
19 |
+
return f"The simulated real-time price for stock {symbol} is ${price}."
|
20 |
else:
|
21 |
+
return f"Could not find information for stock symbol: {symbol}."
|
22 |
except Exception as e:
|
23 |
+
return f"An error occurred while calling the stock API: {e}"
|
tools/tool_registry.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1 |
# tools/tool_registry.py
|
2 |
|
3 |
from langchain_core.tools import tool
|
4 |
-
from typing import List,
|
5 |
|
6 |
-
#
|
7 |
from .stock_tool import get_stock_price
|
8 |
from .news_tool import search_latest_news
|
9 |
|
10 |
-
#
|
11 |
-
#
|
12 |
|
13 |
|
14 |
-
# 使用 @tool 装饰器定义你的工具
|
15 |
@tool
|
16 |
def get_stock_price_tool(symbol: str) -> str:
|
17 |
"""
|
18 |
-
|
19 |
-
|
20 |
"""
|
21 |
return get_stock_price(symbol)
|
22 |
|
@@ -24,13 +23,13 @@ def get_stock_price_tool(symbol: str) -> str:
|
|
24 |
@tool
|
25 |
def search_latest_news_tool(query: str) -> str:
|
26 |
"""
|
27 |
-
|
28 |
-
|
29 |
"""
|
30 |
return search_latest_news(query)
|
31 |
|
32 |
|
33 |
-
#
|
34 |
_all_tools = [
|
35 |
get_stock_price_tool,
|
36 |
search_latest_news_tool,
|
@@ -38,12 +37,12 @@ _all_tools = [
|
|
38 |
|
39 |
|
40 |
def get_all_tools() -> List[Any]:
|
41 |
-
"""
|
42 |
return _all_tools
|
43 |
|
44 |
|
45 |
def get_tool_by_name(name: str) -> Any:
|
46 |
-
"""
|
47 |
for tool_obj in _all_tools:
|
48 |
if tool_obj.name == name:
|
49 |
return tool_obj
|
|
|
1 |
# tools/tool_registry.py
|
2 |
|
3 |
from langchain_core.tools import tool
|
4 |
+
from typing import List, Any
|
5 |
|
6 |
+
# Import the actual tool functions
|
7 |
from .stock_tool import get_stock_price
|
8 |
from .news_tool import search_latest_news
|
9 |
|
10 |
+
# Use LangChain's @tool decorator to define tools.
|
11 |
+
# This is more robust as it automatically handles descriptions and argument schemas.
|
12 |
|
13 |
|
|
|
14 |
@tool
|
15 |
def get_stock_price_tool(symbol: str) -> str:
|
16 |
"""
|
17 |
+
Gets the real-time stock price for a given stock symbol (e.g., AAPL, GOOGL).
|
18 |
+
Use this tool when the user asks for the stock price of a specific company.
|
19 |
"""
|
20 |
return get_stock_price(symbol)
|
21 |
|
|
|
23 |
@tool
|
24 |
def search_latest_news_tool(query: str) -> str:
|
25 |
"""
|
26 |
+
Searches for the latest news articles based on a keyword.
|
27 |
+
Use this tool when the user asks about the latest updates, events, or news on a certain topic.
|
28 |
"""
|
29 |
return search_latest_news(query)
|
30 |
|
31 |
|
32 |
+
# Central registry for all tools
|
33 |
_all_tools = [
|
34 |
get_stock_price_tool,
|
35 |
search_latest_news_tool,
|
|
|
37 |
|
38 |
|
39 |
def get_all_tools() -> List[Any]:
|
40 |
+
"""Returns a list containing all defined tool objects."""
|
41 |
return _all_tools
|
42 |
|
43 |
|
44 |
def get_tool_by_name(name: str) -> Any:
|
45 |
+
"""Finds and returns a tool object by its name."""
|
46 |
for tool_obj in _all_tools:
|
47 |
if tool_obj.name == name:
|
48 |
return tool_obj
|