Create README.md
Browse files
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,91 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            license: apache-2.0
         | 
| 3 | 
            +
            ---
         | 
| 4 | 
            +
            # TRM Model - Pretrained on ARC AGI II
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            This repo contains model checkpoints.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            See Wandb logs [here](https://wandb.ai/trelis/Arc2concept-aug-1000-ACT-torch).
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            We ran on 4xH200 SXM for ~48 hours.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            Final score ~ 10.5%.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Config:
         | 
| 15 | 
            +
            ```yaml
         | 
| 16 | 
            +
            # ARC training config
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            defaults:
         | 
| 19 | 
            +
              - arch: trm
         | 
| 20 | 
            +
              - _self_
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            hydra:
         | 
| 23 | 
            +
              output_subdir: null
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            entity: "trelis"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Data path
         | 
| 28 | 
            +
            data_paths: ['data/arc2concept-aug-1000']
         | 
| 29 | 
            +
            data_paths_test: []
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            evaluators:
         | 
| 32 | 
            +
              - name: arc@ARC
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Hyperparams - Training
         | 
| 35 | 
            +
            global_batch_size: 768
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            epochs: 100000
         | 
| 38 | 
            +
            eval_interval: 10000
         | 
| 39 | 
            +
            checkpoint_every_eval: True
         | 
| 40 | 
            +
            checkpoint_every_n_steps: null
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            lr: 1e-4
         | 
| 43 | 
            +
            lr_min_ratio: 1.0
         | 
| 44 | 
            +
            lr_warmup_steps: 2000
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # Standard hyperparameter settings for LM, as used in Llama
         | 
| 47 | 
            +
            beta1: 0.9
         | 
| 48 | 
            +
            beta2: 0.95
         | 
| 49 | 
            +
            weight_decay: 0.1
         | 
| 50 | 
            +
            puzzle_emb_weight_decay: 0.1
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Hyperparams - Puzzle embeddings training
         | 
| 53 | 
            +
            puzzle_emb_lr: 1e-2
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            seed: 0
         | 
| 56 | 
            +
            min_eval_interval: 0 # when to start the eval
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            ema: True # use Exponential-Moving-Average
         | 
| 59 | 
            +
            ema_rate: 0.999 # EMA-rate
         | 
| 60 | 
            +
            freeze_weights: False # If True, freeze weights and only learn the embeddings
         | 
| 61 | 
            +
            ```
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            trm.yaml
         | 
| 64 | 
            +
            ```yaml
         | 
| 65 | 
            +
            name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
         | 
| 66 | 
            +
            loss:
         | 
| 67 | 
            +
              name: losses@ACTLossHead
         | 
| 68 | 
            +
              loss_type: stablemax_cross_entropy
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            halt_exploration_prob: 0.1
         | 
| 71 | 
            +
            halt_max_steps: 16
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            H_cycles: 3
         | 
| 74 | 
            +
            L_cycles: 4 # NOTE THAT THIS IS DIFFERENT THAN THE PAPER, THAT USES 6. THE DIFFERENCE WAS ACCIDENTAL.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            H_layers: 0
         | 
| 77 | 
            +
            L_layers: 2
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            hidden_size: 512
         | 
| 80 | 
            +
            num_heads: 8  # min(2, hidden_size // 64)
         | 
| 81 | 
            +
            expansion: 4
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            puzzle_emb_ndim: ${.hidden_size}
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            pos_encodings: rope
         | 
| 86 | 
            +
            forward_dtype: bfloat16
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            mlp_t: False # use mlp on L instead of transformer
         | 
| 89 | 
            +
            puzzle_emb_len: 16 # if non-zero, its specified to this value
         | 
| 90 | 
            +
            no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
         | 
| 91 | 
            +
            ```
         | 

