|
|
|
|
|
|
|
""" |
|
Script to start a Mistral-7B fine-tuning job on Hugging Face AutoTrain |
|
This uses the AutoTrain API to launch a fine-tuning job in the cloud |
|
""" |
|
|
|
import argparse |
|
import os |
|
from autotrain.trainers.clm.__main__ import train |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Start a fine-tuning job on Hugging Face AutoTrain") |
|
parser.add_argument( |
|
"--project_name", |
|
type=str, |
|
default="ZamAI-Mistral-7B-Pashto", |
|
help="Name for your AutoTrain project" |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="mistralai/Mistral-7B-v0.1", |
|
help="Base model to fine-tune" |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="tasal9/ZamAI_Pashto_Training", |
|
help="Dataset on HF Hub to use for training" |
|
) |
|
parser.add_argument( |
|
"--text_column", |
|
type=str, |
|
default="text", |
|
help="Column name containing the formatted text" |
|
) |
|
parser.add_argument( |
|
"--lr", |
|
type=float, |
|
default=2e-4, |
|
help="Learning rate for training" |
|
) |
|
parser.add_argument( |
|
"--epochs", |
|
type=int, |
|
default=3, |
|
help="Number of training epochs" |
|
) |
|
parser.add_argument( |
|
"--lora_r", |
|
type=int, |
|
default=16, |
|
help="LoRA attention dimension" |
|
) |
|
parser.add_argument( |
|
"--lora_alpha", |
|
type=int, |
|
default=32, |
|
help="LoRA alpha parameter" |
|
) |
|
parser.add_argument( |
|
"--lora_dropout", |
|
type=float, |
|
default=0.05, |
|
help="LoRA attention dropout" |
|
) |
|
parser.add_argument( |
|
"--hf_token", |
|
type=str, |
|
help="Hugging Face API token (required)" |
|
) |
|
return parser.parse_args() |
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
if not args.hf_token: |
|
print("Error: Hugging Face API token is required") |
|
print("Get your token from https://huggingface.co/settings/tokens") |
|
return |
|
|
|
print(f"Starting fine-tuning job for {args.model} on dataset {args.dataset}") |
|
|
|
|
|
config = { |
|
"model": args.model, |
|
"data_path": args.dataset, |
|
"text_column": args.text_column, |
|
"project_name": args.project_name, |
|
"token": args.hf_token, |
|
"lr": args.lr, |
|
"epochs": args.epochs, |
|
"push_to_hub": True, |
|
"repo_id": f"tasal9/{args.project_name}", |
|
"trainer": "sft", |
|
"peft": True, |
|
"lora_r": args.lora_r, |
|
"lora_alpha": args.lora_alpha, |
|
"lora_dropout": args.lora_dropout, |
|
"batch_size": 4, |
|
"block_size": 1024, |
|
"logging_steps": 10, |
|
"target_modules": "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", |
|
"log": "wandb" |
|
} |
|
|
|
|
|
train(config) |
|
|
|
print("\n==== Training Job Submitted ====") |
|
print(f"Project: {args.project_name}") |
|
print(f"You can monitor your training job at: https://ui.autotrain.huggingface.co/projects") |
|
print(f"Your fine-tuned model will be pushed to: https://huggingface.co/tasal9/{args.project_name}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|