Fine-Tuning Mistral-7B on Clinical Notes with QLoRA
π Project Overview
This project focuses on fine-tuning Mistral-7B, a powerful language model, on synthetic clinical notes using QLoRA (4-bit quantization). The goal is to make the model more useful for clinical NLP tasks, such as answering questions based on patient discharge summaries.
I used the "starmpcc/Asclepius-Synthetic-Clinical-Notes" dataset from Hugging Face. The training process was run on Kaggle T4 GPUs, using LoRA adapters to efficiently fine-tune specific parts of the model.
π οΈ Fine-Tuning Strategy
1οΈβ£ Using QLoRA for Efficient Training
- The model was quantized to 4-bit using
bitsandbytes
, which helps reduce memory requirements. - LoRA adapters were applied to
q_proj
andv_proj
layers to fine-tune the most critical attention components. - This approach allowed me to train on a large model without exceeding Kaggle's GPU limits.
2οΈβ£ Dataset & Preprocessing
The dataset consists of synthetic clinical notes where each example includes:
- Discharge Summary: A short medical note.
- Instruction: A question or instruction from a healthcare professional.
- Response: The correct answer.
I structured the prompt to make the model follow a clear format:
You are an intelligent clinical language model. Below is a snippet of a patient's discharge summary and a following instruction from a healthcare professional. Write a response that appropriately completes the instruction. The response should provide the accurate answer to the instruction, while being concise. [Discharge Summary Begin] {note} [Discharge Summary End] [Instruction Begin] {question} [Instruction End]
Tokenization was done with padding and truncation to a max length of 640 tokens to handle long medical notes.
3οΈβ£ Hyperparameters Used
Parameter | Value |
---|---|
Learning Rate | 3e-5 |
Optimizer | paged_adamw_32bit |
Batch Size | 12 (train) / 4 (eval) |
Epochs | 10 planned, but monitoring early stopping |
Gradient Accumulation | 2 |
Weight Decay | 0.01 |
Scheduler | Cosine with warmup |
Warmup Ratio | 0.1 |
Mixed Precision | FP16 |
Gradient Checkpointing | Enabled |
π Training Results So Far
Current Progress (Checkpoints Saved: 2 Epochs)
Loss Values:
Epoch | Training Loss | Validation Loss |
---|---|---|
1 | 1.39 | 1.29 |
2 | 1.16 | 1.18 |
The model is steadily improving and approaching the 1.0 loss mark, which is good for LLM fine-tuning.
I plan to train for 3 more epochs, but I will stop early if validation loss stops improving to avoid overfitting.
β³ Training Time per Epoch
On Kaggle's T4x2 GPU setup, training time per epoch is approximately:
- ~3 hours per epoch (varies slightly depending on GPU load).
- With planned training, full training (10 epochs) would take ~30 hours, but I might adjust based on loss trends.
Next Steps :
- Continue training for 3 more epochs (or stop early if overfitting is detected).
- Evaluate model outputs to check response quality and structure.
- Add better formatting rules in the prompt to ensure clear, structured responses.
- Consider safety disclaimers to handle medical-related uncertainty in real-world applications.
- Downloads last month
- 7
Model tree for Yuvrajxms09/MediMistral
Base model
mistralai/Mistral-7B-v0.1