import numpy as np from typing import Dict, List, TYPE_CHECKING import torch from sklearn.cluster import KMeans if TYPE_CHECKING: # pragma: no cover from .model import BitTransformerLM class TelemetrySynthesizer: """Analyze telemetry batches and cluster activation patterns.""" def __init__(self, n_clusters: int = 2) -> None: self.n_clusters = n_clusters def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray: """Compute activation/attention summaries for a single telemetry dict.""" acts = telemetry["activations"] attn = telemetry["attention_maps"] summaries = [] for a, m in zip(acts, attn): mean = a.mean().item() var = a.var(unbiased=False).item() prob = m.softmax(-1) entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item() summaries.append([mean, var, entropy]) return np.array(summaries).ravel() def synthesize( self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor ) -> Dict[str, List]: """Cluster telemetry summaries and return cluster info.""" data = np.stack([self._summary(t) for t in telemetries]) km = KMeans(n_clusters=self.n_clusters, n_init=1) labels = km.fit_predict(data) representatives: List[List[int]] = [] for c in range(self.n_clusters): idx = np.where(labels == c)[0] if len(idx) > 0: representatives.append(bit_seqs[idx[0]].tolist()) else: representatives.append([]) return {"cluster_assignments": labels.tolist(), "representatives": representatives} def cluster_sequences( self, model: "BitTransformerLM", bit_seqs: torch.Tensor ) -> List[List[int]]: """Run the model to gather telemetry and return representative sequences. Parameters ---------- model: BitTransformerLM Model used to compute telemetry for each sequence. bit_seqs: torch.Tensor Tensor containing one bit sequence per row. Returns ------- list[list[int]] Representative sequences chosen from KMeans clusters. """ telemetries: List[Dict[str, List[torch.Tensor]]] = [] with torch.no_grad(): for seq in bit_seqs: _, tele = model(seq.unsqueeze(0)) telemetries.append(tele) info = self.synthesize(telemetries, bit_seqs) return info["representatives"] def detect_metric_drift( metrics_log: Dict[str, List[float]], window: int = 10, threshold: float = 0.2, ) -> Dict[str, bool]: """Detect metric drift between consecutive windows. Args: metrics_log: History of scalar metrics keyed by name. window: Number of recent steps to compare. threshold: Absolute difference required to flag drift. Returns: Dictionary mapping metric keys to a boolean drift indicator. """ drift = {} for key, values in metrics_log.items(): if len(values) < window * 2: drift[key] = False continue recent = np.mean(values[-window:]) prev = np.mean(values[-2 * window : -window]) drift[key] = abs(recent - prev) > threshold return drift