JosefAlbers commited on
Commit
5581c1c
1 Parent(s): 8cbe574

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +594 -0
README.md ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # RetNPhi: Byte-Level Hybrid of Phi-3.5 and RetNet
3
+
4
+ RetNPhi is an experimental architecture that transforms Phi-3.5 into a byte-level language model, incorporating RetNet-inspired mechanisms. This innovative approach enables the model to process raw byte sequences, allowing for universal file type handling.
5
+
6
+ ## Key Features:
7
+
8
+ 1. **Byte-Level Processing**: Operates directly on raw byte sequences, enabling universal application to any file type.
9
+ 2. **RetNet Integration**: Incorporates RetNet's multi-scale exponential decay and group normalization for efficient long-range dependency modeling.
10
+ 3. **Dual-mode Processing**: Supports parallel mode for efficient training and recurrent mode for inference.
11
+ 4. **Selective Fine-tuning**: Trains only specific layers (e.g., token embedding, post-attention layer normalizations) while keeping most of the original Phi-3.5 weights frozen.
12
+ 5. **Weight-Decomposed Low-Rank Adaptation (DoRA)**: Applies DoRA to self-attention output projections for efficient adaptation while preserving pretrained knowledge.
13
+
14
+ ## Implementation Strategy:
15
+
16
+ - **Weight Reuse**: Utilizes frozen weights from the original Phi-3.5 model for most layers.
17
+ - **Flexible DoRA Application**: Allows configuration of which layers and targets to apply DoRA.
18
+ - **Configurable Architecture**: Supports both retention-based and original attention mechanisms.
19
+ - **Untied Embeddings Option**: Provides the ability to use separate input and output embeddings.
20
+
21
+ ## Training and Inference:
22
+
23
+ - Implements efficient training loops with customizable learning rate schedules.
24
+ - Supports both training from scratch and fine-tuning from a checkpoint.
25
+ - Provides a generation function for text completion tasks.
26
+
27
+ ## Goals:
28
+
29
+ - Explore the potential of retention-like mechanisms in a byte-level Phi architecture.
30
+ - Leverage dual-mode processing for efficient training and inference.
31
+ - Develop a universal model capable of processing any file type.
32
+
33
+ Note: This is a highly experimental implementation, designed for research and exploration rather than production use. It demonstrates the potential of combining pretrained models with novel architectures and efficient fine-tuning techniques.
34
+
35
+ Author: Josef Albers
36
+ Date: Aug 28, 2024
37
+ """
38
+
39
+ import glob
40
+ import json
41
+ import math
42
+ import time
43
+ from datetime import datetime
44
+ from types import SimpleNamespace
45
+
46
+ import fire
47
+ import mlx.core as mx
48
+ import mlx.nn as nn
49
+ import mlx.optimizers as optim
50
+ import numpy as np
51
+ from huggingface_hub import snapshot_download
52
+ from mlx.utils import tree_flatten, tree_unflatten
53
+
54
+ from datasets import load_dataset
55
+
56
+ class Tokenizer:
57
+ def __init__(self, file_path=None):
58
+ if file_path is None:
59
+ self.vocab = list(range(256))
60
+ else:
61
+ with open(file_path, 'r') as f:
62
+ content = f.read().lower().encode('utf-8')
63
+ self.vocab = sorted(set(content))
64
+ self.vocab_size = len(self.vocab)
65
+ self.byte_to_index = {byte: index for index, byte in enumerate(self.vocab)}
66
+ self.index_to_byte = {index: byte for index, byte in enumerate(self.vocab)}
67
+
68
+ def encode(self, text):
69
+ byte_seq = text.encode('utf-8')
70
+ return [self.byte_to_index[byte] for byte in byte_seq]
71
+
72
+ def decode(self, indices):
73
+ byte_seq = bytes(self.index_to_byte[index] for index in indices)
74
+ return byte_seq.decode('utf-8', errors='ignore')
75
+
76
+ class SuRoPE(nn.Module):
77
+ def __init__(self, config):
78
+ super().__init__()
79
+ self.dim = config.hidden_size // config.num_attention_heads
80
+ self.original_max_position_embeddings = config.original_max_position_embeddings
81
+ self.rope_theta = config.rope_theta
82
+ self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
83
+ self._long_factor = mx.array(config.rope_scaling["long_factor"], dtype=mx.float32)
84
+ self._short_factor = mx.array(config.rope_scaling["short_factor"], dtype=mx.float32)
85
+
86
+ def __call__(self, q, k, position_ids):
87
+ cos, sin = self._get_cos_sin(position_ids)
88
+ q = (q * cos) + (self._rotate_half(q) * sin)
89
+ k = (k * cos) + (self._rotate_half(k) * sin)
90
+ return q, k
91
+
92
+ def _get_cos_sin(self, position_ids):
93
+ su_factor = self._short_factor
94
+ position_ids_expanded = position_ids[:, None, :]
95
+ inv_freq = 1.0 / (su_factor * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
96
+ inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
97
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
98
+ emb = mx.concatenate([freqs, freqs], axis=-1)
99
+ cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
100
+ sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
101
+ return cos, sin
102
+
103
+ def _rotate_half(self, x):
104
+ midpoint = x.shape[-1] // 2
105
+ x1, x2 = x[..., :midpoint], x[..., midpoint:]
106
+ return mx.concatenate([-x2, x1], axis=-1)
107
+
108
+ class Phi3Attention(nn.Module):
109
+ def __init__(self, config):
110
+ super().__init__()
111
+ dim = config.hidden_size
112
+ self.n_heads = n_heads = config.num_attention_heads
113
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
114
+ self.num_hidden_layers = config.num_hidden_layers
115
+ self.head_dim = head_dim = config.hidden_size // n_heads
116
+ self.scale = head_dim**-0.5
117
+ chop_1 = self.n_heads * self.head_dim
118
+ chop_2 = chop_1 + self.n_kv_heads * self.head_dim
119
+ self.chop = [chop_1, chop_2]
120
+ op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
121
+ self.qkv_proj = nn.Linear(dim, op_size, bias=False)
122
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
123
+ self.rope = SuRoPE(config)
124
+
125
+ def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
126
+ B, L, _ = x.shape
127
+ qkv = self.qkv_proj(x)
128
+ q, k, v = mx.split(qkv, self.chop, axis=-1)
129
+ q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
130
+ k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
131
+ v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
132
+ if cache is None:
133
+ position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
134
+ q, k = self.rope(q,k,position_ids)
135
+ mask = mx.triu(mx.full((v.shape[2], v.shape[2]), -mx.inf), k=1)
136
+ if attention_mask is not None:
137
+ mask += mx.where(attention_mask[:, :, None]*attention_mask[:, None, :]==1, 0, -mx.inf)
138
+ mask = mx.expand_dims(mask, 1)
139
+ else:
140
+ mask = mask[None, None]
141
+ else:
142
+ past_k, past_v, past_p, past_m = cache
143
+ position_ids = past_p[:,-1:]+1
144
+ mask = mx.pad(past_m[:,:,-1:,:], ((0,0),(0,0),(0,0),(0,1)))
145
+ q, k = self.rope(q, k, position_ids)
146
+ k = mx.concatenate([past_k, k], axis=2)
147
+ v = mx.concatenate([past_v, v], axis=2)
148
+ cache = (k, v, position_ids, mask)
149
+ w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
150
+ w += mask
151
+ w = mx.softmax(w, axis=-1)
152
+ o = w @ v
153
+ o = o.transpose(0, 2, 1, 3).reshape(B, L, -1)
154
+ return self.o_proj(o).astype(x.dtype), cache
155
+
156
+ class Phi3Retention(nn.Module):
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ self.dim = dim = config.hidden_size
160
+ self.n_heads = n_heads = config.num_attention_heads
161
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
162
+ self.num_hidden_layers = config.num_hidden_layers
163
+ self.head_dim = head_dim = config.hidden_size // n_heads
164
+ self.scale = head_dim**-0.5
165
+ chop_1 = self.n_heads * self.head_dim
166
+ chop_2 = chop_1 + self.n_kv_heads * self.head_dim
167
+ self.chop = [chop_1, chop_2]
168
+ op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
169
+ self.qkv_proj = nn.Linear(dim, op_size, bias=False)
170
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
171
+ self.rope = SuRoPE(config)
172
+ xmin, xmax = math.log(1 / 32), math.log(1 / 512)
173
+ x = mx.linspace(xmin, xmax, num=n_heads)
174
+ self._gamma = 1 - x.exp()
175
+ self.gn = nn.GroupNorm(num_groups=head_dim, dims=-1, affine=False)
176
+
177
+ def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
178
+ if use_recurrent_mode:
179
+ return self.recurrent_mode(x, cache)
180
+ B, L, _ = x.shape
181
+ qkv = self.qkv_proj(x)
182
+ q, k, v = mx.split(qkv, self.chop, axis=-1)
183
+ q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
184
+ k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
185
+ v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
186
+ position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
187
+ q, k = self.rope(q,k,position_ids)
188
+ cache = None
189
+ w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
190
+ w = w * self._decay(L)
191
+ o = w @ v
192
+ o = o.transpose(0, 2, 1, 3).reshape(B*L, -1)
193
+ o = self.gn(o).reshape(B, L, -1)
194
+ return self.o_proj(o).astype(x.dtype), cache
195
+
196
+ def recurrent_mode(self, x, cache):
197
+ if cache is None:
198
+ s = mx.zeros((1, 32, 96, 96))
199
+ n = 0
200
+ else:
201
+ s, n = cache
202
+ qkv = self.qkv_proj(x)
203
+ q, k, v = mx.split(qkv, self.chop, axis=-1)
204
+ q = q.reshape(1, 1, self.n_heads, -1).transpose(0, 2, 1, 3)
205
+ k = k.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
206
+ v = v.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
207
+ position_ids = mx.array([[n]])
208
+ q, k = self.rope(q,k,position_ids)
209
+ k = k * self.scale
210
+ s = self._gamma[None, :, None, None] * s + (k.transpose(0, 1, 3, 2) @ v)
211
+ o = q @ s
212
+ o = o.transpose(0, 2, 1, 3).reshape(1, -1)
213
+ o = self.gn(o).reshape(1, 1, -1)
214
+ o = self.o_proj(o).astype(x.dtype)
215
+ return o, (s, n+1)
216
+
217
+ def _decay(self, sequence_length):
218
+ n = mx.arange(sequence_length)[:,None]
219
+ m = mx.arange(sequence_length)[None]
220
+ D = (self._gamma[:, None, None] ** (n-m)) * (n >= m)
221
+ return D
222
+
223
+ class Phi3MLP(nn.Module):
224
+ def __init__(self, config):
225
+ super().__init__()
226
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
227
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
228
+
229
+ def __call__(self, x):
230
+ x = self.gate_up_proj(x)
231
+ gate, x = mx.split(x, 2, axis=-1)
232
+ return self.down_proj(nn.silu(gate) * x)
233
+
234
+ class Phi3DecoderLayer(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ if config.use_retention:
238
+ self.self_attn = Phi3Retention(config)
239
+ else:
240
+ self.self_attn = Phi3Attention(config)
241
+ self.mlp = Phi3MLP(config)
242
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
243
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
+
245
+ def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
246
+ r, cache = self.self_attn(self.input_layernorm(x), position_ids, attention_mask, cache, use_recurrent_mode)
247
+ h = x + r
248
+ r = self.mlp(self.post_attention_layernorm(h))
249
+ return h + r, cache
250
+
251
+ class Phi3Model(nn.Module):
252
+ def __init__(self, config):
253
+ super().__init__()
254
+ self.embed_new = nn.Embedding(config.vocab_size, config.hidden_size)
255
+ self.layers = [Phi3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
256
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
257
+
258
+ def __call__(self, input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode):
259
+ x = self.embed_new(input_ids)
260
+ cache = [None]*len(self.layers) if cache is None else cache
261
+ for i, l in enumerate(self.layers):
262
+ x, cache[i] = l(x, position_ids, attention_mask, cache[i], use_recurrent_mode)
263
+ return self.norm(x), cache
264
+
265
+ class Phi3ForCausalLM(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ self.model = Phi3Model(config)
269
+ if config.untie_embedding:
270
+ self.lm_new = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
271
+ self.untie = True
272
+ else:
273
+ self.untie = False
274
+
275
+ def __call__(self, input_ids, pixel_values=None, image_sizes=None, position_ids=None, attention_mask=None, cache=None, use_recurrent_mode=False):
276
+ x, cache = self.model(input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode)
277
+ if self.untie:
278
+ return self.lm_new(x), cache
279
+ return self.model.embed_new.as_linear(x), cache
280
+
281
+ @property
282
+ def layers(self):
283
+ return self.model.layers
284
+
285
+ class DoRALinear(nn.Module):
286
+ @staticmethod
287
+ def from_linear(linear, r, alpha, scale, dropout):
288
+ output_dims, input_dims = linear.weight.shape
289
+ if isinstance(linear, nn.QuantizedLinear):
290
+ input_dims *= 32 // linear.bits
291
+ lora_lin = DoRALinear(input_dims=input_dims, output_dims=output_dims, r=r, alpha=alpha, scale=scale, dropout=dropout)
292
+ lora_lin.linear = linear
293
+ return lora_lin
294
+
295
+ def __init__(self, input_dims, output_dims, r, alpha, scale, dropout, bias=False):
296
+ super().__init__()
297
+ self.linear = nn.Linear(input_dims, output_dims, bias=bias)
298
+ self.dropout = nn.Dropout(p=dropout)
299
+ self.scale = scale * (alpha / r)
300
+ scale = 1 / math.sqrt(input_dims)
301
+ self.lora_a = mx.random.uniform(low=-scale, high=scale, shape=(input_dims, r))
302
+ self.lora_b = mx.zeros(shape=(r, output_dims))
303
+ self.m = mx.linalg.norm(self._dequantized_weight(), axis=1).astype(mx.float32)
304
+
305
+ def _dequantized_weight(self):
306
+ weight = self.linear.weight
307
+ if isinstance(self.linear, nn.QuantizedLinear):
308
+ weight = mx.dequantize(weight, self.linear.scales, self.linear.biases, self.linear.group_size, self.linear.bits)
309
+ return weight
310
+
311
+ def __call__(self, x):
312
+ y = self.linear(x)
313
+ z = (self.dropout(x) @ self.lora_a) @ self.lora_b
314
+ z = y + (self.scale * z)
315
+ adapted = self._dequantized_weight() + (self.scale * self.lora_b.T) @ self.lora_a.T
316
+ denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
317
+ z = (self.m / denom) * z
318
+ return z.astype(x.dtype)
319
+
320
+ def linear_to_lora_layers(model, lora_layers, lora_targets, lora_rank, lora_scale, lora_dropout):
321
+ if lora_layers == 'all':
322
+ lora_layers = model.layers
323
+ elif isinstance(lora_layers, int):
324
+ lora_layers = model.layers[-lora_layers:]
325
+ elif isinstance(lora_layers, list):
326
+ lora_layers = [model.layers[i] for i in lora_layers]
327
+ else:
328
+ raise ValueError("Invalid type for lora_layers. Expected int (number of layers) or list (layer indices or names).")
329
+ def to_lora(layer):
330
+ return DoRALinear.from_linear(layer, r=lora_rank, alpha=lora_rank, scale=lora_scale, dropout=lora_dropout)
331
+ for l in lora_layers:
332
+ lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in lora_targets]
333
+ l.update_modules(tree_unflatten(lora_layers))
334
+
335
+ def load_base_model(model_cfg, init=False):
336
+ model_id='microsoft/Phi-3.5-mini-instruct'
337
+ model_path = snapshot_download(model_id, allow_patterns=["*.safetensors", "config.json"])
338
+ with open(f"{model_path}/config.json", "r") as f:
339
+ config = json.load(f)
340
+ config = config|model_cfg
341
+ model_config = SimpleNamespace(**config)
342
+ model = Phi3ForCausalLM(model_config)
343
+ model_weight = [(k, v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()]
344
+ model.load_weights(model_weight, strict=False)
345
+ model.set_dtype(mx.float32)
346
+ if init:
347
+ init_fn_embed = nn.init.normal(mean=-0.000453949, std=0.0344238)
348
+ model.apply_to_modules(lambda k, v: v.apply(init_fn_embed) if k.endswith('embed_new') else None)
349
+ if model_config.untie_embedding:
350
+ init_fn_lm = nn.init.normal(mean=-0.000231743, std=0.043457)
351
+ model.apply_to_modules(lambda k, v: v.apply(init_fn_lm) if k.endswith('lm_new') else None)
352
+ class_predicate = lambda k, m: hasattr(m, "to_quantized") and not k.endswith('new')
353
+ nn.quantize(model, 64, 4, class_predicate)
354
+ mx.eval(model.parameters())
355
+ return model
356
+
357
+ def load_model_for_training(lora_cfg, model_cfg, thaws, from_path=None):
358
+ model = load_base_model(model_cfg, init=False)
359
+ if from_path:
360
+ model.load_weights(from_path, strict=False)
361
+ model.freeze()
362
+ if len(lora_cfg['targets']) > 1:
363
+ linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
364
+ model.apply_to_modules(lambda k, v: v.unfreeze() if any(k.endswith(t) for t in thaws) else None)
365
+ mx.eval(model.parameters())
366
+ # print("Trainable parameters:", [i[0] for i in tree_flatten(model.trainable_parameters())])
367
+ model.train()
368
+ return model
369
+
370
+ def load_model_for_inference(lora_cfg, model_cfg):
371
+ model = load_base_model(model_cfg, init=False)
372
+ if len(lora_cfg['targets']) > 1:
373
+ linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
374
+ _path = 'trained_retnphi.safetensors' if model_cfg['use_retention'] else 'trained_orgnphi.safetensors'
375
+ model.load_weights(_path, strict=False)
376
+ mx.eval(model.parameters())
377
+ model.eval()
378
+ return model
379
+
380
+ def generate(prompt, lora_cfg, model_cfg, max_tokens=50, verbose = True):
381
+ model = load_model_for_inference(lora_cfg=lora_cfg, model_cfg=model_cfg)
382
+ input_ids = mx.array(tokenizer.encode(prompt))
383
+ if model_cfg['use_retention']:
384
+ cache = None
385
+ for i in input_ids:
386
+ logits, cache = model(i[None, None], cache=cache, use_recurrent_mode=True)
387
+ else:
388
+ logits, cache = model(input_ids[None])
389
+ token = mx.argmax(logits[:,-1,:], axis=-1)
390
+ mx.eval(token, cache)
391
+ list_tokens = token.tolist()
392
+ for i in range(max_tokens):
393
+ logits, cache = model(token[None], cache=cache, use_recurrent_mode=True)
394
+ token = mx.argmax(logits[:,-1,:], axis=-1)
395
+ mx.eval(token, cache)
396
+ list_tokens += token.tolist()
397
+ if tokenizer.decode(list_tokens[-2:]) == '\n\n':
398
+ break
399
+ output = tokenizer.decode(list_tokens)
400
+ if verbose:
401
+ print(f'{prompt=} + {output=}\n-> {prompt+output}')
402
+ del model
403
+ return output
404
+
405
+ def train_gsm(learning_rates, num_epochs, batch_size, seq_length, lora_cfg, model_cfg, thaws, take, from_path=None):
406
+ def load_gsm_data(tokenizer, is_tiny=True):
407
+ if is_tiny:
408
+ data = load_dataset("TinyGSM/TinyGSM")["train"]
409
+ if take:
410
+ data = data.take(take)
411
+ data = data.filter(lambda x: len(x['question']) < 100 and ':' not in x['question'] and '-' not in x['question'] and "'" not in x['code'] and '\n result =' in x['code'])
412
+ split_point = int(len(data) * 0.8)
413
+ train_data = data.select(range(split_point))
414
+ eval_data = data.select(range(split_point, len(data)))
415
+ def format_example(example):
416
+ code_raw = example['code']
417
+ start = code_raw.rfind('\n """')
418
+ if start == -1:
419
+ print('Wrong format to start')
420
+ return code_raw.strip()
421
+ start = start + 8
422
+ end = code_raw.rfind('\n result =')
423
+ if end == -1:
424
+ print('Wrong format to end')
425
+ end = len(code_raw)
426
+ code_block = code_raw[start:end]
427
+ code_lines = code_block.split('\n ')
428
+ formatted_code = '\n'.join(line.rstrip() for line in code_lines if line.strip())
429
+ formatted_code = '\n' + formatted_code.strip() + '\n\n'
430
+ result = (example['question'].strip(), formatted_code)
431
+ return result
432
+ else:
433
+ dataset = load_dataset("openai/gsm8k", "main")
434
+ train_data = dataset["train"]
435
+ eval_data = dataset["test"]
436
+ def format_example(example):
437
+ return (example['question'].strip(), '\n'+example['answer'].strip()+'\n\n')
438
+ train_formatted = [format_example(ex) for ex in train_data]
439
+ eval_formatted = [format_example(ex) for ex in eval_data]
440
+ return train_formatted, eval_formatted
441
+
442
+ def create_batches(data, tokenizer, batch_size, seq_length):
443
+ def _encode(x):
444
+ return [tokenizer.encode(i) for i in x]
445
+ encoded_data = [_encode(x) for x in data]
446
+ encoded_data = [x for x in encoded_data if len(x[0]+x[1]) <= seq_length+1]
447
+ if batch_size is None:
448
+ batch_size = min(len(encoded_data), 64)
449
+ else:
450
+ encoded_data = encoded_data[:(len(encoded_data) // batch_size) * batch_size]
451
+ np.random.shuffle(encoded_data)
452
+ for i in range(0, len(encoded_data), batch_size):
453
+ batch = encoded_data[i:i+batch_size]
454
+ max_len = min(max(len(q+a)-1 for q, a in batch), seq_length)
455
+ x_batch = []
456
+ y_batch = []
457
+ mask_batch = []
458
+ for q, a in batch:
459
+ combined = (q+a)[:max_len+1]
460
+ x = combined[:-1]
461
+ y = combined[1:]
462
+ pad_length = max_len - len(x)
463
+ x = x + [0] * pad_length
464
+ y = y + [0] * pad_length
465
+ mask = [False] * (len(q)-1) + [True] * (len(a)) + [False] * (pad_length)
466
+ x_batch.append(x)
467
+ y_batch.append(y)
468
+ mask_batch.append(mask)
469
+ yield mx.array(x_batch), mx.array(y_batch), mx.array(mask_batch)
470
+
471
+ def loss_fn(model, X, y, mask):
472
+ logits, _ = model(X)
473
+ logits = logits.astype(mx.float32)
474
+ ce = nn.losses.cross_entropy(logits, y, reduction='none')
475
+ masked_loss = ce * mask
476
+ return masked_loss.sum(), mask.sum()
477
+
478
+ def evaluate(model, data, tokenizer, seq_length):
479
+ model.eval()
480
+ total_loss = 0
481
+ total_samples = 0
482
+ for X, y, mask in create_batches(data, tokenizer, None, seq_length):
483
+ loss, ntoks = loss_fn(model, X, y, mask)
484
+ total_loss += loss.item()
485
+ total_samples += ntoks.item()
486
+ return total_loss / total_samples if total_samples > 0 else -1
487
+
488
+ def get_optimizer(train_data):
489
+ num_batches_per_epoch = len(list(create_batches(train_data, tokenizer, batch_size, seq_length)))
490
+ print(f'{num_batches_per_epoch=}')
491
+ num_steps = num_epochs * num_batches_per_epoch
492
+ num_warmup = num_steps // 10
493
+ max_lr, min_lr = learning_rates
494
+ if num_warmup > 2:
495
+ warmup = optim.linear_schedule(min_lr*0.1, max_lr, steps=num_warmup)
496
+ cosine = optim.cosine_decay(max_lr, num_steps - num_warmup, min_lr)
497
+ lr_schedule = optim.join_schedules([warmup, cosine], [num_warmup])
498
+ else:
499
+ lr_schedule = optim.cosine_decay(max_lr, num_steps, min_lr)
500
+ return optim.Lion(learning_rate=lr_schedule), num_steps
501
+
502
+ for arg_name in sorted(locals()):
503
+ if arg_name != 'self':
504
+ arg_value = locals()[arg_name]
505
+ if not callable(arg_value):
506
+ print(f"{arg_name}: {arg_value}")
507
+
508
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
509
+ print(f'--- {timestamp} ---')
510
+ train_data, eval_data = load_gsm_data(tokenizer=tokenizer)
511
+ model = load_model_for_training(lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws)
512
+ optimizer, num_steps = get_optimizer(train_data)
513
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
514
+ mx.eval(model, optimizer)
515
+ metrics = {
516
+ 'steps': [],
517
+ 'learning_rates': [],
518
+ 'all_train_losses': [],
519
+ 'avg_train_losses': [],
520
+ 'val_losses': [],
521
+ 'trained_toks': [],
522
+ }
523
+ step = 0
524
+ trained_toks = 0
525
+ losses = []
526
+ tic = time.perf_counter()
527
+ for epoch in range(num_epochs):
528
+ for X, y, loss_mask in create_batches(data=train_data, tokenizer=tokenizer, batch_size=batch_size, seq_length=seq_length):
529
+ model.train()
530
+ (loss, ntoks), grads = loss_and_grad_fn(model, X, y, loss_mask)
531
+ optimizer.update(model, grads)
532
+ mx.eval(loss, ntoks, model, optimizer)
533
+ losses.append(loss.item())
534
+ trained_toks += ntoks.item()
535
+ step += 1
536
+ if (step % (num_steps // 30) == 0):
537
+ avg_train_loss = np.mean(losses)
538
+ lr = optimizer.learning_rate.item()
539
+ val_loss = evaluate(model=model, data=eval_data, tokenizer=tokenizer, seq_length=seq_length)
540
+ print(f"{avg_train_loss:8.4f} ({val_loss:6.4f}) @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
541
+ metrics['val_losses'].append(val_loss)
542
+ # print(f"{avg_train_loss:8.4f} @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
543
+ tic = time.perf_counter()
544
+ metrics['steps'].append(step)
545
+ metrics['learning_rates'].append(lr)
546
+ metrics['all_train_losses'].extend(losses)
547
+ metrics['avg_train_losses'].append(avg_train_loss)
548
+ metrics['trained_toks'].append(trained_toks)
549
+ losses = []
550
+ trained_toks = 0
551
+ _path = f'trained_retnphi.safetensors' if model_cfg['use_retention'] else f'trained_orgnphi.safetensors'
552
+ mx.save_safetensors(_path, dict(tree_flatten(model.trainable_parameters())))
553
+ log = {
554
+ 'args': {
555
+ 'learning_rates': learning_rates,
556
+ 'num_epochs': num_epochs,
557
+ 'batch_size': batch_size,
558
+ 'seq_length': seq_length,
559
+ 'lora_cfg': lora_cfg,
560
+ 'model_cfg': model_cfg,
561
+ 'thaws': thaws,
562
+ 'from_path': from_path
563
+ },
564
+ 'metrics': metrics
565
+ }
566
+ with open(f'train_log_{timestamp}.json', 'w') as f:
567
+ json.dump(log, f, indent=2)
568
+ del model
569
+
570
+ tokenizer = Tokenizer()
571
+
572
+ def main(take=1000, layers='all', targets=["self_attn.o_proj"], thaws=['new', 'post_attention_layernorm'], rank=32, scale=0.1, dropout=0.0, lr_max=1e-4, lr_min=1e-5, num_epochs=90, batch_size=1, seq_length=256, vocab_size=256, use_retention=True, untie_embedding=True, prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?'):
573
+ lora_cfg = dict(layers=layers, targets=targets, rank=rank, scale=scale, dropout=dropout)
574
+ model_cfg = dict(vocab_size=vocab_size, use_retention=use_retention, untie_embedding=untie_embedding)
575
+ train_gsm(learning_rates=(lr_max, lr_min), num_epochs=num_epochs, batch_size=batch_size, seq_length=seq_length, lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws, take=take)
576
+ generate(prompt=prompt, lora_cfg=lora_cfg, model_cfg=model_cfg, max_tokens=(seq_length-len(prompt)))
577
+
578
+ if __name__ == "__main__":
579
+ main(take=None, num_epochs=3) # -> 240916
580
+ main(take=None, num_epochs=3, untie_embedding=False)
581
+
582
+ main(take=None, num_epochs=3, use_retention=False)
583
+ main(take=None, num_epochs=3, untie_embedding=False, use_retention=False)
584
+ # fire.Fire(main)
585
+
586
+ # Output:
587
+ # 388.7268 @ 1/30 w/ 3.36e-05 (64.73 sec)
588
+ # ...
589
+ # 4.3768 @ 30/30 w/ 1.00e-05 (64.36 sec)
590
+ # prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?' + output='\ncandies_in_carton = 8 \nnumber_of_cartons = 5\ntotal_no_of_candies = candies_in_carton * number_of_cartons\n\n'
591
+ # -> There are 8 candies in a carton. How many candies will be in 5 cartons?
592
+ # candies_in_carton = 8
593
+ # number_of_cartons = 5
594
+ # total_no_of_candies = candies_in_carton * number_of_cartons