shrimai19 commited on
Commit
61b43b0
·
verified ·
1 Parent(s): be6a907

Delete Mistral-NeMo-12B-Instruct-HF/convert_mistral_weights_to_hf.py

Browse files
Mistral-NeMo-12B-Instruct-HF/convert_mistral_weights_to_hf.py DELETED
@@ -1,260 +0,0 @@
1
- # Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import argparse
15
- import gc
16
- import json
17
- import os
18
- import shutil
19
- import warnings
20
-
21
- import torch
22
- from safetensors.torch import load_file as safe_load_file
23
-
24
- from transformers import (
25
- LlamaTokenizer,
26
- MistralConfig,
27
- MistralForCausalLM,
28
- )
29
-
30
-
31
- try:
32
- from transformers import LlamaTokenizerFast
33
-
34
- tokenizer_class = LlamaTokenizerFast
35
- except ImportError as e:
36
- warnings.warn(e)
37
- warnings.warn(
38
- "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
39
- )
40
- tokenizer_class = LlamaTokenizer
41
-
42
- """
43
- Sample usage:
44
-
45
- ```
46
- python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \
47
- --input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path
48
- ```
49
-
50
- Thereafter, models can be loaded via:
51
-
52
- ```py
53
- from transformers import MistralForCausalLM, LlamaTokenizer
54
-
55
- model = MistralForCausalLM.from_pretrained("/output/path")
56
- tokenizer = LlamaTokenizer.from_pretrained("/output/path")
57
- ```
58
-
59
- Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
60
- come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
61
- """
62
-
63
- NUM_SHARDS = {"7B": 1}
64
-
65
-
66
- def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
67
- return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
68
-
69
-
70
- def read_json(path):
71
- with open(path, "r") as f:
72
- return json.load(f)
73
-
74
-
75
- def write_json(text, path):
76
- with open(path, "w") as f:
77
- json.dump(text, f)
78
-
79
-
80
- def write_model(model_path, input_base_path, safe_serialization=True, is_v3=False):
81
- # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
82
- os.makedirs(model_path, exist_ok=True)
83
- tmp_model_path = os.path.join(model_path, "tmp")
84
- os.makedirs(tmp_model_path, exist_ok=True)
85
-
86
- params = read_json(os.path.join(input_base_path, "params.json"))
87
- num_shards = 1
88
-
89
- n_layers = params["n_layers"]
90
- n_heads = params["n_heads"]
91
- n_heads_per_shard = n_heads // num_shards
92
- dim = params["dim"]
93
- dims_per_head = params["head_dim"]
94
- base = params.get("rope_theta", 10000.0)
95
- inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
96
- max_position_embeddings = 128000 * 8
97
-
98
- vocab_size = params["vocab_size"]
99
-
100
- if "n_kv_heads" in params:
101
- num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
102
- num_local_key_value_heads = num_key_value_heads // num_shards
103
- key_value_dim = dims_per_head * num_local_key_value_heads
104
-
105
- # permute for sliced rotary
106
- def permute(w, n_heads=n_heads, dim1=dims_per_head * n_heads, dim2=dim):
107
- return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
108
-
109
- print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
110
-
111
- # Load weights - for v3 models the consolidated weights are in a single file format in safetensors
112
- if is_v3:
113
- loaded = [safe_load_file(os.path.join(input_base_path, "consolidated.safetensors"))]
114
- else:
115
- loaded = [
116
- torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
117
- for i in range(num_shards)
118
- ]
119
- param_count = 0
120
- index_dict = {"weight_map": {}}
121
- for layer_i in range(n_layers):
122
- filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
123
-
124
- # Sharded
125
- # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
126
- # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
127
- # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
128
-
129
- state_dict = {
130
- f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
131
- f"layers.{layer_i}.attention_norm.weight"
132
- ].clone(),
133
- f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
134
- f"layers.{layer_i}.ffn_norm.weight"
135
- ].clone(),
136
- }
137
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
138
- torch.cat(
139
- [
140
- loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
141
- for i in range(num_shards)
142
- ],
143
- dim=0,
144
- ).reshape(n_heads_per_shard * dims_per_head, dim)
145
- )
146
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
147
- torch.cat(
148
- [
149
- loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
150
- num_local_key_value_heads, dims_per_head, dim
151
- )
152
- for i in range(num_shards)
153
- ],
154
- dim=0,
155
- ).reshape(key_value_dim, dim),
156
- num_key_value_heads,
157
- key_value_dim,
158
- dim,
159
- )
160
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
161
- [
162
- loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(num_local_key_value_heads, dims_per_head, dim)
163
- for i in range(num_shards)
164
- ],
165
- dim=0,
166
- ).reshape(key_value_dim, dim)
167
-
168
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
169
- [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
170
- )
171
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
172
- [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
173
- )
174
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
175
- [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
176
- )
177
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
178
- [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
179
- )
180
-
181
- state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
182
- for k, v in state_dict.items():
183
- index_dict["weight_map"][k] = filename
184
- param_count += v.numel()
185
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
186
-
187
- filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
188
- state_dict = {
189
- "model.norm.weight": loaded[0]["norm.weight"],
190
- "model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1),
191
- "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
192
- }
193
-
194
- for k, v in state_dict.items():
195
- index_dict["weight_map"][k] = filename
196
- param_count += v.numel()
197
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
198
-
199
- # Write configs
200
- index_dict["metadata"] = {"total_size": param_count * 2}
201
- write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
202
- config = MistralConfig(
203
- hidden_size=dim,
204
- intermediate_size=params["hidden_dim"],
205
- num_attention_heads=params["n_heads"],
206
- num_hidden_layers=params["n_layers"],
207
- rms_norm_eps=params["norm_eps"],
208
- kv_channels=params["head_dim"],
209
- num_key_value_heads=num_key_value_heads,
210
- vocab_size=vocab_size,
211
- rope_theta=base,
212
- max_position_embeddings=max_position_embeddings,
213
- sliding_window=None,
214
- )
215
- config.save_pretrained(tmp_model_path)
216
-
217
- # Make space so we can load the model properly now.
218
- del state_dict
219
- del loaded
220
- gc.collect()
221
-
222
- print("Loading the checkpoint in a Mistral model.")
223
- model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16)
224
- # Avoid saving this as part of the config.
225
- del model.config._name_or_path
226
- model.config.torch_dtype = torch.float16
227
- print("Saving in the Transformers format.")
228
-
229
- model.save_pretrained(model_path, safe_serialization=safe_serialization)
230
- shutil.rmtree(tmp_model_path)
231
-
232
-
233
- def write_tokenizer(tokenizer_path, input_tokenizer_path):
234
- # Initialize the tokenizer based on the `spm` model
235
- print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
236
- tokenizer = tokenizer_class(input_tokenizer_path)
237
- tokenizer.save_pretrained(tokenizer_path)
238
-
239
-
240
- def main():
241
- parser = argparse.ArgumentParser()
242
- parser.add_argument(
243
- "--input_dir",
244
- help="Location of Mistral weights, which contains tokenizer.model and model folders",
245
- )
246
- parser.add_argument(
247
- "--output_dir",
248
- help="Location to write HF model and tokenizer",
249
- )
250
- args = parser.parse_args()
251
- write_model(
252
- model_path=args.output_dir,
253
- input_base_path=args.input_dir,
254
- safe_serialization=True,
255
- is_v3=True,
256
- )
257
-
258
-
259
- if __name__ == "__main__":
260
- main()