Qwen-2.5-3B-Instruct Based Text-to-SQL Generation Model Aligned with Multiple Reward Functions via GRPO
This model is RL-tuned using GRPO to produce Reasoning based SQL Queries as an output.
You can use the same system
prompt or modify as needed.
Just by entering the SCHEMAS
and QUESTION
in the format below as part of the user
prompt, you'll be able to generate the required SQL Query that answers the question
along with the model's reasoning traces.
Quick start
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
def create_prompt(schemas, question):
prompt = [
{
'role': 'system',
'content': """\
You are an expert SQL Query Writer.
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
Remember that you should place all your reasoning between <reason> and </reason> tags.
Also, you should provide your solution between <answer> and </answer> tags.
An example generation is as follows:
<reason>
This is a sample reasoning that solves the question based on the schema.
</reason>
<answer>
SELECT
COLUMN
FROM TABLE_NAME
WHERE
CONDITION
</answer>"""
},
{
'role': 'user',
'content': f"""\
SCHEMAS:
---------------
{schemas}
---------------
QUESTION: "{question}"\
"""
}
]
return prompt
schemas = """\
CREATE TABLE lab (
subject_id text,
hadm_id text,
itemid int,
charttime date,
flag bool,
value_unit int,
label text,
fluid text
)
CREATE TABLE diagnoses (
subject_id text,
hadm_id text,
icd9_code text,
short_title text,
long_title text
)
CREATE TABLE procedures (
subject_id text,
hadm_id text,
icd9_code text,
short_title text,
long_title text
)
CREATE TABLE demographic (
subject_id text,
hadm_id text,
name text,
marital_status text,
age int,
dob date,
gender text,
language text,
religion text,
admission_type text,
days_stay text,
insurance text,
ethnicity text,
expire_flag bool,
admission_location text,
discharge_location text,
diagnosis text,
dod date,
dob_year date,
dod_year date,
admittime date,
dischtime date,
admityear int
)
CREATE TABLE prescriptions (
subject_id text,
hadm_id text,
icustay_id text,
drug_type text,
drug text,
formulary_drug_cd text,
route text,
drug_dose text
)\
"""
question = "How many patients whose admission type is emergency and diagnoses icd9 code is 56210?"
example_prompt = create_prompt(schemas, question)
streamer = TextStreamer(tokenizer, skip_prompt=True)
inputs = tokenizer.apply_chat_template(example_prompt,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt")
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=1024, streamer=streamer)
outputs = tokenizer.batch_decode(outputs)
print(outputs[0].split("<|im_start|>assistant")[-1])
###########OUTPUT###########
<reason>
To answer this question, we need to perform the following steps:
1. Identify patients who have an 'emergency' admission type from the `demographic` table.
2. Identify patients who have the ICD-9 code '56210' in their `diagnosis` field from the same `demographic` table.
3. Find the intersection of these two groups by joining the results of the above queries.
4. Count the number of unique patients who meet both criteria.
We can achieve this using a combination of JOIN operations in our SQL query.
</reason>
<answer>
SELECT
COUNT(DISTINCT d.subject_id)
FROM demographic AS d
JOIN diagnoses AS di
ON d.subject_id = di.subject_id AND d.hadm_id = di.hadm_id
WHERE
d.admission_type = 'Emergency' AND di.icd9_code = '56210'
</answer>
Designed and Developed with โฅ by Praneet | LinkedIn | GitHub
- Downloads last month
- 13