extraction sqlquery (#4027)
Browse filesclone https://github.com/infiniflow/ragflow/pull/4023
improve the information extraction, most llm return results in markdown
format ````sql ___ query `____ ```
- agent/component/exesql.py +17 -3
agent/component/exesql.py
CHANGED
|
@@ -20,7 +20,7 @@ import pymysql
|
|
| 20 |
import psycopg2
|
| 21 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 22 |
import pyodbc
|
| 23 |
-
|
| 24 |
|
| 25 |
class ExeSQLParam(ComponentParamBase):
|
| 26 |
"""
|
|
@@ -65,13 +65,26 @@ class ExeSQL(ComponentBase, ABC):
|
|
| 65 |
self._loop += 1
|
| 66 |
|
| 67 |
ans = self.get_input()
|
|
|
|
|
|
|
| 68 |
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
ans = re.sub(r';.*?SELECT ', '; SELECT ', ans, flags=re.IGNORECASE)
|
| 71 |
ans = re.sub(r';[^;]*$', r';', ans)
|
| 72 |
if not ans:
|
| 73 |
raise Exception("SQL statement not found!")
|
| 74 |
|
|
|
|
| 75 |
if self._param.db_type in ["mysql", "mariadb"]:
|
| 76 |
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
| 77 |
port=self._param.port, password=self._param.password)
|
|
@@ -96,11 +109,12 @@ class ExeSQL(ComponentBase, ABC):
|
|
| 96 |
if not single_sql:
|
| 97 |
continue
|
| 98 |
try:
|
|
|
|
| 99 |
cursor.execute(single_sql)
|
| 100 |
if cursor.rowcount == 0:
|
| 101 |
sql_res.append({"content": "\nTotal: 0\n No record in the database!"})
|
| 102 |
continue
|
| 103 |
-
single_res = pd.DataFrame([i for i in cursor.fetchmany(
|
| 104 |
single_res.columns = [i[0] for i in cursor.description]
|
| 105 |
sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()})
|
| 106 |
except Exception as e:
|
|
|
|
| 20 |
import psycopg2
|
| 21 |
from agent.component.base import ComponentBase, ComponentParamBase
|
| 22 |
import pyodbc
|
| 23 |
+
import logging
|
| 24 |
|
| 25 |
class ExeSQLParam(ComponentParamBase):
|
| 26 |
"""
|
|
|
|
| 65 |
self._loop += 1
|
| 66 |
|
| 67 |
ans = self.get_input()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
| 71 |
+
if self._param.db_type == 'mssql':
|
| 72 |
+
# improve the information extraction, most llm return results in markdown format ```sql query ```
|
| 73 |
+
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
|
| 74 |
+
if match:
|
| 75 |
+
ans = match.group(1) # Query content
|
| 76 |
+
print(ans)
|
| 77 |
+
else:
|
| 78 |
+
print("no markdown")
|
| 79 |
+
ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE)
|
| 80 |
+
else:
|
| 81 |
+
ans = re.sub(r'^.*?SELECT ', 'SELECT ', repr(ans), flags=re.IGNORECASE)
|
| 82 |
ans = re.sub(r';.*?SELECT ', '; SELECT ', ans, flags=re.IGNORECASE)
|
| 83 |
ans = re.sub(r';[^;]*$', r';', ans)
|
| 84 |
if not ans:
|
| 85 |
raise Exception("SQL statement not found!")
|
| 86 |
|
| 87 |
+
logging.info("db_type: ",self._param.db_type)
|
| 88 |
if self._param.db_type in ["mysql", "mariadb"]:
|
| 89 |
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
| 90 |
port=self._param.port, password=self._param.password)
|
|
|
|
| 109 |
if not single_sql:
|
| 110 |
continue
|
| 111 |
try:
|
| 112 |
+
logging.info("single_sql: ",single_sql)
|
| 113 |
cursor.execute(single_sql)
|
| 114 |
if cursor.rowcount == 0:
|
| 115 |
sql_res.append({"content": "\nTotal: 0\n No record in the database!"})
|
| 116 |
continue
|
| 117 |
+
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
|
| 118 |
single_res.columns = [i[0] for i in cursor.description]
|
| 119 |
sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()})
|
| 120 |
except Exception as e:
|