#!/usr/bin/env python3 # fmt: off import argparse import json import math from dataclasses import dataclass, asdict from pathlib import Path from typing import Any, Dict, List, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer @dataclass class RunningStat: count: int = 0 sum: float = 0.0 sumsq: float = 0.0 min: Optional[float] = None max: Optional[float] = None zero_count: int = 0 nan_count: int = 0 inf_count: int = 0 def update_from_tensor(self, t: torch.Tensor): with torch.no_grad(): nan_mask = torch.isnan(t) inf_mask = torch.isinf(t) self.nan_count += int(nan_mask.sum().item()) self.inf_count += int(inf_mask.sum().item()) t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) self.zero_count += int((t == 0).sum().item()) tf = t.float() self.sum += float(tf.sum().item()) self.sumsq += float((tf * tf).sum().item()) self.count += t.numel() t_min = float(t.min().item()) t_max = float(t.max().item()) if self.min is None or t_min < self.min: self.min = t_min if self.max is None or t_max > self.max: self.max = t_max @property def mean(self) -> Optional[float]: if self.count == 0: return None return self.sum / self.count @property def var(self) -> Optional[float]: if self.count == 0: return None m = self.mean return max(0.0, self.sumsq / self.count - (m * m)) @property def std(self) -> Optional[float]: v = self.var if v is None: return None return math.sqrt(v) def to_dict(self) -> Dict[str, Any]: d = asdict(self) d["mean"] = self.mean d["std"] = self.std return d @dataclass class TokenRMSStat: count: int = 0 sum: float = 0.0 sumsq: float = 0.0 def update_from_tensor(self, t: torch.Tensor): with torch.no_grad(): if t.ndim == 1: feats = t.unsqueeze(0) else: feats = t.view(-1, t.shape[-1]) rms = feats.float().pow(2).mean(dim=-1).sqrt() rms = torch.nan_to_num(rms, nan=0.0, posinf=0.0, neginf=0.0) self.count += int(rms.numel()) self.sum += float(rms.sum().item()) self.sumsq += float((rms * rms).sum().item()) @property def mean(self) -> Optional[float]: if self.count == 0: return None return self.sum / self.count @property def var(self) -> Optional[float]: if self.count == 0: return None m = self.mean return max(0.0, self.sumsq / self.count - (m * m)) @property def std(self) -> Optional[float]: v = self.var if v is None: return None return math.sqrt(v) def to_dict(self) -> Dict[str, Any]: return { "count": self.count, "mean": self.mean, "std": self.std, } class ActivationMonitor: def __init__(self, use_tensorboard: bool = False, tb_dir: Optional[str] = None): self.stats: Dict[str, RunningStat] = {} self.token_rms: Dict[str, TokenRMSStat] = {} self.use_tensorboard = use_tensorboard self.tb = None self._global_step = 0 if self.use_tensorboard and tb_dir is not None: try: from torch.utils.tensorboard import SummaryWriter self.tb = SummaryWriter(log_dir=tb_dir) except Exception as e: print(f"TensorBoard not available: {e}") def _get_stat(self, name: str) -> RunningStat: if name not in self.stats: self.stats[name] = RunningStat() return self.stats[name] def _get_token_rms(self, name: str) -> TokenRMSStat: if name not in self.token_rms: self.token_rms[name] = TokenRMSStat() return self.token_rms[name] def hook(self, name: str): def _hook(module, inputs, output): with torch.no_grad(): t = output if isinstance(t, tuple): t = t[0] if not isinstance(t, torch.Tensor): return self._get_stat(name).update_from_tensor(t) self._get_token_rms(name).update_from_tensor(t) if self.tb is not None and (self._global_step % 10 == 0): rs = self.stats[name] tr = self.token_rms[name] if rs.count > 0: self.tb.add_scalar( f"{name}/mean", rs.mean, self._global_step ) if rs.std is not None: self.tb.add_scalar( f"{name}/std", rs.std, self._global_step ) self.tb.add_scalar( f"{name}/zero_frac", rs.zero_count / max(1, rs.count), self._global_step, ) if tr.count > 0 and tr.mean is not None: self.tb.add_scalar( f"{name}/token_rms_mean", tr.mean, self._global_step, ) return return _hook def step(self): self._global_step += 1 def close(self): if self.tb is not None: self.tb.flush() self.tb.close() def to_dict(self) -> Dict[str, Any]: out: Dict[str, Any] = {} for k in sorted(self.stats.keys()): out[k] = { "global": self.stats[k].to_dict(), "token_rms": self.token_rms[k].to_dict(), } return out def find_modules_to_hook( model: torch.nn.Module, patterns: List[str] ) -> List[str]: names: List[str] = [] for name, _ in model.named_modules(): lname = name.lower() if not lname.startswith("model.layers."): continue for p in patterns: if p in lname: names.append(name) break return sorted(list(set(names))) def compute_attention_entropy( model: AutoModelForCausalLM, tok: AutoTokenizer, prompts: List[str], max_length: int, input_device: torch.device, ) -> Dict[int, float]: prev = getattr(model.config, "output_attentions", False) model.config.output_attentions = True with torch.inference_mode(): enc = tok( prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) for k in enc: enc[k] = enc[k].to(input_device) out = model(**enc, output_attentions=True, use_cache=False) atts = out.attentions entropies: Dict[int, float] = {} for i, att in enumerate(atts): probs = att.float().clamp_min(1e-12) ent = -(probs * probs.log()).sum(dim=-1) ent_mean = float(ent.mean().item()) entropies[i] = ent_mean model.config.output_attentions = prev return entropies def load_prompts( prompts: Optional[str], prompts_file: Optional[str] ) -> List[str]: lines: List[str] = [] if prompts_file: with open(prompts_file, "r", encoding="utf-8") as f: for line in f: s = line.strip("\n") if s: lines.append(s) if prompts: for s in prompts.split("\n"): s = s.strip() if s: lines.append(s) if not lines: lines = [ "Hello! Briefly introduce yourself.", "Explain the concept of attention in transformers.", "List three use cases for large language models.", ] return lines def main(): ap = argparse.ArgumentParser( description="Activation statistics monitor for HF CausalLM models." ) ap.add_argument("--model", type=str, required=True, help="Model path or HF ID.") ap.add_argument("--prompts", type=str) ap.add_argument("--prompts_file", type=str) ap.add_argument("--max_length", type=int, default=256) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument( "--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], ) ap.add_argument("--device_map", type=str, default="auto") ap.add_argument( "--patterns", type=str, default=( "q_proj,k_proj,v_proj,o_proj,mlp.up_proj,mlp.gate_proj," "mlp.down_proj,layernorm,norm" ), ) ap.add_argument("--save_json", type=str) ap.add_argument("--tensorboard_dir", type=str) ap.add_argument("--attention_entropy", action="store_true") args = ap.parse_args() dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } torch_dtype = dtype_map[args.dtype] print(f"Loading tokenizer/model: {args.model}") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch_dtype, trust_remote_code=True, device_map=args.device_map, ) model.eval() embed_device = model.get_input_embeddings().weight.device print(f"Sending inputs to: {embed_device}") patterns = [p.strip().lower() for p in args.patterns.split(",") if p.strip()] to_hook = find_modules_to_hook(model, patterns) mon = ActivationMonitor( use_tensorboard=args.tensorboard_dir is not None, tb_dir=args.tensorboard_dir, ) handles = [] for name, module in model.named_modules(): if name in to_hook: handles.append(module.register_forward_hook(mon.hook(name))) print(f"Registered hooks on {len(handles)} modules.") prompts = load_prompts(args.prompts, args.prompts_file) with torch.inference_mode(): i = 0 while i < len(prompts): batch_prompts = prompts[i : i + args.batch_size] i += args.batch_size enc = tok( batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_length, ) for k in enc: enc[k] = enc[k].to(embed_device) _ = model(**enc, use_cache=False) mon.step() attn_entropy: Dict[int, float] = {} if args.attention_entropy: subset = prompts[: min(len(prompts), args.batch_size)] attn_entropy = compute_attention_entropy( model, tok, subset, args.max_length, embed_device ) for h in handles: h.remove() mon.close() stats = mon.to_dict() if args.attention_entropy: stats["_attention_entropy"] = attn_entropy print("\nActivation summary (top 10 by token_rms mean):") ranked = sorted( [ (name, d["token_rms"]["mean"] or 0.0) for name, d in stats.items() if name != "_attention_entropy" ], key=lambda x: x[1], reverse=True, )[:10] for name, rms_mean in ranked: g = stats[name]["global"] zero_frac = g.get("zero_count", 0) / max(1, g.get("count", 1)) print( f"- {name}: token_rms_mean={rms_mean:.4f}, " f"mean={g.get('mean'):.4f} std={g.get('std'):.4f} " f"min={g.get('min'):.4f} max={g.get('max'):.4f} " f"zero_frac={zero_frac:.4f}" ) if args.save_json: out_path = Path(args.save_json) out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w") as f: json.dump(stats, f, indent=2) print(f"\nSaved stats JSON to: {out_path}") if __name__ == "__main__": main()