MedCRAFT-SFT-DPO

Introduction

MedCRAFT-SFT-DPO is a specialized medical Large Language Model (LLM) engineered to address the critical challenges in complex medical question answering (QA). It is built upon the novel MedCRAFT framework and fine-tuned from the powerful meta-llama/Llama-3.1-8B base model, which focuses on constructing high-quality, complex medical QA datasets.

Real-world medical inquiries are often intricate, involving multi-dimensional information and diverse constraints. Effectively handling these multi-faceted constraints is crucial for delivering safe and accurate responses, a task where traditional medical LLMs often struggle due to limitations in training data complexity and quality.

The MedCRAFT framework tackles these challenges through key methodologies:

  • Constraint-Driven Instruction Evolution: Utilizing a hierarchical taxonomy of medical constraints to systematically generate diverse and complex instructions.
  • Dual-Verified Response Generation: Implementing a rigorous verification process to ensure the high quality and safety of generated responses.
  • Constraint-Enhanced Instruction Expansion: Expanding the instruction set with additional relevant constraints to enhance the model's understanding and coverage of complex scenarios.

This MedCRAFT-SFT-DPO model has been further refined using Supervised Fine-Tuning (SFT) and Direct Preference Optimization (DPO) on the high-quality, constraint-enhanced datasets generated by the MedCRAFT framework. This training approach optimizes the model's ability to perform robust medical reasoning and provide precise, context-aware answers.

Usage

You can use MedCRAFT-SFT-DPO in a similar manner to other instruction-tuned LLMs (e.g., Llama-3.1-8B-Instruct).

# Example usage
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "sherry0213/MedCRAFT-SFT-DPO"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16, # or torch.float16 depending on your hardware
    device_map="auto", # or specify a device like "cuda"
)

# Example chat template (replace with the specific template/system prompt from our documentation)
messages = [
    {"role": "system", "content": "You are a helpful and accurate medical assistant. Provide concise and medically sound advice, considering all given constraints."},
    {"role": "user", "content": "My baby is 3 weeks old and started having forceful vomiting after feeding yesterday. He seems hungry again right after throwing up. What should I do?"}
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids,
    max_new_tokens=2048,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id,
)

response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response)
Downloads last month
16
Safetensors
Model size
8.03B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sherry0213/MedCRAFT-SFT-DPO

Finetuned
(1400)
this model

Datasets used to train sherry0213/MedCRAFT-SFT-DPO