JonasGeiping commited on
Commit
06ac94c
·
verified ·
1 Parent(s): 61eaf33

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +601 -302
raven_modeling_minimal.py CHANGED
@@ -6,9 +6,10 @@ import math
6
  from torch import Tensor
7
  from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
  from torch.nn.attention import bias as attn_bias
 
9
  from dataclasses import dataclass
10
- from typing import Union, Optional, Any
11
-
12
 
13
  from .raven_config_minimal import RavenConfig
14
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
@@ -21,8 +22,6 @@ from transformers.generation.utils import GenerateDecoderOnlyOutput
21
  import torch.nn.functional as F
22
  from transformers import GenerationConfig
23
 
24
- torch.backends.cuda.enable_math_sdp(False)
25
-
26
 
27
  class RavenPreTrainedModel(PreTrainedModel):
28
  config_class = RavenConfig
@@ -38,9 +37,77 @@ class RavenPreTrainedModel(PreTrainedModel):
38
  _supports_static_cache = True
39
  _tp_plan = {}
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def _init_weights(self, module):
42
- if not torch.rand((1,)).is_meta:
43
- print("Random Initialization not implemented.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  @dataclass
@@ -468,6 +535,9 @@ class SandwichBlock(torch.nn.Module):
468
  return x
469
 
470
 
 
 
 
471
  class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
472
  freqs_cis: torch.Tensor
473
 
@@ -498,13 +568,15 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
498
  ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
499
  )
500
  )
501
- self.emb_scale = config.init_values["embed_scale"]
502
  # Head
503
  self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
504
  if self.config.tie_embeddings:
505
  self.tie_weights()
506
  # rope
507
  self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
 
 
 
508
 
509
  def get_input_embeddings(self):
510
  return self.transformer.wte
@@ -513,11 +585,9 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
513
  return self.lm_head
514
 
515
  def _precompute_freqs_cis(self):
516
- # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
517
- freqs_cis = precompute_freqs_cis(
518
  self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
519
  )
520
- return freqs_cis
521
 
522
  def compile_mask(
523
  self,
@@ -557,72 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
557
  H=None,
558
  Q_LEN=seq_len,
559
  KV_LEN=kv_length,
560
- device=input_ids.device,
561
- )
562
-
563
- # # Define mask_mod function
564
- # def mask_mod(b, h, q_idx, kv_idx):
565
- # # Always apply causal constraint
566
- # is_causal = q_idx >= kv_idx
567
-
568
- # # Handle cache vs current tokens
569
- # is_cache = kv_idx < cache_len
570
- # current_idx = kv_idx - cache_len
571
-
572
- # # For cache: always valid; For current: check padding
573
- # not_pad = input_ids[b, current_idx] != pad_token_id
574
- # valid = is_cache | not_pad
575
-
576
- # # Apply attention mask if provided
577
- # if attention_mask is not None:
578
- # q_idx_curr = q_idx - cache_len
579
- # attn_valid = attention_mask[b, q_idx_curr, current_idx]
580
- # valid = valid & (is_cache | attn_valid)
581
-
582
- # return is_causal & valid
583
-
584
- # def mask_mod(b, h, q_idx, kv_idx):
585
- # is_causal = q_idx >= kv_idx
586
- # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
587
- # current_idx = kv_idx - cache_len
588
-
589
- # is_valid = (~is_current) | (
590
- # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
591
- # )
592
-
593
- # return is_causal & is_valid
594
-
595
- # # Define mask_mod function
596
- # def mask_mod(b, h, q_idx, kv_idx):
597
- # # Always apply causal constraint
598
- # is_causal = q_idx >= kv_idx
599
-
600
- # # Handle cache vs current tokens
601
- # is_cache = kv_idx < cache_len
602
- # current_idx = kv_idx - cache_len
603
- # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
604
-
605
- # # For cache: always valid; For current: check padding
606
- # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
607
- # valid = is_cache | (not_pad & in_bounds)
608
-
609
- # # Apply attention mask if provided
610
- # if attention_mask is not None:
611
- # q_idx_curr = q_idx - cache_len
612
- # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
613
- # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
614
- # valid = valid & (is_cache | attn_valid)
615
-
616
- # return is_causal & valid
617
-
618
- # Create block mask
619
- block_mask = create_block_mask(
620
- mask_mod,
621
- B=batch_size,
622
- H=None,
623
- Q_LEN=seq_len,
624
- KV_LEN=kv_length,
625
- device=input_ids.device,
626
  )
627
 
628
  return block_mask
@@ -748,7 +753,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
748
 
749
  for grad_step in range(num_steps_with_grad):
750
  xk = x
751
- x, block_idx = self.core_block_forward(
752
  xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
753
  )
754
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
@@ -763,13 +768,73 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
763
  block_idx: torch.Tensor,
764
  current_step: int | Tensor,
765
  ):
 
766
  x = self._maybe_inject_noise(x, current_step)
767
  x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
768
  for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
769
  block_idx += 1
770
  x = block(x, freqs_cis, block_idx, mask, past_key_values)
 
771
  return x, block_idx
772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
  @torch.no_grad()
