Fine tuning example(s)?
#2
by
hackr
- opened
Hi,
Just wondering if there are any fine tuning samples. Just asking because I hit a snag updating my own template to use this model. I'm sure I'll resolve the issue.
Thanks.
Sure, here's a sample fine-tuning code that has been tested on a single MI250 GPU.
(We are preparing a proper GitHub repository with fine-tuning and reinforcement learning examples.)
# package requirements
# pip install transformers datasets loguru kernels tqdm
import argparse
import os
from functools import partial
import kernels
import torch
from datasets import load_dataset
from loguru import logger
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
optimizer_kernels = kernels.get_kernel("Motif-Technologies/optimizer")
def collate_fn(samples, tokenizer):
inp, attn_mask = [], []
for x in samples:
message = [
{"role": "system", "content": "you are an helpful assistant"},
{"role": "user", "content": x["input"]},
{"role": "assistant", "content": x["output"]},
]
chat = tokenizer.apply_chat_template(message, tokenize=False)
single_batch = tokenizer(
chat,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
inp.append(single_batch["input_ids"])
attn_mask.append(single_batch["attention_mask"])
return torch.concatenate(inp, dim=0), torch.concatenate(attn_mask, dim=0)
def main(args):
# this demo will use 100 samples of origin data
# downloading the dataset will consume about 2.5 gb of your storage
train_dataset = load_dataset(
"nvidia/AceReason-1.1-SFT", split="train[:100]"
)
total_iters = len(train_dataset) // args.batchsize
# loading model
# due to tie-weights, you may see logs like below(which can be ignored, in progress)
# Some weights of MotifForCausalLM were not initialized from the model checkpoint at Motif-Technologies/Motif-2.6B and are newly initialized: ['lm_head.weight']
# You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model = AutoModelForCausalLM.from_pretrained(
"Motif-Technologies/Motif-2.6B",
trust_remote_code=True,
_attn_implementation="eager", # also supports flash_attention_2, install if interested
torch_dtype="bfloat16", # used bfloat16 for 1-gpu budget, but you are free to use float32
).to("cuda")
# loading tokenizer
# maybe you want to apply your own chat template here, for example
# tokenizer.chat_template = "some_jinja_template"
tokenizer = AutoTokenizer.from_pretrained(
"Motif-Technologies/Motif-2.6B",
trust_remote_code=True,
)
# defining dataloader, optimizer and scheduler
dataloader = DataLoader(
train_dataset,
batch_size=args.batchsize,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
# you can use muon optimizer(thanks to
@iamwyldecat
), if interested
if args.optimizer == "AdamW":
optimizer = AdamW(model.parameters(), lr=args.lr)
elif args.optimizer == "MuonWithAuxAdam":
optimizer = optimizer_kernels.Muon(
model=model,
is_muon_func=lambda x, name: x.ndim >= 2
and "embed_tokens" not in name
and "lm_head" not in name,
lr=args.lr,
momentum=0.95,
nesterov=True,
ns_steps=5,
weight_decay=0.01,
adamw_betas=(0.9, 0.99),
adamw_eps=1e-12,
)
else:
raise ValueError(f"not supported optimizer {args.optimizer}")
lr_scheduler = LinearLR(optimizer=optimizer, total_iters=total_iters, last_epoch=-1)
# train loop starts
logger.info("=====TRAIN START=====")
for epoch in range(args.epochs):
for idx, batch in enumerate(dataloader):
loss = model(
input_ids=batch[0].cuda(),
labels=batch[0].cuda(),
attention_mask=batch[1].cuda(),
# return_dict = True
).loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
logger.info(
f"TRAIN | {epoch + 1}/{args.epochs} epochs | {idx + 1}/{total_iters} steps | loss: {loss.item()} | lr: {lr_scheduler.get_lr()[0]}"
)
# save trained model & tokenizer
curr_path = os.getcwd()
save_path = os.path.join(curr_path, "./exp01")
os.mkdir(save_path)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
logger.info("=====TRAIN COMPLETE=====")
if __name__ == "__main__":
# define argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", "-e", type=int, default=1)
parser.add_argument("--batchsize", "-b", type=int, default=4)
parser.add_argument("--lr", "-l", type=float, default=5e-5)
parser.add_argument(
"--optimizer",
"-o",
type=str,
default="AdamW",
choices=["AdamW", "MuonWithAuxAdam"],
)
args = parser.parse_args()
main(args)