Spaces:
Sleeping
Sleeping
| """Run dataset on text2sql zazu experiment. | |
| See README.md for more details. | |
| """ | |
| import datetime | |
| import json | |
| import multiprocessing | |
| import random | |
| import re | |
| from pathlib import Path | |
| import click | |
| import numpy as np | |
| from constants import PROMPT_FORMATTERS | |
| from loaders import DefaultLoader | |
| from get_manifest import get_manifest | |
| from manifest import Manifest | |
| from prompt_formatters import RajkumarFormatter | |
| from rich.console import Console | |
| from schema import Table, TextToSQLModelResponse, TextToSQLParams | |
| from text_to_sql import instruction_to_sql, instruction_to_sql_list | |
| from doc_retriever import ( | |
| load_documentation, | |
| split_documents, | |
| embed_documents, | |
| query_docs, | |
| ) | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| console = Console(soft_wrap=True) | |
| def generate_sql( | |
| manifest: Manifest, | |
| text_to_sql_in: list[TextToSQLParams], | |
| retrieved_docs: list[list[str]], | |
| prompt_formatter: RajkumarFormatter, | |
| stop_tokens: list[str] | None = None, | |
| overwrite_manifest: bool = False, | |
| max_tokens: int = 300, | |
| temperature: float = 0.1, | |
| num_beams: int = 2, | |
| parallel: bool = False, | |
| ) -> list[tuple[str, TextToSQLModelResponse]]: | |
| """Call our text2sql function with manifest of our choice.""" | |
| if parallel: | |
| instruction_to_sql_resps: list[ | |
| TextToSQLModelResponse | |
| ] = instruction_to_sql_list( | |
| params=text_to_sql_in, | |
| extra_context=retrieved_docs, | |
| manifest=manifest, | |
| prompt_formatter=prompt_formatter, | |
| overwrite_manifest=overwrite_manifest, | |
| max_tokens=max_tokens, | |
| temperature=0.1, | |
| stop_sequences=stop_tokens, | |
| num_beams=num_beams, | |
| ) | |
| else: | |
| instruction_to_sql_resps = [ | |
| instruction_to_sql( | |
| params=_text_to_sql_in, | |
| extra_context=_retrieved_docs, | |
| manifest=manifest, | |
| prompt_formatter=prompt_formatter, | |
| overwrite_manifest=overwrite_manifest, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| stop_sequences=stop_tokens, | |
| num_beams=num_beams, | |
| ) | |
| for _retrieved_docs, _text_to_sql_in in tqdm( | |
| zip(retrieved_docs, text_to_sql_in), | |
| desc="Generating SQL", | |
| total=len(text_to_sql_in), | |
| disable=(len(text_to_sql_in) <= 1), | |
| ) | |
| ] | |
| assert len(instruction_to_sql_resps) == len(text_to_sql_in) | |
| sql_statements = [] | |
| for i in range(len(instruction_to_sql_resps)): | |
| sql_statement = instruction_to_sql_resps[i].output.strip() | |
| if "<>" in sql_statement: | |
| sql_statement.replace("<>", "!=") | |
| # Models sometime train to predict <databasename/schema> | <sql> | |
| sql_statement = sql_statement.split("|")[-1].strip() | |
| sql_statements.append(sql_statement) | |
| return list(zip(sql_statements, instruction_to_sql_resps)) | |
| def get_text_to_sql_in( | |
| input_question: dict, db_to_tables: dict[str, dict[str, Table]] | |
| ) -> TextToSQLParams: | |
| """Format input question for text2sql function.""" | |
| question = input_question["question"] | |
| db_id = input_question.get("db_id", None) | |
| if db_id != "none": | |
| table_params = list(db_to_tables.get(db_id, {}).values()) | |
| else: | |
| table_params = [] | |
| if len(table_params) == 0: | |
| console.print(f"[red] WARNING: No tables found for {db_id} [/red]") | |
| text_to_sql_in = TextToSQLParams( | |
| instruction=question, | |
| database=db_id, | |
| tables=table_params, | |
| ) | |
| return text_to_sql_in | |
| def cli() -> None: | |
| """Entrypoint.""" | |
| pass | |
| # Format options | |
| # Prompt options | |
| # use whatever is in manifest | |
| # Docs options | |
| # Manifest options | |
| def predict( | |
| dataset_path: str, | |
| table_meta_path: str, | |
| output_dir: str, | |
| run_name: str, | |
| num_run: int, | |
| num_print: int, | |
| prompt_format: str, | |
| stop_tokens: list[str], | |
| max_tokens: int, | |
| temperature: float, | |
| num_beams: int, | |
| max_context_length: int, | |
| markdown_docs_path: Path, | |
| num_retrieved_docs: int, | |
| manifest_client: str, | |
| manifest_engine: str, | |
| manifest_connection: str, | |
| overwrite_manifest: bool, | |
| parallel: bool, | |
| ) -> None: | |
| """Predict SQL. | |
| Args: | |
| dataset_path: the dataset path. | |
| table_meta_path: the json path of the table metadata. | |
| database_path: the database path for sqlite. | |
| output_dir: the prediction output directory | |
| run_name: special prefix to add to filename | |
| num_run: the number of examples to run | |
| num_print: the number of examples to print | |
| prompt_format: the format of the prompt. E.g., "rajkumar" | |
| stop_tokens: the stop tokens to try | |
| max_tokens: the max tokens | |
| temperature: the temperature | |
| num_beams: the number of beams | |
| max_context_length: max context length for demonstration truncation (-1 means None) | |
| markdown_docs_path: path to duckdb sql docs | |
| num_retrieved_docs: number of docs to retrieve | |
| manifest_client: the manifest client | |
| manifest_engine: the manifest engine | |
| manifest_connection: the manifest connection | |
| """ | |
| multiprocessing.set_start_method("spawn", force=True) | |
| random.seed(0) | |
| np.random.seed(0) | |
| locals_dict = locals() | |
| locals_dict["markdown_docs_path"] = str(markdown_docs_path) | |
| console.print(json.dumps(locals_dict, indent=2)) | |
| data_formatter = DefaultLoader() | |
| if prompt_format not in PROMPT_FORMATTERS: | |
| raise ValueError(f"Unknown prompt format {prompt_format}") | |
| prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
| # load manifest | |
| manifest = get_manifest( | |
| manifest_client=manifest_client, | |
| manifest_connection=manifest_connection, | |
| manifest_engine=manifest_engine, | |
| ) | |
| manifest_params = manifest.client_pool.get_current_client().get_model_params() | |
| console.print(f"Running with {manifest_params} manifest.") | |
| model_name = manifest_params.get("engine", manifest_params["model_name"]) | |
| if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint", "inference_api"}: | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if stop_tokens: | |
| stop_tokens = [st.strip("'") for st in stop_tokens] | |
| console.print(f"Stop tokens: {stop_tokens}") | |
| # Get output filename | |
| full_dataset_path = Path(dataset_path) | |
| # Get todays date | |
| date_today = datetime.datetime.now().strftime("%y-%m-%d") | |
| if run_name: | |
| run_name = f"{run_name}_" | |
| suffix = f"{run_name}{full_dataset_path.stem}_{date_today}.json" # noqa: E501 | |
| prefix = f"{prompt_format}_{num_retrieved_docs}docs" | |
| if manifest_client in {"openai", "openaiazure"}: | |
| middleix = manifest_engine | |
| elif manifest_client in {"huggingface", "ray"}: | |
| middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-") | |
| elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint", "inference_api"}: | |
| middleix = manifest_engine.split("/")[-1] | |
| else: | |
| raise ValueError(f"Unknown manifest client {manifest_client}") | |
| output_filename = f"{prefix}_{middleix}_{suffix}" | |
| console.print(f"Saving to {Path(output_dir) / output_filename}") | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| console.print("Loading metadata...") | |
| db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
| console.print("Loading data...") | |
| data = data_formatter.load_data(dataset_path) | |
| if num_run > 0: | |
| console.print(f"Running on {min(len(data), num_run)} examples") | |
| data = data[:num_run] | |
| original_data = data | |
| # load the examples | |
| console.print("Formatting data...") | |
| num_print = min(num_print, len(data)) | |
| token_lengths = [] | |
| text_to_sql_in = [ | |
| get_text_to_sql_in(input_question, db_to_tables) for input_question in data | |
| ] | |
| if num_retrieved_docs > 0: | |
| console.print("Loading documenration and indexing...") | |
| retrieved_docs = [] | |
| doc_contents = load_documentation(markdown_docs_path) | |
| chunked_docs = split_documents(doc_contents) | |
| embedded_docs, full_embedding_mat = embed_documents(chunked_docs) | |
| for i in tqdm(range(len(text_to_sql_in)), desc="Retrieving docs"): | |
| _, retrieved_docs_strings = query_docs( | |
| text_to_sql_in[i].instruction, | |
| embedded_docs, | |
| full_embedding_mat, | |
| top_n=num_retrieved_docs, | |
| ) | |
| retrieved_docs.append(retrieved_docs_strings) | |
| else: | |
| retrieved_docs = [[] for _ in range(len(text_to_sql_in))] | |
| for i in range(num_print): | |
| # Run a few to get some examples to print | |
| generated_responses = generate_sql( | |
| manifest=manifest, | |
| text_to_sql_in=[text_to_sql_in[i]], | |
| retrieved_docs=[retrieved_docs[i]], | |
| stop_tokens=stop_tokens, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| prompt_formatter=prompt_formatter, | |
| overwrite_manifest=overwrite_manifest, | |
| parallel=parallel, | |
| ) | |
| for prediction, model_response in generated_responses: | |
| prediction = re.sub(r"[\s\t\n]+", " ", prediction) | |
| token_lengths.append(len(tokenizer(prediction).input_ids)) | |
| console.print(f"[red]Prediction:[/red] {prediction}") | |
| if data[i].get("query") or data[i].get("sql"): | |
| console.print( | |
| "[purple]Gold:[/purple] " | |
| f"{data[i].get('query') or data[i].get('sql')}" | |
| ) | |
| console.print("\n****\n") | |
| # Run the entire thing now - the to_print results will be in cache and fast | |
| generated_sqls = generate_sql( | |
| manifest=manifest, | |
| text_to_sql_in=text_to_sql_in, | |
| retrieved_docs=retrieved_docs, | |
| stop_tokens=stop_tokens, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| prompt_formatter=prompt_formatter, | |
| overwrite_manifest=overwrite_manifest, | |
| parallel=parallel, | |
| ) | |
| with open(Path(output_dir) / output_filename, "w") as fout: | |
| for i, (prediction, model_response) in enumerate(generated_sqls): | |
| if isinstance(model_response.final_prompt, str): | |
| token_lengths.append( | |
| len(tokenizer(model_response.final_prompt).input_ids) | |
| ) | |
| else: | |
| for prompt in model_response.final_prompt: | |
| token_lengths.append(len(tokenizer(prompt["content"]).input_ids)) | |
| entry = { | |
| **original_data[i], | |
| "pred": prediction, | |
| "raw_pred": model_response.output, | |
| "raw_output": model_response.raw_output, | |
| "prompt": model_response.final_prompt, | |
| "tables": [tbl.dict() for tbl in text_to_sql_in[i].tables or []], | |
| } | |
| formatted_entry = data_formatter.format_output(entry) | |
| print(json.dumps(formatted_entry), file=fout) | |
| overflow = len([tl for tl in token_lengths if tl > 2048]) / len(token_lengths) | |
| console.print(f"Overflow 2048 prompt {100*overflow:.2f}%") | |
| console.print(f"Saved to {Path(output_dir) / output_filename}") | |
| if __name__ == "__main__": | |
| cli() | |