Fine Tuned Mistral-7B for Indian Law
A fine-tuned version of the Mistral 7B
model, optimized for understanding and generating responses related to Indian law using Parameter-Efficient Fine-Tuning (PEFT) with QLoRA and LoRA techniques.
Model Details
- Task: Legal Text Understanding and Generation
- Fine-Tuning Dataset: Custom Indian Law Corpus (jizzu/llama2_indian_law_v3)
- Fine-Tuning Method: PEFT with QLoRA and LoRA
- Perplexity Score: 37.32
- Repository: mistralai/Mistral-7B-v0.1
Out-of-Scope Use
- May struggle with highly ambiguous legal queries or non-Indian legal systems.
- Perplexity suggests potential for improvement with extended training or data.
How to Get Started with the Model
Use the code below to get started with the model.
Use pip install transformers peft torch
, use torch with cuda
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
model_name = "ajay-drew/midtral-7b-indian-law"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-7B-v0.1")
# Load fine-tuned weights with PEFT
model = PeftModel.from_pretrained(base_model, model_name)
text = "What is the penalty for using forged document? " # Ask custom questions on Indian Law
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Metrics
To check the perplexity of the model use the below code after you run pip install transformers datasets torch
use torch with cuda support for reduced metrics check.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import torch
dataset = load_dataset("kshitij230/Indian-Law", split="train")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model_name = "ajay-drew/Mistral-7B-Indian-Law"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
total_loss = 0
total_tokens = 0
test_texts = dataset['Instruction'][:500]
with torch.no_grad():
for text in test_texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
if loss is not None: # Ensure loss is valid
total_loss += loss.item() * inputs["input_ids"].size(1)
total_tokens += inputs["input_ids"].size(1)
if total_tokens > 0:
perplexity = torch.exp(torch.tensor(total_loss / total_tokens)).item()
print(f"Perplexity: {perplexity}")
print(f"Total tokens: {total_tokens}")
print(f"Total loss: {total_loss}")
else:
print("Error: No tokens processed. Check dataset or tokenization.")
Results
Hardware Used
- Hardware Type: NVIDIA GeForce RTX 4050 Laptop GPU
- Hours used: 24:19:47
Model Architecture
- Base Model:
Mistral 7B
(a transformer-based language model with 7 billion parameters) - Architecture:
- Decoder-only transformer with multi-head self-attention layers.
- 32 layers, 4096 hidden size, and 16 attention heads (inherited from Mistral 7B).
- Modified with Low-Rank Adaptation (LoRA) layers for efficient fine-tuning.
- Fine-Tuning Approach:
- PEFT: Parameter-Efficient Fine-Tuning to reduce memory footprint.
- QLoRA: Quantized LoRA, using 4-bit quantization to adapt weights efficiently.
- Parameters Fine-Tuned: LoRA targets specific weight matrices leaving the base model largely frozen.
Software
- CUDA: PyTorch - 2.6.0+cu126
Model Card Contact
- Gmail: [email protected]
- GitHuB: github.com/ajay-drew
- Linkedin: linkedin.com/in/ajay-a-133b1326a/
Framework versions
- PEFT 0.14.0
- Downloads last month
- 61
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
HF Inference deployability: The HF Inference API does not support question-answering models for peft
library.
Model tree for ajay-drew/Mistral-7B-Indian-Law
Base model
mistralai/Mistral-7B-v0.1