774
  def iterate_one_step(
775
  self,
@@ -865,61 +930,135 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
865
  input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
866
  return input_embeds, block_idx
867
 
868
- @torch._dynamo.disable(recursive=False) # type: ignore
869
- def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
870
- """Outputs are long tensors so that they can be passed through compiled functions"""
871
- t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
872
- s = self.config.mean_backprop_depth
873
- if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
874
- # these values are only the mean TFLOPs of the randomized sampler
875
- # Note that this clause also breaks the contract, and returns ints in meta tensor mode
876
- return t, s # type: ignore
877
- if self.training:
878
- sigma = 0.5
879
- mu = math.log(t + s) - (sigma**2 / 2)
880
- rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
881
- p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
882
- n = torch.clamp(p - s, min=0)
883
- k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
884
- else:
885
- n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
886
 
887
- return n.to(dtype=torch.long), k.to(dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
- def initialize_state(self, input_embeds, scale: float = 1.0):
890
- x = torch.randn_like(input_embeds)
891
- std = self.config.init_values["std"] * scale
892
- if std > 0:
893
- torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
894
- if self.emb_scale != 1:
895
- x = x * self.emb_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  else:
897
- x.zero_()
898
- return x
899
 
900
- def _maybe_inject_noise(self, x, current_step, renorm=False):
901
- if self.config.test_time_noise > 0:
902
- n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
903
- if self.config.test_time_noise_type == "geom":
904
- step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
905
- x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
906
- elif self.config.test_time_noise_type == "sqrt":
907
- step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
908
- x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
909
- elif self.config.test_time_noise_type == "line":
910
- noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
911
- x = x * (1 - noise) + torch.randn_like(x) * noise
912
- elif self.config.test_time_noise_type == "chi":
913
- noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
914
- x = x * (1 - noise) + torch.randn_like(x) * noise
915
- elif self.config.test_time_noise_type == "fixed":
916
- x = x * (1 - n) + torch.randn_like(x) * n
917
- else:
918
- raise ValueError()
919
 
920
- if renorm:
921
- x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
922
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
 
924
  def prepare_inputs_for_generation(
925
  self,
@@ -971,11 +1110,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
971
  def generate(self, *args, **kwargs):
972
  """Dispatcher - use HF generate in all normal cases."""
973
  self.generation_config = args[1] if len(args) > 1 else self.generation_config
974
- if any(k in kwargs for k in ("criterion", "exit_threshold")):
975
- # print("Dispatching to custom generate_adaptive function call")
976
  return self.generate_with_adaptive_compute(*args, **kwargs)
 
 
977
  elif "continuous_compute" in kwargs:
978
- # print("Dispatching to custom generate_minimal function call")
979
  return self.generate_minimal(*args, **kwargs)
980
  else:
981
  return super().generate(*args, **kwargs)
@@ -1013,7 +1152,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1013
  lookup_strategy=cache_lookup_strategy,
1014
  )
1015
  model_kwargs["use_cache"] = True
1016
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1017
  return model_kwargs, generation_config, max_new_tokens
1018
 
1019
  @torch.no_grad()
@@ -1030,7 +1169,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1030
  ) -> Union[torch.Tensor, dict[str, Any]]:
1031
  """Minimal single-sequence generation. Template for more complicated generate tasks"""
1032
  model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1033
- input_ids, generation_config, cache_lookup_strategy
1034
  )
1035
  stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1036
  unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
@@ -1093,15 +1232,25 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1093
  tokenizer=None,
1094
  streamer=None,
1095
  continuous_compute=False, # warm-start state / continuous CoT
1096
- criterion="none", # off by default, turn on by choosing an exit criterion
1097
  exit_threshold: Union[str, float, int] = "auto",
1098
  init_scale: float = 1.0,
1099
  cache_lookup_strategy: str = "full",
 
 
 
 
1100
  **model_kwargs,
1101
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1102
  """
1103
  Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1104
  For batches, on each token, we iterate until the entire batch finishes.
 
 
 
 
 
 
1105
  """
1106
  model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1107
  input_ids, generation_config, cache_lookup_strategy, model_kwargs
@@ -1120,8 +1269,11 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1120
  # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1121
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1122
 
 
 
 
1123
  # Generate tokens
1124
- for _ in range(max_new_tokens):
1125
  # Adaptive compute forward
1126
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1127
  aux_inputs = {
@@ -1134,38 +1286,20 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1134
  else model_kwargs["input_states"]
1135
  )
1136
 
 
 
 
 
1137
  # Initialize criterion tracking for each sequence in batch
1138
  exit_values_per_seq = [[] for _ in range(batch_size)]
1139
  compute_steps_per_seq = [0] * batch_size
1140
  exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1141
 
1142
- # Set up criterions based on selected strategy
1143
- if criterion == "entropy-diff":
1144
- entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1145
- exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1146
- elif criterion == "latent-diff":
1147
- exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1148
- elif "kl" in criterion:
1149
- V = self.config.padded_vocab_size
1150
- log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1151
- if criterion == "minp-kl":
1152
- exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1153
- else:
1154
- exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
1155
- elif criterion == "argmax-stability":
1156
- stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1157
- current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1158
- exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1159
- elif criterion == "none":
1160
- exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1161
- else:
1162
- raise ValueError("Invalid adaptive compute strategy.")
1163
-
1164
- next_token_logits = None
1165
 
1166
  # Iterate through compute steps
1167
  for compute_step in range(max_steps):
1168
- prev_latents = current_latents.clone()
1169
  current_latents, block_idx, _ = self.iterate_one_step(
1170
  embedded_inputs,
1171
  current_latents,
@@ -1174,94 +1308,70 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1174
  current_step=compute_step,
1175
  )
1176
 
1177
- if _ > 0: # do not exit in prefill
1178
- # Check exit condition for each sequence in batch
1179
- if criterion == "entropy-diff":
1180
- prev_entropy = entropy
1181
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1182
- logits: torch.Tensor = outputs.logits # type: ignore
1183
- probs = F.softmax(logits[:, -1, :], dim=-1)
1184
- entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1185
- exit_values = (entropy - prev_entropy).abs()
1186
- elif criterion == "latent-diff":
1187
- norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1188
- exit_values = norm_diff.mean(dim=-1)
1189
- elif "kl" in criterion:
1190
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1191
- logits: torch.Tensor = outputs.logits # type: ignore
1192
- prev_log_probs = log_probs
1193
- if criterion == "minp-kl":
1194
- probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1195
- max_probs = probs.max(dim=-1, keepdim=True)[0]
1196
- probs_mask = probs < (0.1 * max_probs)
1197
- masked_probs = probs.clone()
1198
- masked_probs[probs_mask] = 1 / V
1199
- probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1200
- log_probs = probs.log()
1201
- else:
1202
- log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1203
- exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1204
- elif criterion == "argmax-stability":
1205
- prev_argmax = current_argmax
1206
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
1207
- logits: torch.Tensor = outputs.logits # type: ignore
1208
- current_argmax = logits[:, -1, :].argmax(dim=-1)
1209
- stable_for_n_steps = torch.where(
1210
- current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1211
- )
1212
- exit_values = stable_for_n_steps
1213
- elif criterion == "none":
1214
- exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1215
-
1216
- # Record values and check exits for each sequence
1217
  for i in range(batch_size):
1218
- if not exit_reached[i] and unfinished_sequences[i].bool():
1219
- exit_values_per_seq[i].append(exit_values[i].item())
1220
 
