Upload architectureV3.py

#3
by win10 - opened
Files changed (1) hide show
  1. architectureV3.py +33 -13
architectureV3.py CHANGED
@@ -99,13 +99,19 @@ class ReflectiveMemoryLayer(nn.Module):
99
 
100
  def forward(self, x: torch.Tensor):
101
  base_output = self.linear(x)
102
- if 'embeds' not in self.global_state_storage: return base_output
 
103
 
104
  global_embeds = self.global_state_storage['embeds']
105
- if global_embeds.shape[1] != x.shape[1]: global_embeds = global_embeds[:, -x.shape[1]:, :]
 
106
  B, S, _ = x.shape
107
 
 
108
  ltm_state = self.global_state_storage.get('ltm', None)
 
 
 
109
  proj_local = self.local_state_proj(x)
110
  proj_global = self.global_state_proj(global_embeds)
111
  memory_input = torch.stack([proj_global, proj_local], dim=2)
@@ -122,6 +128,7 @@ class ReflectiveMemoryLayer(nn.Module):
122
  if new_ltm_state_expanded is not None:
123
  num_ltm_slots = new_ltm_state_expanded.shape[1]
124
  new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
 
125
  self.global_state_storage['ltm'] = new_ltm_condensed.detach()
126
 
127
  initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
@@ -143,10 +150,11 @@ class ReflectiveMemoryLayer(nn.Module):
143
  final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
144
 
145
  if self.training:
146
- self.last_corrected_activation = final_activation
147
- self.last_additive_correction = value
148
- self.last_memory_input = memory_input
149
- self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim)
 
150
  return final_activation
151
 
152
  # --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
@@ -155,16 +163,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
155
  super().__init__(config)
156
  self.global_state_storage = {}
157
  self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
158
- self.memory_dim, self.num_long_term_memory_slots = 64, 32
159
 
160
- self.model.embed_tokens.register_forward_hook(
161
- lambda module, input, output: self.global_state_storage.update({'embeds': output}))
 
 
 
162
 
163
  try:
164
  original_layer = self.get_submodule(self.target_layer_path)
165
  custom_layer = ReflectiveMemoryLayer(
166
  original_layer=original_layer, global_input_dim=config.hidden_size,
167
- memory_dim=self.memory_dim, num_memory_slots=8, memory_num_heads=4,
168
  global_state_storage=self.global_state_storage)
169
  parent_path = ".".join(self.target_layer_path.split('.')[:-1])
170
  setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
@@ -184,13 +195,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
184
  output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
185
  ltm_state: Optional[torch.Tensor] = None):
186
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
187
- self.global_state_storage.clear()
 
 
 
 
188
 
189
  # *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
190
  if ltm_state is None:
191
  batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
192
  ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
193
- self.global_state_storage['ltm'] = ltm_state
 
 
194
 
195
  outputs = self.model(
196
  input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
@@ -207,7 +224,10 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
207
  labels[..., 1:].contiguous().view(-1))
208
  # Note: Auxiliary losses from main.py are calculated outside the model forward pass.
209
 
 
210
  new_ltm_state = self.global_state_storage.get('ltm', None)
 
 
211
 
212
  if not return_dict:
213
  output = (logits,) + outputs[1:] + (new_ltm_state,)
@@ -215,4 +235,4 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
215
 
216
  return CausalLMOutputWithLTM(
217
  loss=loss, logits=logits, past_key_values=outputs.past_key_values,
218
- hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
 
99
 
100
  def forward(self, x: torch.Tensor):
101
  base_output = self.linear(x)
102
+ if 'embeds' not in self.global_state_storage:
103
+ return base_output
104
 
105
  global_embeds = self.global_state_storage['embeds']
106
+ if global_embeds.shape[1] != x.shape[1]:
107
+ global_embeds = global_embeds[:, -x.shape[1]:, :]
108
  B, S, _ = x.shape
109
 
110
+ # CRITICAL FIX: Always detach LTM state to prevent backward through previous graphs
111
  ltm_state = self.global_state_storage.get('ltm', None)
112
+ if ltm_state is not None:
113
+ ltm_state = ltm_state.detach()
114
+
115
  proj_local = self.local_state_proj(x)
116
  proj_global = self.global_state_proj(global_embeds)
117
  memory_input = torch.stack([proj_global, proj_local], dim=2)
 
128
  if new_ltm_state_expanded is not None:
129
  num_ltm_slots = new_ltm_state_expanded.shape[1]
130
  new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
131
+ # CRITICAL FIX: Always detach when storing in global state
132
  self.global_state_storage['ltm'] = new_ltm_condensed.detach()
133
 
134
  initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
 
150
  final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
151
 
152
  if self.training:
153
+ # CRITICAL FIX: Detach tensors stored for debugging/analysis
154
+ self.last_corrected_activation = final_activation.detach()
155
+ self.last_additive_correction = value.detach()
156
+ self.last_memory_input = memory_input.detach()
157
+ self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
158
  return final_activation
159
 
160
  # --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
 
163
  super().__init__(config)
164
  self.global_state_storage = {}
165
  self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
166
+ self.memory_dim, self.num_long_term_memory_slots = 128, 32
167
 
168
+ # CRITICAL FIX: Ensure embeddings are detached when stored
169
+ def embedding_hook(module, input, output):
170
+ self.global_state_storage['embeds'] = output.detach()
171
+
172
+ self.model.embed_tokens.register_forward_hook(embedding_hook)
173
 
174
  try:
175
  original_layer = self.get_submodule(self.target_layer_path)
176
  custom_layer = ReflectiveMemoryLayer(
177
  original_layer=original_layer, global_input_dim=config.hidden_size,
178
+ memory_dim=self.memory_dim, num_memory_slots=16, memory_num_heads=4,
179
  global_state_storage=self.global_state_storage)
180
  parent_path = ".".join(self.target_layer_path.split('.')[:-1])
181
  setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
 
195
  output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
196
  ltm_state: Optional[torch.Tensor] = None):
197
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
198
+
199
+ # CRITICAL FIX: Don't clear global state storage completely, just reset embeds
200
+ # This prevents losing LTM state continuity
201
+ if 'embeds' in self.global_state_storage:
202
+ del self.global_state_storage['embeds']
203
 
204
  # *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
205
  if ltm_state is None:
206
  batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
207
  ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
208
+
209
+ # CRITICAL FIX: Ensure LTM state is detached when stored
210
+ self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
211
 
212
  outputs = self.model(
213
  input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
 
224
  labels[..., 1:].contiguous().view(-1))
225
  # Note: Auxiliary losses from main.py are calculated outside the model forward pass.
226
 
227
+ # CRITICAL FIX: Ensure returned LTM state is detached
228
  new_ltm_state = self.global_state_storage.get('ltm', None)
229
+ if new_ltm_state is not None:
230
+ new_ltm_state = new_ltm_state.detach()
231
 
232
  if not return_dict:
233
  output = (logits,) + outputs[1:] + (new_ltm_state,)
 
235
 
236
  return CausalLMOutputWithLTM(
237
  loss=loss, logits=logits, past_key_values=outputs.past_key_values,
238
+ hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)