Open-Source AI Cookbook documentation

用于文本到 SQL 的智能体,带有自动错误修正功能

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Open In Colab

用于文本到 SQL 的智能体,带有自动错误修正功能

作者:Aymeric Roucher

在本教程中,我们将学习如何实现一个利用 SQL 的智能体,使用 transformers.agents

与标准文本到 SQL pipeline 相比,它有什么优势?

标准的文本到 SQL pipeline 是脆弱的,因为生成的 SQL 查询可能是错误的。更糟糕的是,查询可能是错误的,但并不会引发错误,而是返回一些错误的/无用的输出,且不会发出警报。

👉 相比之下,智能体系统能够批判性地检查输出,并决定是否需要更改查询,从而大大提高其性能。

让我们开始构建这个智能体吧!💪

设置 SQL 表

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "receipts"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("customer_name", String(16), primary_key=True),
    Column("price", Float),
    Column("tip", Float),
)
metadata_obj.create_all(engine)
rows = [
    {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
    {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
    {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
    {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

让我们检查系统是否能通过一个基本查询正常工作:

>>> with engine.connect() as con:
...     rows = con.execute(text("""SELECT * from receipts"""))
...     for row in rows:
...         print(row)
(1, 'Alan Payne', 12.06, 1.2)
(2, 'Alex Mason', 23.86, 0.24)
(3, 'Woodrow Wilson', 53.43, 5.43)
(4, 'Margaret James', 21.11, 1.0)

构建我们的智能体

现在,让我们将 SQL 表格使其可由智能体(工具)检索。

智能体的 description 属性将被嵌入到大语言模型(LLM)的提示中,这样它就能了解如何使用这个工具。在这个步骤中,我们需要描述 SQL 表格的结构,以便让智能体能够正确地执行查询并与数据库交互。

>>> inspector = inspect(engine)
>>> columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]

>>> table_description = "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
>>> print(table_description)
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

现在,让我们构建我们的工具。它需要以下内容:(详细信息请参阅文档

  • 带有 Args: 部分的文档字符串
  • 类型提示
from transformers.agents import tool


@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'. Its description is as follows:
        Columns:
        - receipt_id: INTEGER
        - customer_name: VARCHAR(16)
        - price: FLOAT
        - tip: FLOAT

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output

现在让我们创建一个利用这个工具的智能体。

我们将使用 ReactCodeAgent,它是 transformers.agents 的主要智能体类:一个根据 ReAct 框架编写代码并能迭代先前输出的智能体。

llm_engine 是驱动智能体系统的 LLM。HfEngine 允许你通过 HF 的推理 API 调用 LLM,无论是通过无服务器或专用端点,但你也可以使用任何专有的 API:查看这个指南以了解如何进行适配。

from transformers.agents import ReactCodeAgent, HfApiEngine

agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
)
agent.run("Can you give me the name of the client who got the most expensive receipt?")

提高难度:表格连接

现在让我们增加一点难度!我们希望智能体能够处理跨多个表的连接查询。

因此,让我们创建一个第二个表,用于记录每个 receipt_id 对应的服务员姓名!

table_name = "waiters"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("waiter_name", String(16), primary_key=True),
)
metadata_obj.create_all(engine)

rows = [
    {"receipt_id": 1, "waiter_name": "Corey Johnson"},
    {"receipt_id": 2, "waiter_name": "Michael Watts"},
    {"receipt_id": 3, "waiter_name": "Michael Watts"},
    {"receipt_id": 4, "waiter_name": "Margaret James"},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

我们需要更新 SQLExecutorTool,将这个表的描述添加进去,以便让 LLM 能够正确地利用这个表中的信息。

>>> updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
... It can use the following tables:"""

>>> inspector = inspect(engine)
>>> for table in ["receipts", "waiters"]:
...     columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]

...     table_description = f"Table '{table}':\n"

...     table_description += "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
...     updated_description += "\n\n" + table_description

>>> print(updated_description)
Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:

Table 'receipts':
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

Table 'waiters':
Columns:
  - receipt_id: INTEGER
  - waiter_name: VARCHAR(16)

由于这个请求比之前的更具挑战性,我们将切换 LLM 引擎,使用更强大的 Qwen/Qwen2.5-72B-Instruct

sql_engine.description = updated_description

agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
)

agent.run("Which waiter got more total money from tips?")

它直接就能工作!设置过程出乎意料地简单,不是吗?

✅ 现在你可以去构建你一直梦想的文本到 SQL 系统了!✨

< > Update on GitHub