|
import os |
|
import re |
|
import duckdb |
|
import asyncio |
|
import threading |
|
from typing import Tuple, Any, List, Set |
|
from itertools import product |
|
from collections import defaultdict |
|
import tqdm |
|
import random |
|
import time |
|
import pickle as pkl |
|
import subprocess |
|
from itertools import chain |
|
import shutil |
|
from pathlib import Path |
|
from .parse import get_all_preds_for_execution, remove_distinct |
|
|
|
|
|
threadLock = threading.Lock() |
|
TIMEOUT = 60 |
|
TMP_DIR = "_tmp" |
|
EXEC_TMP_DIR = os.path.join(os.path.dirname(__file__), "tmp") |
|
|
|
|
|
def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: |
|
assert len(element) == len(perm) |
|
return tuple([element[i] for i in perm]) |
|
|
|
|
|
def unorder_row(row: Tuple) -> Tuple: |
|
return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) |
|
|
|
|
|
def tuple_sublists(row: Tuple) -> Tuple: |
|
new_row = [] |
|
for item in row: |
|
if isinstance(item, list): |
|
new_row.append(tuple(item)) |
|
elif isinstance(item, dict): |
|
new_row.append(tuple(sorted(item.items(), key=lambda x: x[0]))) |
|
print(new_row[-1]) |
|
else: |
|
new_row.append(item) |
|
new_row = tuple(new_row) |
|
return new_row |
|
|
|
|
|
|
|
|
|
|
|
|
|
def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: |
|
s1 = [unorder_row(row) for row in result1] |
|
s2 = [unorder_row(row) for row in result2] |
|
if order_matters: |
|
return s1 == s2 |
|
else: |
|
return set(s1) == set(s2) |
|
|
|
|
|
|
|
def multiset_eq(l1: List, l2: List) -> bool: |
|
if len(l1) != len(l2): |
|
return False |
|
d = defaultdict(int) |
|
for e in l1: |
|
d[e] = d[e] + 1 |
|
for e in l2: |
|
d[e] = d[e] - 1 |
|
if d[e] < 0: |
|
return False |
|
return True |
|
|
|
|
|
def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): |
|
num_cols = len(result2[0]) |
|
perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] |
|
if num_cols <= 3: |
|
return product(*perm_constraints) |
|
|
|
|
|
for _ in range(20): |
|
random_tab2_row = random.choice(result2) |
|
|
|
for tab1_col in range(num_cols): |
|
for tab2_col in set(perm_constraints[tab1_col]): |
|
if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: |
|
perm_constraints[tab1_col].remove(tab2_col) |
|
return product(*perm_constraints) |
|
|
|
|
|
|
|
def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: |
|
if len(result1) == 0 and len(result2) == 0: |
|
return True |
|
|
|
|
|
if len(result1) != len(result2): |
|
return False |
|
|
|
num_cols = len(result1[0]) |
|
|
|
|
|
if len(result2[0]) != num_cols: |
|
return False |
|
|
|
result1 = [tuple_sublists(row) for row in result1] |
|
result2 = [tuple_sublists(row) for row in result2] |
|
|
|
|
|
|
|
if not quick_rej(result1, result2, order_matters): |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] |
|
|
|
|
|
|
|
|
|
for perm in get_constraint_permutation(tab1_sets_by_columns, result2): |
|
if len(perm) != len(set(perm)): |
|
continue |
|
if num_cols == 1: |
|
result2_perm = result2 |
|
else: |
|
result2_perm = [permute_tuple(element, perm) for element in result2] |
|
if order_matters: |
|
if result1 == result2_perm: |
|
return True |
|
else: |
|
|
|
|
|
|
|
if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): |
|
return True |
|
return False |
|
|
|
|
|
def replace_cur_year(query: str) -> str: |
|
return re.sub( |
|
"YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE |
|
) |
|
|
|
|
|
class WithDuckDBConnectionInTmpDir(object): |
|
def __init__(self, databases_file, tmp_dir): |
|
if not os.path.exists(databases_file): |
|
raise Exception("Database note found: %s" % databases_file) |
|
os.makedirs(tmp_dir) |
|
shutil.copy(databases_file, tmp_dir) |
|
self.tmp_dbfile = Path(databases_file).name |
|
self.tmp_dir = tmp_dir |
|
self.original_wd = os.getcwd() |
|
|
|
def __enter__(self): |
|
os.chdir(self.tmp_dir) |
|
self.con = duckdb.connect(self.tmp_dbfile) |
|
return self.con |
|
|
|
def __exit__(self, *args): |
|
self.con.close() |
|
os.chdir(self.original_wd) |
|
shutil.rmtree(self.tmp_dir) |
|
|
|
|
|
async def exec_on_db_( |
|
duckdb_path: str, query: str, setup_sql: str, validate_sql: str |
|
) -> Tuple[str, Any]: |
|
|
|
try: |
|
with WithDuckDBConnectionInTmpDir(duckdb_path, TMP_DIR) as connection: |
|
if setup_sql is not None: |
|
print("Running Setup SQL:" + setup_sql) |
|
connection.execute(setup_sql) |
|
ddb_benchmark_result_rel = connection.sql(query) |
|
if ddb_benchmark_result_rel is not None: |
|
connection.execute( |
|
"CREATE TABLE ddb_benchmark_result AS SELECT * FROM ddb_benchmark_result_rel" |
|
) |
|
else: |
|
connection.execute("CREATE TABLE ddb_benchmark_result(empty TEXT)") |
|
print("Running Validation SQL:" + validate_sql) |
|
result = connection.execute(validate_sql).fetchall() |
|
return "result", result |
|
except Exception as e: |
|
return "exception", e |
|
|
|
|
|
async def exec_on_db( |
|
duckdb_path: str, |
|
query: str, |
|
setup_sql: str, |
|
validate_sql: str, |
|
timeout: int = TIMEOUT, |
|
) -> Tuple[str, Any]: |
|
try: |
|
return await asyncio.wait_for( |
|
exec_on_db_(duckdb_path, query, setup_sql, validate_sql), timeout |
|
) |
|
except asyncio.TimeoutError: |
|
return ("exception", TimeoutError) |
|
except Exception as e: |
|
return ("exception", e) |
|
|
|
|
|
|
|
|
|
def postprocess(query: str) -> str: |
|
query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") |
|
return query |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eval_exec_match( |
|
db: str, |
|
p_str: str, |
|
g_str: str, |
|
setup_sql: str, |
|
validate_sql: str, |
|
plug_value: bool, |
|
keep_distinct: bool, |
|
progress_bar_for_each_datapoint: bool, |
|
) -> int: |
|
|
|
|
|
p_str, g_str = postprocess(p_str), postprocess(g_str) |
|
if not keep_distinct: |
|
try: |
|
|
|
p_str = remove_distinct(p_str) |
|
except Exception as e: |
|
return 0 |
|
g_str = remove_distinct(g_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
order_matters = "order by" in g_str.lower() |
|
|
|
|
|
db_dir = os.path.dirname(db) |
|
db_paths = [ |
|
os.path.join(db_dir, basename) |
|
for basename in os.listdir(db_dir) |
|
if ".duckdb" in basename |
|
] |
|
|
|
preds = [p_str] |
|
|
|
|
|
|
|
if plug_value: |
|
_, preds = get_all_preds_for_execution(g_str, p_str) |
|
|
|
|
|
preds = chain([p_str], preds) |
|
|
|
for pred in preds: |
|
pred_passes = 1 |
|
|
|
|
|
if progress_bar_for_each_datapoint: |
|
ranger = tqdm.tqdm(db_paths) |
|
else: |
|
ranger = db_paths |
|
|
|
for db_path in ranger: |
|
g_flag, g_denotation = asyncio.run( |
|
exec_on_db( |
|
db_path, g_str, setup_sql=setup_sql, validate_sql=validate_sql |
|
) |
|
) |
|
p_flag, p_denotation = asyncio.run( |
|
exec_on_db( |
|
db_path, pred, setup_sql=setup_sql, validate_sql=validate_sql |
|
) |
|
) |
|
|
|
|
|
assert ( |
|
g_flag != "exception" |
|
), f"gold query {g_str} has error {g_denotation} on database file {db_path}" |
|
|
|
|
|
if p_flag == "exception": |
|
pred_passes = 0 |
|
|
|
|
|
elif not result_eq(g_denotation, p_denotation, order_matters=order_matters): |
|
pred_passes = 0 |
|
if pred_passes == 0: |
|
break |
|
|
|
|
|
if pred_passes == 1: |
|
return 1 |
|
|
|
|
|
return 0 |
|
|