File size: 3,371 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
from typing import Dict, List, TYPE_CHECKING

import torch
from sklearn.cluster import KMeans

if TYPE_CHECKING:  # pragma: no cover
    from .model import BitTransformerLM


class TelemetrySynthesizer:
    """Analyze telemetry batches and cluster activation patterns."""

    def __init__(self, n_clusters: int = 2) -> None:
        self.n_clusters = n_clusters

    def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
        """Compute activation/attention summaries for a single telemetry dict."""
        acts = telemetry["activations"]
        attn = telemetry["attention_maps"]
        summaries = []
        for a, m in zip(acts, attn):
            mean = a.mean().item()
            var = a.var(unbiased=False).item()
            prob = m.softmax(-1)
            entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
            summaries.append([mean, var, entropy])
        return np.array(summaries).ravel()

    def synthesize(
        self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
    ) -> Dict[str, List]:
        """Cluster telemetry summaries and return cluster info."""
        data = np.stack([self._summary(t) for t in telemetries])
        km = KMeans(n_clusters=self.n_clusters, n_init=1)
        labels = km.fit_predict(data)
        representatives: List[List[int]] = []
        for c in range(self.n_clusters):
            idx = np.where(labels == c)[0]
            if len(idx) > 0:
                representatives.append(bit_seqs[idx[0]].tolist())
            else:
                representatives.append([])
        return {"cluster_assignments": labels.tolist(), "representatives": representatives}

    def cluster_sequences(
        self, model: "BitTransformerLM", bit_seqs: torch.Tensor
    ) -> List[List[int]]:
        """Run the model to gather telemetry and return representative sequences.

        Parameters
        ----------
        model: BitTransformerLM
            Model used to compute telemetry for each sequence.
        bit_seqs: torch.Tensor
            Tensor containing one bit sequence per row.

        Returns
        -------
        list[list[int]]
            Representative sequences chosen from KMeans clusters.
        """
        telemetries: List[Dict[str, List[torch.Tensor]]] = []
        with torch.no_grad():
            for seq in bit_seqs:
                _, tele = model(seq.unsqueeze(0))
                telemetries.append(tele)
        info = self.synthesize(telemetries, bit_seqs)
        return info["representatives"]


def detect_metric_drift(
    metrics_log: Dict[str, List[float]],
    window: int = 10,
    threshold: float = 0.2,
) -> Dict[str, bool]:
    """Detect metric drift between consecutive windows.

    Args:
        metrics_log: History of scalar metrics keyed by name.
        window: Number of recent steps to compare.
        threshold: Absolute difference required to flag drift.

    Returns:
        Dictionary mapping metric keys to a boolean drift indicator.
    """
    drift = {}
    for key, values in metrics_log.items():
        if len(values) < window * 2:
            drift[key] = False
            continue
        recent = np.mean(values[-window:])
        prev = np.mean(values[-2 * window : -window])
        drift[key] = abs(recent - prev) > threshold
    return drift