MedCRAFT-SFT-DPO
Introduction
MedCRAFT-SFT-DPO is a specialized medical Large Language Model (LLM) engineered to address the critical challenges in complex medical question answering (QA). It is built upon the novel MedCRAFT framework and fine-tuned from the powerful meta-llama/Llama-3.1-8B base model, which focuses on constructing high-quality, complex medical QA datasets.
Real-world medical inquiries are often intricate, involving multi-dimensional information and diverse constraints. Effectively handling these multi-faceted constraints is crucial for delivering safe and accurate responses, a task where traditional medical LLMs often struggle due to limitations in training data complexity and quality.
The MedCRAFT framework tackles these challenges through key methodologies:
- Constraint-Driven Instruction Evolution: Utilizing a hierarchical taxonomy of medical constraints to systematically generate diverse and complex instructions.
- Dual-Verified Response Generation: Implementing a rigorous verification process to ensure the high quality and safety of generated responses.
- Constraint-Enhanced Instruction Expansion: Expanding the instruction set with additional relevant constraints to enhance the model's understanding and coverage of complex scenarios.
This MedCRAFT-SFT-DPO
model has been further refined using Supervised Fine-Tuning (SFT) and Direct Preference Optimization (DPO) on the high-quality, constraint-enhanced datasets generated by the MedCRAFT framework. This training approach optimizes the model's ability to perform robust medical reasoning and provide precise, context-aware answers.
Usage
You can use MedCRAFT-SFT-DPO in a similar manner to other instruction-tuned LLMs (e.g., Llama-3.1-8B-Instruct).
# Example usage
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "sherry0213/MedCRAFT-SFT-DPO"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # or torch.float16 depending on your hardware
device_map="auto", # or specify a device like "cuda"
)
# Example chat template (replace with the specific template/system prompt from our documentation)
messages = [
{"role": "system", "content": "You are a helpful and accurate medical assistant. Provide concise and medically sound advice, considering all given constraints."},
{"role": "user", "content": "My baby is 3 weeks old and started having forceful vomiting after feeding yesterday. He seems hungry again right after throwing up. What should I do?"}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=2048,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response)
- Downloads last month
- 16
Model tree for sherry0213/MedCRAFT-SFT-DPO
Base model
meta-llama/Llama-3.1-8B