Fine tuning example(s)?

#2
by hackr - opened

Hi,

Just wondering if there are any fine tuning samples. Just asking because I hit a snag updating my own template to use this model. I'm sure I'll resolve the issue.

Thanks.

Motif Technologies org
β€’
edited 4 days ago

Sure, here's a sample fine-tuning code that has been tested on a single MI250 GPU.
(We are preparing a proper GitHub repository with fine-tuning and reinforcement learning examples.)

# package requirements
# pip install transformers datasets loguru kernels tqdm

import argparse
import os
from functools import partial

import kernels
import torch
from datasets import load_dataset
from loguru import logger
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

optimizer_kernels = kernels.get_kernel("Motif-Technologies/optimizer")


def collate_fn(samples, tokenizer):
    inp, attn_mask = [], []
    for x in samples:
        message = [
            {"role": "system", "content": "you are an helpful assistant"},
            {"role": "user", "content": x["input"]},
            {"role": "assistant", "content": x["output"]},
        ]
        chat = tokenizer.apply_chat_template(message, tokenize=False)
        single_batch = tokenizer(
            chat,
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        inp.append(single_batch["input_ids"])
        attn_mask.append(single_batch["attention_mask"])

    return torch.concatenate(inp, dim=0), torch.concatenate(attn_mask, dim=0)


def main(args):
    # this demo will use 100 samples of origin data
    # downloading the dataset will consume about 2.5 gb of your storage
    train_dataset = load_dataset(
        "nvidia/AceReason-1.1-SFT", split="train[:100]"
    )
    total_iters = len(train_dataset) // args.batchsize

    # loading model
    # due to tie-weights, you may see logs like below(which can be ignored, in progress)
    #    Some weights of MotifForCausalLM were not initialized from the model checkpoint at Motif-Technologies/Motif-2.6B and are newly initialized: ['lm_head.weight']
    #    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    model = AutoModelForCausalLM.from_pretrained(
        "Motif-Technologies/Motif-2.6B",
        trust_remote_code=True,
        _attn_implementation="eager",  # also supports flash_attention_2, install if interested
        torch_dtype="bfloat16",  # used bfloat16 for 1-gpu budget, but you are free to use float32
    ).to("cuda")

    # loading tokenizer
    # maybe you want to apply your own chat template here, for example
    # tokenizer.chat_template = "some_jinja_template"
    tokenizer = AutoTokenizer.from_pretrained(
        "Motif-Technologies/Motif-2.6B",
        trust_remote_code=True,
    )

    # defining dataloader, optimizer and scheduler
    dataloader = DataLoader(
        train_dataset,
        batch_size=args.batchsize,
        collate_fn=partial(collate_fn, tokenizer=tokenizer),
    )

    # you can use muon optimizer(thanks to 

@iamwyldecat
	), if interested
    if args.optimizer == "AdamW":
        optimizer = AdamW(model.parameters(), lr=args.lr)
    elif args.optimizer == "MuonWithAuxAdam":
        optimizer = optimizer_kernels.Muon(
            model=model,
            is_muon_func=lambda x, name: x.ndim >= 2
            and "embed_tokens" not in name
            and "lm_head" not in name,
            lr=args.lr,
            momentum=0.95,
            nesterov=True,
            ns_steps=5,
            weight_decay=0.01,
            adamw_betas=(0.9, 0.99),
            adamw_eps=1e-12,
        )
    else:
        raise ValueError(f"not supported optimizer {args.optimizer}")

    lr_scheduler = LinearLR(optimizer=optimizer, total_iters=total_iters, last_epoch=-1)

    # train loop starts
    logger.info("=====TRAIN START=====")
    for epoch in range(args.epochs):
        for idx, batch in enumerate(dataloader):
            loss = model(
                input_ids=batch[0].cuda(),
                labels=batch[0].cuda(),
                attention_mask=batch[1].cuda(),
                # return_dict = True
            ).loss

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            logger.info(
                f"TRAIN | {epoch + 1}/{args.epochs} epochs | {idx + 1}/{total_iters} steps | loss: {loss.item()} | lr: {lr_scheduler.get_lr()[0]}"
            )

    # save trained model & tokenizer
    curr_path = os.getcwd()
    save_path = os.path.join(curr_path, "./exp01")
    os.mkdir(save_path)
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    logger.info("=====TRAIN COMPLETE=====")


if __name__ == "__main__":
    # define argument parser
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", "-e", type=int, default=1)
    parser.add_argument("--batchsize", "-b", type=int, default=4)
    parser.add_argument("--lr", "-l", type=float, default=5e-5)
    parser.add_argument(
        "--optimizer",
        "-o",
        type=str,
        default="AdamW",
        choices=["AdamW", "MuonWithAuxAdam"],
    )
    args = parser.parse_args()

    main(args)

Sign up or log in to comment