JosefAlbers
commited on
Commit
•
5581c1c
1
Parent(s):
8cbe574
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# RetNPhi: Byte-Level Hybrid of Phi-3.5 and RetNet
|
3 |
+
|
4 |
+
RetNPhi is an experimental architecture that transforms Phi-3.5 into a byte-level language model, incorporating RetNet-inspired mechanisms. This innovative approach enables the model to process raw byte sequences, allowing for universal file type handling.
|
5 |
+
|
6 |
+
## Key Features:
|
7 |
+
|
8 |
+
1. **Byte-Level Processing**: Operates directly on raw byte sequences, enabling universal application to any file type.
|
9 |
+
2. **RetNet Integration**: Incorporates RetNet's multi-scale exponential decay and group normalization for efficient long-range dependency modeling.
|
10 |
+
3. **Dual-mode Processing**: Supports parallel mode for efficient training and recurrent mode for inference.
|
11 |
+
4. **Selective Fine-tuning**: Trains only specific layers (e.g., token embedding, post-attention layer normalizations) while keeping most of the original Phi-3.5 weights frozen.
|
12 |
+
5. **Weight-Decomposed Low-Rank Adaptation (DoRA)**: Applies DoRA to self-attention output projections for efficient adaptation while preserving pretrained knowledge.
|
13 |
+
|
14 |
+
## Implementation Strategy:
|
15 |
+
|
16 |
+
- **Weight Reuse**: Utilizes frozen weights from the original Phi-3.5 model for most layers.
|
17 |
+
- **Flexible DoRA Application**: Allows configuration of which layers and targets to apply DoRA.
|
18 |
+
- **Configurable Architecture**: Supports both retention-based and original attention mechanisms.
|
19 |
+
- **Untied Embeddings Option**: Provides the ability to use separate input and output embeddings.
|
20 |
+
|
21 |
+
## Training and Inference:
|
22 |
+
|
23 |
+
- Implements efficient training loops with customizable learning rate schedules.
|
24 |
+
- Supports both training from scratch and fine-tuning from a checkpoint.
|
25 |
+
- Provides a generation function for text completion tasks.
|
26 |
+
|
27 |
+
## Goals:
|
28 |
+
|
29 |
+
- Explore the potential of retention-like mechanisms in a byte-level Phi architecture.
|
30 |
+
- Leverage dual-mode processing for efficient training and inference.
|
31 |
+
- Develop a universal model capable of processing any file type.
|
32 |
+
|
33 |
+
Note: This is a highly experimental implementation, designed for research and exploration rather than production use. It demonstrates the potential of combining pretrained models with novel architectures and efficient fine-tuning techniques.
|
34 |
+
|
35 |
+
Author: Josef Albers
|
36 |
+
Date: Aug 28, 2024
|
37 |
+
"""
|
38 |
+
|
39 |
+
import glob
|
40 |
+
import json
|
41 |
+
import math
|
42 |
+
import time
|
43 |
+
from datetime import datetime
|
44 |
+
from types import SimpleNamespace
|
45 |
+
|
46 |
+
import fire
|
47 |
+
import mlx.core as mx
|
48 |
+
import mlx.nn as nn
|
49 |
+
import mlx.optimizers as optim
|
50 |
+
import numpy as np
|
51 |
+
from huggingface_hub import snapshot_download
|
52 |
+
from mlx.utils import tree_flatten, tree_unflatten
|
53 |
+
|
54 |
+
from datasets import load_dataset
|
55 |
+
|
56 |
+
class Tokenizer:
|
57 |
+
def __init__(self, file_path=None):
|
58 |
+
if file_path is None:
|
59 |
+
self.vocab = list(range(256))
|
60 |
+
else:
|
61 |
+
with open(file_path, 'r') as f:
|
62 |
+
content = f.read().lower().encode('utf-8')
|
63 |
+
self.vocab = sorted(set(content))
|
64 |
+
self.vocab_size = len(self.vocab)
|
65 |
+
self.byte_to_index = {byte: index for index, byte in enumerate(self.vocab)}
|
66 |
+
self.index_to_byte = {index: byte for index, byte in enumerate(self.vocab)}
|
67 |
+
|
68 |
+
def encode(self, text):
|
69 |
+
byte_seq = text.encode('utf-8')
|
70 |
+
return [self.byte_to_index[byte] for byte in byte_seq]
|
71 |
+
|
72 |
+
def decode(self, indices):
|
73 |
+
byte_seq = bytes(self.index_to_byte[index] for index in indices)
|
74 |
+
return byte_seq.decode('utf-8', errors='ignore')
|
75 |
+
|
76 |
+
class SuRoPE(nn.Module):
|
77 |
+
def __init__(self, config):
|
78 |
+
super().__init__()
|
79 |
+
self.dim = config.hidden_size // config.num_attention_heads
|
80 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
81 |
+
self.rope_theta = config.rope_theta
|
82 |
+
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
|
83 |
+
self._long_factor = mx.array(config.rope_scaling["long_factor"], dtype=mx.float32)
|
84 |
+
self._short_factor = mx.array(config.rope_scaling["short_factor"], dtype=mx.float32)
|
85 |
+
|
86 |
+
def __call__(self, q, k, position_ids):
|
87 |
+
cos, sin = self._get_cos_sin(position_ids)
|
88 |
+
q = (q * cos) + (self._rotate_half(q) * sin)
|
89 |
+
k = (k * cos) + (self._rotate_half(k) * sin)
|
90 |
+
return q, k
|
91 |
+
|
92 |
+
def _get_cos_sin(self, position_ids):
|
93 |
+
su_factor = self._short_factor
|
94 |
+
position_ids_expanded = position_ids[:, None, :]
|
95 |
+
inv_freq = 1.0 / (su_factor * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
|
96 |
+
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
|
97 |
+
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
|
98 |
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
99 |
+
cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
|
100 |
+
sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
|
101 |
+
return cos, sin
|
102 |
+
|
103 |
+
def _rotate_half(self, x):
|
104 |
+
midpoint = x.shape[-1] // 2
|
105 |
+
x1, x2 = x[..., :midpoint], x[..., midpoint:]
|
106 |
+
return mx.concatenate([-x2, x1], axis=-1)
|
107 |
+
|
108 |
+
class Phi3Attention(nn.Module):
|
109 |
+
def __init__(self, config):
|
110 |
+
super().__init__()
|
111 |
+
dim = config.hidden_size
|
112 |
+
self.n_heads = n_heads = config.num_attention_heads
|
113 |
+
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
114 |
+
self.num_hidden_layers = config.num_hidden_layers
|
115 |
+
self.head_dim = head_dim = config.hidden_size // n_heads
|
116 |
+
self.scale = head_dim**-0.5
|
117 |
+
chop_1 = self.n_heads * self.head_dim
|
118 |
+
chop_2 = chop_1 + self.n_kv_heads * self.head_dim
|
119 |
+
self.chop = [chop_1, chop_2]
|
120 |
+
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
121 |
+
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
122 |
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
123 |
+
self.rope = SuRoPE(config)
|
124 |
+
|
125 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
126 |
+
B, L, _ = x.shape
|
127 |
+
qkv = self.qkv_proj(x)
|
128 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
129 |
+
q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
130 |
+
k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
131 |
+
v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
132 |
+
if cache is None:
|
133 |
+
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
|
134 |
+
q, k = self.rope(q,k,position_ids)
|
135 |
+
mask = mx.triu(mx.full((v.shape[2], v.shape[2]), -mx.inf), k=1)
|
136 |
+
if attention_mask is not None:
|
137 |
+
mask += mx.where(attention_mask[:, :, None]*attention_mask[:, None, :]==1, 0, -mx.inf)
|
138 |
+
mask = mx.expand_dims(mask, 1)
|
139 |
+
else:
|
140 |
+
mask = mask[None, None]
|
141 |
+
else:
|
142 |
+
past_k, past_v, past_p, past_m = cache
|
143 |
+
position_ids = past_p[:,-1:]+1
|
144 |
+
mask = mx.pad(past_m[:,:,-1:,:], ((0,0),(0,0),(0,0),(0,1)))
|
145 |
+
q, k = self.rope(q, k, position_ids)
|
146 |
+
k = mx.concatenate([past_k, k], axis=2)
|
147 |
+
v = mx.concatenate([past_v, v], axis=2)
|
148 |
+
cache = (k, v, position_ids, mask)
|
149 |
+
w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
150 |
+
w += mask
|
151 |
+
w = mx.softmax(w, axis=-1)
|
152 |
+
o = w @ v
|
153 |
+
o = o.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
154 |
+
return self.o_proj(o).astype(x.dtype), cache
|
155 |
+
|
156 |
+
class Phi3Retention(nn.Module):
|
157 |
+
def __init__(self, config):
|
158 |
+
super().__init__()
|
159 |
+
self.dim = dim = config.hidden_size
|
160 |
+
self.n_heads = n_heads = config.num_attention_heads
|
161 |
+
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
162 |
+
self.num_hidden_layers = config.num_hidden_layers
|
163 |
+
self.head_dim = head_dim = config.hidden_size // n_heads
|
164 |
+
self.scale = head_dim**-0.5
|
165 |
+
chop_1 = self.n_heads * self.head_dim
|
166 |
+
chop_2 = chop_1 + self.n_kv_heads * self.head_dim
|
167 |
+
self.chop = [chop_1, chop_2]
|
168 |
+
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
169 |
+
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
170 |
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
171 |
+
self.rope = SuRoPE(config)
|
172 |
+
xmin, xmax = math.log(1 / 32), math.log(1 / 512)
|
173 |
+
x = mx.linspace(xmin, xmax, num=n_heads)
|
174 |
+
self._gamma = 1 - x.exp()
|
175 |
+
self.gn = nn.GroupNorm(num_groups=head_dim, dims=-1, affine=False)
|
176 |
+
|
177 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
178 |
+
if use_recurrent_mode:
|
179 |
+
return self.recurrent_mode(x, cache)
|
180 |
+
B, L, _ = x.shape
|
181 |
+
qkv = self.qkv_proj(x)
|
182 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
183 |
+
q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
184 |
+
k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
185 |
+
v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
186 |
+
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
|
187 |
+
q, k = self.rope(q,k,position_ids)
|
188 |
+
cache = None
|
189 |
+
w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
190 |
+
w = w * self._decay(L)
|
191 |
+
o = w @ v
|
192 |
+
o = o.transpose(0, 2, 1, 3).reshape(B*L, -1)
|
193 |
+
o = self.gn(o).reshape(B, L, -1)
|
194 |
+
return self.o_proj(o).astype(x.dtype), cache
|
195 |
+
|
196 |
+
def recurrent_mode(self, x, cache):
|
197 |
+
if cache is None:
|
198 |
+
s = mx.zeros((1, 32, 96, 96))
|
199 |
+
n = 0
|
200 |
+
else:
|
201 |
+
s, n = cache
|
202 |
+
qkv = self.qkv_proj(x)
|
203 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
204 |
+
q = q.reshape(1, 1, self.n_heads, -1).transpose(0, 2, 1, 3)
|
205 |
+
k = k.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
206 |
+
v = v.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
207 |
+
position_ids = mx.array([[n]])
|
208 |
+
q, k = self.rope(q,k,position_ids)
|
209 |
+
k = k * self.scale
|
210 |
+
s = self._gamma[None, :, None, None] * s + (k.transpose(0, 1, 3, 2) @ v)
|
211 |
+
o = q @ s
|
212 |
+
o = o.transpose(0, 2, 1, 3).reshape(1, -1)
|
213 |
+
o = self.gn(o).reshape(1, 1, -1)
|
214 |
+
o = self.o_proj(o).astype(x.dtype)
|
215 |
+
return o, (s, n+1)
|
216 |
+
|
217 |
+
def _decay(self, sequence_length):
|
218 |
+
n = mx.arange(sequence_length)[:,None]
|
219 |
+
m = mx.arange(sequence_length)[None]
|
220 |
+
D = (self._gamma[:, None, None] ** (n-m)) * (n >= m)
|
221 |
+
return D
|
222 |
+
|
223 |
+
class Phi3MLP(nn.Module):
|
224 |
+
def __init__(self, config):
|
225 |
+
super().__init__()
|
226 |
+
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
227 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
228 |
+
|
229 |
+
def __call__(self, x):
|
230 |
+
x = self.gate_up_proj(x)
|
231 |
+
gate, x = mx.split(x, 2, axis=-1)
|
232 |
+
return self.down_proj(nn.silu(gate) * x)
|
233 |
+
|
234 |
+
class Phi3DecoderLayer(nn.Module):
|
235 |
+
def __init__(self, config):
|
236 |
+
super().__init__()
|
237 |
+
if config.use_retention:
|
238 |
+
self.self_attn = Phi3Retention(config)
|
239 |
+
else:
|
240 |
+
self.self_attn = Phi3Attention(config)
|
241 |
+
self.mlp = Phi3MLP(config)
|
242 |
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
243 |
+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
244 |
+
|
245 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
246 |
+
r, cache = self.self_attn(self.input_layernorm(x), position_ids, attention_mask, cache, use_recurrent_mode)
|
247 |
+
h = x + r
|
248 |
+
r = self.mlp(self.post_attention_layernorm(h))
|
249 |
+
return h + r, cache
|
250 |
+
|
251 |
+
class Phi3Model(nn.Module):
|
252 |
+
def __init__(self, config):
|
253 |
+
super().__init__()
|
254 |
+
self.embed_new = nn.Embedding(config.vocab_size, config.hidden_size)
|
255 |
+
self.layers = [Phi3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
256 |
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
257 |
+
|
258 |
+
def __call__(self, input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode):
|
259 |
+
x = self.embed_new(input_ids)
|
260 |
+
cache = [None]*len(self.layers) if cache is None else cache
|
261 |
+
for i, l in enumerate(self.layers):
|
262 |
+
x, cache[i] = l(x, position_ids, attention_mask, cache[i], use_recurrent_mode)
|
263 |
+
return self.norm(x), cache
|
264 |
+
|
265 |
+
class Phi3ForCausalLM(nn.Module):
|
266 |
+
def __init__(self, config):
|
267 |
+
super().__init__()
|
268 |
+
self.model = Phi3Model(config)
|
269 |
+
if config.untie_embedding:
|
270 |
+
self.lm_new = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
271 |
+
self.untie = True
|
272 |
+
else:
|
273 |
+
self.untie = False
|
274 |
+
|
275 |
+
def __call__(self, input_ids, pixel_values=None, image_sizes=None, position_ids=None, attention_mask=None, cache=None, use_recurrent_mode=False):
|
276 |
+
x, cache = self.model(input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode)
|
277 |
+
if self.untie:
|
278 |
+
return self.lm_new(x), cache
|
279 |
+
return self.model.embed_new.as_linear(x), cache
|
280 |
+
|
281 |
+
@property
|
282 |
+
def layers(self):
|
283 |
+
return self.model.layers
|
284 |
+
|
285 |
+
class DoRALinear(nn.Module):
|
286 |
+
@staticmethod
|
287 |
+
def from_linear(linear, r, alpha, scale, dropout):
|
288 |
+
output_dims, input_dims = linear.weight.shape
|
289 |
+
if isinstance(linear, nn.QuantizedLinear):
|
290 |
+
input_dims *= 32 // linear.bits
|
291 |
+
lora_lin = DoRALinear(input_dims=input_dims, output_dims=output_dims, r=r, alpha=alpha, scale=scale, dropout=dropout)
|
292 |
+
lora_lin.linear = linear
|
293 |
+
return lora_lin
|
294 |
+
|
295 |
+
def __init__(self, input_dims, output_dims, r, alpha, scale, dropout, bias=False):
|
296 |
+
super().__init__()
|
297 |
+
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
298 |
+
self.dropout = nn.Dropout(p=dropout)
|
299 |
+
self.scale = scale * (alpha / r)
|
300 |
+
scale = 1 / math.sqrt(input_dims)
|
301 |
+
self.lora_a = mx.random.uniform(low=-scale, high=scale, shape=(input_dims, r))
|
302 |
+
self.lora_b = mx.zeros(shape=(r, output_dims))
|
303 |
+
self.m = mx.linalg.norm(self._dequantized_weight(), axis=1).astype(mx.float32)
|
304 |
+
|
305 |
+
def _dequantized_weight(self):
|
306 |
+
weight = self.linear.weight
|
307 |
+
if isinstance(self.linear, nn.QuantizedLinear):
|
308 |
+
weight = mx.dequantize(weight, self.linear.scales, self.linear.biases, self.linear.group_size, self.linear.bits)
|
309 |
+
return weight
|
310 |
+
|
311 |
+
def __call__(self, x):
|
312 |
+
y = self.linear(x)
|
313 |
+
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
314 |
+
z = y + (self.scale * z)
|
315 |
+
adapted = self._dequantized_weight() + (self.scale * self.lora_b.T) @ self.lora_a.T
|
316 |
+
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
317 |
+
z = (self.m / denom) * z
|
318 |
+
return z.astype(x.dtype)
|
319 |
+
|
320 |
+
def linear_to_lora_layers(model, lora_layers, lora_targets, lora_rank, lora_scale, lora_dropout):
|
321 |
+
if lora_layers == 'all':
|
322 |
+
lora_layers = model.layers
|
323 |
+
elif isinstance(lora_layers, int):
|
324 |
+
lora_layers = model.layers[-lora_layers:]
|
325 |
+
elif isinstance(lora_layers, list):
|
326 |
+
lora_layers = [model.layers[i] for i in lora_layers]
|
327 |
+
else:
|
328 |
+
raise ValueError("Invalid type for lora_layers. Expected int (number of layers) or list (layer indices or names).")
|
329 |
+
def to_lora(layer):
|
330 |
+
return DoRALinear.from_linear(layer, r=lora_rank, alpha=lora_rank, scale=lora_scale, dropout=lora_dropout)
|
331 |
+
for l in lora_layers:
|
332 |
+
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in lora_targets]
|
333 |
+
l.update_modules(tree_unflatten(lora_layers))
|
334 |
+
|
335 |
+
def load_base_model(model_cfg, init=False):
|
336 |
+
model_id='microsoft/Phi-3.5-mini-instruct'
|
337 |
+
model_path = snapshot_download(model_id, allow_patterns=["*.safetensors", "config.json"])
|
338 |
+
with open(f"{model_path}/config.json", "r") as f:
|
339 |
+
config = json.load(f)
|
340 |
+
config = config|model_cfg
|
341 |
+
model_config = SimpleNamespace(**config)
|
342 |
+
model = Phi3ForCausalLM(model_config)
|
343 |
+
model_weight = [(k, v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()]
|
344 |
+
model.load_weights(model_weight, strict=False)
|
345 |
+
model.set_dtype(mx.float32)
|
346 |
+
if init:
|
347 |
+
init_fn_embed = nn.init.normal(mean=-0.000453949, std=0.0344238)
|
348 |
+
model.apply_to_modules(lambda k, v: v.apply(init_fn_embed) if k.endswith('embed_new') else None)
|
349 |
+
if model_config.untie_embedding:
|
350 |
+
init_fn_lm = nn.init.normal(mean=-0.000231743, std=0.043457)
|
351 |
+
model.apply_to_modules(lambda k, v: v.apply(init_fn_lm) if k.endswith('lm_new') else None)
|
352 |
+
class_predicate = lambda k, m: hasattr(m, "to_quantized") and not k.endswith('new')
|
353 |
+
nn.quantize(model, 64, 4, class_predicate)
|
354 |
+
mx.eval(model.parameters())
|
355 |
+
return model
|
356 |
+
|
357 |
+
def load_model_for_training(lora_cfg, model_cfg, thaws, from_path=None):
|
358 |
+
model = load_base_model(model_cfg, init=False)
|
359 |
+
if from_path:
|
360 |
+
model.load_weights(from_path, strict=False)
|
361 |
+
model.freeze()
|
362 |
+
if len(lora_cfg['targets']) > 1:
|
363 |
+
linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
|
364 |
+
model.apply_to_modules(lambda k, v: v.unfreeze() if any(k.endswith(t) for t in thaws) else None)
|
365 |
+
mx.eval(model.parameters())
|
366 |
+
# print("Trainable parameters:", [i[0] for i in tree_flatten(model.trainable_parameters())])
|
367 |
+
model.train()
|
368 |
+
return model
|
369 |
+
|
370 |
+
def load_model_for_inference(lora_cfg, model_cfg):
|
371 |
+
model = load_base_model(model_cfg, init=False)
|
372 |
+
if len(lora_cfg['targets']) > 1:
|
373 |
+
linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
|
374 |
+
_path = 'trained_retnphi.safetensors' if model_cfg['use_retention'] else 'trained_orgnphi.safetensors'
|
375 |
+
model.load_weights(_path, strict=False)
|
376 |
+
mx.eval(model.parameters())
|
377 |
+
model.eval()
|
378 |
+
return model
|
379 |
+
|
380 |
+
def generate(prompt, lora_cfg, model_cfg, max_tokens=50, verbose = True):
|
381 |
+
model = load_model_for_inference(lora_cfg=lora_cfg, model_cfg=model_cfg)
|
382 |
+
input_ids = mx.array(tokenizer.encode(prompt))
|
383 |
+
if model_cfg['use_retention']:
|
384 |
+
cache = None
|
385 |
+
for i in input_ids:
|
386 |
+
logits, cache = model(i[None, None], cache=cache, use_recurrent_mode=True)
|
387 |
+
else:
|
388 |
+
logits, cache = model(input_ids[None])
|
389 |
+
token = mx.argmax(logits[:,-1,:], axis=-1)
|
390 |
+
mx.eval(token, cache)
|
391 |
+
list_tokens = token.tolist()
|
392 |
+
for i in range(max_tokens):
|
393 |
+
logits, cache = model(token[None], cache=cache, use_recurrent_mode=True)
|
394 |
+
token = mx.argmax(logits[:,-1,:], axis=-1)
|
395 |
+
mx.eval(token, cache)
|
396 |
+
list_tokens += token.tolist()
|
397 |
+
if tokenizer.decode(list_tokens[-2:]) == '\n\n':
|
398 |
+
break
|
399 |
+
output = tokenizer.decode(list_tokens)
|
400 |
+
if verbose:
|
401 |
+
print(f'{prompt=} + {output=}\n-> {prompt+output}')
|
402 |
+
del model
|
403 |
+
return output
|
404 |
+
|
405 |
+
def train_gsm(learning_rates, num_epochs, batch_size, seq_length, lora_cfg, model_cfg, thaws, take, from_path=None):
|
406 |
+
def load_gsm_data(tokenizer, is_tiny=True):
|
407 |
+
if is_tiny:
|
408 |
+
data = load_dataset("TinyGSM/TinyGSM")["train"]
|
409 |
+
if take:
|
410 |
+
data = data.take(take)
|
411 |
+
data = data.filter(lambda x: len(x['question']) < 100 and ':' not in x['question'] and '-' not in x['question'] and "'" not in x['code'] and '\n result =' in x['code'])
|
412 |
+
split_point = int(len(data) * 0.8)
|
413 |
+
train_data = data.select(range(split_point))
|
414 |
+
eval_data = data.select(range(split_point, len(data)))
|
415 |
+
def format_example(example):
|
416 |
+
code_raw = example['code']
|
417 |
+
start = code_raw.rfind('\n """')
|
418 |
+
if start == -1:
|
419 |
+
print('Wrong format to start')
|
420 |
+
return code_raw.strip()
|
421 |
+
start = start + 8
|
422 |
+
end = code_raw.rfind('\n result =')
|
423 |
+
if end == -1:
|
424 |
+
print('Wrong format to end')
|
425 |
+
end = len(code_raw)
|
426 |
+
code_block = code_raw[start:end]
|
427 |
+
code_lines = code_block.split('\n ')
|
428 |
+
formatted_code = '\n'.join(line.rstrip() for line in code_lines if line.strip())
|
429 |
+
formatted_code = '\n' + formatted_code.strip() + '\n\n'
|
430 |
+
result = (example['question'].strip(), formatted_code)
|
431 |
+
return result
|
432 |
+
else:
|
433 |
+
dataset = load_dataset("openai/gsm8k", "main")
|
434 |
+
train_data = dataset["train"]
|
435 |
+
eval_data = dataset["test"]
|
436 |
+
def format_example(example):
|
437 |
+
return (example['question'].strip(), '\n'+example['answer'].strip()+'\n\n')
|
438 |
+
train_formatted = [format_example(ex) for ex in train_data]
|
439 |
+
eval_formatted = [format_example(ex) for ex in eval_data]
|
440 |
+
return train_formatted, eval_formatted
|
441 |
+
|
442 |
+
def create_batches(data, tokenizer, batch_size, seq_length):
|
443 |
+
def _encode(x):
|
444 |
+
return [tokenizer.encode(i) for i in x]
|
445 |
+
encoded_data = [_encode(x) for x in data]
|
446 |
+
encoded_data = [x for x in encoded_data if len(x[0]+x[1]) <= seq_length+1]
|
447 |
+
if batch_size is None:
|
448 |
+
batch_size = min(len(encoded_data), 64)
|
449 |
+
else:
|
450 |
+
encoded_data = encoded_data[:(len(encoded_data) // batch_size) * batch_size]
|
451 |
+
np.random.shuffle(encoded_data)
|
452 |
+
for i in range(0, len(encoded_data), batch_size):
|
453 |
+
batch = encoded_data[i:i+batch_size]
|
454 |
+
max_len = min(max(len(q+a)-1 for q, a in batch), seq_length)
|
455 |
+
x_batch = []
|
456 |
+
y_batch = []
|
457 |
+
mask_batch = []
|
458 |
+
for q, a in batch:
|
459 |
+
combined = (q+a)[:max_len+1]
|
460 |
+
x = combined[:-1]
|
461 |
+
y = combined[1:]
|
462 |
+
pad_length = max_len - len(x)
|
463 |
+
x = x + [0] * pad_length
|
464 |
+
y = y + [0] * pad_length
|
465 |
+
mask = [False] * (len(q)-1) + [True] * (len(a)) + [False] * (pad_length)
|
466 |
+
x_batch.append(x)
|
467 |
+
y_batch.append(y)
|
468 |
+
mask_batch.append(mask)
|
469 |
+
yield mx.array(x_batch), mx.array(y_batch), mx.array(mask_batch)
|
470 |
+
|
471 |
+
def loss_fn(model, X, y, mask):
|
472 |
+
logits, _ = model(X)
|
473 |
+
logits = logits.astype(mx.float32)
|
474 |
+
ce = nn.losses.cross_entropy(logits, y, reduction='none')
|
475 |
+
masked_loss = ce * mask
|
476 |
+
return masked_loss.sum(), mask.sum()
|
477 |
+
|
478 |
+
def evaluate(model, data, tokenizer, seq_length):
|
479 |
+
model.eval()
|
480 |
+
total_loss = 0
|
481 |
+
total_samples = 0
|
482 |
+
for X, y, mask in create_batches(data, tokenizer, None, seq_length):
|
483 |
+
loss, ntoks = loss_fn(model, X, y, mask)
|
484 |
+
total_loss += loss.item()
|
485 |
+
total_samples += ntoks.item()
|
486 |
+
return total_loss / total_samples if total_samples > 0 else -1
|
487 |
+
|
488 |
+
def get_optimizer(train_data):
|
489 |
+
num_batches_per_epoch = len(list(create_batches(train_data, tokenizer, batch_size, seq_length)))
|
490 |
+
print(f'{num_batches_per_epoch=}')
|
491 |
+
num_steps = num_epochs * num_batches_per_epoch
|
492 |
+
num_warmup = num_steps // 10
|
493 |
+
max_lr, min_lr = learning_rates
|
494 |
+
if num_warmup > 2:
|
495 |
+
warmup = optim.linear_schedule(min_lr*0.1, max_lr, steps=num_warmup)
|
496 |
+
cosine = optim.cosine_decay(max_lr, num_steps - num_warmup, min_lr)
|
497 |
+
lr_schedule = optim.join_schedules([warmup, cosine], [num_warmup])
|
498 |
+
else:
|
499 |
+
lr_schedule = optim.cosine_decay(max_lr, num_steps, min_lr)
|
500 |
+
return optim.Lion(learning_rate=lr_schedule), num_steps
|
501 |
+
|
502 |
+
for arg_name in sorted(locals()):
|
503 |
+
if arg_name != 'self':
|
504 |
+
arg_value = locals()[arg_name]
|
505 |
+
if not callable(arg_value):
|
506 |
+
print(f"{arg_name}: {arg_value}")
|
507 |
+
|
508 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
509 |
+
print(f'--- {timestamp} ---')
|
510 |
+
train_data, eval_data = load_gsm_data(tokenizer=tokenizer)
|
511 |
+
model = load_model_for_training(lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws)
|
512 |
+
optimizer, num_steps = get_optimizer(train_data)
|
513 |
+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
514 |
+
mx.eval(model, optimizer)
|
515 |
+
metrics = {
|
516 |
+
'steps': [],
|
517 |
+
'learning_rates': [],
|
518 |
+
'all_train_losses': [],
|
519 |
+
'avg_train_losses': [],
|
520 |
+
'val_losses': [],
|
521 |
+
'trained_toks': [],
|
522 |
+
}
|
523 |
+
step = 0
|
524 |
+
trained_toks = 0
|
525 |
+
losses = []
|
526 |
+
tic = time.perf_counter()
|
527 |
+
for epoch in range(num_epochs):
|
528 |
+
for X, y, loss_mask in create_batches(data=train_data, tokenizer=tokenizer, batch_size=batch_size, seq_length=seq_length):
|
529 |
+
model.train()
|
530 |
+
(loss, ntoks), grads = loss_and_grad_fn(model, X, y, loss_mask)
|
531 |
+
optimizer.update(model, grads)
|
532 |
+
mx.eval(loss, ntoks, model, optimizer)
|
533 |
+
losses.append(loss.item())
|
534 |
+
trained_toks += ntoks.item()
|
535 |
+
step += 1
|
536 |
+
if (step % (num_steps // 30) == 0):
|
537 |
+
avg_train_loss = np.mean(losses)
|
538 |
+
lr = optimizer.learning_rate.item()
|
539 |
+
val_loss = evaluate(model=model, data=eval_data, tokenizer=tokenizer, seq_length=seq_length)
|
540 |
+
print(f"{avg_train_loss:8.4f} ({val_loss:6.4f}) @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
|
541 |
+
metrics['val_losses'].append(val_loss)
|
542 |
+
# print(f"{avg_train_loss:8.4f} @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
|
543 |
+
tic = time.perf_counter()
|
544 |
+
metrics['steps'].append(step)
|
545 |
+
metrics['learning_rates'].append(lr)
|
546 |
+
metrics['all_train_losses'].extend(losses)
|
547 |
+
metrics['avg_train_losses'].append(avg_train_loss)
|
548 |
+
metrics['trained_toks'].append(trained_toks)
|
549 |
+
losses = []
|
550 |
+
trained_toks = 0
|
551 |
+
_path = f'trained_retnphi.safetensors' if model_cfg['use_retention'] else f'trained_orgnphi.safetensors'
|
552 |
+
mx.save_safetensors(_path, dict(tree_flatten(model.trainable_parameters())))
|
553 |
+
log = {
|
554 |
+
'args': {
|
555 |
+
'learning_rates': learning_rates,
|
556 |
+
'num_epochs': num_epochs,
|
557 |
+
'batch_size': batch_size,
|
558 |
+
'seq_length': seq_length,
|
559 |
+
'lora_cfg': lora_cfg,
|
560 |
+
'model_cfg': model_cfg,
|
561 |
+
'thaws': thaws,
|
562 |
+
'from_path': from_path
|
563 |
+
},
|
564 |
+
'metrics': metrics
|
565 |
+
}
|
566 |
+
with open(f'train_log_{timestamp}.json', 'w') as f:
|
567 |
+
json.dump(log, f, indent=2)
|
568 |
+
del model
|
569 |
+
|
570 |
+
tokenizer = Tokenizer()
|
571 |
+
|
572 |
+
def main(take=1000, layers='all', targets=["self_attn.o_proj"], thaws=['new', 'post_attention_layernorm'], rank=32, scale=0.1, dropout=0.0, lr_max=1e-4, lr_min=1e-5, num_epochs=90, batch_size=1, seq_length=256, vocab_size=256, use_retention=True, untie_embedding=True, prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?'):
|
573 |
+
lora_cfg = dict(layers=layers, targets=targets, rank=rank, scale=scale, dropout=dropout)
|
574 |
+
model_cfg = dict(vocab_size=vocab_size, use_retention=use_retention, untie_embedding=untie_embedding)
|
575 |
+
train_gsm(learning_rates=(lr_max, lr_min), num_epochs=num_epochs, batch_size=batch_size, seq_length=seq_length, lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws, take=take)
|
576 |
+
generate(prompt=prompt, lora_cfg=lora_cfg, model_cfg=model_cfg, max_tokens=(seq_length-len(prompt)))
|
577 |
+
|
578 |
+
if __name__ == "__main__":
|
579 |
+
main(take=None, num_epochs=3) # -> 240916
|
580 |
+
main(take=None, num_epochs=3, untie_embedding=False)
|
581 |
+
|
582 |
+
main(take=None, num_epochs=3, use_retention=False)
|
583 |
+
main(take=None, num_epochs=3, untie_embedding=False, use_retention=False)
|
584 |
+
# fire.Fire(main)
|
585 |
+
|
586 |
+
# Output:
|
587 |
+
# 388.7268 @ 1/30 w/ 3.36e-05 (64.73 sec)
|
588 |
+
# ...
|
589 |
+
# 4.3768 @ 30/30 w/ 1.00e-05 (64.36 sec)
|
590 |
+
# prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?' + output='\ncandies_in_carton = 8 \nnumber_of_cartons = 5\ntotal_no_of_candies = candies_in_carton * number_of_cartons\n\n'
|
591 |
+
# -> There are 8 candies in a carton. How many candies will be in 5 cartons?
|
592 |
+
# candies_in_carton = 8
|
593 |
+
# number_of_cartons = 5
|
594 |
+
# total_no_of_candies = candies_in_carton * number_of_cartons
|