Robotics
Transformers
Safetensors
English
VLA
Hume-vla commited on
Commit
72bf50e
·
verified ·
1 Parent(s): 60b421a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
array_typing.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, TypeAlias, TypedDict
2
+
3
+ import numpy as np
4
+ from jaxtyping import Float
5
+ from numpy.typing import NDArray
6
+ from torch import Tensor
7
+
8
+
9
+ class InferConfig(TypedDict):
10
+ """Configuration for inference."""
11
+
12
+ replan_steps: int
13
+ s2_replan_steps: int
14
+ s2_candidates_num: int
15
+ noise_temp_lower_bound: float
16
+ noise_temp_upper_bound: float
17
+ time_temp_lower_bound: float
18
+ time_temp_upper_bound: float
19
+ post_process_action: bool
20
+ device: str
21
+
22
+
23
+ ImageArray: TypeAlias = Annotated[NDArray[np.uint8], "Shape[B, H, W, C]"]
24
+ StateArray: TypeAlias = Annotated[
25
+ NDArray[np.float32], "Shape[B, state_horizon, state_dim]"
26
+ ]
27
+ ActionArray: TypeAlias = Annotated[NDArray[np.float32], "Shape[B, action_dim]"]
28
+
29
+ InferBatchObs = TypedDict(
30
+ "BatchObs",
31
+ {
32
+ "observation.images.image": ImageArray,
33
+ "observation.images.wrist_image": ImageArray,
34
+ "observation.state": StateArray,
35
+ "task": list[str],
36
+ },
37
+ )
38
+
39
+
40
+ class InferOutput(TypedDict):
41
+ noise_action: Float[Tensor, "batch s2_chunksize padded_action_dim"]
42
+ s1_action: Float[Tensor, "batch s1_chunksize unpadded_action_dim"]
43
+ s2_action: Float[Tensor, "batch s2_chunksize unpadded_action_dim"]
44
+
45
+
46
+ class CalQlBatch(TypedDict):
47
+ encoded_observations: Float[Tensor, "batch encoded_dim"]
48
+ encoded_next_observations: Float[Tensor, "batch encoded_dim"]
49
+ actions: Float[Tensor, "batch action_dim"]
50
+ rewards: Float[Tensor, " batch"]
51
+ mc_returns: Float[Tensor, " batch"]
52
+ masks: Float[Tensor, " batch"]
53
+
54
+
55
+ class EnvArgs(TypedDict):
56
+ """Environment arguments."""
57
+
58
+ # necessary args
59
+ num_trials_per_task: int
60
+ num_steps_wait: int
61
+ task_suite_name: str
62
+ seed: int
63
+ ckpt_path: str | None
64
+ eval_name: str | None
65
+
66
+
67
+ class Request(TypedDict):
68
+ """Environment receive message."""
69
+
70
+ frame_type: str # "init" | "action"
71
+ env_args: EnvArgs | None
72
+ action: ActionArray | None
73
+
74
+
75
+ class Response(TypedDict):
76
+ """Environment send message."""
77
+
78
+ status: str # "new_episode" | "eval_finished" | "in_episode"
79
+ success_rate: float | None
80
+ observation: InferBatchObs | None
config.json ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_obs_steps": 1,
3
+ "normalization_mapping": {
4
+ "VISUAL": "IDENTITY",
5
+ "STATE": "MEAN_STD",
6
+ "ACTION": "MEAN_STD"
7
+ },
8
+ "input_features": {
9
+ "observation.images.image": {
10
+ "type": "VISUAL",
11
+ "shape": [
12
+ 3,
13
+ 256,
14
+ 256
15
+ ]
16
+ },
17
+ "observation.images.wrist_image": {
18
+ "type": "VISUAL",
19
+ "shape": [
20
+ 3,
21
+ 256,
22
+ 256
23
+ ]
24
+ },
25
+ "observation.state": {
26
+ "type": "STATE",
27
+ "shape": [
28
+ 8
29
+ ]
30
+ }
31
+ },
32
+ "output_features": {
33
+ "action": {
34
+ "type": "ACTION",
35
+ "shape": [
36
+ 7
37
+ ]
38
+ }
39
+ },
40
+ "device": "cpu",
41
+ "use_amp": false,
42
+ "type": "hume",
43
+ "s1_chunk_size": 8,
44
+ "s2_chunk_size": 16,
45
+ "n_action_steps": 16,
46
+ "max_state_dim": 32,
47
+ "max_action_dim": 32,
48
+ "resize_imgs_with_padding": [
49
+ 224,
50
+ 224
51
+ ],
52
+ "empty_cameras": 0,
53
+ "adapt_to_pi_aloha": false,
54
+ "use_delta_joint_actions_aloha": false,
55
+ "tokenizer_max_length": 48,
56
+ "proj_width": 1024,
57
+ "num_steps": 10,
58
+ "use_cache": true,
59
+ "attention_implementation": "eager",
60
+ "freeze_vision_encoder": true,
61
+ "train_expert_only": false,
62
+ "train_state_proj": true,
63
+ "optimizer_lr": 5e-05,
64
+ "optimizer_betas": [
65
+ 0.9,
66
+ 0.95
67
+ ],
68
+ "optimizer_eps": 1e-08,
69
+ "optimizer_weight_decay": 1e-10,
70
+ "scheduler_warmup_steps": 1000,
71
+ "scheduler_decay_steps": 1600000,
72
+ "scheduler_decay_lr": 2.5e-06,
73
+ "freeze_s2": true,
74
+ "s1_his_state_size": 4,
75
+ "cache_s2_actions": false,
76
+ "theta2": 1.0,
77
+ "theta1": 1.0,
78
+ "noise_slides_eps": 0.0,
79
+ "noise_slides_alp": 0.0,
80
+ "s1_proj_width": 512,
81
+ "freeze_s1_vision_encoder": false,
82
+ "s1_num_steps": 10,
83
+ "num_pos": 3,
84
+ "discount": 0.98,
85
+ "actor_lr": 1e-05,
86
+ "critic_lr": 1e-05,
87
+ "temp_lr": 2e-05,
88
+ "qf_lr": 0.0003,
89
+ "next_obs_offset": 1,
90
+ "vqh_chunk_size": 1,
91
+ "paligemma_config": {
92
+ "bos_token_id": 2,
93
+ "eos_token_id": 1,
94
+ "hidden_size": 2048,
95
+ "ignore_index": -100,
96
+ "image_token_index": 257152,
97
+ "model_type": "paligemma",
98
+ "pad_token_id": 0,
99
+ "projection_dim": 2048,
100
+ "text_config": {
101
+ "hidden_activation": "gelu_pytorch_tanh",
102
+ "hidden_size": 2048,
103
+ "intermediate_size": 16384,
104
+ "model_type": "gemma",
105
+ "num_attention_heads": 8,
106
+ "num_hidden_layers": 18,
107
+ "num_image_tokens": 256,
108
+ "num_key_value_heads": 1,
109
+ "torch_dtype": "float32",
110
+ "vocab_size": 257152
111
+ },
112
+ "torch_dtype": "float32",
113
+ "transformers_version": "4.48.1",
114
+ "vision_config": {
115
+ "hidden_size": 1152,
116
+ "intermediate_size": 4304,
117
+ "model_type": "siglip_vision_model",
118
+ "num_attention_heads": 16,
119
+ "num_hidden_layers": 27,
120
+ "num_image_tokens": 256,
121
+ "patch_size": 14,
122
+ "projection_dim": 2048,
123
+ "projector_hidden_act": "gelu_fast",
124
+ "vision_use_head": false
125
+ },
126
+ "vocab_size": 257152
127
+ },
128
+ "gemma_expert_config": {
129
+ "attention_bias": false,
130
+ "attention_dropout": 0.0,
131
+ "bos_token_id": 2,
132
+ "eos_token_id": 1,
133
+ "head_dim": 256,
134
+ "hidden_act": "gelu_pytorch_tanh",
135
+ "hidden_activation": "gelu_pytorch_tanh",
136
+ "hidden_size": 1024,
137
+ "initializer_range": 0.02,
138
+ "intermediate_size": 4096,
139
+ "max_position_embeddings": 8192,
140
+ "model_type": "gemma",
141
+ "num_attention_heads": 8,
142
+ "num_hidden_layers": 18,
143
+ "num_key_value_heads": 1,
144
+ "pad_token_id": 0,
145
+ "rms_norm_eps": 1e-06,
146
+ "rope_theta": 10000.0,
147
+ "torch_dtype": "float32",
148
+ "transformers_version": "4.48.1",
149
+ "use_cache": true,
150
+ "vocab_size": 257152
151
+ },
152
+ "s1_dino_config": {
153
+ "return_dict": true,
154
+ "output_hidden_states": false,
155
+ "output_attentions": false,
156
+ "torchscript": false,
157
+ "torch_dtype": "float32",
158
+ "use_bfloat16": false,
159
+ "tf_legacy_loss": false,
160
+ "pruned_heads": {},
161
+ "tie_word_embeddings": true,
162
+ "chunk_size_feed_forward": 0,
163
+ "is_encoder_decoder": false,
164
+ "is_decoder": false,
165
+ "cross_attention_hidden_size": null,
166
+ "add_cross_attention": false,
167
+ "tie_encoder_decoder": false,
168
+ "max_length": 20,
169
+ "min_length": 0,
170
+ "do_sample": false,
171
+ "early_stopping": false,
172
+ "num_beams": 1,
173
+ "num_beam_groups": 1,
174
+ "diversity_penalty": 0.0,
175
+ "temperature": 1.0,
176
+ "top_k": 50,
177
+ "top_p": 1.0,
178
+ "typical_p": 1.0,
179
+ "repetition_penalty": 1.0,
180
+ "length_penalty": 1.0,
181
+ "no_repeat_ngram_size": 0,
182
+ "encoder_no_repeat_ngram_size": 0,
183
+ "bad_words_ids": null,
184
+ "num_return_sequences": 1,
185
+ "output_scores": false,
186
+ "return_dict_in_generate": false,
187
+ "forced_bos_token_id": null,
188
+ "forced_eos_token_id": null,
189
+ "remove_invalid_values": false,
190
+ "exponential_decay_length_penalty": null,
191
+ "suppress_tokens": null,
192
+ "begin_suppress_tokens": null,
193
+ "architectures": [
194
+ "Dinov2Model"
195
+ ],
196
+ "finetuning_task": null,
197
+ "id2label": {
198
+ "0": "LABEL_0",
199
+ "1": "LABEL_1"
200
+ },
201
+ "label2id": {
202
+ "LABEL_0": 0,
203
+ "LABEL_1": 1
204
+ },
205
+ "tokenizer_class": null,
206
+ "prefix": null,
207
+ "bos_token_id": null,
208
+ "pad_token_id": null,
209
+ "eos_token_id": null,
210
+ "sep_token_id": null,
211
+ "decoder_start_token_id": null,
212
+ "task_specific_params": null,
213
+ "problem_type": null,
214
+ "_name_or_path": "../pretrained/dinov2-small",
215
+ "_attn_implementation_autoset": false,
216
+ "transformers_version": "4.52.0.dev0",
217
+ "model_type": "dinov2",
218
+ "hidden_size": 384,
219
+ "num_hidden_layers": 12,
220
+ "num_attention_heads": 6,
221
+ "mlp_ratio": 4,
222
+ "hidden_act": "gelu",
223
+ "hidden_dropout_prob": 0.0,
224
+ "attention_probs_dropout_prob": 0.0,
225
+ "initializer_range": 0.02,
226
+ "layer_norm_eps": 1e-06,
227
+ "image_size": 518,
228
+ "patch_size": 14,
229
+ "num_channels": 3,
230
+ "qkv_bias": true,
231
+ "layerscale_value": 1.0,
232
+ "drop_path_rate": 0.0,
233
+ "use_swiglu_ffn": false,
234
+ "stage_names": [
235
+ "stem",
236
+ "stage1",
237
+ "stage2",
238
+ "stage3",
239
+ "stage4",
240
+ "stage5",
241
+ "stage6",
242
+ "stage7",
243
+ "stage8",
244
+ "stage9",
245
+ "stage10",
246
+ "stage11",
247
+ "stage12"
248
+ ],
249
+ "apply_layernorm": true,
250
+ "reshape_hidden_states": true,
251
+ "use_mask_token": true,
252
+ "out_features": [
253
+ "stage12"
254
+ ],
255
+ "out_indices": [
256
+ 12
257
+ ]
258
+ },
259
+ "s1_gemma_expert_config": {
260
+ "attention_bias": false,
261
+ "attention_dropout": 0.0,
262
+ "bos_token_id": 2,
263
+ "eos_token_id": 1,
264
+ "head_dim": 128,
265
+ "hidden_act": "gelu_pytorch_tanh",
266
+ "hidden_activation": "gelu_pytorch_tanh",
267
+ "hidden_size": 512,
268
+ "initializer_range": 0.02,
269
+ "intermediate_size": 2048,
270
+ "max_position_embeddings": 8192,
271
+ "model_type": "gemma",
272
+ "num_attention_heads": 8,
273
+ "num_hidden_layers": 13,
274
+ "num_key_value_heads": 1,
275
+ "pad_token_id": 0,
276
+ "rms_norm_eps": 1e-06,
277
+ "rope_theta": 10000.0,
278
+ "torch_dtype": "float32",
279
+ "transformers_version": "4.48.1",
280
+ "use_cache": true,
281
+ "vocab_size": 257152
282
+ }
283
+ }
configuration_hume.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ from lerobot.common.optim.optimizers import AdamWConfig
4
+ from lerobot.common.optim.schedulers import (
5
+ CosineDecayWithWarmupSchedulerConfig,
6
+ )
7
+ from lerobot.configs.policies import PreTrainedConfig
8
+ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
9
+
10
+
11
+ @PreTrainedConfig.register_subclass("hume")
12
+ @dataclass
13
+ class HumeConfig(PreTrainedConfig):
14
+ # Input / output structure.
15
+ type: str = "hume"
16
+ n_obs_steps: int = 1
17
+ s1_chunk_size: int = 10
18
+ s2_chunk_size: int = 50
19
+ n_action_steps: int = 50
20
+
21
+ normalization_mapping: dict[str, NormalizationMode] = field(
22
+ default_factory=lambda: {
23
+ "VISUAL": NormalizationMode.IDENTITY,
24
+ "STATE": NormalizationMode.MEAN_STD,
25
+ "ACTION": NormalizationMode.MEAN_STD,
26
+ }
27
+ )
28
+
29
+ # Shorter state and action vectors will be padded
30
+ max_state_dim: int = 32
31
+ max_action_dim: int = 32
32
+
33
+ # Image preprocessing
34
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
35
+
36
+ # Add empty images. Used by pi0_aloha_sim which adds the empty
37
+ # left and right wrist cameras in addition to the top camera.
38
+ empty_cameras: int = 0
39
+
40
+ # Converts the joint and gripper values from the standard Aloha space to
41
+ # the space used by the pi internal runtime which was used to train the base model.
42
+ adapt_to_pi_aloha: bool = False
43
+
44
+ # Converts joint dimensions to deltas with respect to the current state before passing to the model.
45
+ # Gripper dimensions will remain in absolute values.
46
+ use_delta_joint_actions_aloha: bool = False
47
+
48
+ # Tokenizer
49
+ tokenizer_max_length: int = 48
50
+
51
+ # Projector
52
+ proj_width: int = 1024
53
+
54
+ # Decoding
55
+ num_steps: int = 10
56
+
57
+ # Attention utils
58
+ use_cache: bool = True
59
+ attention_implementation: str = "eager" # or fa2, flex
60
+
61
+ # Finetuning settings
62
+ freeze_vision_encoder: bool = True
63
+ train_expert_only: bool = False
64
+ train_state_proj: bool = True
65
+
66
+ # Training presets
67
+ optimizer_lr: float = 2.5e-5
68
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
69
+ optimizer_eps: float = 1e-8
70
+ optimizer_weight_decay: float = 1e-10
71
+
72
+ scheduler_warmup_steps: int = 1_000
73
+ scheduler_decay_steps: int = 30_000
74
+ scheduler_decay_lr: float = 2.5e-6
75
+
76
+ # + Aadditional attributes for s1 / s2
77
+ # freeze system
78
+ freeze_s2: bool = False
79
+ s1_his_state_size: int = 1
80
+ cache_s2_actions: bool = False
81
+
82
+ # denoise ratio
83
+ theta2: float = 1.0
84
+ theta1: float = 1.0
85
+ noise_slides_eps: float = 0.0
86
+ noise_slides_alp: float = 0.0
87
+
88
+ # projector
89
+ s1_proj_width: int = 512 # NOTE: consitent with the s1_gemma_expert_config
90
+ freeze_s1_vision_encoder: bool = False
91
+
92
+ # decoding
93
+ s1_num_steps: int = 10
94
+
95
+ # vqh
96
+ num_pos: int = 3
97
+ discount: float = 0.98
98
+ actor_lr: float = 1e-4 # actor learning rate
99
+ critic_lr: float = 3e-4
100
+ temp_lr: float = 3e-4
101
+ qf_lr: float = 3e-4 # Critics learning rate
102
+ next_obs_offset: int = 10 # should be equal to vqh_chunk_size
103
+ vqh_chunk_size: int = 10
104
+
105
+ paligemma_config: dict = field(
106
+ default_factory=lambda: {
107
+ "bos_token_id": 2,
108
+ "eos_token_id": 1,
109
+ "hidden_size": 2048,
110
+ "ignore_index": -100,
111
+ "image_token_index": 257152,
112
+ "model_type": "paligemma",
113
+ "pad_token_id": 0,
114
+ "projection_dim": 2048,
115
+ "text_config": {
116
+ "hidden_activation": "gelu_pytorch_tanh",
117
+ "hidden_size": 2048,
118
+ "intermediate_size": 16384,
119
+ "model_type": "gemma",
120
+ "num_attention_heads": 8,
121
+ "num_hidden_layers": 18,
122
+ "num_image_tokens": 256,
123
+ "num_key_value_heads": 1,
124
+ "torch_dtype": "float32",
125
+ "vocab_size": 257152,
126
+ },
127
+ "torch_dtype": "float32",
128
+ "transformers_version": "4.48.1",
129
+ "vision_config": {
130
+ "hidden_size": 1152,
131
+ "intermediate_size": 4304,
132
+ "model_type": "siglip_vision_model",
133
+ "num_attention_heads": 16,
134
+ "num_hidden_layers": 27,
135
+ "num_image_tokens": 256,
136
+ "patch_size": 14,
137
+ "projection_dim": 2048,
138
+ "projector_hidden_act": "gelu_fast",
139
+ "vision_use_head": False,
140
+ },
141
+ "vocab_size": 257152,
142
+ }
143
+ )
144
+
145
+ gemma_expert_config: dict = field(
146
+ default_factory=lambda: {
147
+ "attention_bias": False,
148
+ "attention_dropout": 0.0,
149
+ "bos_token_id": 2,
150
+ "eos_token_id": 1,
151
+ "head_dim": 256,
152
+ "hidden_act": "gelu_pytorch_tanh",
153
+ "hidden_activation": "gelu_pytorch_tanh",
154
+ "hidden_size": 1024,
155
+ "initializer_range": 0.02,
156
+ "intermediate_size": 4096,
157
+ "max_position_embeddings": 8192,
158
+ "model_type": "gemma",
159
+ "num_attention_heads": 8,
160
+ "num_hidden_layers": 18,
161
+ "num_key_value_heads": 1,
162
+ "pad_token_id": 0,
163
+ "rms_norm_eps": 1e-06,
164
+ "rope_theta": 10000.0,
165
+ "torch_dtype": "float32",
166
+ "transformers_version": "4.48.1",
167
+ "use_cache": True,
168
+ "vocab_size": 257152,
169
+ }
170
+ )
171
+
172
+ # TODO: Add EMA
173
+
174
+ # system2 configurations
175
+ s1_dino_config: dict = field(
176
+ default_factory=lambda: {
177
+ "model_type": "dinov2",
178
+ "attention_probs_dropout_prob": 0.0,
179
+ "drop_path_rate": 0.0,
180
+ "hidden_act": "gelu",
181
+ "hidden_dropout_prob": 0.0,
182
+ "hidden_size": 384,
183
+ "image_size": 518,
184
+ "initializer_range": 0.02,
185
+ "layer_norm_eps": 1e-06,
186
+ "layerscale_value": 1.0,
187
+ "mlp_ratio": 4,
188
+ "num_attention_heads": 6,
189
+ "num_channels": 3,
190
+ "num_hidden_layers": 12,
191
+ "patch_size": 14,
192
+ "qkv_bias": True,
193
+ "torch_dtype": "float32",
194
+ "use_swiglu_ffn": False,
195
+ }
196
+ )
197
+
198
+ s1_gemma_expert_config: dict = field(
199
+ default_factory=lambda: {
200
+ "attention_bias": False,
201
+ "attention_dropout": 0.0,
202
+ "bos_token_id": 2,
203
+ "eos_token_id": 1,
204
+ "head_dim": 128,
205
+ "hidden_act": "gelu_pytorch_tanh",
206
+ "hidden_activation": "gelu_pytorch_tanh",
207
+ "hidden_size": 512,
208
+ "initializer_range": 0.02,
209
+ "intermediate_size": 2048,
210
+ "max_position_embeddings": 8192,
211
+ "model_type": "gemma",
212
+ "num_attention_heads": 8,
213
+ "num_hidden_layers": 13,
214
+ "num_key_value_heads": 1,
215
+ "pad_token_id": 0,
216
+ "rms_norm_eps": 1e-06,
217
+ "rope_theta": 10000.0,
218
+ "torch_dtype": "float32",
219
+ "transformers_version": "4.48.1",
220
+ "use_cache": True,
221
+ "vocab_size": 257152,
222
+ }
223
+ )
224
+
225
+ def __post_init__(self):
226
+ super().__post_init__()
227
+
228
+ """Input validation (not exhaustive)."""
229
+ if self.n_action_steps > self.s2_chunk_size:
230
+ raise ValueError(
231
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
232
+ f"{self.n_action_steps} for `n_action_steps` and {self.s2_chunk_size} for `chunk_size`."
233
+ )
234
+ if self.n_obs_steps != 1:
235
+ raise ValueError(
236
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
237
+ )
238
+
239
+ if self.use_delta_joint_actions_aloha:
240
+ raise NotImplementedError(
241
+ "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
242
+ )
243
+
244
+ def validate_features(self) -> None:
245
+ # TODO: implement value error
246
+ # if not self.image_features and not self.env_state_feature:
247
+ # raise ValueError("You must provide at least one image or the environment state among the inputs.")
248
+
249
+ for i in range(self.empty_cameras):
250
+ key = f"observation.images.empty_camera_{i}"
251
+ empty_camera = PolicyFeature(
252
+ type=FeatureType.VISUAL,
253
+ shape=(3, 480, 640),
254
+ )
255
+ self.input_features[key] = empty_camera
256
+
257
+ def get_optimizer_preset(self) -> dict[AdamWConfig]:
258
+ qf_optimizer = AdamWConfig(
259
+ lr=self.qf_lr,
260
+ weight_decay=0,
261
+ grad_clip_norm=10,
262
+ )
263
+ actor_optimizer = AdamWConfig(
264
+ lr=self.actor_lr,
265
+ weight_decay=0,
266
+ grad_clip_norm=10,
267
+ )
268
+
269
+ trunk_optimizer = AdamWConfig(
270
+ lr=self.optimizer_lr,
271
+ betas=self.optimizer_betas,
272
+ eps=self.optimizer_eps,
273
+ weight_decay=self.optimizer_weight_decay,
274
+ )
275
+
276
+ optimizer_dict = dict(
277
+ qf_optimizer=qf_optimizer,
278
+ actor_optimizer=actor_optimizer,
279
+ trunk_optimizer=trunk_optimizer,
280
+ )
281
+
282
+ return optimizer_dict
283
+
284
+ def get_scheduler_preset(self):
285
+ return CosineDecayWithWarmupSchedulerConfig(
286
+ peak_lr=self.optimizer_lr,
287
+ decay_lr=self.scheduler_decay_lr,
288
+ num_warmup_steps=self.scheduler_warmup_steps,
289
+ num_decay_steps=self.scheduler_decay_steps,
290
+ )
291
+
292
+ @property
293
+ def observation_delta_indices(self) -> None:
294
+ return None
295
+
296
+ @property
297
+ def action_delta_indices(self) -> list:
298
+ return list(range(self.s2_chunk_size))
299
+
300
+ @property
301
+ def reward_delta_indices(self) -> None:
302
+ return None
303
+
304
+ @property
305
+ def slide(self) -> None:
306
+ return self.s2_chunk_size // self.s1_chunk_size
307
+
308
+ @property
309
+ def s1_action_steps(self) -> None:
310
+ return self.s1_chunk_size
311
+
312
+ @property
313
+ def s2_action_steps(self) -> None:
314
+ return self.s2_chunk_size
315
+
316
+ from dataclasses import dataclass, field
317
+
318
+ from lerobot.common.optim.optimizers import AdamWConfig
319
+ from lerobot.common.optim.schedulers import (
320
+ CosineDecayWithWarmupSchedulerConfig,
321
+ )
322
+ from lerobot.configs.policies import PreTrainedConfig
323
+ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
324
+
325
+
326
+ @PreTrainedConfig.register_subclass("system2")
327
+ @dataclass
328
+ class System2Config(PreTrainedConfig):
329
+ # Input / output structure.
330
+ num_pos: int = 3
331
+ discount: float = 0.98
332
+ n_obs_steps: int = 1
333
+ chunk_size: int = 50
334
+ n_action_steps: int = 50
335
+ next_obs_offset: int = 1
336
+ s1_his_state_size: int = 1
337
+
338
+ normalization_mapping: dict[str, NormalizationMode] = field(
339
+ default_factory=lambda: {
340
+ "VISUAL": NormalizationMode.IDENTITY,
341
+ "STATE": NormalizationMode.MEAN_STD,
342
+ "ACTION": NormalizationMode.MEAN_STD,
343
+ }
344
+ )
345
+
346
+ # Shorter state and action vectors will be padded
347
+ max_state_dim: int = 32
348
+ max_action_dim: int = 32
349
+
350
+ # Image preprocessing
351
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
352
+
353
+ # Add empty images. Used by pi0_aloha_sim which adds the empty
354
+ # left and right wrist cameras in addition to the top camera.
355
+ empty_cameras: int = 0
356
+
357
+ # Converts the joint and gripper values from the standard Aloha space to
358
+ # the space used by the pi internal runtime which was used to train the base model.
359
+ adapt_to_pi_aloha: bool = False
360
+
361
+ # Converts joint dimensions to deltas with respect to the current state before passing to the model.
362
+ # Gripper dimensions will remain in absolute values.
363
+ use_delta_joint_actions_aloha: bool = False
364
+
365
+ # Tokenizer
366
+ tokenizer_max_length: int = 48
367
+
368
+ # Projector
369
+ proj_width: int = 1024
370
+
371
+ # Decoding
372
+ num_steps: int = 10
373
+
374
+ # Attention utils
375
+ use_cache: bool = True
376
+ attention_implementation: str = "eager" # or fa2, flex
377
+
378
+ # Finetuning settings
379
+ freeze_vision_encoder: bool = True
380
+ train_expert_only: bool = False
381
+ train_state_proj: bool = True
382
+
383
+ # Training presets
384
+ optimizer_lr: float = 2.5e-5
385
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
386
+ optimizer_eps: float = 1e-8
387
+ optimizer_weight_decay: float = 1e-10
388
+
389
+ scheduler_warmup_steps: int = 1_000
390
+ scheduler_decay_steps: int = 30_000
391
+ scheduler_decay_lr: float = 2.5e-6
392
+
393
+ paligemma_config: dict = field(
394
+ default_factory=lambda: {
395
+ "bos_token_id": 2,
396
+ "eos_token_id": 1,
397
+ "hidden_size": 2048,
398
+ "ignore_index": -100,
399
+ "image_token_index": 257152,
400
+ "model_type": "paligemma",
401
+ "pad_token_id": 0,
402
+ "projection_dim": 2048,
403
+ "text_config": {
404
+ "hidden_activation": "gelu_pytorch_tanh",
405
+ "hidden_size": 2048,
406
+ "intermediate_size": 16384,
407
+ "model_type": "gemma",
408
+ "num_attention_heads": 8,
409
+ "num_hidden_layers": 18,
410
+ "num_image_tokens": 256,
411
+ "num_key_value_heads": 1,
412
+ "torch_dtype": "float32",
413
+ "vocab_size": 257152,
414
+ },
415
+ "torch_dtype": "float32",
416
+ "transformers_version": "4.48.1",
417
+ "vision_config": {
418
+ "hidden_size": 1152,
419
+ "intermediate_size": 4304,
420
+ "model_type": "siglip_vision_model",
421
+ "num_attention_heads": 16,
422
+ "num_hidden_layers": 27,
423
+ "num_image_tokens": 256,
424
+ "patch_size": 14,
425
+ "projection_dim": 2048,
426
+ "projector_hidden_act": "gelu_fast",
427
+ "vision_use_head": False,
428
+ },
429
+ "vocab_size": 257152,
430
+ }
431
+ )
432
+
433
+ gemma_expert_config: dict = field(
434
+ default_factory=lambda: {
435
+ "attention_bias": False,
436
+ "attention_dropout": 0.0,
437
+ "bos_token_id": 2,
438
+ "eos_token_id": 1,
439
+ "head_dim": 256,
440
+ "hidden_act": "gelu_pytorch_tanh",
441
+ "hidden_activation": "gelu_pytorch_tanh",
442
+ "hidden_size": 1024,
443
+ "initializer_range": 0.02,
444
+ "intermediate_size": 4096,
445
+ "max_position_embeddings": 8192,
446
+ "model_type": "gemma",
447
+ "num_attention_heads": 8,
448
+ "num_hidden_layers": 18,
449
+ "num_key_value_heads": 1,
450
+ "pad_token_id": 0,
451
+ "rms_norm_eps": 1e-06,
452
+ "rope_theta": 10000.0,
453
+ "torch_dtype": "float32",
454
+ "transformers_version": "4.48.1",
455
+ "use_cache": True,
456
+ "vocab_size": 257152,
457
+ }
458
+ )
459
+
460
+ # TODO: Add EMA
461
+
462
+ def __post_init__(self):
463
+ super().__post_init__()
464
+
465
+ """Input validation (not exhaustive)."""
466
+ if self.n_action_steps > self.chunk_size:
467
+ raise ValueError(
468
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
469
+ f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
470
+ )
471
+ if self.n_obs_steps != 1:
472
+ raise ValueError(
473
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
474
+ )
475
+
476
+ if self.use_delta_joint_actions_aloha:
477
+ raise NotImplementedError(
478
+ "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
479
+ )
480
+
481
+ def validate_features(self) -> None:
482
+ # TODO: implement value error
483
+ # if not self.image_features and not self.env_state_feature:
484
+ # raise ValueError("You must provide at least one image or the environment state among the inputs.")
485
+
486
+ for i in range(self.empty_cameras):
487
+ key = f"observation.images.empty_camera_{i}"
488
+ empty_camera = PolicyFeature(
489
+ type=FeatureType.VISUAL,
490
+ shape=(3, 480, 640),
491
+ )
492
+ self.input_features[key] = empty_camera
493
+
494
+ def get_optimizer_preset(self) -> AdamWConfig:
495
+ return AdamWConfig(
496
+ lr=self.optimizer_lr,
497
+ betas=self.optimizer_betas,
498
+ eps=self.optimizer_eps,
499
+ weight_decay=self.optimizer_weight_decay,
500
+ )
501
+
502
+ def get_scheduler_preset(self):
503
+ return CosineDecayWithWarmupSchedulerConfig(
504
+ peak_lr=self.optimizer_lr,
505
+ decay_lr=self.scheduler_decay_lr,
506
+ num_warmup_steps=self.scheduler_warmup_steps,
507
+ num_decay_steps=self.scheduler_decay_steps,
508
+ )
509
+
510
+ @property
511
+ def observation_delta_indices(self) -> None:
512
+ return None
513
+
514
+ @property
515
+ def action_delta_indices(self) -> list:
516
+ return list(range(self.chunk_size))
517
+
518
+ @property
519
+ def reward_delta_indices(self) -> None:
520
+ return None
521
+
522
+ @property
523
+ def slide(self) -> None:
524
+ return 1
525
+
526
+ @property
527
+ def s1_action_steps(self) -> None:
528
+ return 1
fast_visuo_expert.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import (
6
+ AutoConfig,
7
+ Dinov2Model,
8
+ GemmaForCausalLM,
9
+ PretrainedConfig,
10
+ PreTrainedModel,
11
+ )
12
+ from transformers.models.auto import CONFIG_MAPPING
13
+
14
+
15
+ def apply_rope(x, positions, max_wavelength=10_000):
16
+ """
17
+ Applies RoPE positions [B, L] to x [B, L, H, D].
18
+ """
19
+ d_half = x.shape[-1] // 2
20
+ device = x.device
21
+ dtype = x.dtype
22
+ x = x.to(torch.float32)
23
+
24
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
25
+ d_half, dtype=torch.float32, device=device
26
+ )
27
+ timescale = max_wavelength**freq_exponents
28
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
29
+ torch.float32
30
+ )
31
+
32
+ radians = radians[..., None, :]
33
+
34
+ sin = torch.sin(radians) # .to(dtype=dtype)
35
+ cos = torch.cos(radians) # .to(dtype=dtype)
36
+
37
+ x1, x2 = x.split(d_half, dim=-1)
38
+ res = torch.empty_like(x)
39
+ res[..., :d_half] = x1 * cos - x2 * sin
40
+ res[..., d_half:] = x2 * cos + x1 * sin
41
+
42
+ return res.to(dtype)
43
+
44
+
45
+ class FastVisuoExpertConfig(PretrainedConfig):
46
+ model_type = "FastVisuoExpertModel"
47
+ sub_configs = {"dino_config": AutoConfig, "gemma_expert_config": AutoConfig}
48
+
49
+ def __init__(
50
+ self,
51
+ dino_config: dict | None = None,
52
+ gemma_expert_config: dict | None = None,
53
+ freeze_vision_encoder: bool = True,
54
+ attention_implementation: str = "eager",
55
+ **kwargs,
56
+ ):
57
+ self.freeze_vision_encoder = freeze_vision_encoder
58
+ self.attention_implementation = attention_implementation
59
+
60
+ if dino_config is None:
61
+ self.dino_config = CONFIG_MAPPING["dinov2"](
62
+ transformers_version="4.48.1",
63
+ model_type="dinov2",
64
+ attention_probs_dropout_prob=0.0,
65
+ drop_path_rate=0.0,
66
+ hidden_act="gelu",
67
+ hidden_dropout_prob=0.0,
68
+ hidden_size=384,
69
+ image_size=518,
70
+ initializer_range=0.02,
71
+ layer_norm_eps=1e-06,
72
+ layerscale_value=1.0,
73
+ mlp_ratio=4,
74
+ num_attention_heads=6,
75
+ num_channels=3,
76
+ num_hidden_layers=12,
77
+ patch_size=14,
78
+ qkv_bias=True,
79
+ torch_dtype="float32",
80
+ use_swiglu_ffn=False,
81
+ )
82
+ elif isinstance(dino_config, dict):
83
+ if "model_type" not in dino_config:
84
+ dino_config["model_type"] = "dinov2"
85
+ cfg_cls = CONFIG_MAPPING[dino_config["model_type"]]
86
+ self.dino_config = cfg_cls(**dino_config)
87
+
88
+ if gemma_expert_config is None:
89
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
90
+ attention_bias=False,
91
+ attention_dropout=0.0,
92
+ bos_token_id=2,
93
+ eos_token_id=1,
94
+ head_dim=256,
95
+ hidden_act="gelu_pytorch_tanh",
96
+ hidden_activation="gelu_pytorch_tanh",
97
+ hidden_size=1024,
98
+ initializer_range=0.02,
99
+ intermediate_size=4096,
100
+ max_position_embeddings=8192,
101
+ model_type="gemma",
102
+ num_attention_heads=8,
103
+ num_hidden_layers=8,
104
+ num_key_value_heads=1,
105
+ pad_token_id=0,
106
+ rms_norm_eps=1e-06,
107
+ rope_theta=10000.0,
108
+ torch_dtype="float32",
109
+ transformers_version="4.48.1",
110
+ use_cache=True,
111
+ vocab_size=257152,
112
+ )
113
+ elif isinstance(gemma_expert_config, dict):
114
+ if "model_type" not in gemma_expert_config:
115
+ gemma_expert_config["model_type"] = "gemma"
116
+ cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
117
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
118
+
119
+ super().__init__(**kwargs)
120
+
121
+ def __post_init__(self):
122
+ super().__post_init__()
123
+ if self.attention_implementation not in ["eager", "fa2", "flex"]:
124
+ raise ValueError(
125
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
126
+ )
127
+
128
+
129
+ class FastVisuoExpertModel(PreTrainedModel):
130
+ config_class = FastVisuoExpertConfig
131
+
132
+ def __init__(self, config: FastVisuoExpertConfig):
133
+ super().__init__(config=config)
134
+ self.config = config
135
+ self.vision_tower = Dinov2Model(config=config.dino_config)
136
+ self.gemma_expert = GemmaForCausalLM(
137
+ config=config.gemma_expert_config
138
+ ) # GemmaModel
139
+ self.multi_modal_projector = nn.Linear(
140
+ config.dino_config.hidden_size, config.gemma_expert_config.hidden_size
141
+ )
142
+ self.gemma_expert.model.embed_tokens = None
143
+ self.gemma_expert.lm_head = None
144
+
145
+ self.to_bfloat16_like_physical_intelligence()
146
+ self.set_requires_grad()
147
+
148
+ def set_requires_grad(self):
149
+ if self.config.freeze_vision_encoder:
150
+ self.vision_tower.eval()
151
+ for params in self.vision_tower.parameters():
152
+ params.requires_grad = False
153
+
154
+ def train(self, mode: bool = True):
155
+ super().train(mode)
156
+
157
+ if self.config.freeze_vision_encoder:
158
+ self.vision_tower.eval()
159
+
160
+ def to_bfloat16_like_physical_intelligence(self):
161
+ self.vision_tower = self.vision_tower.to(dtype=torch.bfloat16)
162
+ params_to_change_dtype = [
163
+ "language_model.model.layers",
164
+ "gemma_expert.model.layers",
165
+ "vision_tower",
166
+ "multi_modal",
167
+ ]
168
+ for name, param in self.named_parameters():
169
+ if any(selector in name for selector in params_to_change_dtype):
170
+ param.data = param.data.to(dtype=torch.bfloat16)
171
+
172
+ def embed_image(self, image: torch.Tensor):
173
+ selected_image_feature = self.vision_tower(image).last_hidden_state
174
+ image_features = self.multi_modal_projector(selected_image_feature)
175
+ image_features = image_features / (
176
+ self.config.gemma_expert_config.hidden_size**0.5
177
+ )
178
+ return image_features
179
+
180
+ # TODO: break down this huge forward into modules or functions
181
+ def forward(
182
+ self,
183
+ attention_mask: Optional[torch.Tensor] = None,
184
+ position_ids: Optional[torch.LongTensor] = None,
185
+ inputs_embeds: Optional[torch.FloatTensor] = None,
186
+ ):
187
+ # RMSNorm
188
+ head_dim = self.gemma_expert.config.head_dim
189
+
190
+ hidden_states = inputs_embeds
191
+ batch_size = hidden_states.shape[0]
192
+ for layer in self.gemma_expert.model.layers[
193
+ : self.gemma_expert.config.num_hidden_layers
194
+ ]:
195
+ # normalizer = torch.tensor(model.config.hidden_size**0.5, dtype=hidden_states.dtype)
196
+ # hidden_states = hidden_states * normalizer
197
+ hidden_states = layer.input_layernorm(hidden_states)
198
+ input_shape = hidden_states.shape[:-1]
199
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
200
+
201
+ # self attention
202
+ hidden_states = hidden_states.to(dtype=torch.bfloat16)
203
+ query_states = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
204
+ key_states = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
205
+ value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
206
+
207
+ query_states = apply_rope(query_states, position_ids)
208
+ key_states = apply_rope(key_states, position_ids)
209
+
210
+ attention_interface = self.get_attention_interface()
211
+ att_output = attention_interface(
212
+ attention_mask,
213
+ batch_size,
214
+ head_dim,
215
+ query_states,
216
+ key_states,
217
+ value_states,
218
+ )
219
+
220
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
221
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
222
+
223
+ out_emb = layer.self_attn.o_proj(att_output)
224
+
225
+ # first residual
226
+ out_emb += hidden_states
227
+ after_first_residual = out_emb.clone()
228
+ out_emb = layer.post_attention_layernorm(out_emb)
229
+ out_emb = layer.mlp(out_emb)
230
+ # second residual
231
+ out_emb += after_first_residual
232
+ hidden_states = out_emb
233
+
234
+ # final norm
235
+ hidden_states = self.gemma_expert.model.norm(hidden_states)
236
+
237
+ return hidden_states
238
+
239
+ def get_attention_interface(self):
240
+ if self.config.attention_implementation == "fa2":
241
+ attention_interface = self.flash_attention_forward
242
+ else:
243
+ attention_interface = self.eager_attention_forward
244
+ return attention_interface
245
+
246
+ def eager_attention_forward(
247
+ self,
248
+ attention_mask,
249
+ batch_size,
250
+ head_dim,
251
+ query_states,
252
+ key_states,
253
+ value_states,
254
+ ):
255
+ num_att_heads = self.config.gemma_expert_config.num_attention_heads
256
+ num_key_value_heads = self.config.gemma_expert_config.num_key_value_heads
257
+ num_key_value_groups = num_att_heads // num_key_value_heads
258
+
259
+ # query_states: batch_size, sequence_length, num_att_head, head_dim
260
+ # key_states: batch_size, sequence_length, num_key_value_head, head_dim
261
+ # value_states: batch_size, sequence_length, num_key_value_head, head_dim
262
+ sequence_length = key_states.shape[1]
263
+
264
+ key_states = key_states[:, :, :, None, :].expand(
265
+ batch_size,
266
+ sequence_length,
267
+ num_key_value_heads,
268
+ num_key_value_groups,
269
+ head_dim,
270
+ )
271
+ key_states = key_states.reshape(
272
+ batch_size,
273
+ sequence_length,
274
+ num_key_value_heads * num_key_value_groups,
275
+ head_dim,
276
+ )
277
+
278
+ value_states = value_states[:, :, :, None, :].expand(
279
+ batch_size,
280
+ sequence_length,
281
+ num_key_value_heads,
282
+ num_key_value_groups,
283
+ head_dim,
284
+ )
285
+ value_states = value_states.reshape(
286
+ batch_size,
287
+ sequence_length,
288
+ num_key_value_heads * num_key_value_groups,
289
+ head_dim,
290
+ )
291
+
292
+ # Attention here is upcasted to float32 to match the original eager implementation.
293
+ query_states = query_states.to(dtype=torch.float32)
294
+ key_states = key_states.to(dtype=torch.float32)
295
+
296
+ query_states = query_states.transpose(1, 2)
297
+ key_states = key_states.transpose(1, 2)
298
+
299
+ att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
300
+ att_weights *= head_dim**-0.5
301
+ big_neg = -2.3819763e38 # See gemma/modules.py
302
+
303
+ masked_att_weights = torch.where(
304
+ attention_mask[:, None, :, :], att_weights, big_neg
305
+ )
306
+
307
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
308
+ probs = probs.to(dtype=value_states.dtype)
309
+
310
+ # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
311
+ # value_states: batch_size, sequence_length, num_att_heads, head_dim
312
+
313
+ att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
314
+
315
+ att_output = att_output.permute(0, 2, 1, 3)
316
+ # we use -1 because sequence length can change
317
+ att_output = att_output.reshape(
318
+ batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
319
+ )
320
+
321
+ return att_output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58518ffa9166223aafee65453b3ed9cd4abde834949a61bbf78c3a5e99c1fe42
3
+ size 9038608596
modeling_hume.py ADDED
@@ -0,0 +1,1909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ from argparse import Namespace
4
+ from collections import deque
5
+
6
+ import array_typing as at
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F # noqa: N812
10
+ import torchvision.transforms.functional as TF
11
+ from beartype import beartype as typechecker
12
+ from configuration_hume import HumeConfig, System2Config
13
+ from fast_visuo_expert import FastVisuoExpertConfig, FastVisuoExpertModel
14
+ from jaxtyping import Bool, Float, Int64, jaxtyped
15
+ from lerobot.common.constants import ACTION, OBS_ROBOT
16
+ from lerobot.common.policies.normalize import Normalize, Unnormalize
17
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
18
+ from lerobot.common.utils.utils import get_safe_dtype
19
+ from paligemma_with_expert import (
20
+ PaliGemmaWithExpertConfig,
21
+ PaliGemmaWithExpertModel,
22
+ )
23
+ from torch import Tensor, nn
24
+ from transformers import AutoTokenizer
25
+ from value_query import (
26
+ CalQL,
27
+ CalQlConfig,
28
+ VQHBackbone,
29
+ VQHBackboneConfig,
30
+ )
31
+
32
+
33
+ def create_sinusoidal_pos_embedding(
34
+ time: torch.tensor,
35
+ dimension: int,
36
+ min_period: float,
37
+ max_period: float,
38
+ device="cpu",
39
+ ) -> Tensor:
40
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
41
+ if dimension % 2 != 0:
42
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
43
+
44
+ if time.ndim != 1:
45
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
46
+
47
+ dtype = get_safe_dtype(torch.float64, device.type)
48
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
49
+ period = min_period * (max_period / min_period) ** fraction
50
+
51
+ # Compute the outer product
52
+ scaling_factor = 1.0 / period * 2 * math.pi
53
+ sin_input = scaling_factor[None, :] * time[:, None]
54
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
55
+ return pos_emb
56
+
57
+
58
+ def sample_beta(alpha, beta, bsize, device):
59
+ gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
60
+ gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
61
+ return gamma1 / (gamma1 + gamma2)
62
+
63
+
64
+ def make_att_2d_masks(pad_masks, att_masks):
65
+ """Copied from big_vision.
66
+
67
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
68
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
69
+ setup several types of attention, for example:
70
+
71
+ [[1 1 1 1 1 1]]: pure causal attention.
72
+
73
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
74
+ themselves and the last 3 tokens have a causal attention. The first
75
+ entry could also be a 1 without changing behaviour.
76
+
77
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
78
+ block can attend all previous blocks and all tokens on the same block.
79
+
80
+ Args:
81
+ input_mask: bool[B, N] true if its part of the input, false if padding.
82
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
83
+ it and 0 where it shares the same attention mask as the previous token.
84
+ """
85
+ if att_masks.ndim != 2:
86
+ raise ValueError(att_masks.ndim)
87
+ if pad_masks.ndim != 2:
88
+ raise ValueError(pad_masks.ndim)
89
+
90
+ cumsum = torch.cumsum(att_masks, dim=1)
91
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
92
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
93
+ att_2d_masks = att_2d_masks & pad_2d_masks
94
+ return att_2d_masks
95
+
96
+
97
+ def resize_with_pad(img, width, height, pad_value=-1):
98
+ # assume no-op when width height fits already
99
+ if img.ndim != 4:
100
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
101
+
102
+ cur_height, cur_width = img.shape[2:]
103
+
104
+ ratio = max(cur_width / width, cur_height / height)
105
+ resized_height = int(cur_height / ratio)
106
+ resized_width = int(cur_width / ratio)
107
+ resized_img = F.interpolate(
108
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
109
+ )
110
+
111
+ pad_height = max(0, int(height - resized_height))
112
+ pad_width = max(0, int(width - resized_width))
113
+
114
+ # pad on left and top of image
115
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
116
+ return padded_img
117
+
118
+
119
+ def pad_vector(vector, new_dim):
120
+ """Can be (batch_size x sequence_length x features_dimension)
121
+ or (batch_size x features_dimension)
122
+ """
123
+ if vector.shape[-1] == new_dim:
124
+ return vector
125
+ shape = list(vector.shape)
126
+ current_dim = shape[-1]
127
+ shape[-1] = new_dim
128
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
129
+ new_vector[..., :current_dim] = vector
130
+ return new_vector
131
+
132
+
133
+ def normalize(x, min_val, max_val):
134
+ return (x - min_val) / (max_val - min_val)
135
+
136
+
137
+ def unnormalize(x, min_val, max_val):
138
+ return x * (max_val - min_val) + min_val
139
+
140
+
141
+ def safe_arcsin(value):
142
+ # This ensures that the input stays within
143
+ # [−1,1] to avoid invalid values for arcsin
144
+ return torch.arcsin(torch.clamp(value, -1.0, 1.0))
145
+
146
+
147
+ def aloha_gripper_to_angular(value):
148
+ # Aloha transforms the gripper positions into a linear space. The following code
149
+ # reverses this transformation to be consistent with pi0 which is pretrained in
150
+ # angular space.
151
+ #
152
+ # These values are coming from the Aloha code:
153
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
154
+ value = unnormalize(value, min_val=0.01844, max_val=0.05800)
155
+
156
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
157
+ def linear_to_radian(linear_position, arm_length, horn_radius):
158
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
159
+ 2 * horn_radius * linear_position
160
+ )
161
+ return safe_arcsin(value)
162
+
163
+ # The constants are taken from the Interbotix code.
164
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
165
+
166
+ # Normalize to [0, 1].
167
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
168
+ return normalize(value, min_val=0.4, max_val=1.5)
169
+
170
+
171
+ def aloha_gripper_from_angular(value):
172
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
173
+ # Note that the units are still angular but the range is different.
174
+
175
+ # The values 0.4 and 1.5 were measured on an actual Trossen robot.
176
+ value = unnormalize(value, min_val=0.4, max_val=1.5)
177
+
178
+ # These values are coming from the Aloha code:
179
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
180
+ return normalize(value, min_val=-0.6213, max_val=1.4910)
181
+
182
+
183
+ def aloha_gripper_from_angular_inv(value):
184
+ # Directly inverts the gripper_from_angular function.
185
+ value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
186
+ return normalize(value, min_val=0.4, max_val=1.5)
187
+
188
+
189
+ class HumePolicy(PreTrainedPolicy):
190
+ """Wrapper class around System2 model to train and run inference within LeRobot."""
191
+
192
+ config_class = HumeConfig
193
+ name = "hume"
194
+
195
+ def __init__(
196
+ self,
197
+ config: HumeConfig,
198
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
199
+ ):
200
+ super().__init__(config)
201
+ config.validate_features()
202
+ self.config = config
203
+
204
+ # TODO: input / output features / normalizer for mutiple datasets
205
+ self.normalize_inputs = Normalize(
206
+ config.input_features, config.normalization_mapping, dataset_stats
207
+ )
208
+ self.normalize_targets = Normalize(
209
+ config.output_features, config.normalization_mapping, dataset_stats
210
+ )
211
+ self.unnormalize_outputs = Unnormalize(
212
+ config.output_features, config.normalization_mapping, dataset_stats
213
+ )
214
+
215
+ self.language_tokenizer = None
216
+ self.s2_model = System2(config)
217
+ self.s1_model = FastVisuoMatching(config)
218
+ self.value_query_head = ValueQueryHead(
219
+ paligemma_with_expert=self.s2_model.paligemma_with_expert, config=config
220
+ )
221
+ self.reset()
222
+
223
+ self.set_requires_grad()
224
+
225
+ def set_requires_grad(self):
226
+ if self.config.freeze_s2:
227
+ self.s2_model.eval()
228
+ for params in self.s2_model.parameters():
229
+ params.requires_grad = False
230
+
231
+ def train(self, mode: bool = True):
232
+ super().train(mode)
233
+ if self.config.freeze_s2:
234
+ self.s2_model.eval()
235
+
236
+ def reset(self):
237
+ """This should be called whenever the environment is reset."""
238
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
239
+ self.s2_action_cache = {}
240
+
241
+ def get_trunk_params(self) -> dict:
242
+ exclude_params = set()
243
+ exclude_modules = [
244
+ self.value_query_head.calql.policy,
245
+ self.value_query_head.calql.critics,
246
+ self.value_query_head.calql.temperature,
247
+ ]
248
+
249
+ for module in exclude_modules:
250
+ for param in module.parameters():
251
+ exclude_params.add(id(param))
252
+
253
+ return [param for param in self.parameters() if id(param) not in exclude_params]
254
+
255
+ def get_optim_params(self) -> dict:
256
+ return self.parameters()
257
+
258
+ def get_actor_optim_params(self) -> dict:
259
+ return self.value_query_head.calql.policy.parameters()
260
+
261
+ def get_critics_optim_params(self) -> dict:
262
+ return self.value_query_head.calql.critics.parameters()
263
+
264
+ def get_temperature_optim_params(self) -> dict:
265
+ return self.value_query_head.calql.temperature.parameters()
266
+
267
+ def init_infer(self, infer_cfg: at.InferConfig):
268
+ self.infer_cfg = Namespace(**infer_cfg)
269
+ self.action_plan = collections.deque()
270
+ self.history_state = collections.deque(maxlen=self.config.s1_his_state_size)
271
+ self.infer_step = 0
272
+ self.outputs = {}
273
+ self.q_value_cache = []
274
+ self.action_cache = []
275
+
276
+ self.reset()
277
+ print("Initializing inference with config:", infer_cfg)
278
+
279
+ return True
280
+
281
+ def infer(self, observation: at.InferBatchObs) -> at.ActionArray:
282
+ # prcoess observation
283
+ # from np.array -> torch.tensor -> add batch, change shape
284
+ if not self.history_state:
285
+ self.history_state.extend(
286
+ np.expand_dims(observation["observation.state"], 1)
287
+ .repeat(self.config.s1_his_state_size, axis=1)
288
+ .transpose(1, 0, 2)
289
+ )
290
+ else:
291
+ self.history_state.append(observation["observation.state"])
292
+
293
+ observation["observation.state"] = np.asarray(self.history_state).transpose(
294
+ 1, 0, 2
295
+ )
296
+
297
+ observation: dict[str, torch.tensor | list[str]] = {
298
+ **{
299
+ k: torch.tensor(v / 255) # b, h, w ,c
300
+ .permute(0, 3, 1, 2) # b, c, h, w
301
+ .to(self.infer_cfg.device)
302
+ .float()
303
+ for k, v in observation.items()
304
+ if k
305
+ in {
306
+ "observation.images.image",
307
+ "observation.images.wrist_image",
308
+ "observation.images.image_0",
309
+ }
310
+ },
311
+ **{k: v for k, v in observation.items() if k in {"task"}}, # len = batch
312
+ **{
313
+ k: torch.tensor(v)
314
+ .to(self.infer_cfg.device)
315
+ .float() # b, state_horizon, state_dim
316
+ for k, v in observation.items()
317
+ if k in {"observation.state"}
318
+ },
319
+ }
320
+ batch_size = len(observation["task"])
321
+
322
+ if not self.action_plan:
323
+ # Finished executing previous action chunk -- compute new chunk
324
+ # Prepare observations dict
325
+ # infer the action
326
+ if self.infer_step % self.infer_cfg.s2_replan_steps == 0:
327
+ self.outputs = {} # infer with s1 or s2
328
+ stamp = (
329
+ torch.tensor(
330
+ [
331
+ self.infer_step
332
+ % self.infer_cfg.s2_replan_steps
333
+ / self.config.s2_chunk_size
334
+ ]
335
+ )
336
+ .expand(batch_size)
337
+ .to(self.infer_cfg.device)
338
+ .float()
339
+ )
340
+ self.outputs = self.select_action(
341
+ observation,
342
+ self.outputs,
343
+ stamp,
344
+ s2_candidates_num=self.infer_cfg.s2_candidates_num,
345
+ noise_temp_bounds=(
346
+ self.infer_cfg.noise_temp_lower_bound,
347
+ self.infer_cfg.noise_temp_upper_bound,
348
+ ),
349
+ time_temp_bounds=(
350
+ self.infer_cfg.time_temp_lower_bound,
351
+ self.infer_cfg.time_temp_upper_bound,
352
+ ),
353
+ )
354
+ action_chunk = self.outputs["s1_action"].cpu().numpy()
355
+
356
+ if self.infer_cfg.post_process_action:
357
+ action_chunk[..., -1] = 2 * (1 - action_chunk[..., -1]) - 1
358
+
359
+ # convert action chunk shape to (replan_steps, batch, action_dim)
360
+ action_chunk = action_chunk.transpose(1, 0, 2)
361
+ assert (
362
+ len(action_chunk) >= self.infer_cfg.replan_steps
363
+ ), f"We want to replan every {self.infer_cfg.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
364
+ self.action_plan.extend(action_chunk[: self.infer_cfg.replan_steps])
365
+
366
+ self.infer_step += 1
367
+ action = self.action_plan.popleft()
368
+ return np.asarray(action)
369
+
370
+ @torch.no_grad
371
+ @jaxtyped(typechecker=typechecker)
372
+ def select_action(
373
+ self,
374
+ batch: at.InferBatchObs,
375
+ outputs: at.InferOutput = {},
376
+ stamp: Float[Tensor, " batch"] | None = None,
377
+ s2_candidates_num: int = 5,
378
+ noise_temp_bounds: tuple = (1.0, 1.0),
379
+ time_temp_bounds: tuple = (1.0, 1.0),
380
+ ) -> at.InferOutput:
381
+ """Select a single action given environment observations.
382
+
383
+ This method wraps `select_actions` in order to return one action at a time for execution in the
384
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
385
+ queue is empty.
386
+ """
387
+ self.eval()
388
+
389
+ if self.config.adapt_to_pi_aloha:
390
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
391
+
392
+ batch = self.normalize_inputs(batch)
393
+
394
+ # querying the policy.
395
+ images, img_masks = self.prepare_images(batch)
396
+ state = self.prepare_state(batch)
397
+ lang_tokens, lang_masks = self.prepare_language(batch)
398
+
399
+ original_action_dim = self.config.action_feature.shape[0]
400
+
401
+ if "noise_action" not in outputs:
402
+ noise_actions = [] # [(Batch, Chunksize, Action dim),]
403
+ for i in range(s2_candidates_num):
404
+ noise_actions.append(
405
+ self.s2_model.sample_actions(
406
+ images,
407
+ img_masks,
408
+ lang_tokens,
409
+ lang_masks,
410
+ state[:, -1, :], # s2 not supported history state yet
411
+ time_temp=(i / s2_candidates_num)
412
+ * (time_temp_bounds[1] - time_temp_bounds[0])
413
+ + time_temp_bounds[0],
414
+ noise_temp=(i / s2_candidates_num)
415
+ * (noise_temp_bounds[1] - noise_temp_bounds[0])
416
+ + noise_temp_bounds[0],
417
+ )
418
+ )
419
+ noise_actions = torch.stack(noise_actions, dim=1)
420
+ # (Batch, s2_candidates_num, Chunksize, Actiondim)
421
+ batch_size = noise_actions.shape[0]
422
+ batch_idx = torch.arange(batch_size, device=noise_actions.device)
423
+
424
+ noise_actions_wo_pad = noise_actions[
425
+ :, :, : self.config.vqh_chunk_size, :original_action_dim
426
+ ]
427
+ action_index, q_values = self.value_query_head.select_q_actions(
428
+ images, img_masks, lang_tokens, lang_masks, noise_actions_wo_pad
429
+ )
430
+ self.q_value_cache.append(q_values.squeeze())
431
+ unnormalized_noise_actions = self.unnormalize_outputs(
432
+ {"action": noise_actions_wo_pad}
433
+ )["action"]
434
+ self.action_cache.append(unnormalized_noise_actions.squeeze())
435
+ selected_noise_action = noise_actions[batch_idx, action_index]
436
+
437
+ outputs = {"noise_action": selected_noise_action}
438
+
439
+ noise_action: Float[Tensor, "batch s2_chunksize action_dim"] = outputs[
440
+ "noise_action"
441
+ ]
442
+ idcs = (stamp * self.config.s2_chunk_size).long().unsqueeze(1) + torch.arange(
443
+ self.config.s1_chunk_size, device=noise_action.device
444
+ )
445
+ batch_idcs = torch.arange(
446
+ noise_action.shape[0], device=noise_action.device
447
+ ).unsqueeze(1)
448
+ noise_action_slides = noise_action[batch_idcs, idcs]
449
+ s1_actions = self.s1_model.sample_actions(
450
+ images, img_masks, state, noise_action_slides, stamp=stamp
451
+ )
452
+
453
+ # Unpad actions
454
+ actions = s1_actions[:, :, :original_action_dim]
455
+ actions = self.unnormalize_outputs({"action": actions})["action"]
456
+
457
+ if self.config.adapt_to_pi_aloha:
458
+ actions = self._pi_aloha_encode_actions_inv(actions)
459
+
460
+ outputs["s1_action"] = actions
461
+
462
+ return outputs
463
+
464
+ def post_normalize(self, batch):
465
+ """additional keys {obervation.x}.s1 are merged in to the batch,
466
+ so we need to normalize these keys
467
+ """
468
+ merge_keys = filter(lambda k: k.endswith(".s1"), batch.keys())
469
+ for k in merge_keys:
470
+ _k = k.replace(".s1", "")
471
+ batch[k] = self.normalize_inputs({_k: batch[k]})[_k]
472
+ return batch
473
+
474
+ def get_noise_action_slides(self, action: Tensor, stamp: Tensor) -> Tensor:
475
+ """Augment the action with the previous actions in the queue."""
476
+ # idcs = (torch.rand_like(stamp) * (self.config.s2_chunk_size - self.config.s1_chunk_size)).long()
477
+ idcs = (
478
+ (
479
+ self.config.noise_slides_alp * torch.rand_like(stamp)
480
+ - self.config.noise_slides_alp / 2
481
+ + stamp
482
+ )
483
+ * self.config.s2_chunk_size
484
+ ).long()
485
+ idcs = torch.clamp(idcs, 0, action.shape[1] - self.config.s1_chunk_size)
486
+ idcs = idcs + torch.arange(self.config.s1_chunk_size, device=action.device)
487
+ batch_idcs = torch.arange(action.shape[0], device=action.device).unsqueeze(1)
488
+ noise_action_slides = action[batch_idcs, idcs]
489
+
490
+ noise_action_slides += (
491
+ torch.randn_like(noise_action_slides) * self.config.noise_slides_eps
492
+ )
493
+ return noise_action_slides
494
+
495
+ def forward(
496
+ self, batch: dict[str, Tensor], noise=None, time=None
497
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, dict[str, Tensor]]:
498
+ """Do a full training forward pass to compute the loss"""
499
+ if self.config.adapt_to_pi_aloha:
500
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
501
+ batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
502
+
503
+ batch = self.normalize_inputs(batch)
504
+ batch = self.post_normalize(batch)
505
+ batch = self.normalize_targets(batch)
506
+
507
+ # prepare images
508
+ images, img_masks = self.prepare_images(batch)
509
+ state = self.prepare_state(batch)
510
+ lang_tokens, lang_masks = self.prepare_language(batch)
511
+
512
+ s1_images, s1_img_masks = self.prepare_images(
513
+ batch, map(lambda x: f"{x}.s1", self.config.image_features)
514
+ ) # 0
515
+ s1_state = self.prepare_state(batch, f"{OBS_ROBOT}.s1")
516
+
517
+ # prepare actions
518
+ actions = self.prepare_action(batch)
519
+ actions_is_pad = batch.get("action_is_pad")
520
+
521
+ b, s, _ = actions.shape
522
+ device = actions.device
523
+ batch_idcs = torch.arange(b, device=device).unsqueeze(1)
524
+ stamp = batch["stamp"]
525
+ idcs = (stamp * self.config.s2_chunk_size).long() + torch.arange(
526
+ self.config.s1_chunk_size, device=device
527
+ )
528
+ s1_actions = actions[batch_idcs, idcs]
529
+ s1_actions_is_pad = (
530
+ None if actions_is_pad is None else actions_is_pad[batch_idcs, idcs]
531
+ )
532
+
533
+ # s2 forward pass
534
+ with torch.no_grad():
535
+ if self.config.cache_s2_actions:
536
+ is_noised = []
537
+ noise_actions = torch.zeros_like(actions)
538
+ for idx, s2_idx in enumerate(batch["s2_idx"]):
539
+ if s2_idx in self.s2_action_cache:
540
+ noise_actions[idx] = self.s2_action_cache[s2_idx]
541
+ is_noised.append(False)
542
+ else:
543
+ is_noised.append(True)
544
+ # noise batch
545
+ is_noised = torch.tensor(is_noised, device=batch["s2_idx"].device)
546
+ s2_actions_infered = self.s2_model.sample_actions(
547
+ [img[is_noised] for img in images],
548
+ [mask[is_noised] for mask in img_masks],
549
+ lang_tokens[is_noised],
550
+ lang_masks[is_noised],
551
+ state[is_noised],
552
+ )
553
+ noise_actions[is_noised] = s2_actions_infered
554
+ else:
555
+ noise_actions = self.s2_model.sample_actions(
556
+ images,
557
+ img_masks,
558
+ lang_tokens,
559
+ lang_masks,
560
+ state,
561
+ )
562
+
563
+ # vgps: embs[q] -> layers -> [q] -> mlp
564
+ # value query head features are end with vqh: xx.vqh
565
+ vqh_images, vqh_img_masks = self.prepare_images(
566
+ batch, map(lambda x: f"{x}.vqh", self.config.image_features)
567
+ ) # 1
568
+
569
+ temperature_loss, policy_loss, critic_loss, log_dict = (
570
+ self.value_query_head.forward(
571
+ images,
572
+ img_masks,
573
+ lang_tokens,
574
+ lang_masks,
575
+ vqh_images,
576
+ vqh_img_masks,
577
+ batch["action"][:, : self.config.vqh_chunk_size, :],
578
+ batch["reward.vqh"],
579
+ batch["mc.vqh"],
580
+ batch["reward.vqh"].to(dtype=torch.float),
581
+ )
582
+ )
583
+
584
+ noise_action_slides = self.get_noise_action_slides(noise_actions, stamp)
585
+ s1_losses = self.s1_model.forward(
586
+ s1_images,
587
+ s1_img_masks,
588
+ s1_state,
589
+ s1_actions,
590
+ noise_action_slides,
591
+ time,
592
+ stamp=stamp.squeeze(),
593
+ )
594
+
595
+ total_loss, loss_dict = 0.0, {}
596
+
597
+ if s1_actions_is_pad is not None:
598
+ in_episode_bound = ~s1_actions_is_pad
599
+ s1_losses = s1_losses * in_episode_bound.unsqueeze(-1)
600
+
601
+ s1_losses = s1_losses[..., : self.config.max_action_dim]
602
+ s1_losses = s1_losses.mean()
603
+
604
+ loss_dict["s1_loss"] = s1_losses.item()
605
+ total_loss += s1_losses
606
+
607
+ # add ValueQueryHead log dict to loss_dict
608
+ # loss_dict = {**loss_dict, **log_dict}
609
+ loss_dict["entropy"] = log_dict["entropy"].item()
610
+ loss_dict["actions_mse"] = log_dict["actions_mse"].item()
611
+ loss_dict["td_err"] = log_dict["td_err"].item()
612
+ loss_dict["temperature"] = log_dict["temperature"].item()
613
+ loss_dict["cql_loss"] = log_dict["cql_loss"].item()
614
+ loss_dict["cql_alpha"] = log_dict["cql_alpha"]
615
+ loss_dict["cql_diff"] = log_dict["cql_diff"].item()
616
+ loss_dict["critic_loss"] = log_dict["critic_loss"].item()
617
+ loss_dict["cql_ood_values"] = log_dict["cql_ood_values"].item()
618
+ loss_dict["calql_bound_rate"] = log_dict["calql_bound_rate"].item()
619
+ loss_dict["online_q"] = log_dict["online_q"].item()
620
+ loss_dict["target_q"] = log_dict["target_q"].item()
621
+ loss_dict["positive_qs"] = log_dict["positive_qs"].item()
622
+ loss_dict["actor_loss"] = log_dict["actor_loss"].item()
623
+
624
+ return total_loss, temperature_loss, policy_loss, critic_loss, loss_dict
625
+
626
+ def prepare_images(self, batch, image_features=None):
627
+ """Apply preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
628
+ convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
629
+ """
630
+ images = []
631
+ img_masks = []
632
+
633
+ image_features = image_features or self.config.image_features
634
+ present_img_keys = [key for key in image_features if key in batch]
635
+ missing_img_keys = [key for key in image_features if key not in batch]
636
+
637
+ if len(present_img_keys) == 0:
638
+ raise ValueError(
639
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
640
+ )
641
+
642
+ # Preprocess image features present in the batch
643
+ for key in present_img_keys:
644
+ img = batch[key]
645
+
646
+ if self.config.resize_imgs_with_padding is not None:
647
+ img = resize_with_pad(
648
+ img, *self.config.resize_imgs_with_padding, pad_value=0
649
+ )
650
+
651
+ # Normalize from range [0,1] to [-1,1] as expacted by siglip
652
+ img = img * 2.0 - 1.0
653
+
654
+ bsize = img.shape[0]
655
+ device = img.device
656
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
657
+ images.append(img)
658
+ img_masks.append(mask)
659
+
660
+ # Create image features not present in the batch
661
+ # as fully 0 padded images.
662
+ for num_empty_cameras in range(len(missing_img_keys)):
663
+ if num_empty_cameras >= self.config.empty_cameras:
664
+ break
665
+ img = torch.ones_like(img) * -1
666
+ mask = torch.zeros_like(mask)
667
+ images.append(img)
668
+ img_masks.append(mask)
669
+
670
+ return images, img_masks
671
+
672
+ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
673
+ """Tokenize the text input"""
674
+ device = batch[OBS_ROBOT].device
675
+ tasks = batch["task"]
676
+
677
+ # PaliGemma prompt has to end with a new line
678
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
679
+
680
+ tokenized_prompt = self.language_tokenizer.__call__(
681
+ tasks,
682
+ padding="max_length",
683
+ padding_side="right",
684
+ max_length=self.config.tokenizer_max_length,
685
+ return_tensors="pt",
686
+ truncation=True,
687
+ )
688
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
689
+ lang_masks = tokenized_prompt["attention_mask"].to(
690
+ device=device, dtype=torch.bool
691
+ )
692
+
693
+ return lang_tokens, lang_masks
694
+
695
+ def _pi_aloha_decode_state(self, state):
696
+ # Flip the joints.
697
+ for motor_idx in [1, 2, 8, 9]:
698
+ state[:, motor_idx] *= -1
699
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
700
+ for motor_idx in [6, 13]:
701
+ state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
702
+ return state
703
+
704
+ def _pi_aloha_encode_actions(self, actions):
705
+ # Flip the joints.
706
+ for motor_idx in [1, 2, 8, 9]:
707
+ actions[:, :, motor_idx] *= -1
708
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
709
+ for motor_idx in [6, 13]:
710
+ actions[:, :, motor_idx] = aloha_gripper_from_angular(
711
+ actions[:, :, motor_idx]
712
+ )
713
+ return actions
714
+
715
+ def _pi_aloha_encode_actions_inv(self, actions):
716
+ # Flip the joints again.
717
+ for motor_idx in [1, 2, 8, 9]:
718
+ actions[:, :, motor_idx] *= -1
719
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
720
+ for motor_idx in [6, 13]:
721
+ actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
722
+ actions[:, :, motor_idx]
723
+ )
724
+ return actions
725
+
726
+ def prepare_state(self, batch, feature=None):
727
+ """Pad state"""
728
+ feature = feature or OBS_ROBOT
729
+ state = pad_vector(batch[feature], self.config.max_state_dim)
730
+ return state
731
+
732
+ def prepare_action(self, batch):
733
+ """Pad action"""
734
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
735
+ return actions
736
+
737
+ def _save_pretrained(self, save_directory) -> None:
738
+ super()._save_pretrained(save_directory)
739
+ print(f"Saving the language tokenizer to {save_directory} ...")
740
+ self.language_tokenizer.save_pretrained(save_directory)
741
+
742
+ import shutil
743
+
744
+ files = [
745
+ "src/hume/models/array_typing.py",
746
+ "src/hume/models/configuration_hume.py",
747
+ "src/hume/models/fast_visuo_expert.py",
748
+ "src/hume/models/modeling_hume.py",
749
+ "src/hume/models/paligemma_with_expert.py",
750
+ "src/hume/models/value_query.py",
751
+ ]
752
+ try:
753
+ for file in files:
754
+ shutil.copy(file, save_directory)
755
+ except Exception:
756
+ print("Failed to copy files to save_directory")
757
+
758
+ @classmethod
759
+ def from_pretrained(
760
+ cls,
761
+ pretrained_name_or_path,
762
+ **kwargs,
763
+ ):
764
+ policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
765
+ print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
766
+ policy.language_tokenizer = AutoTokenizer.from_pretrained(
767
+ pretrained_name_or_path
768
+ )
769
+ return policy
770
+
771
+
772
+ class System2Policy(PreTrainedPolicy):
773
+ """Wrapper class around System2FlowMatching model to train and run inference within LeRobot."""
774
+
775
+ config_class = System2Config
776
+ name = "system2"
777
+
778
+ def __init__(
779
+ self,
780
+ config: System2Config,
781
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
782
+ ):
783
+ """
784
+ Args:
785
+ config: Policy configuration class instance or None, in which case the default instantiation of
786
+ the configuration class is used.
787
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
788
+ that they will be passed with a call to `load_state_dict` before the policy is used.
789
+ """
790
+
791
+ super().__init__(config)
792
+ config.validate_features()
793
+ self.config = config
794
+
795
+ # TODO: input / output features / normalizer for mutiple datasets
796
+ self.normalize_inputs = Normalize(
797
+ config.input_features, config.normalization_mapping, dataset_stats
798
+ )
799
+ self.normalize_targets = Normalize(
800
+ config.output_features, config.normalization_mapping, dataset_stats
801
+ )
802
+ self.unnormalize_outputs = Unnormalize(
803
+ config.output_features, config.normalization_mapping, dataset_stats
804
+ )
805
+
806
+ self.language_tokenizer = None
807
+ self.model = System2(config)
808
+
809
+ self.reset()
810
+
811
+ def reset(self):
812
+ """This should be called whenever the environment is reset."""
813
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
814
+
815
+ def get_optim_params(self) -> dict:
816
+ return self.parameters()
817
+
818
+ @torch.no_grad
819
+ def select_action(
820
+ self, batch: dict[str, Tensor], noise: Tensor | None = None
821
+ ) -> Tensor:
822
+ """Select a single action given environment observations.
823
+
824
+ This method wraps `select_actions` in order to return one action at a time for execution in the
825
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
826
+ queue is empty.
827
+ """
828
+ self.eval()
829
+
830
+ if self.config.adapt_to_pi_aloha:
831
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
832
+
833
+ batch = self.normalize_inputs(batch)
834
+
835
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
836
+ # querying the policy.
837
+ images, img_masks = self.prepare_images(batch)
838
+ state = self.prepare_state(batch)
839
+ lang_tokens, lang_masks = self.prepare_language(batch)
840
+
841
+ actions = self.model.sample_actions(
842
+ images, img_masks, lang_tokens, lang_masks, state, noise=noise
843
+ )
844
+
845
+ # Unpad actions
846
+ original_action_dim = self.config.action_feature.shape[0]
847
+ actions = actions[:, :, :original_action_dim]
848
+
849
+ actions = self.unnormalize_outputs({"action": actions})["action"]
850
+
851
+ if self.config.adapt_to_pi_aloha:
852
+ actions = self._pi_aloha_encode_actions(actions)
853
+ return actions
854
+
855
+ def forward(
856
+ self, batch: dict[str, Tensor], noise=None, time=None
857
+ ) -> tuple[Tensor, dict[str, Tensor]]:
858
+ """Do a full training forward pass to compute the loss"""
859
+ if self.config.adapt_to_pi_aloha:
860
+ batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
861
+ batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
862
+
863
+ batch = self.normalize_inputs(batch)
864
+ batch = self.normalize_targets(batch)
865
+
866
+ images, img_masks = self.prepare_images(batch)
867
+ state = self.prepare_state(batch)
868
+ lang_tokens, lang_masks = self.prepare_language(batch)
869
+ actions = self.prepare_action(batch)
870
+ actions_is_pad = batch.get("action_is_pad")
871
+
872
+ loss_dict = {}
873
+ losses, _ = self.model.forward(
874
+ images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
875
+ )
876
+ # loss_dict["losses_after_forward"] = losses.detach().mean().item()
877
+
878
+ if actions_is_pad is not None:
879
+ in_episode_bound = ~actions_is_pad
880
+ losses = losses * in_episode_bound.unsqueeze(-1)
881
+ # loss_dict["losses_after_in_ep_bound"] = losses.detach().mean().item()
882
+
883
+ # Remove padding
884
+ losses = losses[:, :, : self.config.max_action_dim]
885
+ # loss_dict["losses_after_rm_padding"] = losses.detach().mean().item()
886
+
887
+ # For backward pass
888
+ loss = losses.mean()
889
+ # For logging
890
+ loss_dict["l2_loss"] = loss.item()
891
+
892
+ return loss, loss_dict
893
+
894
+ def prepare_images(self, batch):
895
+ """Apply preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
896
+ convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
897
+ """
898
+ images = []
899
+ img_masks = []
900
+
901
+ present_img_keys = [key for key in self.config.image_features if key in batch]
902
+ missing_img_keys = [
903
+ key for key in self.config.image_features if key not in batch
904
+ ]
905
+
906
+ if len(present_img_keys) == 0:
907
+ raise ValueError(
908
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
909
+ )
910
+
911
+ # Preprocess image features present in the batch
912
+ for key in present_img_keys:
913
+ img = batch[key]
914
+
915
+ if self.config.resize_imgs_with_padding is not None:
916
+ img = resize_with_pad(
917
+ img, *self.config.resize_imgs_with_padding, pad_value=0
918
+ )
919
+
920
+ # Normalize from range [0,1] to [-1,1] as expacted by siglip
921
+ img = img * 2.0 - 1.0
922
+
923
+ bsize = img.shape[0]
924
+ device = img.device
925
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
926
+ images.append(img)
927
+ img_masks.append(mask)
928
+
929
+ # Create image features not present in the batch
930
+ # as fully 0 padded images.
931
+ for num_empty_cameras in range(len(missing_img_keys)):
932
+ if num_empty_cameras >= self.config.empty_cameras:
933
+ break
934
+ img = torch.ones_like(img) * -1
935
+ mask = torch.zeros_like(mask)
936
+ images.append(img)
937
+ img_masks.append(mask)
938
+
939
+ return images, img_masks
940
+
941
+ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
942
+ """Tokenize the text input"""
943
+ device = batch[OBS_ROBOT].device
944
+ tasks = batch["task"]
945
+
946
+ # PaliGemma prompt has to end with a new line
947
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
948
+
949
+ tokenized_prompt = self.language_tokenizer.__call__(
950
+ tasks,
951
+ padding="max_length",
952
+ padding_side="right",
953
+ max_length=self.config.tokenizer_max_length,
954
+ return_tensors="pt",
955
+ truncation=True,
956
+ )
957
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
958
+ lang_masks = tokenized_prompt["attention_mask"].to(
959
+ device=device, dtype=torch.bool
960
+ )
961
+
962
+ return lang_tokens, lang_masks
963
+
964
+ def _pi_aloha_decode_state(self, state):
965
+ # Flip the joints.
966
+ for motor_idx in [1, 2, 8, 9]:
967
+ state[:, motor_idx] *= -1
968
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
969
+ for motor_idx in [6, 13]:
970
+ state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
971
+ return state
972
+
973
+ def _pi_aloha_encode_actions(self, actions):
974
+ # Flip the joints.
975
+ for motor_idx in [1, 2, 8, 9]:
976
+ actions[:, :, motor_idx] *= -1
977
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
978
+ for motor_idx in [6, 13]:
979
+ actions[:, :, motor_idx] = aloha_gripper_from_angular(
980
+ actions[:, :, motor_idx]
981
+ )
982
+ return actions
983
+
984
+ def _pi_aloha_encode_actions_inv(self, actions):
985
+ # Flip the joints again.
986
+ for motor_idx in [1, 2, 8, 9]:
987
+ actions[:, :, motor_idx] *= -1
988
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
989
+ for motor_idx in [6, 13]:
990
+ actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
991
+ actions[:, :, motor_idx]
992
+ )
993
+ return actions
994
+
995
+ def prepare_state(self, batch):
996
+ """Pad state"""
997
+ state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
998
+ return state
999
+
1000
+ def prepare_action(self, batch):
1001
+ """Pad action"""
1002
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
1003
+ return actions
1004
+
1005
+ def _save_pretrained(self, save_directory) -> None:
1006
+ super()._save_pretrained(save_directory)
1007
+ print(f"Saving the language tokenizer to {save_directory} ...")
1008
+ self.language_tokenizer.save_pretrained(save_directory)
1009
+
1010
+ import shutil
1011
+
1012
+ files = [
1013
+ "src/hume/models/array_typing.py",
1014
+ "src/hume/models/configuration_hume.py",
1015
+ "src/hume/models/fast_visuo_expert.py",
1016
+ "src/hume/models/modeling_hume.py",
1017
+ "src/hume/models/paligemma_with_expert.py",
1018
+ "src/hume/models/value_query.py",
1019
+ ]
1020
+ try:
1021
+ for file in files:
1022
+ shutil.copy(file, save_directory)
1023
+ except Exception:
1024
+ print("Failed to copy files to save_directory")
1025
+
1026
+ @classmethod
1027
+ def from_pretrained(
1028
+ cls,
1029
+ pretrained_name_or_path,
1030
+ **kwargs,
1031
+ ):
1032
+ policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
1033
+ print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
1034
+ policy.language_tokenizer = AutoTokenizer.from_pretrained(
1035
+ pretrained_name_or_path
1036
+ )
1037
+ return policy
1038
+
1039
+
1040
+ class System2(nn.Module):
1041
+ def __init__(self, config):
1042
+ super().__init__()
1043
+ self.config = config
1044
+
1045
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
1046
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
1047
+ train_expert_only=self.config.train_expert_only,
1048
+ attention_implementation=self.config.attention_implementation,
1049
+ paligemma_config=self.config.paligemma_config,
1050
+ gemma_expert_config=self.config.gemma_expert_config,
1051
+ )
1052
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
1053
+ paligemma_with_export_config
1054
+ )
1055
+
1056
+ # Projections are float32
1057
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
1058
+ self.action_in_proj = nn.Linear(
1059
+ self.config.max_action_dim, self.config.proj_width
1060
+ )
1061
+ self.action_out_proj = nn.Linear(
1062
+ self.config.proj_width, self.config.max_action_dim
1063
+ )
1064
+
1065
+ self.action_time_mlp_in = nn.Linear(
1066
+ self.config.proj_width * 2, self.config.proj_width
1067
+ )
1068
+ self.action_time_mlp_out = nn.Linear(
1069
+ self.config.proj_width, self.config.proj_width
1070
+ )
1071
+
1072
+ self.set_requires_grad()
1073
+
1074
+ def set_requires_grad(self):
1075
+ for params in self.state_proj.parameters():
1076
+ params.requires_grad = self.config.train_state_proj
1077
+
1078
+ def sample_noise(self, shape, device):
1079
+ noise = torch.normal(
1080
+ mean=0.0,
1081
+ std=1.0,
1082
+ size=shape,
1083
+ dtype=torch.float32,
1084
+ device=device,
1085
+ )
1086
+ return noise
1087
+
1088
+ def sample_time(self, bsize, device):
1089
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
1090
+ time = time_beta * 0.999 + 0.001
1091
+ return time.to(dtype=torch.float32, device=device)
1092
+
1093
+ def embed_prefix(
1094
+ self, images, img_masks, lang_tokens, lang_masks
1095
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1096
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
1097
+ for PaliGemma transformer processing.
1098
+ """
1099
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
1100
+ embs = []
1101
+ pad_masks = []
1102
+ att_masks = []
1103
+
1104
+ # TODO: remove for loop
1105
+ for (
1106
+ img,
1107
+ img_mask,
1108
+ ) in zip(images, img_masks, strict=False):
1109
+ img_emb = self.paligemma_with_expert.embed_image(img)
1110
+ img_emb = img_emb.to(dtype=torch.bfloat16)
1111
+
1112
+ # Normalize image embeddings
1113
+ img_emb_dim = img_emb.shape[-1]
1114
+ img_emb = img_emb * torch.tensor(
1115
+ img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
1116
+ )
1117
+
1118
+ bsize, num_img_embs = img_emb.shape[:2]
1119
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
1120
+
1121
+ embs.append(img_emb)
1122
+ pad_masks.append(img_mask)
1123
+
1124
+ # Create attention masks so that image tokens attend to each other
1125
+ att_masks += [0] * num_img_embs
1126
+
1127
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
1128
+
1129
+ # Normalize language embeddings
1130
+ lang_emb_dim = lang_emb.shape[-1]
1131
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
1132
+
1133
+ embs.append(lang_emb)
1134
+ pad_masks.append(lang_masks)
1135
+
1136
+ # full attention between image and language inputs
1137
+ num_lang_embs = lang_emb.shape[1]
1138
+ att_masks += [0] * num_lang_embs
1139
+
1140
+ embs = torch.cat(embs, dim=1)
1141
+ pad_masks = torch.cat(pad_masks, dim=1)
1142
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
1143
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1144
+
1145
+ return embs, pad_masks, att_masks
1146
+
1147
+ def embed_suffix(self, state, noisy_actions, timestep):
1148
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
1149
+ embs = []
1150
+ pad_masks = []
1151
+ att_masks = []
1152
+
1153
+ # Embed state
1154
+ state_emb = self.state_proj(state)
1155
+ state_emb = state_emb.to(dtype=torch.bfloat16)
1156
+ embs.append(state_emb[:, None, :])
1157
+ bsize = state_emb.shape[0]
1158
+ dtype = state_emb.dtype
1159
+ device = state_emb.device
1160
+
1161
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
1162
+ pad_masks.append(state_mask)
1163
+
1164
+ # Set attention masks so that image and language inputs do not attend to state or actions
1165
+ att_masks += [1]
1166
+
1167
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
1168
+ time_emb = create_sinusoidal_pos_embedding(
1169
+ timestep,
1170
+ self.config.proj_width,
1171
+ min_period=4e-3,
1172
+ max_period=4.0,
1173
+ device=device,
1174
+ )
1175
+ time_emb = time_emb.type(dtype=dtype)
1176
+
1177
+ # Fuse timestep + action information using an MLP
1178
+ action_emb = self.action_in_proj(noisy_actions)
1179
+
1180
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
1181
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
1182
+
1183
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
1184
+ action_time_emb = F.silu(action_time_emb) # swish == silu
1185
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
1186
+
1187
+ # Add to input tokens
1188
+ embs.append(action_time_emb)
1189
+
1190
+ bsize, action_time_dim = action_time_emb.shape[:2]
1191
+ action_time_mask = torch.ones(
1192
+ bsize, action_time_dim, dtype=torch.bool, device=device
1193
+ )
1194
+ pad_masks.append(action_time_mask)
1195
+
1196
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
1197
+ att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
1198
+
1199
+ embs = torch.cat(embs, dim=1)
1200
+ pad_masks = torch.cat(pad_masks, dim=1)
1201
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
1202
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1203
+
1204
+ return embs, pad_masks, att_masks
1205
+
1206
+ def forward(
1207
+ self,
1208
+ images,
1209
+ img_masks,
1210
+ lang_tokens,
1211
+ lang_masks,
1212
+ state,
1213
+ actions,
1214
+ noise=None,
1215
+ time=None,
1216
+ ) -> Tensor:
1217
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
1218
+ if noise is None:
1219
+ noise = self.sample_noise(actions.shape, actions.device)
1220
+
1221
+ if time is None:
1222
+ time = self.sample_time(actions.shape[0], actions.device)
1223
+ time_expanded = time[:, None, None]
1224
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
1225
+ u_t = noise - actions
1226
+
1227
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1228
+ images, img_masks, lang_tokens, lang_masks
1229
+ )
1230
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
1231
+ state, x_t, time
1232
+ )
1233
+
1234
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
1235
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
1236
+
1237
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
1238
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
1239
+
1240
+ (_, suffix_out), past_key_values = self.paligemma_with_expert.forward(
1241
+ attention_mask=att_2d_masks,
1242
+ position_ids=position_ids,
1243
+ past_key_values=None,
1244
+ inputs_embeds=[prefix_embs, suffix_embs],
1245
+ use_cache=True,
1246
+ fill_kv_cache=True,
1247
+ )
1248
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
1249
+ # Original openpi code, upcast attention output
1250
+ suffix_out = suffix_out.to(dtype=torch.float32)
1251
+ v_t = self.action_out_proj(suffix_out)
1252
+
1253
+ losses = F.mse_loss(u_t, v_t, reduction="none")
1254
+
1255
+ return losses, past_key_values
1256
+
1257
+ def sample_actions(
1258
+ self,
1259
+ images,
1260
+ img_masks,
1261
+ lang_tokens,
1262
+ lang_masks,
1263
+ state,
1264
+ noise=None,
1265
+ past_key_values=None,
1266
+ time_temp=1.0,
1267
+ noise_temp=1.0,
1268
+ ) -> Tensor:
1269
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
1270
+ bsize = state.shape[0]
1271
+ device = state.device
1272
+
1273
+ if noise is None:
1274
+ actions_shape = (
1275
+ bsize,
1276
+ self.config.n_action_steps,
1277
+ self.config.max_action_dim,
1278
+ )
1279
+ noise = self.sample_noise(actions_shape, device)
1280
+
1281
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1282
+ images, img_masks, lang_tokens, lang_masks
1283
+ )
1284
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
1285
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
1286
+
1287
+ # Compute image and language key value cache
1288
+ if past_key_values is None:
1289
+ _, past_key_values = self.paligemma_with_expert.forward(
1290
+ attention_mask=prefix_att_2d_masks,
1291
+ position_ids=prefix_position_ids,
1292
+ past_key_values=None,
1293
+ inputs_embeds=[prefix_embs, None],
1294
+ use_cache=self.config.use_cache,
1295
+ fill_kv_cache=True,
1296
+ )
1297
+
1298
+ dt = -1.0 / self.config.num_steps
1299
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
1300
+
1301
+ x_t = noise
1302
+ time = torch.tensor(
1303
+ time_temp, dtype=torch.float32, device=device
1304
+ ) # TODO: Add temp
1305
+ while time >= -dt / 2 + (1 - self.config.theta2):
1306
+ expanded_time = time.expand(bsize)
1307
+ v_t = self.denoise_step(
1308
+ state,
1309
+ prefix_pad_masks,
1310
+ past_key_values,
1311
+ x_t,
1312
+ expanded_time,
1313
+ )
1314
+
1315
+ # Euler step
1316
+ x_t += dt * v_t * noise_temp # TODO: Add noise temp
1317
+ time += dt
1318
+ return x_t
1319
+
1320
+ def denoise_step(
1321
+ self,
1322
+ state,
1323
+ prefix_pad_masks,
1324
+ past_key_values,
1325
+ x_t,
1326
+ timestep,
1327
+ ):
1328
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
1329
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
1330
+ state, x_t, timestep
1331
+ )
1332
+
1333
+ suffix_len = suffix_pad_masks.shape[1]
1334
+ batch_size = prefix_pad_masks.shape[0]
1335
+ prefix_len = prefix_pad_masks.shape[1]
1336
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
1337
+ batch_size, suffix_len, prefix_len
1338
+ )
1339
+
1340
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
1341
+
1342
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
1343
+
1344
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
1345
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
1346
+
1347
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
1348
+ attention_mask=full_att_2d_masks,
1349
+ position_ids=position_ids,
1350
+ past_key_values=past_key_values,
1351
+ inputs_embeds=[None, suffix_embs],
1352
+ use_cache=self.config.use_cache,
1353
+ fill_kv_cache=False,
1354
+ )
1355
+ suffix_out = outputs_embeds[1]
1356
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
1357
+ suffix_out = suffix_out.to(dtype=torch.float32)
1358
+ v_t = self.action_out_proj(suffix_out)
1359
+ return v_t
1360
+
1361
+
1362
+ class FastVisuoMatching(nn.Module):
1363
+ def __init__(self, config):
1364
+ super().__init__()
1365
+ self.config = config
1366
+
1367
+ # FastVisuoExpertConfig, FastVisuoExpertModel
1368
+ fast_visuo_expertConfig = FastVisuoExpertConfig(
1369
+ freeze_vision_encoder=self.config.freeze_s1_vision_encoder,
1370
+ attention_implementation=self.config.attention_implementation,
1371
+ dino_config=self.config.s1_dino_config,
1372
+ gemma_expert_config=self.config.s1_gemma_expert_config,
1373
+ )
1374
+ self.fast_visuo_expert = FastVisuoExpertModel(fast_visuo_expertConfig)
1375
+
1376
+ # Projections are float32
1377
+ self.state_proj = nn.Linear(
1378
+ self.config.max_state_dim, self.config.s1_proj_width
1379
+ )
1380
+ self.action_in_proj = nn.Linear(
1381
+ self.config.max_action_dim, self.config.s1_proj_width
1382
+ )
1383
+ self.action_out_proj = nn.Linear(
1384
+ self.config.s1_proj_width, self.config.max_action_dim
1385
+ )
1386
+ self.action_time_mlp_in = nn.Linear(
1387
+ self.config.s1_proj_width * 2, self.config.s1_proj_width
1388
+ )
1389
+ self.action_time_mlp_out = nn.Linear(
1390
+ self.config.s1_proj_width, self.config.s1_proj_width
1391
+ )
1392
+
1393
+ self.set_requires_grad()
1394
+
1395
+ def set_requires_grad(self):
1396
+ for params in self.state_proj.parameters():
1397
+ params.requires_grad = self.config.train_state_proj
1398
+
1399
+ def sample_noise(self, shape, device):
1400
+ noise = torch.normal(
1401
+ mean=0.0,
1402
+ std=1.0,
1403
+ size=shape,
1404
+ dtype=torch.float32,
1405
+ device=device,
1406
+ )
1407
+ return noise
1408
+
1409
+ def sample_time(self, bsize, device):
1410
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
1411
+ time = time_beta * 0.999 + 0.001
1412
+ return time.to(dtype=torch.float32, device=device)
1413
+
1414
+ def embed_prefix(
1415
+ self, images, img_masks
1416
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1417
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
1418
+ for PaliGemma transformer processing.
1419
+ """
1420
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
1421
+ embs = []
1422
+ pad_masks = []
1423
+ att_masks = []
1424
+
1425
+ # TODO: remove for loop
1426
+ for img, img_mask in zip(images, img_masks, strict=False):
1427
+ DINO_MEAN, DINO_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
1428
+ img = TF.normalize(img * 0.5 + 0.5, mean=DINO_MEAN, std=DINO_STD)
1429
+ img_emb = self.fast_visuo_expert.embed_image(img)
1430
+ img_emb = img_emb.to(dtype=torch.bfloat16)
1431
+
1432
+ # Normalize image embeddings
1433
+ img_emb_dim = img_emb.shape[-1]
1434
+ img_emb = img_emb * torch.tensor(
1435
+ img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
1436
+ )
1437
+
1438
+ bsize, num_img_embs = img_emb.shape[:2]
1439
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
1440
+
1441
+ embs.append(img_emb)
1442
+ pad_masks.append(img_mask)
1443
+
1444
+ # Create attention masks so that image tokens attend to each other
1445
+ att_masks += [0] * num_img_embs
1446
+
1447
+ embs = torch.cat(embs, dim=1)
1448
+ pad_masks = torch.cat(pad_masks, dim=1)
1449
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
1450
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1451
+
1452
+ return embs, pad_masks, att_masks
1453
+
1454
+ def embed_suffix(self, state, noisy_actions, timestep, stamp):
1455
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
1456
+ embs = []
1457
+ pad_masks = []
1458
+ att_masks = []
1459
+
1460
+ # Embed state
1461
+ state_emb = self.state_proj(state)
1462
+ state_emb = state_emb.to(dtype=torch.bfloat16)
1463
+ embs.append(state_emb)
1464
+ bsize = state_emb.shape[0]
1465
+ state_horizon = state_emb.shape[1]
1466
+ dtype = state_emb.dtype
1467
+ device = state_emb.device
1468
+
1469
+ state_mask = torch.ones(bsize, state_horizon, dtype=torch.bool, device=device)
1470
+ pad_masks.append(state_mask)
1471
+
1472
+ # Set attention masks so that image and language inputs do not attend to state or actions
1473
+ att_masks += [1] * state_horizon
1474
+
1475
+ # Embed stamp
1476
+ stamp_emb = create_sinusoidal_pos_embedding(
1477
+ stamp,
1478
+ self.config.s1_proj_width,
1479
+ min_period=4e-3,
1480
+ max_period=4.0,
1481
+ device=device,
1482
+ )
1483
+ stamp_emb = stamp_emb.type(dtype=dtype)[:, None, :]
1484
+ embs.append(stamp_emb)
1485
+ stamp_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
1486
+ pad_masks.append(stamp_mask)
1487
+ att_masks += [1]
1488
+
1489
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
1490
+ time_emb = create_sinusoidal_pos_embedding(
1491
+ timestep,
1492
+ self.config.s1_proj_width,
1493
+ min_period=4e-3,
1494
+ max_period=4.0,
1495
+ device=device,
1496
+ )
1497
+ time_emb = time_emb.type(dtype=dtype)
1498
+
1499
+ # Fuse timestep + action information using an MLP
1500
+ action_emb = self.action_in_proj(noisy_actions)
1501
+
1502
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
1503
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
1504
+
1505
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
1506
+ action_time_emb = F.silu(action_time_emb) # swish == silu
1507
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
1508
+
1509
+ # Add to input tokens
1510
+ embs.append(action_time_emb)
1511
+
1512
+ bsize, action_time_dim = action_time_emb.shape[:2]
1513
+ action_time_mask = torch.ones(
1514
+ bsize, action_time_dim, dtype=torch.bool, device=device
1515
+ )
1516
+ pad_masks.append(action_time_mask)
1517
+
1518
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
1519
+ att_masks += [1] + ([0] * (self.config.s1_action_steps - 1))
1520
+
1521
+ embs = torch.cat(embs, dim=1)
1522
+ pad_masks = torch.cat(pad_masks, dim=1)
1523
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
1524
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1525
+
1526
+ return embs, pad_masks, att_masks
1527
+
1528
+ def forward(
1529
+ self, images, img_masks, state, actions, noise=None, time=None, stamp=None
1530
+ ) -> Float[
1531
+ Tensor, "batch {self.config.s1_action_steps} {self.config.max_action_dim}"
1532
+ ]:
1533
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
1534
+ if noise is None:
1535
+ noise = self.sample_noise(actions.shape, actions.device)
1536
+ if time is None:
1537
+ time = (
1538
+ self.sample_time(actions.shape[0], actions.device) * self.config.theta1
1539
+ ) # s2: [1, 0.1] -> s1: [0.1, 0]
1540
+ time_expanded = time[:, None, None]
1541
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
1542
+ u_t = noise - actions
1543
+
1544
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1545
+ images, img_masks
1546
+ )
1547
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
1548
+ state, x_t, time, stamp
1549
+ )
1550
+
1551
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
1552
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
1553
+
1554
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
1555
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
1556
+
1557
+ inputs_embeds = torch.cat(
1558
+ [prefix_embs, suffix_embs], dim=1
1559
+ ) # torch.Size([16, 565]), torch.Size([16, 565])
1560
+
1561
+ suffix_out = self.fast_visuo_expert.forward(
1562
+ attention_mask=att_2d_masks,
1563
+ position_ids=position_ids,
1564
+ inputs_embeds=inputs_embeds,
1565
+ )
1566
+ suffix_out = suffix_out[:, -self.config.s1_action_steps :]
1567
+ # Original openpi code, upcast attention output
1568
+ suffix_out = suffix_out.to(dtype=torch.float32)
1569
+ v_t = self.action_out_proj(suffix_out)
1570
+
1571
+ losses = F.mse_loss(u_t, v_t, reduction="none")
1572
+ return losses
1573
+
1574
+ def sample_actions(
1575
+ self, images, img_masks, state, noise=None, stamp=None
1576
+ ) -> Tensor:
1577
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
1578
+ bsize = state.shape[0]
1579
+ device = state.device
1580
+
1581
+ if noise is None:
1582
+ actions_shape = (
1583
+ bsize,
1584
+ self.config.s1_action_steps,
1585
+ self.config.max_action_dim,
1586
+ )
1587
+ noise = self.sample_noise(actions_shape, device)
1588
+
1589
+ if stamp is None:
1590
+ stamp = torch.rand(bsize, device=device)
1591
+
1592
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1593
+ images, img_masks
1594
+ )
1595
+
1596
+ dt = -self.config.theta1 / self.config.s1_num_steps
1597
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
1598
+
1599
+ x_t = noise
1600
+ time = torch.tensor(self.config.theta1, dtype=torch.float32, device=device)
1601
+ while time >= -dt / 2:
1602
+ expanded_time = time.expand(bsize)
1603
+ v_t = self.denoise_step(
1604
+ state,
1605
+ prefix_embs,
1606
+ prefix_pad_masks,
1607
+ prefix_att_masks,
1608
+ x_t,
1609
+ expanded_time,
1610
+ stamp,
1611
+ )
1612
+ # Euler step
1613
+ x_t += dt * v_t
1614
+ time += dt
1615
+ return x_t
1616
+
1617
+ def denoise_step(
1618
+ self,
1619
+ state,
1620
+ prefix_embs,
1621
+ prefix_pad_masks,
1622
+ prefix_att_masks,
1623
+ x_t,
1624
+ timestep,
1625
+ stamp,
1626
+ ):
1627
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
1628
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
1629
+ state, x_t, timestep, stamp
1630
+ )
1631
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
1632
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
1633
+
1634
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
1635
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
1636
+
1637
+ inputs_embeds = torch.cat(
1638
+ [prefix_embs, suffix_embs], dim=1
1639
+ ) # torch.Size([16, 565]), torch.Size([16, 565])
1640
+ suffix_out = self.fast_visuo_expert.forward(
1641
+ attention_mask=att_2d_masks,
1642
+ position_ids=position_ids,
1643
+ inputs_embeds=inputs_embeds,
1644
+ )
1645
+ suffix_out = suffix_out[:, -self.config.s1_action_steps :]
1646
+ suffix_out = suffix_out.to(dtype=torch.float32)
1647
+ v_t = self.action_out_proj(suffix_out)
1648
+ return v_t
1649
+
1650
+
1651
+ class ValueQueryHead(nn.Module):
1652
+ def __init__(self, paligemma_with_expert, config):
1653
+ super().__init__()
1654
+ # gemma_expert for processing img and languge tokens
1655
+ # paligemma with export fot processing image features
1656
+ self.config = config
1657
+ self.paligemma_with_expert = paligemma_with_expert
1658
+
1659
+ vqh_backbone_config = VQHBackboneConfig()
1660
+ self.vqh_backbone = VQHBackbone(config=vqh_backbone_config)
1661
+
1662
+ cal_ql_config = CalQlConfig(
1663
+ obs_encoded_dim=self.paligemma_with_expert.config.paligemma_config.hidden_size,
1664
+ action_dim=config.vqh_chunk_size * config.action_feature.shape[0],
1665
+ actor_lr=config.actor_lr,
1666
+ critic_lr=config.critic_lr,
1667
+ temp_lr=config.temp_lr,
1668
+ )
1669
+ self.calql = CalQL(config=cal_ql_config)
1670
+
1671
+ self.query_embedding = nn.Parameter(
1672
+ torch.zeros(
1673
+ self.paligemma_with_expert.config.paligemma_config.hidden_size,
1674
+ dtype=torch.bfloat16,
1675
+ )
1676
+ )
1677
+
1678
+ def embed_prefix(
1679
+ self, images, img_masks, lang_tokens, lang_masks
1680
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1681
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
1682
+ for PaliGemma transformer processing.
1683
+ """
1684
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
1685
+ embs = []
1686
+ pad_masks = []
1687
+ att_masks = []
1688
+
1689
+ # TODO: remove for loop
1690
+ for (
1691
+ img,
1692
+ img_mask,
1693
+ ) in zip(images, img_masks, strict=False):
1694
+ img_emb = self.paligemma_with_expert.embed_image(img)
1695
+ img_emb = img_emb.to(dtype=torch.bfloat16)
1696
+
1697
+ # Normalize image embeddings
1698
+ img_emb_dim = img_emb.shape[-1]
1699
+ img_emb = img_emb * torch.tensor(
1700
+ img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
1701
+ )
1702
+
1703
+ bsize, num_img_embs = img_emb.shape[:2]
1704
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
1705
+
1706
+ embs.append(img_emb)
1707
+ pad_masks.append(img_mask)
1708
+
1709
+ # Create attention masks so that image tokens attend to each other
1710
+ att_masks += [0] * num_img_embs
1711
+
1712
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(
1713
+ lang_tokens
1714
+ ).detach()
1715
+
1716
+ # Normalize language embeddings
1717
+ lang_emb_dim = lang_emb.shape[-1]
1718
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
1719
+
1720
+ embs.append(lang_emb)
1721
+ pad_masks.append(lang_masks)
1722
+
1723
+ # full attention between image and language inputs
1724
+ num_lang_embs = lang_emb.shape[1]
1725
+ att_masks += [0] * num_lang_embs
1726
+
1727
+ embs = torch.cat(embs, dim=1)
1728
+ pad_masks = torch.cat(pad_masks, dim=1)
1729
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
1730
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1731
+
1732
+ # NOTE: add query embedding for each sequence
1733
+ seq_lengths = pad_masks.sum(dim=1).long() # w/o padding length
1734
+ seq_len = embs.shape[1]
1735
+
1736
+ new_seq_len = seq_len + 1
1737
+ new_embs = torch.zeros(
1738
+ (bsize, new_seq_len, embs.shape[-1]), dtype=embs.dtype, device=embs.device
1739
+ )
1740
+ new_pad_masks = torch.zeros(
1741
+ (bsize, new_seq_len), dtype=pad_masks.dtype, device=pad_masks.device
1742
+ )
1743
+ new_att_masks = torch.zeros(
1744
+ (bsize, new_seq_len), dtype=att_masks.dtype, device=att_masks.device
1745
+ )
1746
+
1747
+ batch_idx = torch.arange(bsize, device=embs.device).view(-1, 1)
1748
+ seq_idx = (
1749
+ torch.arange(seq_len, device=embs.device).view(1, -1).expand(bsize, -1)
1750
+ )
1751
+
1752
+ mask = seq_idx >= seq_lengths.unsqueeze(1)
1753
+ new_seq_idx = seq_idx + mask.long()
1754
+
1755
+ new_embs[batch_idx, new_seq_idx] = embs
1756
+ new_pad_masks[batch_idx, new_seq_idx] = pad_masks
1757
+ new_att_masks[batch_idx, new_seq_idx] = att_masks
1758
+ new_embs[torch.arange(bsize), seq_lengths] = self.query_embedding.unsqueeze(
1759
+ 0
1760
+ ).expand(bsize, -1)
1761
+ new_pad_masks[torch.arange(bsize), seq_lengths] = True
1762
+ new_att_masks[torch.arange(bsize), seq_lengths] = False
1763
+
1764
+ return new_embs, new_pad_masks, new_att_masks
1765
+
1766
+ def process_next_obs(
1767
+ self,
1768
+ images: list[torch.Tensor],
1769
+ img_masks: list[torch.Tensor],
1770
+ vqh_images: list[torch.Tensor],
1771
+ vqh_img_masks: list[torch.Tensor],
1772
+ lang_tokens: torch.Tensor,
1773
+ lang_masks: torch.Tensor,
1774
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor]:
1775
+ """Process next observation for ValueQueryHead model.
1776
+ Args:
1777
+ images (list): List of image tensors.
1778
+ img_masks (list): List of image mask tensors.
1779
+ vqh_images (list): List of ValueQueryHead image tensors.
1780
+ vqh_img_masks (list): List of ValueQueryHead image mask tensors.
1781
+ lang_tokens (torch.Tensor): Language token tensor.
1782
+ lang_masks (torch.Tensor): Language mask tensor.
1783
+
1784
+ Returns:
1785
+ tuple: Tuple containing processed images, masks, and language tokens.
1786
+ """
1787
+ new_images = []
1788
+ new_img_masks = []
1789
+
1790
+ for img, next_img, img_mask, next_img_mask in zip(
1791
+ images, vqh_images, img_masks, vqh_img_masks
1792
+ ):
1793
+ new_images.append(torch.cat([img, next_img], dim=0))
1794
+ new_img_masks.append(torch.cat([img_mask, next_img_mask], dim=0))
1795
+
1796
+ new_lang_tokens = torch.cat([lang_tokens, lang_tokens], dim=0)
1797
+ new_lang_masks = torch.cat([lang_masks, lang_masks], dim=0)
1798
+
1799
+ return (
1800
+ new_images,
1801
+ new_img_masks,
1802
+ new_lang_tokens,
1803
+ new_lang_masks,
1804
+ )
1805
+
1806
+ @jaxtyped(typechecker=typechecker)
1807
+ def forward(
1808
+ self,
1809
+ images: list[Float[Tensor, "batch 3 224 224"]],
1810
+ img_masks: list[Bool[Tensor, " batch"]],
1811
+ lang_tokens: Int64[Tensor, "batch seq_len"],
1812
+ lang_masks: Bool[Tensor, "batch seq_len"],
1813
+ vqh_images: list[Float[Tensor, "batch 3 224 224"]],
1814
+ vqh_img_masks: list[Bool[Tensor, " batch"]],
1815
+ actions: Float[
1816
+ Tensor,
1817
+ "batch {self.config.vqh_chunk_size} {self.config.action_feature.shape[0]}",
1818
+ ],
1819
+ rewards: Float[Tensor, " batch"],
1820
+ mc_returns: Float[Tensor, " batch"],
1821
+ masks: Float[Tensor, " batch"],
1822
+ ) -> tuple[Tensor, Tensor, Tensor, dict]:
1823
+ """Forward pass for ValueQueryHead model.
1824
+ Args:
1825
+ images (torch.Tensor): Image input tensor.
1826
+ img_masks (torch.Tensor): Image mask tensor.
1827
+ lang_tokens (torch.Tensor): Language token tensor.
1828
+ lang_masks (torch.Tensor): Language mask tensor.
1829
+
1830
+ Returns:
1831
+ tuple: Tuple containing the output tensors.
1832
+ """
1833
+ images, img_masks, lang_tokens, lang_masks = self.process_next_obs(
1834
+ images, img_masks, vqh_images, vqh_img_masks, lang_tokens, lang_masks
1835
+ )
1836
+
1837
+ embs, pad_masks, att_masks = self.embed_prefix(
1838
+ images, img_masks, lang_tokens, lang_masks
1839
+ )
1840
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
1841
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
1842
+
1843
+ suffix_out = self.vqh_backbone.forward(
1844
+ attention_mask=att_2d_masks,
1845
+ position_ids=position_ids,
1846
+ inputs_embeds=embs,
1847
+ ) # (2B, S, E)
1848
+
1849
+ batch_indices = torch.arange(suffix_out.shape[0], device=suffix_out.device)
1850
+ query_embedding_idx = pad_masks.sum(-1).long() - 1
1851
+ query_embedding = suffix_out[batch_indices, query_embedding_idx]
1852
+
1853
+ cal_ql_batch: at.CalQlBatch = dict(
1854
+ encoded_observations=query_embedding[
1855
+ : int(query_embedding.shape[0] / 2)
1856
+ ].to(dtype=torch.float32),
1857
+ encoded_next_observations=query_embedding[
1858
+ int(query_embedding.shape[0] / 2) :
1859
+ ].to(dtype=torch.float32),
1860
+ actions=actions.view(actions.shape[0], -1),
1861
+ rewards=rewards,
1862
+ mc_returns=mc_returns,
1863
+ masks=masks,
1864
+ )
1865
+ temperature_loss, policy_loss, critic_loss, log_dict = self.calql(cal_ql_batch)
1866
+
1867
+ return temperature_loss, policy_loss, critic_loss, log_dict
1868
+
1869
+ @jaxtyped(typechecker=typechecker)
1870
+ def select_q_actions(
1871
+ self,
1872
+ images: list[Float[Tensor, "Batch 3 224 224"]],
1873
+ img_masks: list[Bool[Tensor, " Batch"]],
1874
+ lang_tokens: Int64[Tensor, "Batch seq_len"],
1875
+ lang_masks: Bool[Tensor, "Batch seq_len"],
1876
+ noise_actions: Float[
1877
+ Tensor,
1878
+ "Batch s2_candidates_num {self.config.vqh_chunk_size} {self.config.action_feature.shape[0]}",
1879
+ ],
1880
+ ) -> tuple[Int64[Tensor, " Batch"], Float[Tensor, "Batch s2_candidates_num"]]:
1881
+ batch_size = noise_actions.shape[0]
1882
+ s2_candidates_num = noise_actions.shape[1]
1883
+ embs, pad_masks, att_masks = self.embed_prefix(
1884
+ images, img_masks, lang_tokens, lang_masks
1885
+ )
1886
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
1887
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
1888
+
1889
+ suffix_out = self.vqh_backbone.forward(
1890
+ attention_mask=att_2d_masks,
1891
+ position_ids=position_ids,
1892
+ inputs_embeds=embs,
1893
+ ) # (B, S, E)
1894
+
1895
+ batch_indices = torch.arange(suffix_out.shape[0], device=suffix_out.device)
1896
+ query_embedding_idx = pad_masks.sum(-1).long() - 1
1897
+ query_embedding = suffix_out[batch_indices, query_embedding_idx]
1898
+
1899
+ noise_actions = noise_actions.reshape(batch_size, s2_candidates_num, -1)
1900
+ q_values = self.calql.get_q_values(query_embedding, noise_actions)
1901
+
1902
+ action_index = torch.argmax(q_values, dim=1)
1903
+
1904
+ print(f"MaxValues: {q_values.max(dim=1)[0].tolist()}")
1905
+ print(f"MinValues: {q_values.min(dim=1)[0].tolist()}")
1906
+ print(f"MeanValues: {q_values.mean(dim=1)[0].tolist()}")
1907
+ print(f"ActionIndex: {action_index.tolist()}")
1908
+
1909
+ return action_index, q_values
paligemma_with_expert.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.version
5
+ from pytest import Cache
6
+ from torch import nn
7
+ from transformers import (
8
+ AutoConfig,
9
+ GemmaForCausalLM,
10
+ PaliGemmaForConditionalGeneration,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ )
14
+ from transformers.models.auto import CONFIG_MAPPING
15
+
16
+
17
+ def apply_rope(x, positions, max_wavelength=10_000):
18
+ """
19
+ Applies RoPE positions [B, L] to x [B, L, H, D].
20
+ """
21
+ d_half = x.shape[-1] // 2
22
+ device = x.device
23
+ dtype = x.dtype
24
+ x = x.to(torch.float32)
25
+
26
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
27
+ d_half, dtype=torch.float32, device=device
28
+ )
29
+ timescale = max_wavelength**freq_exponents
30
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
31
+ torch.float32
32
+ )
33
+
34
+ radians = radians[..., None, :]
35
+
36
+ sin = torch.sin(radians) # .to(dtype=dtype)
37
+ cos = torch.cos(radians) # .to(dtype=dtype)
38
+
39
+ x1, x2 = x.split(d_half, dim=-1)
40
+ res = torch.empty_like(x)
41
+ res[..., :d_half] = x1 * cos - x2 * sin
42
+ res[..., d_half:] = x2 * cos + x1 * sin
43
+
44
+ return res.to(dtype)
45
+
46
+
47
+ class PaliGemmaWithExpertConfig(PretrainedConfig):
48
+ model_type = "PaliGemmaWithExpertModel"
49
+ sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
50
+
51
+ def __init__(
52
+ self,
53
+ paligemma_config: dict | None = None,
54
+ gemma_expert_config: dict | None = None,
55
+ freeze_vision_encoder: bool = True,
56
+ train_expert_only: bool = True,
57
+ attention_implementation: str = "eager",
58
+ **kwargs,
59
+ ):
60
+ self.freeze_vision_encoder = freeze_vision_encoder
61
+ self.train_expert_only = train_expert_only
62
+ self.attention_implementation = attention_implementation
63
+
64
+ if paligemma_config is None:
65
+ self.paligemma_config = CONFIG_MAPPING["paligemma"](
66
+ transformers_version="4.48.1",
67
+ _vocab_size=257152,
68
+ bos_token_id=2,
69
+ eos_token_id=1,
70
+ hidden_size=2048,
71
+ image_token_index=257152,
72
+ model_type="paligemma",
73
+ pad_token_id=0,
74
+ projection_dim=2048,
75
+ text_config={
76
+ "hidden_activation": "gelu_pytorch_tanh",
77
+ "hidden_size": 2048,
78
+ "intermediate_size": 16384,
79
+ "model_type": "gemma",
80
+ "num_attention_heads": 8,
81
+ "num_hidden_layers": 18,
82
+ "num_image_tokens": 256,
83
+ "num_key_value_heads": 1,
84
+ "torch_dtype": "float32",
85
+ "vocab_size": 257152,
86
+ },
87
+ vision_config={
88
+ "hidden_size": 1152,
89
+ "intermediate_size": 4304,
90
+ "model_type": "siglip_vision_model",
91
+ "num_attention_heads": 16,
92
+ "num_hidden_layers": 27,
93
+ "num_image_tokens": 256,
94
+ "patch_size": 14,
95
+ "projection_dim": 2048,
96
+ "projector_hidden_act": "gelu_fast",
97
+ "torch_dtype": "float32",
98
+ "vision_use_head": False,
99
+ },
100
+ )
101
+ elif isinstance(paligemma_config, dict):
102
+ if "model_type" not in paligemma_config:
103
+ paligemma_config["model_type"] = "paligemma"
104
+
105
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
106
+ self.paligemma_config = cfg_cls(**paligemma_config)
107
+
108
+ if gemma_expert_config is None:
109
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
110
+ attention_bias=False,
111
+ attention_dropout=0.0,
112
+ bos_token_id=2,
113
+ eos_token_id=1,
114
+ head_dim=256,
115
+ hidden_act="gelu_pytorch_tanh",
116
+ hidden_activation="gelu_pytorch_tanh",
117
+ hidden_size=1024,
118
+ initializer_range=0.02,
119
+ intermediate_size=4096,
120
+ max_position_embeddings=8192,
121
+ model_type="gemma",
122
+ num_attention_heads=8,
123
+ num_hidden_layers=18,
124
+ num_key_value_heads=1,
125
+ pad_token_id=0,
126
+ rms_norm_eps=1e-06,
127
+ rope_theta=10000.0,
128
+ torch_dtype="float32",
129
+ transformers_version="4.48.1",
130
+ use_cache=True,
131
+ vocab_size=257152,
132
+ )
133
+ elif isinstance(gemma_expert_config, dict):
134
+ if "model_type" not in gemma_expert_config:
135
+ gemma_expert_config["model_type"] = "gemma"
136
+
137
+ cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
138
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
139
+
140
+ super().__init__(**kwargs)
141
+
142
+ def __post_init__(self):
143
+ super().__post_init__()
144
+ if self.train_expert_only and not self.freeze_vision_encoder:
145
+ raise ValueError(
146
+ "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
147
+ )
148
+
149
+ if self.attention_implementation not in ["eager", "fa2", "flex"]:
150
+ raise ValueError(
151
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
152
+ )
153
+
154
+
155
+ class PaliGemmaWithExpertModel(PreTrainedModel):
156
+ config_class = PaliGemmaWithExpertConfig
157
+
158
+ def __init__(self, config: PaliGemmaWithExpertConfig):
159
+ super().__init__(config=config)
160
+ self.config = config
161
+ self.paligemma = PaliGemmaForConditionalGeneration(
162
+ config=config.paligemma_config
163
+ )
164
+ self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
165
+ # Remove unused embed_tokens
166
+ self.gemma_expert.model.embed_tokens = None
167
+ self.gemma_expert.lm_head = None
168
+ self.to_bfloat16_like_physical_intelligence()
169
+ self.set_requires_grad()
170
+
171
+ def set_requires_grad(self):
172
+ if self.config.freeze_vision_encoder:
173
+ self.paligemma.vision_tower.eval()
174
+ for params in self.paligemma.vision_tower.parameters():
175
+ params.requires_grad = False
176
+
177
+ if self.config.train_expert_only:
178
+ self.paligemma.eval()
179
+ for params in self.paligemma.parameters():
180
+ params.requires_grad = False
181
+
182
+ def train(self, mode: bool = True):
183
+ super().train(mode)
184
+
185
+ if self.config.freeze_vision_encoder:
186
+ self.paligemma.vision_tower.eval()
187
+
188
+ if self.config.train_expert_only:
189
+ self.paligemma.eval()
190
+
191
+ def to_bfloat16_like_physical_intelligence(self):
192
+ self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
193
+
194
+ params_to_change_dtype = [
195
+ "language_model.model.layers",
196
+ "gemma_expert.model.layers",
197
+ "vision_tower",
198
+ "multi_modal",
199
+ ]
200
+ for name, param in self.named_parameters():
201
+ if any(selector in name for selector in params_to_change_dtype):
202
+ param.data = param.data.to(dtype=torch.bfloat16)
203
+
204
+ def embed_image(self, image: torch.Tensor):
205
+ return self.paligemma.get_image_features(image)
206
+
207
+ def embed_language_tokens(self, tokens: torch.Tensor):
208
+ return self.paligemma.language_model.model.embed_tokens(tokens)
209
+
210
+ # TODO: break down this huge forward into modules or functions
211
+ def forward(
212
+ self,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ position_ids: Optional[torch.LongTensor] = None,
215
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
216
+ inputs_embeds: List[torch.FloatTensor] = None,
217
+ use_cache: Optional[bool] = None,
218
+ fill_kv_cache: Optional[bool] = None,
219
+ ):
220
+ models = [self.paligemma.language_model.model, self.gemma_expert.model]
221
+
222
+ for hidden_states in inputs_embeds:
223
+ # TODO this is very inefficient
224
+ # dtype is always the same, batch size too (if > 1 len)
225
+ # device could be trickier in multi gpu edge cases but that's it
226
+ if hidden_states is None:
227
+ continue
228
+ batch_size = hidden_states.shape[0]
229
+
230
+ # RMSNorm
231
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
232
+ head_dim = self.paligemma.config.text_config.head_dim
233
+ for layer_idx in range(num_layers):
234
+ query_states = []
235
+ key_states = []
236
+ value_states = []
237
+ for i, hidden_states in enumerate(inputs_embeds):
238
+ if hidden_states is None:
239
+ continue
240
+ layer = models[i].layers[layer_idx]
241
+ # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
242
+ # hidden_states = hidden_states * normalizer
243
+ hidden_states = layer.input_layernorm(hidden_states)
244
+
245
+ input_shape = hidden_states.shape[
246
+ :-1
247
+ ] # (b s e) -> layer* -> (b s e) -> mlp
248
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) # (b s h d)
249
+
250
+ hidden_states = hidden_states.to(dtype=torch.bfloat16)
251
+
252
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
253
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
254
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
255
+
256
+ query_states.append(query_state)
257
+ key_states.append(key_state)
258
+ value_states.append(value_state)
259
+
260
+ # B,L,H,D with L sequence length, H number of heads, D head dim
261
+ # concatenate on the number of embeddings/tokens
262
+ query_states = torch.cat(query_states, dim=1)
263
+ key_states = torch.cat(key_states, dim=1)
264
+ value_states = torch.cat(value_states, dim=1)
265
+
266
+ query_states = apply_rope(query_states, position_ids)
267
+ key_states = apply_rope(key_states, position_ids)
268
+
269
+ if use_cache and past_key_values is None:
270
+ past_key_values = {}
271
+
272
+ if use_cache:
273
+ if fill_kv_cache:
274
+ past_key_values[layer_idx] = {
275
+ "key_states": key_states,
276
+ "value_states": value_states,
277
+ }
278
+ else:
279
+ # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
280
+ # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
281
+ # the max len, then we (for instance) double the cache size. This implementation already exists
282
+ # in `transformers`. (molbap)
283
+ key_states = torch.cat(
284
+ [past_key_values[layer_idx]["key_states"], key_states], dim=1
285
+ )
286
+ value_states = torch.cat(
287
+ [past_key_values[layer_idx]["value_states"], value_states],
288
+ dim=1,
289
+ )
290
+
291
+ attention_interface = self.get_attention_interface()
292
+ att_output = attention_interface(
293
+ attention_mask,
294
+ batch_size,
295
+ head_dim,
296
+ query_states,
297
+ key_states,
298
+ value_states,
299
+ )
300
+ att_output = att_output.to(dtype=torch.bfloat16)
301
+
302
+ # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
303
+ outputs_embeds = []
304
+ start = 0
305
+ for i, hidden_states in enumerate(inputs_embeds):
306
+ layer = models[i].layers[layer_idx]
307
+
308
+ if hidden_states is not None:
309
+ end = start + hidden_states.shape[1]
310
+
311
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
312
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
313
+ out_emb = layer.self_attn.o_proj(att_output[:, start:end])
314
+
315
+ # TODO: first dropout (by default 0.0)
316
+
317
+ # first residual
318
+ out_emb += hidden_states
319
+ after_first_residual = out_emb.clone()
320
+
321
+ out_emb = layer.post_attention_layernorm(out_emb)
322
+ out_emb = layer.mlp(out_emb)
323
+
324
+ # TODO: second dropout (by default 0.0)
325
+
326
+ # second residual
327
+ out_emb += after_first_residual
328
+
329
+ outputs_embeds.append(out_emb)
330
+
331
+ start = end
332
+ else:
333
+ outputs_embeds.append(None)
334
+
335
+ inputs_embeds = outputs_embeds
336
+
337
+ # final norm
338
+ outputs_embeds = []
339
+ for i, hidden_states in enumerate(inputs_embeds):
340
+ if hidden_states is not None:
341
+ out_emb = models[i].norm(hidden_states)
342
+ outputs_embeds.append(out_emb)
343
+ else:
344
+ outputs_embeds.append(None)
345
+
346
+ return outputs_embeds, past_key_values
347
+
348
+ def get_attention_interface(self):
349
+ if self.config.attention_implementation == "fa2":
350
+ attention_interface = self.flash_attention_forward
351
+ else:
352
+ attention_interface = self.eager_attention_forward
353
+ return attention_interface
354
+
355
+ def flash_attention_forward(
356
+ self,
357
+ attention_mask,
358
+ batch_size,
359
+ head_dim,
360
+ query_states,
361
+ key_states,
362
+ value_states,
363
+ ):
364
+ raise NotImplementedError("FA2 is not implemented (yet)")
365
+
366
+ def eager_attention_forward(
367
+ self,
368
+ attention_mask,
369
+ batch_size,
370
+ head_dim,
371
+ query_states,
372
+ key_states,
373
+ value_states,
374
+ ):
375
+ num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
376
+ num_key_value_heads = (
377
+ self.config.paligemma_config.text_config.num_key_value_heads
378
+ )
379
+ num_key_value_groups = num_att_heads // num_key_value_heads
380
+
381
+ # query_states: batch_size, sequence_length, num_att_head, head_dim
382
+ # key_states: batch_size, sequence_length, num_key_value_head, head_dim
383
+ # value_states: batch_size, sequence_length, num_key_value_head, head_dim
384
+ sequence_length = key_states.shape[1]
385
+
386
+ key_states = key_states[:, :, :, None, :].expand(
387
+ batch_size,
388
+ sequence_length,
389
+ num_key_value_heads,
390
+ num_key_value_groups,
391
+ head_dim,
392
+ )
393
+ key_states = key_states.reshape(
394
+ batch_size,
395
+ sequence_length,
396
+ num_key_value_heads * num_key_value_groups,
397
+ head_dim,
398
+ )
399
+
400
+ value_states = value_states[:, :, :, None, :].expand(
401
+ batch_size,
402
+ sequence_length,
403
+ num_key_value_heads,
404
+ num_key_value_groups,
405
+ head_dim,
406
+ )
407
+ value_states = value_states.reshape(
408
+ batch_size,
409
+ sequence_length,
410
+ num_key_value_heads * num_key_value_groups,
411
+ head_dim,
412
+ )
413
+
414
+ # Attention here is upcasted to float32 to match the original eager implementation.
415
+
416
+ query_states = query_states.to(dtype=torch.float32)
417
+ key_states = key_states.to(dtype=torch.float32)
418
+
419
+ query_states = query_states.transpose(1, 2)
420
+ key_states = key_states.transpose(1, 2)
421
+
422
+ att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
423
+ att_weights *= head_dim**-0.5
424
+ big_neg = -2.3819763e38 # See gemma/modules.py
425
+
426
+ masked_att_weights = torch.where(
427
+ attention_mask[:, None, :, :], att_weights, big_neg
428
+ )
429
+
430
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
431
+ probs = probs.to(dtype=value_states.dtype)
432
+
433
+ # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
434
+ # value_states: batch_size, sequence_length, num_att_heads, head_dim
435
+
436
+ att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
437
+
438
+ att_output = att_output.permute(0, 2, 1, 3)
439
+ # we use -1 because sequence length can change
440
+ att_output = att_output.reshape(
441
+ batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
442
+ )
443
+
444
+ return att_output
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<image>"
4
+ ],
5
+ "bos_token": {
6
+ "content": "<bos>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eos_token": {
13
+ "content": "<eos>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:537cbe6b94581ee7b70f7f39453d5c52f2590069aa75d76ceee458fde442523c
3
+ size 34387383
tokenizer_config.json ADDED
@@ -0,0 +1,1772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<pad>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<eos>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<bos>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "4": {
38
+ "content": "<mask>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": false
44
+ },
45
+ "5": {
46
+ "content": "<2mass>",
47
+ "lstrip": false,
48
+ "normalized": true,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": false
52
+ },
53
+ "6": {
54
+ "content": "[@BOS@]",
55
+ "lstrip": false,
56
+ "normalized": true,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": false
60
+ },
61
+ "7": {
62
+ "content": "<unused0>",
63
+ "lstrip": false,
64
+ "normalized": true,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": false
68
+ },
69
+ "8": {
70
+ "content": "<unused1>",
71
+ "lstrip": false,
72
+ "normalized": true,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": false
76
+ },
77
+ "9": {
78
+ "content": "<unused2>",
79
+ "lstrip": false,
80
+ "normalized": true,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": false
84
+ },
85
+ "10": {
86
+ "content": "<unused3>",
87
+ "lstrip": false,
88
+ "normalized": true,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": false
92
+ },
93
+ "11": {
94
+ "content": "<unused4>",
95
+ "lstrip": false,
96
+ "normalized": true,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": false
100
+ },
101
+ "12": {
102
+ "content": "<unused5>",
103
+ "lstrip": false,
104
+ "normalized": true,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": false
108
+ },
109
+ "13": {
110
+ "content": "<unused6>",
111
+ "lstrip": false,
112
+ "normalized": true,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": false
116
+ },
117
+ "14": {
118
+ "content": "<unused7>",
119
+ "lstrip": false,
120
+ "normalized": true,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "15": {
126
+ "content": "<unused8>",
127
+ "lstrip": false,
128
+ "normalized": true,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "16": {
134
+ "content": "<unused9>",
135
+ "lstrip": false,
136
+ "normalized": true,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "17": {
142
+ "content": "<unused10>",
143
+ "lstrip": false,
144
+ "normalized": true,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "18": {
150
+ "content": "<unused11>",
151
+ "lstrip": false,
152
+ "normalized": true,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "19": {
158
+ "content": "<unused12>",
159
+ "lstrip": false,
160
+ "normalized": true,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "20": {
166
+ "content": "<unused13>",
167
+ "lstrip": false,
168
+ "normalized": true,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "21": {
174
+ "content": "<unused14>",
175
+ "lstrip": false,
176
+ "normalized": true,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "22": {
182
+ "content": "<unused15>",
183
+ "lstrip": false,
184
+ "normalized": true,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "23": {
190
+ "content": "<unused16>",
191
+ "lstrip": false,
192
+ "normalized": true,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "24": {
198
+ "content": "<unused17>",
199
+ "lstrip": false,
200
+ "normalized": true,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "25": {
206
+ "content": "<unused18>",
207
+ "lstrip": false,
208
+ "normalized": true,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "26": {
214
+ "content": "<unused19>",
215
+ "lstrip": false,
216
+ "normalized": true,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": false
220
+ },
221
+ "27": {
222
+ "content": "<unused20>",
223
+ "lstrip": false,
224
+ "normalized": true,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": false
228
+ },
229
+ "28": {
230
+ "content": "<unused21>",
231
+ "lstrip": false,
232
+ "normalized": true,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": false
236
+ },
237
+ "29": {
238
+ "content": "<unused22>",
239
+ "lstrip": false,
240
+ "normalized": true,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": false
244
+ },
245
+ "30": {
246
+ "content": "<unused23>",
247
+ "lstrip": false,
248
+ "normalized": true,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": false
252
+ },
253
+ "31": {
254
+ "content": "<unused24>",
255
+ "lstrip": false,
256
+ "normalized": true,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": false
260
+ },
261
+ "32": {
262
+ "content": "<unused25>",
263
+ "lstrip": false,
264
+ "normalized": true,
265
+ "rstrip": false,
266
+ "single_word": false,
267
+ "special": false
268
+ },
269
+ "33": {
270
+ "content": "<unused26>",
271
+ "lstrip": false,
272
+ "normalized": true,
273
+ "rstrip": false,
274
+ "single_word": false,
275
+ "special": false
276
+ },
277
+ "34": {
278
+ "content": "<unused27>",
279
+ "lstrip": false,
280
+ "normalized": true,
281
+ "rstrip": false,
282
+ "single_word": false,
283
+ "special": false
284
+ },
285
+ "35": {
286
+ "content": "<unused28>",
287
+ "lstrip": false,
288
+ "normalized": true,
289
+ "rstrip": false,
290
+ "single_word": false,
291
+ "special": false
292
+ },
293
+ "36": {
294
+ "content": "<unused29>",
295
+ "lstrip": false,
296
+ "normalized": true,
297
+ "rstrip": false,
298
+ "single_word": false,
299
+ "special": false
300
+ },
301
+ "37": {
302
+ "content": "<unused30>",
303
+ "lstrip": false,
304
+ "normalized": true,
305
+ "rstrip": false,
306
+ "single_word": false,
307
+ "special": false
308
+ },
309
+ "38": {
310
+ "content": "<unused31>",
311
+ "lstrip": false,
312
+ "normalized": true,
313
+ "rstrip": false,
314
+ "single_word": false,
315
+ "special": false
316
+ },
317
+ "39": {
318
+ "content": "<unused32>",
319
+ "lstrip": false,
320
+ "normalized": true,
321
+ "rstrip": false,
322
+ "single_word": false,
323
+ "special": false
324
+ },
325
+ "40": {
326
+ "content": "<unused33>",
327
+ "lstrip": false,
328
+ "normalized": true,
329
+ "rstrip": false,
330
+ "single_word": false,
331
+ "special": false
332
+ },
333
+ "41": {
334
+ "content": "<unused34>",
335
+ "lstrip": false,
336
+ "normalized": true,
337
+ "rstrip": false,
338
+ "single_word": false,
339
+ "special": false
340
+ },
341
+ "42": {
342
+ "content": "<unused35>",
343
+ "lstrip": false,
344
+ "normalized": true,
345
+ "rstrip": false,
346
+ "single_word": false,
347
+ "special": false
348
+ },
349
+ "43": {
350
+ "content": "<unused36>",
351
+ "lstrip": false,
352
+ "normalized": true,
353
+ "rstrip": false,
354
+ "single_word": false,
355
+ "special": false
356
+ },
357
+ "44": {
358
+ "content": "<unused37>",
359
+ "lstrip": false,
360
+ "normalized": true,
361
+ "rstrip": false,
362
+ "single_word": false,
363
+ "special": false
364
+ },
365
+ "45": {
366
+ "content": "<unused38>",
367
+ "lstrip": false,
368
+ "normalized": true,
369
+ "rstrip": false,
370
+ "single_word": false,
371
+ "special": false
372
+ },
373
+ "46": {
374
+ "content": "<unused39>",
375
+ "lstrip": false,
376
+ "normalized": true,
377
+ "rstrip": false,
378
+ "single_word": false,
379
+ "special": false
380
+ },
381
+ "47": {
382
+ "content": "<unused40>",
383
+ "lstrip": false,
384
+ "normalized": true,
385
+ "rstrip": false,
386
+ "single_word": false,
387
+ "special": false
388
+ },
389
+ "48": {
390
+ "content": "<unused41>",
391
+ "lstrip": false,
392
+ "normalized": true,
393
+ "rstrip": false,
394
+ "single_word": false,
395
+ "special": false
396
+ },
397
+ "49": {
398
+ "content": "<unused42>",
399
+ "lstrip": false,
400
+ "normalized": true,
401
+ "rstrip": false,
402
+ "single_word": false,
403
+ "special": false
404
+ },
405
+ "50": {
406
+ "content": "<unused43>",
407
+ "lstrip": false,
408
+ "normalized": true,
409
+ "rstrip": false,
410
+ "single_word": false,
411
+ "special": false
412
+ },
413
+ "51": {
414
+ "content": "<unused44>",
415
+ "lstrip": false,
416
+ "normalized": true,
417
+ "rstrip": false,
418
+ "single_word": false,
419
+ "special": false
420
+ },
421
+ "52": {
422
+ "content": "<unused45>",
423
+ "lstrip": false,
424
+ "normalized": true,
425
+ "rstrip": false,
426
+ "single_word": false,
427
+ "special": false
428
+ },
429
+ "53": {
430
+ "content": "<unused46>",
431
+ "lstrip": false,
432
+ "normalized": true,
433
+ "rstrip": false,
434
+ "single_word": false,
435
+ "special": false
436
+ },
437
+ "54": {
438
+ "content": "<unused47>",
439
+ "lstrip": false,
440
+ "normalized": true,
441
+ "rstrip": false,
442
+ "single_word": false,
443
+ "special": false
444
+ },
445
+ "55": {
446
+ "content": "<unused48>",
447
+ "lstrip": false,
448
+ "normalized": true,
449
+ "rstrip": false,
450
+ "single_word": false,
451
+ "special": false
452
+ },
453
+ "56": {
454
+ "content": "<unused49>",
455
+ "lstrip": false,
456
+ "normalized": true,
457
+ "rstrip": false,
458
+ "single_word": false,
459
+ "special": false
460
+ },
461
+ "57": {
462
+ "content": "<unused50>",
463
+ "lstrip": false,
464
+ "normalized": true,
465
+ "rstrip": false,
466
+ "single_word": false,
467
+ "special": false
468
+ },
469
+ "58": {
470
+ "content": "<unused51>",
471
+ "lstrip": false,
472
+ "normalized": true,
473
+ "rstrip": false,
474
+ "single_word": false,
475
+ "special": false
476
+ },
477
+ "59": {
478
+ "content": "<unused52>",
479
+ "lstrip": false,
480
+ "normalized": true,
481
+ "rstrip": false,
482
+ "single_word": false,
483
+ "special": false
484
+ },
485
+ "60": {
486
+ "content": "<unused53>",
487
+ "lstrip": false,
488
+ "normalized": true,
489
+ "rstrip": false,
490
+ "single_word": false,
491
+ "special": false
492
+ },
493
+ "61": {
494
+ "content": "<unused54>",
495
+ "lstrip": false,
496
+ "normalized": true,
497
+ "rstrip": false,
498
+ "single_word": false,
499
+ "special": false
500
+ },
501
+ "62": {
502
+ "content": "<unused55>",
503
+ "lstrip": false,
504
+ "normalized": true,
505
+ "rstrip": false,
506
+ "single_word": false,
507
+ "special": false
508
+ },
509
+ "63": {
510
+ "content": "<unused56>",
511
+ "lstrip": false,
512
+ "normalized": true,
513
+ "rstrip": false,
514
+ "single_word": false,
515
+ "special": false
516
+ },
517
+ "64": {
518
+ "content": "<unused57>",
519
+ "lstrip": false,
520
+ "normalized": true,
521
+ "rstrip": false,
522
+ "single_word": false,
523
+ "special": false
524
+ },
525
+ "65": {
526
+ "content": "<unused58>",
527
+ "lstrip": false,
528
+ "normalized": true,
529
+ "rstrip": false,
530
+ "single_word": false,
531
+ "special": false
532
+ },
533
+ "66": {
534
+ "content": "<unused59>",
535
+ "lstrip": false,
536
+ "normalized": true,
537
+ "rstrip": false,
538
+ "single_word": false,
539
+ "special": false
540
+ },
541
+ "67": {
542
+ "content": "<unused60>",
543
+ "lstrip": false,
544
+ "normalized": true,
545
+ "rstrip": false,
546
+ "single_word": false,
547
+ "special": false
548
+ },
549
+ "68": {
550
+ "content": "<unused61>",
551
+ "lstrip": false,
552
+ "normalized": true,
553
+ "rstrip": false,
554
+ "single_word": false,
555
+ "special": false
556
+ },
557
+ "69": {
558
+ "content": "<unused62>",
559
+ "lstrip": false,
560
+ "normalized": true,
561
+ "rstrip": false,
562
+ "single_word": false,
563
+ "special": false
564
+ },
565
+ "70": {
566
+ "content": "<unused63>",
567
+ "lstrip": false,
568
+ "normalized": true,
569
+ "rstrip": false,
570
+ "single_word": false,
571
+ "special": false
572
+ },
573
+ "71": {
574
+ "content": "<unused64>",
575
+ "lstrip": false,
576
+ "normalized": true,
577
+ "rstrip": false,
578
+ "single_word": false,
579
+ "special": false
580
+ },
581
+ "72": {
582
+ "content": "<unused65>",
583
+ "lstrip": false,
584
+ "normalized": true,
585
+ "rstrip": false,
586
+ "single_word": false,
587
+ "special": false
588
+ },
589
+ "73": {
590
+ "content": "<unused66>",
591
+ "lstrip": false,
592
+ "normalized": true,
593
+ "rstrip": false,
594
+ "single_word": false,
595
+ "special": false
596
+ },
597
+ "74": {
598
+ "content": "<unused67>",
599
+ "lstrip": false,
600
+ "normalized": true,
601
+ "rstrip": false,
602
+ "single_word": false,
603
+ "special": false
604
+ },
605
+ "75": {
606
+ "content": "<unused68>",
607
+ "lstrip": false,
608
+ "normalized": true,
609
+ "rstrip": false,
610
+ "single_word": false,
611
+ "special": false
612
+ },
613
+ "76": {
614
+ "content": "<unused69>",
615
+ "lstrip": false,
616
+ "normalized": true,
617
+ "rstrip": false,
618
+ "single_word": false,
619
+ "special": false
620
+ },
621
+ "77": {
622
+ "content": "<unused70>",
623
+ "lstrip": false,
624
+ "normalized": true,
625
+ "rstrip": false,
626
+ "single_word": false,
627
+ "special": false
628
+ },
629
+ "78": {
630
+ "content": "<unused71>",
631
+ "lstrip": false,
632
+ "normalized": true,
633
+ "rstrip": false,
634
+ "single_word": false,
635
+ "special": false
636
+ },
637
+ "79": {
638
+ "content": "<unused72>",
639
+ "lstrip": false,
640
+ "normalized": true,
641
+ "rstrip": false,
642
+ "single_word": false,
643
+ "special": false
644
+ },
645
+ "80": {
646
+ "content": "<unused73>",
647
+ "lstrip": false,
648
+ "normalized": true,
649
+ "rstrip": false,
650
+ "single_word": false,
651
+ "special": false
652
+ },
653
+ "81": {
654
+ "content": "<unused74>",
655
+ "lstrip": false,
656
+ "normalized": true,
657
+ "rstrip": false,
658
+ "single_word": false,
659
+ "special": false
660
+ },
661
+ "82": {
662
+ "content": "<unused75>",
663
+ "lstrip": false,
664
+ "normalized": true,
665
+ "rstrip": false,
666
+ "single_word": false,
667
+ "special": false
668
+ },
669
+ "83": {
670
+ "content": "<unused76>",
671
+ "lstrip": false,
672
+ "normalized": true,
673
+ "rstrip": false,
674
+ "single_word": false,
675
+ "special": false
676
+ },
677
+ "84": {
678
+ "content": "<unused77>",
679
+ "lstrip": false,
680
+ "normalized": true,
681
+ "rstrip": false,
682
+ "single_word": false,
683
+ "special": false
684
+ },
685
+ "85": {
686
+ "content": "<unused78>",
687
+ "lstrip": false,
688
+ "normalized": true,
689
+ "rstrip": false,
690
+ "single_word": false,
691
+ "special": false
692
+ },
693
+ "86": {
694
+ "content": "<unused79>",
695
+ "lstrip": false,
696
+ "normalized": true,
697
+ "rstrip": false,
698
+ "single_word": false,
699
+ "special": false
700
+ },
701
+ "87": {
702
+ "content": "<unused80>",
703
+ "lstrip": false,
704
+ "normalized": true,
705
+ "rstrip": false,
706
+ "single_word": false,
707
+ "special": false
708
+ },
709
+ "88": {
710
+ "content": "<unused81>",
711
+ "lstrip": false,
712
+ "normalized": true,
713
+ "rstrip": false,
714
+ "single_word": false,
715
+ "special": false
716
+ },
717
+ "89": {
718
+ "content": "<unused82>",
719
+ "lstrip": false,
720
+ "normalized": true,
721
+ "rstrip": false,
722
+ "single_word": false,
723
+ "special": false
724
+ },
725
+ "90": {
726
+ "content": "<unused83>",
727
+ "lstrip": false,
728
+ "normalized": true,
729
+ "rstrip": false,
730
+ "single_word": false,
731
+ "special": false
732
+ },
733
+ "91": {
734
+ "content": "<unused84>",
735
+ "lstrip": false,
736
+ "normalized": true,
737
+ "rstrip": false,
738
+ "single_word": false,
739
+ "special": false
740
+ },
741
+ "92": {
742
+ "content": "<unused85>",
743
+ "lstrip": false,
744
+ "normalized": true,
745
+ "rstrip": false,
746
+ "single_word": false,
747
+ "special": false
748
+ },
749
+ "93": {
750
+ "content": "<unused86>",
751
+ "lstrip": false,
752
+ "normalized": true,
753
+ "rstrip": false,
754
+ "single_word": false,
755
+ "special": false
756
+ },
757
+ "94": {
758
+ "content": "<unused87>",
759
+ "lstrip": false,
760
+ "normalized": true,
761
+ "rstrip": false,
762
+ "single_word": false,
763
+ "special": false
764
+ },
765
+ "95": {
766
+ "content": "<unused88>",
767
+ "lstrip": false,
768
+ "normalized": true,
769
+ "rstrip": false,
770
+ "single_word": false,
771
+ "special": false
772
+ },
773
+ "96": {
774
+ "content": "<unused89>",
775
+ "lstrip": false,
776
+ "normalized": true,
777
+ "rstrip": false,
778
+ "single_word": false,
779
+ "special": false
780
+ },
781
+ "97": {
782
+ "content": "<unused90>",
783
+ "lstrip": false,
784
+ "normalized": true,
785
+ "rstrip": false,
786
+ "single_word": false,
787
+ "special": false
788
+ },
789
+ "98": {
790
+ "content": "<unused91>",
791
+ "lstrip": false,
792
+ "normalized": true,
793
+ "rstrip": false,
794
+ "single_word": false,
795
+ "special": false
796
+ },
797
+ "99": {
798
+ "content": "<unused92>",
799
+ "lstrip": false,
800
+ "normalized": true,
801
+ "rstrip": false,
802
+ "single_word": false,
803
+ "special": false
804
+ },
805
+ "100": {
806
+ "content": "<unused93>",
807
+ "lstrip": false,
808
+ "normalized": true,
809
+ "rstrip": false,
810
+ "single_word": false,
811
+ "special": false
812
+ },
813
+ "101": {
814
+ "content": "<unused94>",
815
+ "lstrip": false,
816
+ "normalized": true,
817
+ "rstrip": false,
818
+ "single_word": false,
819
+ "special": false
820
+ },
821
+ "102": {
822
+ "content": "<unused95>",
823
+ "lstrip": false,
824
+ "normalized": true,
825
+ "rstrip": false,
826
+ "single_word": false,
827
+ "special": false
828
+ },
829
+ "103": {
830
+ "content": "<unused96>",
831
+ "lstrip": false,
832
+ "normalized": true,
833
+ "rstrip": false,
834
+ "single_word": false,
835
+ "special": false
836
+ },
837
+ "104": {
838
+ "content": "<unused97>",
839
+ "lstrip": false,
840
+ "normalized": true,
841
+ "rstrip": false,
842
+ "single_word": false,
843
+ "special": false
844
+ },
845
+ "105": {
846
+ "content": "<unused98>",
847
+ "lstrip": false,
848
+ "normalized": true,
849
+ "rstrip": false,
850
+ "single_word": false,
851
+ "special": false
852
+ },
853
+ "106": {
854
+ "content": "<start_of_turn>",
855
+ "lstrip": false,
856
+ "normalized": true,
857
+ "rstrip": false,
858
+ "single_word": false,
859
+ "special": false
860
+ },
861
+ "107": {
862
+ "content": "<end_of_turn>",
863
+ "lstrip": false,
864
+ "normalized": true,
865
+ "rstrip": false,
866
+ "single_word": false,
867
+ "special": false
868
+ },
869
+ "108": {
870
+ "content": "\n",
871
+ "lstrip": false,
872
+ "normalized": true,
873
+ "rstrip": false,
874
+ "single_word": false,
875
+ "special": false
876
+ },
877
+ "109": {
878
+ "content": "\n\n",
879
+ "lstrip": false,
880
+ "normalized": true,
881
+ "rstrip": false,
882
+ "single_word": false,
883
+ "special": false
884
+ },
885
+ "110": {
886
+ "content": "\n\n\n",
887
+ "lstrip": false,
888
+ "normalized": true,
889
+ "rstrip": false,
890
+ "single_word": false,
891
+ "special": false
892
+ },
893
+ "111": {
894
+ "content": "\n\n\n\n",
895
+ "lstrip": false,
896
+ "normalized": true,
897
+ "rstrip": false,
898
+ "single_word": false,
899
+ "special": false
900
+ },
901
+ "112": {
902
+ "content": "\n\n\n\n\n",
903
+ "lstrip": false,
904
+ "normalized": true,
905
+ "rstrip": false,
906
+ "single_word": false,
907
+ "special": false
908
+ },
909
+ "113": {
910
+ "content": "\n\n\n\n\n\n",
911
+ "lstrip": false,
912
+ "normalized": true,
913
+ "rstrip": false,
914
+ "single_word": false,
915
+ "special": false
916
+ },
917
+ "114": {
918
+ "content": "\n\n\n\n\n\n\n",
919
+ "lstrip": false,
920
+ "normalized": true,
921
+ "rstrip": false,
922
+ "single_word": false,
923
+ "special": false
924
+ },
925
+ "115": {
926
+ "content": "\n\n\n\n\n\n\n\n",
927
+ "lstrip": false,
928
+ "normalized": true,
929
+ "rstrip": false,
930
+ "single_word": false,
931
+ "special": false
932
+ },
933
+ "116": {
934
+ "content": "\n\n\n\n\n\n\n\n\n",
935
+ "lstrip": false,
936
+ "normalized": true,
937
+ "rstrip": false,
938
+ "single_word": false,
939
+ "special": false
940
+ },
941
+ "117": {
942
+ "content": "\n\n\n\n\n\n\n\n\n\n",
943
+ "lstrip": false,
944
+ "normalized": true,
945
+ "rstrip": false,
946
+ "single_word": false,
947
+ "special": false
948
+ },
949
+ "118": {
950
+ "content": "\n\n\n\n\n\n\n\n\n\n\n",
951
+ "lstrip": false,
952
+ "normalized": true,
953
+ "rstrip": false,
954
+ "single_word": false,
955
+ "special": false
956
+ },
957
+ "119": {
958
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n",
959
+ "lstrip": false,
960
+ "normalized": true,
961
+ "rstrip": false,
962
+ "single_word": false,
963
+ "special": false
964
+ },
965
+ "120": {
966
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n",
967
+ "lstrip": false,
968
+ "normalized": true,
969
+ "rstrip": false,
970
+ "single_word": false,
971
+ "special": false
972
+ },
973
+ "121": {
974
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
975
+ "lstrip": false,
976
+ "normalized": true,
977
+ "rstrip": false,
978
+ "single_word": false,
979
+ "special": false
980
+ },
981
+ "122": {
982
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
983
+ "lstrip": false,
984
+ "normalized": true,
985
+ "rstrip": false,
986
+ "single_word": false,
987
+ "special": false
988
+ },
989
+ "123": {
990
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
991
+ "lstrip": false,
992
+ "normalized": true,
993
+ "rstrip": false,
994
+ "single_word": false,
995
+ "special": false
996
+ },
997
+ "124": {
998
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
999
+ "lstrip": false,
1000
+ "normalized": true,
1001
+ "rstrip": false,
1002
+ "single_word": false,
1003
+ "special": false
1004
+ },
1005
+ "125": {
1006
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1007
+ "lstrip": false,
1008
+ "normalized": true,
1009
+ "rstrip": false,
1010
+ "single_word": false,
1011
+ "special": false
1012
+ },
1013
+ "126": {
1014
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1015
+ "lstrip": false,
1016
+ "normalized": true,
1017
+ "rstrip": false,
1018
+ "single_word": false,
1019
+ "special": false
1020
+ },
1021
+ "127": {
1022
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1023
+ "lstrip": false,
1024
+ "normalized": true,
1025
+ "rstrip": false,
1026
+ "single_word": false,
1027
+ "special": false
1028
+ },
1029
+ "128": {
1030
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1031
+ "lstrip": false,
1032
+ "normalized": true,
1033
+ "rstrip": false,
1034
+ "single_word": false,
1035
+ "special": false
1036
+ },
1037
+ "129": {
1038
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1039
+ "lstrip": false,
1040
+ "normalized": true,
1041
+ "rstrip": false,
1042
+ "single_word": false,
1043
+ "special": false
1044
+ },
1045
+ "130": {
1046
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1047
+ "lstrip": false,
1048
+ "normalized": true,
1049
+ "rstrip": false,
1050
+ "single_word": false,
1051
+ "special": false
1052
+ },
1053
+ "131": {
1054
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1055
+ "lstrip": false,
1056
+ "normalized": true,
1057
+ "rstrip": false,
1058
+ "single_word": false,
1059
+ "special": false
1060
+ },
1061
+ "132": {
1062
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1063
+ "lstrip": false,
1064
+ "normalized": true,
1065
+ "rstrip": false,
1066
+ "single_word": false,
1067
+ "special": false
1068
+ },
1069
+ "133": {
1070
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1071
+ "lstrip": false,
1072
+ "normalized": true,
1073
+ "rstrip": false,
1074
+ "single_word": false,
1075
+ "special": false
1076
+ },
1077
+ "134": {
1078
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1079
+ "lstrip": false,
1080
+ "normalized": true,
1081
+ "rstrip": false,
1082
+ "single_word": false,
1083
+ "special": false
1084
+ },
1085
+ "135": {
1086
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1087
+ "lstrip": false,
1088
+ "normalized": true,
1089
+ "rstrip": false,
1090
+ "single_word": false,
1091
+ "special": false
1092
+ },
1093
+ "136": {
1094
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1095
+ "lstrip": false,
1096
+ "normalized": true,
1097
+ "rstrip": false,
1098
+ "single_word": false,
1099
+ "special": false
1100
+ },
1101
+ "137": {
1102
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1103
+ "lstrip": false,
1104
+ "normalized": true,
1105
+ "rstrip": false,
1106
+ "single_word": false,
1107
+ "special": false
1108
+ },
1109
+ "138": {
1110
+ "content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
1111
+ "lstrip": false,
1112
+ "normalized": true,
1113
+ "rstrip": false,
1114
+ "single_word": false,
1115
+ "special": false
1116
+ },
1117
+ "139": {
1118
+ "content": "▁▁",
1119
+ "lstrip": false,
1120
+ "normalized": true,
1121
+ "rstrip": false,
1122
+ "single_word": false,
1123
+ "special": false
1124
+ },
1125
+ "140": {
1126
+ "content": "▁▁▁",
1127
+ "lstrip": false,
1128
+ "normalized": true,
1129
+ "rstrip": false,
1130
+ "single_word": false,
1131
+ "special": false
1132
+ },
1133
+ "141": {
1134
+ "content": "▁▁▁▁",
1135
+ "lstrip": false,
1136
+ "normalized": true,
1137
+ "rstrip": false,
1138
+ "single_word": false,
1139
+ "special": false
1140
+ },
1141
+ "142": {
1142
+ "content": "▁▁▁▁▁",
1143
+ "lstrip": false,
1144
+ "normalized": true,
1145
+ "rstrip": false,
1146
+ "single_word": false,
1147
+ "special": false
1148
+ },
1149
+ "143": {
1150
+ "content": "▁▁▁▁▁▁",
1151
+ "lstrip": false,
1152
+ "normalized": true,
1153
+ "rstrip": false,
1154
+ "single_word": false,
1155
+ "special": false
1156
+ },
1157
+ "144": {
1158
+ "content": "▁▁▁▁▁▁▁",
1159
+ "lstrip": false,
1160
+ "normalized": true,
1161
+ "rstrip": false,
1162
+ "single_word": false,
1163
+ "special": false
1164
+ },
1165
+ "145": {
1166
+ "content": "▁▁▁▁▁▁▁▁",
1167
+ "lstrip": false,
1168
+ "normalized": true,
1169
+ "rstrip": false,
1170
+ "single_word": false,
1171
+ "special": false
1172
+ },
1173
+ "146": {
1174
+ "content": "▁▁▁▁▁▁▁▁▁",
1175
+ "lstrip": false,
1176
+ "normalized": true,
1177
+ "rstrip": false,
1178
+ "single_word": false,
1179
+ "special": false
1180
+ },
1181
+ "147": {
1182
+ "content": "▁▁▁▁▁▁▁▁▁▁",
1183
+ "lstrip": false,
1184
+ "normalized": true,
1185
+ "rstrip": false,
1186
+ "single_word": false,
1187
+ "special": false
1188
+ },
1189
+ "148": {
1190
+ "content": "▁▁▁▁▁▁▁▁▁▁▁",
1191
+ "lstrip": false,
1192
+ "normalized": true,
1193
+ "rstrip": false,
1194
+ "single_word": false,
1195
+ "special": false
1196
+ },
1197
+ "149": {
1198
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁",
1199
+ "lstrip": false,
1200
+ "normalized": true,
1201
+ "rstrip": false,
1202
+ "single_word": false,
1203
+ "special": false
1204
+ },
1205
+ "150": {
1206
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁",
1207
+ "lstrip": false,
1208
+ "normalized": true,
1209
+ "rstrip": false,
1210
+ "single_word": false,
1211
+ "special": false
1212
+ },
1213
+ "151": {
1214
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1215
+ "lstrip": false,
1216
+ "normalized": true,
1217
+ "rstrip": false,
1218
+ "single_word": false,
1219
+ "special": false
1220
+ },
1221
+ "152": {
1222
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1223
+ "lstrip": false,
1224
+ "normalized": true,
1225
+ "rstrip": false,
1226
+ "single_word": false,
1227
+ "special": false
1228
+ },
1229
+ "153": {
1230
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1231
+ "lstrip": false,
1232
+ "normalized": true,
1233
+ "rstrip": false,
1234
+ "single_word": false,
1235
+ "special": false
1236
+ },
1237
+ "154": {
1238
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1239
+ "lstrip": false,
1240
+ "normalized": true,
1241
+ "rstrip": false,
1242
+ "single_word": false,
1243
+ "special": false
1244
+ },
1245
+ "155": {
1246
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1247
+ "lstrip": false,
1248
+ "normalized": true,
1249
+ "rstrip": false,
1250
+ "single_word": false,
1251
+ "special": false
1252
+ },
1253
+ "156": {
1254
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1255
+ "lstrip": false,
1256
+ "normalized": true,
1257
+ "rstrip": false,
1258
+ "single_word": false,
1259
+ "special": false
1260
+ },
1261
+ "157": {
1262
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1263
+ "lstrip": false,
1264
+ "normalized": true,
1265
+ "rstrip": false,
1266
+ "single_word": false,
1267
+ "special": false
1268
+ },
1269
+ "158": {
1270
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1271
+ "lstrip": false,
1272
+ "normalized": true,
1273
+ "rstrip": false,
1274
+ "single_word": false,
1275
+ "special": false
1276
+ },
1277
+ "159": {
1278
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1279
+ "lstrip": false,
1280
+ "normalized": true,
1281
+ "rstrip": false,
1282
+ "single_word": false,
1283
+ "special": false
1284
+ },
1285
+ "160": {
1286
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1287
+ "lstrip": false,
1288
+ "normalized": true,
1289
+ "rstrip": false,
1290
+ "single_word": false,
1291
+ "special": false
1292
+ },
1293
+ "161": {
1294
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1295
+ "lstrip": false,
1296
+ "normalized": true,
1297
+ "rstrip": false,
1298
+ "single_word": false,
1299
+ "special": false
1300
+ },
1301
+ "162": {
1302
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1303
+ "lstrip": false,
1304
+ "normalized": true,
1305
+ "rstrip": false,
1306
+ "single_word": false,
1307
+ "special": false
1308
+ },
1309
+ "163": {
1310
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1311
+ "lstrip": false,
1312
+ "normalized": true,
1313
+ "rstrip": false,
1314
+ "single_word": false,
1315
+ "special": false
1316
+ },
1317
+ "164": {
1318
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1319
+ "lstrip": false,
1320
+ "normalized": true,
1321
+ "rstrip": false,
1322
+ "single_word": false,
1323
+ "special": false
1324
+ },
1325
+ "165": {
1326
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1327
+ "lstrip": false,
1328
+ "normalized": true,
1329
+ "rstrip": false,
1330
+ "single_word": false,
1331
+ "special": false
1332
+ },
1333
+ "166": {
1334
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1335
+ "lstrip": false,
1336
+ "normalized": true,
1337
+ "rstrip": false,
1338
+ "single_word": false,
1339
+ "special": false
1340
+ },
1341
+ "167": {
1342
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1343
+ "lstrip": false,
1344
+ "normalized": true,
1345
+ "rstrip": false,
1346
+ "single_word": false,
1347
+ "special": false
1348
+ },
1349
+ "168": {
1350
+ "content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
1351
+ "lstrip": false,
1352
+ "normalized": true,
1353
+ "rstrip": false,
1354
+ "single_word": false,
1355
+ "special": false
1356
+ },
1357
+ "169": {
1358
+ "content": "<table>",
1359
+ "lstrip": false,
1360
+ "normalized": true,
1361
+ "rstrip": false,
1362
+ "single_word": false,
1363
+ "special": false
1364
+ },
1365
+ "170": {
1366
+ "content": "<caption>",
1367
+ "lstrip": false,
1368
+ "normalized": true,
1369
+ "rstrip": false,
1370
+ "single_word": false,
1371
+ "special": false
1372
+ },
1373
+ "171": {
1374
+ "content": "<thead>",
1375
+ "lstrip": false,
1376
+ "normalized": true,
1377
+ "rstrip": false,
1378
+ "single_word": false,
1379
+ "special": false
1380
+ },
1381
+ "172": {
1382
+ "content": "<tbody>",
1383
+ "lstrip": false,
1384
+ "normalized": true,
1385
+ "rstrip": false,
1386
+ "single_word": false,
1387
+ "special": false
1388
+ },
1389
+ "173": {
1390
+ "content": "<tfoot>",
1391
+ "lstrip": false,
1392
+ "normalized": true,
1393
+ "rstrip": false,
1394
+ "single_word": false,
1395
+ "special": false
1396
+ },
1397
+ "174": {
1398
+ "content": "<tr>",
1399
+ "lstrip": false,
1400
+ "normalized": true,
1401
+ "rstrip": false,
1402
+ "single_word": false,
1403
+ "special": false
1404
+ },
1405
+ "175": {
1406
+ "content": "<th>",
1407
+ "lstrip": false,
1408
+ "normalized": true,
1409
+ "rstrip": false,
1410
+ "single_word": false,
1411
+ "special": false
1412
+ },
1413
+ "176": {
1414
+ "content": "<td>",
1415
+ "lstrip": false,
1416
+ "normalized": true,
1417
+ "rstrip": false,
1418
+ "single_word": false,
1419
+ "special": false
1420
+ },
1421
+ "177": {
1422
+ "content": "</table>",
1423
+ "lstrip": false,
1424
+ "normalized": true,
1425
+ "rstrip": false,
1426
+ "single_word": false,
1427
+ "special": false
1428
+ },
1429
+ "178": {
1430
+ "content": "</caption>",
1431
+ "lstrip": false,
1432
+ "normalized": true,
1433
+ "rstrip": false,
1434
+ "single_word": false,
1435
+ "special": false
1436
+ },
1437
+ "179": {
1438
+ "content": "</thead>",
1439
+ "lstrip": false,
1440
+ "normalized": true,
1441
+ "rstrip": false,
1442
+ "single_word": false,
1443
+ "special": false
1444
+ },
1445
+ "180": {
1446
+ "content": "</tbody>",
1447
+ "lstrip": false,
1448
+ "normalized": true,
1449
+ "rstrip": false,
1450
+ "single_word": false,
1451
+ "special": false
1452
+ },
1453
+ "181": {
1454
+ "content": "</tfoot>",
1455
+ "lstrip": false,
1456
+ "normalized": true,
1457
+ "rstrip": false,
1458
+ "single_word": false,
1459
+ "special": false
1460
+ },
1461
+ "182": {
1462
+ "content": "</tr>",
1463
+ "lstrip": false,
1464
+ "normalized": true,
1465
+ "rstrip": false,
1466
+ "single_word": false,
1467
+ "special": false
1468
+ },
1469
+ "183": {
1470
+ "content": "</th>",
1471
+ "lstrip": false,
1472
+ "normalized": true,
1473
+ "rstrip": false,
1474
+ "single_word": false,
1475
+ "special": false
1476
+ },
1477
+ "184": {
1478
+ "content": "</td>",
1479
+ "lstrip": false,
1480
+ "normalized": true,
1481
+ "rstrip": false,
1482
+ "single_word": false,
1483
+ "special": false
1484
+ },
1485
+ "185": {
1486
+ "content": "<h1>",
1487
+ "lstrip": false,
1488
+ "normalized": true,
1489
+ "rstrip": false,
1490
+ "single_word": false,
1491
+ "special": false
1492
+ },
1493
+ "186": {
1494
+ "content": "<h2>",
1495
+ "lstrip": false,
1496
+ "normalized": true,
1497
+ "rstrip": false,
1498
+ "single_word": false,
1499
+ "special": false
1500
+ },
1501
+ "187": {
1502
+ "content": "<h3>",
1503
+ "lstrip": false,
1504
+ "normalized": true,
1505
+ "rstrip": false,
1506
+ "single_word": false,
1507
+ "special": false
1508
+ },
1509
+ "188": {
1510
+ "content": "<h4>",
1511
+ "lstrip": false,
1512
+ "normalized": true,
1513
+ "rstrip": false,
1514
+ "single_word": false,
1515
+ "special": false
1516
+ },
1517
+ "189": {
1518
+ "content": "<h5>",
1519
+ "lstrip": false,
1520
+ "normalized": true,
1521
+ "rstrip": false,
1522
+ "single_word": false,
1523
+ "special": false
1524
+ },
1525
+ "190": {
1526
+ "content": "<h6>",
1527
+ "lstrip": false,
1528
+ "normalized": true,
1529
+ "rstrip": false,
1530
+ "single_word": false,
1531
+ "special": false
1532
+ },
1533
+ "191": {
1534
+ "content": "<blockquote>",
1535
+ "lstrip": false,
1536
+ "normalized": true,
1537
+ "rstrip": false,
1538
+ "single_word": false,
1539
+ "special": false
1540
+ },
1541
+ "192": {
1542
+ "content": "</h1>",
1543
+ "lstrip": false,
1544
+ "normalized": true,
1545
+ "rstrip": false,
1546
+ "single_word": false,
1547
+ "special": false
1548
+ },
1549
+ "193": {
1550
+ "content": "</h2>",
1551
+ "lstrip": false,
1552
+ "normalized": true,
1553
+ "rstrip": false,
1554
+ "single_word": false,
1555
+ "special": false
1556
+ },
1557
+ "194": {
1558
+ "content": "</h3>",
1559
+ "lstrip": false,
1560
+ "normalized": true,
1561
+ "rstrip": false,
1562
+ "single_word": false,
1563
+ "special": false
1564
+ },
1565
+ "195": {
1566
+ "content": "</h4>",
1567
+ "lstrip": false,
1568
+ "normalized": true,
1569
+ "rstrip": false,
1570
+ "single_word": false,
1571
+ "special": false
1572
+ },
1573
+ "196": {
1574
+ "content": "</h5>",
1575
+ "lstrip": false,
1576
+ "normalized": true,
1577
+ "rstrip": false,
1578
+ "single_word": false,
1579
+ "special": false
1580
+ },
1581
+ "197": {
1582
+ "content": "</h6>",
1583
+ "lstrip": false,
1584
+ "normalized": true,
1585
+ "rstrip": false,
1586
+ "single_word": false,
1587
+ "special": false
1588
+ },
1589
+ "198": {
1590
+ "content": "</blockquote>",
1591
+ "lstrip": false,
1592
+ "normalized": true,
1593
+ "rstrip": false,
1594
+ "single_word": false,
1595
+ "special": false
1596
+ },
1597
+ "199": {
1598
+ "content": "<strong>",
1599
+ "lstrip": false,
1600
+ "normalized": true,
1601
+ "rstrip": false,
1602
+ "single_word": false,
1603
+ "special": false
1604
+ },
1605
+ "200": {
1606
+ "content": "<em>",
1607
+ "lstrip": false,
1608
+ "normalized": true,
1609
+ "rstrip": false,
1610
+ "single_word": false,
1611
+ "special": false
1612
+ },
1613
+ "201": {
1614
+ "content": "<b>",
1615
+ "lstrip": false,
1616
+ "normalized": true,
1617
+ "rstrip": false,
1618
+ "single_word": false,
1619
+ "special": false
1620
+ },
1621
+ "202": {
1622
+ "content": "<i>",
1623
+ "lstrip": false,
1624
+ "normalized": true,
1625
+ "rstrip": false,
1626
+ "single_word": false,
1627
+ "special": false
1628
+ },
1629
+ "203": {
1630
+ "content": "<u>",
1631
+ "lstrip": false,
1632
+ "normalized": true,
1633
+ "rstrip": false,
1634
+ "single_word": false,
1635
+ "special": false
1636
+ },
1637
+ "204": {
1638
+ "content": "<s>",
1639
+ "lstrip": false,
1640
+ "normalized": true,
1641
+ "rstrip": false,
1642
+ "single_word": false,
1643
+ "special": false
1644
+ },
1645
+ "205": {
1646
+ "content": "<sub>",
1647
+ "lstrip": false,
1648
+ "normalized": true,
1649
+ "rstrip": false,
1650
+ "single_word": false,
1651
+ "special": false
1652
+ },
1653
+ "206": {
1654
+ "content": "<sup>",
1655
+ "lstrip": false,
1656
+ "normalized": true,
1657
+ "rstrip": false,
1658
+ "single_word": false,
1659
+ "special": false
1660
+ },
1661
+ "207": {
1662
+ "content": "<code>",
1663
+ "lstrip": false,
1664
+ "normalized": true,
1665
+ "rstrip": false,
1666
+ "single_word": false,
1667
+ "special": false
1668
+ },
1669
+ "208": {
1670
+ "content": "</strong>",
1671
+ "lstrip": false,
1672
+ "normalized": true,
1673
+ "rstrip": false,
1674
+ "single_word": false,
1675
+ "special": false
1676
+ },
1677
+ "209": {
1678
+ "content": "</em>",
1679
+ "lstrip": false,
1680
+ "normalized": true,
1681
+ "rstrip": false,
1682
+ "single_word": false,
1683
+ "special": false
1684
+ },
1685
+ "210": {
1686
+ "content": "</b>",
1687
+ "lstrip": false,
1688
+ "normalized": true,
1689
+ "rstrip": false,
1690
+ "single_word": false,
1691
+ "special": false
1692
+ },
1693
+ "211": {
1694
+ "content": "</i>",
1695
+ "lstrip": false,
1696
+ "normalized": true,
1697
+ "rstrip": false,
1698
+ "single_word": false,
1699
+ "special": false
1700
+ },
1701
+ "212": {
1702
+ "content": "</u>",
1703
+ "lstrip": false,
1704
+ "normalized": true,
1705
+ "rstrip": false,
1706
+ "single_word": false,
1707
+ "special": false
1708
+ },
1709
+ "213": {
1710
+ "content": "</s>",
1711
+ "lstrip": false,
1712
+ "normalized": true,
1713
+ "rstrip": false,
1714
+ "single_word": false,
1715
+ "special": false
1716
+ },
1717
+ "214": {
1718
+ "content": "</sub>",
1719
+ "lstrip": false,
1720
+ "normalized": true,
1721
+ "rstrip": false,
1722
+ "single_word": false,
1723
+ "special": false
1724
+ },
1725
+ "215": {
1726
+ "content": "</sup>",
1727
+ "lstrip": false,
1728
+ "normalized": true,
1729
+ "rstrip": false,
1730
+ "single_word": false,
1731
+ "special": false
1732
+ },
1733
+ "216": {
1734
+ "content": "</code>",
1735
+ "lstrip": false,
1736
+ "normalized": true,
1737
+ "rstrip": false,
1738
+ "single_word": false,
1739
+ "special": false
1740
+ },
1741
+ "257152": {
1742
+ "content": "<image>",
1743
+ "lstrip": false,
1744
+ "normalized": false,
1745
+ "rstrip": false,
1746
+ "single_word": false,
1747
+ "special": true
1748
+ }
1749
+ },
1750
+ "additional_special_tokens": [
1751
+ "<image>"
1752
+ ],
1753
+ "bos_token": "<bos>",
1754
+ "clean_up_tokenization_spaces": false,
1755
+ "eos_token": "<eos>",
1756
+ "extra_special_tokens": {},
1757
+ "max_length": 48,
1758
+ "model_max_length": 1000000000000000019884624838656,
1759
+ "pad_to_multiple_of": null,
1760
+ "pad_token": "<pad>",
1761
+ "pad_token_type_id": 0,
1762
+ "padding_side": "right",
1763
+ "processor_class": "PaliGemmaProcessor",
1764
+ "sp_model_kwargs": {},
1765
+ "spaces_between_special_tokens": false,
1766
+ "stride": 0,
1767
+ "tokenizer_class": "GemmaTokenizer",
1768
+ "truncation_side": "right",
1769
+ "truncation_strategy": "longest_first",
1770
+ "unk_token": "<unk>",
1771
+ "use_default_system_prompt": false
1772
+ }
value_query.py ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import deepcopy
3
+ from functools import partial
4
+ from typing import Callable, Optional, Sequence, Tuple, Union
5
+
6
+ import array_typing as at
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from beartype import beartype as typechecker
11
+ from jaxtyping import Float, jaxtyped
12
+ from torch.distributions import Independent, Normal, TransformedDistribution
13
+ from torch.distributions.transforms import (
14
+ AffineTransform,
15
+ ComposeTransform,
16
+ TanhTransform,
17
+ )
18
+ from torch.optim import Adam, AdamW, Optimizer
19
+ from torch.optim.lr_scheduler import (
20
+ LambdaLR,
21
+ )
22
+ from transformers import (
23
+ AutoConfig,
24
+ GemmaForCausalLM,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+ from transformers.models.auto import CONFIG_MAPPING
29
+
30
+
31
+ def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor:
32
+ return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim)
33
+
34
+
35
+ def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False):
36
+ if isinstance(module, nn.Linear):
37
+ if orthogonal_init:
38
+ nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
39
+ nn.init.constant_(module.bias, 0.0)
40
+ else:
41
+ nn.init.xavier_uniform_(module.weight, gain=1e-2)
42
+
43
+
44
+ class VQHBackboneConfig(PretrainedConfig):
45
+ model_type = "VQHBackbone"
46
+ sub_configs = {"gemma_expert_config": AutoConfig}
47
+
48
+ def __init__(
49
+ self,
50
+ gemma_expert_config: dict | None = None,
51
+ attention_implementation: str = "eager",
52
+ **kwargs,
53
+ ):
54
+ self.attention_implementation = attention_implementation
55
+
56
+ if gemma_expert_config is None:
57
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
58
+ attention_bias=False,
59
+ attention_dropout=0.0,
60
+ bos_token_id=2,
61
+ eos_token_id=1,
62
+ head_dim=256,
63
+ hidden_act="gelu_pytorch_tanh",
64
+ hidden_activation="gelu_pytorch_tanh",
65
+ hidden_size=2048,
66
+ initializer_range=0.02,
67
+ intermediate_size=4096,
68
+ max_position_embeddings=8192,
69
+ model_type="gemma",
70
+ num_attention_heads=8,
71
+ num_hidden_layers=4,
72
+ num_key_value_heads=1,
73
+ pad_token_id=0,
74
+ rms_norm_eps=1e-06,
75
+ rope_theta=10000.0,
76
+ torch_dtype="float32",
77
+ transformers_version="4.48.1",
78
+ use_cache=True,
79
+ vocab_size=257152,
80
+ )
81
+ elif isinstance(gemma_expert_config, dict):
82
+ if "model_type" not in gemma_expert_config:
83
+ gemma_expert_config["model_type"] = "gemma"
84
+ cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
85
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
86
+
87
+ super().__init__(**kwargs)
88
+
89
+ def __post_init__(self):
90
+ super().__post_init__()
91
+ if self.attention_implementation not in ["eager", "fa2", "flex"]:
92
+ raise ValueError(
93
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
94
+ )
95
+
96
+
97
+ def apply_rope(x, positions, max_wavelength=10_000):
98
+ """
99
+ Applies RoPE positions [B, L] to x [B, L, H, D].
100
+ """
101
+ d_half = x.shape[-1] // 2
102
+ device = x.device
103
+ dtype = x.dtype
104
+ x = x.to(torch.float32)
105
+
106
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
107
+ d_half, dtype=torch.float32, device=device
108
+ )
109
+ timescale = max_wavelength**freq_exponents
110
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
111
+ torch.float32
112
+ )
113
+
114
+ radians = radians[..., None, :]
115
+
116
+ sin = torch.sin(radians) # .to(dtype=dtype)
117
+ cos = torch.cos(radians) # .to(dtype=dtype)
118
+
119
+ x1, x2 = x.split(d_half, dim=-1)
120
+ res = torch.empty_like(x)
121
+ res[..., :d_half] = x1 * cos - x2 * sin
122
+ res[..., d_half:] = x2 * cos + x1 * sin
123
+
124
+ return res.to(dtype)
125
+
126
+
127
+ class VQHBackbone(PreTrainedModel):
128
+ config_class = VQHBackboneConfig
129
+
130
+ def __init__(self, config: VQHBackboneConfig):
131
+ super().__init__(config=config)
132
+ self.config = config
133
+ self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
134
+
135
+ self.to_bfloat16_like_physical_intelligence()
136
+
137
+ def train(self, mode: bool = True):
138
+ super().train(mode)
139
+
140
+ def to_bfloat16_like_physical_intelligence(self):
141
+ params_to_change_dtype = [
142
+ "language_model.model.layers",
143
+ "gemma_expert.model.layers",
144
+ ]
145
+ for name, param in self.named_parameters():
146
+ if any(selector in name for selector in params_to_change_dtype):
147
+ param.data = param.data.to(dtype=torch.bfloat16)
148
+
149
+ def forward(
150
+ self,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ position_ids: Optional[torch.LongTensor] = None,
153
+ inputs_embeds: Optional[torch.FloatTensor] = None,
154
+ ):
155
+ # RMSNorm
156
+ head_dim = self.gemma_expert.config.head_dim
157
+
158
+ hidden_states = inputs_embeds
159
+ batch_size = hidden_states.shape[0]
160
+ for layer in self.gemma_expert.model.layers[
161
+ : self.gemma_expert.config.num_hidden_layers
162
+ ]:
163
+ # normalizer = torch.tensor(model.config.hidden_size**0.5, dtype=hidden_states.dtype)
164
+ # hidden_states = hidden_states * normalizer
165
+ hidden_states = layer.input_layernorm(hidden_states)
166
+ input_shape = hidden_states.shape[:-1]
167
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
168
+
169
+ # self attention
170
+ hidden_states = hidden_states.to(dtype=torch.bfloat16)
171
+ query_states = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
172
+ key_states = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
173
+ value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
174
+
175
+ query_states = apply_rope(query_states, position_ids)
176
+ key_states = apply_rope(key_states, position_ids)
177
+
178
+ attention_interface = self.get_attention_interface()
179
+ att_output = attention_interface(
180
+ attention_mask,
181
+ batch_size,
182
+ head_dim,
183
+ query_states,
184
+ key_states,
185
+ value_states,
186
+ )
187
+
188
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
189
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
190
+
191
+ out_emb = layer.self_attn.o_proj(att_output)
192
+
193
+ # first residual
194
+ out_emb += hidden_states
195
+ after_first_residual = out_emb.clone()
196
+ out_emb = layer.post_attention_layernorm(out_emb)
197
+ out_emb = layer.mlp(out_emb)
198
+ # second residual
199
+ out_emb += after_first_residual
200
+ hidden_states = out_emb
201
+
202
+ # final norm
203
+ hidden_states = self.gemma_expert.model.norm(hidden_states)
204
+
205
+ return hidden_states
206
+
207
+ def get_attention_interface(self):
208
+ if self.config.attention_implementation == "fa2":
209
+ attention_interface = self.flash_attention_forward
210
+ else:
211
+ attention_interface = self.eager_attention_forward
212
+ return attention_interface
213
+
214
+ def eager_attention_forward(
215
+ self,
216
+ attention_mask,
217
+ batch_size,
218
+ head_dim,
219
+ query_states,
220
+ key_states,
221
+ value_states,
222
+ ):
223
+ num_att_heads = self.config.gemma_expert_config.num_attention_heads
224
+ num_key_value_heads = self.config.gemma_expert_config.num_key_value_heads
225
+ num_key_value_groups = num_att_heads // num_key_value_heads
226
+
227
+ # query_states: batch_size, sequence_length, num_att_head, head_dim
228
+ # key_states: batch_size, sequence_length, num_key_value_head, head_dim
229
+ # value_states: batch_size, sequence_length, num_key_value_head, head_dim
230
+ sequence_length = key_states.shape[1]
231
+
232
+ key_states = key_states[:, :, :, None, :].expand(
233
+ batch_size,
234
+ sequence_length,
235
+ num_key_value_heads,
236
+ num_key_value_groups,
237
+ head_dim,
238
+ )
239
+ key_states = key_states.reshape(
240
+ batch_size,
241
+ sequence_length,
242
+ num_key_value_heads * num_key_value_groups,
243
+ head_dim,
244
+ )
245
+
246
+ value_states = value_states[:, :, :, None, :].expand(
247
+ batch_size,
248
+ sequence_length,
249
+ num_key_value_heads,
250
+ num_key_value_groups,
251
+ head_dim,
252
+ )
253
+ value_states = value_states.reshape(
254
+ batch_size,
255
+ sequence_length,
256
+ num_key_value_heads * num_key_value_groups,
257
+ head_dim,
258
+ )
259
+
260
+ # Attention here is upcasted to float32 to match the original eager implementation.
261
+ query_states = query_states.to(dtype=torch.float32)
262
+ key_states = key_states.to(dtype=torch.float32)
263
+
264
+ query_states = query_states.transpose(1, 2)
265
+ key_states = key_states.transpose(1, 2)
266
+
267
+ att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
268
+ att_weights *= head_dim**-0.5
269
+ big_neg = -2.3819763e38 # See gemma/modules.py
270
+
271
+ masked_att_weights = torch.where(
272
+ attention_mask[:, None, :, :], att_weights, big_neg
273
+ )
274
+
275
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
276
+ probs = probs.to(dtype=value_states.dtype)
277
+
278
+ # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
279
+ # value_states: batch_size, sequence_length, num_att_heads, head_dim
280
+
281
+ att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
282
+
283
+ att_output = att_output.permute(0, 2, 1, 3)
284
+ # we use -1 because sequence length can change
285
+ att_output = att_output.reshape(
286
+ batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
287
+ )
288
+
289
+ return att_output
290
+
291
+
292
+ class LagrangeMultiplier(nn.Module):
293
+ def __init__(
294
+ self,
295
+ init_value: float = 1.0,
296
+ constraint_shape: Tuple[int, ...] = (),
297
+ constraint_type: str = "eq", # One of ("eq", "leq", "geq")
298
+ parameterization: Optional[
299
+ str
300
+ ] = None, # One of ("softplus", "exp"), or None for equality constraints
301
+ ):
302
+ super().__init__()
303
+ self.constraint_type = constraint_type
304
+ self.parameterization = parameterization
305
+
306
+ if constraint_type != "eq":
307
+ assert (
308
+ init_value > 0
309
+ ), "Inequality constraints must have non-negative initial multiplier values"
310
+
311
+ if parameterization == "softplus":
312
+ init_value = torch.log(torch.exp(torch.tensor(init_value)) - 1).item()
313
+ elif parameterization == "exp":
314
+ init_value = torch.log(torch.tensor(init_value)).item()
315
+ else:
316
+ raise ValueError(
317
+ f"Invalid multiplier parameterization {parameterization}"
318
+ )
319
+ else:
320
+ assert (
321
+ parameterization is None
322
+ ), "Equality constraints must have no parameterization"
323
+
324
+ self.multiplier = nn.Parameter(torch.full(constraint_shape, init_value))
325
+
326
+ def forward(
327
+ self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None
328
+ ) -> torch.Tensor:
329
+ multiplier = self.multiplier
330
+
331
+ if self.constraint_type != "eq":
332
+ if self.parameterization == "softplus":
333
+ multiplier = torch.nn.functional.softplus(multiplier)
334
+ elif self.parameterization == "exp":
335
+ multiplier = torch.exp(multiplier)
336
+ else:
337
+ raise ValueError(
338
+ f"Invalid multiplier parameterization {self.parameterization}"
339
+ )
340
+
341
+ if lhs is None:
342
+ return multiplier
343
+
344
+ if rhs is None:
345
+ rhs = torch.zeros_like(lhs)
346
+
347
+ diff = lhs - rhs
348
+
349
+ assert (
350
+ diff.shape == multiplier.shape
351
+ ), f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
352
+
353
+ if self.constraint_type == "eq":
354
+ return multiplier * diff
355
+ elif self.constraint_type == "geq":
356
+ return multiplier * diff
357
+ elif self.constraint_type == "leq":
358
+ return -multiplier * diff
359
+
360
+
361
+ GeqLagrangeMultiplier = partial(
362
+ LagrangeMultiplier, constraint_type="geq", parameterization="softplus"
363
+ )
364
+
365
+ LeqLagrangeMultiplier = partial(
366
+ LagrangeMultiplier, constraint_type="leq", parameterization="softplus"
367
+ )
368
+
369
+
370
+ class MLP(nn.Module):
371
+ def __init__(
372
+ self,
373
+ input_dim: int,
374
+ hidden_dims: Sequence[int],
375
+ activations: Union[Callable[[torch.Tensor], torch.Tensor], str] = "silu",
376
+ activate_final: bool = False,
377
+ use_layer_norm: bool = False,
378
+ use_group_norm: bool = False,
379
+ dropout_rate: Optional[float] = None,
380
+ ):
381
+ super().__init__()
382
+
383
+ assert not (use_layer_norm and use_group_norm)
384
+
385
+ self.activate_final = activate_final
386
+ self.dropout_rate = dropout_rate
387
+ self.input_dim = input_dim
388
+ self.hidden_dims = hidden_dims
389
+
390
+ if isinstance(activations, str):
391
+ if activations == "silu" or activations == "swish":
392
+ self.activations = nn.SiLU()
393
+ else:
394
+ self.activations = getattr(nn, activations)()
395
+ else:
396
+ self.activations = activations
397
+
398
+ layers = []
399
+
400
+ for i, hidden_dim in enumerate(hidden_dims):
401
+ layers.append(nn.Linear(input_dim, hidden_dim))
402
+ nn.init.xavier_uniform_(layers[-1].weight)
403
+ nn.init.zeros_(layers[-1].bias)
404
+
405
+ input_dim = hidden_dim
406
+
407
+ if i + 1 < len(hidden_dims) or activate_final:
408
+ if dropout_rate is not None and dropout_rate > 0:
409
+ layers.append(nn.Dropout(p=dropout_rate))
410
+
411
+ if use_layer_norm:
412
+ layers.append(nn.LayerNorm(hidden_dim))
413
+ elif use_group_norm:
414
+ num_groups = min(hidden_dim, 32)
415
+ layers.append(nn.GroupNorm(num_groups, hidden_dim))
416
+ layers.append(self.activations)
417
+
418
+ self.layers = nn.ModuleList(layers)
419
+
420
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
421
+ for layer in self.layers:
422
+ x = layer(x)
423
+
424
+ return x
425
+
426
+
427
+ class TanhMultivariateNormalDiag(TransformedDistribution):
428
+ def __init__(
429
+ self,
430
+ loc: torch.Tensor,
431
+ scale_diag: torch.Tensor,
432
+ low: Optional[torch.Tensor] = None,
433
+ high: Optional[torch.Tensor] = None,
434
+ ):
435
+ self.loc = loc
436
+ self.scale_diag = scale_diag
437
+ base_distribution = Independent(Normal(loc, scale_diag), 1)
438
+
439
+ transforms = []
440
+ transforms.append(TanhTransform())
441
+ if not (low is None or high is None):
442
+ transforms.append(
443
+ AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
444
+ )
445
+ transform = ComposeTransform(transforms)
446
+
447
+ super().__init__(base_distribution, transform)
448
+
449
+ def mode(self) -> torch.Tensor:
450
+ """返回分布的众数"""
451
+ # 对于正态分布,众数就是均值
452
+ mode = self.loc
453
+ # 应用变换
454
+ for transform in self.transforms:
455
+ mode = transform(mode)
456
+ return mode
457
+
458
+ def stddev(self) -> torch.Tensor:
459
+ """返回变换后的标准差(近似值)"""
460
+ # 注意:这只是一个近似,因为非线性变换后的标准差计算复杂
461
+ return self.transform(self.loc + self.scale_diag) - self.transform(self.loc)
462
+
463
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
464
+ eps = 1e-6
465
+ value = torch.clamp(value, -1 + eps, 1 - eps)
466
+ return super().log_prob(value)
467
+
468
+
469
+ class Policy(nn.Module):
470
+ def __init__(
471
+ self,
472
+ obs_encoded_dim: int,
473
+ network: nn.Module,
474
+ action_dim: int,
475
+ std_parameterization: str = "exp", # "exp", "softplus", "fixed", or "uniform"
476
+ std_min: Optional[float] = 1e-5,
477
+ std_max: Optional[float] = 10.0,
478
+ tanh_squash_distribution: bool = False,
479
+ fixed_std: Optional[torch.Tensor] = None,
480
+ ):
481
+ super().__init__()
482
+
483
+ self.obs_encoded_dim = obs_encoded_dim
484
+ self.network = network
485
+ self.action_dim = action_dim
486
+ self.std_parameterization = std_parameterization
487
+ self.std_min = std_min
488
+ self.std_max = std_max
489
+ self.tanh_squash_distribution = tanh_squash_distribution
490
+ self.fixed_std = fixed_std
491
+
492
+ self.mean_layer = nn.Linear(network.hidden_dims[-1], action_dim)
493
+
494
+ if fixed_std is None:
495
+ if std_parameterization in ["exp", "softplus"]:
496
+ self.std_layer = nn.Linear(network.hidden_dims[-1], action_dim)
497
+ elif std_parameterization == "uniform":
498
+ self.log_stds = nn.Parameter(torch.zeros(action_dim))
499
+ else:
500
+ raise ValueError(
501
+ f"Invalid std_parameterization: {self.std_parameterization}"
502
+ )
503
+ else:
504
+ assert std_parameterization == "fixed"
505
+
506
+ nn.init.xavier_uniform_(self.mean_layer.weight)
507
+ nn.init.zeros_(self.mean_layer.bias)
508
+
509
+ if fixed_std is None and std_parameterization in ["exp", "softplus"]:
510
+ nn.init.xavier_uniform_(self.std_layer.weight)
511
+ nn.init.zeros_(self.std_layer.bias)
512
+
513
+ def forward(
514
+ self, encoded_observations: torch.Tensor, temperature: float = 1.0
515
+ ) -> Union[TransformedDistribution, Normal]:
516
+ outputs = self.network(encoded_observations)
517
+
518
+ means = self.mean_layer(outputs)
519
+
520
+ if self.fixed_std is None:
521
+ if self.std_parameterization == "exp":
522
+ log_stds = self.std_layer(outputs)
523
+ stds = torch.exp(log_stds)
524
+ elif self.std_parameterization == "softplus":
525
+ stds = self.std_layer(outputs)
526
+ stds = nn.functional.softplus(stds)
527
+ elif self.std_parameterization == "uniform":
528
+ stds = torch.exp(self.log_stds).expand_as(means)
529
+ else:
530
+ raise ValueError(
531
+ f"Invalid std_parameterization: {self.std_parameterization}"
532
+ )
533
+ else:
534
+ stds = self.fixed_std.to(means.device).expand_as(means)
535
+
536
+ stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(
537
+ torch.tensor(temperature)
538
+ )
539
+
540
+ if self.tanh_squash_distribution:
541
+ distribution = TanhMultivariateNormalDiag(
542
+ loc=means,
543
+ scale_diag=stds,
544
+ )
545
+ else:
546
+ distribution = Normal(loc=means, scale=stds)
547
+
548
+ return distribution
549
+
550
+
551
+ class Critics(nn.Module):
552
+ def __init__(
553
+ self,
554
+ obs_encoded_dim: int,
555
+ networks: list[nn.Module],
556
+ num_backbones: int = 2,
557
+ init_final: Optional[float] = None,
558
+ ):
559
+ super().__init__()
560
+ assert len(networks) == num_backbones
561
+ self.obs_encoded_dim = obs_encoded_dim
562
+ self.networks = nn.ModuleList(networks)
563
+ self.num_backbones = num_backbones
564
+ self.init_final = init_final
565
+
566
+ self.backbone_output_dims = networks[0].hidden_dims[-1]
567
+
568
+ if init_final is not None:
569
+ self.output_layer = nn.Linear(self.backbone_output_dims, 1)
570
+ nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
571
+ nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
572
+ else:
573
+ self.output_layer = nn.Linear(self.backbone_output_dims, 1)
574
+ nn.init.xavier_uniform_(self.output_layer.weight)
575
+ nn.init.zeros_(self.output_layer.bias)
576
+
577
+ @jaxtyped(typechecker=typechecker)
578
+ def forward(
579
+ self,
580
+ encoded_observations: Float[torch.Tensor, "batch {self.obs_encoded_dim}"],
581
+ actions: Float[torch.Tensor, "batch *num_actions action_dim"],
582
+ ) -> Float[torch.Tensor, "{self.num_backbones} batch *num_actions"]:
583
+ if actions.ndim == 3:
584
+ # forward the q function with multiple actions on each state
585
+ encoded_observations = encoded_observations.unsqueeze(1).expand(
586
+ -1, actions.shape[1], -1
587
+ )
588
+ # HACK: check dimensions here
589
+ inputs = torch.cat([encoded_observations, actions], dim=-1)
590
+
591
+ backbone_outputs = []
592
+ for network in self.networks:
593
+ backbone_outputs.append(network(inputs))
594
+ backbone_outputs: Float[
595
+ torch.Tensor,
596
+ "{self.num_backbones} batch *num_actions {self.backbone_output_dims}",
597
+ ] = torch.stack(backbone_outputs, dim=0)
598
+
599
+ value = self.output_layer(backbone_outputs)
600
+ # HACK: check output shape here
601
+ # if actions.ndim == 3:
602
+ # value = value.squeeze(-1).permute(0, 2, 1)
603
+ # else:
604
+ value = value.squeeze(-1)
605
+ return value # (num_backbones, batch, *num_actions)
606
+
607
+
608
+ class CalQlConfig(PretrainedConfig):
609
+ moedel_type = "calql"
610
+
611
+ def __init__(
612
+ self,
613
+ obs_encoded_dim=2048,
614
+ action_dim=70,
615
+ actor_lr=1e-4,
616
+ critic_lr=3e-4,
617
+ temp_lr=3e-4,
618
+ actor_wps=2000,
619
+ critic_wps=2000,
620
+ **kwargs,
621
+ ):
622
+ self.cql_clip_diff_min = -np.inf
623
+ self.cql_clip_diff_max = np.inf
624
+ self.cql_alpha = 5.0
625
+ self.cql_autotune_alpha = False
626
+ self.action_dim = action_dim
627
+ self.target_entropy = -self.action_dim
628
+ self.obs_encoded_dim = obs_encoded_dim
629
+ self.cql_temperature_init_value = 1.0
630
+ self.critic_ensemble_size = 2
631
+ self.cql_n_actions = 4
632
+ self.cql_max_target_backup = True
633
+ self.policy_network_kwargs = dict(
634
+ input_dim=self.obs_encoded_dim,
635
+ hidden_dims=[256, 256],
636
+ activate_final=True,
637
+ use_layer_norm=False,
638
+ )
639
+ self.critic_network_kwargs = dict(
640
+ input_dim=self.obs_encoded_dim + self.action_dim,
641
+ hidden_dims=[256, 256],
642
+ activate_final=True,
643
+ use_layer_norm=False,
644
+ )
645
+ self.policy_kwargs = dict(
646
+ tanh_squash_distribution=True,
647
+ std_parameterization="exp",
648
+ )
649
+ self.critic_subsample_size = None
650
+ self.cql_max_target_backup = True
651
+ self.backup_entropy = False
652
+ self.discount = 0.98
653
+ self.goal_conditioned = True
654
+ self.gc_kwargs = dict(
655
+ negative_proportion=0.0,
656
+ )
657
+ self.use_td_loss = True
658
+ self.cql_action_sample_method = "uniform"
659
+ self.cql_importance_sample = True
660
+ self.cql_temp = 1.0
661
+ self.use_calql = True
662
+
663
+ self.actor_optimizer_kwargs = dict(
664
+ learning_rate=actor_lr,
665
+ warmup_steps=actor_wps,
666
+ )
667
+ self.critic_optimizer_kwargs = dict(
668
+ learning_rate=critic_lr,
669
+ warmup_steps=critic_wps,
670
+ )
671
+ self.temperature_optimizer_kwargs = dict(learning_rate=temp_lr)
672
+
673
+ super().__init__(**kwargs)
674
+
675
+
676
+ class CalQL(PreTrainedModel):
677
+ config_calss = CalQlConfig
678
+
679
+ def __init__(self, config: CalQlConfig):
680
+ super(CalQL, self).__init__(config=config)
681
+ self.config = config
682
+
683
+ self.temperature = GeqLagrangeMultiplier(
684
+ init_value=self.config.cql_temperature_init_value,
685
+ constraint_shape=(),
686
+ )
687
+
688
+ self.policy = Policy(
689
+ obs_encoded_dim=self.config.obs_encoded_dim,
690
+ network=MLP(**self.config.policy_network_kwargs),
691
+ action_dim=self.config.action_dim,
692
+ **self.config.policy_kwargs,
693
+ )
694
+
695
+ self.critics = Critics(
696
+ obs_encoded_dim=self.config.obs_encoded_dim,
697
+ networks=[
698
+ MLP(**self.config.critic_network_kwargs)
699
+ for _ in range(self.config.critic_ensemble_size)
700
+ ],
701
+ num_backbones=self.config.critic_ensemble_size,
702
+ )
703
+
704
+ self.target_critics = deepcopy(self.critics)
705
+
706
+ def forward_policy_and_sample(
707
+ self,
708
+ encoded_obs: Float[torch.Tensor, "batch {self.config.obs_encoded_dim}"],
709
+ repeat: int = None,
710
+ ):
711
+ action_dist = self.policy.forward(encoded_obs)
712
+ if repeat:
713
+ new_actions = action_dist.rsample(
714
+ torch.tensor([repeat])
715
+ ) # repeat, tensor, act_dim
716
+ log_pi = action_dist.log_prob(new_actions)
717
+ new_actions = new_actions.permute(1, 0, 2) # (batch, repeat, action_dim)
718
+ log_pi = log_pi.permute(1, 0) # (batch, repeat)
719
+
720
+ else:
721
+ new_actions = action_dist.rsample() # (batch, action_dim)
722
+ log_pi = action_dist.log_prob(new_actions) # (batch)
723
+ # NOTE: detach gradient here
724
+ new_actions = new_actions.detach()
725
+ log_pi = log_pi.detach()
726
+ return new_actions, log_pi
727
+
728
+ def _compute_next_actions(self, batch: at.CalQlBatch):
729
+ """
730
+ compute the next actions but with repeat cql_n_actions times
731
+ this should only be used when calculating critic loss using
732
+ cql_max_target_backup
733
+ """
734
+ sample_n_actions = (
735
+ self.config.cql_n_actions if self.config.cql_max_target_backup else None
736
+ )
737
+
738
+ next_actions, next_actions_log_probs = self.forward_policy_and_sample(
739
+ batch["encoded_next_observations"],
740
+ repeat=sample_n_actions,
741
+ )
742
+ return next_actions, next_actions_log_probs
743
+
744
+ def temperature_loss_fn(self, batch: at.CalQlBatch):
745
+ next_actions, next_actions_log_probs = self._compute_next_actions(batch)
746
+
747
+ entropy = -next_actions_log_probs.mean()
748
+ temperature_loss = self.temperature.forward(
749
+ lhs=entropy,
750
+ rhs=self.config.target_entropy,
751
+ )
752
+ return temperature_loss, {"temperature_loss": temperature_loss}
753
+
754
+ def policy_loss_fn(self, batch: at.CalQlBatch):
755
+ batch_size = batch["rewards"].shape[0]
756
+ temperature = self.temperature.forward().detach() # detach gradient
757
+
758
+ action_distributions = self.policy.forward(batch["encoded_observations"])
759
+ actions = action_distributions.rsample()
760
+ log_probs = action_distributions.log_prob(actions)
761
+
762
+ predicted_qs = self.critics.forward(
763
+ batch["encoded_observations"],
764
+ actions,
765
+ ).detach() # NOTE: detach grads
766
+ predicted_q = predicted_qs.min(dim=0)[0]
767
+
768
+ assert predicted_q.shape == (batch_size,)
769
+ assert log_probs.shape == (batch_size,)
770
+
771
+ nll_objective = -torch.mean(
772
+ action_distributions.log_prob(torch.clip(batch["actions"], -0.99, 0.99))
773
+ )
774
+ actor_objective = predicted_q
775
+ actor_loss = -torch.mean(actor_objective) + torch.mean(temperature * log_probs)
776
+
777
+ info = {
778
+ "actor_loss": actor_loss,
779
+ "actor_nll": nll_objective,
780
+ "temperature": temperature,
781
+ "entropy": -log_probs.mean(),
782
+ "log_probs": log_probs,
783
+ "actions_mse": ((actions - batch["actions"]) ** 2).sum(dim=-1).mean(),
784
+ "dataset_rewards": batch["rewards"],
785
+ "mc_returns": batch.get("mc_returns", None),
786
+ }
787
+
788
+ return actor_loss, info
789
+
790
+ def sac_critic_loss_fn(self, batch: at.CalQlBatch):
791
+ """classes that inherit this class can change this function"""
792
+ batch_size = batch["rewards"].shape[0]
793
+ next_actions, next_actions_log_probs = self._compute_next_actions(batch)
794
+ # (batch_size, ) for sac, (batch_size, cql_n_actions) for cql
795
+
796
+ # Evaluate next Qs for all ensemble members (cheap because we're only doing the forward pass)
797
+ with torch.no_grad():
798
+ self.target_critics.eval()
799
+ target_next_qs = self.target_critics.forward(
800
+ batch["encoded_next_observations"],
801
+ next_actions,
802
+ ) # (critic_ensemble_size, batch_size, cql_n_actions)
803
+ self.target_critics.train()
804
+
805
+ # Subsample if requested
806
+ if self.config.critic_subsample_size is not None:
807
+ subsample_idcs = torch.randint(
808
+ 0,
809
+ self.config.critic_ensemble_size,
810
+ (self.config.critic_ensemble_size,),
811
+ device=target_next_qs.device,
812
+ )
813
+ target_next_qs = target_next_qs[subsample_idcs]
814
+
815
+ # Minimum Q across (subsampled) ensemble members
816
+ target_next_min_q = target_next_qs.min(dim=0)[0]
817
+ assert target_next_min_q.shape == next_actions_log_probs.shape
818
+ # (batch_size,) for sac, (batch_size, cql_n_actions) for cql
819
+
820
+ target_next_min_q = self._process_target_next_qs(
821
+ target_next_min_q,
822
+ next_actions_log_probs,
823
+ )
824
+
825
+ target_q = (
826
+ batch["rewards"] + self.config.discount * batch["masks"] * target_next_min_q
827
+ )
828
+ assert target_q.shape == (batch_size,)
829
+
830
+ predicted_qs = self.critics.forward(
831
+ batch["encoded_observations"], batch["actions"]
832
+ )
833
+ assert predicted_qs.shape == (self.config.critic_ensemble_size, batch_size)
834
+
835
+ target_qs = target_q.unsqueeze(0).expand(self.config.critic_ensemble_size, -1)
836
+ assert predicted_qs.shape == target_qs.shape
837
+ critic_loss = torch.mean((predicted_qs - target_qs) ** 2)
838
+
839
+ info = {
840
+ "td_err": critic_loss,
841
+ "online_q": torch.mean(predicted_qs),
842
+ "target_q": torch.mean(target_qs),
843
+ }
844
+
845
+ if self.config.goal_conditioned:
846
+ num_negatives = int(
847
+ self.config.gc_kwargs["negative_proportion"] * batch_size
848
+ )
849
+ info["negative_qs"] = torch.mean(predicted_qs, dim=-1)[
850
+ :num_negatives
851
+ ].mean()
852
+ info["positive_qs"] = torch.mean(predicted_qs, dim=-1)[
853
+ num_negatives:
854
+ ].mean()
855
+
856
+ return critic_loss, info
857
+
858
+ def _process_target_next_qs(self, target_next_qs, next_actions_log_probs):
859
+ """add cql_max_target_backup option"""
860
+
861
+ if self.config.cql_max_target_backup:
862
+ max_target_indices = torch.argmax(target_next_qs, dim=-1, keepdim=True)
863
+ target_next_qs = torch.gather(
864
+ target_next_qs, -1, max_target_indices
865
+ ).squeeze(-1)
866
+ next_actions_log_probs = torch.gather(
867
+ next_actions_log_probs, -1, max_target_indices
868
+ ).squeeze(-1)
869
+
870
+ target_next_qs = self.sac_process_target_next_qs(
871
+ target_next_qs,
872
+ next_actions_log_probs,
873
+ )
874
+
875
+ return target_next_qs
876
+
877
+ def sac_process_target_next_qs(self, target_next_qs, next_actions_log_probs):
878
+ """classes that inherit this class can add to this function
879
+ e.g. CQL will add the cql_max_target_backup option
880
+ """
881
+ if self.config.backup_entropy:
882
+ temperature = self.forward_temperature()
883
+ target_next_qs = target_next_qs - temperature * next_actions_log_probs
884
+
885
+ return target_next_qs
886
+
887
+ def critic_loss_fn(self, batch: at.CalQlBatch):
888
+ """add CQL loss on top of SAC loss"""
889
+ if self.config.use_td_loss:
890
+ td_loss, td_loss_info = self.sac_critic_loss_fn(batch)
891
+ else:
892
+ td_loss, td_loss_info = 0.0, {}
893
+
894
+ cql_q_diff, cql_intermediate_results = self._get_cql_q_diff(batch)
895
+
896
+ """auto tune cql alpha"""
897
+ if self.config.cql_autotune_alpha:
898
+ raise NotImplementedError
899
+ # alpha = self.forward_cql_alpha_lagrange()
900
+ # cql_loss = (cql_q_diff - self.config["cql_target_action_gap"]).mean()
901
+ else:
902
+ alpha = self.config.cql_alpha
903
+ cql_loss = torch.clip(
904
+ cql_q_diff, self.config.cql_clip_diff_min, self.config.cql_clip_diff_max
905
+ ).mean()
906
+
907
+ critic_loss = td_loss + alpha * cql_loss
908
+
909
+ info = {
910
+ **td_loss_info,
911
+ "critic_loss": critic_loss,
912
+ "td_err": td_loss,
913
+ "cql_loss": cql_loss,
914
+ "cql_alpha": alpha,
915
+ "cql_diff": cql_q_diff.mean(),
916
+ **cql_intermediate_results,
917
+ }
918
+
919
+ return critic_loss, info
920
+
921
+ def _get_cql_q_diff(self, batch: at.CalQlBatch):
922
+ """
923
+ most of the CQL loss logic is here
924
+ It is needed for both critic_loss_fn and cql_alpha_loss_fn
925
+ """
926
+ batch_size = batch["rewards"].shape[0]
927
+
928
+ q_pred = self.critics.forward(batch["encoded_observations"], batch["actions"])
929
+ # HACK: shape changed from jax implementation
930
+ assert q_pred.shape == (self.config.critic_ensemble_size, batch_size)
931
+
932
+ """sample random actions"""
933
+ action_dim = batch["actions"].shape[-1]
934
+ if self.config.cql_action_sample_method == "uniform":
935
+ cql_random_actions = (
936
+ torch.rand(
937
+ (batch_size, self.config.cql_n_actions, action_dim),
938
+ device=batch["actions"].device,
939
+ )
940
+ * 2.0
941
+ - 1.0
942
+ )
943
+ elif self.config.cql_action_sample_method == "normal":
944
+ cql_random_actions = torch.randn(
945
+ (batch_size, self.config.cql_n_actions, action_dim),
946
+ device=batch["actions"].device,
947
+ )
948
+ else:
949
+ raise NotImplementedError
950
+
951
+ cql_current_actions, cql_current_log_pis = self.forward_policy_and_sample(
952
+ batch["encoded_observations"],
953
+ repeat=self.config.cql_n_actions,
954
+ )
955
+ assert cql_current_log_pis.shape == (batch_size, self.config.cql_n_actions)
956
+
957
+ cql_next_actions, cql_next_log_pis = self.forward_policy_and_sample(
958
+ batch["encoded_next_observations"],
959
+ repeat=self.config.cql_n_actions,
960
+ )
961
+
962
+ all_sampled_actions = torch.cat(
963
+ [
964
+ cql_random_actions,
965
+ cql_current_actions,
966
+ cql_next_actions,
967
+ ],
968
+ dim=1,
969
+ )
970
+
971
+ """q values of randomly sampled actions"""
972
+ cql_q_samples = self.critics.forward(
973
+ batch["encoded_observations"], all_sampled_actions
974
+ )
975
+ # HACK: shape changed from jax implementation
976
+ assert cql_q_samples.shape == (
977
+ self.config.critic_ensemble_size,
978
+ batch_size,
979
+ self.config.cql_n_actions * 3,
980
+ )
981
+
982
+ if self.config.critic_subsample_size is not None:
983
+ subsample_idcs = torch.randint(
984
+ 0,
985
+ self.config.critic_ensemble_size,
986
+ (self.config.critic_ensemble_size,),
987
+ device=cql_q_samples.device,
988
+ )
989
+ cql_q_samples = cql_q_samples[subsample_idcs]
990
+
991
+ """Cal-QL"""
992
+ if self.config.use_calql:
993
+ # HACK: check shape of mc_returns
994
+ mc_lower_bound = (
995
+ batch["mc_returns"]
996
+ .reshape(-1, 1)
997
+ .repeat(1, self.config.cql_n_actions * 2)
998
+ )
999
+ assert mc_lower_bound.shape == (
1000
+ batch_size,
1001
+ self.config.cql_n_actions * 2,
1002
+ )
1003
+
1004
+ cql_q_pi = cql_q_samples[:, :, self.config.cql_n_actions :]
1005
+ num_vals = cql_q_pi.numel()
1006
+ calql_bound_rate = torch.sum((cql_q_pi < mc_lower_bound).float()) / num_vals
1007
+ cql_q_pi = torch.maximum(cql_q_pi, mc_lower_bound)
1008
+ cql_q_samples = torch.cat(
1009
+ [
1010
+ cql_q_samples[:, :, : self.config.cql_n_actions],
1011
+ cql_q_pi,
1012
+ ],
1013
+ dim=-1,
1014
+ )
1015
+
1016
+ if self.config.cql_importance_sample:
1017
+ random_density = torch.log(
1018
+ torch.tensor(0.5**action_dim, device=cql_q_samples.device)
1019
+ )
1020
+
1021
+ importance_prob = torch.cat(
1022
+ [
1023
+ random_density.expand(batch_size, self.config.cql_n_actions),
1024
+ cql_current_log_pis,
1025
+ cql_next_log_pis,
1026
+ ],
1027
+ dim=1,
1028
+ )
1029
+ # HACK: check dim
1030
+ cql_q_samples = cql_q_samples - importance_prob.unsqueeze(0)
1031
+ else:
1032
+ cql_q_samples = torch.cat([cql_q_samples, q_pred.unsqueeze(-1)], dim=-1)
1033
+
1034
+ cql_q_samples -= (
1035
+ torch.log(
1036
+ torch.tensor(
1037
+ cql_q_samples.shape[-1],
1038
+ dtype=torch.float,
1039
+ device=cql_q_samples.device,
1040
+ )
1041
+ )
1042
+ * self.config.cql_temp
1043
+ )
1044
+ # HACK: shape diff from jax implementation
1045
+ assert cql_q_samples.shape == (
1046
+ self.config.critic_ensemble_size,
1047
+ batch_size,
1048
+ 3 * self.config.cql_n_actions + 1,
1049
+ )
1050
+
1051
+ """log sum exp of the ood actions"""
1052
+ cql_ood_values = (
1053
+ torch.logsumexp(cql_q_samples / self.config.cql_temp, dim=-1)
1054
+ * self.config.cql_temp
1055
+ )
1056
+ assert cql_ood_values.shape == (self.config.critic_ensemble_size, batch_size)
1057
+
1058
+ cql_q_diff = cql_ood_values - q_pred
1059
+ info = {
1060
+ "cql_ood_values": cql_ood_values.mean(),
1061
+ }
1062
+ if self.config.use_calql:
1063
+ info["calql_bound_rate"] = calql_bound_rate
1064
+
1065
+ return cql_q_diff, info
1066
+
1067
+ @staticmethod
1068
+ def make_optimizer(
1069
+ params: torch.nn.Module,
1070
+ learning_rate: float = 3e-4,
1071
+ warmup_steps: int = 0,
1072
+ cosine_decay_steps: Optional[int] = None,
1073
+ weight_decay: Optional[float] = None,
1074
+ return_lr_schedule: bool = True,
1075
+ ) -> Union[Optimizer, Tuple[Optimizer, LambdaLR]]:
1076
+ optimizer: Optimizer
1077
+ if weight_decay is not None:
1078
+ optimizer = AdamW(
1079
+ params=params,
1080
+ lr=learning_rate,
1081
+ weight_decay=weight_decay,
1082
+ )
1083
+ else:
1084
+ optimizer = Adam(params=params, lr=learning_rate)
1085
+
1086
+ def _lr_lambda(step: int) -> float:
1087
+ if warmup_steps > 0 and step < warmup_steps:
1088
+ return step / warmup_steps
1089
+
1090
+ if cosine_decay_steps is not None:
1091
+ decay_step = step - warmup_steps
1092
+ if decay_step < 0:
1093
+ return 0.0
1094
+ if decay_step >= cosine_decay_steps:
1095
+ return 0.0
1096
+ progress = decay_step / cosine_decay_steps
1097
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
1098
+
1099
+ return 1.0
1100
+
1101
+ scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda)
1102
+
1103
+ if return_lr_schedule:
1104
+ return optimizer, scheduler
1105
+ else:
1106
+ return optimizer
1107
+
1108
+ def prepare_optimizers(self):
1109
+ actor_optimizer, actor_scheduler = self.make_optimizer(
1110
+ self.policy.parameters(), **self.config.actor_optimizer_kwargs
1111
+ )
1112
+ critic_optimizer, critic_scheduler = self.make_optimizer(
1113
+ self.critics.parameters(), **self.config.critic_optimizer_kwargs
1114
+ )
1115
+ temperature_optimizer, temperature_scheduler = self.make_optimizer(
1116
+ self.temperature.parameters(), **self.config.temperature_optimizer_kwargs
1117
+ )
1118
+
1119
+ return (
1120
+ actor_optimizer,
1121
+ actor_scheduler,
1122
+ critic_optimizer,
1123
+ critic_scheduler,
1124
+ temperature_optimizer,
1125
+ temperature_scheduler,
1126
+ )
1127
+
1128
+ def forward(self, batch: at.CalQlBatch):
1129
+ temperature_loss, temperature_loss_info = self.temperature_loss_fn(batch)
1130
+ policy_loss, policy_loss_info = self.policy_loss_fn(batch)
1131
+ critic_loss, critic_loss_info = self.critic_loss_fn(batch)
1132
+
1133
+ return (
1134
+ temperature_loss,
1135
+ policy_loss,
1136
+ critic_loss,
1137
+ {
1138
+ **temperature_loss_info,
1139
+ **policy_loss_info,
1140
+ **critic_loss_info,
1141
+ },
1142
+ )
1143
+
1144
+ @jaxtyped(typechecker=typechecker)
1145
+ def get_q_values(
1146
+ self,
1147
+ encoded_observations: Float[
1148
+ torch.Tensor, "batch {self.config.obs_encoded_dim}"
1149
+ ],
1150
+ noise_actions: Float[torch.Tensor, "batch num_actions action_dim"],
1151
+ ) -> Float[torch.Tensor, "batch num_actions"]:
1152
+ # (num_backbones, batch, *num_actions)
1153
+ q_values = self.target_critics.forward(encoded_observations, noise_actions)
1154
+ q_values = q_values.min(dim=0)[0]
1155
+ return q_values