1221
- # Check for new exits, respecting unfinished_sequences
1222
- new_exits = (
1223
- exit_values < exit_threshold
1224
- if criterion != "argmax-stability"
1225
- else exit_values >= exit_threshold
1226
- )
1227
- new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1228
-
1229
- if new_exits.any():
1230
- exit_reached = exit_reached | new_exits
1231
- if criterion == "latent-diff":
1232
- # Normally we don't compute the output for latent-diff, but when there is an exit,
1233
- # we need to compute and save the output
1234
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
1235
- logits: torch.Tensor = outputs.logits # type: ignore
1236
- if next_token_logits is None:
1237
- next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1238
- else:
1239
- for i in range(batch_size):
1240
- if new_exits[i]:
1241
- next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1242
- for i in range(batch_size):
1243
- if new_exits[i]:
1244
- compute_steps_per_seq[i] = compute_step + 1
1245
-
1246
- # If all sequences have exited or finished, break early
1247
- if (exit_reached | ~unfinished_sequences.bool()).all():
1248
- break
1249
- # This else is if the for loop finished without breaking
1250
  else:
1251
- outputs = self.predict_from_latents(current_latents, **aux_inputs)
 
1252
 
1253
  # For sequences that didn't exit early, use the final logits
1254
  if next_token_logits is None:
1255
  next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
 
 
1256
  else:
1257
  for i in range(batch_size):
1258
  if not exit_reached[i] and unfinished_sequences[i].bool():
1259
  next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1260
  compute_steps_per_seq[i] = max_steps
1261
-
1262
  # Save latent states for continuous compute if enabled
1263
  if continuous_compute:
1264
- model_kwargs["input_states"] = current_latents[:, -1:, :]
 
 
1265
 
1266
  # Record compute steps for this token generation
1267
  compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
@@ -1276,7 +1386,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1276
  streamer.put(next_token.cpu())
1277
 
1278
  # Update model kwargs for next iteration
1279
- model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1280
 
1281
  # Check for stop tokens and update unfinished sequences
1282
  for i in range(batch_size):
@@ -1309,62 +1419,6 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1309
  )
1310
  return input_ids
1311
 
1312
- def _get_stops(self, generation_config, tokenizer, model_kwargs):
1313
- stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1314
- if generation_config.eos_token_id is not None:
1315
- stop_tokens.add(generation_config.eos_token_id)
1316
- if "stopping_criteria" in model_kwargs and tokenizer is None:
1317
- tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1318
- if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1319
- for s in generation_config.stop_strings:
1320
- token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1321
- stop_tokens.add(token_id)
1322
- return torch.tensor(list(stop_tokens))
1323
-
1324
- def _sample_next_token(self, next_token_logits, generation_config):
1325
- """Helper function to sample the next token."""
1326
- if generation_config.do_sample:
1327
- if generation_config.temperature:
1328
- next_token_logits = next_token_logits.float() / generation_config.temperature
1329
-
1330
- probs = F.softmax(next_token_logits, dim=-1)
1331
-
1332
- # Apply top_k
1333
- if generation_config.top_k:
1334
- top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1335
- min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1336
- probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1337
-
1338
- # Apply top_p (nucleus sampling)
1339
- if generation_config.top_p:
1340
- sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1341
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1342
-
1343
- # Create mask for probs to keep
1344
- remove_indices = cumulative_probs > generation_config.top_p
1345
- remove_indices[:, 0] = False # Keep at least the top probability
1346
-
1347
- # Convert sorted indices mask back to original indices mask
1348
- mask = torch.zeros_like(probs, dtype=torch.bool)
1349
- for i in range(probs.shape[0]):
1350
- mask[i, sorted_indices[i, remove_indices[i]]] = True
1351
-
1352
- probs = torch.where(mask, torch.zeros_like(probs), probs)
1353
-
1354
- # Apply min_p
1355
- if generation_config.min_p:
1356
- max_probs = probs.max(dim=-1, keepdim=True)[0]
1357
- min_p_threshold = generation_config.min_p * max_probs
1358
- probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1359
-
1360
- # Renormalize probabilities
1361
- probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1362
-
1363
- # Sample from the distribution
1364
- return torch.multinomial(probs, num_samples=1)
1365
- else:
1366
- return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1367
-
1368
  @torch.no_grad()
1369
  def generate_speculative(
1370
  self,
@@ -1546,22 +1600,69 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
1546
  )
1547
  return input_ids
1548
 
1549
- def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1550
- probs = torch.softmax(logits.float(), dim=-1)
1551
- prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1552
- residual_diff = (x - latent_states).norm(dim=-1)
1553
- rel_residual = residual_diff / latent_states.norm(dim=-1)
1554
- stats = {
1555
- "entropy": prob_entropy,
1556
- "residual_diff": residual_diff,
1557
- "rel_residual": rel_residual,
1558
- "num_steps_no_grad": num_steps_no_grad,
1559
- "num_steps_with_grad": num_steps_with_grad,
1560
- }
1561
- return stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1562
 
1563
 
1564
- #################################### Utils #######################################################################
1565
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
1566
  with torch.autocast("cuda", enabled=False):
1567
  inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
@@ -1587,6 +1688,204 @@ def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tu
1587
  return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
1588
 
1589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1590
  #################################### HF registration ############################################################
1591
 
1592
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 
6
  from torch import Tensor
7
  from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
  from torch.nn.attention import bias as attn_bias
9
+ from torch.utils.checkpoint import checkpoint
10
  from dataclasses import dataclass
11
+ from typing import Union, Optional, Any, Tuple, Callable, List
12
+ from functools import cache, cached_property
13
 
14
  from .raven_config_minimal import RavenConfig
15
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
 
22
  import torch.nn.functional as F
23
  from transformers import GenerationConfig
24
 
 
 
25
 
26
  class RavenPreTrainedModel(PreTrainedModel):
27
  config_class = RavenConfig
 
37
  _supports_static_cache = True
38
  _tp_plan = {}
39
 
