|  | import numpy as np | 
					
						
						|  | from typing import Dict, List, TYPE_CHECKING | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from sklearn.cluster import KMeans | 
					
						
						|  |  | 
					
						
						|  | if TYPE_CHECKING: | 
					
						
						|  | from .model import BitTransformerLM | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TelemetrySynthesizer: | 
					
						
						|  | """Analyze telemetry batches and cluster activation patterns.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, n_clusters: int = 2) -> None: | 
					
						
						|  | self.n_clusters = n_clusters | 
					
						
						|  |  | 
					
						
						|  | def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray: | 
					
						
						|  | """Compute activation/attention summaries for a single telemetry dict.""" | 
					
						
						|  | acts = telemetry["activations"] | 
					
						
						|  | attn = telemetry["attention_maps"] | 
					
						
						|  | summaries = [] | 
					
						
						|  | for a, m in zip(acts, attn): | 
					
						
						|  | mean = a.mean().item() | 
					
						
						|  | var = a.var(unbiased=False).item() | 
					
						
						|  | prob = m.softmax(-1) | 
					
						
						|  | entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item() | 
					
						
						|  | summaries.append([mean, var, entropy]) | 
					
						
						|  | return np.array(summaries).ravel() | 
					
						
						|  |  | 
					
						
						|  | def synthesize( | 
					
						
						|  | self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor | 
					
						
						|  | ) -> Dict[str, List]: | 
					
						
						|  | """Cluster telemetry summaries and return cluster info.""" | 
					
						
						|  | data = np.stack([self._summary(t) for t in telemetries]) | 
					
						
						|  | km = KMeans(n_clusters=self.n_clusters, n_init=1) | 
					
						
						|  | labels = km.fit_predict(data) | 
					
						
						|  | representatives: List[List[int]] = [] | 
					
						
						|  | for c in range(self.n_clusters): | 
					
						
						|  | idx = np.where(labels == c)[0] | 
					
						
						|  | if len(idx) > 0: | 
					
						
						|  | representatives.append(bit_seqs[idx[0]].tolist()) | 
					
						
						|  | else: | 
					
						
						|  | representatives.append([]) | 
					
						
						|  | return {"cluster_assignments": labels.tolist(), "representatives": representatives} | 
					
						
						|  |  | 
					
						
						|  | def cluster_sequences( | 
					
						
						|  | self, model: "BitTransformerLM", bit_seqs: torch.Tensor | 
					
						
						|  | ) -> List[List[int]]: | 
					
						
						|  | """Run the model to gather telemetry and return representative sequences. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | model: BitTransformerLM | 
					
						
						|  | Model used to compute telemetry for each sequence. | 
					
						
						|  | bit_seqs: torch.Tensor | 
					
						
						|  | Tensor containing one bit sequence per row. | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | list[list[int]] | 
					
						
						|  | Representative sequences chosen from KMeans clusters. | 
					
						
						|  | """ | 
					
						
						|  | telemetries: List[Dict[str, List[torch.Tensor]]] = [] | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for seq in bit_seqs: | 
					
						
						|  | _, tele = model(seq.unsqueeze(0)) | 
					
						
						|  | telemetries.append(tele) | 
					
						
						|  | info = self.synthesize(telemetries, bit_seqs) | 
					
						
						|  | return info["representatives"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def detect_metric_drift( | 
					
						
						|  | metrics_log: Dict[str, List[float]], | 
					
						
						|  | window: int = 10, | 
					
						
						|  | threshold: float = 0.2, | 
					
						
						|  | ) -> Dict[str, bool]: | 
					
						
						|  | """Detect metric drift between consecutive windows. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | metrics_log: History of scalar metrics keyed by name. | 
					
						
						|  | window: Number of recent steps to compare. | 
					
						
						|  | threshold: Absolute difference required to flag drift. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary mapping metric keys to a boolean drift indicator. | 
					
						
						|  | """ | 
					
						
						|  | drift = {} | 
					
						
						|  | for key, values in metrics_log.items(): | 
					
						
						|  | if len(values) < window * 2: | 
					
						
						|  | drift[key] = False | 
					
						
						|  | continue | 
					
						
						|  | recent = np.mean(values[-window:]) | 
					
						
						|  | prev = np.mean(values[-2 * window : -window]) | 
					
						
						|  | drift[key] = abs(recent - prev) > threshold | 
					
						
						|  | return drift | 
					
						
						|  |  |