|
import re |
|
import sqlparse |
|
from typing import List, Tuple, Set, Iterator, Dict, Any, Union |
|
from sqlparse.sql import Comparison, Identifier |
|
from sqlparse.tokens import Whitespace |
|
import itertools |
|
from collections import namedtuple |
|
|
|
Token = namedtuple("Token", ["ttype", "value"]) |
|
VALUE_NUM_SYMBOL = "VALUERARE" |
|
QUOTE_CHARS = {"`", "'", '"'} |
|
|
|
|
|
def tokenize(query: str) -> List[Token]: |
|
tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()]) |
|
return tokens |
|
|
|
|
|
def join_tokens(tokens: List[Token]) -> str: |
|
return "".join([x.value for x in tokens]).strip().replace(" ", " ") |
|
|
|
|
|
def round_trip_test(query: str) -> None: |
|
tokens = tokenize(query) |
|
reconstructed = "".join([token.value for token in tokens]) |
|
assert query == reconstructed, "Round trip test fails for string %s" % query |
|
|
|
|
|
def postprocess(query: str) -> str: |
|
query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") |
|
return query |
|
|
|
|
|
|
|
|
|
def strip_query(query: str) -> Tuple[List[str], List[str]]: |
|
query_keywords, all_values = [], [] |
|
|
|
|
|
|
|
|
|
""" |
|
str_1 = re.findall("\"[^\"]*\"", query) |
|
str_2 = re.findall("\'[^\']*\'", query) |
|
values = str_1 + str_2 |
|
""" |
|
|
|
toks = sqlparse.parse(query)[0].flatten() |
|
values = [ |
|
t.value |
|
for t in toks |
|
if t.ttype == sqlparse.tokens.Literal.String.Single |
|
or t.ttype == sqlparse.tokens.Literal.String.Symbol |
|
] |
|
|
|
for val in values: |
|
all_values.append(val) |
|
query = query.replace(val.strip(), VALUE_NUM_SYMBOL) |
|
|
|
query_tokenized = query.split() |
|
float_nums = re.findall("[-+]?\d*\.\d+", query) |
|
all_values += [qt for qt in query_tokenized if qt in float_nums] |
|
query_tokenized = [ |
|
VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized |
|
] |
|
|
|
query = " ".join(query_tokenized) |
|
int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] |
|
|
|
all_values += [qt for qt in query_tokenized if qt in int_nums] |
|
query_tokenized = [ |
|
VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized |
|
] |
|
|
|
|
|
for tok in query_tokenized: |
|
if "." in tok: |
|
table = re.findall("[Tt]\d+\.", tok) |
|
if len(table) > 0: |
|
to = tok.replace(".", " . ").split() |
|
to = [t.lower() for t in to if len(t) > 0] |
|
query_keywords.extend(to) |
|
else: |
|
query_keywords.append(tok.lower()) |
|
|
|
elif len(tok) > 0: |
|
query_keywords.append(tok.lower()) |
|
return query_keywords, all_values |
|
|
|
|
|
def reformat_query(query: str) -> str: |
|
query = query.strip().replace(";", "").replace("\t", "") |
|
query = " ".join( |
|
[t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace] |
|
) |
|
t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"] |
|
for ts in t_stars: |
|
query = query.replace(ts, "*") |
|
return query |
|
|
|
|
|
def replace_values(sql: str) -> Tuple[List[str], Set[str]]: |
|
sql = sqlparse.format(sql, reindent=False, keyword_case="upper") |
|
|
|
sql = re.sub(r"(T\d+\.)\s", r"\1", sql) |
|
query_toks_no_value, values = strip_query(sql) |
|
return query_toks_no_value, set(values) |
|
|
|
|
|
|
|
|
|
def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]: |
|
reformated = reformat_query(query=sql) |
|
query_value_replaced, values = replace_values(reformated) |
|
return query_value_replaced, values |
|
|
|
|
|
|
|
def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str: |
|
q_length = len(query_value_replaced) |
|
query_w_values = query_value_replaced[:] |
|
value_idx = [ |
|
idx |
|
for idx in range(q_length) |
|
if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower() |
|
] |
|
assert len(value_idx) == len(values_in_order) |
|
|
|
for idx, value in zip(value_idx, values_in_order): |
|
query_w_values[idx] = value |
|
return " ".join(query_w_values) |
|
|
|
|
|
|
|
|
|
def plugin_all_permutations( |
|
query_value_replaced: List[str], values: Set[str] |
|
) -> Iterator[str]: |
|
num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) |
|
for values in itertools.product(*[list(values) for _ in range(num_slots)]): |
|
yield plugin(query_value_replaced, list(values)) |
|
|
|
|
|
|
|
|
|
|
|
def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]: |
|
_, gold_values = extract_query_values(gold) |
|
pred_query_value_replaced, _ = extract_query_values(pred) |
|
num_slots = len( |
|
[v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()] |
|
) |
|
num_alternatives = len(gold_values) ** num_slots |
|
return ( |
|
num_alternatives, |
|
plugin_all_permutations(pred_query_value_replaced, gold_values), |
|
) |
|
|
|
|
|
def remove_distinct(s): |
|
toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] |
|
return "".join([t for t in toks if t.lower() != "distinct"]) |
|
|
|
|
|
def extract_all_comparison_from_node(node: Token) -> List[Comparison]: |
|
comparison_list = [] |
|
if hasattr(node, "tokens"): |
|
for t in node.tokens: |
|
comparison_list.extend(extract_all_comparison_from_node(t)) |
|
if type(node) == Comparison: |
|
comparison_list.append(node) |
|
return comparison_list |
|
|
|
|
|
def extract_all_comparison(query: str) -> List[Comparison]: |
|
tree = sqlparse.parse(query)[0] |
|
comparison_list = extract_all_comparison_from_node(tree) |
|
return comparison_list |
|
|
|
|
|
def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]: |
|
tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace] |
|
return tokens |
|
|
|
|
|
def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]: |
|
tokens = extract_toks_from_comparison(comparison_node) |
|
left, op, right = tokens |
|
|
|
returned_dict = {"left": left, "op": op.value, "right": right} |
|
|
|
if type(left) != Identifier: |
|
return returned_dict |
|
|
|
table = None |
|
if len(left.tokens) == 3 and re.match("^[tT][0-9]$", left.tokens[0].value) is None: |
|
table = left.tokens[0].value.lower() |
|
col = left.tokens[-1].value |
|
|
|
if type(right) == Identifier: |
|
if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token: |
|
right_val = right.tokens[0].value |
|
else: |
|
return returned_dict |
|
elif type(right) == sqlparse.sql.Token: |
|
right_val = right.value |
|
else: |
|
return returned_dict |
|
|
|
returned_dict["table_col"], returned_dict["val"] = ( |
|
(table, col.upper()), |
|
process_str_value(right_val), |
|
) |
|
|
|
return returned_dict |
|
|
|
|
|
def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]: |
|
comparison_list = extract_all_comparison(query) |
|
return [extract_info_from_comparison(c) for c in comparison_list] |
|
|
|
|
|
def extract_typed_value_in_comparison_from_query( |
|
query: str, |
|
) -> List[Tuple[Tuple[Union[str, None], str], str]]: |
|
cmps = extract_all_comparison_from_query(query) |
|
typed_values = [ |
|
(cmp["table_col"], cmp["val"]) for cmp in cmps if "table_col" in cmp |
|
] |
|
for table, col, val1, val2 in re.findall( |
|
"(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)", |
|
query, |
|
re.IGNORECASE, |
|
): |
|
if table == "": |
|
table = None |
|
else: |
|
table = table.lower() |
|
col = col.upper() |
|
for v in [val1, val2]: |
|
typed_values.append(((table, col), v)) |
|
return typed_values |
|
|
|
|
|
def process_str_value(v: str) -> str: |
|
if len(v) > 0 and v[0] in QUOTE_CHARS: |
|
v = v[1:] |
|
if len(v) > 0 and v[-1] in QUOTE_CHARS: |
|
v = v[:-1] |
|
for c in QUOTE_CHARS: |
|
v = v.replace(c + c, c) |
|
return v |
|
|