--- base_model: - mistralai/Mistral-Nemo-Base-2407 license: apache-2.0 tags: - writing - creative-writing --- # Koto 22B (Pretrained) ![image/png](https://cdn-uploads.huggingface.co/production/uploads/634262af8d8089ebaefd410e/cnBQlWjMTKGLOKMudPBVj.png) Koto-22B-PT is a [depth-upscaled](https://arxiv.org/abs/2312.15166) version of Mistral-Nemo-Base-2407, healed and trained on almost a billion tokens of creative writing data. ## Usage This model is not intended for use outside of raw text completion settings, such as cowriting. Instruct will *not* work. Multi-turn roleplay will *not* work. It was trained at 32k, but as not all samples were this long, we expect that in the best case you can get ~16k effective context. We found that 1.5-1.55 temperature and 0.05-0.1 min_p worked best, but YMMV! ## Datasets Some of the data used to train this model includes: - Most of [The Anarchist Library](https://theanarchistlibrary.org/), a repository for anarchist manifestos and writing (see [allura-org/the-anarchist-library](https://huggingface.co/datasets/allura-org/the-anarchist-library)) - A random sample of public domain books from Project Gutenberg - Furry (anthro and feral) storytelling and smut - A small subset of known high-quality books and story data ## Acknowledgements - thank you to [@takeshimaxfj](https://x.com/takeshimaxfj) on twitter for drawing the art used in the model card! - thank you very much to [mango/deltavector](https://huggingface.co/Delta-Vector) for providing the compute used to train this model - thanks to curse for testing, ideas - thanks to toasty for some data, ideas - thanks to everyone else in allura for moral support ilya <3 ## Call for Help if you would like to help build on this model (instruct/RP SFT, further annealing on higher quality data, etc)... please join [our discord](https://discord.gg/PPBMhF2vgC) or [our matrix](https://matrix.to/#/#allura:allura.moe)! <3 ## Technical Appendix
### Training Notes This model was trained over the course of ~14 hours on an 8xB200 node. We used 8-bit AdamW and the REX LR scheduler, as well as both gradient clipping and weight decay for regularization. There *was* a very odd loss spike ~60% of the way through training, but it recovered and the model seems fine? So? Eh? If it works it works :3 ### WandB ![image/png](https://cdn-uploads.huggingface.co/production/uploads/634262af8d8089ebaefd410e/6XFFhkQD8lUFGerBrOAyd.png) ### Finetuning Notes This model has had ChatML tokens already added if you prefer to tune using that chat format. Please do not readd them to maintain the vocab size for (possible) usage on places like Featherless ### Axolotl Config ```yaml ## model base_model: allura-forge/nemo-upscaled-2 #tokenizer_use_mistral_common: true ## qlora COPE!!! load_in_8bit: false load_in_4bit: false strict: false ## data datasets: datasets: - path: estrogen/bookscpt2 type: completion field: text shuffle_merged_datasets: true dataset_prepared_path: dataset_preparedss val_set_size: 0.0 output_dir: ./Pretrain ## Liger + CCE plugins: - axolotl.integrations.liger.LigerPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin liger_rope: true liger_rms_norm: true liger_layer_norm: true liger_glu_activation: true liger_fused_linear_cross_entropy: false cut_cross_entropy: true ## CTX settings sequence_len: 32768 sample_packing: true eval_sample_packing: false pad_to_sequence_len: true ## max grad norm max_grad_norm: 1.0 ## WandB wandb_project: NeMo-Upscale wandb_entity: wandb_watch: wandb_name: Pretrain-22B wandb_log_model: ## hoe params gradient_accumulation_steps: 4 micro_batch_size: 4 num_epochs: 1 optimizer: adamw_bnb_8bit lr_scheduler: rex learning_rate: 2e-5 train_on_inputs: false group_by_length: false bf16: auto fp16: tf32: false gradient_checkpointing: true early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 xformers_attention: flash_attention: true s2_attention: warmup_steps: 50 saves_per_epoch: 2 debug: deepspeed: ./deepspeed_configs/zero3_bf16.json weight_decay: 0.0025 fsdp: fsdp_config: special_tokens: pad_token: ``` ### Mergekit Config ```yaml dtype: bfloat16 merge_method: passthrough slices: # untouched intro - sources: - layer_range: [0, 8] model: mistralai/Mistral-Nemo-Base-2407 - sources: - layer_range: [8, 12] model: mistralai/Mistral-Nemo-Base-2407 # 8–16 baseline - sources: - layer_range: [8, 16] model: mistralai/Mistral-Nemo-Base-2407 # 8–16 duplicate with projections nulled - sources: - layer_range: [8, 16] model: mistralai/Mistral-Nemo-Base-2407 parameters: scale: - filter: o_proj value: 0.0 - filter: down_proj value: 0.0 - value: 1.0 # 16–24 duplicate - sources: - layer_range: [16, 24] model: mistralai/Mistral-Nemo-Base-2407 parameters: scale: - filter: o_proj value: 0.0 - filter: down_proj value: 0.0 - value: 1.0 # 16–24 baseline - sources: - layer_range: [16, 24] model: mistralai/Mistral-Nemo-Base-2407 # 16–24 duplicate - sources: - layer_range: [16, 24] model: mistralai/Mistral-Nemo-Base-2407 parameters: scale: - filter: o_proj value: 0.0 - filter: down_proj value: 0.0 - value: 1.0 # 24–32 baseline - sources: - layer_range: [24, 32] model: mistralai/Mistral-Nemo-Base-2407 # 24–32 duplicate - sources: - layer_range: [24, 32] model: mistralai/Mistral-Nemo-Base-2407 parameters: scale: - filter: o_proj value: 0.0 - filter: down_proj value: 0.0 - value: 1.0 # untouched tail - sources: - layer_range: [32, 40] model: mistralai/Mistral-Nemo-Base-2407 ```