Fine-Tuning Mistral-7B on Clinical Notes with QLoRA

πŸš€ Project Overview

This project focuses on fine-tuning Mistral-7B, a powerful language model, on synthetic clinical notes using QLoRA (4-bit quantization). The goal is to make the model more useful for clinical NLP tasks, such as answering questions based on patient discharge summaries.

I used the "starmpcc/Asclepius-Synthetic-Clinical-Notes" dataset from Hugging Face. The training process was run on Kaggle T4 GPUs, using LoRA adapters to efficiently fine-tune specific parts of the model.


πŸ› οΈ Fine-Tuning Strategy

1️⃣ Using QLoRA for Efficient Training

  • The model was quantized to 4-bit using bitsandbytes, which helps reduce memory requirements.
  • LoRA adapters were applied to q_proj and v_proj layers to fine-tune the most critical attention components.
  • This approach allowed me to train on a large model without exceeding Kaggle's GPU limits.

2️⃣ Dataset & Preprocessing

  • The dataset consists of synthetic clinical notes where each example includes:

    • Discharge Summary: A short medical note.
    • Instruction: A question or instruction from a healthcare professional.
    • Response: The correct answer.
  • I structured the prompt to make the model follow a clear format:

    You are an intelligent clinical language model. Below is a snippet of a patient's discharge summary and a following instruction from a healthcare professional. Write a response that appropriately completes the instruction. The response should provide the accurate answer to the instruction, while being concise.
    
    [Discharge Summary Begin]
    {note}
    [Discharge Summary End]
    
    [Instruction Begin]
    {question}
    [Instruction End]
    
  • Tokenization was done with padding and truncation to a max length of 640 tokens to handle long medical notes.

3️⃣ Hyperparameters Used

Parameter Value
Learning Rate 3e-5
Optimizer paged_adamw_32bit
Batch Size 12 (train) / 4 (eval)
Epochs 10 planned, but monitoring early stopping
Gradient Accumulation 2
Weight Decay 0.01
Scheduler Cosine with warmup
Warmup Ratio 0.1
Mixed Precision FP16
Gradient Checkpointing Enabled

πŸ“ˆ Training Results So Far

Current Progress (Checkpoints Saved: 2 Epochs)

Loss Values:

Epoch Training Loss Validation Loss
1 1.39 1.29
2 1.16 1.18
  • The model is steadily improving and approaching the 1.0 loss mark, which is good for LLM fine-tuning.

  • I plan to train for 3 more epochs, but I will stop early if validation loss stops improving to avoid overfitting.


⏳ Training Time per Epoch

On Kaggle's T4x2 GPU setup, training time per epoch is approximately:

  • ~3 hours per epoch (varies slightly depending on GPU load).
  • With planned training, full training (10 epochs) would take ~30 hours, but I might adjust based on loss trends.

Next Steps :

  • Continue training for 3 more epochs (or stop early if overfitting is detected).
  • Evaluate model outputs to check response quality and structure.
  • Add better formatting rules in the prompt to ensure clear, structured responses.
  • Consider safety disclaimers to handle medical-related uncertainty in real-world applications.


Downloads last month
7
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Yuvrajxms09/MediMistral

Finetuned
(959)
this model

Dataset used to train Yuvrajxms09/MediMistral