WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
3.37 kB
import numpy as np
from typing import Dict, List, TYPE_CHECKING
import torch
from sklearn.cluster import KMeans
if TYPE_CHECKING: # pragma: no cover
from .model import BitTransformerLM
class TelemetrySynthesizer:
"""Analyze telemetry batches and cluster activation patterns."""
def __init__(self, n_clusters: int = 2) -> None:
self.n_clusters = n_clusters
def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
"""Compute activation/attention summaries for a single telemetry dict."""
acts = telemetry["activations"]
attn = telemetry["attention_maps"]
summaries = []
for a, m in zip(acts, attn):
mean = a.mean().item()
var = a.var(unbiased=False).item()
prob = m.softmax(-1)
entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
summaries.append([mean, var, entropy])
return np.array(summaries).ravel()
def synthesize(
self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
) -> Dict[str, List]:
"""Cluster telemetry summaries and return cluster info."""
data = np.stack([self._summary(t) for t in telemetries])
km = KMeans(n_clusters=self.n_clusters, n_init=1)
labels = km.fit_predict(data)
representatives: List[List[int]] = []
for c in range(self.n_clusters):
idx = np.where(labels == c)[0]
if len(idx) > 0:
representatives.append(bit_seqs[idx[0]].tolist())
else:
representatives.append([])
return {"cluster_assignments": labels.tolist(), "representatives": representatives}
def cluster_sequences(
self, model: "BitTransformerLM", bit_seqs: torch.Tensor
) -> List[List[int]]:
"""Run the model to gather telemetry and return representative sequences.
Parameters
----------
model: BitTransformerLM
Model used to compute telemetry for each sequence.
bit_seqs: torch.Tensor
Tensor containing one bit sequence per row.
Returns
-------
list[list[int]]
Representative sequences chosen from KMeans clusters.
"""
telemetries: List[Dict[str, List[torch.Tensor]]] = []
with torch.no_grad():
for seq in bit_seqs:
_, tele = model(seq.unsqueeze(0))
telemetries.append(tele)
info = self.synthesize(telemetries, bit_seqs)
return info["representatives"]
def detect_metric_drift(
metrics_log: Dict[str, List[float]],
window: int = 10,
threshold: float = 0.2,
) -> Dict[str, bool]:
"""Detect metric drift between consecutive windows.
Args:
metrics_log: History of scalar metrics keyed by name.
window: Number of recent steps to compare.
threshold: Absolute difference required to flag drift.
Returns:
Dictionary mapping metric keys to a boolean drift indicator.
"""
drift = {}
for key, values in metrics_log.items():
if len(values) < window * 2:
drift[key] = False
continue
recent = np.mean(values[-window:])
prev = np.mean(values[-2 * window : -window])
drift[key] = abs(recent - prev) > threshold
return drift