Upload architectureV3.py
Browse filesFix RuntimeError on h100: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
- architectureV3.py +33 -13
architectureV3.py
CHANGED
@@ -99,13 +99,19 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
99 |
|
100 |
def forward(self, x: torch.Tensor):
|
101 |
base_output = self.linear(x)
|
102 |
-
if 'embeds' not in self.global_state_storage:
|
|
|
103 |
|
104 |
global_embeds = self.global_state_storage['embeds']
|
105 |
-
if global_embeds.shape[1] != x.shape[1]:
|
|
|
106 |
B, S, _ = x.shape
|
107 |
|
|
|
108 |
ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
|
|
|
|
109 |
proj_local = self.local_state_proj(x)
|
110 |
proj_global = self.global_state_proj(global_embeds)
|
111 |
memory_input = torch.stack([proj_global, proj_local], dim=2)
|
@@ -122,6 +128,7 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
122 |
if new_ltm_state_expanded is not None:
|
123 |
num_ltm_slots = new_ltm_state_expanded.shape[1]
|
124 |
new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
|
|
|
125 |
self.global_state_storage['ltm'] = new_ltm_condensed.detach()
|
126 |
|
127 |
initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
|
@@ -143,10 +150,11 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
143 |
final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
|
144 |
|
145 |
if self.training:
|
146 |
-
|
147 |
-
self.
|
148 |
-
self.
|
149 |
-
self.
|
|
|
150 |
return final_activation
|
151 |
|
152 |
# --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
|
@@ -155,16 +163,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
155 |
super().__init__(config)
|
156 |
self.global_state_storage = {}
|
157 |
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
|
158 |
-
self.memory_dim, self.num_long_term_memory_slots =
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
162 |
|
163 |
try:
|
164 |
original_layer = self.get_submodule(self.target_layer_path)
|
165 |
custom_layer = ReflectiveMemoryLayer(
|
166 |
original_layer=original_layer, global_input_dim=config.hidden_size,
|
167 |
-
memory_dim=self.memory_dim, num_memory_slots=
|
168 |
global_state_storage=self.global_state_storage)
|
169 |
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
|
170 |
setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
|
@@ -184,13 +195,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
184 |
output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
|
185 |
ltm_state: Optional[torch.Tensor] = None):
|
186 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
187 |
-
|
|
|
|
|
|
|
|
|
188 |
|
189 |
# *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
|
190 |
if ltm_state is None:
|
191 |
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
192 |
ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
|
193 |
-
|
|
|
|
|
194 |
|
195 |
outputs = self.model(
|
196 |
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
|
@@ -207,7 +224,10 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
207 |
labels[..., 1:].contiguous().view(-1))
|
208 |
# Note: Auxiliary losses from main.py are calculated outside the model forward pass.
|
209 |
|
|
|
210 |
new_ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
|
|
211 |
|
212 |
if not return_dict:
|
213 |
output = (logits,) + outputs[1:] + (new_ltm_state,)
|
@@ -215,4 +235,4 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
215 |
|
216 |
return CausalLMOutputWithLTM(
|
217 |
loss=loss, logits=logits, past_key_values=outputs.past_key_values,
|
218 |
-
hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
|
|
|
99 |
|
100 |
def forward(self, x: torch.Tensor):
|
101 |
base_output = self.linear(x)
|
102 |
+
if 'embeds' not in self.global_state_storage:
|
103 |
+
return base_output
|
104 |
|
105 |
global_embeds = self.global_state_storage['embeds']
|
106 |
+
if global_embeds.shape[1] != x.shape[1]:
|
107 |
+
global_embeds = global_embeds[:, -x.shape[1]:, :]
|
108 |
B, S, _ = x.shape
|
109 |
|
110 |
+
# CRITICAL FIX: Always detach LTM state to prevent backward through previous graphs
|
111 |
ltm_state = self.global_state_storage.get('ltm', None)
|
112 |
+
if ltm_state is not None:
|
113 |
+
ltm_state = ltm_state.detach()
|
114 |
+
|
115 |
proj_local = self.local_state_proj(x)
|
116 |
proj_global = self.global_state_proj(global_embeds)
|
117 |
memory_input = torch.stack([proj_global, proj_local], dim=2)
|
|
|
128 |
if new_ltm_state_expanded is not None:
|
129 |
num_ltm_slots = new_ltm_state_expanded.shape[1]
|
130 |
new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
|
131 |
+
# CRITICAL FIX: Always detach when storing in global state
|
132 |
self.global_state_storage['ltm'] = new_ltm_condensed.detach()
|
133 |
|
134 |
initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
|
|
|
150 |
final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
|
151 |
|
152 |
if self.training:
|
153 |
+
# CRITICAL FIX: Detach tensors stored for debugging/analysis
|
154 |
+
self.last_corrected_activation = final_activation.detach()
|
155 |
+
self.last_additive_correction = value.detach()
|
156 |
+
self.last_memory_input = memory_input.detach()
|
157 |
+
self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
|
158 |
return final_activation
|
159 |
|
160 |
# --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
|
|
|
163 |
super().__init__(config)
|
164 |
self.global_state_storage = {}
|
165 |
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
|
166 |
+
self.memory_dim, self.num_long_term_memory_slots = 128, 32
|
167 |
|
168 |
+
# CRITICAL FIX: Ensure embeddings are detached when stored
|
169 |
+
def embedding_hook(module, input, output):
|
170 |
+
self.global_state_storage['embeds'] = output.detach()
|
171 |
+
|
172 |
+
self.model.embed_tokens.register_forward_hook(embedding_hook)
|
173 |
|
174 |
try:
|
175 |
original_layer = self.get_submodule(self.target_layer_path)
|
176 |
custom_layer = ReflectiveMemoryLayer(
|
177 |
original_layer=original_layer, global_input_dim=config.hidden_size,
|
178 |
+
memory_dim=self.memory_dim, num_memory_slots=16, memory_num_heads=4,
|
179 |
global_state_storage=self.global_state_storage)
|
180 |
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
|
181 |
setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
|
|
|
195 |
output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
|
196 |
ltm_state: Optional[torch.Tensor] = None):
|
197 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
198 |
+
|
199 |
+
# CRITICAL FIX: Don't clear global state storage completely, just reset embeds
|
200 |
+
# This prevents losing LTM state continuity
|
201 |
+
if 'embeds' in self.global_state_storage:
|
202 |
+
del self.global_state_storage['embeds']
|
203 |
|
204 |
# *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
|
205 |
if ltm_state is None:
|
206 |
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
207 |
ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
|
208 |
+
|
209 |
+
# CRITICAL FIX: Ensure LTM state is detached when stored
|
210 |
+
self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
|
211 |
|
212 |
outputs = self.model(
|
213 |
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
|
|
|
224 |
labels[..., 1:].contiguous().view(-1))
|
225 |
# Note: Auxiliary losses from main.py are calculated outside the model forward pass.
|
226 |
|
227 |
+
# CRITICAL FIX: Ensure returned LTM state is detached
|
228 |
new_ltm_state = self.global_state_storage.get('ltm', None)
|
229 |
+
if new_ltm_state is not None:
|
230 |
+
new_ltm_state = new_ltm_state.detach()
|
231 |
|
232 |
if not return_dict:
|
233 |
output = (logits,) + outputs[1:] + (new_ltm_state,)
|
|
|
235 |
|
236 |
return CausalLMOutputWithLTM(
|
237 |
loss=loss, logits=logits, past_key_values=outputs.past_key_values,
|
238 |
+
hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
|