Improve model card: Add metadata, links, and detailed sections
#1
by
nielsr
HF Staff
- opened
README.md
CHANGED
@@ -1,26 +1,51 @@
|
|
1 |
---
|
2 |
base_model: JunxiongWang/llama3_0_875_mamba2_sft
|
3 |
-
tags:
|
4 |
-
- alignment-handbook
|
5 |
-
- generated_from_trainer
|
6 |
datasets:
|
7 |
- HuggingFaceH4/ultrafeedback_binarized
|
8 |
- HuggingFaceH4/orca_dpo_pairs
|
9 |
- JunxiongWang/llama3-ultrafeedback-armorm
|
|
|
|
|
|
|
|
|
|
|
10 |
model-index:
|
11 |
- name: JunxiongWang/Mamba2InLlama_0_875
|
12 |
results: []
|
|
|
|
|
|
|
13 |
---
|
14 |
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
-
|
|
|
|
|
19 |
|
20 |
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/junxiong12/huggingface/runs/58mrdgq8)
|
21 |
-
# JunxiongWang/Mamba2InLlama_0_875
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
It achieves the following results on the evaluation set:
|
25 |
- Loss: 0.4761
|
26 |
- Rewards/chosen: -1.4040
|
@@ -32,19 +57,15 @@ It achieves the following results on the evaluation set:
|
|
32 |
- Logits/rejected: 0.3408
|
33 |
- Logits/chosen: 0.3851
|
34 |
|
35 |
-
##
|
36 |
-
|
37 |
-
More information needed
|
38 |
-
|
39 |
-
## Intended uses & limitations
|
40 |
|
41 |
-
|
42 |
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
## Training procedure
|
48 |
|
49 |
### Training hyperparameters
|
50 |
|
@@ -69,7 +90,6 @@ The following hyperparameters were used during training:
|
|
69 |
| 0.5009 | 0.4798 | 2000 | 0.4998 | -1.4973 | -2.6147 | 0.7804 | 1.1175 | -586.2582 | -468.3976 | 0.4682 | 0.5136 |
|
70 |
| 0.4895 | 0.9597 | 4000 | 0.4761 | -1.4040 | -2.6012 | 0.7982 | 1.1973 | -584.9104 | -459.0677 | 0.3408 | 0.3851 |
|
71 |
|
72 |
-
|
73 |
### Framework versions
|
74 |
|
75 |
- Transformers 4.43.1
|
@@ -77,13 +97,102 @@ The following hyperparameters were used during training:
|
|
77 |
- Datasets 2.20.0
|
78 |
- Tokenizers 0.19.1
|
79 |
|
80 |
-
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
```
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
}
|
89 |
```
|
|
|
1 |
---
|
2 |
base_model: JunxiongWang/llama3_0_875_mamba2_sft
|
|
|
|
|
|
|
3 |
datasets:
|
4 |
- HuggingFaceH4/ultrafeedback_binarized
|
5 |
- HuggingFaceH4/orca_dpo_pairs
|
6 |
- JunxiongWang/llama3-ultrafeedback-armorm
|
7 |
+
tags:
|
8 |
+
- alignment-handbook
|
9 |
+
- generated_from_trainer
|
10 |
+
- mamba
|
11 |
+
- distillation
|
12 |
model-index:
|
13 |
- name: JunxiongWang/Mamba2InLlama_0_875
|
14 |
results: []
|
15 |
+
pipeline_tag: text-generation
|
16 |
+
library_name: transformers
|
17 |
+
license: apache-2.0
|
18 |
---
|
19 |
|
20 |
+
# JunxiongWang/Mamba2InLlama_0_875: The Mamba in the Llama
|
21 |
+
|
22 |
+
This model is part of the work presented in the paper [The Mamba in the Llama: Distilling and Accelerating Hybrid Models](https://arxiv.org/abs/2408.15237).
|
23 |
|
24 |
+
**Code Repository (New Version)**: [https://github.com/jxiw/M1](https://github.com/jxiw/M1)
|
25 |
+
**Code Repository (Original)**: [https://github.com/jxiw/MambaInLlama](https://github.com/jxiw/MambaInLlama)
|
26 |
+
**Project Page**: [https://openreview.net/forum?id=uAzhODjALU](https://openreview.net/forum?id=uAzhODjALU)
|
27 |
|
28 |
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/junxiong12/huggingface/runs/58mrdgq8)
|
|
|
29 |
|
30 |
+
## Model description
|
31 |
+
|
32 |
+
This model, `JunxiongWang/Mamba2InLlama_0_875`, is a fine-tuned hybrid language model that combines elements of Transformer and Mamba (linear RNN) architectures. It is a result of a distillation process aimed at converting large pre-trained Transformer models, such as Llama3-8B-Instruct, into more deployment-advantageous linear RNNs. The approach involves reusing linear projection weights from attention layers, resulting in a hybrid model that incorporates a fraction (in this case, one-quarter) of the original attention layers.
|
33 |
+
|
34 |
+
The model aims to achieve performance comparable to the original Transformer in chat benchmarks while offering improved inference characteristics. It leverages a hardware-aware speculative decoding algorithm for accelerated inference speed. This specific model, distilled from Llama3-8B-Instruct, demonstrates competitive results with a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench. It also exhibits natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length.
|
35 |
+
|
36 |
+
## Intended uses & limitations
|
37 |
+
|
38 |
+
This model is intended for efficient and accelerated text generation tasks, particularly in scenarios where the advantageous deployment characteristics of linear RNNs (like Mamba) are desired over traditional Transformer models. It is suitable for chat applications, general language modeling, and tasks requiring long-range context handling.
|
39 |
+
|
40 |
+
Limitations: While designed to preserve generative quality, the hybrid architecture might have different performance profiles compared to a full Transformer model on specific tasks. Users should be aware that the optimal performance and reproducibility are dependent on adherence to the recommended environment setup, including specific CUDA and Python package versions. As with all large language models, potential biases from training data and hallucination remain a consideration.
|
41 |
+
|
42 |
+
## Training and evaluation data
|
43 |
+
|
44 |
+
This model is a fine-tuned version of [JunxiongWang/llama3_0_875_mamba2_sft](https://huggingface.co/JunxiongWang/llama3_0_875_mamba2_sft/) on the following datasets:
|
45 |
+
* [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
46 |
+
* [HuggingFaceH4/orca_dpo_pairs](https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs)
|
47 |
+
* [JunxiongWang/llama3-ultrafeedback-armorm](https://huggingface.co/datasets/JunxiongWang/llama3-ultrafeedback-armorm)
|
48 |
+
|
49 |
It achieves the following results on the evaluation set:
|
50 |
- Loss: 0.4761
|
51 |
- Rewards/chosen: -1.4040
|
|
|
57 |
- Logits/rejected: 0.3408
|
58 |
- Logits/chosen: 0.3851
|
59 |
|
60 |
+
## Training procedure
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
The models were distilled using a multi-stage approach:
|
63 |
|
64 |
+
1. **Stepwise layer alignment** (Optional): Attention layers are replaced by Mamba2 layers one by one in a stepwise manner. During this stage, MLP layers are typically frozen to ensure the model remains similar to the initialization model.
|
65 |
+
2. **End-to-end distillation** (Most important): The primary phase involves minimizing the KL divergence loss between the student (hybrid) and teacher (original Transformer) models. In this stage, all parameters, including MLP layers, are trained to achieve better results based purely on KL loss.
|
66 |
+
3. **Instruction tuning** (Optional): For simplicity, DPO (Direct Preference Optimization) was used for this process to align the models with human preferences.
|
67 |
|
68 |
+
The distillation process typically requires around 3 to 4 days using 8x80G A100 GPUs with limited resources.
|
|
|
|
|
69 |
|
70 |
### Training hyperparameters
|
71 |
|
|
|
90 |
| 0.5009 | 0.4798 | 2000 | 0.4998 | -1.4973 | -2.6147 | 0.7804 | 1.1175 | -586.2582 | -468.3976 | 0.4682 | 0.5136 |
|
91 |
| 0.4895 | 0.9597 | 4000 | 0.4761 | -1.4040 | -2.6012 | 0.7982 | 1.1973 | -584.9104 | -459.0677 | 0.3408 | 0.3851 |
|
92 |
|
|
|
93 |
### Framework versions
|
94 |
|
95 |
- Transformers 4.43.1
|
|
|
97 |
- Datasets 2.20.0
|
98 |
- Tokenizers 0.19.1
|
99 |
|
100 |
+
## Usage
|
101 |
+
|
102 |
+
For detailed instructions, the full codebase, and other released models, please refer to the primary [M1 GitHub repository](https://github.com/jxiw/M1).
|
103 |
|
104 |
+
### Environment Setup
|
105 |
+
|
106 |
+
The project provides an `environment.yml` file listing specific Python package versions used for reproducibility. It is recommended to use these versions for optimal performance. Key packages include `mamba-ssm`, `causal-conv1d`, and `flash-attn`, along with specific PyTorch and CUDA versions.
|
107 |
+
|
108 |
+
```bash
|
109 |
+
# CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
|
110 |
+
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
|
111 |
+
# Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
|
112 |
+
pip install torch==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
113 |
+
|
114 |
+
pip install causal-conv1d==1.4.0
|
115 |
+
pip install flash-attn==2.6.3
|
116 |
+
|
117 |
+
# make sure you use this alignment version
|
118 |
+
git clone https://github.com/huggingface/alignment-handbook.git
|
119 |
+
cd alignment-handbook/
|
120 |
+
git checkout 606d2e9
|
121 |
+
|
122 |
+
git clone https://github.com/huggingface/transformers.git --branch v4.43.1
|
123 |
+
|
124 |
+
# check your version matches those
|
125 |
+
# deepspeed==0.12.2
|
126 |
+
# torch==2.1.1+cu118
|
127 |
+
# transformers==4.43.1
|
128 |
+
# trl==0.8.6
|
129 |
+
# accelerate==0.33.0
|
130 |
+
# peft==0.12.0
|
131 |
+
# huggingface-hub==0.24.5
|
132 |
+
```
|
133 |
+
|
134 |
+
If `mamba-ssm==2.2.2` is installed via pip, a manual change to `CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py` might be needed to support GQA (used in Llama3). Refer to [this version](https://github.com/state-spaces/mamba/blob/014c094d11f780a27330657faabecaaded7a31db/mamba_ssm/modules/mha.py) or build `mamba-ssm` from source (commit after `014c094d11f780a27330657faabecaaded7a31db`).
|
135 |
+
|
136 |
+
### Generation Example (Mamba 2)
|
137 |
+
|
138 |
+
```python
|
139 |
+
import torch
|
140 |
+
from transformers import AutoTokenizer
|
141 |
+
# For Mamba2InLlama models, use mamba2_inference.hybrid_wrapper
|
142 |
+
from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper
|
143 |
+
|
144 |
+
pretrained_model_name = "JunxiongWang/Mamba2InLlama_0_875" # This model
|
145 |
+
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
|
146 |
+
model.eval()
|
147 |
+
|
148 |
+
messages = [[
|
149 |
+
{
|
150 |
+
"role": "user",
|
151 |
+
"content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
|
152 |
+
},
|
153 |
+
]]
|
154 |
+
|
155 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
156 |
+
formatted_prompts = [
|
157 |
+
tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
|
158 |
+
]
|
159 |
+
|
160 |
+
prompts = [
|
161 |
+
tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
|
162 |
+
for formatted_prompt in formatted_prompts
|
163 |
+
]
|
164 |
+
batch_prompts = torch.cat(prompts, dim=0).cuda()
|
165 |
+
|
166 |
+
outputs = model.generate(
|
167 |
+
input_ids=batch_prompts,
|
168 |
+
max_length=1000,
|
169 |
+
cg=True, # Enables speculative decoding if supported
|
170 |
+
return_dict_in_generate=True,
|
171 |
+
output_scores=True,
|
172 |
+
enable_timing=True,
|
173 |
+
top_k=1,
|
174 |
+
eos_token_id=tokenizer.eos_token_id
|
175 |
+
)
|
176 |
+
|
177 |
+
generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
|
178 |
+
print(generated_text[0])
|
179 |
+
|
180 |
+
# Example output (trimmed for brevity):
|
181 |
+
# Let's use algebra to solve this problem. Let \( c \) represent the number of chickens and \( k \) represent the number of cows.
|
182 |
+
# ... (full derivation and answer)
|
183 |
+
# So, there are 5 chickens on Farmer Brown's farm.
|
184 |
```
|
185 |
+
|
186 |
+
## Citation
|
187 |
+
If you use this codebase, or otherwise found our work valuable, please cite:
|
188 |
+
|
189 |
+
```bibtex
|
190 |
+
@inproceedings{
|
191 |
+
junxiongdaniele2024mambainllama,
|
192 |
+
title={The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
|
193 |
+
author={Junxiong Wang and Daniele Paliotta and Avner May and Alexander M Rush and Tri Dao},
|
194 |
+
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
|
195 |
+
year={2024},
|
196 |
+
url={https://openreview.net/forum?id=uAzhODjALU}
|
197 |
}
|
198 |
```
|