40
+ @cache
41
+ def _init_func(self, dim, num_layers):
42
+ return {
43
+ "std": math.sqrt(2 / (5 * dim)),
44
+ "out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
45
+ "embedding": math.sqrt(2 / (5 * dim)),
46
+ "embed_scale": math.sqrt(dim),
47
+ }
48
+
49
+ @property
50
+ def emb_scale(self):
51
+ return self._init_func(self.config.n_embd, self.config.effective_expected_depth)["embed_scale"]
52
+
53
+ def _normal_(self, tensor, std):
54
+ return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
55
+
56
+ @torch.no_grad()
57
+ def init_qkv(self, qkv_tensor, init_fn, qk_std, v_std, dim, head_dim):
58
+ s = qkv_tensor.shape[0]
59
+ n_kv_heads = (s - dim) // (2 * head_dim)
60
+ shapes = [dim, n_kv_heads * head_dim, n_kv_heads * head_dim]
61
+
62
+ Q, K, V = (
63
+ qkv_tensor.new_empty([shapes[0], dim]),
64
+ qkv_tensor.new_empty([shapes[1], dim]),
65
+ qkv_tensor.new_empty([shapes[2], dim]),
66
+ )
67
+ init_fn(Q, qk_std)
68
+ init_fn(K, qk_std)
69
+ init_fn(V, v_std)
70
+ qkv_tensor.data.copy_(torch.cat([Q, K, V], dim=0).contiguous())
71
+
72
+ @torch.no_grad()
73
+ def init_glu(self, glu_tensor, init_fn, w1_std, w2_std):
74
+ g, h = glu_tensor.shape
75
+ W1, W2 = (
76
+ glu_tensor.new_empty([g // 2, h]),
77
+ glu_tensor.new_empty([g // 2, h]),
78
+ )
79
+ init_fn(W1, w1_std)
80
+ init_fn(W2, w2_std)
81
+ glu_tensor.data.copy_(torch.cat([W1, W2], dim=0).contiguous())
82
+
83
+ @cached_property
84
+ def _full_name_of_module_lookup(self):
85
+ return {id(m): n for n, m in self.named_modules()}
86
+
87
+ @torch.no_grad()
88
  def _init_weights(self, module):
89
+ _init_values = self._init_func(self.config.n_embd, self.config.effective_expected_depth)
90
+ name = self._full_name_of_module_lookup[id(module)]
91
+ if isinstance(module, RMSNorm):
92
+ torch.nn.init.ones_(module.weight)
93
+ elif isinstance(module, torch.nn.Linear):
94
+ if "Wqkv" in name:
95
+ self.init_qkv(
96
+ module.weight,
97
+ self._normal_,
98
+ float(_init_values["std"]),
99
+ float(_init_values["std"]),
100
+ self.config.n_embd,
101
+ self.config.head_dim,
102
+ )
103
+ elif "fc" in name:
104
+ self.init_glu(module.weight, self._normal_, float(_init_values["std"]), float(_init_values["out_proj"]))
105
+ elif "mlp.proj" in name or "attn.proj" in name:
106
+ self._normal_(module.weight, std=float(_init_values["out_proj"]))
107
+ elif "adapter" in name or "lm_head" in name:
108
+ self._normal_(module.weight, std=float(_init_values["std"]))
109
+ elif isinstance(module, torch.nn.Embedding):
110
+ self._normal_(module.weight, std=float(_init_values["embedding"]))
111
 
112
 
113
  @dataclass
 
535
  return x
536
 
537
 
538
+ #################################### Main Model ##################################################################
539
+
540
+
541
  class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
542
  freqs_cis: torch.Tensor
543
 
 
568
  ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
569
  )
570
  )
 
571
  # Head
572
  self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
573
  if self.config.tie_embeddings:
574
  self.tie_weights()
575
  # rope
576
  self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
577
+ self.gradient_checkpointing = False
578
+ # Call weight init through HF post init:
579
+ self.post_init()
580
 
581
  def get_input_embeddings(self):
582
  return self.transformer.wte
 
585
  return self.lm_head
586
 
587
  def _precompute_freqs_cis(self):
588
+ return precompute_freqs_cis(
 
589
  self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
590
  )
 
591
 
592
  def compile_mask(
593
  self,
 
627
  H=None,
628
  Q_LEN=seq_len,
629
  KV_LEN=kv_length,
630
+ device=str(input_ids.device),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  )
632
 
633
  return block_mask
 
753
 
754
  for grad_step in range(num_steps_with_grad):
755
  xk = x
756
+ x, block_idx = self._maybe_checkpoint_core_block(
757
  xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
758
  )
759
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
 
768
  block_idx: torch.Tensor,
769
  current_step: int | Tensor,
770
  ):
771
+ block_idx = block_idx.detach().clone() # line only included to convince torch.checkpointing
772
  x = self._maybe_inject_noise(x, current_step)
773
  x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
774
  for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
775
  block_idx += 1
776
  x = block(x, freqs_cis, block_idx, mask, past_key_values)
777
+
778
  return x, block_idx
779
 
780
+ @torch._dynamo.disable(recursive=False) # type: ignore
781
+ def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
782
+ """Outputs are long tensors so that they can be passed through compiled functions"""
783
+ t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
784
+ s = self.config.mean_backprop_depth
785
+ if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
786
+ # these values are only the mean TFLOPs of the randomized sampler
787
+ # Note that this clause also breaks the contract, and returns ints in meta tensor mode
788
+ return t, s # type: ignore
789
+ if self.training:
790
+ sigma = 0.5
791
+ mu = math.log(t + s) - (sigma**2 / 2)
792
+ rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
793
+ p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
794
+ n = torch.clamp(p - s, min=0)
795
+ k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
796
+ else:
797
+ n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
798
+
799
+ return n.to(dtype=torch.long), k.to(dtype=torch.long)
800
+
801
+ def initialize_state(self, input_embeds, scale: float = 1.0):
802
+ x = torch.randn_like(input_embeds)
803
+ std = self.config.init_values["std"] * scale
804
+ if std > 0:
805
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
806
+ if self.emb_scale != 1:
807
+ x = x * self.emb_scale
808
+ else:
809
+ x.zero_()
810
+ return x
811
+
812
+ def _maybe_inject_noise(self, x, current_step, renorm=True):
813
+ if self.config.test_time_noise > 0:
814
+ n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
815
+ if self.config.test_time_noise_type == "geom":
816
+ step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
817
+ x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
818
+ elif self.config.test_time_noise_type == "sqrt":
819
+ step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
820
+ x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
821
+ elif self.config.test_time_noise_type == "line":
822
+ noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
823
+ x = x * (1 - noise) + torch.randn_like(x) * noise
824
+ elif self.config.test_time_noise_type == "chi":
825
+ noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
826
+ x = x * (1 - noise) + torch.randn_like(x) * noise
827
+ elif self.config.test_time_noise_type == "fixed":
828
+ x = x * (1 - n) + torch.randn_like(x) * n
829
+ else:
830
+ raise ValueError()
831
+
832
+ if renorm:
833
+ x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
834
+ return x
835
+
836
+ """ ------------------ Alternative interfaces into the model forward ---------------------------------------- """
837
+
838
  @torch.no_grad()
839
  def iterate_one_step(
840
  self,
 
930
  input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
931
  return input_embeds, block_idx
932
 
933
+ @torch.no_grad()
934
+ def _prefill_with_varied_exit_steps(
935
+ self,
936
+ input_ids: torch.Tensor,
937
+ exit_evaluator: "PerIterationExitEvaluator",
938
+ past_key_values: Optional[ValidCache] = None,
939
+ init_scale: float = 1.0,
940
+ **kwargs,
941
+ ) -> Tuple[torch.Tensor, ValidCache, List[int]]:
942
+ """ "
943
+ Note that this the opposite of a real prefill, it goes token-by token and can adaptively exit on each.
944
+ Use for scientific experiments.
945
+ """
946
+ # currently the cache doesn't support batching with adaptive compute
947
+ assert input_ids.shape[0] == 1
 
 
 
948
 
949
+ if past_key_values is None:
950
+ past_key_values = HuginnDynamicCache()
951
+ attention_mask = None
952
+ output = torch.empty(
953
+ (input_ids.shape[0], 0, self.config.vocab_size), device=input_ids.device, dtype=torch.float
954
+ )
955
+ compute_steps = []
956
+ for pos in range(input_ids.shape[1]):
957
+ aux_inputs = {
958
+ "cache_position": pos,
959
+ "past_key_values": past_key_values,
960
+ "attention_mask": attention_mask,
961
+ }
962
+ freqs_cis = self.freqs_cis[:, pos]
963
+ embedded_inputs, block_idx = self.embed_inputs(input_ids[:, pos].unsqueeze(1), **aux_inputs)
964
 
965
+ current_latents = self.initialize_state(embedded_inputs, scale=init_scale)
966
+ exit_evaluator.init(current_latents)
967
+
968
+ # Main recurrence
969
+ for compute_step in range(self.config.mean_recurrence):
970
+ current_latents, block_idx, _ = self.iterate_one_step(
971
+ embedded_inputs,
972
+ current_latents,
973
+ block_idx=block_idx,
974
+ **aux_inputs,
975
+ current_step=compute_step,
976
+ )
977
+ new_exits, _, _ = exit_evaluator.check(self, current_latents, aux_inputs)
978
+ if new_exits.any():
979
+ break
980
+ compute_steps.append(compute_step + 1)
981
+
982
+ x = self.transformer.ln_f(current_latents) # type: ignore
983
+
984
+ # Coda layers
985
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
986
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
987
+ block_idx -= 1
988
+ x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
989
+
990
+ x = self.transformer.ln_f(x) # type: ignore
991
+ logits = self.lm_head(x).float()
992
+ output = torch.cat([output, logits], dim=1)
993
+ return output, past_key_values, compute_steps # type: ignore
994
+
995
+ @torch.no_grad()
996
+ def forward_with_adaptive_compute(
997
+ self,
998
+ input_ids: torch.Tensor,
999
+ exit_evaluator: "PerIterationExitEvaluator",
1000
+ labels: Optional[torch.Tensor] = None,
1001
+ past_key_values: Optional[ValidCache] = None,
1002
+ output_details: dict = {
1003
+ "return_logits": True,
1004
+ "return_latents": True,
1005
+ "return_head": False,
1006
+ "return_stats": False,
1007
+ },
1008
+ init_scale: float = 1.0,
1009
+ **kwargs,
1010
+ ) -> CausalLMOutputRecurrentLatents:
1011
+ """This forward call does not make use of the causal nature of transformers, it runs token-by token!
1012
+ Do not use this function for anything other than scientific experiments with adaptive compute!
1013
+ """
1014
+ logits, past_key_values, compute_steps = self._prefill_with_varied_exit_steps(
1015
+ input_ids, exit_evaluator, past_key_values, init_scale
1016
+ )
1017
+ if labels is not None:
1018
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
1019
+ log_ppl = loss.clone().detach()
1020
  else:
1021
+ loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
 
1022
 
1023
+ return CausalLMOutputRecurrentLatents(
1024
+ loss=loss,
1025
+ log_ppl=log_ppl,
1026
+ logits=logits if output_details["return_logits"] else None,
1027
+ past_key_values=None,
1028
+ hidden_states=None,
1029
+ latent_states=None,
1030
+ attention_maps=None,
1031
+ stats={"compute_steps": compute_steps},
1032
+ )
 
 
 
 
 
 
 
 
 
1033
 
1034
+ def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1035
+ probs = torch.softmax(logits.float(), dim=-1)
1036
+ prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1037
+ residual_diff = (x - latent_states).norm(dim=-1)
1038
+ rel_residual = residual_diff / latent_states.norm(dim=-1)
1039
+ stats = {
1040
+ "entropy": prob_entropy,
1041
+ "residual_diff": residual_diff,
1042
+ "rel_residual": rel_residual,
1043
+ "num_steps_no_grad": num_steps_no_grad,
1044
+ "num_steps_with_grad": num_steps_with_grad,
1045
+ }
1046
+ return stats
1047
+
1048
+ def _maybe_checkpoint_core_block(self, *args, **kwargs) -> tuple[Tensor, Tensor]:
1049
+ if self.gradient_checkpointing:
1050
+ return checkpoint(
1051
+ self.core_block_forward,
1052
+ *args,
1053
+ use_reentrant=False,
1054
+ preserve_rng_state=False,
1055
+ determinism_check="none",
1056
+ **kwargs,
1057
+ ) # type: ignore
1058
+ else:
1059
+ return self.core_block_forward(*args)
1060
+
1061
+ """"------------------------------------------Generation Utilities from here----------------------------------"""
1062
 
1063
  def prepare_inputs_for_generation(
1064
  self,
 
1110
  def generate(self, *args, **kwargs):
1111
  """Dispatcher - use HF generate in all normal cases."""
1112
  self.generation_config = args[1] if len(args) > 1 else self.generation_config
1113
+ if any(k in kwargs for k in ("criterion", "exit_threshold", "exit_evaluator")):
 
1114
  return self.generate_with_adaptive_compute(*args, **kwargs)
1115
+ elif any(k in kwargs for k in ("draft_steps", "lookahead_for_draft", "verification_threshold")):
1116
+ return self.generate_speculative(*args, **kwargs)
1117
  elif "continuous_compute" in kwargs:
 
1118
  return self.generate_minimal(*args, **kwargs)
1119
  else:
1120
  return super().generate(*args, **kwargs)
 
1152
  lookup_strategy=cache_lookup_strategy,
1153
  )
1154
  model_kwargs["use_cache"] = True
1155
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
1156
  return model_kwargs, generation_config, max_new_tokens
1157
 
1158
  @torch.no_grad()
 
1169
  ) -> Union[torch.Tensor, dict[str, Any]]:
1170
  """Minimal single-sequence generation. Template for more complicated generate tasks"""
1171
  model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1172
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1173
  )
