|
--- |
|
language: en |
|
tags: |
|
- jax |
|
- flax |
|
- text-generation |
|
- transformers |
|
- meta-llama/Llama-3.2-3B |
|
--- |
|
|
|
# meta-llama/Llama-3.2-3B - JAX/Flax |
|
|
|
This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from meta-llama. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework. |
|
|
|
## Model Description |
|
|
|
meta-llama/Llama-3.2-3B is a transformer-based language model developed by meta-llama. |
|
|
|
## Conversion Details |
|
|
|
This model was converted from the original PyTorch implementation to JAX/Flax. The conversion process involved the following steps: |
|
|
|
1. **Loading the PyTorch model and configuration:** The pretrained PyTorch model and its configuration were loaded using the Hugging Face Transformers library. |
|
2. **Creating an equivalent Flax model architecture:** A Flax model with the same architecture as the original PyTorch model was created. |
|
3. **Converting the PyTorch weights to Flax format:** The weights from the PyTorch model were converted to the Flax format using the `convert_pytorch_state_dict_to_flax` utility function provided by Hugging Face. |
|
4. **Verifying the converted weights:** The converted Flax weights were compared against the original PyTorch weights to ensure that the conversion process was performed accurately. |
|
|
|
### Important Note about `max_position_embeddings` |
|
|
|
During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of 131072 led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to 16384. |
|
|
|
**Implications of this change:** |
|
|
|
* The model may not be able to handle sequences longer than 16384 tokens without truncation or other modifications. |
|
* If you fine-tune this model, keep in mind the revised `max_position_embeddings` when preparing your training data. |
|
|
|
## Weight Comparison |
|
|
|
The following table summarizes the comparison between the weights of the original PyTorch model and the converted JAX/Flax model. This detailed verification confirms that the conversion was accurate and that both models should produce (approximately) the same outputs given the same inputs. |
|
|
|
| Layer | PyTorch Shape | Flax Shape | Allclose | Max Diff | Mean Diff | Std Diff | |
|
| :---- | :------------ | :--------- | :------- | :------- | :-------- | :------- | |
|
| model.embed_tokens.weight | (128256, 3072) | (128256, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.0.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.0.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.0.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.0.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.0.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.0.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.0.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.0.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.0.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.1.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.1.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.1.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.1.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.1.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.1.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.1.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.1.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.1.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.2.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.2.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.2.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.2.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.2.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.2.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.2.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.2.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.2.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.3.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.3.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.3.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.3.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.3.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.3.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.3.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.3.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.3.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.4.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.4.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.4.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.4.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.4.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.4.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.4.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.4.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.4.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.5.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.5.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.5.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.5.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.5.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.5.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.5.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.5.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.5.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.6.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.6.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.6.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.6.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.6.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.6.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.6.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.6.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.6.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.7.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.7.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.7.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.7.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.7.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.7.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.7.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.7.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.7.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.8.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.8.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.8.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.8.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.8.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.8.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.8.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.8.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.8.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.9.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.9.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.9.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.9.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.9.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.9.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.9.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.9.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.9.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.10.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.10.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.10.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.10.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.10.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.10.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.10.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.10.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.10.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.11.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.11.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.11.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.11.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.11.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.11.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.11.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.11.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.11.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.12.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.12.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.12.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.12.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.12.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.12.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.12.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.12.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.12.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.13.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.13.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.13.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.13.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.13.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.13.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.13.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.13.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.13.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.14.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.14.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.14.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.14.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.14.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.14.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.14.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.14.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.14.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.15.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.15.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.15.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.15.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.15.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.15.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.15.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.15.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.15.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.16.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.16.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.16.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.16.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.16.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.16.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.16.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.16.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.16.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.17.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.17.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.17.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.17.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.17.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.17.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.17.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.17.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.17.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.18.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.18.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.18.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.18.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.18.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.18.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.18.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.18.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.18.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.19.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.19.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.19.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.19.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.19.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.19.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.19.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.19.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.19.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.20.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.20.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.20.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.20.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.20.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.20.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.20.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.20.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.20.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.21.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.21.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.21.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.21.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.21.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.21.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.21.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.21.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.21.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.22.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.22.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.22.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.22.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.22.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.22.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.22.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.22.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.22.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.23.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.23.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.23.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.23.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.23.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.23.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.23.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.23.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.23.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.24.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.24.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.24.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.24.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.24.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.24.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.24.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.24.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.24.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.25.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.25.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.25.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.25.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.25.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.25.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.25.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.25.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.25.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.26.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.26.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.26.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.26.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.26.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.26.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.26.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.26.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.26.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.27.self_attn.q_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.27.self_attn.k_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.27.self_attn.v_proj.weight | (3072, 1024) | (3072, 1024) | True | 0 | 0 | 0 | |
|
| model.layers.27.self_attn.o_proj.weight | (3072, 3072) | (3072, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.27.mlp.gate_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.27.mlp.up_proj.weight | (3072, 8192) | (3072, 8192) | True | 0 | 0 | 0 | |
|
| model.layers.27.mlp.down_proj.weight | (8192, 3072) | (8192, 3072) | True | 0 | 0 | 0 | |
|
| model.layers.27.input_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.layers.27.post_attention_layernorm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| model.norm.weight | (3072,) | (3072,) | True | 0 | 0 | 0 | |
|
| lm_head.weight | (3072, 128256) | (3072, 128256) | True | 0 | 0 | 0 | |
|
|
|
**Note:** |
|
|
|
* `Allclose` indicates whether the weights are approximately equal within the specified relative (`rtol=1e-5`) and absolute (`atol=1e-3`) tolerances using `jnp.allclose()`. |
|
* `Max Diff`, `Mean Diff`, and `Std Diff` provide further details on the differences between the weights if `Allclose` is `False`, which might be expected for some layers due to numerical precision differences between frameworks. |
|
|
|
## Hardware Used for Conversion |
|
|
|
The conversion process was performed on the following hardware configuration: |
|
|
|
* **CPU:** |
|
* **RAM:** 251.67 GB |
|
* **OS:** Linux-5.15.0-107-generic-x86_64-with-glibc2.36 |
|
* **JAX version:** 0.3.22 |
|
* **Flax version:** 0.6.2 |
|
* **Transformers version:** 4.47.0 |
|
* **GPU:** NVIDIA A100-SXM4-40GB |
|
|
|
This conversion took approximately 81.05 seconds to complete. |
|
|
|
## Usage |
|
|
|
Here's how you can use the converted model in JAX/Flax for text generation: |
|
|
|
```python |
|
import jax |
|
import jax.numpy as jnp |
|
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer |
|
|
|
model_name = "Erland/Llama-3.2-3B-JAX" # Replace with your repository name |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = FlaxAutoModelForCausalLM.from_pretrained(model_name, from_pt=False) # from_pt should be False since it's already flax |
|
|
|
# Example prompt |
|
prompt = "The quick brown fox" |
|
|
|
# Tokenize the prompt |
|
tokenized_prompt = tokenizer(prompt, return_tensors="np") |
|
|
|
# Generate text |
|
output_ids = model.generate(tokenized_prompt.input_ids, max_length=50) |
|
|
|
# Decode the generated text |
|
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
``` |
|
## Limitations |
|
|
|
Sequence Length: As mentioned earlier, the max_position_embeddings has been modified to 16384. Be mindful of this limitation when working with long sequences. |
|
|
|
Numerical Precision: Minor differences in outputs compared to the original PyTorch model might be observed due to numerical precision variations between PyTorch and JAX/Flax, particularly on different hardware. |
|
|
|
## Acknowledgements |
|
|
|
We thank the original authors of meta-llama/Llama-3.2-3B at `meta-llama` for their groundbreaking work in developing this powerful language model. |
|
|
|
We acknowledge the Hugging Face Transformers library for providing the essential tools and infrastructure that made this conversion possible. |
|
|
|
Thanks to the JAX and Flax teams for developing such performant and flexible frameworks for numerical computation and deep learning. |
|
|
|
## License |
|
|
|
This JAX/Flax model is released under the original model license. |