aquiffoo commited on
Commit
c58e471
·
verified ·
1 Parent(s): c42df27

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +388 -0
README.md CHANGED
@@ -40,5 +40,393 @@ This is our first ever model! Allow us to explain how the `mesh` architecture wo
40
  ## Here's how the mesh architecture works:
41
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6747320df82ae35f0327cdd3/WRpS2T5KBMPbacobfh0bw.png)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ## Disclaimer
44
  This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.
 
40
  ## Here's how the mesh architecture works:
41
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6747320df82ae35f0327cdd3/WRpS2T5KBMPbacobfh0bw.png)
42
 
43
+ ## How to load the model
44
+ ```python
45
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig, PreTrainedModel
46
+ import torch
47
+ import torch.nn as nn
48
+ import torch.nn.functional as F
49
+ import math
50
+ from transformers.modeling_outputs import CausalLMOutputWithPast
51
+ from transformers.generation import GenerationMixin
52
+ import os
53
+
54
+ class MeshConfig(PretrainedConfig):
55
+ model_type = "mesh"
56
+
57
+ def __init__(
58
+ self,
59
+ vocab_size=32000,
60
+ hidden_size=768,
61
+ intermediate_size=2048,
62
+ num_hidden_layers=12,
63
+ num_attention_heads=12,
64
+ num_key_value_heads=12,
65
+ max_position_embeddings=4096,
66
+ initializer_range=0.02,
67
+ rms_norm_eps=1e-6,
68
+ use_cache=True,
69
+ pad_token_id=0,
70
+ bos_token_id=1,
71
+ eos_token_id=2,
72
+ tie_word_embeddings=False,
73
+ mesh_grid_size=(2, 2),
74
+ expert_intermediate_size=256,
75
+ routing_k=2,
76
+ neighbor_exchange_enabled=True,
77
+ cross_expert_attention_enabled=True,
78
+ expert_scale_factor="sqrt_k",
79
+ load_in_8bit=False,
80
+ load_in_4bit=False,
81
+ **kwargs
82
+ ):
83
+ super().__init__(
84
+ vocab_size=vocab_size,
85
+ hidden_size=hidden_size,
86
+ intermediate_size=intermediate_size,
87
+ num_hidden_layers=num_hidden_layers,
88
+ num_attention_heads=num_attention_heads,
89
+ num_key_value_heads=num_key_value_heads,
90
+ max_position_embeddings=max_position_embeddings,
91
+ initializer_range=initializer_range,
92
+ rms_norm_eps=rms_norm_eps,
93
+ use_cache=use_cache,
94
+ pad_token_id=pad_token_id,
95
+ bos_token_id=bos_token_id,
96
+ eos_token_id=eos_token_id,
97
+ tie_word_embeddings=tie_word_embeddings,
98
+ **kwargs,
99
+ )
100
+ self.mesh_grid_size = mesh_grid_size
101
+ self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]))
102
+ self.routing_k = routing_k
103
+ self.neighbor_exchange_enabled = neighbor_exchange_enabled
104
+ self.cross_expert_attention_enabled = cross_expert_attention_enabled
105
+ self.expert_scale_factor = expert_scale_factor
106
+ self.load_in_8bit = load_in_8bit
107
+ self.load_in_4bit = load_in_4bit
108
+
109
+ class MeshExpert(nn.Module):
110
+ def __init__(self, config: MeshConfig):
111
+ super().__init__()
112
+ self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
113
+ self.gelu = nn.GELU()
114
+ self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)
115
+
116
+ def forward(self, x):
117
+ return self.fc2(self.gelu(self.fc1(x)))
118
+
119
+ class MeshRouter(nn.Module):
120
+ def __init__(self, config: MeshConfig):
121
+ super().__init__()
122
+ self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
123
+ self.softmax = nn.Softmax(dim=-1)
124
+ self.routing_k = config.routing_k
125
+
126
+ def forward(self, x):
127
+ gate_scores = self.gate(x)
128
+ gate_weights = self.softmax(gate_scores)
129
+ topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
130
+ return topk_weights, topk_indices
131
+
132
+ class NeighborExchange(nn.Module):
133
+ def __init__(self, config: MeshConfig):
134
+ super().__init__()
135
+ self.config = config
136
+ self.num_experts_x = config.mesh_grid_size[0]
137
+ self.num_experts_y = config.mesh_grid_size[1]
138
+ self.num_experts = self.num_experts_x * self.num_experts_y
139
+
140
+ self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size)
141
+
142
+ def forward(self, expert_outputs, expert_indices=None):
143
+ if not self.config.neighbor_exchange_enabled:
144
+ return expert_outputs
145
+
146
+ batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
147
+ reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
148
+ aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)
149
+
150
+ for i in range(self.num_experts_x):
151
+ for j in range(self.num_experts_y):
152
+ current_expert_output = reshaped_outputs[:, :, i, j, :]
153
+ neighbor_info = torch.zeros_like(current_expert_output)
154
+ neighbors = []
155
+ if i > 0: neighbors.append(reshaped_outputs[:, :, i-1, j, :])
156
+ if i < self.num_experts_x - 1: neighbors.append(reshaped_outputs[:, :, i+1, j, :])
157
+ if j > 0: neighbors.append(reshaped_outputs[:, :, i, j-1, :])
158
+ if j < self.num_experts_y - 1: neighbors.append(reshaped_outputs[:, :, i, j+1, :])
159
+
160
+ if neighbors:
161
+ neighbor_stack = torch.stack(neighbors, dim=-2)
162
+ aggregated_info = torch.mean(neighbor_stack, dim=-2)
163
+ neighbor_info = aggregated_info
164
+
165
+ transformed_neighbor_info = self.exchange_projection(neighbor_info)
166
+ aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info
167
+
168
+ aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
169
+ exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info
170
+
171
+ return exchanged_expert_outputs
172
+
173
+ class CrossExpertAttention(nn.Module):
174
+ def __init__(self, config: MeshConfig):
175
+ super().__init__()
176
+ self.config = config
177
+ self.cross_attention = nn.MultiheadAttention(
178
+ embed_dim=config.hidden_size,
179
+ num_heads=config.num_attention_heads,
180
+ batch_first=True
181
+ )
182
+
183
+ def forward(self, expert_outputs):
184
+ if not self.config.cross_expert_attention_enabled:
185
+ return expert_outputs
186
+
187
+ batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
188
+ reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
189
+ cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
190
+ cross_attn_output = cross_attn_output.view(
191
+ expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
192
+ )
193
+ return cross_attn_output
194
+
195
+ class MeshLayer(nn.Module):
196
+ def __init__(self, config: MeshConfig):
197
+ super().__init__()
198
+ self.config = config
199
+ self.router = MeshRouter(config)
200
+ self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
201
+ self.neighbor_exchange = NeighborExchange(config)
202
+ self.cross_expert_attention = CrossExpertAttention(config)
203
+
204
+ def forward(self, hidden_states):
205
+ topk_weights, topk_indices = self.router(hidden_states)
206
+ expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
207
+
208
+ if self.config.expert_scale_factor == "sqrt_k":
209
+ scaling_factor = math.sqrt(self.config.routing_k)
210
+ scaled_expert_inputs = expanded_hidden_states * scaling_factor
211
+ elif self.config.expert_scale_factor == "1_over_k":
212
+ scaling_factor = 1.0 / self.config.routing_k
213
+ scaled_expert_inputs = expanded_hidden_states * scaling_factor
214
+ else:
215
+ scaled_expert_inputs = expanded_hidden_states
216
+
217
+ expert_outputs_list = [expert(scaled_expert_inputs[:, :, i, :]) for i, expert in enumerate(self.experts)]
218
+ expert_outputs = torch.stack(expert_outputs_list, dim=2)
219
+
220
+ exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
221
+ cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
222
+
223
+ gathered_outputs = torch.gather(
224
+ cross_attned_expert_outputs,
225
+ dim=2,
226
+ index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
227
+ )
228
+
229
+ combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
230
+
231
+ return combined_output, topk_indices
232
+
233
+ class MeshModel(PreTrainedModel, GenerationMixin):
234
+ config_class = MeshConfig
235
+
236
+ def __init__(self, config: MeshConfig):
237
+ super().__init__(config)
238
+ self.config = config
239
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
240
+ self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
241
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
242
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
243
+ self.post_init()
244
+
245
+ self._supports_gradient_checkpointing = True
246
+ self.gradient_checkpointing = False
247
+
248
+ def forward(
249
+ self,
250
+ input_ids=None,
251
+ attention_mask=None,
252
+ token_type_ids=None,
253
+ position_ids=None,
254
+ inputs_embeds=None,
255
+ labels=None,
256
+ return_dict=None,
257
+ output_attentions=None,
258
+ output_hidden_states=None,
259
+ past_key_values=None,
260
+ ):
261
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
262
+
263
+ if input_ids is not None and inputs_embeds is not None:
264
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
265
+ elif input_ids is not None:
266
+ inputs_embeds = self.embedding(input_ids)
267
+ elif inputs_embeds is not None:
268
+ pass
269
+ else:
270
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
271
+
272
+ hidden_states = inputs_embeds
273
+
274
+ if self.gradient_checkpointing and self.training:
275
+ import torch.utils.checkpoint
276
+
277
+ for i, layer in enumerate(self.layers):
278
+ if hasattr(layer, 'forward') and callable(layer.forward):
279
+ if self.gradient_checkpointing and self.training:
280
+ checkpoint_output = torch.utils.checkpoint.checkpoint(
281
+ layer, hidden_states, use_reentrant=False
282
+ )
283
+ if isinstance(checkpoint_output, tuple):
284
+ hidden_states = checkpoint_output[0]
285
+ else:
286
+ hidden_states = checkpoint_output
287
+
288
+ else:
289
+ layer_output = layer(hidden_states)
290
+ hidden_states = layer_output[0]
291
+ else:
292
+ print(f"Warning: Layer {i} does not have a callable forward method. Skipping layer processing.")
293
+
294
+ hidden_states = self.norm(hidden_states)
295
+ logits = self.lm_head(hidden_states)
296
+
297
+ loss = None
298
+ if labels is not None:
299
+ loss_fct = nn.CrossEntropyLoss()
300
+ shift_logits = logits[..., :-1, :].contiguous()
301
+ shift_labels = labels[..., 1:].contiguous()
302
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
303
+
304
+ if return_dict:
305
+ return CausalLMOutputWithPast(
306
+ loss=loss,
307
+ logits=logits,
308
+ )
309
+ else:
310
+ return (loss, logits)
311
+
312
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
313
+ if past_key_values is not None:
314
+ input_ids = input_ids[:, -1].unsqueeze(-1)
315
+ if inputs_embeds is not None:
316
+ inputs_embeds = inputs_embeds[:, -1, :].unsqueeze(1)
317
+
318
+ if inputs_embeds is not None:
319
+ model_inputs = {"inputs_embeds": inputs_embeds}
320
+ else:
321
+ model_inputs = {"input_ids": input_ids}
322
+
323
+ if "attention_mask" in kwargs:
324
+ model_inputs["attention_mask"] = kwargs["attention_mask"]
325
+
326
+ return model_inputs
327
+
328
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
329
+ self.gradient_checkpointing = True
330
+ self.config.gradient_checkpointing = True
331
+ print("Gradient checkpointing enabled on MeshModel.")
332
+
333
+ def gradient_checkpointing_disable(self):
334
+ self.gradient_checkpointing = False
335
+ self.config.gradient_checkpointing = False
336
+ print("Gradient checkpointing disabled on MeshModel.")
337
+
338
+ def _set_gradient_checkpointing(self, enable=True):
339
+ if enable:
340
+ self.gradient_checkpointing_enable()
341
+ else:
342
+ self.gradient_checkpointing_disable()
343
+
344
+ from transformers import AutoConfig
345
+ AutoConfig.register("mesh", MeshConfig)
346
+ AutoModelForCausalLM.register(MeshConfig, MeshModel)
347
+
348
+ HF_MERGED_REPO_STAGE003 = "mesh-labs/v0.1-2x2-stage003"
349
+
350
+ loaded_model_stage003 = None
351
+ loaded_tokenizer_stage003 = None
352
+
353
+ try:
354
+ print(f"Attempting to load Stage 003 merged model from HF: {HF_MERGED_REPO_STAGE003}...")
355
+ device_map = "auto"
356
+
357
+ loaded_model_stage003 = AutoModelForCausalLM.from_pretrained(
358
+ HF_MERGED_REPO_STAGE003,
359
+ trust_remote_code=True,
360
+ device_map=device_map,
361
+ torch_dtype=torch.float32
362
+ )
363
+
364
+ if torch.cuda.is_available():
365
+ loaded_model_stage003.to('cuda')
366
+ print("Stage 003 merged model moved to GPU.")
367
+ else:
368
+ print("Stage 003 merged model loaded on CPU.")
369
+
370
+ loaded_tokenizer_stage003 = AutoTokenizer.from_pretrained(
371
+ HF_MERGED_REPO_STAGE003,
372
+ trust_remote_code=True,
373
+ use_fast=False
374
+ )
375
+
376
+ print("Stage 003 merged model and tokenizer loaded successfully from Hugging Face Hub.")
377
+
378
+ except Exception as e:
379
+ print(f"Error loading Stage 003 merged model or tokenizer from Hugging Face Hub: {e}")
380
+ loaded_model_stage003 = None
381
+ loaded_tokenizer_stage003 = None
382
+
383
+ if loaded_model_stage003 is not None and loaded_tokenizer_stage003 is not None:
384
+ print("\n--- Starting Chat Interface ---")
385
+ print("Type your message and press Enter. Type 'quit' to exit.")
386
+
387
+ loaded_model_stage003.eval()
388
+
389
+ while True:
390
+ try:
391
+ user_input = input("You: ")
392
+ if user_input.lower() == 'quit':
393
+ break
394
+
395
+ prompt = f"Question: {user_input}\nAnswer:"
396
+
397
+ inputs = loaded_tokenizer_stage003(prompt, return_tensors="pt")
398
+
399
+ if torch.cuda.is_available():
400
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
401
+
402
+ with torch.no_grad():
403
+ outputs = loaded_model_stage003.generate(
404
+ **inputs,
405
+ max_new_tokens=128,
406
+ num_beams=1,
407
+ do_sample=False,
408
+ )
409
+
410
+ generated_sequence = loaded_tokenizer_stage003.decode(outputs[0], skip_special_tokens=True)
411
+
412
+ answer_prefix = "Answer:"
413
+ answer_start_index = generated_sequence.find(answer_prefix)
414
+
415
+ if answer_start_index != -1:
416
+ generated_answer = generated_sequence[answer_start_index + len(answer_prefix):].strip()
417
+ else:
418
+ print("Warning: 'Answer:' prefix not found in generated text. Showing full generated sequence.")
419
+ generated_answer = generated_sequence.strip()
420
+
421
+ print("Model:", generated_answer)
422
+
423
+ except Exception as e:
424
+ print(f"An error occurred: {e}")
425
+ print("Please try again or type 'quit' to exit.")
426
+
427
+ else:
428
+ print("\nModel or tokenizer not loaded. Cannot start chat interface.")
429
+ ```
430
+
431
  ## Disclaimer
432
  This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.