Qwen-2.5-3B-Instruct Based Text-to-SQL Generation Model Aligned with Multiple Reward Functions via GRPO

This model is RL-tuned using GRPO to produce Reasoning based SQL Queries as an output.

You can use the same system prompt or modify as needed.

Just by entering the SCHEMAS and QUESTION in the format below as part of the user prompt, you'll be able to generate the required SQL Query that answers the question along with the model's reasoning traces.

Quick start

import torch

from peft import PeftModel

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer


model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)

def create_prompt(schemas, question):
    prompt = [
      {
        'role': 'system',
        'content': """\
You are an expert SQL Query Writer. 
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.

Remember that you should place all your reasoning between <reason> and </reason> tags.
Also, you should provide your solution between <answer> and </answer> tags.

An example generation is as follows:
<reason>
This is a sample reasoning that solves the question based on the schema.
</reason>
<answer>
SELECT
    COLUMN
FROM TABLE_NAME
WHERE
    CONDITION
</answer>"""
      },
      {
        'role': 'user',
        'content': f"""\
SCHEMAS:
---------------

{schemas}

---------------

QUESTION: "{question}"\
"""
      }
    ]

    return prompt


schemas = """\
CREATE TABLE lab (
    subject_id text,
    hadm_id text,
    itemid int,
    charttime date,
    flag bool,
    value_unit int,
    label text,
    fluid text
)

CREATE TABLE diagnoses (
    subject_id text,
    hadm_id text,
    icd9_code text,
    short_title text,
    long_title text
)

CREATE TABLE procedures (
    subject_id text,
    hadm_id text,
    icd9_code text,
    short_title text,
    long_title text
)

CREATE TABLE demographic (
    subject_id text,
    hadm_id text,
    name text,
    marital_status text,
    age int,
    dob date,
    gender text,
    language text,
    religion text,
    admission_type text,
    days_stay text,
    insurance text,
    ethnicity text,
    expire_flag bool,
    admission_location text,
    discharge_location text,
    diagnosis text,
    dod date,
    dob_year date,
    dod_year date,
    admittime date,
    dischtime date,
    admityear int
)

CREATE TABLE prescriptions (
    subject_id text,
    hadm_id text,
    icustay_id text,
    drug_type text,
    drug text,
    formulary_drug_cd text,
    route text,
    drug_dose text
)\
"""

question = "How many patients whose admission type is emergency and diagnoses icd9 code is 56210?"

example_prompt = create_prompt(schemas, question)

streamer = TextStreamer(tokenizer, skip_prompt=True)

inputs = tokenizer.apply_chat_template(example_prompt,
                                       tokenize=True,
                                       add_generation_prompt=True,
                                       return_dict=True,
                                       return_tensors="pt")

with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=1024, streamer=streamer)

outputs = tokenizer.batch_decode(outputs)
print(outputs[0].split("<|im_start|>assistant")[-1])


###########OUTPUT###########
<reason>
To answer this question, we need to perform the following steps:

1. Identify patients who have an 'emergency' admission type from the `demographic` table.
2. Identify patients who have the ICD-9 code '56210' in their `diagnosis` field from the same `demographic` table.
3. Find the intersection of these two groups by joining the results of the above queries.
4. Count the number of unique patients who meet both criteria.

We can achieve this using a combination of JOIN operations in our SQL query.
</reason>
<answer>
SELECT
    COUNT(DISTINCT d.subject_id)
FROM demographic AS d
JOIN diagnoses AS di
    ON d.subject_id = di.subject_id AND d.hadm_id = di.hadm_id
WHERE
    d.admission_type = 'Emergency' AND di.icd9_code = '56210'
</answer>

Designed and Developed with โ™ฅ by Praneet | LinkedIn | GitHub

Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO

Base model

Qwen/Qwen2.5-3B
Finetuned
(768)
this model

Space using DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO 1