PyTorch
llama
alignment-handbook
Generated from Trainer

Improve model card: Add metadata, links, and detailed sections

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +134 -25
README.md CHANGED
@@ -1,26 +1,51 @@
1
  ---
2
  base_model: JunxiongWang/llama3_0_875_mamba2_sft
3
- tags:
4
- - alignment-handbook
5
- - generated_from_trainer
6
  datasets:
7
  - HuggingFaceH4/ultrafeedback_binarized
8
  - HuggingFaceH4/orca_dpo_pairs
9
  - JunxiongWang/llama3-ultrafeedback-armorm
 
 
 
 
 
10
  model-index:
11
  - name: JunxiongWang/Mamba2InLlama_0_875
12
  results: []
 
 
 
13
  ---
14
 
15
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
16
- should probably proofread and complete it, then remove this comment. -->
 
17
 
18
- Please check [here](https://github.com/jxiw/MambaInLlama/tree/main) for details.
 
 
19
 
20
  [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/junxiong12/huggingface/runs/58mrdgq8)
21
- # JunxiongWang/Mamba2InLlama_0_875
22
 
23
- This model is a fine-tuned version of [JunxiongWang/llama3_0_875_mamba2_sft](https://huggingface.co/JunxiongWang/llama3_0_875_mamba2_sft/) on the HuggingFaceH4/ultrafeedback_binarized, the HuggingFaceH4/orca_dpo_pairs and the JunxiongWang/llama3-ultrafeedback-armorm datasets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  It achieves the following results on the evaluation set:
25
  - Loss: 0.4761
26
  - Rewards/chosen: -1.4040
@@ -32,19 +57,15 @@ It achieves the following results on the evaluation set:
32
  - Logits/rejected: 0.3408
33
  - Logits/chosen: 0.3851
34
 
35
- ## Model description
36
-
37
- More information needed
38
-
39
- ## Intended uses & limitations
40
 
41
- More information needed
42
 
43
- ## Training and evaluation data
 
 
44
 
45
- More information needed
46
-
47
- ## Training procedure
48
 
49
  ### Training hyperparameters
50
 
@@ -69,7 +90,6 @@ The following hyperparameters were used during training:
69
  | 0.5009 | 0.4798 | 2000 | 0.4998 | -1.4973 | -2.6147 | 0.7804 | 1.1175 | -586.2582 | -468.3976 | 0.4682 | 0.5136 |
70
  | 0.4895 | 0.9597 | 4000 | 0.4761 | -1.4040 | -2.6012 | 0.7982 | 1.1973 | -584.9104 | -459.0677 | 0.3408 | 0.3851 |
71
 
72
-
73
  ### Framework versions
74
 
75
  - Transformers 4.43.1
@@ -77,13 +97,102 @@ The following hyperparameters were used during training:
77
  - Datasets 2.20.0
78
  - Tokenizers 0.19.1
79
 
80
- [MambaInLlama](arxiv.org/abs/2408.15237)
 
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ```
83
- @article{junxiongdaniele2024mambainllama,
84
- title = {The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
85
- author = {Junxiong Wang and Daniele Paliotta and Avner May and Alexander M. Rush and Tri Dao},
86
- journal = {arXiv preprint arXiv:2408.15237},
87
- year = {2024}
 
 
 
 
 
 
 
88
  }
89
  ```
 
1
  ---
2
  base_model: JunxiongWang/llama3_0_875_mamba2_sft
 
 
 
3
  datasets:
4
  - HuggingFaceH4/ultrafeedback_binarized
5
  - HuggingFaceH4/orca_dpo_pairs
6
  - JunxiongWang/llama3-ultrafeedback-armorm
7
+ tags:
8
+ - alignment-handbook
9
+ - generated_from_trainer
10
+ - mamba
11
+ - distillation
12
  model-index:
13
  - name: JunxiongWang/Mamba2InLlama_0_875
14
  results: []
15
+ pipeline_tag: text-generation
16
+ library_name: transformers
17
+ license: apache-2.0
18
  ---
19
 
20
+ # JunxiongWang/Mamba2InLlama_0_875: The Mamba in the Llama
21
+
22
+ This model is part of the work presented in the paper [The Mamba in the Llama: Distilling and Accelerating Hybrid Models](https://arxiv.org/abs/2408.15237).
23
 
24
+ **Code Repository (New Version)**: [https://github.com/jxiw/M1](https://github.com/jxiw/M1)
25
+ **Code Repository (Original)**: [https://github.com/jxiw/MambaInLlama](https://github.com/jxiw/MambaInLlama)
26
+ **Project Page**: [https://openreview.net/forum?id=uAzhODjALU](https://openreview.net/forum?id=uAzhODjALU)
27
 
28
  [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/junxiong12/huggingface/runs/58mrdgq8)
 
29
 
30
+ ## Model description
31
+
32
+ This model, `JunxiongWang/Mamba2InLlama_0_875`, is a fine-tuned hybrid language model that combines elements of Transformer and Mamba (linear RNN) architectures. It is a result of a distillation process aimed at converting large pre-trained Transformer models, such as Llama3-8B-Instruct, into more deployment-advantageous linear RNNs. The approach involves reusing linear projection weights from attention layers, resulting in a hybrid model that incorporates a fraction (in this case, one-quarter) of the original attention layers.
33
+
34
+ The model aims to achieve performance comparable to the original Transformer in chat benchmarks while offering improved inference characteristics. It leverages a hardware-aware speculative decoding algorithm for accelerated inference speed. This specific model, distilled from Llama3-8B-Instruct, demonstrates competitive results with a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench. It also exhibits natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length.
35
+
36
+ ## Intended uses & limitations
37
+
38
+ This model is intended for efficient and accelerated text generation tasks, particularly in scenarios where the advantageous deployment characteristics of linear RNNs (like Mamba) are desired over traditional Transformer models. It is suitable for chat applications, general language modeling, and tasks requiring long-range context handling.
39
+
40
+ Limitations: While designed to preserve generative quality, the hybrid architecture might have different performance profiles compared to a full Transformer model on specific tasks. Users should be aware that the optimal performance and reproducibility are dependent on adherence to the recommended environment setup, including specific CUDA and Python package versions. As with all large language models, potential biases from training data and hallucination remain a consideration.
41
+
42
+ ## Training and evaluation data
43
+
44
+ This model is a fine-tuned version of [JunxiongWang/llama3_0_875_mamba2_sft](https://huggingface.co/JunxiongWang/llama3_0_875_mamba2_sft/) on the following datasets:
45
+ * [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
46
+ * [HuggingFaceH4/orca_dpo_pairs](https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs)
47
+ * [JunxiongWang/llama3-ultrafeedback-armorm](https://huggingface.co/datasets/JunxiongWang/llama3-ultrafeedback-armorm)
48
+
49
  It achieves the following results on the evaluation set:
50
  - Loss: 0.4761
51
  - Rewards/chosen: -1.4040
 
57
  - Logits/rejected: 0.3408
58
  - Logits/chosen: 0.3851
59
 
60
+ ## Training procedure
 
 
 
 
61
 
62
+ The models were distilled using a multi-stage approach:
63
 
64
+ 1. **Stepwise layer alignment** (Optional): Attention layers are replaced by Mamba2 layers one by one in a stepwise manner. During this stage, MLP layers are typically frozen to ensure the model remains similar to the initialization model.
65
+ 2. **End-to-end distillation** (Most important): The primary phase involves minimizing the KL divergence loss between the student (hybrid) and teacher (original Transformer) models. In this stage, all parameters, including MLP layers, are trained to achieve better results based purely on KL loss.
66
+ 3. **Instruction tuning** (Optional): For simplicity, DPO (Direct Preference Optimization) was used for this process to align the models with human preferences.
67
 
68
+ The distillation process typically requires around 3 to 4 days using 8x80G A100 GPUs with limited resources.
 
 
69
 
70
  ### Training hyperparameters
71
 
 
90
  | 0.5009 | 0.4798 | 2000 | 0.4998 | -1.4973 | -2.6147 | 0.7804 | 1.1175 | -586.2582 | -468.3976 | 0.4682 | 0.5136 |
91
  | 0.4895 | 0.9597 | 4000 | 0.4761 | -1.4040 | -2.6012 | 0.7982 | 1.1973 | -584.9104 | -459.0677 | 0.3408 | 0.3851 |
92
 
 
93
  ### Framework versions
94
 
95
  - Transformers 4.43.1
 
97
  - Datasets 2.20.0
98
  - Tokenizers 0.19.1
99
 
100
+ ## Usage
101
+
102
+ For detailed instructions, the full codebase, and other released models, please refer to the primary [M1 GitHub repository](https://github.com/jxiw/M1).
103
 
104
+ ### Environment Setup
105
+
106
+ The project provides an `environment.yml` file listing specific Python package versions used for reproducibility. It is recommended to use these versions for optimal performance. Key packages include `mamba-ssm`, `causal-conv1d`, and `flash-attn`, along with specific PyTorch and CUDA versions.
107
+
108
+ ```bash
109
+ # CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
110
+ conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
111
+ # Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
112
+ pip install torch==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
113
+
114
+ pip install causal-conv1d==1.4.0
115
+ pip install flash-attn==2.6.3
116
+
117
+ # make sure you use this alignment version
118
+ git clone https://github.com/huggingface/alignment-handbook.git
119
+ cd alignment-handbook/
120
+ git checkout 606d2e9
121
+
122
+ git clone https://github.com/huggingface/transformers.git --branch v4.43.1
123
+
124
+ # check your version matches those
125
+ # deepspeed==0.12.2
126
+ # torch==2.1.1+cu118
127
+ # transformers==4.43.1
128
+ # trl==0.8.6
129
+ # accelerate==0.33.0
130
+ # peft==0.12.0
131
+ # huggingface-hub==0.24.5
132
+ ```
133
+
134
+ If `mamba-ssm==2.2.2` is installed via pip, a manual change to `CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py` might be needed to support GQA (used in Llama3). Refer to [this version](https://github.com/state-spaces/mamba/blob/014c094d11f780a27330657faabecaaded7a31db/mamba_ssm/modules/mha.py) or build `mamba-ssm` from source (commit after `014c094d11f780a27330657faabecaaded7a31db`).
135
+
136
+ ### Generation Example (Mamba 2)
137
+
138
+ ```python
139
+ import torch
140
+ from transformers import AutoTokenizer
141
+ # For Mamba2InLlama models, use mamba2_inference.hybrid_wrapper
142
+ from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper
143
+
144
+ pretrained_model_name = "JunxiongWang/Mamba2InLlama_0_875" # This model
145
+ model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
146
+ model.eval()
147
+
148
+ messages = [[
149
+ {
150
+ "role": "user",
151
+ "content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
152
+ },
153
+ ]]
154
+
155
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
156
+ formatted_prompts = [
157
+ tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
158
+ ]
159
+
160
+ prompts = [
161
+ tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
162
+ for formatted_prompt in formatted_prompts
163
+ ]
164
+ batch_prompts = torch.cat(prompts, dim=0).cuda()
165
+
166
+ outputs = model.generate(
167
+ input_ids=batch_prompts,
168
+ max_length=1000,
169
+ cg=True, # Enables speculative decoding if supported
170
+ return_dict_in_generate=True,
171
+ output_scores=True,
172
+ enable_timing=True,
173
+ top_k=1,
174
+ eos_token_id=tokenizer.eos_token_id
175
+ )
176
+
177
+ generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
178
+ print(generated_text[0])
179
+
180
+ # Example output (trimmed for brevity):
181
+ # Let's use algebra to solve this problem. Let \( c \) represent the number of chickens and \( k \) represent the number of cows.
182
+ # ... (full derivation and answer)
183
+ # So, there are 5 chickens on Farmer Brown's farm.
184
  ```
185
+
186
+ ## Citation
187
+ If you use this codebase, or otherwise found our work valuable, please cite:
188
+
189
+ ```bibtex
190
+ @inproceedings{
191
+ junxiongdaniele2024mambainllama,
192
+ title={The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
193
+ author={Junxiong Wang and Daniele Paliotta and Avner May and Alexander M Rush and Tri Dao},
194
+ booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
195
+ year={2024},
196
+ url={https://openreview.net/forum?id=uAzhODjALU}
197
  }
198
  ```