import json import os from typing import Dict, List, Optional, Tuple import torch from .model import BitTransformerLM from .training import train_loop def collapse_submodel( cluster_data: List[List[int]], target_params: Dict, floors: Optional[Dict[str, float]] = None, max_rounds: int = 3, width_scale: float = 1.5, forward_kwargs: Optional[Dict] = None, ) -> Tuple[BitTransformerLM, Dict[str, float]]: """Distill a submodel from clustered bit sequences. The routine deepens the target model when telemetry floors are unmet and, after the first deepening fails, widens the hidden dimensions by ``width_scale`` once before retrying. Returns the distilled model and its final telemetry metrics. """ if floors is None: floors = {"negentropy": 0.5, "lz_complexity": 0.3, "symbiosis_score": 0.5} bit_tensor = torch.tensor(cluster_data, dtype=torch.long) n = len(bit_tensor) split = max(1, int(0.8 * n)) train_bits = bit_tensor[:split] val_bits = bit_tensor[split:] if len(val_bits) == 0: val_bits = train_bits params = target_params.copy() metrics: Dict[str, float] = {} width_scaled = False for round_idx in range(max_rounds): model = BitTransformerLM(**params) train_loop( model, train_bits, epochs=2, compress_prob=0.5, direct_prob=0.0, log=False, forward_kwargs=forward_kwargs, ) with torch.no_grad(): logits, telemetry = model(val_bits, **(forward_kwargs or {})) neg_k = model.negentropy_logits(logits).mean().item() lz_c = model.lz_complexity_logits(logits).mean().item() sym_s = telemetry["symbiosis_score"].mean().item() metrics = { "negentropy": neg_k, "lz_complexity": lz_c, "symbiosis_score": sym_s, } if ( neg_k >= floors["negentropy"] and lz_c >= floors["lz_complexity"] and sym_s >= floors["symbiosis_score"] ): break if round_idx == 0: params["num_layers"] = max(1, params.get("num_layers", 1)) + 1 elif not width_scaled: params["d_model"] = int(params.get("d_model", 32) * width_scale) params["dim_feedforward"] = int( params.get("dim_feedforward", 64) * width_scale ) width_scaled = True else: params["num_layers"] = max(1, params.get("num_layers", 1)) + 1 return model, metrics def save_distilled_model( model: BitTransformerLM, path: str, metrics: Dict[str, float], floors: Optional[Dict[str, float]] = None, ) -> None: """Serialize a distilled model and its metric summary to disk. Weights are written to ``path`` and a ``metrics.json`` file is placed in the same directory containing the achieved metrics alongside the target floors. """ torch.save(model.state_dict(), path) payload = {"metrics": metrics, "floors": floors or {}} metrics_path = os.path.join(os.path.dirname(path), "metrics.json") with open(metrics_path, "w") as f: json.dump(payload, f)