tdoehmen commited on
Commit
9e806a7
·
verified ·
1 Parent(s): c47beda

update response extraction

Browse files
Files changed (1) hide show
  1. duckdb-nsql/eval/prompt_formatters.py +10 -71
duckdb-nsql/eval/prompt_formatters.py CHANGED
@@ -65,77 +65,16 @@ class RajkumarFormatter:
65
 
66
  @classmethod
67
  def format_model_output(cls, output_sql: str, prompt: str) -> str:
68
- def clean_code_block(block: str) -> str:
69
- """Clean a code block by removing markdown syntax and extra whitespace."""
70
- # Remove markdown indicators and common SQL prefixes
71
- cleaned = (block
72
- .replace('```sql\n', '')
73
- .replace('```duckdb\n', '')
74
- .replace('```\n', '')
75
- .replace('```', '')
76
- .strip())
77
-
78
- return cleaned
79
-
80
- def ensure_semicolon(sql: str) -> str:
81
- """Ensure the SQL query ends with exactly one semicolon."""
82
- sql = sql.strip()
83
- # Remove any existing trailing semicolons
84
- while sql.endswith(';'):
85
- sql = sql[:-1].strip()
86
- # Add back exactly one semicolon
87
- return sql + ";"
88
-
89
- # First, try to find SQL-specific code blocks
90
- sql_blocks = []
91
- start_pos = 0
92
- while True:
93
- start = output_sql.find('```sql', start_pos)
94
- if start == -1:
95
- start = output_sql.find('```duckdb', start_pos)
96
- if start == -1:
97
- break
98
-
99
- end = output_sql.find('```', start + 4)
100
- if end == -1:
101
- break
102
-
103
- sql_blocks.append(output_sql[start:end+3])
104
- start_pos = end + 3
105
-
106
- # If SQL blocks found, use the last one
107
- if sql_blocks:
108
- return ensure_semicolon(clean_code_block(sql_blocks[-1])).replace('```sql','').replace('```','').strip()
109
-
110
- # If no SQL blocks, look for generic code blocks
111
- generic_blocks = []
112
- start_pos = 0
113
- while True:
114
- start = output_sql.find('```', start_pos)
115
- if start == -1:
116
- break
117
-
118
- end = output_sql.find('```', start + 3)
119
- if end == -1:
120
- break
121
-
122
- block = output_sql[start:end+3]
123
- # Skip if this is actually an SQL block (we already handled those)
124
- if not block.startswith('```sql') and not block.startswith('```duckdb'):
125
- generic_blocks.append(block)
126
- start_pos = end + 3
127
-
128
- # If generic blocks found, use the last one
129
- if generic_blocks:
130
- return ensure_semicolon(clean_code_block(generic_blocks[-1])).replace('```sql','').replace('```','').strip()
131
-
132
- # If no code blocks found at all, take everything up to first semicolon
133
- semicolon_pos = output_sql.find(';')
134
- if semicolon_pos != -1:
135
- return ensure_semicolon(output_sql[:semicolon_pos].strip()).replace('```sql','').replace('```','').strip()
136
-
137
- # If no semicolon found, use the entire text
138
- return ensure_semicolon(output_sql.strip()).replace('```sql','').replace('```','').strip()
139
 
140
  @classmethod
141
  def format_gold_output(cls, output_sql: str) -> str:
 
65
 
66
  @classmethod
67
  def format_model_output(cls, output_sql: str, prompt: str) -> str:
68
+ pattern = r"```(?:sql|mysql|duckdb)?\n?(.*?)```"
69
+ match = re.search(pattern, output_sql, re.DOTALL)
70
+ sql = match.group(1).strip() if match else output_sql.strip()
71
+
72
+ # Handle edge case where regex captured empty content
73
+ if not sql:
74
+ sql = output_sql.strip()
75
+
76
+ # Ensure single trailing semicolon
77
+ return sql.rstrip(';') + ';'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  @classmethod
80
  def format_gold_output(cls, output_sql: str) -> str: