#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Measure per-layer causal influence via gating and/or swapping. # Mapping fix: correctly map composite layer indices to donor indices. import argparse import math from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer def parse_layers(spec: str) -> List[int]: out: List[int] = [] for chunk in spec.split(","): chunk = chunk.strip() if not chunk: continue if "-" in chunk: a, b = chunk.split("-") a, b = int(a), int(b) out.extend(list(range(a, b + 1))) else: out.append(int(chunk)) out = sorted(list({x for x in out})) return out def load_lines( prompts: Optional[str], prompts_file: Optional[str] ) -> List[str]: lines: List[str] = [] if prompts_file: with open(prompts_file, "r", encoding="utf-8") as f: for line in f: s = line.strip("\n") if s: lines.append(s) if prompts: for s in prompts.split("\n"): s = s.strip() if s: lines.append(s) if not lines: lines = [ "You are a helpful assistant. Say hi in one sentence.", "Explain transformers in 2-3 sentences.", "Summarize the benefits of bfloat16 training.", ] return lines def get_embed_device(model: torch.nn.Module) -> torch.device: return model.get_input_embeddings().weight.device @torch.inference_mode() def dataset_nll( model: AutoModelForCausalLM, tok: AutoTokenizer, texts: List[str], max_length: int = 512, batch_size: int = 4, input_device: Optional[torch.device] = None, ) -> Tuple[float, int]: if input_device is None: input_device = get_embed_device(model) total_nll = 0.0 total_tokens = 0 i = 0 while i < len(texts): batch = texts[i : i + batch_size] i += batch_size enc = tok( batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) for k in enc: enc[k] = enc[k].to(input_device) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] labels = input_ids.clone() labels[labels == tok.pad_token_id] = -100 out = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, use_cache=False, ) loss = out.loss valid = labels.ne(-100) n_tokens = int(valid.sum().item()) total_nll += float(loss.item()) * n_tokens total_tokens += n_tokens return total_nll, total_tokens def ppl_from_nll(total_nll: float, total_tokens: int) -> float: if total_tokens == 0: return float("inf") return float(math.exp(total_nll / total_tokens)) @dataclass class GateSpec: layer: int attn_scale: float = 0.0 mlp_scale: float = 0.0 @contextmanager def gate_layer(model: AutoModelForCausalLM, spec: GateSpec): """ Temporarily scale a layer's residual contribution by scaling: - self_attn.o_proj.weight by attn_scale - mlp.down_proj.weight by mlp_scale Using 0.0 disables that sublayer's residual addition. """ backups: List[Tuple[torch.nn.Parameter, torch.Tensor]] = [] def scale_param(p: torch.nn.Parameter, s: float): backups.append((p, p.data.detach().clone())) p.data.mul_(s) layer = model.model.layers[spec.layer] # type: ignore[attr-defined] try: if hasattr(layer.self_attn, "o_proj"): scale_param(layer.self_attn.o_proj.weight, spec.attn_scale) else: raise AttributeError("No o_proj in self_attn") if hasattr(layer.mlp, "down_proj"): scale_param(layer.mlp.down_proj.weight, spec.mlp_scale) else: raise AttributeError("No down_proj in mlp") yield finally: for p, old in backups: p.data.copy_(old) backups.clear() @contextmanager def swap_layer_from_donor( model_dst: AutoModelForCausalLM, model_src: AutoModelForCausalLM, dst_layer_idx: int, src_layer_idx: int, ): """ Temporarily copy all parameters/buffers for dst_layer_idx from model_src's src_layer_idx, then restore. """ dst_prefix = f"model.layers.{dst_layer_idx}." src_prefix = f"model.layers.{src_layer_idx}." src_named_params = dict(model_src.named_parameters()) dst_named_params = dict(model_dst.named_parameters()) src_named_bufs = dict(model_src.named_buffers()) dst_named_bufs = dict(model_dst.named_buffers()) src_params_by_suffix: Dict[str, torch.Tensor] = {} for name, p in src_named_params.items(): if name.startswith(src_prefix): suffix = name[len(src_prefix) :] src_params_by_suffix[suffix] = p src_bufs_by_suffix: Dict[str, torch.Tensor] = {} for name, b in src_named_bufs.items(): if name.startswith(src_prefix): suffix = name[len(src_prefix) :] src_bufs_by_suffix[suffix] = b backups_p: List[Tuple[str, torch.Tensor]] = [] backups_b: List[Tuple[str, torch.Tensor]] = [] try: with torch.no_grad(): for name, p_dst in list(dst_named_params.items()): if not name.startswith(dst_prefix): continue suffix = name[len(dst_prefix) :] if suffix not in src_params_by_suffix: continue p_src = src_params_by_suffix[suffix] backups_p.append((name, p_dst.data.detach().clone())) p_dst.data.copy_( p_src.data.to(device=p_dst.device, dtype=p_dst.dtype) ) for name, b_dst in list(dst_named_bufs.items()): if not name.startswith(dst_prefix): continue suffix = name[len(dst_prefix) :] if suffix not in src_bufs_by_suffix: continue b_src = src_bufs_by_suffix[suffix] backups_b.append((name, b_dst.data.detach().clone())) b_dst.data.copy_( b_src.data.to(device=b_dst.device, dtype=b_dst.dtype) ) yield finally: with torch.no_grad(): for name, old in backups_p: p_dst = dst_named_params[name] p_dst.data.copy_(old) for name, old in backups_b: b_dst = dst_named_bufs[name] b_dst.data.copy_(old) def map_layer_idx( dst_idx: int, dst_total: int, src_total: int, mode: str = "ratio" ) -> int: """ Map a composite (dst) layer index to donor (src) layer index. - ratio (default): floor(dst_idx * src_total / dst_total) - wrap: dst_idx % src_total """ if src_total <= 0: raise ValueError("src_total must be > 0") if mode == "wrap": return dst_idx % src_total x = int(math.floor(dst_idx * src_total / max(1, dst_total))) return max(0, min(src_total - 1, x)) def main(): ap = argparse.ArgumentParser( description="Per-layer influence via gating and/or swapping." ) ap.add_argument("--model", type=str, required=True) ap.add_argument("--donor_model", type=str) ap.add_argument("--layers", type=str, required=True) ap.add_argument("--prompts", type=str) ap.add_argument("--prompts_file", type=str) ap.add_argument("--max_length", type=int, default=512) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument( "--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], ) ap.add_argument("--device_map", type=str, default="auto") ap.add_argument("--gate_scan", action="store_true") ap.add_argument("--swap_scan", action="store_true") ap.add_argument("--attn_only", action="store_true") ap.add_argument("--mlp_only", action="store_true") ap.add_argument( "--swap_map", type=str, default="ratio", choices=["ratio", "wrap"] ) args = ap.parse_args() dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } torch_dtype = dtype_map[args.dtype] layers = parse_layers(args.layers) texts = load_lines(args.prompts, args.prompts_file) print(f"Loading model: {args.model}") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch_dtype, trust_remote_code=True, device_map=args.device_map, ).eval() final_layers = int( getattr(model.config, "num_hidden_layers", len(model.model.layers)) ) print(f"Composite num_hidden_layers: {final_layers}") print("Computing baseline NLL/PPL...") base_nll, base_tokens = dataset_nll( model, tok, texts, max_length=args.max_length, batch_size=args.batch_size, input_device=get_embed_device(model), ) base_ppl = ppl_from_nll(base_nll, base_tokens) print(f"Baseline: tokens={base_tokens} NLL={base_nll:.3f} PPL={base_ppl:.3f}") if args.gate_scan: print("\n== Gate scan (disable residual contribution per layer) ==") a_scale = 0.0 if not args.mlp_only else 1.0 m_scale = 0.0 if not args.attn_only else 1.0 results: List[Tuple[int, float, float]] = [] for L in layers: spec = GateSpec(layer=L, attn_scale=a_scale, mlp_scale=m_scale) with gate_layer(model, spec): nll, ntok = dataset_nll( model, tok, texts, max_length=args.max_length, batch_size=args.batch_size, input_device=get_embed_device(model), ) ppl = ppl_from_nll(nll, ntok) delta_nll = nll - base_nll delta_ppl = ppl - base_ppl results.append((L, delta_nll, delta_ppl)) print( f"Layer {L:>3}: ΔNLL={delta_nll:+.3f} ΔPPL={delta_ppl:+.3f} " f"(NLL={nll:.3f}, PPL={ppl:.3f})" ) if args.swap_scan: if not args.donor_model: raise ValueError("--swap_scan requires --donor_model.") print(f"\nLoading donor model: {args.donor_model}") donor = AutoModelForCausalLM.from_pretrained( args.donor_model, torch_dtype=torch_dtype, trust_remote_code=True, device_map="cpu", ).eval() donor_layers = int( getattr(donor.config, "num_hidden_layers", len(donor.model.layers)) ) print( f"Donor num_hidden_layers: {donor_layers} " f"(mapping mode: {args.swap_map})" ) print("\n== Swap scan (replace composite layer with donor-mapped) ==") results_s: List[Tuple[int, int, float, float]] = [] for L in layers: src_L = map_layer_idx(L, final_layers, donor_layers, mode=args.swap_map) with swap_layer_from_donor(model, donor, L, src_L): nll, ntok = dataset_nll( model, tok, texts, max_length=args.max_length, batch_size=args.batch_size, input_device=get_embed_device(model), ) ppl = ppl_from_nll(nll, ntok) delta_nll = nll - base_nll delta_ppl = ppl - base_ppl results_s.append((L, src_L, delta_nll, delta_ppl)) print( f"Layer {L:>3} <- donor {src_L:>2}: " f"ΔNLL={delta_nll:+.3f} ΔPPL={delta_ppl:+.3f} " f"(NLL={nll:.3f}, PPL={ppl:.3f})" ) if __name__ == "__main__": main()