Training code
Hi,
Congrats on the training of Mistral using EasyLM. We were looking at doing exactly the same, but only found the implementation of the model. Are you planning to release the training code you used?
Cheers.
Hey, sure. I've basically forked the EasyLM's fork that contained GQA implementation, and added Mistral config. It's here: https://github.com/defdet/EasyLM.
I have to warn you though, while wandb run looked fine (model converged to a good loss value), model's output are really not what they should be (while finetuned 3B llama's outputs were fine).
The command to run the training:
cd EasyLM && export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE SEQ_LEN=4096'; python3 -m EasyLM.models.llama.llama_train
--mesh_dim='1,1,8'
--dtype='bf16'
--load_llama_config='mistral'
--total_steps=15_000
--load_checkpoint='params::easylm_mistral'
--train_dataset.type='json'
--train_dataset.text_processor.fields='[prompt],answer'
--train_dataset.json_dataset.seq_length=4096
--train_dataset.json_dataset.batch_size=1
--eval_dataset.text_processor.fields='[prompt],answer'
--eval_dataset.json_dataset.seq_length=4096
--train_dataset.json_dataset.start_seek_loc=0
--eval_dataset.json_dataset.batch_size=1
--eval_dataset.type='json'
--update_llama_config='{"resid_pdrop": 0.1, "embd_pdrop": 0.1, "attn_pdrop": 0.1, "fcm_max_ratio": 0.1}'
--train_dataset.json_dataset.path='mistral-v2-format-train.json'
--eval_dataset.json_dataset.path='mistral-v2-format-test.json'
--optimizer.type=adamw
--optimizer.adamw_optimizer.lr=1e-6
--optimizer.adamw_optimizer.end_lr=1e-6
--tokenizer.vocab_file=tokenizer.model
--log_freq=1_000
--logger.online=True
--checkpointer.save_optimizer_state=False
--eval_steps=500
--save_model_freq=5_000
--logger.output_dir="mistral-saiga"
I see. Thanks!