🧠 LoRA Fine-Tuned Mistral-7B on MTS-Dialog

This repository contains a LoRA fine-tuned version of mistralai/Mistral-7B-v0.1 for medical dialogue summarization, trained on the MTS-Dialog dataset.


πŸ”— Resources


πŸ“˜ Model Summary

  • Base Model: mistralai/Mistral-7B-v0.1
  • Fine-tuning Method: LoRA (Low-Rank Adaptation)
  • Frameworks: πŸ€— Transformers, PEFT, bitsandbytes
  • Quantization: 4-bit
  • Task: Medical dialogue summarization
  • Dataset: MTS-Dialog

πŸ₯ Task Description

This model is trained to summarize doctor-patient conversations into concise clinical notes, categorized by sections such as GENHX, HPI, ROS, etc. These summaries assist with EHR documentation and clinical decision-making.


βš™οΈ Training Configuration

Parameter Value
LoRA Rank 4
Epochs 3
Batch Size 4 (Γ—4 grad. acc.)
Learning Rate 3e-4
Device CUDA:0
Quantization 4-bit (bnb)

⚠️ Due to limited GPU resources (office laptop), training was constrained to 3 epochs and a small LoRA rank. Performance is expected to improve significantly with extended training and better hardware.


πŸ“Š Evaluation Metrics

Metric Score
ROUGE-1 0.1318
ROUGE-2 0.0456
ROUGE-L 0.0900
BLEU 0.0260

πŸ’‘ Example Prompt

Summarize the following dialogue for section: GENHX
Doctor: What brings you back into the clinic today, miss?
Patient: I've had chest pain for the last few days.
Doctor: When did it start?
Summary:


## πŸ§ͺ Inference Code

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", load_in_4bit=True)
model = PeftModel.from_pretrained(model, "Imsachinsingh00/Fine_tuned_LoRA_Mistral_MTSDialog_Summarization")
model.eval()

tokenizer = AutoTokenizer.from_pretrained("Imsachinsingh00/Fine_tuned_LoRA_Mistral_MTSDialog_Summarization")

prompt = "Summarize the following dialogue for section: HPI\nDoctor: Hello, what brings you in?\nPatient: I've been dizzy for two days.\nSummary:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=150)
print(tokenizer.decode(output[0], skip_special_tokens=True))


## πŸ“ Included Files

- `config.json` – PEFT configuration for LoRA  
- `adapter_model.bin` – LoRA adapter weights  
- `tokenizer/` – Tokenizer files  
- `README.md` – This model card  

## πŸ“Œ Notes

- 🚫 This is not a fully optimized clinical model β€” only a proof of concept.  
- πŸ’‘ Consider training longer (`epochs=10`, `rank=8`) on GPUs with higher VRAM for better results.  
Downloads last month
2
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support