1174
  stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1175
  unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
 
1232
  tokenizer=None,
1233
  streamer=None,
1234
  continuous_compute=False, # warm-start state / continuous CoT
1235
+ criterion="none", # adaptive compute is off by default, turn on by choosing an exit criterion
1236
  exit_threshold: Union[str, float, int] = "auto",
1237
  init_scale: float = 1.0,
1238
  cache_lookup_strategy: str = "full",
1239
+ do_not_exit_in_prefill: bool = False,
1240
+ min_steps: int = 0,
1241
+ check_criterion_every_n_steps=1,
1242
+ exit_evaluator: "Optional[PerIterationExitEvaluator]" = None, # optional plugin of a new exit eval object
1243
  **model_kwargs,
1244
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1245
  """
1246
  Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1247
  For batches, on each token, we iterate until the entire batch finishes.
1248
+ Note: While the method can be used batched, and will produce sensible results, this cannot be used to evaluate
1249
+ the success of adaptive compute methods, which should only ever be benchmarked with batch_size=1.
1250
+ This is because the KV-cache entries are necessarily batched and so contain entries equal to the sequence
1251
+ with the largest number of steps in the whole batch, and these KV states, which would not have been computed
1252
+ if there was only one (short compute) sequence in the batch, will be picked up by later compute steps,
1253
+ making early exits look better than they are.
1254
  """
