|
import numpy as np |
|
from typing import Dict, List, TYPE_CHECKING |
|
|
|
import torch |
|
from sklearn.cluster import KMeans |
|
|
|
if TYPE_CHECKING: |
|
from .model import BitTransformerLM |
|
|
|
|
|
class TelemetrySynthesizer: |
|
"""Analyze telemetry batches and cluster activation patterns.""" |
|
|
|
def __init__(self, n_clusters: int = 2) -> None: |
|
self.n_clusters = n_clusters |
|
|
|
def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray: |
|
"""Compute activation/attention summaries for a single telemetry dict.""" |
|
acts = telemetry["activations"] |
|
attn = telemetry["attention_maps"] |
|
summaries = [] |
|
for a, m in zip(acts, attn): |
|
mean = a.mean().item() |
|
var = a.var(unbiased=False).item() |
|
prob = m.softmax(-1) |
|
entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item() |
|
summaries.append([mean, var, entropy]) |
|
return np.array(summaries).ravel() |
|
|
|
def synthesize( |
|
self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor |
|
) -> Dict[str, List]: |
|
"""Cluster telemetry summaries and return cluster info.""" |
|
data = np.stack([self._summary(t) for t in telemetries]) |
|
km = KMeans(n_clusters=self.n_clusters, n_init=1) |
|
labels = km.fit_predict(data) |
|
representatives: List[List[int]] = [] |
|
for c in range(self.n_clusters): |
|
idx = np.where(labels == c)[0] |
|
if len(idx) > 0: |
|
representatives.append(bit_seqs[idx[0]].tolist()) |
|
else: |
|
representatives.append([]) |
|
return {"cluster_assignments": labels.tolist(), "representatives": representatives} |
|
|
|
def cluster_sequences( |
|
self, model: "BitTransformerLM", bit_seqs: torch.Tensor |
|
) -> List[List[int]]: |
|
"""Run the model to gather telemetry and return representative sequences. |
|
|
|
Parameters |
|
---------- |
|
model: BitTransformerLM |
|
Model used to compute telemetry for each sequence. |
|
bit_seqs: torch.Tensor |
|
Tensor containing one bit sequence per row. |
|
|
|
Returns |
|
------- |
|
list[list[int]] |
|
Representative sequences chosen from KMeans clusters. |
|
""" |
|
telemetries: List[Dict[str, List[torch.Tensor]]] = [] |
|
with torch.no_grad(): |
|
for seq in bit_seqs: |
|
_, tele = model(seq.unsqueeze(0)) |
|
telemetries.append(tele) |
|
info = self.synthesize(telemetries, bit_seqs) |
|
return info["representatives"] |
|
|
|
|
|
def detect_metric_drift( |
|
metrics_log: Dict[str, List[float]], |
|
window: int = 10, |
|
threshold: float = 0.2, |
|
) -> Dict[str, bool]: |
|
"""Detect metric drift between consecutive windows. |
|
|
|
Args: |
|
metrics_log: History of scalar metrics keyed by name. |
|
window: Number of recent steps to compare. |
|
threshold: Absolute difference required to flag drift. |
|
|
|
Returns: |
|
Dictionary mapping metric keys to a boolean drift indicator. |
|
""" |
|
drift = {} |
|
for key, values in metrics_log.items(): |
|
if len(values) < window * 2: |
|
drift[key] = False |
|
continue |
|
recent = np.mean(values[-window:]) |
|
prev = np.mean(values[-2 * window : -window]) |
|
drift[key] = abs(recent - prev) > threshold |
|
return drift |
|
|