erfanzar commited on
Commit
45b1af2
·
verified ·
1 Parent(s): 8754a18

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ easydel-model.parameters filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - EasyDeL
4
+ - llama
5
+ - CausalLM
6
+ - splash
7
+ - safetensors
8
+ - Flax
9
+ - JAX
10
+
11
+ - TPU
12
+
13
+ ---
14
+ <p align="center">
15
+ <a href="https://github.com/erfanzar/EasyDeL">
16
+ <img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80">
17
+ </a>
18
+ </p>
19
+ <p align="center">
20
+ <a href="https://github.com/erfanzar/EasyDeL">
21
+ <img src="https://img.shields.io/badge/🤗_EasyDeL-v0.1.5-blue.svg" />
22
+ </a>
23
+ <a href="https://github.com/erfanzar/EasyDeL">
24
+ <img src="https://img.shields.io/badge/Model_Arch-llama-green.svg" />
25
+ </a>
26
+ </p>
27
+
28
+ # Training Run: marin-8b-instruct-orpo
29
+
30
+ This document outlines the configuration and parameters used for training the model `marin-8b-instruct-orpo` using the [EasyDeL](https://github.com/erfanzar/EasyDeL) library.
31
+
32
+ EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments.
33
+
34
+ ## How to Load This Checkpoint
35
+
36
+ You can load the checkpoint generated from this training run using EasyDeL as follows:
37
+
38
+ ```python
39
+ import easydel as ed
40
+ from jax import numpy as jnp, lax
41
+
42
+ # Path to the directory where this README.md is located
43
+ repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo
44
+
45
+ model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
46
+ repo_id,
47
+ config_kwargs=EasyDeLBaseConfigDict(
48
+ # use_scan_mlp=False, # Set to True to potentially reduce memory usage
49
+ attn_dtype=jnp.float16, # Or jnp.bfloat16
50
+ # freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
51
+ # mask_max_position_embeddings=max_length, # Set if max length is defined
52
+ attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model
53
+ ),
54
+ dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
55
+ param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
56
+ precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
57
+ auto_shard_model=True, # Auto-shard across available devices
58
+ )
59
+ ```
60
+ *Note: Replace `checkpoint_path` with the actual path to the saved checkpoint directory.*
61
+ *The `params` returned are ready to be used with the `model`.*
62
+
63
+ ## Training Configuration Summary
64
+
65
+ ### Model & Hardware
66
+
67
+ - **Model Name (Run Name)**: `marin-8b-instruct-orpo`
68
+ - **Base Model Architecture**: `llama`
69
+ - **Platform**: `TPU`
70
+ - **Number of Devices Used**: `4` (total), `4` (local)
71
+ - **EasyDeL Version**: `v0.1.5`
72
+
73
+ ### Key Training Parameters
74
+
75
+ - **Learning Rate (Start → End)**: `8e-07`
76
+ - **Optimizer**: `EasyDeLOptimizers.ADAMW`
77
+ - **Scheduler**: `EasyDeLSchedulers.COSINE`
78
+ - **Warmup Steps**: `0`
79
+ - **Weight Decay**: `0.01`
80
+ - **Loss Configuration**: `LossConfig(
81
+ ignore_index : -100
82
+ label_smoothing : 0.0
83
+ z_loss : 0.0
84
+ loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS
85
+ num_labels : None
86
+ problem_type : None
87
+ divide_weight_sum : False
88
+ shift_tokens : True
89
+ break_on_nan : True
90
+ reduction : None
91
+ num_classification_labels : None
92
+ classification_problem_type : None
93
+ )`
94
+
95
+ ### Data & Batching
96
+
97
+ - **Number of Training Epochs**: `8`
98
+ - **Total Batch Size (per step)**: `4`
99
+ - **Maximum Sequence Length**: `4096`
100
+ - **Gradient Accumulation Steps**: `1`
101
+
102
+ ### Datatypes & Precision
103
+
104
+ - **Computation `dtype`**: `<class 'jax.numpy.bfloat16'>`
105
+ - **Parameter `param_dtype`**: `<class 'jax.numpy.bfloat16'>`
106
+ - **Gradient Checkpointing Method**: `EasyDeLGradientCheckPointers.NOTHING_SAVEABLE`
107
+ - **Attention Mechanism Used in Training**: `splash` (can be loaded as `AttentionMechanisms.SPLASH` if using `EasyDeLConfig`)
108
+
109
+ ### Run Control
110
+
111
+ - **Max Training Steps**: `Not Set`
112
+ - **Max Evaluation Steps**: `Not Set`
113
+ - **Training Time Limit**: `Not Set`
114
+
115
+ ## Citation
116
+
117
+ If you use EasyDeL in your research or work, please cite it:
118
+
119
+ ```bibtex
120
+ @misc{Zare Chavoshi_2023,
121
+ title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
122
+ url={https://github.com/erfanzar/EasyDeL},
123
+ author={Zare Chavoshi, Erfan},
124
+ year={2023}
125
+ }
126
+ ```
127
+
128
+ ---
129
+ *This document was automatically generated by EasyDeL v0.1.5 during the training run.*
config.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_mechanism": "splash",
8
+ "backend": null,
9
+ "begin_suppress_tokens": [
10
+ 128000,
11
+ 128001
12
+ ],
13
+ "bits": null,
14
+ "blocksize_b": 1,
15
+ "blocksize_k": 128,
16
+ "blocksize_q": 128,
17
+ "bos_token_id": 128000,
18
+ "decode_attn_mechanism": null,
19
+ "decoder_start_token_id": 128000,
20
+ "easy_method": "train",
21
+ "embd_pdrop": 0.0,
22
+ "eos_token_id": 128009,
23
+ "fcm_max_ratio": -1,
24
+ "fcm_min_ratio": -1,
25
+ "flash_attention_backward_pass_impl": "triton",
26
+ "freq_max_position_embeddings": 2048,
27
+ "gradient_checkpointing": "nothing_saveable",
28
+ "hardware_abstraction": false,
29
+ "head_dim": 128,
30
+ "hidden_act": "silu",
31
+ "hidden_size": 4096,
32
+ "initializer_range": 0.02,
33
+ "intermediate_size": 14336,
34
+ "kv_cache_quantization_blocksize": 64,
35
+ "kv_cache_quantization_method": "None",
36
+ "kv_cache_sharding_sequence_axis_name": "sp",
37
+ "mask_max_position_embeddings": 2048,
38
+ "max_position_embeddings": 4096,
39
+ "mlp_bias": false,
40
+ "model_type": "llama",
41
+ "num_attention_heads": 32,
42
+ "num_hidden_layers": 32,
43
+ "num_key_value_heads": 8,
44
+ "number_rep_kv": 1,
45
+ "pallas_k_block_size": 128,
46
+ "pallas_m_block_size": 128,
47
+ "pallas_n_block_size": 128,
48
+ "partition_axis": {
49
+ "attention_dim_axis": null,
50
+ "attention_kv_dim_axis": null,
51
+ "batch_axis": [
52
+ "fsdp",
53
+ "dp"
54
+ ],
55
+ "bias_head_sequence_axis": null,
56
+ "bias_key_sequence_axis": null,
57
+ "data_parallel_axis": "dp",
58
+ "decode_attention_dim_axis": null,
59
+ "decode_attention_kv_dim_axis": null,
60
+ "decode_batch_axis": [
61
+ "fsdp",
62
+ "dp"
63
+ ],
64
+ "decode_head_axis": "tp",
65
+ "decode_key_sequence_axis": "sp",
66
+ "decode_kv_head_axis": "tp",
67
+ "decode_query_sequence_axis": null,
68
+ "expert_axis": "ep",
69
+ "expert_gate_axis": null,
70
+ "expert_parallel_axis": "ep",
71
+ "fully_sharded_data_parallel_axis": "fsdp",
72
+ "head_axis": "tp",
73
+ "hidden_state_axis": "tp",
74
+ "key_sequence_axis": "sp",
75
+ "kv_head_axis": "tp",
76
+ "mlp_intermediate_axis": "tp",
77
+ "query_sequence_axis": "sp",
78
+ "sequence_axis": "sp",
79
+ "sequence_parallel_axis": "sp",
80
+ "tensor_parallel_axis": "tp",
81
+ "vocab_axis": "tp"
82
+ },
83
+ "platform": null,
84
+ "precompute_masks": true,
85
+ "pretraining_tp": 1,
86
+ "quantization_blocksize": 64,
87
+ "quantization_method": "None",
88
+ "quantization_pattern": ".*",
89
+ "resid_pdrop": 0.0,
90
+ "rms_norm_eps": 1e-05,
91
+ "rope_scaling": {
92
+ "factor": 8.0,
93
+ "high_freq_factor": 4.0,
94
+ "low_freq_factor": 1.0,
95
+ "original_max_position_embeddings": 8192,
96
+ "rope_type": "llama3"
97
+ },
98
+ "rope_theta": 500000,
99
+ "scan_attention_layers": false,
100
+ "scan_layers": false,
101
+ "scan_mlp_chunk_size": 1024,
102
+ "scan_ring_attention": true,
103
+ "sequence_axis_name": "sp",
104
+ "shard_attention_computation": true,
105
+ "sharding_axis_dims": [
106
+ 1,
107
+ -1,
108
+ 1,
109
+ 1
110
+ ],
111
+ "sharding_axis_names": [
112
+ "dp",
113
+ "fsdp",
114
+ "tp",
115
+ "sp"
116
+ ],
117
+ "sharding_dcn_axis_dims": null,
118
+ "tie_word_embeddings": false,
119
+ "transformers_version": "4.51.3",
120
+ "use_cache": true,
121
+ "use_scan_mlp": false,
122
+ "use_sharded_kv_caching": false,
123
+ "use_sharding_constraint": false,
124
+ "vocab_size": 128256
125
+ }
easydel-model.parameters ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac2e62e6e6a803c321ee79678a0b2df8ea4bfd635d1ed4c44a2410866d99c3d
3
+ size 16060556584
easydel-training-arguments.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_can_log_metrics": null,
3
+ "auto_shard_states": true,
4
+ "aux_loss_enabled": false,
5
+ "backend": null,
6
+ "beta": 0.1,
7
+ "clip_grad": 1.0,
8
+ "custom_scheduler": null,
9
+ "dataloader_num_workers": 0,
10
+ "dataloader_pin_memory": false,
11
+ "dataset_num_proc": null,
12
+ "disable_dropout": true,
13
+ "do_eval": true,
14
+ "do_last_save": true,
15
+ "do_train": true,
16
+ "eval_batch_size": 4,
17
+ "evaluation_steps": null,
18
+ "extra_optimizer_kwargs": {},
19
+ "frozen_parameters": null,
20
+ "generate_during_eval": false,
21
+ "gradient_accumulation_steps": 1,
22
+ "ids_to_pop_from_dataset": [],
23
+ "init_tx": true,
24
+ "is_encoder_decoder": null,
25
+ "is_fine_tuning": true,
26
+ "jax_distributed_config": null,
27
+ "label_pad_token_id": -100,
28
+ "learning_rate": 8e-07,
29
+ "learning_rate_end": null,
30
+ "log_all_workers": false,
31
+ "log_grad_norms": true,
32
+ "log_steps": 5,
33
+ "loss_config": {
34
+ "break_on_nan": true,
35
+ "classification_problem_type": null,
36
+ "divide_weight_sum": false,
37
+ "ignore_index": -100,
38
+ "label_smoothing": 0.0,
39
+ "loss_normalizing_factor": "SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS",
40
+ "num_classification_labels": null,
41
+ "num_labels": null,
42
+ "problem_type": null,
43
+ "reduction": null,
44
+ "shift_tokens": true,
45
+ "z_loss": 0.0
46
+ },
47
+ "low_mem_usage": true,
48
+ "max_completion_length": 2048,
49
+ "max_evaluation_steps": null,
50
+ "max_length": 2048,
51
+ "max_prompt_length": 1024,
52
+ "max_sequence_length": 4096,
53
+ "max_training_steps": null,
54
+ "metrics_to_show_in_rich_pbar": null,
55
+ "model_name": "marin-8b-instruct-orpo",
56
+ "model_parameters": null,
57
+ "num_train_epochs": 8,
58
+ "offload_dataset": false,
59
+ "offload_device_index": 0,
60
+ "offload_device_type": "cpu",
61
+ "optimizer": "adamw",
62
+ "padding_value": 128009,
63
+ "per_epoch_evaluation_steps": null,
64
+ "per_epoch_training_steps": null,
65
+ "performance_mode": false,
66
+ "process_zero_is_admin": true,
67
+ "progress_bar_type": "json",
68
+ "pruning_module": null,
69
+ "remove_ckpt_after_load": false,
70
+ "remove_unused_columns": true,
71
+ "report_metrics": true,
72
+ "report_steps": 10,
73
+ "save_directory": "EasyDeL-Checkpoints",
74
+ "save_optimizer_state": false,
75
+ "save_steps": 1000,
76
+ "save_total_limit": 1,
77
+ "scheduler": "cosine",
78
+ "shuffle_train_dataset": true,
79
+ "sparse_module_type": "bcoo",
80
+ "sparsify_module": false,
81
+ "state_apply_fn_kwarguments_to_model": null,
82
+ "step_partition_spec": [
83
+ [
84
+ "dp",
85
+ "fsdp"
86
+ ],
87
+ "sp"
88
+ ],
89
+ "step_start_point": 0,
90
+ "total_batch_size": 4,
91
+ "track_memory": false,
92
+ "train_on_inputs": true,
93
+ "trainer_config_class": "ORPOConfig",
94
+ "training_time_limit": null,
95
+ "truncation_mode": "keep_end",
96
+ "tx_mu_dtype": null,
97
+ "use_data_collactor": true,
98
+ "use_wandb": true,
99
+ "verbose": true,
100
+ "wandb_entity": "erfanzar",
101
+ "wandb_name": null,
102
+ "warmup_steps": 0,
103
+ "weight_decay": 0.01,
104
+ "weight_distribution_log_steps": 100,
105
+ "weight_distribution_pattern": ".*"
106
+ }
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "begin_suppress_tokens": [
4
+ 128000,
5
+ 128001
6
+ ],
7
+ "bos_token_id": 128000,
8
+ "decoder_start_token_id": 128000,
9
+ "eos_token_id": 128009,
10
+ "transformers_version": "4.51.3"
11
+ }