1255
  model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1256
  input_ids, generation_config, cache_lookup_strategy, model_kwargs
 
1269
  # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1270
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1271
 
1272
+ if exit_evaluator is None:
1273
+ exit_evaluator = get_adaptive_exit_evaluator(self, criterion, exit_threshold)
1274
+
1275
  # Generate tokens
1276
+ for token_step_in_sequence in range(max_new_tokens):
1277
  # Adaptive compute forward
1278
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1279
  aux_inputs = {
 
1286
  else model_kwargs["input_states"]
1287
  )
1288
 
1289
+ # Initialize next_states for continuous compute
1290
+ if continuous_compute:
1291
+ next_states = current_latents[:, -1:, :].clone()
1292
+
1293
  # Initialize criterion tracking for each sequence in batch
1294
  exit_values_per_seq = [[] for _ in range(batch_size)]
1295
  compute_steps_per_seq = [0] * batch_size
1296
  exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1297
 
1298
+ outputs, next_token_logits = None, None
1299
+ exit_evaluator.init(current_latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1300
 
1301
  # Iterate through compute steps
1302
  for compute_step in range(max_steps):
 
1303
  current_latents, block_idx, _ = self.iterate_one_step(
1304
  embedded_inputs,
1305
  current_latents,
 
1308
  current_step=compute_step,
1309
  )
1310
 
1311
+ # Skip checking exit conditions if min_steps not met, or not checking this step, or in prefill
1312
+ if (
1313
+ compute_step < min_steps
1314
+ or (compute_step - min_steps) % check_criterion_every_n_steps != 0
1315
+ or (do_not_exit_in_prefill and token_step_in_sequence == 0)
1316
+ ):
1317
+ continue
1318
+
1319
+ # Otherwise check for new exits, potentially by evaluating the coda:
1320
+ new_exits, outputs, exit_values = exit_evaluator.check(self, current_latents, aux_inputs)
1321
+
1322
+ # Record values and check exits for each sequence
1323
+ for i in range(batch_size):
1324
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1325
+ exit_values_per_seq[i].append(exit_values[i].item())
1326
+
1327
+ new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1328
+
1329
+ if new_exits.any():
1330
+ exit_reached = exit_reached | new_exits
1331
+ if outputs is not None:
1332
+ logits = outputs.logits
1333
+ else:
1334
+ # For latent-based criteria, compute outputs when we need them
 
 
 
 
 
1335
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
1336
+ logits = outputs.logits
1337
+
1338
+ if next_token_logits is None:
1339
+ next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1340
+ else:
1341
+ next_token_logits[new_exits] = logits[new_exits, -1, :].to(**logit_type) # type: ignore
1342
+
 
 
 
1343
  for i in range(batch_size):
1344
+ if new_exits[i]:
1345
+ compute_steps_per_seq[i] = compute_step + 1
1346
 
1347
+ # Update continuous compute states for newly exited sequences
1348
+ if continuous_compute:
1349
+ next_states[new_exits] = current_latents[new_exits, -1:, :]
1350
+
1351
+ # If all sequences have exited or finished, break early
1352
+ if (exit_reached | ~unfinished_sequences.bool()).all():
1353
+ break
1354
+
1355
+ # This else triggers if the for loop finishes without breaking:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1356
  else:
1357
+ if outputs is None:
1358
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1359
 
1360
  # For sequences that didn't exit early, use the final logits
1361
  if next_token_logits is None:
1362
  next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1363
+ for i in range(batch_size):
1364
+ compute_steps_per_seq[i] = max_steps
1365
  else:
1366
  for i in range(batch_size):
1367
  if not exit_reached[i] and unfinished_sequences[i].bool():
1368
  next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1369
  compute_steps_per_seq[i] = max_steps
 
1370
  # Save latent states for continuous compute if enabled
1371
  if continuous_compute:
1372
+ still_running = ~exit_reached & unfinished_sequences.bool()
1373
+ next_states[still_running] = current_latents[still_running, -1:, :]
1374
+ model_kwargs["input_states"] = next_states
1375
 
1376
  # Record compute steps for this token generation
1377
  compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
 
1386
  streamer.put(next_token.cpu())
1387
 
1388
  # Update model kwargs for next iteration
1389
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) # type: ignore
1390
 
