ZamAI-Mistral-7B-Pashto / autotrain_finetune.py
tasal9's picture
Push content to Hugging Face model repo
0e4235d
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script to start a Mistral-7B fine-tuning job on Hugging Face AutoTrain
This uses the AutoTrain API to launch a fine-tuning job in the cloud
"""
import argparse
import os
from autotrain.trainers.clm.__main__ import train
def parse_args():
parser = argparse.ArgumentParser(description="Start a fine-tuning job on Hugging Face AutoTrain")
parser.add_argument(
"--project_name",
type=str,
default="ZamAI-Mistral-7B-Pashto",
help="Name for your AutoTrain project"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mistral-7B-v0.1",
help="Base model to fine-tune"
)
parser.add_argument(
"--dataset",
type=str,
default="tasal9/ZamAI_Pashto_Training",
help="Dataset on HF Hub to use for training"
)
parser.add_argument(
"--text_column",
type=str,
default="text",
help="Column name containing the formatted text"
)
parser.add_argument(
"--lr",
type=float,
default=2e-4,
help="Learning rate for training"
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="Number of training epochs"
)
parser.add_argument(
"--lora_r",
type=int,
default=16,
help="LoRA attention dimension"
)
parser.add_argument(
"--lora_alpha",
type=int,
default=32,
help="LoRA alpha parameter"
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.05,
help="LoRA attention dropout"
)
parser.add_argument(
"--hf_token",
type=str,
help="Hugging Face API token (required)"
)
return parser.parse_args()
def main():
args = parse_args()
if not args.hf_token:
print("Error: Hugging Face API token is required")
print("Get your token from https://huggingface.co/settings/tokens")
return
print(f"Starting fine-tuning job for {args.model} on dataset {args.dataset}")
# Define the training configuration
config = {
"model": args.model,
"data_path": args.dataset,
"text_column": args.text_column,
"project_name": args.project_name,
"token": args.hf_token,
"lr": args.lr,
"epochs": args.epochs,
"push_to_hub": True,
"repo_id": f"tasal9/{args.project_name}",
"trainer": "sft",
"peft": True,
"lora_r": args.lora_r,
"lora_alpha": args.lora_alpha,
"lora_dropout": args.lora_dropout,
"batch_size": 4,
"block_size": 1024,
"logging_steps": 10,
"target_modules": "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
"log": "wandb" # You can change to "tensorboard" or "none"
}
# Run the training job
train(config)
print("\n==== Training Job Submitted ====")
print(f"Project: {args.project_name}")
print(f"You can monitor your training job at: https://ui.autotrain.huggingface.co/projects")
print(f"Your fine-tuned model will be pushed to: https://huggingface.co/tasal9/{args.project_name}")
if __name__ == "__main__":
main()