File size: 5,303 Bytes
ca2139a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from dataclasses import dataclass, field
from typing import Optional, Literal
import torch
import pyrallis
from transformers import PretrainedConfig
from typing import Optional
from dataclasses import asdict


@dataclass
class TrainingConfig:
    # Model settings
    model_name: str = "unsloth/Meta-Llama-3.1-8B"
    layer: int = 12
    hook_point: str = "resid_mid"
    act_size: Optional[int] = None  # Will be set after model initialization
    
    # SAE settings
    sae_type: str = "batchtopk"
    dict_size: int = 2**15
    aux_penalty: float = 1/32
    input_unit_norm: bool = True
    
    # TopK specific settings
    top_k: int = 50
    top_k_warmup_steps_fraction: float = 0.1
    start_top_k: int = 4096
    top_k_aux: int = 512

    n_batches_to_dead: int = 10
    
    # Training settings
    lr: float = 3e-4
    bandwidth: float = 0.001
    l1_coeff: float = 0.0018
    num_tokens: int = int(1e9)
    seq_len: int = 1024
    model_batch_size: int = 16
    num_batches_in_buffer: int = 5
    max_grad_norm: float = 1.0
    batch_size: int = 8192

    # scheduler
    warmup_fraction: float = 0.1
    scheduler_type: str = 'linear'
    
    # Hardware settings
    device: str = "cuda"
    dtype: torch.dtype = field(default=torch.float32)
    sae_dtype: torch.dtype = field(default=torch.float32)
    
    # Dataset settings
    dataset_path: str = "cerebras/SlimPajama-627B"
    
    # Logging settings
    wandb_project: str = "turbo-llama-lens"

    performance_log_steps: int = 100
    save_checkpoint_steps: int = 10_000
    def __post_init__(self):
        if self.device == "cuda" and not torch.cuda.is_available():
            print("CUDA not available, falling back to CPU")
            self.device = "cpu"
        
        # Convert string dtype to torch.dtype if needed
        if isinstance(self.dtype, str):
            self.dtype = getattr(torch, self.dtype)


class SAEConfig(PretrainedConfig):
    model_type = "sae"
    
    def __init__(
        self,
        # SAE architecture
        act_size: int = None,
        dict_size: int = 2**15,
        sae_type: str = "batchtopk",
        input_unit_norm: bool = True,
        
        # TopK specific settings
        top_k: int = 50,
        top_k_aux: int = 512,
        n_batches_to_dead: int = 10,
        
        # Training hyperparameters
        aux_penalty: float = 1/32,
        l1_coeff: float = 0.0018,
        bandwidth: float = 0.001,
        
        # Hardware settings
        dtype: str = "float32",
        sae_dtype: str = "float32",
        
        # Optional parent model info
        parent_model_name: Optional[str] = None,
        parent_layer: Optional[int] = None,
        parent_hook_point: Optional[str] = None,
        
        **kwargs
    ):
        super().__init__(**kwargs)
        self.act_size = act_size
        self.dict_size = dict_size
        self.sae_type = sae_type
        self.input_unit_norm = input_unit_norm
        
        self.top_k = top_k
        self.top_k_aux = top_k_aux
        self.n_batches_to_dead = n_batches_to_dead
        
        self.aux_penalty = aux_penalty
        self.l1_coeff = l1_coeff
        self.bandwidth = bandwidth
        
        self.dtype = dtype
        self.sae_dtype = sae_dtype
        
        self.parent_model_name = parent_model_name
        self.parent_layer = parent_layer
        self.parent_hook_point = parent_hook_point
    
    def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
        dtype_map = {
            "float32": torch.float32,
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
        }
        return dtype_map.get(dtype_str, torch.float32)
    
    @classmethod
    def from_training_config(cls, cfg: TrainingConfig):
        """Convert TrainingConfig to SAEConfig"""
        return cls(
            act_size=cfg.act_size,
            dict_size=cfg.dict_size,
            sae_type=cfg.sae_type,
            input_unit_norm=cfg.input_unit_norm,
            top_k=cfg.top_k,
            top_k_aux=cfg.top_k_aux,
            n_batches_to_dead=cfg.n_batches_to_dead,
            aux_penalty=cfg.aux_penalty,
            l1_coeff=cfg.l1_coeff,
            bandwidth=cfg.bandwidth,
            dtype=str(cfg.dtype).split('.')[-1],
            sae_dtype=str(cfg.sae_dtype).split('.')[-1],
            parent_model_name=cfg.model_name,
            parent_layer=cfg.layer,
            parent_hook_point=cfg.hook_point,
        )
    
    def to_training_config(self) -> TrainingConfig:
        """Convert SAEConfig back to TrainingConfig"""
        config_dict = asdict(self)
        config_dict['dtype'] = self.get_torch_dtype(self.dtype)
        config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype)
        config_dict['model_name'] = self.parent_model_name
        config_dict['layer'] = self.parent_layer
        config_dict['hook_point'] = self.parent_hook_point
        return TrainingConfig(**config_dict)


@pyrallis.wrap()
def get_config() -> TrainingConfig:
    return TrainingConfig()


# For backward compatibility
def get_default_cfg() -> TrainingConfig:
    return get_config()


def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig:
    """
    Any additional configuration setup that needs to happen after model initialization
    """
    return cfg