1391
  # Check for stop tokens and update unfinished sequences
1392
  for i in range(batch_size):
 
1419
  )
1420
  return input_ids
1421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1422
  @torch.no_grad()
1423
  def generate_speculative(
1424
  self,
 
1600
  )
1601
  return input_ids
1602
 
1603
+ def _get_stops(self, generation_config, tokenizer, model_kwargs):
1604
+ stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1605
+ if generation_config.eos_token_id is not None:
1606
+ try:
1607
+ stop_tokens.update(generation_config.eos_token_id)
1608
+ except TypeError:
1609
+ stop_tokens.add(generation_config.eos_token_id)
1610
+ if "stopping_criteria" in model_kwargs and tokenizer is None:
1611
+ tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1612
+ if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1613
+ for s in generation_config.stop_strings:
1614
+ token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1615
+ stop_tokens.add(token_id)
1616
+ return torch.tensor(list(stop_tokens))
1617
+
1618
+ def _sample_next_token(self, next_token_logits, generation_config):
1619
+ """Helper function to sample the next token."""
1620
+ if generation_config.do_sample:
1621
+ if generation_config.temperature:
1622
+ next_token_logits = next_token_logits.float() / generation_config.temperature
1623
+
1624
+ probs = F.softmax(next_token_logits, dim=-1)
1625
+
1626
+ # Apply top_k
1627
+ if generation_config.top_k:
1628
+ top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1629
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1630
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1631
+
1632
+ # Apply top_p (nucleus sampling)
1633
+ if generation_config.top_p:
1634
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1635
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1636
+
1637
+ # Create mask for probs to keep
1638
+ remove_indices = cumulative_probs > generation_config.top_p
1639
+ remove_indices[:, 0] = False # Keep at least the top probability
1640
+
1641
+ # Convert sorted indices mask back to original indices mask
1642
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1643
+ for i in range(probs.shape[0]):
1644
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1645
+
1646
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1647
+
1648
+ # Apply min_p
1649
+ if generation_config.min_p:
1650
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1651
+ min_p_threshold = generation_config.min_p * max_probs
1652
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1653
+
1654
+ # Renormalize probabilities
1655
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1656
+
1657
+ # Sample from the distribution
1658
+ return torch.multinomial(probs, num_samples=1)
1659
+ else:
1660
+ return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1661
+
1662
+
1663
+ ################################ Model Utils #######################################################################
1664
 
1665
 
 
1666
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
1667
  with torch.autocast("cuda", enabled=False):
1668
  inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
 
1688
  return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
1689
 
1690
 
1691
+ #################################### Adaptive Compute Exit Evaluators ##########################################
1692
+
1693
+ Exit = Tuple[torch.Tensor, Optional[CausalLMOutputRecurrentLatents], torch.Tensor]
1694
+
1695
+
1696
+ class PerIterationExitEvaluator:
1697
+ """Base class for exit evaluators that check after each recurrent step."""
1698
+
1699
+ def init(self, initial_latents: torch.Tensor):
1700
+ """Initialize evaluator state."""
1701
+
1702
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1703
+ """Returns (should_exit, outputs (or None), exit_values)"""
1704
+ raise NotImplementedError()
1705
+
1706
+
1707
+ class NoOpExitEvaluator(PerIterationExitEvaluator):
1708
+ """Exit evaluator that never exits early."""
1709
+
1710
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1711
+ return (
1712
+ torch.zeros(latents.shape[0], device=latents.device, dtype=torch.bool),
1713
+ None,
1714
+ torch.zeros(latents.shape[0], device=latents.device),
1715
+ )
1716
+
1717
+
1718
+ class EntropyDiffExitEvaluator(PerIterationExitEvaluator):
1719
+ """Exit based on change in output entropy."""
1720
+
1721
+ def __init__(self, exit_threshold: Union[str, float] = "auto"):
1722
+ self.exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1723
+
1724
+ def init(self, initial_latents: torch.Tensor):
1725
+ batch_size = initial_latents.shape[0]
1726
+ self.prev_entropy = torch.ones(batch_size, device=initial_latents.device) * 100.0
1727
+
1728
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1729
+ outputs = model.predict_from_latents(latents, **aux_inputs)
1730
+ logits: torch.Tensor = outputs.logits # type: ignore
1731
+ probs = F.softmax(logits[:, -1, :], dim=-1)
1732
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1733
+ exit_values = (entropy - self.prev_entropy).abs()
1734
+ self.prev_entropy = entropy
1735
+ return exit_values < self.exit_threshold, outputs, exit_values
1736
+
1737
+
1738
+ class LatentDiffExitEvaluator(PerIterationExitEvaluator):
1739
+ """Exit based on change in latent states."""
1740
+
1741
+ def __init__(self, exit_threshold: Union[str, float] = "auto"):
1742
+ self.exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1743
+
1744
+ def init(self, initial_latents: torch.Tensor):
1745
+ self.prev_latents = initial_latents.clone().detach()
1746
+
1747
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1748
+ exit_values = ((latents - self.prev_latents).norm(dim=-1) / latents.norm(dim=-1)).mean(dim=-1)
1749
+ self.prev_latents = latents.clone().detach()
1750
+ return exit_values < self.exit_threshold, None, exit_values
1751
+
1752
+
1753
+ class KLExitEvaluator(PerIterationExitEvaluator):
1754
+ """Exit based on KL divergence between successive outputs."""
1755
+
1756
+ def __init__(self, model: "RavenForCausalLM", exit_threshold: Union[str, float] = "auto"):
1757
+ self.exit_threshold = 0.001 if exit_threshold == "auto" else float(exit_threshold)
1758
+ self.V = model.config.padded_vocab_size
1759
+
1760
+ def init(self, initial_latents: torch.Tensor):
1761
+ batch_size = initial_latents.shape[0]
1762
+ self.prev_log_probs = ((1 / self.V) * torch.ones(batch_size, self.V, device=initial_latents.device)).log()
1763
+
1764
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1765
+ outputs = model.predict_from_latents(latents, **aux_inputs)
1766
+ logits: torch.Tensor = outputs.logits # type: ignore
1767
+ log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1768
+ exit_values = F.kl_div(log_probs, self.prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1769
+ self.prev_log_probs = log_probs
1770
+ return exit_values < self.exit_threshold, outputs, exit_values
1771
+
1772
+
1773
+ class MinKLExitEvaluator(PerIterationExitEvaluator):
1774
+ """Exit based on min-p filtered KL divergence."""
1775
+
1776
+ def __init__(self, model: "RavenForCausalLM", exit_threshold: Union[str, float] = "auto"):
1777
+ self.exit_threshold = 1e-5 if exit_threshold == "auto" else float(exit_threshold)
1778
+ self.V = model.config.padded_vocab_size
1779
+
1780
+ def init(self, initial_latents: torch.Tensor):
1781
+ batch_size = initial_latents.shape[0]
1782
+ self.prev_log_probs = ((1 / self.V) * torch.ones(batch_size, self.V, device=initial_latents.device)).log()
1783
+
1784
+ def _calc_minp_log_probs(self, logits: torch.Tensor) -> torch.Tensor:
1785
+ probs = F.softmax(logits[:, -1, :], dim=-1)
1786
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1787
+ probs_mask = probs < (0.1 * max_probs)
1788
+ masked_probs = probs
1789
+ masked_probs[probs_mask] = 1 / self.V
1790
+ probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1791
+ return probs.log()
1792
+
1793
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1794
+ outputs = model.predict_from_latents(latents, **aux_inputs)
1795
+ logits: torch.Tensor = outputs.logits # type: ignore
1796
+ log_probs = self._calc_minp_log_probs(logits)
1797
+ exit_values = F.kl_div(log_probs, self.prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1798
+ self.prev_log_probs = log_probs
1799
+ return exit_values < self.exit_threshold, outputs, exit_values
1800
+
1801
+
1802
+ class ArgmaxStabilityExitEvaluator(PerIterationExitEvaluator):
1803
+ """Exit based on argmax stability over consecutive steps."""
1804
+
1805
+ def __init__(self, exit_threshold: Union[str, int] = "auto"):
1806
+ self.exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1807
+
1808
+ def init(self, initial_latents: torch.Tensor):
1809
+ batch_size = initial_latents.shape[0]
1810
+ self.prev_argmax = torch.ones(batch_size, dtype=torch.long, device=initial_latents.device) * -1
1811
+ self.stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=initial_latents.device)
1812
+
1813
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1814
+ outputs = model.predict_from_latents(latents, **aux_inputs)
1815
+ logits: torch.Tensor = outputs.logits # type: ignore
1816
+ current_argmax = logits[:, -1, :].argmax(dim=-1)
1817
+ stable_for_n_steps = torch.where(
1818
+ current_argmax == self.prev_argmax, self.stable_for_n_steps + 1, torch.zeros_like(self.stable_for_n_steps)
1819
+ )
1820
+ exit_values = stable_for_n_steps
1821
+ self.prev_argmax = current_argmax
1822
+ self.stable_for_n_steps = stable_for_n_steps
1823
+ return exit_values >= self.exit_threshold, outputs, exit_values
1824
+
1825
+
1826
+ class CosineExitEvaluator(PerIterationExitEvaluator):
1827
+ """Exit based on cosine similarity between successive latent states."""
1828
+
1829
+ def __init__(self, exit_threshold: Union[str, float] = "auto"):
1830
+ self.exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1831
+
1832
+ def init(self, initial_latents: torch.Tensor):
1833
+ self.prev_latents = initial_latents.clone().detach()
1834
+
1835
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1836
+ cosine_sim = (
1837
+ (latents * self.prev_latents).sum(dim=-1) / latents.norm(dim=-1) / self.prev_latents.norm(dim=-1)
1838
+ ).mean(dim=1)
1839
+ exit_values = 1 - cosine_sim
1840
+ self.prev_latents = latents.clone().detach()
1841
+ return exit_values < self.exit_threshold, None, exit_values
1842
+
1843
+
1844
+ class NumStepsGenerator(PerIterationExitEvaluator):
1845
+ def __init__(self, steps_fn: Callable):
1846
+ self.steps_fn = steps_fn
1847
+ self.counter = 0
1848
+ self.target_steps = 0
1849
+ self.current_step = 0
1850
+
1851
+ def init(self, initial_latents):
1852
+ self.target_steps = self.steps_fn(self.counter)
1853
+ self.counter += 1
1854
+ self.current_step = 0
1855
+
1856
+ def check(self, model: "RavenForCausalLM", latents: torch.Tensor, aux_inputs: dict) -> Exit:
1857
+ self.current_step += 1
1858
+ should_exit = self.current_step >= self.target_steps
1859
+ return (
1860
+ torch.full((latents.shape[0],), should_exit, dtype=torch.bool, device=latents.device),
1861
+ None,
1862
+ torch.zeros(latents.shape[0], device=latents.device),
1863
+ )
1864
+
1865
+
1866
+ def get_adaptive_exit_evaluator(
1867
+ model: "RavenForCausalLM", criterion: str, exit_threshold: Union[str, float, int]
1868
+ ) -> PerIterationExitEvaluator:
1869
+ """Factory function to create appropriate exit evaluator."""
1870
+ if criterion == "entropy-diff":
1871
+ return EntropyDiffExitEvaluator(exit_threshold)
1872
+ elif criterion == "latent-diff":
1873
+ return LatentDiffExitEvaluator(exit_threshold)
1874
+ elif criterion == "cosine":
1875
+ return CosineExitEvaluator(exit_threshold)
1876
+ elif "kl" in criterion:
1877
+ if criterion == "minp-kl":
1878
+ return MinKLExitEvaluator(model, exit_threshold)
1879
+ else:
1880
+ return KLExitEvaluator(model, exit_threshold)
1881
+ elif criterion == "argmax-stability":
1882
+ return ArgmaxStabilityExitEvaluator(exit_threshold) # type: ignore
1883
+ elif criterion == "none":
1884
+ return NoOpExitEvaluator()
1885
+ else:
1886
+ raise ValueError(f"Invalid adaptive compute strategy: {criterion}")
1887
+
1888
+
1889
  #################################### HF registration ############################################################
1890
 
1891
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM