Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- array_typing.py +80 -0
- config.json +283 -0
- configuration_hume.py +528 -0
- fast_visuo_expert.py +321 -0
- model.safetensors +3 -0
- modeling_hume.py +1909 -0
- paligemma_with_expert.py +444 -0
- special_tokens_map.json +33 -0
- tokenizer.json +3 -0
- tokenizer_config.json +1772 -0
- value_query.py +1155 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
array_typing.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, TypeAlias, TypedDict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from jaxtyping import Float
|
5 |
+
from numpy.typing import NDArray
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
|
9 |
+
class InferConfig(TypedDict):
|
10 |
+
"""Configuration for inference."""
|
11 |
+
|
12 |
+
replan_steps: int
|
13 |
+
s2_replan_steps: int
|
14 |
+
s2_candidates_num: int
|
15 |
+
noise_temp_lower_bound: float
|
16 |
+
noise_temp_upper_bound: float
|
17 |
+
time_temp_lower_bound: float
|
18 |
+
time_temp_upper_bound: float
|
19 |
+
post_process_action: bool
|
20 |
+
device: str
|
21 |
+
|
22 |
+
|
23 |
+
ImageArray: TypeAlias = Annotated[NDArray[np.uint8], "Shape[B, H, W, C]"]
|
24 |
+
StateArray: TypeAlias = Annotated[
|
25 |
+
NDArray[np.float32], "Shape[B, state_horizon, state_dim]"
|
26 |
+
]
|
27 |
+
ActionArray: TypeAlias = Annotated[NDArray[np.float32], "Shape[B, action_dim]"]
|
28 |
+
|
29 |
+
InferBatchObs = TypedDict(
|
30 |
+
"BatchObs",
|
31 |
+
{
|
32 |
+
"observation.images.image": ImageArray,
|
33 |
+
"observation.images.wrist_image": ImageArray,
|
34 |
+
"observation.state": StateArray,
|
35 |
+
"task": list[str],
|
36 |
+
},
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class InferOutput(TypedDict):
|
41 |
+
noise_action: Float[Tensor, "batch s2_chunksize padded_action_dim"]
|
42 |
+
s1_action: Float[Tensor, "batch s1_chunksize unpadded_action_dim"]
|
43 |
+
s2_action: Float[Tensor, "batch s2_chunksize unpadded_action_dim"]
|
44 |
+
|
45 |
+
|
46 |
+
class CalQlBatch(TypedDict):
|
47 |
+
encoded_observations: Float[Tensor, "batch encoded_dim"]
|
48 |
+
encoded_next_observations: Float[Tensor, "batch encoded_dim"]
|
49 |
+
actions: Float[Tensor, "batch action_dim"]
|
50 |
+
rewards: Float[Tensor, " batch"]
|
51 |
+
mc_returns: Float[Tensor, " batch"]
|
52 |
+
masks: Float[Tensor, " batch"]
|
53 |
+
|
54 |
+
|
55 |
+
class EnvArgs(TypedDict):
|
56 |
+
"""Environment arguments."""
|
57 |
+
|
58 |
+
# necessary args
|
59 |
+
num_trials_per_task: int
|
60 |
+
num_steps_wait: int
|
61 |
+
task_suite_name: str
|
62 |
+
seed: int
|
63 |
+
ckpt_path: str | None
|
64 |
+
eval_name: str | None
|
65 |
+
|
66 |
+
|
67 |
+
class Request(TypedDict):
|
68 |
+
"""Environment receive message."""
|
69 |
+
|
70 |
+
frame_type: str # "init" | "action"
|
71 |
+
env_args: EnvArgs | None
|
72 |
+
action: ActionArray | None
|
73 |
+
|
74 |
+
|
75 |
+
class Response(TypedDict):
|
76 |
+
"""Environment send message."""
|
77 |
+
|
78 |
+
status: str # "new_episode" | "eval_finished" | "in_episode"
|
79 |
+
success_rate: float | None
|
80 |
+
observation: InferBatchObs | None
|
config.json
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_obs_steps": 1,
|
3 |
+
"normalization_mapping": {
|
4 |
+
"VISUAL": "IDENTITY",
|
5 |
+
"STATE": "MEAN_STD",
|
6 |
+
"ACTION": "MEAN_STD"
|
7 |
+
},
|
8 |
+
"input_features": {
|
9 |
+
"observation.images.image": {
|
10 |
+
"type": "VISUAL",
|
11 |
+
"shape": [
|
12 |
+
3,
|
13 |
+
256,
|
14 |
+
256
|
15 |
+
]
|
16 |
+
},
|
17 |
+
"observation.images.wrist_image": {
|
18 |
+
"type": "VISUAL",
|
19 |
+
"shape": [
|
20 |
+
3,
|
21 |
+
256,
|
22 |
+
256
|
23 |
+
]
|
24 |
+
},
|
25 |
+
"observation.state": {
|
26 |
+
"type": "STATE",
|
27 |
+
"shape": [
|
28 |
+
8
|
29 |
+
]
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"output_features": {
|
33 |
+
"action": {
|
34 |
+
"type": "ACTION",
|
35 |
+
"shape": [
|
36 |
+
7
|
37 |
+
]
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"device": "cpu",
|
41 |
+
"use_amp": false,
|
42 |
+
"type": "hume",
|
43 |
+
"s1_chunk_size": 8,
|
44 |
+
"s2_chunk_size": 16,
|
45 |
+
"n_action_steps": 16,
|
46 |
+
"max_state_dim": 32,
|
47 |
+
"max_action_dim": 32,
|
48 |
+
"resize_imgs_with_padding": [
|
49 |
+
224,
|
50 |
+
224
|
51 |
+
],
|
52 |
+
"empty_cameras": 0,
|
53 |
+
"adapt_to_pi_aloha": false,
|
54 |
+
"use_delta_joint_actions_aloha": false,
|
55 |
+
"tokenizer_max_length": 48,
|
56 |
+
"proj_width": 1024,
|
57 |
+
"num_steps": 10,
|
58 |
+
"use_cache": true,
|
59 |
+
"attention_implementation": "eager",
|
60 |
+
"freeze_vision_encoder": true,
|
61 |
+
"train_expert_only": false,
|
62 |
+
"train_state_proj": true,
|
63 |
+
"optimizer_lr": 5e-05,
|
64 |
+
"optimizer_betas": [
|
65 |
+
0.9,
|
66 |
+
0.95
|
67 |
+
],
|
68 |
+
"optimizer_eps": 1e-08,
|
69 |
+
"optimizer_weight_decay": 1e-10,
|
70 |
+
"scheduler_warmup_steps": 1000,
|
71 |
+
"scheduler_decay_steps": 1600000,
|
72 |
+
"scheduler_decay_lr": 2.5e-06,
|
73 |
+
"freeze_s2": true,
|
74 |
+
"s1_his_state_size": 4,
|
75 |
+
"cache_s2_actions": false,
|
76 |
+
"theta2": 1.0,
|
77 |
+
"theta1": 1.0,
|
78 |
+
"noise_slides_eps": 0.0,
|
79 |
+
"noise_slides_alp": 0.0,
|
80 |
+
"s1_proj_width": 512,
|
81 |
+
"freeze_s1_vision_encoder": false,
|
82 |
+
"s1_num_steps": 10,
|
83 |
+
"num_pos": 3,
|
84 |
+
"discount": 0.98,
|
85 |
+
"actor_lr": 1e-05,
|
86 |
+
"critic_lr": 1e-05,
|
87 |
+
"temp_lr": 2e-05,
|
88 |
+
"qf_lr": 0.0003,
|
89 |
+
"next_obs_offset": 1,
|
90 |
+
"vqh_chunk_size": 1,
|
91 |
+
"paligemma_config": {
|
92 |
+
"bos_token_id": 2,
|
93 |
+
"eos_token_id": 1,
|
94 |
+
"hidden_size": 2048,
|
95 |
+
"ignore_index": -100,
|
96 |
+
"image_token_index": 257152,
|
97 |
+
"model_type": "paligemma",
|
98 |
+
"pad_token_id": 0,
|
99 |
+
"projection_dim": 2048,
|
100 |
+
"text_config": {
|
101 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
102 |
+
"hidden_size": 2048,
|
103 |
+
"intermediate_size": 16384,
|
104 |
+
"model_type": "gemma",
|
105 |
+
"num_attention_heads": 8,
|
106 |
+
"num_hidden_layers": 18,
|
107 |
+
"num_image_tokens": 256,
|
108 |
+
"num_key_value_heads": 1,
|
109 |
+
"torch_dtype": "float32",
|
110 |
+
"vocab_size": 257152
|
111 |
+
},
|
112 |
+
"torch_dtype": "float32",
|
113 |
+
"transformers_version": "4.48.1",
|
114 |
+
"vision_config": {
|
115 |
+
"hidden_size": 1152,
|
116 |
+
"intermediate_size": 4304,
|
117 |
+
"model_type": "siglip_vision_model",
|
118 |
+
"num_attention_heads": 16,
|
119 |
+
"num_hidden_layers": 27,
|
120 |
+
"num_image_tokens": 256,
|
121 |
+
"patch_size": 14,
|
122 |
+
"projection_dim": 2048,
|
123 |
+
"projector_hidden_act": "gelu_fast",
|
124 |
+
"vision_use_head": false
|
125 |
+
},
|
126 |
+
"vocab_size": 257152
|
127 |
+
},
|
128 |
+
"gemma_expert_config": {
|
129 |
+
"attention_bias": false,
|
130 |
+
"attention_dropout": 0.0,
|
131 |
+
"bos_token_id": 2,
|
132 |
+
"eos_token_id": 1,
|
133 |
+
"head_dim": 256,
|
134 |
+
"hidden_act": "gelu_pytorch_tanh",
|
135 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
136 |
+
"hidden_size": 1024,
|
137 |
+
"initializer_range": 0.02,
|
138 |
+
"intermediate_size": 4096,
|
139 |
+
"max_position_embeddings": 8192,
|
140 |
+
"model_type": "gemma",
|
141 |
+
"num_attention_heads": 8,
|
142 |
+
"num_hidden_layers": 18,
|
143 |
+
"num_key_value_heads": 1,
|
144 |
+
"pad_token_id": 0,
|
145 |
+
"rms_norm_eps": 1e-06,
|
146 |
+
"rope_theta": 10000.0,
|
147 |
+
"torch_dtype": "float32",
|
148 |
+
"transformers_version": "4.48.1",
|
149 |
+
"use_cache": true,
|
150 |
+
"vocab_size": 257152
|
151 |
+
},
|
152 |
+
"s1_dino_config": {
|
153 |
+
"return_dict": true,
|
154 |
+
"output_hidden_states": false,
|
155 |
+
"output_attentions": false,
|
156 |
+
"torchscript": false,
|
157 |
+
"torch_dtype": "float32",
|
158 |
+
"use_bfloat16": false,
|
159 |
+
"tf_legacy_loss": false,
|
160 |
+
"pruned_heads": {},
|
161 |
+
"tie_word_embeddings": true,
|
162 |
+
"chunk_size_feed_forward": 0,
|
163 |
+
"is_encoder_decoder": false,
|
164 |
+
"is_decoder": false,
|
165 |
+
"cross_attention_hidden_size": null,
|
166 |
+
"add_cross_attention": false,
|
167 |
+
"tie_encoder_decoder": false,
|
168 |
+
"max_length": 20,
|
169 |
+
"min_length": 0,
|
170 |
+
"do_sample": false,
|
171 |
+
"early_stopping": false,
|
172 |
+
"num_beams": 1,
|
173 |
+
"num_beam_groups": 1,
|
174 |
+
"diversity_penalty": 0.0,
|
175 |
+
"temperature": 1.0,
|
176 |
+
"top_k": 50,
|
177 |
+
"top_p": 1.0,
|
178 |
+
"typical_p": 1.0,
|
179 |
+
"repetition_penalty": 1.0,
|
180 |
+
"length_penalty": 1.0,
|
181 |
+
"no_repeat_ngram_size": 0,
|
182 |
+
"encoder_no_repeat_ngram_size": 0,
|
183 |
+
"bad_words_ids": null,
|
184 |
+
"num_return_sequences": 1,
|
185 |
+
"output_scores": false,
|
186 |
+
"return_dict_in_generate": false,
|
187 |
+
"forced_bos_token_id": null,
|
188 |
+
"forced_eos_token_id": null,
|
189 |
+
"remove_invalid_values": false,
|
190 |
+
"exponential_decay_length_penalty": null,
|
191 |
+
"suppress_tokens": null,
|
192 |
+
"begin_suppress_tokens": null,
|
193 |
+
"architectures": [
|
194 |
+
"Dinov2Model"
|
195 |
+
],
|
196 |
+
"finetuning_task": null,
|
197 |
+
"id2label": {
|
198 |
+
"0": "LABEL_0",
|
199 |
+
"1": "LABEL_1"
|
200 |
+
},
|
201 |
+
"label2id": {
|
202 |
+
"LABEL_0": 0,
|
203 |
+
"LABEL_1": 1
|
204 |
+
},
|
205 |
+
"tokenizer_class": null,
|
206 |
+
"prefix": null,
|
207 |
+
"bos_token_id": null,
|
208 |
+
"pad_token_id": null,
|
209 |
+
"eos_token_id": null,
|
210 |
+
"sep_token_id": null,
|
211 |
+
"decoder_start_token_id": null,
|
212 |
+
"task_specific_params": null,
|
213 |
+
"problem_type": null,
|
214 |
+
"_name_or_path": "../pretrained/dinov2-small",
|
215 |
+
"_attn_implementation_autoset": false,
|
216 |
+
"transformers_version": "4.52.0.dev0",
|
217 |
+
"model_type": "dinov2",
|
218 |
+
"hidden_size": 384,
|
219 |
+
"num_hidden_layers": 12,
|
220 |
+
"num_attention_heads": 6,
|
221 |
+
"mlp_ratio": 4,
|
222 |
+
"hidden_act": "gelu",
|
223 |
+
"hidden_dropout_prob": 0.0,
|
224 |
+
"attention_probs_dropout_prob": 0.0,
|
225 |
+
"initializer_range": 0.02,
|
226 |
+
"layer_norm_eps": 1e-06,
|
227 |
+
"image_size": 518,
|
228 |
+
"patch_size": 14,
|
229 |
+
"num_channels": 3,
|
230 |
+
"qkv_bias": true,
|
231 |
+
"layerscale_value": 1.0,
|
232 |
+
"drop_path_rate": 0.0,
|
233 |
+
"use_swiglu_ffn": false,
|
234 |
+
"stage_names": [
|
235 |
+
"stem",
|
236 |
+
"stage1",
|
237 |
+
"stage2",
|
238 |
+
"stage3",
|
239 |
+
"stage4",
|
240 |
+
"stage5",
|
241 |
+
"stage6",
|
242 |
+
"stage7",
|
243 |
+
"stage8",
|
244 |
+
"stage9",
|
245 |
+
"stage10",
|
246 |
+
"stage11",
|
247 |
+
"stage12"
|
248 |
+
],
|
249 |
+
"apply_layernorm": true,
|
250 |
+
"reshape_hidden_states": true,
|
251 |
+
"use_mask_token": true,
|
252 |
+
"out_features": [
|
253 |
+
"stage12"
|
254 |
+
],
|
255 |
+
"out_indices": [
|
256 |
+
12
|
257 |
+
]
|
258 |
+
},
|
259 |
+
"s1_gemma_expert_config": {
|
260 |
+
"attention_bias": false,
|
261 |
+
"attention_dropout": 0.0,
|
262 |
+
"bos_token_id": 2,
|
263 |
+
"eos_token_id": 1,
|
264 |
+
"head_dim": 128,
|
265 |
+
"hidden_act": "gelu_pytorch_tanh",
|
266 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
267 |
+
"hidden_size": 512,
|
268 |
+
"initializer_range": 0.02,
|
269 |
+
"intermediate_size": 2048,
|
270 |
+
"max_position_embeddings": 8192,
|
271 |
+
"model_type": "gemma",
|
272 |
+
"num_attention_heads": 8,
|
273 |
+
"num_hidden_layers": 13,
|
274 |
+
"num_key_value_heads": 1,
|
275 |
+
"pad_token_id": 0,
|
276 |
+
"rms_norm_eps": 1e-06,
|
277 |
+
"rope_theta": 10000.0,
|
278 |
+
"torch_dtype": "float32",
|
279 |
+
"transformers_version": "4.48.1",
|
280 |
+
"use_cache": true,
|
281 |
+
"vocab_size": 257152
|
282 |
+
}
|
283 |
+
}
|
configuration_hume.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
from lerobot.common.optim.optimizers import AdamWConfig
|
4 |
+
from lerobot.common.optim.schedulers import (
|
5 |
+
CosineDecayWithWarmupSchedulerConfig,
|
6 |
+
)
|
7 |
+
from lerobot.configs.policies import PreTrainedConfig
|
8 |
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
9 |
+
|
10 |
+
|
11 |
+
@PreTrainedConfig.register_subclass("hume")
|
12 |
+
@dataclass
|
13 |
+
class HumeConfig(PreTrainedConfig):
|
14 |
+
# Input / output structure.
|
15 |
+
type: str = "hume"
|
16 |
+
n_obs_steps: int = 1
|
17 |
+
s1_chunk_size: int = 10
|
18 |
+
s2_chunk_size: int = 50
|
19 |
+
n_action_steps: int = 50
|
20 |
+
|
21 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
22 |
+
default_factory=lambda: {
|
23 |
+
"VISUAL": NormalizationMode.IDENTITY,
|
24 |
+
"STATE": NormalizationMode.MEAN_STD,
|
25 |
+
"ACTION": NormalizationMode.MEAN_STD,
|
26 |
+
}
|
27 |
+
)
|
28 |
+
|
29 |
+
# Shorter state and action vectors will be padded
|
30 |
+
max_state_dim: int = 32
|
31 |
+
max_action_dim: int = 32
|
32 |
+
|
33 |
+
# Image preprocessing
|
34 |
+
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
35 |
+
|
36 |
+
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
37 |
+
# left and right wrist cameras in addition to the top camera.
|
38 |
+
empty_cameras: int = 0
|
39 |
+
|
40 |
+
# Converts the joint and gripper values from the standard Aloha space to
|
41 |
+
# the space used by the pi internal runtime which was used to train the base model.
|
42 |
+
adapt_to_pi_aloha: bool = False
|
43 |
+
|
44 |
+
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
45 |
+
# Gripper dimensions will remain in absolute values.
|
46 |
+
use_delta_joint_actions_aloha: bool = False
|
47 |
+
|
48 |
+
# Tokenizer
|
49 |
+
tokenizer_max_length: int = 48
|
50 |
+
|
51 |
+
# Projector
|
52 |
+
proj_width: int = 1024
|
53 |
+
|
54 |
+
# Decoding
|
55 |
+
num_steps: int = 10
|
56 |
+
|
57 |
+
# Attention utils
|
58 |
+
use_cache: bool = True
|
59 |
+
attention_implementation: str = "eager" # or fa2, flex
|
60 |
+
|
61 |
+
# Finetuning settings
|
62 |
+
freeze_vision_encoder: bool = True
|
63 |
+
train_expert_only: bool = False
|
64 |
+
train_state_proj: bool = True
|
65 |
+
|
66 |
+
# Training presets
|
67 |
+
optimizer_lr: float = 2.5e-5
|
68 |
+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
69 |
+
optimizer_eps: float = 1e-8
|
70 |
+
optimizer_weight_decay: float = 1e-10
|
71 |
+
|
72 |
+
scheduler_warmup_steps: int = 1_000
|
73 |
+
scheduler_decay_steps: int = 30_000
|
74 |
+
scheduler_decay_lr: float = 2.5e-6
|
75 |
+
|
76 |
+
# + Aadditional attributes for s1 / s2
|
77 |
+
# freeze system
|
78 |
+
freeze_s2: bool = False
|
79 |
+
s1_his_state_size: int = 1
|
80 |
+
cache_s2_actions: bool = False
|
81 |
+
|
82 |
+
# denoise ratio
|
83 |
+
theta2: float = 1.0
|
84 |
+
theta1: float = 1.0
|
85 |
+
noise_slides_eps: float = 0.0
|
86 |
+
noise_slides_alp: float = 0.0
|
87 |
+
|
88 |
+
# projector
|
89 |
+
s1_proj_width: int = 512 # NOTE: consitent with the s1_gemma_expert_config
|
90 |
+
freeze_s1_vision_encoder: bool = False
|
91 |
+
|
92 |
+
# decoding
|
93 |
+
s1_num_steps: int = 10
|
94 |
+
|
95 |
+
# vqh
|
96 |
+
num_pos: int = 3
|
97 |
+
discount: float = 0.98
|
98 |
+
actor_lr: float = 1e-4 # actor learning rate
|
99 |
+
critic_lr: float = 3e-4
|
100 |
+
temp_lr: float = 3e-4
|
101 |
+
qf_lr: float = 3e-4 # Critics learning rate
|
102 |
+
next_obs_offset: int = 10 # should be equal to vqh_chunk_size
|
103 |
+
vqh_chunk_size: int = 10
|
104 |
+
|
105 |
+
paligemma_config: dict = field(
|
106 |
+
default_factory=lambda: {
|
107 |
+
"bos_token_id": 2,
|
108 |
+
"eos_token_id": 1,
|
109 |
+
"hidden_size": 2048,
|
110 |
+
"ignore_index": -100,
|
111 |
+
"image_token_index": 257152,
|
112 |
+
"model_type": "paligemma",
|
113 |
+
"pad_token_id": 0,
|
114 |
+
"projection_dim": 2048,
|
115 |
+
"text_config": {
|
116 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
117 |
+
"hidden_size": 2048,
|
118 |
+
"intermediate_size": 16384,
|
119 |
+
"model_type": "gemma",
|
120 |
+
"num_attention_heads": 8,
|
121 |
+
"num_hidden_layers": 18,
|
122 |
+
"num_image_tokens": 256,
|
123 |
+
"num_key_value_heads": 1,
|
124 |
+
"torch_dtype": "float32",
|
125 |
+
"vocab_size": 257152,
|
126 |
+
},
|
127 |
+
"torch_dtype": "float32",
|
128 |
+
"transformers_version": "4.48.1",
|
129 |
+
"vision_config": {
|
130 |
+
"hidden_size": 1152,
|
131 |
+
"intermediate_size": 4304,
|
132 |
+
"model_type": "siglip_vision_model",
|
133 |
+
"num_attention_heads": 16,
|
134 |
+
"num_hidden_layers": 27,
|
135 |
+
"num_image_tokens": 256,
|
136 |
+
"patch_size": 14,
|
137 |
+
"projection_dim": 2048,
|
138 |
+
"projector_hidden_act": "gelu_fast",
|
139 |
+
"vision_use_head": False,
|
140 |
+
},
|
141 |
+
"vocab_size": 257152,
|
142 |
+
}
|
143 |
+
)
|
144 |
+
|
145 |
+
gemma_expert_config: dict = field(
|
146 |
+
default_factory=lambda: {
|
147 |
+
"attention_bias": False,
|
148 |
+
"attention_dropout": 0.0,
|
149 |
+
"bos_token_id": 2,
|
150 |
+
"eos_token_id": 1,
|
151 |
+
"head_dim": 256,
|
152 |
+
"hidden_act": "gelu_pytorch_tanh",
|
153 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
154 |
+
"hidden_size": 1024,
|
155 |
+
"initializer_range": 0.02,
|
156 |
+
"intermediate_size": 4096,
|
157 |
+
"max_position_embeddings": 8192,
|
158 |
+
"model_type": "gemma",
|
159 |
+
"num_attention_heads": 8,
|
160 |
+
"num_hidden_layers": 18,
|
161 |
+
"num_key_value_heads": 1,
|
162 |
+
"pad_token_id": 0,
|
163 |
+
"rms_norm_eps": 1e-06,
|
164 |
+
"rope_theta": 10000.0,
|
165 |
+
"torch_dtype": "float32",
|
166 |
+
"transformers_version": "4.48.1",
|
167 |
+
"use_cache": True,
|
168 |
+
"vocab_size": 257152,
|
169 |
+
}
|
170 |
+
)
|
171 |
+
|
172 |
+
# TODO: Add EMA
|
173 |
+
|
174 |
+
# system2 configurations
|
175 |
+
s1_dino_config: dict = field(
|
176 |
+
default_factory=lambda: {
|
177 |
+
"model_type": "dinov2",
|
178 |
+
"attention_probs_dropout_prob": 0.0,
|
179 |
+
"drop_path_rate": 0.0,
|
180 |
+
"hidden_act": "gelu",
|
181 |
+
"hidden_dropout_prob": 0.0,
|
182 |
+
"hidden_size": 384,
|
183 |
+
"image_size": 518,
|
184 |
+
"initializer_range": 0.02,
|
185 |
+
"layer_norm_eps": 1e-06,
|
186 |
+
"layerscale_value": 1.0,
|
187 |
+
"mlp_ratio": 4,
|
188 |
+
"num_attention_heads": 6,
|
189 |
+
"num_channels": 3,
|
190 |
+
"num_hidden_layers": 12,
|
191 |
+
"patch_size": 14,
|
192 |
+
"qkv_bias": True,
|
193 |
+
"torch_dtype": "float32",
|
194 |
+
"use_swiglu_ffn": False,
|
195 |
+
}
|
196 |
+
)
|
197 |
+
|
198 |
+
s1_gemma_expert_config: dict = field(
|
199 |
+
default_factory=lambda: {
|
200 |
+
"attention_bias": False,
|
201 |
+
"attention_dropout": 0.0,
|
202 |
+
"bos_token_id": 2,
|
203 |
+
"eos_token_id": 1,
|
204 |
+
"head_dim": 128,
|
205 |
+
"hidden_act": "gelu_pytorch_tanh",
|
206 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
207 |
+
"hidden_size": 512,
|
208 |
+
"initializer_range": 0.02,
|
209 |
+
"intermediate_size": 2048,
|
210 |
+
"max_position_embeddings": 8192,
|
211 |
+
"model_type": "gemma",
|
212 |
+
"num_attention_heads": 8,
|
213 |
+
"num_hidden_layers": 13,
|
214 |
+
"num_key_value_heads": 1,
|
215 |
+
"pad_token_id": 0,
|
216 |
+
"rms_norm_eps": 1e-06,
|
217 |
+
"rope_theta": 10000.0,
|
218 |
+
"torch_dtype": "float32",
|
219 |
+
"transformers_version": "4.48.1",
|
220 |
+
"use_cache": True,
|
221 |
+
"vocab_size": 257152,
|
222 |
+
}
|
223 |
+
)
|
224 |
+
|
225 |
+
def __post_init__(self):
|
226 |
+
super().__post_init__()
|
227 |
+
|
228 |
+
"""Input validation (not exhaustive)."""
|
229 |
+
if self.n_action_steps > self.s2_chunk_size:
|
230 |
+
raise ValueError(
|
231 |
+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
232 |
+
f"{self.n_action_steps} for `n_action_steps` and {self.s2_chunk_size} for `chunk_size`."
|
233 |
+
)
|
234 |
+
if self.n_obs_steps != 1:
|
235 |
+
raise ValueError(
|
236 |
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
237 |
+
)
|
238 |
+
|
239 |
+
if self.use_delta_joint_actions_aloha:
|
240 |
+
raise NotImplementedError(
|
241 |
+
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
242 |
+
)
|
243 |
+
|
244 |
+
def validate_features(self) -> None:
|
245 |
+
# TODO: implement value error
|
246 |
+
# if not self.image_features and not self.env_state_feature:
|
247 |
+
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
248 |
+
|
249 |
+
for i in range(self.empty_cameras):
|
250 |
+
key = f"observation.images.empty_camera_{i}"
|
251 |
+
empty_camera = PolicyFeature(
|
252 |
+
type=FeatureType.VISUAL,
|
253 |
+
shape=(3, 480, 640),
|
254 |
+
)
|
255 |
+
self.input_features[key] = empty_camera
|
256 |
+
|
257 |
+
def get_optimizer_preset(self) -> dict[AdamWConfig]:
|
258 |
+
qf_optimizer = AdamWConfig(
|
259 |
+
lr=self.qf_lr,
|
260 |
+
weight_decay=0,
|
261 |
+
grad_clip_norm=10,
|
262 |
+
)
|
263 |
+
actor_optimizer = AdamWConfig(
|
264 |
+
lr=self.actor_lr,
|
265 |
+
weight_decay=0,
|
266 |
+
grad_clip_norm=10,
|
267 |
+
)
|
268 |
+
|
269 |
+
trunk_optimizer = AdamWConfig(
|
270 |
+
lr=self.optimizer_lr,
|
271 |
+
betas=self.optimizer_betas,
|
272 |
+
eps=self.optimizer_eps,
|
273 |
+
weight_decay=self.optimizer_weight_decay,
|
274 |
+
)
|
275 |
+
|
276 |
+
optimizer_dict = dict(
|
277 |
+
qf_optimizer=qf_optimizer,
|
278 |
+
actor_optimizer=actor_optimizer,
|
279 |
+
trunk_optimizer=trunk_optimizer,
|
280 |
+
)
|
281 |
+
|
282 |
+
return optimizer_dict
|
283 |
+
|
284 |
+
def get_scheduler_preset(self):
|
285 |
+
return CosineDecayWithWarmupSchedulerConfig(
|
286 |
+
peak_lr=self.optimizer_lr,
|
287 |
+
decay_lr=self.scheduler_decay_lr,
|
288 |
+
num_warmup_steps=self.scheduler_warmup_steps,
|
289 |
+
num_decay_steps=self.scheduler_decay_steps,
|
290 |
+
)
|
291 |
+
|
292 |
+
@property
|
293 |
+
def observation_delta_indices(self) -> None:
|
294 |
+
return None
|
295 |
+
|
296 |
+
@property
|
297 |
+
def action_delta_indices(self) -> list:
|
298 |
+
return list(range(self.s2_chunk_size))
|
299 |
+
|
300 |
+
@property
|
301 |
+
def reward_delta_indices(self) -> None:
|
302 |
+
return None
|
303 |
+
|
304 |
+
@property
|
305 |
+
def slide(self) -> None:
|
306 |
+
return self.s2_chunk_size // self.s1_chunk_size
|
307 |
+
|
308 |
+
@property
|
309 |
+
def s1_action_steps(self) -> None:
|
310 |
+
return self.s1_chunk_size
|
311 |
+
|
312 |
+
@property
|
313 |
+
def s2_action_steps(self) -> None:
|
314 |
+
return self.s2_chunk_size
|
315 |
+
|
316 |
+
from dataclasses import dataclass, field
|
317 |
+
|
318 |
+
from lerobot.common.optim.optimizers import AdamWConfig
|
319 |
+
from lerobot.common.optim.schedulers import (
|
320 |
+
CosineDecayWithWarmupSchedulerConfig,
|
321 |
+
)
|
322 |
+
from lerobot.configs.policies import PreTrainedConfig
|
323 |
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
324 |
+
|
325 |
+
|
326 |
+
@PreTrainedConfig.register_subclass("system2")
|
327 |
+
@dataclass
|
328 |
+
class System2Config(PreTrainedConfig):
|
329 |
+
# Input / output structure.
|
330 |
+
num_pos: int = 3
|
331 |
+
discount: float = 0.98
|
332 |
+
n_obs_steps: int = 1
|
333 |
+
chunk_size: int = 50
|
334 |
+
n_action_steps: int = 50
|
335 |
+
next_obs_offset: int = 1
|
336 |
+
s1_his_state_size: int = 1
|
337 |
+
|
338 |
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
339 |
+
default_factory=lambda: {
|
340 |
+
"VISUAL": NormalizationMode.IDENTITY,
|
341 |
+
"STATE": NormalizationMode.MEAN_STD,
|
342 |
+
"ACTION": NormalizationMode.MEAN_STD,
|
343 |
+
}
|
344 |
+
)
|
345 |
+
|
346 |
+
# Shorter state and action vectors will be padded
|
347 |
+
max_state_dim: int = 32
|
348 |
+
max_action_dim: int = 32
|
349 |
+
|
350 |
+
# Image preprocessing
|
351 |
+
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
352 |
+
|
353 |
+
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
354 |
+
# left and right wrist cameras in addition to the top camera.
|
355 |
+
empty_cameras: int = 0
|
356 |
+
|
357 |
+
# Converts the joint and gripper values from the standard Aloha space to
|
358 |
+
# the space used by the pi internal runtime which was used to train the base model.
|
359 |
+
adapt_to_pi_aloha: bool = False
|
360 |
+
|
361 |
+
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
362 |
+
# Gripper dimensions will remain in absolute values.
|
363 |
+
use_delta_joint_actions_aloha: bool = False
|
364 |
+
|
365 |
+
# Tokenizer
|
366 |
+
tokenizer_max_length: int = 48
|
367 |
+
|
368 |
+
# Projector
|
369 |
+
proj_width: int = 1024
|
370 |
+
|
371 |
+
# Decoding
|
372 |
+
num_steps: int = 10
|
373 |
+
|
374 |
+
# Attention utils
|
375 |
+
use_cache: bool = True
|
376 |
+
attention_implementation: str = "eager" # or fa2, flex
|
377 |
+
|
378 |
+
# Finetuning settings
|
379 |
+
freeze_vision_encoder: bool = True
|
380 |
+
train_expert_only: bool = False
|
381 |
+
train_state_proj: bool = True
|
382 |
+
|
383 |
+
# Training presets
|
384 |
+
optimizer_lr: float = 2.5e-5
|
385 |
+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
386 |
+
optimizer_eps: float = 1e-8
|
387 |
+
optimizer_weight_decay: float = 1e-10
|
388 |
+
|
389 |
+
scheduler_warmup_steps: int = 1_000
|
390 |
+
scheduler_decay_steps: int = 30_000
|
391 |
+
scheduler_decay_lr: float = 2.5e-6
|
392 |
+
|
393 |
+
paligemma_config: dict = field(
|
394 |
+
default_factory=lambda: {
|
395 |
+
"bos_token_id": 2,
|
396 |
+
"eos_token_id": 1,
|
397 |
+
"hidden_size": 2048,
|
398 |
+
"ignore_index": -100,
|
399 |
+
"image_token_index": 257152,
|
400 |
+
"model_type": "paligemma",
|
401 |
+
"pad_token_id": 0,
|
402 |
+
"projection_dim": 2048,
|
403 |
+
"text_config": {
|
404 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
405 |
+
"hidden_size": 2048,
|
406 |
+
"intermediate_size": 16384,
|
407 |
+
"model_type": "gemma",
|
408 |
+
"num_attention_heads": 8,
|
409 |
+
"num_hidden_layers": 18,
|
410 |
+
"num_image_tokens": 256,
|
411 |
+
"num_key_value_heads": 1,
|
412 |
+
"torch_dtype": "float32",
|
413 |
+
"vocab_size": 257152,
|
414 |
+
},
|
415 |
+
"torch_dtype": "float32",
|
416 |
+
"transformers_version": "4.48.1",
|
417 |
+
"vision_config": {
|
418 |
+
"hidden_size": 1152,
|
419 |
+
"intermediate_size": 4304,
|
420 |
+
"model_type": "siglip_vision_model",
|
421 |
+
"num_attention_heads": 16,
|
422 |
+
"num_hidden_layers": 27,
|
423 |
+
"num_image_tokens": 256,
|
424 |
+
"patch_size": 14,
|
425 |
+
"projection_dim": 2048,
|
426 |
+
"projector_hidden_act": "gelu_fast",
|
427 |
+
"vision_use_head": False,
|
428 |
+
},
|
429 |
+
"vocab_size": 257152,
|
430 |
+
}
|
431 |
+
)
|
432 |
+
|
433 |
+
gemma_expert_config: dict = field(
|
434 |
+
default_factory=lambda: {
|
435 |
+
"attention_bias": False,
|
436 |
+
"attention_dropout": 0.0,
|
437 |
+
"bos_token_id": 2,
|
438 |
+
"eos_token_id": 1,
|
439 |
+
"head_dim": 256,
|
440 |
+
"hidden_act": "gelu_pytorch_tanh",
|
441 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
442 |
+
"hidden_size": 1024,
|
443 |
+
"initializer_range": 0.02,
|
444 |
+
"intermediate_size": 4096,
|
445 |
+
"max_position_embeddings": 8192,
|
446 |
+
"model_type": "gemma",
|
447 |
+
"num_attention_heads": 8,
|
448 |
+
"num_hidden_layers": 18,
|
449 |
+
"num_key_value_heads": 1,
|
450 |
+
"pad_token_id": 0,
|
451 |
+
"rms_norm_eps": 1e-06,
|
452 |
+
"rope_theta": 10000.0,
|
453 |
+
"torch_dtype": "float32",
|
454 |
+
"transformers_version": "4.48.1",
|
455 |
+
"use_cache": True,
|
456 |
+
"vocab_size": 257152,
|
457 |
+
}
|
458 |
+
)
|
459 |
+
|
460 |
+
# TODO: Add EMA
|
461 |
+
|
462 |
+
def __post_init__(self):
|
463 |
+
super().__post_init__()
|
464 |
+
|
465 |
+
"""Input validation (not exhaustive)."""
|
466 |
+
if self.n_action_steps > self.chunk_size:
|
467 |
+
raise ValueError(
|
468 |
+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
469 |
+
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
470 |
+
)
|
471 |
+
if self.n_obs_steps != 1:
|
472 |
+
raise ValueError(
|
473 |
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
474 |
+
)
|
475 |
+
|
476 |
+
if self.use_delta_joint_actions_aloha:
|
477 |
+
raise NotImplementedError(
|
478 |
+
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
479 |
+
)
|
480 |
+
|
481 |
+
def validate_features(self) -> None:
|
482 |
+
# TODO: implement value error
|
483 |
+
# if not self.image_features and not self.env_state_feature:
|
484 |
+
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
485 |
+
|
486 |
+
for i in range(self.empty_cameras):
|
487 |
+
key = f"observation.images.empty_camera_{i}"
|
488 |
+
empty_camera = PolicyFeature(
|
489 |
+
type=FeatureType.VISUAL,
|
490 |
+
shape=(3, 480, 640),
|
491 |
+
)
|
492 |
+
self.input_features[key] = empty_camera
|
493 |
+
|
494 |
+
def get_optimizer_preset(self) -> AdamWConfig:
|
495 |
+
return AdamWConfig(
|
496 |
+
lr=self.optimizer_lr,
|
497 |
+
betas=self.optimizer_betas,
|
498 |
+
eps=self.optimizer_eps,
|
499 |
+
weight_decay=self.optimizer_weight_decay,
|
500 |
+
)
|
501 |
+
|
502 |
+
def get_scheduler_preset(self):
|
503 |
+
return CosineDecayWithWarmupSchedulerConfig(
|
504 |
+
peak_lr=self.optimizer_lr,
|
505 |
+
decay_lr=self.scheduler_decay_lr,
|
506 |
+
num_warmup_steps=self.scheduler_warmup_steps,
|
507 |
+
num_decay_steps=self.scheduler_decay_steps,
|
508 |
+
)
|
509 |
+
|
510 |
+
@property
|
511 |
+
def observation_delta_indices(self) -> None:
|
512 |
+
return None
|
513 |
+
|
514 |
+
@property
|
515 |
+
def action_delta_indices(self) -> list:
|
516 |
+
return list(range(self.chunk_size))
|
517 |
+
|
518 |
+
@property
|
519 |
+
def reward_delta_indices(self) -> None:
|
520 |
+
return None
|
521 |
+
|
522 |
+
@property
|
523 |
+
def slide(self) -> None:
|
524 |
+
return 1
|
525 |
+
|
526 |
+
@property
|
527 |
+
def s1_action_steps(self) -> None:
|
528 |
+
return 1
|
fast_visuo_expert.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from transformers import (
|
6 |
+
AutoConfig,
|
7 |
+
Dinov2Model,
|
8 |
+
GemmaForCausalLM,
|
9 |
+
PretrainedConfig,
|
10 |
+
PreTrainedModel,
|
11 |
+
)
|
12 |
+
from transformers.models.auto import CONFIG_MAPPING
|
13 |
+
|
14 |
+
|
15 |
+
def apply_rope(x, positions, max_wavelength=10_000):
|
16 |
+
"""
|
17 |
+
Applies RoPE positions [B, L] to x [B, L, H, D].
|
18 |
+
"""
|
19 |
+
d_half = x.shape[-1] // 2
|
20 |
+
device = x.device
|
21 |
+
dtype = x.dtype
|
22 |
+
x = x.to(torch.float32)
|
23 |
+
|
24 |
+
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
25 |
+
d_half, dtype=torch.float32, device=device
|
26 |
+
)
|
27 |
+
timescale = max_wavelength**freq_exponents
|
28 |
+
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
29 |
+
torch.float32
|
30 |
+
)
|
31 |
+
|
32 |
+
radians = radians[..., None, :]
|
33 |
+
|
34 |
+
sin = torch.sin(radians) # .to(dtype=dtype)
|
35 |
+
cos = torch.cos(radians) # .to(dtype=dtype)
|
36 |
+
|
37 |
+
x1, x2 = x.split(d_half, dim=-1)
|
38 |
+
res = torch.empty_like(x)
|
39 |
+
res[..., :d_half] = x1 * cos - x2 * sin
|
40 |
+
res[..., d_half:] = x2 * cos + x1 * sin
|
41 |
+
|
42 |
+
return res.to(dtype)
|
43 |
+
|
44 |
+
|
45 |
+
class FastVisuoExpertConfig(PretrainedConfig):
|
46 |
+
model_type = "FastVisuoExpertModel"
|
47 |
+
sub_configs = {"dino_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
dino_config: dict | None = None,
|
52 |
+
gemma_expert_config: dict | None = None,
|
53 |
+
freeze_vision_encoder: bool = True,
|
54 |
+
attention_implementation: str = "eager",
|
55 |
+
**kwargs,
|
56 |
+
):
|
57 |
+
self.freeze_vision_encoder = freeze_vision_encoder
|
58 |
+
self.attention_implementation = attention_implementation
|
59 |
+
|
60 |
+
if dino_config is None:
|
61 |
+
self.dino_config = CONFIG_MAPPING["dinov2"](
|
62 |
+
transformers_version="4.48.1",
|
63 |
+
model_type="dinov2",
|
64 |
+
attention_probs_dropout_prob=0.0,
|
65 |
+
drop_path_rate=0.0,
|
66 |
+
hidden_act="gelu",
|
67 |
+
hidden_dropout_prob=0.0,
|
68 |
+
hidden_size=384,
|
69 |
+
image_size=518,
|
70 |
+
initializer_range=0.02,
|
71 |
+
layer_norm_eps=1e-06,
|
72 |
+
layerscale_value=1.0,
|
73 |
+
mlp_ratio=4,
|
74 |
+
num_attention_heads=6,
|
75 |
+
num_channels=3,
|
76 |
+
num_hidden_layers=12,
|
77 |
+
patch_size=14,
|
78 |
+
qkv_bias=True,
|
79 |
+
torch_dtype="float32",
|
80 |
+
use_swiglu_ffn=False,
|
81 |
+
)
|
82 |
+
elif isinstance(dino_config, dict):
|
83 |
+
if "model_type" not in dino_config:
|
84 |
+
dino_config["model_type"] = "dinov2"
|
85 |
+
cfg_cls = CONFIG_MAPPING[dino_config["model_type"]]
|
86 |
+
self.dino_config = cfg_cls(**dino_config)
|
87 |
+
|
88 |
+
if gemma_expert_config is None:
|
89 |
+
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
90 |
+
attention_bias=False,
|
91 |
+
attention_dropout=0.0,
|
92 |
+
bos_token_id=2,
|
93 |
+
eos_token_id=1,
|
94 |
+
head_dim=256,
|
95 |
+
hidden_act="gelu_pytorch_tanh",
|
96 |
+
hidden_activation="gelu_pytorch_tanh",
|
97 |
+
hidden_size=1024,
|
98 |
+
initializer_range=0.02,
|
99 |
+
intermediate_size=4096,
|
100 |
+
max_position_embeddings=8192,
|
101 |
+
model_type="gemma",
|
102 |
+
num_attention_heads=8,
|
103 |
+
num_hidden_layers=8,
|
104 |
+
num_key_value_heads=1,
|
105 |
+
pad_token_id=0,
|
106 |
+
rms_norm_eps=1e-06,
|
107 |
+
rope_theta=10000.0,
|
108 |
+
torch_dtype="float32",
|
109 |
+
transformers_version="4.48.1",
|
110 |
+
use_cache=True,
|
111 |
+
vocab_size=257152,
|
112 |
+
)
|
113 |
+
elif isinstance(gemma_expert_config, dict):
|
114 |
+
if "model_type" not in gemma_expert_config:
|
115 |
+
gemma_expert_config["model_type"] = "gemma"
|
116 |
+
cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
|
117 |
+
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
118 |
+
|
119 |
+
super().__init__(**kwargs)
|
120 |
+
|
121 |
+
def __post_init__(self):
|
122 |
+
super().__post_init__()
|
123 |
+
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
124 |
+
raise ValueError(
|
125 |
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class FastVisuoExpertModel(PreTrainedModel):
|
130 |
+
config_class = FastVisuoExpertConfig
|
131 |
+
|
132 |
+
def __init__(self, config: FastVisuoExpertConfig):
|
133 |
+
super().__init__(config=config)
|
134 |
+
self.config = config
|
135 |
+
self.vision_tower = Dinov2Model(config=config.dino_config)
|
136 |
+
self.gemma_expert = GemmaForCausalLM(
|
137 |
+
config=config.gemma_expert_config
|
138 |
+
) # GemmaModel
|
139 |
+
self.multi_modal_projector = nn.Linear(
|
140 |
+
config.dino_config.hidden_size, config.gemma_expert_config.hidden_size
|
141 |
+
)
|
142 |
+
self.gemma_expert.model.embed_tokens = None
|
143 |
+
self.gemma_expert.lm_head = None
|
144 |
+
|
145 |
+
self.to_bfloat16_like_physical_intelligence()
|
146 |
+
self.set_requires_grad()
|
147 |
+
|
148 |
+
def set_requires_grad(self):
|
149 |
+
if self.config.freeze_vision_encoder:
|
150 |
+
self.vision_tower.eval()
|
151 |
+
for params in self.vision_tower.parameters():
|
152 |
+
params.requires_grad = False
|
153 |
+
|
154 |
+
def train(self, mode: bool = True):
|
155 |
+
super().train(mode)
|
156 |
+
|
157 |
+
if self.config.freeze_vision_encoder:
|
158 |
+
self.vision_tower.eval()
|
159 |
+
|
160 |
+
def to_bfloat16_like_physical_intelligence(self):
|
161 |
+
self.vision_tower = self.vision_tower.to(dtype=torch.bfloat16)
|
162 |
+
params_to_change_dtype = [
|
163 |
+
"language_model.model.layers",
|
164 |
+
"gemma_expert.model.layers",
|
165 |
+
"vision_tower",
|
166 |
+
"multi_modal",
|
167 |
+
]
|
168 |
+
for name, param in self.named_parameters():
|
169 |
+
if any(selector in name for selector in params_to_change_dtype):
|
170 |
+
param.data = param.data.to(dtype=torch.bfloat16)
|
171 |
+
|
172 |
+
def embed_image(self, image: torch.Tensor):
|
173 |
+
selected_image_feature = self.vision_tower(image).last_hidden_state
|
174 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
175 |
+
image_features = image_features / (
|
176 |
+
self.config.gemma_expert_config.hidden_size**0.5
|
177 |
+
)
|
178 |
+
return image_features
|
179 |
+
|
180 |
+
# TODO: break down this huge forward into modules or functions
|
181 |
+
def forward(
|
182 |
+
self,
|
183 |
+
attention_mask: Optional[torch.Tensor] = None,
|
184 |
+
position_ids: Optional[torch.LongTensor] = None,
|
185 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
186 |
+
):
|
187 |
+
# RMSNorm
|
188 |
+
head_dim = self.gemma_expert.config.head_dim
|
189 |
+
|
190 |
+
hidden_states = inputs_embeds
|
191 |
+
batch_size = hidden_states.shape[0]
|
192 |
+
for layer in self.gemma_expert.model.layers[
|
193 |
+
: self.gemma_expert.config.num_hidden_layers
|
194 |
+
]:
|
195 |
+
# normalizer = torch.tensor(model.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
196 |
+
# hidden_states = hidden_states * normalizer
|
197 |
+
hidden_states = layer.input_layernorm(hidden_states)
|
198 |
+
input_shape = hidden_states.shape[:-1]
|
199 |
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
200 |
+
|
201 |
+
# self attention
|
202 |
+
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
203 |
+
query_states = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
204 |
+
key_states = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
205 |
+
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
206 |
+
|
207 |
+
query_states = apply_rope(query_states, position_ids)
|
208 |
+
key_states = apply_rope(key_states, position_ids)
|
209 |
+
|
210 |
+
attention_interface = self.get_attention_interface()
|
211 |
+
att_output = attention_interface(
|
212 |
+
attention_mask,
|
213 |
+
batch_size,
|
214 |
+
head_dim,
|
215 |
+
query_states,
|
216 |
+
key_states,
|
217 |
+
value_states,
|
218 |
+
)
|
219 |
+
|
220 |
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
221 |
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
222 |
+
|
223 |
+
out_emb = layer.self_attn.o_proj(att_output)
|
224 |
+
|
225 |
+
# first residual
|
226 |
+
out_emb += hidden_states
|
227 |
+
after_first_residual = out_emb.clone()
|
228 |
+
out_emb = layer.post_attention_layernorm(out_emb)
|
229 |
+
out_emb = layer.mlp(out_emb)
|
230 |
+
# second residual
|
231 |
+
out_emb += after_first_residual
|
232 |
+
hidden_states = out_emb
|
233 |
+
|
234 |
+
# final norm
|
235 |
+
hidden_states = self.gemma_expert.model.norm(hidden_states)
|
236 |
+
|
237 |
+
return hidden_states
|
238 |
+
|
239 |
+
def get_attention_interface(self):
|
240 |
+
if self.config.attention_implementation == "fa2":
|
241 |
+
attention_interface = self.flash_attention_forward
|
242 |
+
else:
|
243 |
+
attention_interface = self.eager_attention_forward
|
244 |
+
return attention_interface
|
245 |
+
|
246 |
+
def eager_attention_forward(
|
247 |
+
self,
|
248 |
+
attention_mask,
|
249 |
+
batch_size,
|
250 |
+
head_dim,
|
251 |
+
query_states,
|
252 |
+
key_states,
|
253 |
+
value_states,
|
254 |
+
):
|
255 |
+
num_att_heads = self.config.gemma_expert_config.num_attention_heads
|
256 |
+
num_key_value_heads = self.config.gemma_expert_config.num_key_value_heads
|
257 |
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
258 |
+
|
259 |
+
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
260 |
+
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
261 |
+
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
262 |
+
sequence_length = key_states.shape[1]
|
263 |
+
|
264 |
+
key_states = key_states[:, :, :, None, :].expand(
|
265 |
+
batch_size,
|
266 |
+
sequence_length,
|
267 |
+
num_key_value_heads,
|
268 |
+
num_key_value_groups,
|
269 |
+
head_dim,
|
270 |
+
)
|
271 |
+
key_states = key_states.reshape(
|
272 |
+
batch_size,
|
273 |
+
sequence_length,
|
274 |
+
num_key_value_heads * num_key_value_groups,
|
275 |
+
head_dim,
|
276 |
+
)
|
277 |
+
|
278 |
+
value_states = value_states[:, :, :, None, :].expand(
|
279 |
+
batch_size,
|
280 |
+
sequence_length,
|
281 |
+
num_key_value_heads,
|
282 |
+
num_key_value_groups,
|
283 |
+
head_dim,
|
284 |
+
)
|
285 |
+
value_states = value_states.reshape(
|
286 |
+
batch_size,
|
287 |
+
sequence_length,
|
288 |
+
num_key_value_heads * num_key_value_groups,
|
289 |
+
head_dim,
|
290 |
+
)
|
291 |
+
|
292 |
+
# Attention here is upcasted to float32 to match the original eager implementation.
|
293 |
+
query_states = query_states.to(dtype=torch.float32)
|
294 |
+
key_states = key_states.to(dtype=torch.float32)
|
295 |
+
|
296 |
+
query_states = query_states.transpose(1, 2)
|
297 |
+
key_states = key_states.transpose(1, 2)
|
298 |
+
|
299 |
+
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
300 |
+
att_weights *= head_dim**-0.5
|
301 |
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
302 |
+
|
303 |
+
masked_att_weights = torch.where(
|
304 |
+
attention_mask[:, None, :, :], att_weights, big_neg
|
305 |
+
)
|
306 |
+
|
307 |
+
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
308 |
+
probs = probs.to(dtype=value_states.dtype)
|
309 |
+
|
310 |
+
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
311 |
+
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
312 |
+
|
313 |
+
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
314 |
+
|
315 |
+
att_output = att_output.permute(0, 2, 1, 3)
|
316 |
+
# we use -1 because sequence length can change
|
317 |
+
att_output = att_output.reshape(
|
318 |
+
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
319 |
+
)
|
320 |
+
|
321 |
+
return att_output
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58518ffa9166223aafee65453b3ed9cd4abde834949a61bbf78c3a5e99c1fe42
|
3 |
+
size 9038608596
|
modeling_hume.py
ADDED
@@ -0,0 +1,1909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
from argparse import Namespace
|
4 |
+
from collections import deque
|
5 |
+
|
6 |
+
import array_typing as at
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F # noqa: N812
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
from beartype import beartype as typechecker
|
12 |
+
from configuration_hume import HumeConfig, System2Config
|
13 |
+
from fast_visuo_expert import FastVisuoExpertConfig, FastVisuoExpertModel
|
14 |
+
from jaxtyping import Bool, Float, Int64, jaxtyped
|
15 |
+
from lerobot.common.constants import ACTION, OBS_ROBOT
|
16 |
+
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
17 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
18 |
+
from lerobot.common.utils.utils import get_safe_dtype
|
19 |
+
from paligemma_with_expert import (
|
20 |
+
PaliGemmaWithExpertConfig,
|
21 |
+
PaliGemmaWithExpertModel,
|
22 |
+
)
|
23 |
+
from torch import Tensor, nn
|
24 |
+
from transformers import AutoTokenizer
|
25 |
+
from value_query import (
|
26 |
+
CalQL,
|
27 |
+
CalQlConfig,
|
28 |
+
VQHBackbone,
|
29 |
+
VQHBackboneConfig,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def create_sinusoidal_pos_embedding(
|
34 |
+
time: torch.tensor,
|
35 |
+
dimension: int,
|
36 |
+
min_period: float,
|
37 |
+
max_period: float,
|
38 |
+
device="cpu",
|
39 |
+
) -> Tensor:
|
40 |
+
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
41 |
+
if dimension % 2 != 0:
|
42 |
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
43 |
+
|
44 |
+
if time.ndim != 1:
|
45 |
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
46 |
+
|
47 |
+
dtype = get_safe_dtype(torch.float64, device.type)
|
48 |
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
49 |
+
period = min_period * (max_period / min_period) ** fraction
|
50 |
+
|
51 |
+
# Compute the outer product
|
52 |
+
scaling_factor = 1.0 / period * 2 * math.pi
|
53 |
+
sin_input = scaling_factor[None, :] * time[:, None]
|
54 |
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
55 |
+
return pos_emb
|
56 |
+
|
57 |
+
|
58 |
+
def sample_beta(alpha, beta, bsize, device):
|
59 |
+
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
60 |
+
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
61 |
+
return gamma1 / (gamma1 + gamma2)
|
62 |
+
|
63 |
+
|
64 |
+
def make_att_2d_masks(pad_masks, att_masks):
|
65 |
+
"""Copied from big_vision.
|
66 |
+
|
67 |
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
68 |
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
69 |
+
setup several types of attention, for example:
|
70 |
+
|
71 |
+
[[1 1 1 1 1 1]]: pure causal attention.
|
72 |
+
|
73 |
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
74 |
+
themselves and the last 3 tokens have a causal attention. The first
|
75 |
+
entry could also be a 1 without changing behaviour.
|
76 |
+
|
77 |
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
78 |
+
block can attend all previous blocks and all tokens on the same block.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
input_mask: bool[B, N] true if its part of the input, false if padding.
|
82 |
+
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
83 |
+
it and 0 where it shares the same attention mask as the previous token.
|
84 |
+
"""
|
85 |
+
if att_masks.ndim != 2:
|
86 |
+
raise ValueError(att_masks.ndim)
|
87 |
+
if pad_masks.ndim != 2:
|
88 |
+
raise ValueError(pad_masks.ndim)
|
89 |
+
|
90 |
+
cumsum = torch.cumsum(att_masks, dim=1)
|
91 |
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
92 |
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
93 |
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
94 |
+
return att_2d_masks
|
95 |
+
|
96 |
+
|
97 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
98 |
+
# assume no-op when width height fits already
|
99 |
+
if img.ndim != 4:
|
100 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
101 |
+
|
102 |
+
cur_height, cur_width = img.shape[2:]
|
103 |
+
|
104 |
+
ratio = max(cur_width / width, cur_height / height)
|
105 |
+
resized_height = int(cur_height / ratio)
|
106 |
+
resized_width = int(cur_width / ratio)
|
107 |
+
resized_img = F.interpolate(
|
108 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
109 |
+
)
|
110 |
+
|
111 |
+
pad_height = max(0, int(height - resized_height))
|
112 |
+
pad_width = max(0, int(width - resized_width))
|
113 |
+
|
114 |
+
# pad on left and top of image
|
115 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
116 |
+
return padded_img
|
117 |
+
|
118 |
+
|
119 |
+
def pad_vector(vector, new_dim):
|
120 |
+
"""Can be (batch_size x sequence_length x features_dimension)
|
121 |
+
or (batch_size x features_dimension)
|
122 |
+
"""
|
123 |
+
if vector.shape[-1] == new_dim:
|
124 |
+
return vector
|
125 |
+
shape = list(vector.shape)
|
126 |
+
current_dim = shape[-1]
|
127 |
+
shape[-1] = new_dim
|
128 |
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
129 |
+
new_vector[..., :current_dim] = vector
|
130 |
+
return new_vector
|
131 |
+
|
132 |
+
|
133 |
+
def normalize(x, min_val, max_val):
|
134 |
+
return (x - min_val) / (max_val - min_val)
|
135 |
+
|
136 |
+
|
137 |
+
def unnormalize(x, min_val, max_val):
|
138 |
+
return x * (max_val - min_val) + min_val
|
139 |
+
|
140 |
+
|
141 |
+
def safe_arcsin(value):
|
142 |
+
# This ensures that the input stays within
|
143 |
+
# [−1,1] to avoid invalid values for arcsin
|
144 |
+
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
145 |
+
|
146 |
+
|
147 |
+
def aloha_gripper_to_angular(value):
|
148 |
+
# Aloha transforms the gripper positions into a linear space. The following code
|
149 |
+
# reverses this transformation to be consistent with pi0 which is pretrained in
|
150 |
+
# angular space.
|
151 |
+
#
|
152 |
+
# These values are coming from the Aloha code:
|
153 |
+
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
154 |
+
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
155 |
+
|
156 |
+
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
157 |
+
def linear_to_radian(linear_position, arm_length, horn_radius):
|
158 |
+
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (
|
159 |
+
2 * horn_radius * linear_position
|
160 |
+
)
|
161 |
+
return safe_arcsin(value)
|
162 |
+
|
163 |
+
# The constants are taken from the Interbotix code.
|
164 |
+
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
165 |
+
|
166 |
+
# Normalize to [0, 1].
|
167 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
168 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
169 |
+
|
170 |
+
|
171 |
+
def aloha_gripper_from_angular(value):
|
172 |
+
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
173 |
+
# Note that the units are still angular but the range is different.
|
174 |
+
|
175 |
+
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
176 |
+
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
177 |
+
|
178 |
+
# These values are coming from the Aloha code:
|
179 |
+
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
180 |
+
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
181 |
+
|
182 |
+
|
183 |
+
def aloha_gripper_from_angular_inv(value):
|
184 |
+
# Directly inverts the gripper_from_angular function.
|
185 |
+
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
186 |
+
return normalize(value, min_val=0.4, max_val=1.5)
|
187 |
+
|
188 |
+
|
189 |
+
class HumePolicy(PreTrainedPolicy):
|
190 |
+
"""Wrapper class around System2 model to train and run inference within LeRobot."""
|
191 |
+
|
192 |
+
config_class = HumeConfig
|
193 |
+
name = "hume"
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
config: HumeConfig,
|
198 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
199 |
+
):
|
200 |
+
super().__init__(config)
|
201 |
+
config.validate_features()
|
202 |
+
self.config = config
|
203 |
+
|
204 |
+
# TODO: input / output features / normalizer for mutiple datasets
|
205 |
+
self.normalize_inputs = Normalize(
|
206 |
+
config.input_features, config.normalization_mapping, dataset_stats
|
207 |
+
)
|
208 |
+
self.normalize_targets = Normalize(
|
209 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
210 |
+
)
|
211 |
+
self.unnormalize_outputs = Unnormalize(
|
212 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
213 |
+
)
|
214 |
+
|
215 |
+
self.language_tokenizer = None
|
216 |
+
self.s2_model = System2(config)
|
217 |
+
self.s1_model = FastVisuoMatching(config)
|
218 |
+
self.value_query_head = ValueQueryHead(
|
219 |
+
paligemma_with_expert=self.s2_model.paligemma_with_expert, config=config
|
220 |
+
)
|
221 |
+
self.reset()
|
222 |
+
|
223 |
+
self.set_requires_grad()
|
224 |
+
|
225 |
+
def set_requires_grad(self):
|
226 |
+
if self.config.freeze_s2:
|
227 |
+
self.s2_model.eval()
|
228 |
+
for params in self.s2_model.parameters():
|
229 |
+
params.requires_grad = False
|
230 |
+
|
231 |
+
def train(self, mode: bool = True):
|
232 |
+
super().train(mode)
|
233 |
+
if self.config.freeze_s2:
|
234 |
+
self.s2_model.eval()
|
235 |
+
|
236 |
+
def reset(self):
|
237 |
+
"""This should be called whenever the environment is reset."""
|
238 |
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
239 |
+
self.s2_action_cache = {}
|
240 |
+
|
241 |
+
def get_trunk_params(self) -> dict:
|
242 |
+
exclude_params = set()
|
243 |
+
exclude_modules = [
|
244 |
+
self.value_query_head.calql.policy,
|
245 |
+
self.value_query_head.calql.critics,
|
246 |
+
self.value_query_head.calql.temperature,
|
247 |
+
]
|
248 |
+
|
249 |
+
for module in exclude_modules:
|
250 |
+
for param in module.parameters():
|
251 |
+
exclude_params.add(id(param))
|
252 |
+
|
253 |
+
return [param for param in self.parameters() if id(param) not in exclude_params]
|
254 |
+
|
255 |
+
def get_optim_params(self) -> dict:
|
256 |
+
return self.parameters()
|
257 |
+
|
258 |
+
def get_actor_optim_params(self) -> dict:
|
259 |
+
return self.value_query_head.calql.policy.parameters()
|
260 |
+
|
261 |
+
def get_critics_optim_params(self) -> dict:
|
262 |
+
return self.value_query_head.calql.critics.parameters()
|
263 |
+
|
264 |
+
def get_temperature_optim_params(self) -> dict:
|
265 |
+
return self.value_query_head.calql.temperature.parameters()
|
266 |
+
|
267 |
+
def init_infer(self, infer_cfg: at.InferConfig):
|
268 |
+
self.infer_cfg = Namespace(**infer_cfg)
|
269 |
+
self.action_plan = collections.deque()
|
270 |
+
self.history_state = collections.deque(maxlen=self.config.s1_his_state_size)
|
271 |
+
self.infer_step = 0
|
272 |
+
self.outputs = {}
|
273 |
+
self.q_value_cache = []
|
274 |
+
self.action_cache = []
|
275 |
+
|
276 |
+
self.reset()
|
277 |
+
print("Initializing inference with config:", infer_cfg)
|
278 |
+
|
279 |
+
return True
|
280 |
+
|
281 |
+
def infer(self, observation: at.InferBatchObs) -> at.ActionArray:
|
282 |
+
# prcoess observation
|
283 |
+
# from np.array -> torch.tensor -> add batch, change shape
|
284 |
+
if not self.history_state:
|
285 |
+
self.history_state.extend(
|
286 |
+
np.expand_dims(observation["observation.state"], 1)
|
287 |
+
.repeat(self.config.s1_his_state_size, axis=1)
|
288 |
+
.transpose(1, 0, 2)
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
self.history_state.append(observation["observation.state"])
|
292 |
+
|
293 |
+
observation["observation.state"] = np.asarray(self.history_state).transpose(
|
294 |
+
1, 0, 2
|
295 |
+
)
|
296 |
+
|
297 |
+
observation: dict[str, torch.tensor | list[str]] = {
|
298 |
+
**{
|
299 |
+
k: torch.tensor(v / 255) # b, h, w ,c
|
300 |
+
.permute(0, 3, 1, 2) # b, c, h, w
|
301 |
+
.to(self.infer_cfg.device)
|
302 |
+
.float()
|
303 |
+
for k, v in observation.items()
|
304 |
+
if k
|
305 |
+
in {
|
306 |
+
"observation.images.image",
|
307 |
+
"observation.images.wrist_image",
|
308 |
+
"observation.images.image_0",
|
309 |
+
}
|
310 |
+
},
|
311 |
+
**{k: v for k, v in observation.items() if k in {"task"}}, # len = batch
|
312 |
+
**{
|
313 |
+
k: torch.tensor(v)
|
314 |
+
.to(self.infer_cfg.device)
|
315 |
+
.float() # b, state_horizon, state_dim
|
316 |
+
for k, v in observation.items()
|
317 |
+
if k in {"observation.state"}
|
318 |
+
},
|
319 |
+
}
|
320 |
+
batch_size = len(observation["task"])
|
321 |
+
|
322 |
+
if not self.action_plan:
|
323 |
+
# Finished executing previous action chunk -- compute new chunk
|
324 |
+
# Prepare observations dict
|
325 |
+
# infer the action
|
326 |
+
if self.infer_step % self.infer_cfg.s2_replan_steps == 0:
|
327 |
+
self.outputs = {} # infer with s1 or s2
|
328 |
+
stamp = (
|
329 |
+
torch.tensor(
|
330 |
+
[
|
331 |
+
self.infer_step
|
332 |
+
% self.infer_cfg.s2_replan_steps
|
333 |
+
/ self.config.s2_chunk_size
|
334 |
+
]
|
335 |
+
)
|
336 |
+
.expand(batch_size)
|
337 |
+
.to(self.infer_cfg.device)
|
338 |
+
.float()
|
339 |
+
)
|
340 |
+
self.outputs = self.select_action(
|
341 |
+
observation,
|
342 |
+
self.outputs,
|
343 |
+
stamp,
|
344 |
+
s2_candidates_num=self.infer_cfg.s2_candidates_num,
|
345 |
+
noise_temp_bounds=(
|
346 |
+
self.infer_cfg.noise_temp_lower_bound,
|
347 |
+
self.infer_cfg.noise_temp_upper_bound,
|
348 |
+
),
|
349 |
+
time_temp_bounds=(
|
350 |
+
self.infer_cfg.time_temp_lower_bound,
|
351 |
+
self.infer_cfg.time_temp_upper_bound,
|
352 |
+
),
|
353 |
+
)
|
354 |
+
action_chunk = self.outputs["s1_action"].cpu().numpy()
|
355 |
+
|
356 |
+
if self.infer_cfg.post_process_action:
|
357 |
+
action_chunk[..., -1] = 2 * (1 - action_chunk[..., -1]) - 1
|
358 |
+
|
359 |
+
# convert action chunk shape to (replan_steps, batch, action_dim)
|
360 |
+
action_chunk = action_chunk.transpose(1, 0, 2)
|
361 |
+
assert (
|
362 |
+
len(action_chunk) >= self.infer_cfg.replan_steps
|
363 |
+
), f"We want to replan every {self.infer_cfg.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
364 |
+
self.action_plan.extend(action_chunk[: self.infer_cfg.replan_steps])
|
365 |
+
|
366 |
+
self.infer_step += 1
|
367 |
+
action = self.action_plan.popleft()
|
368 |
+
return np.asarray(action)
|
369 |
+
|
370 |
+
@torch.no_grad
|
371 |
+
@jaxtyped(typechecker=typechecker)
|
372 |
+
def select_action(
|
373 |
+
self,
|
374 |
+
batch: at.InferBatchObs,
|
375 |
+
outputs: at.InferOutput = {},
|
376 |
+
stamp: Float[Tensor, " batch"] | None = None,
|
377 |
+
s2_candidates_num: int = 5,
|
378 |
+
noise_temp_bounds: tuple = (1.0, 1.0),
|
379 |
+
time_temp_bounds: tuple = (1.0, 1.0),
|
380 |
+
) -> at.InferOutput:
|
381 |
+
"""Select a single action given environment observations.
|
382 |
+
|
383 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
384 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
385 |
+
queue is empty.
|
386 |
+
"""
|
387 |
+
self.eval()
|
388 |
+
|
389 |
+
if self.config.adapt_to_pi_aloha:
|
390 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
391 |
+
|
392 |
+
batch = self.normalize_inputs(batch)
|
393 |
+
|
394 |
+
# querying the policy.
|
395 |
+
images, img_masks = self.prepare_images(batch)
|
396 |
+
state = self.prepare_state(batch)
|
397 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
398 |
+
|
399 |
+
original_action_dim = self.config.action_feature.shape[0]
|
400 |
+
|
401 |
+
if "noise_action" not in outputs:
|
402 |
+
noise_actions = [] # [(Batch, Chunksize, Action dim),]
|
403 |
+
for i in range(s2_candidates_num):
|
404 |
+
noise_actions.append(
|
405 |
+
self.s2_model.sample_actions(
|
406 |
+
images,
|
407 |
+
img_masks,
|
408 |
+
lang_tokens,
|
409 |
+
lang_masks,
|
410 |
+
state[:, -1, :], # s2 not supported history state yet
|
411 |
+
time_temp=(i / s2_candidates_num)
|
412 |
+
* (time_temp_bounds[1] - time_temp_bounds[0])
|
413 |
+
+ time_temp_bounds[0],
|
414 |
+
noise_temp=(i / s2_candidates_num)
|
415 |
+
* (noise_temp_bounds[1] - noise_temp_bounds[0])
|
416 |
+
+ noise_temp_bounds[0],
|
417 |
+
)
|
418 |
+
)
|
419 |
+
noise_actions = torch.stack(noise_actions, dim=1)
|
420 |
+
# (Batch, s2_candidates_num, Chunksize, Actiondim)
|
421 |
+
batch_size = noise_actions.shape[0]
|
422 |
+
batch_idx = torch.arange(batch_size, device=noise_actions.device)
|
423 |
+
|
424 |
+
noise_actions_wo_pad = noise_actions[
|
425 |
+
:, :, : self.config.vqh_chunk_size, :original_action_dim
|
426 |
+
]
|
427 |
+
action_index, q_values = self.value_query_head.select_q_actions(
|
428 |
+
images, img_masks, lang_tokens, lang_masks, noise_actions_wo_pad
|
429 |
+
)
|
430 |
+
self.q_value_cache.append(q_values.squeeze())
|
431 |
+
unnormalized_noise_actions = self.unnormalize_outputs(
|
432 |
+
{"action": noise_actions_wo_pad}
|
433 |
+
)["action"]
|
434 |
+
self.action_cache.append(unnormalized_noise_actions.squeeze())
|
435 |
+
selected_noise_action = noise_actions[batch_idx, action_index]
|
436 |
+
|
437 |
+
outputs = {"noise_action": selected_noise_action}
|
438 |
+
|
439 |
+
noise_action: Float[Tensor, "batch s2_chunksize action_dim"] = outputs[
|
440 |
+
"noise_action"
|
441 |
+
]
|
442 |
+
idcs = (stamp * self.config.s2_chunk_size).long().unsqueeze(1) + torch.arange(
|
443 |
+
self.config.s1_chunk_size, device=noise_action.device
|
444 |
+
)
|
445 |
+
batch_idcs = torch.arange(
|
446 |
+
noise_action.shape[0], device=noise_action.device
|
447 |
+
).unsqueeze(1)
|
448 |
+
noise_action_slides = noise_action[batch_idcs, idcs]
|
449 |
+
s1_actions = self.s1_model.sample_actions(
|
450 |
+
images, img_masks, state, noise_action_slides, stamp=stamp
|
451 |
+
)
|
452 |
+
|
453 |
+
# Unpad actions
|
454 |
+
actions = s1_actions[:, :, :original_action_dim]
|
455 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
456 |
+
|
457 |
+
if self.config.adapt_to_pi_aloha:
|
458 |
+
actions = self._pi_aloha_encode_actions_inv(actions)
|
459 |
+
|
460 |
+
outputs["s1_action"] = actions
|
461 |
+
|
462 |
+
return outputs
|
463 |
+
|
464 |
+
def post_normalize(self, batch):
|
465 |
+
"""additional keys {obervation.x}.s1 are merged in to the batch,
|
466 |
+
so we need to normalize these keys
|
467 |
+
"""
|
468 |
+
merge_keys = filter(lambda k: k.endswith(".s1"), batch.keys())
|
469 |
+
for k in merge_keys:
|
470 |
+
_k = k.replace(".s1", "")
|
471 |
+
batch[k] = self.normalize_inputs({_k: batch[k]})[_k]
|
472 |
+
return batch
|
473 |
+
|
474 |
+
def get_noise_action_slides(self, action: Tensor, stamp: Tensor) -> Tensor:
|
475 |
+
"""Augment the action with the previous actions in the queue."""
|
476 |
+
# idcs = (torch.rand_like(stamp) * (self.config.s2_chunk_size - self.config.s1_chunk_size)).long()
|
477 |
+
idcs = (
|
478 |
+
(
|
479 |
+
self.config.noise_slides_alp * torch.rand_like(stamp)
|
480 |
+
- self.config.noise_slides_alp / 2
|
481 |
+
+ stamp
|
482 |
+
)
|
483 |
+
* self.config.s2_chunk_size
|
484 |
+
).long()
|
485 |
+
idcs = torch.clamp(idcs, 0, action.shape[1] - self.config.s1_chunk_size)
|
486 |
+
idcs = idcs + torch.arange(self.config.s1_chunk_size, device=action.device)
|
487 |
+
batch_idcs = torch.arange(action.shape[0], device=action.device).unsqueeze(1)
|
488 |
+
noise_action_slides = action[batch_idcs, idcs]
|
489 |
+
|
490 |
+
noise_action_slides += (
|
491 |
+
torch.randn_like(noise_action_slides) * self.config.noise_slides_eps
|
492 |
+
)
|
493 |
+
return noise_action_slides
|
494 |
+
|
495 |
+
def forward(
|
496 |
+
self, batch: dict[str, Tensor], noise=None, time=None
|
497 |
+
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, dict[str, Tensor]]:
|
498 |
+
"""Do a full training forward pass to compute the loss"""
|
499 |
+
if self.config.adapt_to_pi_aloha:
|
500 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
501 |
+
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
502 |
+
|
503 |
+
batch = self.normalize_inputs(batch)
|
504 |
+
batch = self.post_normalize(batch)
|
505 |
+
batch = self.normalize_targets(batch)
|
506 |
+
|
507 |
+
# prepare images
|
508 |
+
images, img_masks = self.prepare_images(batch)
|
509 |
+
state = self.prepare_state(batch)
|
510 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
511 |
+
|
512 |
+
s1_images, s1_img_masks = self.prepare_images(
|
513 |
+
batch, map(lambda x: f"{x}.s1", self.config.image_features)
|
514 |
+
) # 0
|
515 |
+
s1_state = self.prepare_state(batch, f"{OBS_ROBOT}.s1")
|
516 |
+
|
517 |
+
# prepare actions
|
518 |
+
actions = self.prepare_action(batch)
|
519 |
+
actions_is_pad = batch.get("action_is_pad")
|
520 |
+
|
521 |
+
b, s, _ = actions.shape
|
522 |
+
device = actions.device
|
523 |
+
batch_idcs = torch.arange(b, device=device).unsqueeze(1)
|
524 |
+
stamp = batch["stamp"]
|
525 |
+
idcs = (stamp * self.config.s2_chunk_size).long() + torch.arange(
|
526 |
+
self.config.s1_chunk_size, device=device
|
527 |
+
)
|
528 |
+
s1_actions = actions[batch_idcs, idcs]
|
529 |
+
s1_actions_is_pad = (
|
530 |
+
None if actions_is_pad is None else actions_is_pad[batch_idcs, idcs]
|
531 |
+
)
|
532 |
+
|
533 |
+
# s2 forward pass
|
534 |
+
with torch.no_grad():
|
535 |
+
if self.config.cache_s2_actions:
|
536 |
+
is_noised = []
|
537 |
+
noise_actions = torch.zeros_like(actions)
|
538 |
+
for idx, s2_idx in enumerate(batch["s2_idx"]):
|
539 |
+
if s2_idx in self.s2_action_cache:
|
540 |
+
noise_actions[idx] = self.s2_action_cache[s2_idx]
|
541 |
+
is_noised.append(False)
|
542 |
+
else:
|
543 |
+
is_noised.append(True)
|
544 |
+
# noise batch
|
545 |
+
is_noised = torch.tensor(is_noised, device=batch["s2_idx"].device)
|
546 |
+
s2_actions_infered = self.s2_model.sample_actions(
|
547 |
+
[img[is_noised] for img in images],
|
548 |
+
[mask[is_noised] for mask in img_masks],
|
549 |
+
lang_tokens[is_noised],
|
550 |
+
lang_masks[is_noised],
|
551 |
+
state[is_noised],
|
552 |
+
)
|
553 |
+
noise_actions[is_noised] = s2_actions_infered
|
554 |
+
else:
|
555 |
+
noise_actions = self.s2_model.sample_actions(
|
556 |
+
images,
|
557 |
+
img_masks,
|
558 |
+
lang_tokens,
|
559 |
+
lang_masks,
|
560 |
+
state,
|
561 |
+
)
|
562 |
+
|
563 |
+
# vgps: embs[q] -> layers -> [q] -> mlp
|
564 |
+
# value query head features are end with vqh: xx.vqh
|
565 |
+
vqh_images, vqh_img_masks = self.prepare_images(
|
566 |
+
batch, map(lambda x: f"{x}.vqh", self.config.image_features)
|
567 |
+
) # 1
|
568 |
+
|
569 |
+
temperature_loss, policy_loss, critic_loss, log_dict = (
|
570 |
+
self.value_query_head.forward(
|
571 |
+
images,
|
572 |
+
img_masks,
|
573 |
+
lang_tokens,
|
574 |
+
lang_masks,
|
575 |
+
vqh_images,
|
576 |
+
vqh_img_masks,
|
577 |
+
batch["action"][:, : self.config.vqh_chunk_size, :],
|
578 |
+
batch["reward.vqh"],
|
579 |
+
batch["mc.vqh"],
|
580 |
+
batch["reward.vqh"].to(dtype=torch.float),
|
581 |
+
)
|
582 |
+
)
|
583 |
+
|
584 |
+
noise_action_slides = self.get_noise_action_slides(noise_actions, stamp)
|
585 |
+
s1_losses = self.s1_model.forward(
|
586 |
+
s1_images,
|
587 |
+
s1_img_masks,
|
588 |
+
s1_state,
|
589 |
+
s1_actions,
|
590 |
+
noise_action_slides,
|
591 |
+
time,
|
592 |
+
stamp=stamp.squeeze(),
|
593 |
+
)
|
594 |
+
|
595 |
+
total_loss, loss_dict = 0.0, {}
|
596 |
+
|
597 |
+
if s1_actions_is_pad is not None:
|
598 |
+
in_episode_bound = ~s1_actions_is_pad
|
599 |
+
s1_losses = s1_losses * in_episode_bound.unsqueeze(-1)
|
600 |
+
|
601 |
+
s1_losses = s1_losses[..., : self.config.max_action_dim]
|
602 |
+
s1_losses = s1_losses.mean()
|
603 |
+
|
604 |
+
loss_dict["s1_loss"] = s1_losses.item()
|
605 |
+
total_loss += s1_losses
|
606 |
+
|
607 |
+
# add ValueQueryHead log dict to loss_dict
|
608 |
+
# loss_dict = {**loss_dict, **log_dict}
|
609 |
+
loss_dict["entropy"] = log_dict["entropy"].item()
|
610 |
+
loss_dict["actions_mse"] = log_dict["actions_mse"].item()
|
611 |
+
loss_dict["td_err"] = log_dict["td_err"].item()
|
612 |
+
loss_dict["temperature"] = log_dict["temperature"].item()
|
613 |
+
loss_dict["cql_loss"] = log_dict["cql_loss"].item()
|
614 |
+
loss_dict["cql_alpha"] = log_dict["cql_alpha"]
|
615 |
+
loss_dict["cql_diff"] = log_dict["cql_diff"].item()
|
616 |
+
loss_dict["critic_loss"] = log_dict["critic_loss"].item()
|
617 |
+
loss_dict["cql_ood_values"] = log_dict["cql_ood_values"].item()
|
618 |
+
loss_dict["calql_bound_rate"] = log_dict["calql_bound_rate"].item()
|
619 |
+
loss_dict["online_q"] = log_dict["online_q"].item()
|
620 |
+
loss_dict["target_q"] = log_dict["target_q"].item()
|
621 |
+
loss_dict["positive_qs"] = log_dict["positive_qs"].item()
|
622 |
+
loss_dict["actor_loss"] = log_dict["actor_loss"].item()
|
623 |
+
|
624 |
+
return total_loss, temperature_loss, policy_loss, critic_loss, loss_dict
|
625 |
+
|
626 |
+
def prepare_images(self, batch, image_features=None):
|
627 |
+
"""Apply preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
628 |
+
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
629 |
+
"""
|
630 |
+
images = []
|
631 |
+
img_masks = []
|
632 |
+
|
633 |
+
image_features = image_features or self.config.image_features
|
634 |
+
present_img_keys = [key for key in image_features if key in batch]
|
635 |
+
missing_img_keys = [key for key in image_features if key not in batch]
|
636 |
+
|
637 |
+
if len(present_img_keys) == 0:
|
638 |
+
raise ValueError(
|
639 |
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
640 |
+
)
|
641 |
+
|
642 |
+
# Preprocess image features present in the batch
|
643 |
+
for key in present_img_keys:
|
644 |
+
img = batch[key]
|
645 |
+
|
646 |
+
if self.config.resize_imgs_with_padding is not None:
|
647 |
+
img = resize_with_pad(
|
648 |
+
img, *self.config.resize_imgs_with_padding, pad_value=0
|
649 |
+
)
|
650 |
+
|
651 |
+
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
652 |
+
img = img * 2.0 - 1.0
|
653 |
+
|
654 |
+
bsize = img.shape[0]
|
655 |
+
device = img.device
|
656 |
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
657 |
+
images.append(img)
|
658 |
+
img_masks.append(mask)
|
659 |
+
|
660 |
+
# Create image features not present in the batch
|
661 |
+
# as fully 0 padded images.
|
662 |
+
for num_empty_cameras in range(len(missing_img_keys)):
|
663 |
+
if num_empty_cameras >= self.config.empty_cameras:
|
664 |
+
break
|
665 |
+
img = torch.ones_like(img) * -1
|
666 |
+
mask = torch.zeros_like(mask)
|
667 |
+
images.append(img)
|
668 |
+
img_masks.append(mask)
|
669 |
+
|
670 |
+
return images, img_masks
|
671 |
+
|
672 |
+
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
673 |
+
"""Tokenize the text input"""
|
674 |
+
device = batch[OBS_ROBOT].device
|
675 |
+
tasks = batch["task"]
|
676 |
+
|
677 |
+
# PaliGemma prompt has to end with a new line
|
678 |
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
679 |
+
|
680 |
+
tokenized_prompt = self.language_tokenizer.__call__(
|
681 |
+
tasks,
|
682 |
+
padding="max_length",
|
683 |
+
padding_side="right",
|
684 |
+
max_length=self.config.tokenizer_max_length,
|
685 |
+
return_tensors="pt",
|
686 |
+
truncation=True,
|
687 |
+
)
|
688 |
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
689 |
+
lang_masks = tokenized_prompt["attention_mask"].to(
|
690 |
+
device=device, dtype=torch.bool
|
691 |
+
)
|
692 |
+
|
693 |
+
return lang_tokens, lang_masks
|
694 |
+
|
695 |
+
def _pi_aloha_decode_state(self, state):
|
696 |
+
# Flip the joints.
|
697 |
+
for motor_idx in [1, 2, 8, 9]:
|
698 |
+
state[:, motor_idx] *= -1
|
699 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
700 |
+
for motor_idx in [6, 13]:
|
701 |
+
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
702 |
+
return state
|
703 |
+
|
704 |
+
def _pi_aloha_encode_actions(self, actions):
|
705 |
+
# Flip the joints.
|
706 |
+
for motor_idx in [1, 2, 8, 9]:
|
707 |
+
actions[:, :, motor_idx] *= -1
|
708 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
709 |
+
for motor_idx in [6, 13]:
|
710 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
711 |
+
actions[:, :, motor_idx]
|
712 |
+
)
|
713 |
+
return actions
|
714 |
+
|
715 |
+
def _pi_aloha_encode_actions_inv(self, actions):
|
716 |
+
# Flip the joints again.
|
717 |
+
for motor_idx in [1, 2, 8, 9]:
|
718 |
+
actions[:, :, motor_idx] *= -1
|
719 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
720 |
+
for motor_idx in [6, 13]:
|
721 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
722 |
+
actions[:, :, motor_idx]
|
723 |
+
)
|
724 |
+
return actions
|
725 |
+
|
726 |
+
def prepare_state(self, batch, feature=None):
|
727 |
+
"""Pad state"""
|
728 |
+
feature = feature or OBS_ROBOT
|
729 |
+
state = pad_vector(batch[feature], self.config.max_state_dim)
|
730 |
+
return state
|
731 |
+
|
732 |
+
def prepare_action(self, batch):
|
733 |
+
"""Pad action"""
|
734 |
+
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
735 |
+
return actions
|
736 |
+
|
737 |
+
def _save_pretrained(self, save_directory) -> None:
|
738 |
+
super()._save_pretrained(save_directory)
|
739 |
+
print(f"Saving the language tokenizer to {save_directory} ...")
|
740 |
+
self.language_tokenizer.save_pretrained(save_directory)
|
741 |
+
|
742 |
+
import shutil
|
743 |
+
|
744 |
+
files = [
|
745 |
+
"src/hume/models/array_typing.py",
|
746 |
+
"src/hume/models/configuration_hume.py",
|
747 |
+
"src/hume/models/fast_visuo_expert.py",
|
748 |
+
"src/hume/models/modeling_hume.py",
|
749 |
+
"src/hume/models/paligemma_with_expert.py",
|
750 |
+
"src/hume/models/value_query.py",
|
751 |
+
]
|
752 |
+
try:
|
753 |
+
for file in files:
|
754 |
+
shutil.copy(file, save_directory)
|
755 |
+
except Exception:
|
756 |
+
print("Failed to copy files to save_directory")
|
757 |
+
|
758 |
+
@classmethod
|
759 |
+
def from_pretrained(
|
760 |
+
cls,
|
761 |
+
pretrained_name_or_path,
|
762 |
+
**kwargs,
|
763 |
+
):
|
764 |
+
policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
|
765 |
+
print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
|
766 |
+
policy.language_tokenizer = AutoTokenizer.from_pretrained(
|
767 |
+
pretrained_name_or_path
|
768 |
+
)
|
769 |
+
return policy
|
770 |
+
|
771 |
+
|
772 |
+
class System2Policy(PreTrainedPolicy):
|
773 |
+
"""Wrapper class around System2FlowMatching model to train and run inference within LeRobot."""
|
774 |
+
|
775 |
+
config_class = System2Config
|
776 |
+
name = "system2"
|
777 |
+
|
778 |
+
def __init__(
|
779 |
+
self,
|
780 |
+
config: System2Config,
|
781 |
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
782 |
+
):
|
783 |
+
"""
|
784 |
+
Args:
|
785 |
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
786 |
+
the configuration class is used.
|
787 |
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
788 |
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
789 |
+
"""
|
790 |
+
|
791 |
+
super().__init__(config)
|
792 |
+
config.validate_features()
|
793 |
+
self.config = config
|
794 |
+
|
795 |
+
# TODO: input / output features / normalizer for mutiple datasets
|
796 |
+
self.normalize_inputs = Normalize(
|
797 |
+
config.input_features, config.normalization_mapping, dataset_stats
|
798 |
+
)
|
799 |
+
self.normalize_targets = Normalize(
|
800 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
801 |
+
)
|
802 |
+
self.unnormalize_outputs = Unnormalize(
|
803 |
+
config.output_features, config.normalization_mapping, dataset_stats
|
804 |
+
)
|
805 |
+
|
806 |
+
self.language_tokenizer = None
|
807 |
+
self.model = System2(config)
|
808 |
+
|
809 |
+
self.reset()
|
810 |
+
|
811 |
+
def reset(self):
|
812 |
+
"""This should be called whenever the environment is reset."""
|
813 |
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
814 |
+
|
815 |
+
def get_optim_params(self) -> dict:
|
816 |
+
return self.parameters()
|
817 |
+
|
818 |
+
@torch.no_grad
|
819 |
+
def select_action(
|
820 |
+
self, batch: dict[str, Tensor], noise: Tensor | None = None
|
821 |
+
) -> Tensor:
|
822 |
+
"""Select a single action given environment observations.
|
823 |
+
|
824 |
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
825 |
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
826 |
+
queue is empty.
|
827 |
+
"""
|
828 |
+
self.eval()
|
829 |
+
|
830 |
+
if self.config.adapt_to_pi_aloha:
|
831 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
832 |
+
|
833 |
+
batch = self.normalize_inputs(batch)
|
834 |
+
|
835 |
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
836 |
+
# querying the policy.
|
837 |
+
images, img_masks = self.prepare_images(batch)
|
838 |
+
state = self.prepare_state(batch)
|
839 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
840 |
+
|
841 |
+
actions = self.model.sample_actions(
|
842 |
+
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
843 |
+
)
|
844 |
+
|
845 |
+
# Unpad actions
|
846 |
+
original_action_dim = self.config.action_feature.shape[0]
|
847 |
+
actions = actions[:, :, :original_action_dim]
|
848 |
+
|
849 |
+
actions = self.unnormalize_outputs({"action": actions})["action"]
|
850 |
+
|
851 |
+
if self.config.adapt_to_pi_aloha:
|
852 |
+
actions = self._pi_aloha_encode_actions(actions)
|
853 |
+
return actions
|
854 |
+
|
855 |
+
def forward(
|
856 |
+
self, batch: dict[str, Tensor], noise=None, time=None
|
857 |
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
858 |
+
"""Do a full training forward pass to compute the loss"""
|
859 |
+
if self.config.adapt_to_pi_aloha:
|
860 |
+
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
861 |
+
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
862 |
+
|
863 |
+
batch = self.normalize_inputs(batch)
|
864 |
+
batch = self.normalize_targets(batch)
|
865 |
+
|
866 |
+
images, img_masks = self.prepare_images(batch)
|
867 |
+
state = self.prepare_state(batch)
|
868 |
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
869 |
+
actions = self.prepare_action(batch)
|
870 |
+
actions_is_pad = batch.get("action_is_pad")
|
871 |
+
|
872 |
+
loss_dict = {}
|
873 |
+
losses, _ = self.model.forward(
|
874 |
+
images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
|
875 |
+
)
|
876 |
+
# loss_dict["losses_after_forward"] = losses.detach().mean().item()
|
877 |
+
|
878 |
+
if actions_is_pad is not None:
|
879 |
+
in_episode_bound = ~actions_is_pad
|
880 |
+
losses = losses * in_episode_bound.unsqueeze(-1)
|
881 |
+
# loss_dict["losses_after_in_ep_bound"] = losses.detach().mean().item()
|
882 |
+
|
883 |
+
# Remove padding
|
884 |
+
losses = losses[:, :, : self.config.max_action_dim]
|
885 |
+
# loss_dict["losses_after_rm_padding"] = losses.detach().mean().item()
|
886 |
+
|
887 |
+
# For backward pass
|
888 |
+
loss = losses.mean()
|
889 |
+
# For logging
|
890 |
+
loss_dict["l2_loss"] = loss.item()
|
891 |
+
|
892 |
+
return loss, loss_dict
|
893 |
+
|
894 |
+
def prepare_images(self, batch):
|
895 |
+
"""Apply preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
896 |
+
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
897 |
+
"""
|
898 |
+
images = []
|
899 |
+
img_masks = []
|
900 |
+
|
901 |
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
902 |
+
missing_img_keys = [
|
903 |
+
key for key in self.config.image_features if key not in batch
|
904 |
+
]
|
905 |
+
|
906 |
+
if len(present_img_keys) == 0:
|
907 |
+
raise ValueError(
|
908 |
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
909 |
+
)
|
910 |
+
|
911 |
+
# Preprocess image features present in the batch
|
912 |
+
for key in present_img_keys:
|
913 |
+
img = batch[key]
|
914 |
+
|
915 |
+
if self.config.resize_imgs_with_padding is not None:
|
916 |
+
img = resize_with_pad(
|
917 |
+
img, *self.config.resize_imgs_with_padding, pad_value=0
|
918 |
+
)
|
919 |
+
|
920 |
+
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
921 |
+
img = img * 2.0 - 1.0
|
922 |
+
|
923 |
+
bsize = img.shape[0]
|
924 |
+
device = img.device
|
925 |
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
926 |
+
images.append(img)
|
927 |
+
img_masks.append(mask)
|
928 |
+
|
929 |
+
# Create image features not present in the batch
|
930 |
+
# as fully 0 padded images.
|
931 |
+
for num_empty_cameras in range(len(missing_img_keys)):
|
932 |
+
if num_empty_cameras >= self.config.empty_cameras:
|
933 |
+
break
|
934 |
+
img = torch.ones_like(img) * -1
|
935 |
+
mask = torch.zeros_like(mask)
|
936 |
+
images.append(img)
|
937 |
+
img_masks.append(mask)
|
938 |
+
|
939 |
+
return images, img_masks
|
940 |
+
|
941 |
+
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
942 |
+
"""Tokenize the text input"""
|
943 |
+
device = batch[OBS_ROBOT].device
|
944 |
+
tasks = batch["task"]
|
945 |
+
|
946 |
+
# PaliGemma prompt has to end with a new line
|
947 |
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
948 |
+
|
949 |
+
tokenized_prompt = self.language_tokenizer.__call__(
|
950 |
+
tasks,
|
951 |
+
padding="max_length",
|
952 |
+
padding_side="right",
|
953 |
+
max_length=self.config.tokenizer_max_length,
|
954 |
+
return_tensors="pt",
|
955 |
+
truncation=True,
|
956 |
+
)
|
957 |
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
958 |
+
lang_masks = tokenized_prompt["attention_mask"].to(
|
959 |
+
device=device, dtype=torch.bool
|
960 |
+
)
|
961 |
+
|
962 |
+
return lang_tokens, lang_masks
|
963 |
+
|
964 |
+
def _pi_aloha_decode_state(self, state):
|
965 |
+
# Flip the joints.
|
966 |
+
for motor_idx in [1, 2, 8, 9]:
|
967 |
+
state[:, motor_idx] *= -1
|
968 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
969 |
+
for motor_idx in [6, 13]:
|
970 |
+
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
971 |
+
return state
|
972 |
+
|
973 |
+
def _pi_aloha_encode_actions(self, actions):
|
974 |
+
# Flip the joints.
|
975 |
+
for motor_idx in [1, 2, 8, 9]:
|
976 |
+
actions[:, :, motor_idx] *= -1
|
977 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
978 |
+
for motor_idx in [6, 13]:
|
979 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular(
|
980 |
+
actions[:, :, motor_idx]
|
981 |
+
)
|
982 |
+
return actions
|
983 |
+
|
984 |
+
def _pi_aloha_encode_actions_inv(self, actions):
|
985 |
+
# Flip the joints again.
|
986 |
+
for motor_idx in [1, 2, 8, 9]:
|
987 |
+
actions[:, :, motor_idx] *= -1
|
988 |
+
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
989 |
+
for motor_idx in [6, 13]:
|
990 |
+
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(
|
991 |
+
actions[:, :, motor_idx]
|
992 |
+
)
|
993 |
+
return actions
|
994 |
+
|
995 |
+
def prepare_state(self, batch):
|
996 |
+
"""Pad state"""
|
997 |
+
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
998 |
+
return state
|
999 |
+
|
1000 |
+
def prepare_action(self, batch):
|
1001 |
+
"""Pad action"""
|
1002 |
+
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
1003 |
+
return actions
|
1004 |
+
|
1005 |
+
def _save_pretrained(self, save_directory) -> None:
|
1006 |
+
super()._save_pretrained(save_directory)
|
1007 |
+
print(f"Saving the language tokenizer to {save_directory} ...")
|
1008 |
+
self.language_tokenizer.save_pretrained(save_directory)
|
1009 |
+
|
1010 |
+
import shutil
|
1011 |
+
|
1012 |
+
files = [
|
1013 |
+
"src/hume/models/array_typing.py",
|
1014 |
+
"src/hume/models/configuration_hume.py",
|
1015 |
+
"src/hume/models/fast_visuo_expert.py",
|
1016 |
+
"src/hume/models/modeling_hume.py",
|
1017 |
+
"src/hume/models/paligemma_with_expert.py",
|
1018 |
+
"src/hume/models/value_query.py",
|
1019 |
+
]
|
1020 |
+
try:
|
1021 |
+
for file in files:
|
1022 |
+
shutil.copy(file, save_directory)
|
1023 |
+
except Exception:
|
1024 |
+
print("Failed to copy files to save_directory")
|
1025 |
+
|
1026 |
+
@classmethod
|
1027 |
+
def from_pretrained(
|
1028 |
+
cls,
|
1029 |
+
pretrained_name_or_path,
|
1030 |
+
**kwargs,
|
1031 |
+
):
|
1032 |
+
policy = super().from_pretrained(pretrained_name_or_path, **kwargs)
|
1033 |
+
print(f"Loading the language tokenizer from {pretrained_name_or_path} ...")
|
1034 |
+
policy.language_tokenizer = AutoTokenizer.from_pretrained(
|
1035 |
+
pretrained_name_or_path
|
1036 |
+
)
|
1037 |
+
return policy
|
1038 |
+
|
1039 |
+
|
1040 |
+
class System2(nn.Module):
|
1041 |
+
def __init__(self, config):
|
1042 |
+
super().__init__()
|
1043 |
+
self.config = config
|
1044 |
+
|
1045 |
+
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
1046 |
+
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
1047 |
+
train_expert_only=self.config.train_expert_only,
|
1048 |
+
attention_implementation=self.config.attention_implementation,
|
1049 |
+
paligemma_config=self.config.paligemma_config,
|
1050 |
+
gemma_expert_config=self.config.gemma_expert_config,
|
1051 |
+
)
|
1052 |
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
1053 |
+
paligemma_with_export_config
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
# Projections are float32
|
1057 |
+
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
1058 |
+
self.action_in_proj = nn.Linear(
|
1059 |
+
self.config.max_action_dim, self.config.proj_width
|
1060 |
+
)
|
1061 |
+
self.action_out_proj = nn.Linear(
|
1062 |
+
self.config.proj_width, self.config.max_action_dim
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
self.action_time_mlp_in = nn.Linear(
|
1066 |
+
self.config.proj_width * 2, self.config.proj_width
|
1067 |
+
)
|
1068 |
+
self.action_time_mlp_out = nn.Linear(
|
1069 |
+
self.config.proj_width, self.config.proj_width
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
self.set_requires_grad()
|
1073 |
+
|
1074 |
+
def set_requires_grad(self):
|
1075 |
+
for params in self.state_proj.parameters():
|
1076 |
+
params.requires_grad = self.config.train_state_proj
|
1077 |
+
|
1078 |
+
def sample_noise(self, shape, device):
|
1079 |
+
noise = torch.normal(
|
1080 |
+
mean=0.0,
|
1081 |
+
std=1.0,
|
1082 |
+
size=shape,
|
1083 |
+
dtype=torch.float32,
|
1084 |
+
device=device,
|
1085 |
+
)
|
1086 |
+
return noise
|
1087 |
+
|
1088 |
+
def sample_time(self, bsize, device):
|
1089 |
+
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
1090 |
+
time = time_beta * 0.999 + 0.001
|
1091 |
+
return time.to(dtype=torch.float32, device=device)
|
1092 |
+
|
1093 |
+
def embed_prefix(
|
1094 |
+
self, images, img_masks, lang_tokens, lang_masks
|
1095 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1096 |
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
1097 |
+
for PaliGemma transformer processing.
|
1098 |
+
"""
|
1099 |
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
1100 |
+
embs = []
|
1101 |
+
pad_masks = []
|
1102 |
+
att_masks = []
|
1103 |
+
|
1104 |
+
# TODO: remove for loop
|
1105 |
+
for (
|
1106 |
+
img,
|
1107 |
+
img_mask,
|
1108 |
+
) in zip(images, img_masks, strict=False):
|
1109 |
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
1110 |
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
1111 |
+
|
1112 |
+
# Normalize image embeddings
|
1113 |
+
img_emb_dim = img_emb.shape[-1]
|
1114 |
+
img_emb = img_emb * torch.tensor(
|
1115 |
+
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
1119 |
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
1120 |
+
|
1121 |
+
embs.append(img_emb)
|
1122 |
+
pad_masks.append(img_mask)
|
1123 |
+
|
1124 |
+
# Create attention masks so that image tokens attend to each other
|
1125 |
+
att_masks += [0] * num_img_embs
|
1126 |
+
|
1127 |
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
1128 |
+
|
1129 |
+
# Normalize language embeddings
|
1130 |
+
lang_emb_dim = lang_emb.shape[-1]
|
1131 |
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
1132 |
+
|
1133 |
+
embs.append(lang_emb)
|
1134 |
+
pad_masks.append(lang_masks)
|
1135 |
+
|
1136 |
+
# full attention between image and language inputs
|
1137 |
+
num_lang_embs = lang_emb.shape[1]
|
1138 |
+
att_masks += [0] * num_lang_embs
|
1139 |
+
|
1140 |
+
embs = torch.cat(embs, dim=1)
|
1141 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
1142 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
1143 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
1144 |
+
|
1145 |
+
return embs, pad_masks, att_masks
|
1146 |
+
|
1147 |
+
def embed_suffix(self, state, noisy_actions, timestep):
|
1148 |
+
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
1149 |
+
embs = []
|
1150 |
+
pad_masks = []
|
1151 |
+
att_masks = []
|
1152 |
+
|
1153 |
+
# Embed state
|
1154 |
+
state_emb = self.state_proj(state)
|
1155 |
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
1156 |
+
embs.append(state_emb[:, None, :])
|
1157 |
+
bsize = state_emb.shape[0]
|
1158 |
+
dtype = state_emb.dtype
|
1159 |
+
device = state_emb.device
|
1160 |
+
|
1161 |
+
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
1162 |
+
pad_masks.append(state_mask)
|
1163 |
+
|
1164 |
+
# Set attention masks so that image and language inputs do not attend to state or actions
|
1165 |
+
att_masks += [1]
|
1166 |
+
|
1167 |
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
1168 |
+
time_emb = create_sinusoidal_pos_embedding(
|
1169 |
+
timestep,
|
1170 |
+
self.config.proj_width,
|
1171 |
+
min_period=4e-3,
|
1172 |
+
max_period=4.0,
|
1173 |
+
device=device,
|
1174 |
+
)
|
1175 |
+
time_emb = time_emb.type(dtype=dtype)
|
1176 |
+
|
1177 |
+
# Fuse timestep + action information using an MLP
|
1178 |
+
action_emb = self.action_in_proj(noisy_actions)
|
1179 |
+
|
1180 |
+
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
1181 |
+
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
1182 |
+
|
1183 |
+
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
1184 |
+
action_time_emb = F.silu(action_time_emb) # swish == silu
|
1185 |
+
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
1186 |
+
|
1187 |
+
# Add to input tokens
|
1188 |
+
embs.append(action_time_emb)
|
1189 |
+
|
1190 |
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
1191 |
+
action_time_mask = torch.ones(
|
1192 |
+
bsize, action_time_dim, dtype=torch.bool, device=device
|
1193 |
+
)
|
1194 |
+
pad_masks.append(action_time_mask)
|
1195 |
+
|
1196 |
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
1197 |
+
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
1198 |
+
|
1199 |
+
embs = torch.cat(embs, dim=1)
|
1200 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
1201 |
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
1202 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
1203 |
+
|
1204 |
+
return embs, pad_masks, att_masks
|
1205 |
+
|
1206 |
+
def forward(
|
1207 |
+
self,
|
1208 |
+
images,
|
1209 |
+
img_masks,
|
1210 |
+
lang_tokens,
|
1211 |
+
lang_masks,
|
1212 |
+
state,
|
1213 |
+
actions,
|
1214 |
+
noise=None,
|
1215 |
+
time=None,
|
1216 |
+
) -> Tensor:
|
1217 |
+
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
1218 |
+
if noise is None:
|
1219 |
+
noise = self.sample_noise(actions.shape, actions.device)
|
1220 |
+
|
1221 |
+
if time is None:
|
1222 |
+
time = self.sample_time(actions.shape[0], actions.device)
|
1223 |
+
time_expanded = time[:, None, None]
|
1224 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
1225 |
+
u_t = noise - actions
|
1226 |
+
|
1227 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
1228 |
+
images, img_masks, lang_tokens, lang_masks
|
1229 |
+
)
|
1230 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
1231 |
+
state, x_t, time
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
1235 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
1236 |
+
|
1237 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
1238 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
1239 |
+
|
1240 |
+
(_, suffix_out), past_key_values = self.paligemma_with_expert.forward(
|
1241 |
+
attention_mask=att_2d_masks,
|
1242 |
+
position_ids=position_ids,
|
1243 |
+
past_key_values=None,
|
1244 |
+
inputs_embeds=[prefix_embs, suffix_embs],
|
1245 |
+
use_cache=True,
|
1246 |
+
fill_kv_cache=True,
|
1247 |
+
)
|
1248 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
1249 |
+
# Original openpi code, upcast attention output
|
1250 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
1251 |
+
v_t = self.action_out_proj(suffix_out)
|
1252 |
+
|
1253 |
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
1254 |
+
|
1255 |
+
return losses, past_key_values
|
1256 |
+
|
1257 |
+
def sample_actions(
|
1258 |
+
self,
|
1259 |
+
images,
|
1260 |
+
img_masks,
|
1261 |
+
lang_tokens,
|
1262 |
+
lang_masks,
|
1263 |
+
state,
|
1264 |
+
noise=None,
|
1265 |
+
past_key_values=None,
|
1266 |
+
time_temp=1.0,
|
1267 |
+
noise_temp=1.0,
|
1268 |
+
) -> Tensor:
|
1269 |
+
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
1270 |
+
bsize = state.shape[0]
|
1271 |
+
device = state.device
|
1272 |
+
|
1273 |
+
if noise is None:
|
1274 |
+
actions_shape = (
|
1275 |
+
bsize,
|
1276 |
+
self.config.n_action_steps,
|
1277 |
+
self.config.max_action_dim,
|
1278 |
+
)
|
1279 |
+
noise = self.sample_noise(actions_shape, device)
|
1280 |
+
|
1281 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
1282 |
+
images, img_masks, lang_tokens, lang_masks
|
1283 |
+
)
|
1284 |
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
1285 |
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
1286 |
+
|
1287 |
+
# Compute image and language key value cache
|
1288 |
+
if past_key_values is None:
|
1289 |
+
_, past_key_values = self.paligemma_with_expert.forward(
|
1290 |
+
attention_mask=prefix_att_2d_masks,
|
1291 |
+
position_ids=prefix_position_ids,
|
1292 |
+
past_key_values=None,
|
1293 |
+
inputs_embeds=[prefix_embs, None],
|
1294 |
+
use_cache=self.config.use_cache,
|
1295 |
+
fill_kv_cache=True,
|
1296 |
+
)
|
1297 |
+
|
1298 |
+
dt = -1.0 / self.config.num_steps
|
1299 |
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
1300 |
+
|
1301 |
+
x_t = noise
|
1302 |
+
time = torch.tensor(
|
1303 |
+
time_temp, dtype=torch.float32, device=device
|
1304 |
+
) # TODO: Add temp
|
1305 |
+
while time >= -dt / 2 + (1 - self.config.theta2):
|
1306 |
+
expanded_time = time.expand(bsize)
|
1307 |
+
v_t = self.denoise_step(
|
1308 |
+
state,
|
1309 |
+
prefix_pad_masks,
|
1310 |
+
past_key_values,
|
1311 |
+
x_t,
|
1312 |
+
expanded_time,
|
1313 |
+
)
|
1314 |
+
|
1315 |
+
# Euler step
|
1316 |
+
x_t += dt * v_t * noise_temp # TODO: Add noise temp
|
1317 |
+
time += dt
|
1318 |
+
return x_t
|
1319 |
+
|
1320 |
+
def denoise_step(
|
1321 |
+
self,
|
1322 |
+
state,
|
1323 |
+
prefix_pad_masks,
|
1324 |
+
past_key_values,
|
1325 |
+
x_t,
|
1326 |
+
timestep,
|
1327 |
+
):
|
1328 |
+
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
1329 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
1330 |
+
state, x_t, timestep
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
suffix_len = suffix_pad_masks.shape[1]
|
1334 |
+
batch_size = prefix_pad_masks.shape[0]
|
1335 |
+
prefix_len = prefix_pad_masks.shape[1]
|
1336 |
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
1337 |
+
batch_size, suffix_len, prefix_len
|
1338 |
+
)
|
1339 |
+
|
1340 |
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
1341 |
+
|
1342 |
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
1343 |
+
|
1344 |
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
1345 |
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
1346 |
+
|
1347 |
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
1348 |
+
attention_mask=full_att_2d_masks,
|
1349 |
+
position_ids=position_ids,
|
1350 |
+
past_key_values=past_key_values,
|
1351 |
+
inputs_embeds=[None, suffix_embs],
|
1352 |
+
use_cache=self.config.use_cache,
|
1353 |
+
fill_kv_cache=False,
|
1354 |
+
)
|
1355 |
+
suffix_out = outputs_embeds[1]
|
1356 |
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
1357 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
1358 |
+
v_t = self.action_out_proj(suffix_out)
|
1359 |
+
return v_t
|
1360 |
+
|
1361 |
+
|
1362 |
+
class FastVisuoMatching(nn.Module):
|
1363 |
+
def __init__(self, config):
|
1364 |
+
super().__init__()
|
1365 |
+
self.config = config
|
1366 |
+
|
1367 |
+
# FastVisuoExpertConfig, FastVisuoExpertModel
|
1368 |
+
fast_visuo_expertConfig = FastVisuoExpertConfig(
|
1369 |
+
freeze_vision_encoder=self.config.freeze_s1_vision_encoder,
|
1370 |
+
attention_implementation=self.config.attention_implementation,
|
1371 |
+
dino_config=self.config.s1_dino_config,
|
1372 |
+
gemma_expert_config=self.config.s1_gemma_expert_config,
|
1373 |
+
)
|
1374 |
+
self.fast_visuo_expert = FastVisuoExpertModel(fast_visuo_expertConfig)
|
1375 |
+
|
1376 |
+
# Projections are float32
|
1377 |
+
self.state_proj = nn.Linear(
|
1378 |
+
self.config.max_state_dim, self.config.s1_proj_width
|
1379 |
+
)
|
1380 |
+
self.action_in_proj = nn.Linear(
|
1381 |
+
self.config.max_action_dim, self.config.s1_proj_width
|
1382 |
+
)
|
1383 |
+
self.action_out_proj = nn.Linear(
|
1384 |
+
self.config.s1_proj_width, self.config.max_action_dim
|
1385 |
+
)
|
1386 |
+
self.action_time_mlp_in = nn.Linear(
|
1387 |
+
self.config.s1_proj_width * 2, self.config.s1_proj_width
|
1388 |
+
)
|
1389 |
+
self.action_time_mlp_out = nn.Linear(
|
1390 |
+
self.config.s1_proj_width, self.config.s1_proj_width
|
1391 |
+
)
|
1392 |
+
|
1393 |
+
self.set_requires_grad()
|
1394 |
+
|
1395 |
+
def set_requires_grad(self):
|
1396 |
+
for params in self.state_proj.parameters():
|
1397 |
+
params.requires_grad = self.config.train_state_proj
|
1398 |
+
|
1399 |
+
def sample_noise(self, shape, device):
|
1400 |
+
noise = torch.normal(
|
1401 |
+
mean=0.0,
|
1402 |
+
std=1.0,
|
1403 |
+
size=shape,
|
1404 |
+
dtype=torch.float32,
|
1405 |
+
device=device,
|
1406 |
+
)
|
1407 |
+
return noise
|
1408 |
+
|
1409 |
+
def sample_time(self, bsize, device):
|
1410 |
+
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
1411 |
+
time = time_beta * 0.999 + 0.001
|
1412 |
+
return time.to(dtype=torch.float32, device=device)
|
1413 |
+
|
1414 |
+
def embed_prefix(
|
1415 |
+
self, images, img_masks
|
1416 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1417 |
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
1418 |
+
for PaliGemma transformer processing.
|
1419 |
+
"""
|
1420 |
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
1421 |
+
embs = []
|
1422 |
+
pad_masks = []
|
1423 |
+
att_masks = []
|
1424 |
+
|
1425 |
+
# TODO: remove for loop
|
1426 |
+
for img, img_mask in zip(images, img_masks, strict=False):
|
1427 |
+
DINO_MEAN, DINO_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
1428 |
+
img = TF.normalize(img * 0.5 + 0.5, mean=DINO_MEAN, std=DINO_STD)
|
1429 |
+
img_emb = self.fast_visuo_expert.embed_image(img)
|
1430 |
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
1431 |
+
|
1432 |
+
# Normalize image embeddings
|
1433 |
+
img_emb_dim = img_emb.shape[-1]
|
1434 |
+
img_emb = img_emb * torch.tensor(
|
1435 |
+
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
1436 |
+
)
|
1437 |
+
|
1438 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
1439 |
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
1440 |
+
|
1441 |
+
embs.append(img_emb)
|
1442 |
+
pad_masks.append(img_mask)
|
1443 |
+
|
1444 |
+
# Create attention masks so that image tokens attend to each other
|
1445 |
+
att_masks += [0] * num_img_embs
|
1446 |
+
|
1447 |
+
embs = torch.cat(embs, dim=1)
|
1448 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
1449 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
1450 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
1451 |
+
|
1452 |
+
return embs, pad_masks, att_masks
|
1453 |
+
|
1454 |
+
def embed_suffix(self, state, noisy_actions, timestep, stamp):
|
1455 |
+
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
1456 |
+
embs = []
|
1457 |
+
pad_masks = []
|
1458 |
+
att_masks = []
|
1459 |
+
|
1460 |
+
# Embed state
|
1461 |
+
state_emb = self.state_proj(state)
|
1462 |
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
1463 |
+
embs.append(state_emb)
|
1464 |
+
bsize = state_emb.shape[0]
|
1465 |
+
state_horizon = state_emb.shape[1]
|
1466 |
+
dtype = state_emb.dtype
|
1467 |
+
device = state_emb.device
|
1468 |
+
|
1469 |
+
state_mask = torch.ones(bsize, state_horizon, dtype=torch.bool, device=device)
|
1470 |
+
pad_masks.append(state_mask)
|
1471 |
+
|
1472 |
+
# Set attention masks so that image and language inputs do not attend to state or actions
|
1473 |
+
att_masks += [1] * state_horizon
|
1474 |
+
|
1475 |
+
# Embed stamp
|
1476 |
+
stamp_emb = create_sinusoidal_pos_embedding(
|
1477 |
+
stamp,
|
1478 |
+
self.config.s1_proj_width,
|
1479 |
+
min_period=4e-3,
|
1480 |
+
max_period=4.0,
|
1481 |
+
device=device,
|
1482 |
+
)
|
1483 |
+
stamp_emb = stamp_emb.type(dtype=dtype)[:, None, :]
|
1484 |
+
embs.append(stamp_emb)
|
1485 |
+
stamp_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
1486 |
+
pad_masks.append(stamp_mask)
|
1487 |
+
att_masks += [1]
|
1488 |
+
|
1489 |
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
1490 |
+
time_emb = create_sinusoidal_pos_embedding(
|
1491 |
+
timestep,
|
1492 |
+
self.config.s1_proj_width,
|
1493 |
+
min_period=4e-3,
|
1494 |
+
max_period=4.0,
|
1495 |
+
device=device,
|
1496 |
+
)
|
1497 |
+
time_emb = time_emb.type(dtype=dtype)
|
1498 |
+
|
1499 |
+
# Fuse timestep + action information using an MLP
|
1500 |
+
action_emb = self.action_in_proj(noisy_actions)
|
1501 |
+
|
1502 |
+
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
1503 |
+
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
1504 |
+
|
1505 |
+
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
1506 |
+
action_time_emb = F.silu(action_time_emb) # swish == silu
|
1507 |
+
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
1508 |
+
|
1509 |
+
# Add to input tokens
|
1510 |
+
embs.append(action_time_emb)
|
1511 |
+
|
1512 |
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
1513 |
+
action_time_mask = torch.ones(
|
1514 |
+
bsize, action_time_dim, dtype=torch.bool, device=device
|
1515 |
+
)
|
1516 |
+
pad_masks.append(action_time_mask)
|
1517 |
+
|
1518 |
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
1519 |
+
att_masks += [1] + ([0] * (self.config.s1_action_steps - 1))
|
1520 |
+
|
1521 |
+
embs = torch.cat(embs, dim=1)
|
1522 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
1523 |
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
1524 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
1525 |
+
|
1526 |
+
return embs, pad_masks, att_masks
|
1527 |
+
|
1528 |
+
def forward(
|
1529 |
+
self, images, img_masks, state, actions, noise=None, time=None, stamp=None
|
1530 |
+
) -> Float[
|
1531 |
+
Tensor, "batch {self.config.s1_action_steps} {self.config.max_action_dim}"
|
1532 |
+
]:
|
1533 |
+
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
1534 |
+
if noise is None:
|
1535 |
+
noise = self.sample_noise(actions.shape, actions.device)
|
1536 |
+
if time is None:
|
1537 |
+
time = (
|
1538 |
+
self.sample_time(actions.shape[0], actions.device) * self.config.theta1
|
1539 |
+
) # s2: [1, 0.1] -> s1: [0.1, 0]
|
1540 |
+
time_expanded = time[:, None, None]
|
1541 |
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
1542 |
+
u_t = noise - actions
|
1543 |
+
|
1544 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
1545 |
+
images, img_masks
|
1546 |
+
)
|
1547 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
1548 |
+
state, x_t, time, stamp
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
1552 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
1553 |
+
|
1554 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
1555 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
1556 |
+
|
1557 |
+
inputs_embeds = torch.cat(
|
1558 |
+
[prefix_embs, suffix_embs], dim=1
|
1559 |
+
) # torch.Size([16, 565]), torch.Size([16, 565])
|
1560 |
+
|
1561 |
+
suffix_out = self.fast_visuo_expert.forward(
|
1562 |
+
attention_mask=att_2d_masks,
|
1563 |
+
position_ids=position_ids,
|
1564 |
+
inputs_embeds=inputs_embeds,
|
1565 |
+
)
|
1566 |
+
suffix_out = suffix_out[:, -self.config.s1_action_steps :]
|
1567 |
+
# Original openpi code, upcast attention output
|
1568 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
1569 |
+
v_t = self.action_out_proj(suffix_out)
|
1570 |
+
|
1571 |
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
1572 |
+
return losses
|
1573 |
+
|
1574 |
+
def sample_actions(
|
1575 |
+
self, images, img_masks, state, noise=None, stamp=None
|
1576 |
+
) -> Tensor:
|
1577 |
+
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
1578 |
+
bsize = state.shape[0]
|
1579 |
+
device = state.device
|
1580 |
+
|
1581 |
+
if noise is None:
|
1582 |
+
actions_shape = (
|
1583 |
+
bsize,
|
1584 |
+
self.config.s1_action_steps,
|
1585 |
+
self.config.max_action_dim,
|
1586 |
+
)
|
1587 |
+
noise = self.sample_noise(actions_shape, device)
|
1588 |
+
|
1589 |
+
if stamp is None:
|
1590 |
+
stamp = torch.rand(bsize, device=device)
|
1591 |
+
|
1592 |
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
1593 |
+
images, img_masks
|
1594 |
+
)
|
1595 |
+
|
1596 |
+
dt = -self.config.theta1 / self.config.s1_num_steps
|
1597 |
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
1598 |
+
|
1599 |
+
x_t = noise
|
1600 |
+
time = torch.tensor(self.config.theta1, dtype=torch.float32, device=device)
|
1601 |
+
while time >= -dt / 2:
|
1602 |
+
expanded_time = time.expand(bsize)
|
1603 |
+
v_t = self.denoise_step(
|
1604 |
+
state,
|
1605 |
+
prefix_embs,
|
1606 |
+
prefix_pad_masks,
|
1607 |
+
prefix_att_masks,
|
1608 |
+
x_t,
|
1609 |
+
expanded_time,
|
1610 |
+
stamp,
|
1611 |
+
)
|
1612 |
+
# Euler step
|
1613 |
+
x_t += dt * v_t
|
1614 |
+
time += dt
|
1615 |
+
return x_t
|
1616 |
+
|
1617 |
+
def denoise_step(
|
1618 |
+
self,
|
1619 |
+
state,
|
1620 |
+
prefix_embs,
|
1621 |
+
prefix_pad_masks,
|
1622 |
+
prefix_att_masks,
|
1623 |
+
x_t,
|
1624 |
+
timestep,
|
1625 |
+
stamp,
|
1626 |
+
):
|
1627 |
+
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
1628 |
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
|
1629 |
+
state, x_t, timestep, stamp
|
1630 |
+
)
|
1631 |
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
1632 |
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
1633 |
+
|
1634 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
1635 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
1636 |
+
|
1637 |
+
inputs_embeds = torch.cat(
|
1638 |
+
[prefix_embs, suffix_embs], dim=1
|
1639 |
+
) # torch.Size([16, 565]), torch.Size([16, 565])
|
1640 |
+
suffix_out = self.fast_visuo_expert.forward(
|
1641 |
+
attention_mask=att_2d_masks,
|
1642 |
+
position_ids=position_ids,
|
1643 |
+
inputs_embeds=inputs_embeds,
|
1644 |
+
)
|
1645 |
+
suffix_out = suffix_out[:, -self.config.s1_action_steps :]
|
1646 |
+
suffix_out = suffix_out.to(dtype=torch.float32)
|
1647 |
+
v_t = self.action_out_proj(suffix_out)
|
1648 |
+
return v_t
|
1649 |
+
|
1650 |
+
|
1651 |
+
class ValueQueryHead(nn.Module):
|
1652 |
+
def __init__(self, paligemma_with_expert, config):
|
1653 |
+
super().__init__()
|
1654 |
+
# gemma_expert for processing img and languge tokens
|
1655 |
+
# paligemma with export fot processing image features
|
1656 |
+
self.config = config
|
1657 |
+
self.paligemma_with_expert = paligemma_with_expert
|
1658 |
+
|
1659 |
+
vqh_backbone_config = VQHBackboneConfig()
|
1660 |
+
self.vqh_backbone = VQHBackbone(config=vqh_backbone_config)
|
1661 |
+
|
1662 |
+
cal_ql_config = CalQlConfig(
|
1663 |
+
obs_encoded_dim=self.paligemma_with_expert.config.paligemma_config.hidden_size,
|
1664 |
+
action_dim=config.vqh_chunk_size * config.action_feature.shape[0],
|
1665 |
+
actor_lr=config.actor_lr,
|
1666 |
+
critic_lr=config.critic_lr,
|
1667 |
+
temp_lr=config.temp_lr,
|
1668 |
+
)
|
1669 |
+
self.calql = CalQL(config=cal_ql_config)
|
1670 |
+
|
1671 |
+
self.query_embedding = nn.Parameter(
|
1672 |
+
torch.zeros(
|
1673 |
+
self.paligemma_with_expert.config.paligemma_config.hidden_size,
|
1674 |
+
dtype=torch.bfloat16,
|
1675 |
+
)
|
1676 |
+
)
|
1677 |
+
|
1678 |
+
def embed_prefix(
|
1679 |
+
self, images, img_masks, lang_tokens, lang_masks
|
1680 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
1681 |
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
1682 |
+
for PaliGemma transformer processing.
|
1683 |
+
"""
|
1684 |
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
1685 |
+
embs = []
|
1686 |
+
pad_masks = []
|
1687 |
+
att_masks = []
|
1688 |
+
|
1689 |
+
# TODO: remove for loop
|
1690 |
+
for (
|
1691 |
+
img,
|
1692 |
+
img_mask,
|
1693 |
+
) in zip(images, img_masks, strict=False):
|
1694 |
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
1695 |
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
1696 |
+
|
1697 |
+
# Normalize image embeddings
|
1698 |
+
img_emb_dim = img_emb.shape[-1]
|
1699 |
+
img_emb = img_emb * torch.tensor(
|
1700 |
+
img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device
|
1701 |
+
)
|
1702 |
+
|
1703 |
+
bsize, num_img_embs = img_emb.shape[:2]
|
1704 |
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
1705 |
+
|
1706 |
+
embs.append(img_emb)
|
1707 |
+
pad_masks.append(img_mask)
|
1708 |
+
|
1709 |
+
# Create attention masks so that image tokens attend to each other
|
1710 |
+
att_masks += [0] * num_img_embs
|
1711 |
+
|
1712 |
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(
|
1713 |
+
lang_tokens
|
1714 |
+
).detach()
|
1715 |
+
|
1716 |
+
# Normalize language embeddings
|
1717 |
+
lang_emb_dim = lang_emb.shape[-1]
|
1718 |
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
1719 |
+
|
1720 |
+
embs.append(lang_emb)
|
1721 |
+
pad_masks.append(lang_masks)
|
1722 |
+
|
1723 |
+
# full attention between image and language inputs
|
1724 |
+
num_lang_embs = lang_emb.shape[1]
|
1725 |
+
att_masks += [0] * num_lang_embs
|
1726 |
+
|
1727 |
+
embs = torch.cat(embs, dim=1)
|
1728 |
+
pad_masks = torch.cat(pad_masks, dim=1)
|
1729 |
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
1730 |
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
1731 |
+
|
1732 |
+
# NOTE: add query embedding for each sequence
|
1733 |
+
seq_lengths = pad_masks.sum(dim=1).long() # w/o padding length
|
1734 |
+
seq_len = embs.shape[1]
|
1735 |
+
|
1736 |
+
new_seq_len = seq_len + 1
|
1737 |
+
new_embs = torch.zeros(
|
1738 |
+
(bsize, new_seq_len, embs.shape[-1]), dtype=embs.dtype, device=embs.device
|
1739 |
+
)
|
1740 |
+
new_pad_masks = torch.zeros(
|
1741 |
+
(bsize, new_seq_len), dtype=pad_masks.dtype, device=pad_masks.device
|
1742 |
+
)
|
1743 |
+
new_att_masks = torch.zeros(
|
1744 |
+
(bsize, new_seq_len), dtype=att_masks.dtype, device=att_masks.device
|
1745 |
+
)
|
1746 |
+
|
1747 |
+
batch_idx = torch.arange(bsize, device=embs.device).view(-1, 1)
|
1748 |
+
seq_idx = (
|
1749 |
+
torch.arange(seq_len, device=embs.device).view(1, -1).expand(bsize, -1)
|
1750 |
+
)
|
1751 |
+
|
1752 |
+
mask = seq_idx >= seq_lengths.unsqueeze(1)
|
1753 |
+
new_seq_idx = seq_idx + mask.long()
|
1754 |
+
|
1755 |
+
new_embs[batch_idx, new_seq_idx] = embs
|
1756 |
+
new_pad_masks[batch_idx, new_seq_idx] = pad_masks
|
1757 |
+
new_att_masks[batch_idx, new_seq_idx] = att_masks
|
1758 |
+
new_embs[torch.arange(bsize), seq_lengths] = self.query_embedding.unsqueeze(
|
1759 |
+
0
|
1760 |
+
).expand(bsize, -1)
|
1761 |
+
new_pad_masks[torch.arange(bsize), seq_lengths] = True
|
1762 |
+
new_att_masks[torch.arange(bsize), seq_lengths] = False
|
1763 |
+
|
1764 |
+
return new_embs, new_pad_masks, new_att_masks
|
1765 |
+
|
1766 |
+
def process_next_obs(
|
1767 |
+
self,
|
1768 |
+
images: list[torch.Tensor],
|
1769 |
+
img_masks: list[torch.Tensor],
|
1770 |
+
vqh_images: list[torch.Tensor],
|
1771 |
+
vqh_img_masks: list[torch.Tensor],
|
1772 |
+
lang_tokens: torch.Tensor,
|
1773 |
+
lang_masks: torch.Tensor,
|
1774 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor]:
|
1775 |
+
"""Process next observation for ValueQueryHead model.
|
1776 |
+
Args:
|
1777 |
+
images (list): List of image tensors.
|
1778 |
+
img_masks (list): List of image mask tensors.
|
1779 |
+
vqh_images (list): List of ValueQueryHead image tensors.
|
1780 |
+
vqh_img_masks (list): List of ValueQueryHead image mask tensors.
|
1781 |
+
lang_tokens (torch.Tensor): Language token tensor.
|
1782 |
+
lang_masks (torch.Tensor): Language mask tensor.
|
1783 |
+
|
1784 |
+
Returns:
|
1785 |
+
tuple: Tuple containing processed images, masks, and language tokens.
|
1786 |
+
"""
|
1787 |
+
new_images = []
|
1788 |
+
new_img_masks = []
|
1789 |
+
|
1790 |
+
for img, next_img, img_mask, next_img_mask in zip(
|
1791 |
+
images, vqh_images, img_masks, vqh_img_masks
|
1792 |
+
):
|
1793 |
+
new_images.append(torch.cat([img, next_img], dim=0))
|
1794 |
+
new_img_masks.append(torch.cat([img_mask, next_img_mask], dim=0))
|
1795 |
+
|
1796 |
+
new_lang_tokens = torch.cat([lang_tokens, lang_tokens], dim=0)
|
1797 |
+
new_lang_masks = torch.cat([lang_masks, lang_masks], dim=0)
|
1798 |
+
|
1799 |
+
return (
|
1800 |
+
new_images,
|
1801 |
+
new_img_masks,
|
1802 |
+
new_lang_tokens,
|
1803 |
+
new_lang_masks,
|
1804 |
+
)
|
1805 |
+
|
1806 |
+
@jaxtyped(typechecker=typechecker)
|
1807 |
+
def forward(
|
1808 |
+
self,
|
1809 |
+
images: list[Float[Tensor, "batch 3 224 224"]],
|
1810 |
+
img_masks: list[Bool[Tensor, " batch"]],
|
1811 |
+
lang_tokens: Int64[Tensor, "batch seq_len"],
|
1812 |
+
lang_masks: Bool[Tensor, "batch seq_len"],
|
1813 |
+
vqh_images: list[Float[Tensor, "batch 3 224 224"]],
|
1814 |
+
vqh_img_masks: list[Bool[Tensor, " batch"]],
|
1815 |
+
actions: Float[
|
1816 |
+
Tensor,
|
1817 |
+
"batch {self.config.vqh_chunk_size} {self.config.action_feature.shape[0]}",
|
1818 |
+
],
|
1819 |
+
rewards: Float[Tensor, " batch"],
|
1820 |
+
mc_returns: Float[Tensor, " batch"],
|
1821 |
+
masks: Float[Tensor, " batch"],
|
1822 |
+
) -> tuple[Tensor, Tensor, Tensor, dict]:
|
1823 |
+
"""Forward pass for ValueQueryHead model.
|
1824 |
+
Args:
|
1825 |
+
images (torch.Tensor): Image input tensor.
|
1826 |
+
img_masks (torch.Tensor): Image mask tensor.
|
1827 |
+
lang_tokens (torch.Tensor): Language token tensor.
|
1828 |
+
lang_masks (torch.Tensor): Language mask tensor.
|
1829 |
+
|
1830 |
+
Returns:
|
1831 |
+
tuple: Tuple containing the output tensors.
|
1832 |
+
"""
|
1833 |
+
images, img_masks, lang_tokens, lang_masks = self.process_next_obs(
|
1834 |
+
images, img_masks, vqh_images, vqh_img_masks, lang_tokens, lang_masks
|
1835 |
+
)
|
1836 |
+
|
1837 |
+
embs, pad_masks, att_masks = self.embed_prefix(
|
1838 |
+
images, img_masks, lang_tokens, lang_masks
|
1839 |
+
)
|
1840 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
1841 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
1842 |
+
|
1843 |
+
suffix_out = self.vqh_backbone.forward(
|
1844 |
+
attention_mask=att_2d_masks,
|
1845 |
+
position_ids=position_ids,
|
1846 |
+
inputs_embeds=embs,
|
1847 |
+
) # (2B, S, E)
|
1848 |
+
|
1849 |
+
batch_indices = torch.arange(suffix_out.shape[0], device=suffix_out.device)
|
1850 |
+
query_embedding_idx = pad_masks.sum(-1).long() - 1
|
1851 |
+
query_embedding = suffix_out[batch_indices, query_embedding_idx]
|
1852 |
+
|
1853 |
+
cal_ql_batch: at.CalQlBatch = dict(
|
1854 |
+
encoded_observations=query_embedding[
|
1855 |
+
: int(query_embedding.shape[0] / 2)
|
1856 |
+
].to(dtype=torch.float32),
|
1857 |
+
encoded_next_observations=query_embedding[
|
1858 |
+
int(query_embedding.shape[0] / 2) :
|
1859 |
+
].to(dtype=torch.float32),
|
1860 |
+
actions=actions.view(actions.shape[0], -1),
|
1861 |
+
rewards=rewards,
|
1862 |
+
mc_returns=mc_returns,
|
1863 |
+
masks=masks,
|
1864 |
+
)
|
1865 |
+
temperature_loss, policy_loss, critic_loss, log_dict = self.calql(cal_ql_batch)
|
1866 |
+
|
1867 |
+
return temperature_loss, policy_loss, critic_loss, log_dict
|
1868 |
+
|
1869 |
+
@jaxtyped(typechecker=typechecker)
|
1870 |
+
def select_q_actions(
|
1871 |
+
self,
|
1872 |
+
images: list[Float[Tensor, "Batch 3 224 224"]],
|
1873 |
+
img_masks: list[Bool[Tensor, " Batch"]],
|
1874 |
+
lang_tokens: Int64[Tensor, "Batch seq_len"],
|
1875 |
+
lang_masks: Bool[Tensor, "Batch seq_len"],
|
1876 |
+
noise_actions: Float[
|
1877 |
+
Tensor,
|
1878 |
+
"Batch s2_candidates_num {self.config.vqh_chunk_size} {self.config.action_feature.shape[0]}",
|
1879 |
+
],
|
1880 |
+
) -> tuple[Int64[Tensor, " Batch"], Float[Tensor, "Batch s2_candidates_num"]]:
|
1881 |
+
batch_size = noise_actions.shape[0]
|
1882 |
+
s2_candidates_num = noise_actions.shape[1]
|
1883 |
+
embs, pad_masks, att_masks = self.embed_prefix(
|
1884 |
+
images, img_masks, lang_tokens, lang_masks
|
1885 |
+
)
|
1886 |
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
1887 |
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
1888 |
+
|
1889 |
+
suffix_out = self.vqh_backbone.forward(
|
1890 |
+
attention_mask=att_2d_masks,
|
1891 |
+
position_ids=position_ids,
|
1892 |
+
inputs_embeds=embs,
|
1893 |
+
) # (B, S, E)
|
1894 |
+
|
1895 |
+
batch_indices = torch.arange(suffix_out.shape[0], device=suffix_out.device)
|
1896 |
+
query_embedding_idx = pad_masks.sum(-1).long() - 1
|
1897 |
+
query_embedding = suffix_out[batch_indices, query_embedding_idx]
|
1898 |
+
|
1899 |
+
noise_actions = noise_actions.reshape(batch_size, s2_candidates_num, -1)
|
1900 |
+
q_values = self.calql.get_q_values(query_embedding, noise_actions)
|
1901 |
+
|
1902 |
+
action_index = torch.argmax(q_values, dim=1)
|
1903 |
+
|
1904 |
+
print(f"MaxValues: {q_values.max(dim=1)[0].tolist()}")
|
1905 |
+
print(f"MinValues: {q_values.min(dim=1)[0].tolist()}")
|
1906 |
+
print(f"MeanValues: {q_values.mean(dim=1)[0].tolist()}")
|
1907 |
+
print(f"ActionIndex: {action_index.tolist()}")
|
1908 |
+
|
1909 |
+
return action_index, q_values
|
paligemma_with_expert.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.version
|
5 |
+
from pytest import Cache
|
6 |
+
from torch import nn
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
GemmaForCausalLM,
|
10 |
+
PaliGemmaForConditionalGeneration,
|
11 |
+
PretrainedConfig,
|
12 |
+
PreTrainedModel,
|
13 |
+
)
|
14 |
+
from transformers.models.auto import CONFIG_MAPPING
|
15 |
+
|
16 |
+
|
17 |
+
def apply_rope(x, positions, max_wavelength=10_000):
|
18 |
+
"""
|
19 |
+
Applies RoPE positions [B, L] to x [B, L, H, D].
|
20 |
+
"""
|
21 |
+
d_half = x.shape[-1] // 2
|
22 |
+
device = x.device
|
23 |
+
dtype = x.dtype
|
24 |
+
x = x.to(torch.float32)
|
25 |
+
|
26 |
+
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
27 |
+
d_half, dtype=torch.float32, device=device
|
28 |
+
)
|
29 |
+
timescale = max_wavelength**freq_exponents
|
30 |
+
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
31 |
+
torch.float32
|
32 |
+
)
|
33 |
+
|
34 |
+
radians = radians[..., None, :]
|
35 |
+
|
36 |
+
sin = torch.sin(radians) # .to(dtype=dtype)
|
37 |
+
cos = torch.cos(radians) # .to(dtype=dtype)
|
38 |
+
|
39 |
+
x1, x2 = x.split(d_half, dim=-1)
|
40 |
+
res = torch.empty_like(x)
|
41 |
+
res[..., :d_half] = x1 * cos - x2 * sin
|
42 |
+
res[..., d_half:] = x2 * cos + x1 * sin
|
43 |
+
|
44 |
+
return res.to(dtype)
|
45 |
+
|
46 |
+
|
47 |
+
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
48 |
+
model_type = "PaliGemmaWithExpertModel"
|
49 |
+
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
paligemma_config: dict | None = None,
|
54 |
+
gemma_expert_config: dict | None = None,
|
55 |
+
freeze_vision_encoder: bool = True,
|
56 |
+
train_expert_only: bool = True,
|
57 |
+
attention_implementation: str = "eager",
|
58 |
+
**kwargs,
|
59 |
+
):
|
60 |
+
self.freeze_vision_encoder = freeze_vision_encoder
|
61 |
+
self.train_expert_only = train_expert_only
|
62 |
+
self.attention_implementation = attention_implementation
|
63 |
+
|
64 |
+
if paligemma_config is None:
|
65 |
+
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
66 |
+
transformers_version="4.48.1",
|
67 |
+
_vocab_size=257152,
|
68 |
+
bos_token_id=2,
|
69 |
+
eos_token_id=1,
|
70 |
+
hidden_size=2048,
|
71 |
+
image_token_index=257152,
|
72 |
+
model_type="paligemma",
|
73 |
+
pad_token_id=0,
|
74 |
+
projection_dim=2048,
|
75 |
+
text_config={
|
76 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
77 |
+
"hidden_size": 2048,
|
78 |
+
"intermediate_size": 16384,
|
79 |
+
"model_type": "gemma",
|
80 |
+
"num_attention_heads": 8,
|
81 |
+
"num_hidden_layers": 18,
|
82 |
+
"num_image_tokens": 256,
|
83 |
+
"num_key_value_heads": 1,
|
84 |
+
"torch_dtype": "float32",
|
85 |
+
"vocab_size": 257152,
|
86 |
+
},
|
87 |
+
vision_config={
|
88 |
+
"hidden_size": 1152,
|
89 |
+
"intermediate_size": 4304,
|
90 |
+
"model_type": "siglip_vision_model",
|
91 |
+
"num_attention_heads": 16,
|
92 |
+
"num_hidden_layers": 27,
|
93 |
+
"num_image_tokens": 256,
|
94 |
+
"patch_size": 14,
|
95 |
+
"projection_dim": 2048,
|
96 |
+
"projector_hidden_act": "gelu_fast",
|
97 |
+
"torch_dtype": "float32",
|
98 |
+
"vision_use_head": False,
|
99 |
+
},
|
100 |
+
)
|
101 |
+
elif isinstance(paligemma_config, dict):
|
102 |
+
if "model_type" not in paligemma_config:
|
103 |
+
paligemma_config["model_type"] = "paligemma"
|
104 |
+
|
105 |
+
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
106 |
+
self.paligemma_config = cfg_cls(**paligemma_config)
|
107 |
+
|
108 |
+
if gemma_expert_config is None:
|
109 |
+
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
110 |
+
attention_bias=False,
|
111 |
+
attention_dropout=0.0,
|
112 |
+
bos_token_id=2,
|
113 |
+
eos_token_id=1,
|
114 |
+
head_dim=256,
|
115 |
+
hidden_act="gelu_pytorch_tanh",
|
116 |
+
hidden_activation="gelu_pytorch_tanh",
|
117 |
+
hidden_size=1024,
|
118 |
+
initializer_range=0.02,
|
119 |
+
intermediate_size=4096,
|
120 |
+
max_position_embeddings=8192,
|
121 |
+
model_type="gemma",
|
122 |
+
num_attention_heads=8,
|
123 |
+
num_hidden_layers=18,
|
124 |
+
num_key_value_heads=1,
|
125 |
+
pad_token_id=0,
|
126 |
+
rms_norm_eps=1e-06,
|
127 |
+
rope_theta=10000.0,
|
128 |
+
torch_dtype="float32",
|
129 |
+
transformers_version="4.48.1",
|
130 |
+
use_cache=True,
|
131 |
+
vocab_size=257152,
|
132 |
+
)
|
133 |
+
elif isinstance(gemma_expert_config, dict):
|
134 |
+
if "model_type" not in gemma_expert_config:
|
135 |
+
gemma_expert_config["model_type"] = "gemma"
|
136 |
+
|
137 |
+
cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
|
138 |
+
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
139 |
+
|
140 |
+
super().__init__(**kwargs)
|
141 |
+
|
142 |
+
def __post_init__(self):
|
143 |
+
super().__post_init__()
|
144 |
+
if self.train_expert_only and not self.freeze_vision_encoder:
|
145 |
+
raise ValueError(
|
146 |
+
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
147 |
+
)
|
148 |
+
|
149 |
+
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
150 |
+
raise ValueError(
|
151 |
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
class PaliGemmaWithExpertModel(PreTrainedModel):
|
156 |
+
config_class = PaliGemmaWithExpertConfig
|
157 |
+
|
158 |
+
def __init__(self, config: PaliGemmaWithExpertConfig):
|
159 |
+
super().__init__(config=config)
|
160 |
+
self.config = config
|
161 |
+
self.paligemma = PaliGemmaForConditionalGeneration(
|
162 |
+
config=config.paligemma_config
|
163 |
+
)
|
164 |
+
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
165 |
+
# Remove unused embed_tokens
|
166 |
+
self.gemma_expert.model.embed_tokens = None
|
167 |
+
self.gemma_expert.lm_head = None
|
168 |
+
self.to_bfloat16_like_physical_intelligence()
|
169 |
+
self.set_requires_grad()
|
170 |
+
|
171 |
+
def set_requires_grad(self):
|
172 |
+
if self.config.freeze_vision_encoder:
|
173 |
+
self.paligemma.vision_tower.eval()
|
174 |
+
for params in self.paligemma.vision_tower.parameters():
|
175 |
+
params.requires_grad = False
|
176 |
+
|
177 |
+
if self.config.train_expert_only:
|
178 |
+
self.paligemma.eval()
|
179 |
+
for params in self.paligemma.parameters():
|
180 |
+
params.requires_grad = False
|
181 |
+
|
182 |
+
def train(self, mode: bool = True):
|
183 |
+
super().train(mode)
|
184 |
+
|
185 |
+
if self.config.freeze_vision_encoder:
|
186 |
+
self.paligemma.vision_tower.eval()
|
187 |
+
|
188 |
+
if self.config.train_expert_only:
|
189 |
+
self.paligemma.eval()
|
190 |
+
|
191 |
+
def to_bfloat16_like_physical_intelligence(self):
|
192 |
+
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
193 |
+
|
194 |
+
params_to_change_dtype = [
|
195 |
+
"language_model.model.layers",
|
196 |
+
"gemma_expert.model.layers",
|
197 |
+
"vision_tower",
|
198 |
+
"multi_modal",
|
199 |
+
]
|
200 |
+
for name, param in self.named_parameters():
|
201 |
+
if any(selector in name for selector in params_to_change_dtype):
|
202 |
+
param.data = param.data.to(dtype=torch.bfloat16)
|
203 |
+
|
204 |
+
def embed_image(self, image: torch.Tensor):
|
205 |
+
return self.paligemma.get_image_features(image)
|
206 |
+
|
207 |
+
def embed_language_tokens(self, tokens: torch.Tensor):
|
208 |
+
return self.paligemma.language_model.model.embed_tokens(tokens)
|
209 |
+
|
210 |
+
# TODO: break down this huge forward into modules or functions
|
211 |
+
def forward(
|
212 |
+
self,
|
213 |
+
attention_mask: Optional[torch.Tensor] = None,
|
214 |
+
position_ids: Optional[torch.LongTensor] = None,
|
215 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
216 |
+
inputs_embeds: List[torch.FloatTensor] = None,
|
217 |
+
use_cache: Optional[bool] = None,
|
218 |
+
fill_kv_cache: Optional[bool] = None,
|
219 |
+
):
|
220 |
+
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
221 |
+
|
222 |
+
for hidden_states in inputs_embeds:
|
223 |
+
# TODO this is very inefficient
|
224 |
+
# dtype is always the same, batch size too (if > 1 len)
|
225 |
+
# device could be trickier in multi gpu edge cases but that's it
|
226 |
+
if hidden_states is None:
|
227 |
+
continue
|
228 |
+
batch_size = hidden_states.shape[0]
|
229 |
+
|
230 |
+
# RMSNorm
|
231 |
+
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
232 |
+
head_dim = self.paligemma.config.text_config.head_dim
|
233 |
+
for layer_idx in range(num_layers):
|
234 |
+
query_states = []
|
235 |
+
key_states = []
|
236 |
+
value_states = []
|
237 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
238 |
+
if hidden_states is None:
|
239 |
+
continue
|
240 |
+
layer = models[i].layers[layer_idx]
|
241 |
+
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
242 |
+
# hidden_states = hidden_states * normalizer
|
243 |
+
hidden_states = layer.input_layernorm(hidden_states)
|
244 |
+
|
245 |
+
input_shape = hidden_states.shape[
|
246 |
+
:-1
|
247 |
+
] # (b s e) -> layer* -> (b s e) -> mlp
|
248 |
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) # (b s h d)
|
249 |
+
|
250 |
+
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
251 |
+
|
252 |
+
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
253 |
+
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
254 |
+
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
255 |
+
|
256 |
+
query_states.append(query_state)
|
257 |
+
key_states.append(key_state)
|
258 |
+
value_states.append(value_state)
|
259 |
+
|
260 |
+
# B,L,H,D with L sequence length, H number of heads, D head dim
|
261 |
+
# concatenate on the number of embeddings/tokens
|
262 |
+
query_states = torch.cat(query_states, dim=1)
|
263 |
+
key_states = torch.cat(key_states, dim=1)
|
264 |
+
value_states = torch.cat(value_states, dim=1)
|
265 |
+
|
266 |
+
query_states = apply_rope(query_states, position_ids)
|
267 |
+
key_states = apply_rope(key_states, position_ids)
|
268 |
+
|
269 |
+
if use_cache and past_key_values is None:
|
270 |
+
past_key_values = {}
|
271 |
+
|
272 |
+
if use_cache:
|
273 |
+
if fill_kv_cache:
|
274 |
+
past_key_values[layer_idx] = {
|
275 |
+
"key_states": key_states,
|
276 |
+
"value_states": value_states,
|
277 |
+
}
|
278 |
+
else:
|
279 |
+
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
280 |
+
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
281 |
+
# the max len, then we (for instance) double the cache size. This implementation already exists
|
282 |
+
# in `transformers`. (molbap)
|
283 |
+
key_states = torch.cat(
|
284 |
+
[past_key_values[layer_idx]["key_states"], key_states], dim=1
|
285 |
+
)
|
286 |
+
value_states = torch.cat(
|
287 |
+
[past_key_values[layer_idx]["value_states"], value_states],
|
288 |
+
dim=1,
|
289 |
+
)
|
290 |
+
|
291 |
+
attention_interface = self.get_attention_interface()
|
292 |
+
att_output = attention_interface(
|
293 |
+
attention_mask,
|
294 |
+
batch_size,
|
295 |
+
head_dim,
|
296 |
+
query_states,
|
297 |
+
key_states,
|
298 |
+
value_states,
|
299 |
+
)
|
300 |
+
att_output = att_output.to(dtype=torch.bfloat16)
|
301 |
+
|
302 |
+
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
303 |
+
outputs_embeds = []
|
304 |
+
start = 0
|
305 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
306 |
+
layer = models[i].layers[layer_idx]
|
307 |
+
|
308 |
+
if hidden_states is not None:
|
309 |
+
end = start + hidden_states.shape[1]
|
310 |
+
|
311 |
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
312 |
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
313 |
+
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
314 |
+
|
315 |
+
# TODO: first dropout (by default 0.0)
|
316 |
+
|
317 |
+
# first residual
|
318 |
+
out_emb += hidden_states
|
319 |
+
after_first_residual = out_emb.clone()
|
320 |
+
|
321 |
+
out_emb = layer.post_attention_layernorm(out_emb)
|
322 |
+
out_emb = layer.mlp(out_emb)
|
323 |
+
|
324 |
+
# TODO: second dropout (by default 0.0)
|
325 |
+
|
326 |
+
# second residual
|
327 |
+
out_emb += after_first_residual
|
328 |
+
|
329 |
+
outputs_embeds.append(out_emb)
|
330 |
+
|
331 |
+
start = end
|
332 |
+
else:
|
333 |
+
outputs_embeds.append(None)
|
334 |
+
|
335 |
+
inputs_embeds = outputs_embeds
|
336 |
+
|
337 |
+
# final norm
|
338 |
+
outputs_embeds = []
|
339 |
+
for i, hidden_states in enumerate(inputs_embeds):
|
340 |
+
if hidden_states is not None:
|
341 |
+
out_emb = models[i].norm(hidden_states)
|
342 |
+
outputs_embeds.append(out_emb)
|
343 |
+
else:
|
344 |
+
outputs_embeds.append(None)
|
345 |
+
|
346 |
+
return outputs_embeds, past_key_values
|
347 |
+
|
348 |
+
def get_attention_interface(self):
|
349 |
+
if self.config.attention_implementation == "fa2":
|
350 |
+
attention_interface = self.flash_attention_forward
|
351 |
+
else:
|
352 |
+
attention_interface = self.eager_attention_forward
|
353 |
+
return attention_interface
|
354 |
+
|
355 |
+
def flash_attention_forward(
|
356 |
+
self,
|
357 |
+
attention_mask,
|
358 |
+
batch_size,
|
359 |
+
head_dim,
|
360 |
+
query_states,
|
361 |
+
key_states,
|
362 |
+
value_states,
|
363 |
+
):
|
364 |
+
raise NotImplementedError("FA2 is not implemented (yet)")
|
365 |
+
|
366 |
+
def eager_attention_forward(
|
367 |
+
self,
|
368 |
+
attention_mask,
|
369 |
+
batch_size,
|
370 |
+
head_dim,
|
371 |
+
query_states,
|
372 |
+
key_states,
|
373 |
+
value_states,
|
374 |
+
):
|
375 |
+
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
376 |
+
num_key_value_heads = (
|
377 |
+
self.config.paligemma_config.text_config.num_key_value_heads
|
378 |
+
)
|
379 |
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
380 |
+
|
381 |
+
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
382 |
+
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
383 |
+
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
384 |
+
sequence_length = key_states.shape[1]
|
385 |
+
|
386 |
+
key_states = key_states[:, :, :, None, :].expand(
|
387 |
+
batch_size,
|
388 |
+
sequence_length,
|
389 |
+
num_key_value_heads,
|
390 |
+
num_key_value_groups,
|
391 |
+
head_dim,
|
392 |
+
)
|
393 |
+
key_states = key_states.reshape(
|
394 |
+
batch_size,
|
395 |
+
sequence_length,
|
396 |
+
num_key_value_heads * num_key_value_groups,
|
397 |
+
head_dim,
|
398 |
+
)
|
399 |
+
|
400 |
+
value_states = value_states[:, :, :, None, :].expand(
|
401 |
+
batch_size,
|
402 |
+
sequence_length,
|
403 |
+
num_key_value_heads,
|
404 |
+
num_key_value_groups,
|
405 |
+
head_dim,
|
406 |
+
)
|
407 |
+
value_states = value_states.reshape(
|
408 |
+
batch_size,
|
409 |
+
sequence_length,
|
410 |
+
num_key_value_heads * num_key_value_groups,
|
411 |
+
head_dim,
|
412 |
+
)
|
413 |
+
|
414 |
+
# Attention here is upcasted to float32 to match the original eager implementation.
|
415 |
+
|
416 |
+
query_states = query_states.to(dtype=torch.float32)
|
417 |
+
key_states = key_states.to(dtype=torch.float32)
|
418 |
+
|
419 |
+
query_states = query_states.transpose(1, 2)
|
420 |
+
key_states = key_states.transpose(1, 2)
|
421 |
+
|
422 |
+
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
423 |
+
att_weights *= head_dim**-0.5
|
424 |
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
425 |
+
|
426 |
+
masked_att_weights = torch.where(
|
427 |
+
attention_mask[:, None, :, :], att_weights, big_neg
|
428 |
+
)
|
429 |
+
|
430 |
+
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
431 |
+
probs = probs.to(dtype=value_states.dtype)
|
432 |
+
|
433 |
+
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
434 |
+
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
435 |
+
|
436 |
+
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
437 |
+
|
438 |
+
att_output = att_output.permute(0, 2, 1, 3)
|
439 |
+
# we use -1 because sequence length can change
|
440 |
+
att_output = att_output.reshape(
|
441 |
+
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
442 |
+
)
|
443 |
+
|
444 |
+
return att_output
|
special_tokens_map.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<image>"
|
4 |
+
],
|
5 |
+
"bos_token": {
|
6 |
+
"content": "<bos>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"eos_token": {
|
13 |
+
"content": "<eos>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false
|
18 |
+
},
|
19 |
+
"pad_token": {
|
20 |
+
"content": "<pad>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
},
|
26 |
+
"unk_token": {
|
27 |
+
"content": "<unk>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:537cbe6b94581ee7b70f7f39453d5c52f2590069aa75d76ceee458fde442523c
|
3 |
+
size 34387383
|
tokenizer_config.json
ADDED
@@ -0,0 +1,1772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<pad>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<eos>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "<bos>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"3": {
|
30 |
+
"content": "<unk>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"4": {
|
38 |
+
"content": "<mask>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": true,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": false
|
44 |
+
},
|
45 |
+
"5": {
|
46 |
+
"content": "<2mass>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": true,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": false
|
52 |
+
},
|
53 |
+
"6": {
|
54 |
+
"content": "[@BOS@]",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": true,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": false
|
60 |
+
},
|
61 |
+
"7": {
|
62 |
+
"content": "<unused0>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": true,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": false
|
68 |
+
},
|
69 |
+
"8": {
|
70 |
+
"content": "<unused1>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": true,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": false
|
76 |
+
},
|
77 |
+
"9": {
|
78 |
+
"content": "<unused2>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": true,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": false
|
84 |
+
},
|
85 |
+
"10": {
|
86 |
+
"content": "<unused3>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": true,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": false
|
92 |
+
},
|
93 |
+
"11": {
|
94 |
+
"content": "<unused4>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": true,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": false
|
100 |
+
},
|
101 |
+
"12": {
|
102 |
+
"content": "<unused5>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": true,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": false
|
108 |
+
},
|
109 |
+
"13": {
|
110 |
+
"content": "<unused6>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": true,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": false
|
116 |
+
},
|
117 |
+
"14": {
|
118 |
+
"content": "<unused7>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": true,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"15": {
|
126 |
+
"content": "<unused8>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": true,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"16": {
|
134 |
+
"content": "<unused9>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": true,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"17": {
|
142 |
+
"content": "<unused10>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": true,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"18": {
|
150 |
+
"content": "<unused11>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": true,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"19": {
|
158 |
+
"content": "<unused12>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": true,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"20": {
|
166 |
+
"content": "<unused13>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": true,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"21": {
|
174 |
+
"content": "<unused14>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": true,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"22": {
|
182 |
+
"content": "<unused15>",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": true,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": false
|
188 |
+
},
|
189 |
+
"23": {
|
190 |
+
"content": "<unused16>",
|
191 |
+
"lstrip": false,
|
192 |
+
"normalized": true,
|
193 |
+
"rstrip": false,
|
194 |
+
"single_word": false,
|
195 |
+
"special": false
|
196 |
+
},
|
197 |
+
"24": {
|
198 |
+
"content": "<unused17>",
|
199 |
+
"lstrip": false,
|
200 |
+
"normalized": true,
|
201 |
+
"rstrip": false,
|
202 |
+
"single_word": false,
|
203 |
+
"special": false
|
204 |
+
},
|
205 |
+
"25": {
|
206 |
+
"content": "<unused18>",
|
207 |
+
"lstrip": false,
|
208 |
+
"normalized": true,
|
209 |
+
"rstrip": false,
|
210 |
+
"single_word": false,
|
211 |
+
"special": false
|
212 |
+
},
|
213 |
+
"26": {
|
214 |
+
"content": "<unused19>",
|
215 |
+
"lstrip": false,
|
216 |
+
"normalized": true,
|
217 |
+
"rstrip": false,
|
218 |
+
"single_word": false,
|
219 |
+
"special": false
|
220 |
+
},
|
221 |
+
"27": {
|
222 |
+
"content": "<unused20>",
|
223 |
+
"lstrip": false,
|
224 |
+
"normalized": true,
|
225 |
+
"rstrip": false,
|
226 |
+
"single_word": false,
|
227 |
+
"special": false
|
228 |
+
},
|
229 |
+
"28": {
|
230 |
+
"content": "<unused21>",
|
231 |
+
"lstrip": false,
|
232 |
+
"normalized": true,
|
233 |
+
"rstrip": false,
|
234 |
+
"single_word": false,
|
235 |
+
"special": false
|
236 |
+
},
|
237 |
+
"29": {
|
238 |
+
"content": "<unused22>",
|
239 |
+
"lstrip": false,
|
240 |
+
"normalized": true,
|
241 |
+
"rstrip": false,
|
242 |
+
"single_word": false,
|
243 |
+
"special": false
|
244 |
+
},
|
245 |
+
"30": {
|
246 |
+
"content": "<unused23>",
|
247 |
+
"lstrip": false,
|
248 |
+
"normalized": true,
|
249 |
+
"rstrip": false,
|
250 |
+
"single_word": false,
|
251 |
+
"special": false
|
252 |
+
},
|
253 |
+
"31": {
|
254 |
+
"content": "<unused24>",
|
255 |
+
"lstrip": false,
|
256 |
+
"normalized": true,
|
257 |
+
"rstrip": false,
|
258 |
+
"single_word": false,
|
259 |
+
"special": false
|
260 |
+
},
|
261 |
+
"32": {
|
262 |
+
"content": "<unused25>",
|
263 |
+
"lstrip": false,
|
264 |
+
"normalized": true,
|
265 |
+
"rstrip": false,
|
266 |
+
"single_word": false,
|
267 |
+
"special": false
|
268 |
+
},
|
269 |
+
"33": {
|
270 |
+
"content": "<unused26>",
|
271 |
+
"lstrip": false,
|
272 |
+
"normalized": true,
|
273 |
+
"rstrip": false,
|
274 |
+
"single_word": false,
|
275 |
+
"special": false
|
276 |
+
},
|
277 |
+
"34": {
|
278 |
+
"content": "<unused27>",
|
279 |
+
"lstrip": false,
|
280 |
+
"normalized": true,
|
281 |
+
"rstrip": false,
|
282 |
+
"single_word": false,
|
283 |
+
"special": false
|
284 |
+
},
|
285 |
+
"35": {
|
286 |
+
"content": "<unused28>",
|
287 |
+
"lstrip": false,
|
288 |
+
"normalized": true,
|
289 |
+
"rstrip": false,
|
290 |
+
"single_word": false,
|
291 |
+
"special": false
|
292 |
+
},
|
293 |
+
"36": {
|
294 |
+
"content": "<unused29>",
|
295 |
+
"lstrip": false,
|
296 |
+
"normalized": true,
|
297 |
+
"rstrip": false,
|
298 |
+
"single_word": false,
|
299 |
+
"special": false
|
300 |
+
},
|
301 |
+
"37": {
|
302 |
+
"content": "<unused30>",
|
303 |
+
"lstrip": false,
|
304 |
+
"normalized": true,
|
305 |
+
"rstrip": false,
|
306 |
+
"single_word": false,
|
307 |
+
"special": false
|
308 |
+
},
|
309 |
+
"38": {
|
310 |
+
"content": "<unused31>",
|
311 |
+
"lstrip": false,
|
312 |
+
"normalized": true,
|
313 |
+
"rstrip": false,
|
314 |
+
"single_word": false,
|
315 |
+
"special": false
|
316 |
+
},
|
317 |
+
"39": {
|
318 |
+
"content": "<unused32>",
|
319 |
+
"lstrip": false,
|
320 |
+
"normalized": true,
|
321 |
+
"rstrip": false,
|
322 |
+
"single_word": false,
|
323 |
+
"special": false
|
324 |
+
},
|
325 |
+
"40": {
|
326 |
+
"content": "<unused33>",
|
327 |
+
"lstrip": false,
|
328 |
+
"normalized": true,
|
329 |
+
"rstrip": false,
|
330 |
+
"single_word": false,
|
331 |
+
"special": false
|
332 |
+
},
|
333 |
+
"41": {
|
334 |
+
"content": "<unused34>",
|
335 |
+
"lstrip": false,
|
336 |
+
"normalized": true,
|
337 |
+
"rstrip": false,
|
338 |
+
"single_word": false,
|
339 |
+
"special": false
|
340 |
+
},
|
341 |
+
"42": {
|
342 |
+
"content": "<unused35>",
|
343 |
+
"lstrip": false,
|
344 |
+
"normalized": true,
|
345 |
+
"rstrip": false,
|
346 |
+
"single_word": false,
|
347 |
+
"special": false
|
348 |
+
},
|
349 |
+
"43": {
|
350 |
+
"content": "<unused36>",
|
351 |
+
"lstrip": false,
|
352 |
+
"normalized": true,
|
353 |
+
"rstrip": false,
|
354 |
+
"single_word": false,
|
355 |
+
"special": false
|
356 |
+
},
|
357 |
+
"44": {
|
358 |
+
"content": "<unused37>",
|
359 |
+
"lstrip": false,
|
360 |
+
"normalized": true,
|
361 |
+
"rstrip": false,
|
362 |
+
"single_word": false,
|
363 |
+
"special": false
|
364 |
+
},
|
365 |
+
"45": {
|
366 |
+
"content": "<unused38>",
|
367 |
+
"lstrip": false,
|
368 |
+
"normalized": true,
|
369 |
+
"rstrip": false,
|
370 |
+
"single_word": false,
|
371 |
+
"special": false
|
372 |
+
},
|
373 |
+
"46": {
|
374 |
+
"content": "<unused39>",
|
375 |
+
"lstrip": false,
|
376 |
+
"normalized": true,
|
377 |
+
"rstrip": false,
|
378 |
+
"single_word": false,
|
379 |
+
"special": false
|
380 |
+
},
|
381 |
+
"47": {
|
382 |
+
"content": "<unused40>",
|
383 |
+
"lstrip": false,
|
384 |
+
"normalized": true,
|
385 |
+
"rstrip": false,
|
386 |
+
"single_word": false,
|
387 |
+
"special": false
|
388 |
+
},
|
389 |
+
"48": {
|
390 |
+
"content": "<unused41>",
|
391 |
+
"lstrip": false,
|
392 |
+
"normalized": true,
|
393 |
+
"rstrip": false,
|
394 |
+
"single_word": false,
|
395 |
+
"special": false
|
396 |
+
},
|
397 |
+
"49": {
|
398 |
+
"content": "<unused42>",
|
399 |
+
"lstrip": false,
|
400 |
+
"normalized": true,
|
401 |
+
"rstrip": false,
|
402 |
+
"single_word": false,
|
403 |
+
"special": false
|
404 |
+
},
|
405 |
+
"50": {
|
406 |
+
"content": "<unused43>",
|
407 |
+
"lstrip": false,
|
408 |
+
"normalized": true,
|
409 |
+
"rstrip": false,
|
410 |
+
"single_word": false,
|
411 |
+
"special": false
|
412 |
+
},
|
413 |
+
"51": {
|
414 |
+
"content": "<unused44>",
|
415 |
+
"lstrip": false,
|
416 |
+
"normalized": true,
|
417 |
+
"rstrip": false,
|
418 |
+
"single_word": false,
|
419 |
+
"special": false
|
420 |
+
},
|
421 |
+
"52": {
|
422 |
+
"content": "<unused45>",
|
423 |
+
"lstrip": false,
|
424 |
+
"normalized": true,
|
425 |
+
"rstrip": false,
|
426 |
+
"single_word": false,
|
427 |
+
"special": false
|
428 |
+
},
|
429 |
+
"53": {
|
430 |
+
"content": "<unused46>",
|
431 |
+
"lstrip": false,
|
432 |
+
"normalized": true,
|
433 |
+
"rstrip": false,
|
434 |
+
"single_word": false,
|
435 |
+
"special": false
|
436 |
+
},
|
437 |
+
"54": {
|
438 |
+
"content": "<unused47>",
|
439 |
+
"lstrip": false,
|
440 |
+
"normalized": true,
|
441 |
+
"rstrip": false,
|
442 |
+
"single_word": false,
|
443 |
+
"special": false
|
444 |
+
},
|
445 |
+
"55": {
|
446 |
+
"content": "<unused48>",
|
447 |
+
"lstrip": false,
|
448 |
+
"normalized": true,
|
449 |
+
"rstrip": false,
|
450 |
+
"single_word": false,
|
451 |
+
"special": false
|
452 |
+
},
|
453 |
+
"56": {
|
454 |
+
"content": "<unused49>",
|
455 |
+
"lstrip": false,
|
456 |
+
"normalized": true,
|
457 |
+
"rstrip": false,
|
458 |
+
"single_word": false,
|
459 |
+
"special": false
|
460 |
+
},
|
461 |
+
"57": {
|
462 |
+
"content": "<unused50>",
|
463 |
+
"lstrip": false,
|
464 |
+
"normalized": true,
|
465 |
+
"rstrip": false,
|
466 |
+
"single_word": false,
|
467 |
+
"special": false
|
468 |
+
},
|
469 |
+
"58": {
|
470 |
+
"content": "<unused51>",
|
471 |
+
"lstrip": false,
|
472 |
+
"normalized": true,
|
473 |
+
"rstrip": false,
|
474 |
+
"single_word": false,
|
475 |
+
"special": false
|
476 |
+
},
|
477 |
+
"59": {
|
478 |
+
"content": "<unused52>",
|
479 |
+
"lstrip": false,
|
480 |
+
"normalized": true,
|
481 |
+
"rstrip": false,
|
482 |
+
"single_word": false,
|
483 |
+
"special": false
|
484 |
+
},
|
485 |
+
"60": {
|
486 |
+
"content": "<unused53>",
|
487 |
+
"lstrip": false,
|
488 |
+
"normalized": true,
|
489 |
+
"rstrip": false,
|
490 |
+
"single_word": false,
|
491 |
+
"special": false
|
492 |
+
},
|
493 |
+
"61": {
|
494 |
+
"content": "<unused54>",
|
495 |
+
"lstrip": false,
|
496 |
+
"normalized": true,
|
497 |
+
"rstrip": false,
|
498 |
+
"single_word": false,
|
499 |
+
"special": false
|
500 |
+
},
|
501 |
+
"62": {
|
502 |
+
"content": "<unused55>",
|
503 |
+
"lstrip": false,
|
504 |
+
"normalized": true,
|
505 |
+
"rstrip": false,
|
506 |
+
"single_word": false,
|
507 |
+
"special": false
|
508 |
+
},
|
509 |
+
"63": {
|
510 |
+
"content": "<unused56>",
|
511 |
+
"lstrip": false,
|
512 |
+
"normalized": true,
|
513 |
+
"rstrip": false,
|
514 |
+
"single_word": false,
|
515 |
+
"special": false
|
516 |
+
},
|
517 |
+
"64": {
|
518 |
+
"content": "<unused57>",
|
519 |
+
"lstrip": false,
|
520 |
+
"normalized": true,
|
521 |
+
"rstrip": false,
|
522 |
+
"single_word": false,
|
523 |
+
"special": false
|
524 |
+
},
|
525 |
+
"65": {
|
526 |
+
"content": "<unused58>",
|
527 |
+
"lstrip": false,
|
528 |
+
"normalized": true,
|
529 |
+
"rstrip": false,
|
530 |
+
"single_word": false,
|
531 |
+
"special": false
|
532 |
+
},
|
533 |
+
"66": {
|
534 |
+
"content": "<unused59>",
|
535 |
+
"lstrip": false,
|
536 |
+
"normalized": true,
|
537 |
+
"rstrip": false,
|
538 |
+
"single_word": false,
|
539 |
+
"special": false
|
540 |
+
},
|
541 |
+
"67": {
|
542 |
+
"content": "<unused60>",
|
543 |
+
"lstrip": false,
|
544 |
+
"normalized": true,
|
545 |
+
"rstrip": false,
|
546 |
+
"single_word": false,
|
547 |
+
"special": false
|
548 |
+
},
|
549 |
+
"68": {
|
550 |
+
"content": "<unused61>",
|
551 |
+
"lstrip": false,
|
552 |
+
"normalized": true,
|
553 |
+
"rstrip": false,
|
554 |
+
"single_word": false,
|
555 |
+
"special": false
|
556 |
+
},
|
557 |
+
"69": {
|
558 |
+
"content": "<unused62>",
|
559 |
+
"lstrip": false,
|
560 |
+
"normalized": true,
|
561 |
+
"rstrip": false,
|
562 |
+
"single_word": false,
|
563 |
+
"special": false
|
564 |
+
},
|
565 |
+
"70": {
|
566 |
+
"content": "<unused63>",
|
567 |
+
"lstrip": false,
|
568 |
+
"normalized": true,
|
569 |
+
"rstrip": false,
|
570 |
+
"single_word": false,
|
571 |
+
"special": false
|
572 |
+
},
|
573 |
+
"71": {
|
574 |
+
"content": "<unused64>",
|
575 |
+
"lstrip": false,
|
576 |
+
"normalized": true,
|
577 |
+
"rstrip": false,
|
578 |
+
"single_word": false,
|
579 |
+
"special": false
|
580 |
+
},
|
581 |
+
"72": {
|
582 |
+
"content": "<unused65>",
|
583 |
+
"lstrip": false,
|
584 |
+
"normalized": true,
|
585 |
+
"rstrip": false,
|
586 |
+
"single_word": false,
|
587 |
+
"special": false
|
588 |
+
},
|
589 |
+
"73": {
|
590 |
+
"content": "<unused66>",
|
591 |
+
"lstrip": false,
|
592 |
+
"normalized": true,
|
593 |
+
"rstrip": false,
|
594 |
+
"single_word": false,
|
595 |
+
"special": false
|
596 |
+
},
|
597 |
+
"74": {
|
598 |
+
"content": "<unused67>",
|
599 |
+
"lstrip": false,
|
600 |
+
"normalized": true,
|
601 |
+
"rstrip": false,
|
602 |
+
"single_word": false,
|
603 |
+
"special": false
|
604 |
+
},
|
605 |
+
"75": {
|
606 |
+
"content": "<unused68>",
|
607 |
+
"lstrip": false,
|
608 |
+
"normalized": true,
|
609 |
+
"rstrip": false,
|
610 |
+
"single_word": false,
|
611 |
+
"special": false
|
612 |
+
},
|
613 |
+
"76": {
|
614 |
+
"content": "<unused69>",
|
615 |
+
"lstrip": false,
|
616 |
+
"normalized": true,
|
617 |
+
"rstrip": false,
|
618 |
+
"single_word": false,
|
619 |
+
"special": false
|
620 |
+
},
|
621 |
+
"77": {
|
622 |
+
"content": "<unused70>",
|
623 |
+
"lstrip": false,
|
624 |
+
"normalized": true,
|
625 |
+
"rstrip": false,
|
626 |
+
"single_word": false,
|
627 |
+
"special": false
|
628 |
+
},
|
629 |
+
"78": {
|
630 |
+
"content": "<unused71>",
|
631 |
+
"lstrip": false,
|
632 |
+
"normalized": true,
|
633 |
+
"rstrip": false,
|
634 |
+
"single_word": false,
|
635 |
+
"special": false
|
636 |
+
},
|
637 |
+
"79": {
|
638 |
+
"content": "<unused72>",
|
639 |
+
"lstrip": false,
|
640 |
+
"normalized": true,
|
641 |
+
"rstrip": false,
|
642 |
+
"single_word": false,
|
643 |
+
"special": false
|
644 |
+
},
|
645 |
+
"80": {
|
646 |
+
"content": "<unused73>",
|
647 |
+
"lstrip": false,
|
648 |
+
"normalized": true,
|
649 |
+
"rstrip": false,
|
650 |
+
"single_word": false,
|
651 |
+
"special": false
|
652 |
+
},
|
653 |
+
"81": {
|
654 |
+
"content": "<unused74>",
|
655 |
+
"lstrip": false,
|
656 |
+
"normalized": true,
|
657 |
+
"rstrip": false,
|
658 |
+
"single_word": false,
|
659 |
+
"special": false
|
660 |
+
},
|
661 |
+
"82": {
|
662 |
+
"content": "<unused75>",
|
663 |
+
"lstrip": false,
|
664 |
+
"normalized": true,
|
665 |
+
"rstrip": false,
|
666 |
+
"single_word": false,
|
667 |
+
"special": false
|
668 |
+
},
|
669 |
+
"83": {
|
670 |
+
"content": "<unused76>",
|
671 |
+
"lstrip": false,
|
672 |
+
"normalized": true,
|
673 |
+
"rstrip": false,
|
674 |
+
"single_word": false,
|
675 |
+
"special": false
|
676 |
+
},
|
677 |
+
"84": {
|
678 |
+
"content": "<unused77>",
|
679 |
+
"lstrip": false,
|
680 |
+
"normalized": true,
|
681 |
+
"rstrip": false,
|
682 |
+
"single_word": false,
|
683 |
+
"special": false
|
684 |
+
},
|
685 |
+
"85": {
|
686 |
+
"content": "<unused78>",
|
687 |
+
"lstrip": false,
|
688 |
+
"normalized": true,
|
689 |
+
"rstrip": false,
|
690 |
+
"single_word": false,
|
691 |
+
"special": false
|
692 |
+
},
|
693 |
+
"86": {
|
694 |
+
"content": "<unused79>",
|
695 |
+
"lstrip": false,
|
696 |
+
"normalized": true,
|
697 |
+
"rstrip": false,
|
698 |
+
"single_word": false,
|
699 |
+
"special": false
|
700 |
+
},
|
701 |
+
"87": {
|
702 |
+
"content": "<unused80>",
|
703 |
+
"lstrip": false,
|
704 |
+
"normalized": true,
|
705 |
+
"rstrip": false,
|
706 |
+
"single_word": false,
|
707 |
+
"special": false
|
708 |
+
},
|
709 |
+
"88": {
|
710 |
+
"content": "<unused81>",
|
711 |
+
"lstrip": false,
|
712 |
+
"normalized": true,
|
713 |
+
"rstrip": false,
|
714 |
+
"single_word": false,
|
715 |
+
"special": false
|
716 |
+
},
|
717 |
+
"89": {
|
718 |
+
"content": "<unused82>",
|
719 |
+
"lstrip": false,
|
720 |
+
"normalized": true,
|
721 |
+
"rstrip": false,
|
722 |
+
"single_word": false,
|
723 |
+
"special": false
|
724 |
+
},
|
725 |
+
"90": {
|
726 |
+
"content": "<unused83>",
|
727 |
+
"lstrip": false,
|
728 |
+
"normalized": true,
|
729 |
+
"rstrip": false,
|
730 |
+
"single_word": false,
|
731 |
+
"special": false
|
732 |
+
},
|
733 |
+
"91": {
|
734 |
+
"content": "<unused84>",
|
735 |
+
"lstrip": false,
|
736 |
+
"normalized": true,
|
737 |
+
"rstrip": false,
|
738 |
+
"single_word": false,
|
739 |
+
"special": false
|
740 |
+
},
|
741 |
+
"92": {
|
742 |
+
"content": "<unused85>",
|
743 |
+
"lstrip": false,
|
744 |
+
"normalized": true,
|
745 |
+
"rstrip": false,
|
746 |
+
"single_word": false,
|
747 |
+
"special": false
|
748 |
+
},
|
749 |
+
"93": {
|
750 |
+
"content": "<unused86>",
|
751 |
+
"lstrip": false,
|
752 |
+
"normalized": true,
|
753 |
+
"rstrip": false,
|
754 |
+
"single_word": false,
|
755 |
+
"special": false
|
756 |
+
},
|
757 |
+
"94": {
|
758 |
+
"content": "<unused87>",
|
759 |
+
"lstrip": false,
|
760 |
+
"normalized": true,
|
761 |
+
"rstrip": false,
|
762 |
+
"single_word": false,
|
763 |
+
"special": false
|
764 |
+
},
|
765 |
+
"95": {
|
766 |
+
"content": "<unused88>",
|
767 |
+
"lstrip": false,
|
768 |
+
"normalized": true,
|
769 |
+
"rstrip": false,
|
770 |
+
"single_word": false,
|
771 |
+
"special": false
|
772 |
+
},
|
773 |
+
"96": {
|
774 |
+
"content": "<unused89>",
|
775 |
+
"lstrip": false,
|
776 |
+
"normalized": true,
|
777 |
+
"rstrip": false,
|
778 |
+
"single_word": false,
|
779 |
+
"special": false
|
780 |
+
},
|
781 |
+
"97": {
|
782 |
+
"content": "<unused90>",
|
783 |
+
"lstrip": false,
|
784 |
+
"normalized": true,
|
785 |
+
"rstrip": false,
|
786 |
+
"single_word": false,
|
787 |
+
"special": false
|
788 |
+
},
|
789 |
+
"98": {
|
790 |
+
"content": "<unused91>",
|
791 |
+
"lstrip": false,
|
792 |
+
"normalized": true,
|
793 |
+
"rstrip": false,
|
794 |
+
"single_word": false,
|
795 |
+
"special": false
|
796 |
+
},
|
797 |
+
"99": {
|
798 |
+
"content": "<unused92>",
|
799 |
+
"lstrip": false,
|
800 |
+
"normalized": true,
|
801 |
+
"rstrip": false,
|
802 |
+
"single_word": false,
|
803 |
+
"special": false
|
804 |
+
},
|
805 |
+
"100": {
|
806 |
+
"content": "<unused93>",
|
807 |
+
"lstrip": false,
|
808 |
+
"normalized": true,
|
809 |
+
"rstrip": false,
|
810 |
+
"single_word": false,
|
811 |
+
"special": false
|
812 |
+
},
|
813 |
+
"101": {
|
814 |
+
"content": "<unused94>",
|
815 |
+
"lstrip": false,
|
816 |
+
"normalized": true,
|
817 |
+
"rstrip": false,
|
818 |
+
"single_word": false,
|
819 |
+
"special": false
|
820 |
+
},
|
821 |
+
"102": {
|
822 |
+
"content": "<unused95>",
|
823 |
+
"lstrip": false,
|
824 |
+
"normalized": true,
|
825 |
+
"rstrip": false,
|
826 |
+
"single_word": false,
|
827 |
+
"special": false
|
828 |
+
},
|
829 |
+
"103": {
|
830 |
+
"content": "<unused96>",
|
831 |
+
"lstrip": false,
|
832 |
+
"normalized": true,
|
833 |
+
"rstrip": false,
|
834 |
+
"single_word": false,
|
835 |
+
"special": false
|
836 |
+
},
|
837 |
+
"104": {
|
838 |
+
"content": "<unused97>",
|
839 |
+
"lstrip": false,
|
840 |
+
"normalized": true,
|
841 |
+
"rstrip": false,
|
842 |
+
"single_word": false,
|
843 |
+
"special": false
|
844 |
+
},
|
845 |
+
"105": {
|
846 |
+
"content": "<unused98>",
|
847 |
+
"lstrip": false,
|
848 |
+
"normalized": true,
|
849 |
+
"rstrip": false,
|
850 |
+
"single_word": false,
|
851 |
+
"special": false
|
852 |
+
},
|
853 |
+
"106": {
|
854 |
+
"content": "<start_of_turn>",
|
855 |
+
"lstrip": false,
|
856 |
+
"normalized": true,
|
857 |
+
"rstrip": false,
|
858 |
+
"single_word": false,
|
859 |
+
"special": false
|
860 |
+
},
|
861 |
+
"107": {
|
862 |
+
"content": "<end_of_turn>",
|
863 |
+
"lstrip": false,
|
864 |
+
"normalized": true,
|
865 |
+
"rstrip": false,
|
866 |
+
"single_word": false,
|
867 |
+
"special": false
|
868 |
+
},
|
869 |
+
"108": {
|
870 |
+
"content": "\n",
|
871 |
+
"lstrip": false,
|
872 |
+
"normalized": true,
|
873 |
+
"rstrip": false,
|
874 |
+
"single_word": false,
|
875 |
+
"special": false
|
876 |
+
},
|
877 |
+
"109": {
|
878 |
+
"content": "\n\n",
|
879 |
+
"lstrip": false,
|
880 |
+
"normalized": true,
|
881 |
+
"rstrip": false,
|
882 |
+
"single_word": false,
|
883 |
+
"special": false
|
884 |
+
},
|
885 |
+
"110": {
|
886 |
+
"content": "\n\n\n",
|
887 |
+
"lstrip": false,
|
888 |
+
"normalized": true,
|
889 |
+
"rstrip": false,
|
890 |
+
"single_word": false,
|
891 |
+
"special": false
|
892 |
+
},
|
893 |
+
"111": {
|
894 |
+
"content": "\n\n\n\n",
|
895 |
+
"lstrip": false,
|
896 |
+
"normalized": true,
|
897 |
+
"rstrip": false,
|
898 |
+
"single_word": false,
|
899 |
+
"special": false
|
900 |
+
},
|
901 |
+
"112": {
|
902 |
+
"content": "\n\n\n\n\n",
|
903 |
+
"lstrip": false,
|
904 |
+
"normalized": true,
|
905 |
+
"rstrip": false,
|
906 |
+
"single_word": false,
|
907 |
+
"special": false
|
908 |
+
},
|
909 |
+
"113": {
|
910 |
+
"content": "\n\n\n\n\n\n",
|
911 |
+
"lstrip": false,
|
912 |
+
"normalized": true,
|
913 |
+
"rstrip": false,
|
914 |
+
"single_word": false,
|
915 |
+
"special": false
|
916 |
+
},
|
917 |
+
"114": {
|
918 |
+
"content": "\n\n\n\n\n\n\n",
|
919 |
+
"lstrip": false,
|
920 |
+
"normalized": true,
|
921 |
+
"rstrip": false,
|
922 |
+
"single_word": false,
|
923 |
+
"special": false
|
924 |
+
},
|
925 |
+
"115": {
|
926 |
+
"content": "\n\n\n\n\n\n\n\n",
|
927 |
+
"lstrip": false,
|
928 |
+
"normalized": true,
|
929 |
+
"rstrip": false,
|
930 |
+
"single_word": false,
|
931 |
+
"special": false
|
932 |
+
},
|
933 |
+
"116": {
|
934 |
+
"content": "\n\n\n\n\n\n\n\n\n",
|
935 |
+
"lstrip": false,
|
936 |
+
"normalized": true,
|
937 |
+
"rstrip": false,
|
938 |
+
"single_word": false,
|
939 |
+
"special": false
|
940 |
+
},
|
941 |
+
"117": {
|
942 |
+
"content": "\n\n\n\n\n\n\n\n\n\n",
|
943 |
+
"lstrip": false,
|
944 |
+
"normalized": true,
|
945 |
+
"rstrip": false,
|
946 |
+
"single_word": false,
|
947 |
+
"special": false
|
948 |
+
},
|
949 |
+
"118": {
|
950 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n",
|
951 |
+
"lstrip": false,
|
952 |
+
"normalized": true,
|
953 |
+
"rstrip": false,
|
954 |
+
"single_word": false,
|
955 |
+
"special": false
|
956 |
+
},
|
957 |
+
"119": {
|
958 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n",
|
959 |
+
"lstrip": false,
|
960 |
+
"normalized": true,
|
961 |
+
"rstrip": false,
|
962 |
+
"single_word": false,
|
963 |
+
"special": false
|
964 |
+
},
|
965 |
+
"120": {
|
966 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
967 |
+
"lstrip": false,
|
968 |
+
"normalized": true,
|
969 |
+
"rstrip": false,
|
970 |
+
"single_word": false,
|
971 |
+
"special": false
|
972 |
+
},
|
973 |
+
"121": {
|
974 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
975 |
+
"lstrip": false,
|
976 |
+
"normalized": true,
|
977 |
+
"rstrip": false,
|
978 |
+
"single_word": false,
|
979 |
+
"special": false
|
980 |
+
},
|
981 |
+
"122": {
|
982 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
983 |
+
"lstrip": false,
|
984 |
+
"normalized": true,
|
985 |
+
"rstrip": false,
|
986 |
+
"single_word": false,
|
987 |
+
"special": false
|
988 |
+
},
|
989 |
+
"123": {
|
990 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
991 |
+
"lstrip": false,
|
992 |
+
"normalized": true,
|
993 |
+
"rstrip": false,
|
994 |
+
"single_word": false,
|
995 |
+
"special": false
|
996 |
+
},
|
997 |
+
"124": {
|
998 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
999 |
+
"lstrip": false,
|
1000 |
+
"normalized": true,
|
1001 |
+
"rstrip": false,
|
1002 |
+
"single_word": false,
|
1003 |
+
"special": false
|
1004 |
+
},
|
1005 |
+
"125": {
|
1006 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1007 |
+
"lstrip": false,
|
1008 |
+
"normalized": true,
|
1009 |
+
"rstrip": false,
|
1010 |
+
"single_word": false,
|
1011 |
+
"special": false
|
1012 |
+
},
|
1013 |
+
"126": {
|
1014 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1015 |
+
"lstrip": false,
|
1016 |
+
"normalized": true,
|
1017 |
+
"rstrip": false,
|
1018 |
+
"single_word": false,
|
1019 |
+
"special": false
|
1020 |
+
},
|
1021 |
+
"127": {
|
1022 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1023 |
+
"lstrip": false,
|
1024 |
+
"normalized": true,
|
1025 |
+
"rstrip": false,
|
1026 |
+
"single_word": false,
|
1027 |
+
"special": false
|
1028 |
+
},
|
1029 |
+
"128": {
|
1030 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1031 |
+
"lstrip": false,
|
1032 |
+
"normalized": true,
|
1033 |
+
"rstrip": false,
|
1034 |
+
"single_word": false,
|
1035 |
+
"special": false
|
1036 |
+
},
|
1037 |
+
"129": {
|
1038 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1039 |
+
"lstrip": false,
|
1040 |
+
"normalized": true,
|
1041 |
+
"rstrip": false,
|
1042 |
+
"single_word": false,
|
1043 |
+
"special": false
|
1044 |
+
},
|
1045 |
+
"130": {
|
1046 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1047 |
+
"lstrip": false,
|
1048 |
+
"normalized": true,
|
1049 |
+
"rstrip": false,
|
1050 |
+
"single_word": false,
|
1051 |
+
"special": false
|
1052 |
+
},
|
1053 |
+
"131": {
|
1054 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1055 |
+
"lstrip": false,
|
1056 |
+
"normalized": true,
|
1057 |
+
"rstrip": false,
|
1058 |
+
"single_word": false,
|
1059 |
+
"special": false
|
1060 |
+
},
|
1061 |
+
"132": {
|
1062 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1063 |
+
"lstrip": false,
|
1064 |
+
"normalized": true,
|
1065 |
+
"rstrip": false,
|
1066 |
+
"single_word": false,
|
1067 |
+
"special": false
|
1068 |
+
},
|
1069 |
+
"133": {
|
1070 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1071 |
+
"lstrip": false,
|
1072 |
+
"normalized": true,
|
1073 |
+
"rstrip": false,
|
1074 |
+
"single_word": false,
|
1075 |
+
"special": false
|
1076 |
+
},
|
1077 |
+
"134": {
|
1078 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1079 |
+
"lstrip": false,
|
1080 |
+
"normalized": true,
|
1081 |
+
"rstrip": false,
|
1082 |
+
"single_word": false,
|
1083 |
+
"special": false
|
1084 |
+
},
|
1085 |
+
"135": {
|
1086 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1087 |
+
"lstrip": false,
|
1088 |
+
"normalized": true,
|
1089 |
+
"rstrip": false,
|
1090 |
+
"single_word": false,
|
1091 |
+
"special": false
|
1092 |
+
},
|
1093 |
+
"136": {
|
1094 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1095 |
+
"lstrip": false,
|
1096 |
+
"normalized": true,
|
1097 |
+
"rstrip": false,
|
1098 |
+
"single_word": false,
|
1099 |
+
"special": false
|
1100 |
+
},
|
1101 |
+
"137": {
|
1102 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1103 |
+
"lstrip": false,
|
1104 |
+
"normalized": true,
|
1105 |
+
"rstrip": false,
|
1106 |
+
"single_word": false,
|
1107 |
+
"special": false
|
1108 |
+
},
|
1109 |
+
"138": {
|
1110 |
+
"content": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
|
1111 |
+
"lstrip": false,
|
1112 |
+
"normalized": true,
|
1113 |
+
"rstrip": false,
|
1114 |
+
"single_word": false,
|
1115 |
+
"special": false
|
1116 |
+
},
|
1117 |
+
"139": {
|
1118 |
+
"content": "▁▁",
|
1119 |
+
"lstrip": false,
|
1120 |
+
"normalized": true,
|
1121 |
+
"rstrip": false,
|
1122 |
+
"single_word": false,
|
1123 |
+
"special": false
|
1124 |
+
},
|
1125 |
+
"140": {
|
1126 |
+
"content": "▁▁▁",
|
1127 |
+
"lstrip": false,
|
1128 |
+
"normalized": true,
|
1129 |
+
"rstrip": false,
|
1130 |
+
"single_word": false,
|
1131 |
+
"special": false
|
1132 |
+
},
|
1133 |
+
"141": {
|
1134 |
+
"content": "▁▁▁▁",
|
1135 |
+
"lstrip": false,
|
1136 |
+
"normalized": true,
|
1137 |
+
"rstrip": false,
|
1138 |
+
"single_word": false,
|
1139 |
+
"special": false
|
1140 |
+
},
|
1141 |
+
"142": {
|
1142 |
+
"content": "▁▁▁▁▁",
|
1143 |
+
"lstrip": false,
|
1144 |
+
"normalized": true,
|
1145 |
+
"rstrip": false,
|
1146 |
+
"single_word": false,
|
1147 |
+
"special": false
|
1148 |
+
},
|
1149 |
+
"143": {
|
1150 |
+
"content": "▁▁▁▁▁▁",
|
1151 |
+
"lstrip": false,
|
1152 |
+
"normalized": true,
|
1153 |
+
"rstrip": false,
|
1154 |
+
"single_word": false,
|
1155 |
+
"special": false
|
1156 |
+
},
|
1157 |
+
"144": {
|
1158 |
+
"content": "▁▁▁▁▁▁▁",
|
1159 |
+
"lstrip": false,
|
1160 |
+
"normalized": true,
|
1161 |
+
"rstrip": false,
|
1162 |
+
"single_word": false,
|
1163 |
+
"special": false
|
1164 |
+
},
|
1165 |
+
"145": {
|
1166 |
+
"content": "▁▁▁▁▁▁▁▁",
|
1167 |
+
"lstrip": false,
|
1168 |
+
"normalized": true,
|
1169 |
+
"rstrip": false,
|
1170 |
+
"single_word": false,
|
1171 |
+
"special": false
|
1172 |
+
},
|
1173 |
+
"146": {
|
1174 |
+
"content": "▁▁▁▁▁▁▁▁▁",
|
1175 |
+
"lstrip": false,
|
1176 |
+
"normalized": true,
|
1177 |
+
"rstrip": false,
|
1178 |
+
"single_word": false,
|
1179 |
+
"special": false
|
1180 |
+
},
|
1181 |
+
"147": {
|
1182 |
+
"content": "▁▁▁▁▁▁▁▁▁▁",
|
1183 |
+
"lstrip": false,
|
1184 |
+
"normalized": true,
|
1185 |
+
"rstrip": false,
|
1186 |
+
"single_word": false,
|
1187 |
+
"special": false
|
1188 |
+
},
|
1189 |
+
"148": {
|
1190 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁",
|
1191 |
+
"lstrip": false,
|
1192 |
+
"normalized": true,
|
1193 |
+
"rstrip": false,
|
1194 |
+
"single_word": false,
|
1195 |
+
"special": false
|
1196 |
+
},
|
1197 |
+
"149": {
|
1198 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁",
|
1199 |
+
"lstrip": false,
|
1200 |
+
"normalized": true,
|
1201 |
+
"rstrip": false,
|
1202 |
+
"single_word": false,
|
1203 |
+
"special": false
|
1204 |
+
},
|
1205 |
+
"150": {
|
1206 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1207 |
+
"lstrip": false,
|
1208 |
+
"normalized": true,
|
1209 |
+
"rstrip": false,
|
1210 |
+
"single_word": false,
|
1211 |
+
"special": false
|
1212 |
+
},
|
1213 |
+
"151": {
|
1214 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1215 |
+
"lstrip": false,
|
1216 |
+
"normalized": true,
|
1217 |
+
"rstrip": false,
|
1218 |
+
"single_word": false,
|
1219 |
+
"special": false
|
1220 |
+
},
|
1221 |
+
"152": {
|
1222 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1223 |
+
"lstrip": false,
|
1224 |
+
"normalized": true,
|
1225 |
+
"rstrip": false,
|
1226 |
+
"single_word": false,
|
1227 |
+
"special": false
|
1228 |
+
},
|
1229 |
+
"153": {
|
1230 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1231 |
+
"lstrip": false,
|
1232 |
+
"normalized": true,
|
1233 |
+
"rstrip": false,
|
1234 |
+
"single_word": false,
|
1235 |
+
"special": false
|
1236 |
+
},
|
1237 |
+
"154": {
|
1238 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1239 |
+
"lstrip": false,
|
1240 |
+
"normalized": true,
|
1241 |
+
"rstrip": false,
|
1242 |
+
"single_word": false,
|
1243 |
+
"special": false
|
1244 |
+
},
|
1245 |
+
"155": {
|
1246 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1247 |
+
"lstrip": false,
|
1248 |
+
"normalized": true,
|
1249 |
+
"rstrip": false,
|
1250 |
+
"single_word": false,
|
1251 |
+
"special": false
|
1252 |
+
},
|
1253 |
+
"156": {
|
1254 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1255 |
+
"lstrip": false,
|
1256 |
+
"normalized": true,
|
1257 |
+
"rstrip": false,
|
1258 |
+
"single_word": false,
|
1259 |
+
"special": false
|
1260 |
+
},
|
1261 |
+
"157": {
|
1262 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1263 |
+
"lstrip": false,
|
1264 |
+
"normalized": true,
|
1265 |
+
"rstrip": false,
|
1266 |
+
"single_word": false,
|
1267 |
+
"special": false
|
1268 |
+
},
|
1269 |
+
"158": {
|
1270 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1271 |
+
"lstrip": false,
|
1272 |
+
"normalized": true,
|
1273 |
+
"rstrip": false,
|
1274 |
+
"single_word": false,
|
1275 |
+
"special": false
|
1276 |
+
},
|
1277 |
+
"159": {
|
1278 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1279 |
+
"lstrip": false,
|
1280 |
+
"normalized": true,
|
1281 |
+
"rstrip": false,
|
1282 |
+
"single_word": false,
|
1283 |
+
"special": false
|
1284 |
+
},
|
1285 |
+
"160": {
|
1286 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1287 |
+
"lstrip": false,
|
1288 |
+
"normalized": true,
|
1289 |
+
"rstrip": false,
|
1290 |
+
"single_word": false,
|
1291 |
+
"special": false
|
1292 |
+
},
|
1293 |
+
"161": {
|
1294 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1295 |
+
"lstrip": false,
|
1296 |
+
"normalized": true,
|
1297 |
+
"rstrip": false,
|
1298 |
+
"single_word": false,
|
1299 |
+
"special": false
|
1300 |
+
},
|
1301 |
+
"162": {
|
1302 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1303 |
+
"lstrip": false,
|
1304 |
+
"normalized": true,
|
1305 |
+
"rstrip": false,
|
1306 |
+
"single_word": false,
|
1307 |
+
"special": false
|
1308 |
+
},
|
1309 |
+
"163": {
|
1310 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1311 |
+
"lstrip": false,
|
1312 |
+
"normalized": true,
|
1313 |
+
"rstrip": false,
|
1314 |
+
"single_word": false,
|
1315 |
+
"special": false
|
1316 |
+
},
|
1317 |
+
"164": {
|
1318 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1319 |
+
"lstrip": false,
|
1320 |
+
"normalized": true,
|
1321 |
+
"rstrip": false,
|
1322 |
+
"single_word": false,
|
1323 |
+
"special": false
|
1324 |
+
},
|
1325 |
+
"165": {
|
1326 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1327 |
+
"lstrip": false,
|
1328 |
+
"normalized": true,
|
1329 |
+
"rstrip": false,
|
1330 |
+
"single_word": false,
|
1331 |
+
"special": false
|
1332 |
+
},
|
1333 |
+
"166": {
|
1334 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1335 |
+
"lstrip": false,
|
1336 |
+
"normalized": true,
|
1337 |
+
"rstrip": false,
|
1338 |
+
"single_word": false,
|
1339 |
+
"special": false
|
1340 |
+
},
|
1341 |
+
"167": {
|
1342 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1343 |
+
"lstrip": false,
|
1344 |
+
"normalized": true,
|
1345 |
+
"rstrip": false,
|
1346 |
+
"single_word": false,
|
1347 |
+
"special": false
|
1348 |
+
},
|
1349 |
+
"168": {
|
1350 |
+
"content": "▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁",
|
1351 |
+
"lstrip": false,
|
1352 |
+
"normalized": true,
|
1353 |
+
"rstrip": false,
|
1354 |
+
"single_word": false,
|
1355 |
+
"special": false
|
1356 |
+
},
|
1357 |
+
"169": {
|
1358 |
+
"content": "<table>",
|
1359 |
+
"lstrip": false,
|
1360 |
+
"normalized": true,
|
1361 |
+
"rstrip": false,
|
1362 |
+
"single_word": false,
|
1363 |
+
"special": false
|
1364 |
+
},
|
1365 |
+
"170": {
|
1366 |
+
"content": "<caption>",
|
1367 |
+
"lstrip": false,
|
1368 |
+
"normalized": true,
|
1369 |
+
"rstrip": false,
|
1370 |
+
"single_word": false,
|
1371 |
+
"special": false
|
1372 |
+
},
|
1373 |
+
"171": {
|
1374 |
+
"content": "<thead>",
|
1375 |
+
"lstrip": false,
|
1376 |
+
"normalized": true,
|
1377 |
+
"rstrip": false,
|
1378 |
+
"single_word": false,
|
1379 |
+
"special": false
|
1380 |
+
},
|
1381 |
+
"172": {
|
1382 |
+
"content": "<tbody>",
|
1383 |
+
"lstrip": false,
|
1384 |
+
"normalized": true,
|
1385 |
+
"rstrip": false,
|
1386 |
+
"single_word": false,
|
1387 |
+
"special": false
|
1388 |
+
},
|
1389 |
+
"173": {
|
1390 |
+
"content": "<tfoot>",
|
1391 |
+
"lstrip": false,
|
1392 |
+
"normalized": true,
|
1393 |
+
"rstrip": false,
|
1394 |
+
"single_word": false,
|
1395 |
+
"special": false
|
1396 |
+
},
|
1397 |
+
"174": {
|
1398 |
+
"content": "<tr>",
|
1399 |
+
"lstrip": false,
|
1400 |
+
"normalized": true,
|
1401 |
+
"rstrip": false,
|
1402 |
+
"single_word": false,
|
1403 |
+
"special": false
|
1404 |
+
},
|
1405 |
+
"175": {
|
1406 |
+
"content": "<th>",
|
1407 |
+
"lstrip": false,
|
1408 |
+
"normalized": true,
|
1409 |
+
"rstrip": false,
|
1410 |
+
"single_word": false,
|
1411 |
+
"special": false
|
1412 |
+
},
|
1413 |
+
"176": {
|
1414 |
+
"content": "<td>",
|
1415 |
+
"lstrip": false,
|
1416 |
+
"normalized": true,
|
1417 |
+
"rstrip": false,
|
1418 |
+
"single_word": false,
|
1419 |
+
"special": false
|
1420 |
+
},
|
1421 |
+
"177": {
|
1422 |
+
"content": "</table>",
|
1423 |
+
"lstrip": false,
|
1424 |
+
"normalized": true,
|
1425 |
+
"rstrip": false,
|
1426 |
+
"single_word": false,
|
1427 |
+
"special": false
|
1428 |
+
},
|
1429 |
+
"178": {
|
1430 |
+
"content": "</caption>",
|
1431 |
+
"lstrip": false,
|
1432 |
+
"normalized": true,
|
1433 |
+
"rstrip": false,
|
1434 |
+
"single_word": false,
|
1435 |
+
"special": false
|
1436 |
+
},
|
1437 |
+
"179": {
|
1438 |
+
"content": "</thead>",
|
1439 |
+
"lstrip": false,
|
1440 |
+
"normalized": true,
|
1441 |
+
"rstrip": false,
|
1442 |
+
"single_word": false,
|
1443 |
+
"special": false
|
1444 |
+
},
|
1445 |
+
"180": {
|
1446 |
+
"content": "</tbody>",
|
1447 |
+
"lstrip": false,
|
1448 |
+
"normalized": true,
|
1449 |
+
"rstrip": false,
|
1450 |
+
"single_word": false,
|
1451 |
+
"special": false
|
1452 |
+
},
|
1453 |
+
"181": {
|
1454 |
+
"content": "</tfoot>",
|
1455 |
+
"lstrip": false,
|
1456 |
+
"normalized": true,
|
1457 |
+
"rstrip": false,
|
1458 |
+
"single_word": false,
|
1459 |
+
"special": false
|
1460 |
+
},
|
1461 |
+
"182": {
|
1462 |
+
"content": "</tr>",
|
1463 |
+
"lstrip": false,
|
1464 |
+
"normalized": true,
|
1465 |
+
"rstrip": false,
|
1466 |
+
"single_word": false,
|
1467 |
+
"special": false
|
1468 |
+
},
|
1469 |
+
"183": {
|
1470 |
+
"content": "</th>",
|
1471 |
+
"lstrip": false,
|
1472 |
+
"normalized": true,
|
1473 |
+
"rstrip": false,
|
1474 |
+
"single_word": false,
|
1475 |
+
"special": false
|
1476 |
+
},
|
1477 |
+
"184": {
|
1478 |
+
"content": "</td>",
|
1479 |
+
"lstrip": false,
|
1480 |
+
"normalized": true,
|
1481 |
+
"rstrip": false,
|
1482 |
+
"single_word": false,
|
1483 |
+
"special": false
|
1484 |
+
},
|
1485 |
+
"185": {
|
1486 |
+
"content": "<h1>",
|
1487 |
+
"lstrip": false,
|
1488 |
+
"normalized": true,
|
1489 |
+
"rstrip": false,
|
1490 |
+
"single_word": false,
|
1491 |
+
"special": false
|
1492 |
+
},
|
1493 |
+
"186": {
|
1494 |
+
"content": "<h2>",
|
1495 |
+
"lstrip": false,
|
1496 |
+
"normalized": true,
|
1497 |
+
"rstrip": false,
|
1498 |
+
"single_word": false,
|
1499 |
+
"special": false
|
1500 |
+
},
|
1501 |
+
"187": {
|
1502 |
+
"content": "<h3>",
|
1503 |
+
"lstrip": false,
|
1504 |
+
"normalized": true,
|
1505 |
+
"rstrip": false,
|
1506 |
+
"single_word": false,
|
1507 |
+
"special": false
|
1508 |
+
},
|
1509 |
+
"188": {
|
1510 |
+
"content": "<h4>",
|
1511 |
+
"lstrip": false,
|
1512 |
+
"normalized": true,
|
1513 |
+
"rstrip": false,
|
1514 |
+
"single_word": false,
|
1515 |
+
"special": false
|
1516 |
+
},
|
1517 |
+
"189": {
|
1518 |
+
"content": "<h5>",
|
1519 |
+
"lstrip": false,
|
1520 |
+
"normalized": true,
|
1521 |
+
"rstrip": false,
|
1522 |
+
"single_word": false,
|
1523 |
+
"special": false
|
1524 |
+
},
|
1525 |
+
"190": {
|
1526 |
+
"content": "<h6>",
|
1527 |
+
"lstrip": false,
|
1528 |
+
"normalized": true,
|
1529 |
+
"rstrip": false,
|
1530 |
+
"single_word": false,
|
1531 |
+
"special": false
|
1532 |
+
},
|
1533 |
+
"191": {
|
1534 |
+
"content": "<blockquote>",
|
1535 |
+
"lstrip": false,
|
1536 |
+
"normalized": true,
|
1537 |
+
"rstrip": false,
|
1538 |
+
"single_word": false,
|
1539 |
+
"special": false
|
1540 |
+
},
|
1541 |
+
"192": {
|
1542 |
+
"content": "</h1>",
|
1543 |
+
"lstrip": false,
|
1544 |
+
"normalized": true,
|
1545 |
+
"rstrip": false,
|
1546 |
+
"single_word": false,
|
1547 |
+
"special": false
|
1548 |
+
},
|
1549 |
+
"193": {
|
1550 |
+
"content": "</h2>",
|
1551 |
+
"lstrip": false,
|
1552 |
+
"normalized": true,
|
1553 |
+
"rstrip": false,
|
1554 |
+
"single_word": false,
|
1555 |
+
"special": false
|
1556 |
+
},
|
1557 |
+
"194": {
|
1558 |
+
"content": "</h3>",
|
1559 |
+
"lstrip": false,
|
1560 |
+
"normalized": true,
|
1561 |
+
"rstrip": false,
|
1562 |
+
"single_word": false,
|
1563 |
+
"special": false
|
1564 |
+
},
|
1565 |
+
"195": {
|
1566 |
+
"content": "</h4>",
|
1567 |
+
"lstrip": false,
|
1568 |
+
"normalized": true,
|
1569 |
+
"rstrip": false,
|
1570 |
+
"single_word": false,
|
1571 |
+
"special": false
|
1572 |
+
},
|
1573 |
+
"196": {
|
1574 |
+
"content": "</h5>",
|
1575 |
+
"lstrip": false,
|
1576 |
+
"normalized": true,
|
1577 |
+
"rstrip": false,
|
1578 |
+
"single_word": false,
|
1579 |
+
"special": false
|
1580 |
+
},
|
1581 |
+
"197": {
|
1582 |
+
"content": "</h6>",
|
1583 |
+
"lstrip": false,
|
1584 |
+
"normalized": true,
|
1585 |
+
"rstrip": false,
|
1586 |
+
"single_word": false,
|
1587 |
+
"special": false
|
1588 |
+
},
|
1589 |
+
"198": {
|
1590 |
+
"content": "</blockquote>",
|
1591 |
+
"lstrip": false,
|
1592 |
+
"normalized": true,
|
1593 |
+
"rstrip": false,
|
1594 |
+
"single_word": false,
|
1595 |
+
"special": false
|
1596 |
+
},
|
1597 |
+
"199": {
|
1598 |
+
"content": "<strong>",
|
1599 |
+
"lstrip": false,
|
1600 |
+
"normalized": true,
|
1601 |
+
"rstrip": false,
|
1602 |
+
"single_word": false,
|
1603 |
+
"special": false
|
1604 |
+
},
|
1605 |
+
"200": {
|
1606 |
+
"content": "<em>",
|
1607 |
+
"lstrip": false,
|
1608 |
+
"normalized": true,
|
1609 |
+
"rstrip": false,
|
1610 |
+
"single_word": false,
|
1611 |
+
"special": false
|
1612 |
+
},
|
1613 |
+
"201": {
|
1614 |
+
"content": "<b>",
|
1615 |
+
"lstrip": false,
|
1616 |
+
"normalized": true,
|
1617 |
+
"rstrip": false,
|
1618 |
+
"single_word": false,
|
1619 |
+
"special": false
|
1620 |
+
},
|
1621 |
+
"202": {
|
1622 |
+
"content": "<i>",
|
1623 |
+
"lstrip": false,
|
1624 |
+
"normalized": true,
|
1625 |
+
"rstrip": false,
|
1626 |
+
"single_word": false,
|
1627 |
+
"special": false
|
1628 |
+
},
|
1629 |
+
"203": {
|
1630 |
+
"content": "<u>",
|
1631 |
+
"lstrip": false,
|
1632 |
+
"normalized": true,
|
1633 |
+
"rstrip": false,
|
1634 |
+
"single_word": false,
|
1635 |
+
"special": false
|
1636 |
+
},
|
1637 |
+
"204": {
|
1638 |
+
"content": "<s>",
|
1639 |
+
"lstrip": false,
|
1640 |
+
"normalized": true,
|
1641 |
+
"rstrip": false,
|
1642 |
+
"single_word": false,
|
1643 |
+
"special": false
|
1644 |
+
},
|
1645 |
+
"205": {
|
1646 |
+
"content": "<sub>",
|
1647 |
+
"lstrip": false,
|
1648 |
+
"normalized": true,
|
1649 |
+
"rstrip": false,
|
1650 |
+
"single_word": false,
|
1651 |
+
"special": false
|
1652 |
+
},
|
1653 |
+
"206": {
|
1654 |
+
"content": "<sup>",
|
1655 |
+
"lstrip": false,
|
1656 |
+
"normalized": true,
|
1657 |
+
"rstrip": false,
|
1658 |
+
"single_word": false,
|
1659 |
+
"special": false
|
1660 |
+
},
|
1661 |
+
"207": {
|
1662 |
+
"content": "<code>",
|
1663 |
+
"lstrip": false,
|
1664 |
+
"normalized": true,
|
1665 |
+
"rstrip": false,
|
1666 |
+
"single_word": false,
|
1667 |
+
"special": false
|
1668 |
+
},
|
1669 |
+
"208": {
|
1670 |
+
"content": "</strong>",
|
1671 |
+
"lstrip": false,
|
1672 |
+
"normalized": true,
|
1673 |
+
"rstrip": false,
|
1674 |
+
"single_word": false,
|
1675 |
+
"special": false
|
1676 |
+
},
|
1677 |
+
"209": {
|
1678 |
+
"content": "</em>",
|
1679 |
+
"lstrip": false,
|
1680 |
+
"normalized": true,
|
1681 |
+
"rstrip": false,
|
1682 |
+
"single_word": false,
|
1683 |
+
"special": false
|
1684 |
+
},
|
1685 |
+
"210": {
|
1686 |
+
"content": "</b>",
|
1687 |
+
"lstrip": false,
|
1688 |
+
"normalized": true,
|
1689 |
+
"rstrip": false,
|
1690 |
+
"single_word": false,
|
1691 |
+
"special": false
|
1692 |
+
},
|
1693 |
+
"211": {
|
1694 |
+
"content": "</i>",
|
1695 |
+
"lstrip": false,
|
1696 |
+
"normalized": true,
|
1697 |
+
"rstrip": false,
|
1698 |
+
"single_word": false,
|
1699 |
+
"special": false
|
1700 |
+
},
|
1701 |
+
"212": {
|
1702 |
+
"content": "</u>",
|
1703 |
+
"lstrip": false,
|
1704 |
+
"normalized": true,
|
1705 |
+
"rstrip": false,
|
1706 |
+
"single_word": false,
|
1707 |
+
"special": false
|
1708 |
+
},
|
1709 |
+
"213": {
|
1710 |
+
"content": "</s>",
|
1711 |
+
"lstrip": false,
|
1712 |
+
"normalized": true,
|
1713 |
+
"rstrip": false,
|
1714 |
+
"single_word": false,
|
1715 |
+
"special": false
|
1716 |
+
},
|
1717 |
+
"214": {
|
1718 |
+
"content": "</sub>",
|
1719 |
+
"lstrip": false,
|
1720 |
+
"normalized": true,
|
1721 |
+
"rstrip": false,
|
1722 |
+
"single_word": false,
|
1723 |
+
"special": false
|
1724 |
+
},
|
1725 |
+
"215": {
|
1726 |
+
"content": "</sup>",
|
1727 |
+
"lstrip": false,
|
1728 |
+
"normalized": true,
|
1729 |
+
"rstrip": false,
|
1730 |
+
"single_word": false,
|
1731 |
+
"special": false
|
1732 |
+
},
|
1733 |
+
"216": {
|
1734 |
+
"content": "</code>",
|
1735 |
+
"lstrip": false,
|
1736 |
+
"normalized": true,
|
1737 |
+
"rstrip": false,
|
1738 |
+
"single_word": false,
|
1739 |
+
"special": false
|
1740 |
+
},
|
1741 |
+
"257152": {
|
1742 |
+
"content": "<image>",
|
1743 |
+
"lstrip": false,
|
1744 |
+
"normalized": false,
|
1745 |
+
"rstrip": false,
|
1746 |
+
"single_word": false,
|
1747 |
+
"special": true
|
1748 |
+
}
|
1749 |
+
},
|
1750 |
+
"additional_special_tokens": [
|
1751 |
+
"<image>"
|
1752 |
+
],
|
1753 |
+
"bos_token": "<bos>",
|
1754 |
+
"clean_up_tokenization_spaces": false,
|
1755 |
+
"eos_token": "<eos>",
|
1756 |
+
"extra_special_tokens": {},
|
1757 |
+
"max_length": 48,
|
1758 |
+
"model_max_length": 1000000000000000019884624838656,
|
1759 |
+
"pad_to_multiple_of": null,
|
1760 |
+
"pad_token": "<pad>",
|
1761 |
+
"pad_token_type_id": 0,
|
1762 |
+
"padding_side": "right",
|
1763 |
+
"processor_class": "PaliGemmaProcessor",
|
1764 |
+
"sp_model_kwargs": {},
|
1765 |
+
"spaces_between_special_tokens": false,
|
1766 |
+
"stride": 0,
|
1767 |
+
"tokenizer_class": "GemmaTokenizer",
|
1768 |
+
"truncation_side": "right",
|
1769 |
+
"truncation_strategy": "longest_first",
|
1770 |
+
"unk_token": "<unk>",
|
1771 |
+
"use_default_system_prompt": false
|
1772 |
+
}
|
value_query.py
ADDED
@@ -0,0 +1,1155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from copy import deepcopy
|
3 |
+
from functools import partial
|
4 |
+
from typing import Callable, Optional, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import array_typing as at
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from beartype import beartype as typechecker
|
11 |
+
from jaxtyping import Float, jaxtyped
|
12 |
+
from torch.distributions import Independent, Normal, TransformedDistribution
|
13 |
+
from torch.distributions.transforms import (
|
14 |
+
AffineTransform,
|
15 |
+
ComposeTransform,
|
16 |
+
TanhTransform,
|
17 |
+
)
|
18 |
+
from torch.optim import Adam, AdamW, Optimizer
|
19 |
+
from torch.optim.lr_scheduler import (
|
20 |
+
LambdaLR,
|
21 |
+
)
|
22 |
+
from transformers import (
|
23 |
+
AutoConfig,
|
24 |
+
GemmaForCausalLM,
|
25 |
+
PretrainedConfig,
|
26 |
+
PreTrainedModel,
|
27 |
+
)
|
28 |
+
from transformers.models.auto import CONFIG_MAPPING
|
29 |
+
|
30 |
+
|
31 |
+
def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor:
|
32 |
+
return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim)
|
33 |
+
|
34 |
+
|
35 |
+
def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False):
|
36 |
+
if isinstance(module, nn.Linear):
|
37 |
+
if orthogonal_init:
|
38 |
+
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
39 |
+
nn.init.constant_(module.bias, 0.0)
|
40 |
+
else:
|
41 |
+
nn.init.xavier_uniform_(module.weight, gain=1e-2)
|
42 |
+
|
43 |
+
|
44 |
+
class VQHBackboneConfig(PretrainedConfig):
|
45 |
+
model_type = "VQHBackbone"
|
46 |
+
sub_configs = {"gemma_expert_config": AutoConfig}
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
gemma_expert_config: dict | None = None,
|
51 |
+
attention_implementation: str = "eager",
|
52 |
+
**kwargs,
|
53 |
+
):
|
54 |
+
self.attention_implementation = attention_implementation
|
55 |
+
|
56 |
+
if gemma_expert_config is None:
|
57 |
+
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
58 |
+
attention_bias=False,
|
59 |
+
attention_dropout=0.0,
|
60 |
+
bos_token_id=2,
|
61 |
+
eos_token_id=1,
|
62 |
+
head_dim=256,
|
63 |
+
hidden_act="gelu_pytorch_tanh",
|
64 |
+
hidden_activation="gelu_pytorch_tanh",
|
65 |
+
hidden_size=2048,
|
66 |
+
initializer_range=0.02,
|
67 |
+
intermediate_size=4096,
|
68 |
+
max_position_embeddings=8192,
|
69 |
+
model_type="gemma",
|
70 |
+
num_attention_heads=8,
|
71 |
+
num_hidden_layers=4,
|
72 |
+
num_key_value_heads=1,
|
73 |
+
pad_token_id=0,
|
74 |
+
rms_norm_eps=1e-06,
|
75 |
+
rope_theta=10000.0,
|
76 |
+
torch_dtype="float32",
|
77 |
+
transformers_version="4.48.1",
|
78 |
+
use_cache=True,
|
79 |
+
vocab_size=257152,
|
80 |
+
)
|
81 |
+
elif isinstance(gemma_expert_config, dict):
|
82 |
+
if "model_type" not in gemma_expert_config:
|
83 |
+
gemma_expert_config["model_type"] = "gemma"
|
84 |
+
cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]]
|
85 |
+
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
86 |
+
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def __post_init__(self):
|
90 |
+
super().__post_init__()
|
91 |
+
if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
92 |
+
raise ValueError(
|
93 |
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def apply_rope(x, positions, max_wavelength=10_000):
|
98 |
+
"""
|
99 |
+
Applies RoPE positions [B, L] to x [B, L, H, D].
|
100 |
+
"""
|
101 |
+
d_half = x.shape[-1] // 2
|
102 |
+
device = x.device
|
103 |
+
dtype = x.dtype
|
104 |
+
x = x.to(torch.float32)
|
105 |
+
|
106 |
+
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(
|
107 |
+
d_half, dtype=torch.float32, device=device
|
108 |
+
)
|
109 |
+
timescale = max_wavelength**freq_exponents
|
110 |
+
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(
|
111 |
+
torch.float32
|
112 |
+
)
|
113 |
+
|
114 |
+
radians = radians[..., None, :]
|
115 |
+
|
116 |
+
sin = torch.sin(radians) # .to(dtype=dtype)
|
117 |
+
cos = torch.cos(radians) # .to(dtype=dtype)
|
118 |
+
|
119 |
+
x1, x2 = x.split(d_half, dim=-1)
|
120 |
+
res = torch.empty_like(x)
|
121 |
+
res[..., :d_half] = x1 * cos - x2 * sin
|
122 |
+
res[..., d_half:] = x2 * cos + x1 * sin
|
123 |
+
|
124 |
+
return res.to(dtype)
|
125 |
+
|
126 |
+
|
127 |
+
class VQHBackbone(PreTrainedModel):
|
128 |
+
config_class = VQHBackboneConfig
|
129 |
+
|
130 |
+
def __init__(self, config: VQHBackboneConfig):
|
131 |
+
super().__init__(config=config)
|
132 |
+
self.config = config
|
133 |
+
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
134 |
+
|
135 |
+
self.to_bfloat16_like_physical_intelligence()
|
136 |
+
|
137 |
+
def train(self, mode: bool = True):
|
138 |
+
super().train(mode)
|
139 |
+
|
140 |
+
def to_bfloat16_like_physical_intelligence(self):
|
141 |
+
params_to_change_dtype = [
|
142 |
+
"language_model.model.layers",
|
143 |
+
"gemma_expert.model.layers",
|
144 |
+
]
|
145 |
+
for name, param in self.named_parameters():
|
146 |
+
if any(selector in name for selector in params_to_change_dtype):
|
147 |
+
param.data = param.data.to(dtype=torch.bfloat16)
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self,
|
151 |
+
attention_mask: Optional[torch.Tensor] = None,
|
152 |
+
position_ids: Optional[torch.LongTensor] = None,
|
153 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
154 |
+
):
|
155 |
+
# RMSNorm
|
156 |
+
head_dim = self.gemma_expert.config.head_dim
|
157 |
+
|
158 |
+
hidden_states = inputs_embeds
|
159 |
+
batch_size = hidden_states.shape[0]
|
160 |
+
for layer in self.gemma_expert.model.layers[
|
161 |
+
: self.gemma_expert.config.num_hidden_layers
|
162 |
+
]:
|
163 |
+
# normalizer = torch.tensor(model.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
164 |
+
# hidden_states = hidden_states * normalizer
|
165 |
+
hidden_states = layer.input_layernorm(hidden_states)
|
166 |
+
input_shape = hidden_states.shape[:-1]
|
167 |
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
168 |
+
|
169 |
+
# self attention
|
170 |
+
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
171 |
+
query_states = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
172 |
+
key_states = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
173 |
+
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
174 |
+
|
175 |
+
query_states = apply_rope(query_states, position_ids)
|
176 |
+
key_states = apply_rope(key_states, position_ids)
|
177 |
+
|
178 |
+
attention_interface = self.get_attention_interface()
|
179 |
+
att_output = attention_interface(
|
180 |
+
attention_mask,
|
181 |
+
batch_size,
|
182 |
+
head_dim,
|
183 |
+
query_states,
|
184 |
+
key_states,
|
185 |
+
value_states,
|
186 |
+
)
|
187 |
+
|
188 |
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
189 |
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
190 |
+
|
191 |
+
out_emb = layer.self_attn.o_proj(att_output)
|
192 |
+
|
193 |
+
# first residual
|
194 |
+
out_emb += hidden_states
|
195 |
+
after_first_residual = out_emb.clone()
|
196 |
+
out_emb = layer.post_attention_layernorm(out_emb)
|
197 |
+
out_emb = layer.mlp(out_emb)
|
198 |
+
# second residual
|
199 |
+
out_emb += after_first_residual
|
200 |
+
hidden_states = out_emb
|
201 |
+
|
202 |
+
# final norm
|
203 |
+
hidden_states = self.gemma_expert.model.norm(hidden_states)
|
204 |
+
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
def get_attention_interface(self):
|
208 |
+
if self.config.attention_implementation == "fa2":
|
209 |
+
attention_interface = self.flash_attention_forward
|
210 |
+
else:
|
211 |
+
attention_interface = self.eager_attention_forward
|
212 |
+
return attention_interface
|
213 |
+
|
214 |
+
def eager_attention_forward(
|
215 |
+
self,
|
216 |
+
attention_mask,
|
217 |
+
batch_size,
|
218 |
+
head_dim,
|
219 |
+
query_states,
|
220 |
+
key_states,
|
221 |
+
value_states,
|
222 |
+
):
|
223 |
+
num_att_heads = self.config.gemma_expert_config.num_attention_heads
|
224 |
+
num_key_value_heads = self.config.gemma_expert_config.num_key_value_heads
|
225 |
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
226 |
+
|
227 |
+
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
228 |
+
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
229 |
+
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
230 |
+
sequence_length = key_states.shape[1]
|
231 |
+
|
232 |
+
key_states = key_states[:, :, :, None, :].expand(
|
233 |
+
batch_size,
|
234 |
+
sequence_length,
|
235 |
+
num_key_value_heads,
|
236 |
+
num_key_value_groups,
|
237 |
+
head_dim,
|
238 |
+
)
|
239 |
+
key_states = key_states.reshape(
|
240 |
+
batch_size,
|
241 |
+
sequence_length,
|
242 |
+
num_key_value_heads * num_key_value_groups,
|
243 |
+
head_dim,
|
244 |
+
)
|
245 |
+
|
246 |
+
value_states = value_states[:, :, :, None, :].expand(
|
247 |
+
batch_size,
|
248 |
+
sequence_length,
|
249 |
+
num_key_value_heads,
|
250 |
+
num_key_value_groups,
|
251 |
+
head_dim,
|
252 |
+
)
|
253 |
+
value_states = value_states.reshape(
|
254 |
+
batch_size,
|
255 |
+
sequence_length,
|
256 |
+
num_key_value_heads * num_key_value_groups,
|
257 |
+
head_dim,
|
258 |
+
)
|
259 |
+
|
260 |
+
# Attention here is upcasted to float32 to match the original eager implementation.
|
261 |
+
query_states = query_states.to(dtype=torch.float32)
|
262 |
+
key_states = key_states.to(dtype=torch.float32)
|
263 |
+
|
264 |
+
query_states = query_states.transpose(1, 2)
|
265 |
+
key_states = key_states.transpose(1, 2)
|
266 |
+
|
267 |
+
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
268 |
+
att_weights *= head_dim**-0.5
|
269 |
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
270 |
+
|
271 |
+
masked_att_weights = torch.where(
|
272 |
+
attention_mask[:, None, :, :], att_weights, big_neg
|
273 |
+
)
|
274 |
+
|
275 |
+
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
276 |
+
probs = probs.to(dtype=value_states.dtype)
|
277 |
+
|
278 |
+
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
279 |
+
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
280 |
+
|
281 |
+
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
282 |
+
|
283 |
+
att_output = att_output.permute(0, 2, 1, 3)
|
284 |
+
# we use -1 because sequence length can change
|
285 |
+
att_output = att_output.reshape(
|
286 |
+
batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim
|
287 |
+
)
|
288 |
+
|
289 |
+
return att_output
|
290 |
+
|
291 |
+
|
292 |
+
class LagrangeMultiplier(nn.Module):
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
init_value: float = 1.0,
|
296 |
+
constraint_shape: Tuple[int, ...] = (),
|
297 |
+
constraint_type: str = "eq", # One of ("eq", "leq", "geq")
|
298 |
+
parameterization: Optional[
|
299 |
+
str
|
300 |
+
] = None, # One of ("softplus", "exp"), or None for equality constraints
|
301 |
+
):
|
302 |
+
super().__init__()
|
303 |
+
self.constraint_type = constraint_type
|
304 |
+
self.parameterization = parameterization
|
305 |
+
|
306 |
+
if constraint_type != "eq":
|
307 |
+
assert (
|
308 |
+
init_value > 0
|
309 |
+
), "Inequality constraints must have non-negative initial multiplier values"
|
310 |
+
|
311 |
+
if parameterization == "softplus":
|
312 |
+
init_value = torch.log(torch.exp(torch.tensor(init_value)) - 1).item()
|
313 |
+
elif parameterization == "exp":
|
314 |
+
init_value = torch.log(torch.tensor(init_value)).item()
|
315 |
+
else:
|
316 |
+
raise ValueError(
|
317 |
+
f"Invalid multiplier parameterization {parameterization}"
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
assert (
|
321 |
+
parameterization is None
|
322 |
+
), "Equality constraints must have no parameterization"
|
323 |
+
|
324 |
+
self.multiplier = nn.Parameter(torch.full(constraint_shape, init_value))
|
325 |
+
|
326 |
+
def forward(
|
327 |
+
self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None
|
328 |
+
) -> torch.Tensor:
|
329 |
+
multiplier = self.multiplier
|
330 |
+
|
331 |
+
if self.constraint_type != "eq":
|
332 |
+
if self.parameterization == "softplus":
|
333 |
+
multiplier = torch.nn.functional.softplus(multiplier)
|
334 |
+
elif self.parameterization == "exp":
|
335 |
+
multiplier = torch.exp(multiplier)
|
336 |
+
else:
|
337 |
+
raise ValueError(
|
338 |
+
f"Invalid multiplier parameterization {self.parameterization}"
|
339 |
+
)
|
340 |
+
|
341 |
+
if lhs is None:
|
342 |
+
return multiplier
|
343 |
+
|
344 |
+
if rhs is None:
|
345 |
+
rhs = torch.zeros_like(lhs)
|
346 |
+
|
347 |
+
diff = lhs - rhs
|
348 |
+
|
349 |
+
assert (
|
350 |
+
diff.shape == multiplier.shape
|
351 |
+
), f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
352 |
+
|
353 |
+
if self.constraint_type == "eq":
|
354 |
+
return multiplier * diff
|
355 |
+
elif self.constraint_type == "geq":
|
356 |
+
return multiplier * diff
|
357 |
+
elif self.constraint_type == "leq":
|
358 |
+
return -multiplier * diff
|
359 |
+
|
360 |
+
|
361 |
+
GeqLagrangeMultiplier = partial(
|
362 |
+
LagrangeMultiplier, constraint_type="geq", parameterization="softplus"
|
363 |
+
)
|
364 |
+
|
365 |
+
LeqLagrangeMultiplier = partial(
|
366 |
+
LagrangeMultiplier, constraint_type="leq", parameterization="softplus"
|
367 |
+
)
|
368 |
+
|
369 |
+
|
370 |
+
class MLP(nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
input_dim: int,
|
374 |
+
hidden_dims: Sequence[int],
|
375 |
+
activations: Union[Callable[[torch.Tensor], torch.Tensor], str] = "silu",
|
376 |
+
activate_final: bool = False,
|
377 |
+
use_layer_norm: bool = False,
|
378 |
+
use_group_norm: bool = False,
|
379 |
+
dropout_rate: Optional[float] = None,
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
|
383 |
+
assert not (use_layer_norm and use_group_norm)
|
384 |
+
|
385 |
+
self.activate_final = activate_final
|
386 |
+
self.dropout_rate = dropout_rate
|
387 |
+
self.input_dim = input_dim
|
388 |
+
self.hidden_dims = hidden_dims
|
389 |
+
|
390 |
+
if isinstance(activations, str):
|
391 |
+
if activations == "silu" or activations == "swish":
|
392 |
+
self.activations = nn.SiLU()
|
393 |
+
else:
|
394 |
+
self.activations = getattr(nn, activations)()
|
395 |
+
else:
|
396 |
+
self.activations = activations
|
397 |
+
|
398 |
+
layers = []
|
399 |
+
|
400 |
+
for i, hidden_dim in enumerate(hidden_dims):
|
401 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
402 |
+
nn.init.xavier_uniform_(layers[-1].weight)
|
403 |
+
nn.init.zeros_(layers[-1].bias)
|
404 |
+
|
405 |
+
input_dim = hidden_dim
|
406 |
+
|
407 |
+
if i + 1 < len(hidden_dims) or activate_final:
|
408 |
+
if dropout_rate is not None and dropout_rate > 0:
|
409 |
+
layers.append(nn.Dropout(p=dropout_rate))
|
410 |
+
|
411 |
+
if use_layer_norm:
|
412 |
+
layers.append(nn.LayerNorm(hidden_dim))
|
413 |
+
elif use_group_norm:
|
414 |
+
num_groups = min(hidden_dim, 32)
|
415 |
+
layers.append(nn.GroupNorm(num_groups, hidden_dim))
|
416 |
+
layers.append(self.activations)
|
417 |
+
|
418 |
+
self.layers = nn.ModuleList(layers)
|
419 |
+
|
420 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
421 |
+
for layer in self.layers:
|
422 |
+
x = layer(x)
|
423 |
+
|
424 |
+
return x
|
425 |
+
|
426 |
+
|
427 |
+
class TanhMultivariateNormalDiag(TransformedDistribution):
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
loc: torch.Tensor,
|
431 |
+
scale_diag: torch.Tensor,
|
432 |
+
low: Optional[torch.Tensor] = None,
|
433 |
+
high: Optional[torch.Tensor] = None,
|
434 |
+
):
|
435 |
+
self.loc = loc
|
436 |
+
self.scale_diag = scale_diag
|
437 |
+
base_distribution = Independent(Normal(loc, scale_diag), 1)
|
438 |
+
|
439 |
+
transforms = []
|
440 |
+
transforms.append(TanhTransform())
|
441 |
+
if not (low is None or high is None):
|
442 |
+
transforms.append(
|
443 |
+
AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
|
444 |
+
)
|
445 |
+
transform = ComposeTransform(transforms)
|
446 |
+
|
447 |
+
super().__init__(base_distribution, transform)
|
448 |
+
|
449 |
+
def mode(self) -> torch.Tensor:
|
450 |
+
"""返回分布的众数"""
|
451 |
+
# 对于正态分布,众数就是均值
|
452 |
+
mode = self.loc
|
453 |
+
# 应用变换
|
454 |
+
for transform in self.transforms:
|
455 |
+
mode = transform(mode)
|
456 |
+
return mode
|
457 |
+
|
458 |
+
def stddev(self) -> torch.Tensor:
|
459 |
+
"""返回变换后的标准差(近似值)"""
|
460 |
+
# 注意:这只是一个近似,因为非线性变换后的标准差计算复杂
|
461 |
+
return self.transform(self.loc + self.scale_diag) - self.transform(self.loc)
|
462 |
+
|
463 |
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
464 |
+
eps = 1e-6
|
465 |
+
value = torch.clamp(value, -1 + eps, 1 - eps)
|
466 |
+
return super().log_prob(value)
|
467 |
+
|
468 |
+
|
469 |
+
class Policy(nn.Module):
|
470 |
+
def __init__(
|
471 |
+
self,
|
472 |
+
obs_encoded_dim: int,
|
473 |
+
network: nn.Module,
|
474 |
+
action_dim: int,
|
475 |
+
std_parameterization: str = "exp", # "exp", "softplus", "fixed", or "uniform"
|
476 |
+
std_min: Optional[float] = 1e-5,
|
477 |
+
std_max: Optional[float] = 10.0,
|
478 |
+
tanh_squash_distribution: bool = False,
|
479 |
+
fixed_std: Optional[torch.Tensor] = None,
|
480 |
+
):
|
481 |
+
super().__init__()
|
482 |
+
|
483 |
+
self.obs_encoded_dim = obs_encoded_dim
|
484 |
+
self.network = network
|
485 |
+
self.action_dim = action_dim
|
486 |
+
self.std_parameterization = std_parameterization
|
487 |
+
self.std_min = std_min
|
488 |
+
self.std_max = std_max
|
489 |
+
self.tanh_squash_distribution = tanh_squash_distribution
|
490 |
+
self.fixed_std = fixed_std
|
491 |
+
|
492 |
+
self.mean_layer = nn.Linear(network.hidden_dims[-1], action_dim)
|
493 |
+
|
494 |
+
if fixed_std is None:
|
495 |
+
if std_parameterization in ["exp", "softplus"]:
|
496 |
+
self.std_layer = nn.Linear(network.hidden_dims[-1], action_dim)
|
497 |
+
elif std_parameterization == "uniform":
|
498 |
+
self.log_stds = nn.Parameter(torch.zeros(action_dim))
|
499 |
+
else:
|
500 |
+
raise ValueError(
|
501 |
+
f"Invalid std_parameterization: {self.std_parameterization}"
|
502 |
+
)
|
503 |
+
else:
|
504 |
+
assert std_parameterization == "fixed"
|
505 |
+
|
506 |
+
nn.init.xavier_uniform_(self.mean_layer.weight)
|
507 |
+
nn.init.zeros_(self.mean_layer.bias)
|
508 |
+
|
509 |
+
if fixed_std is None and std_parameterization in ["exp", "softplus"]:
|
510 |
+
nn.init.xavier_uniform_(self.std_layer.weight)
|
511 |
+
nn.init.zeros_(self.std_layer.bias)
|
512 |
+
|
513 |
+
def forward(
|
514 |
+
self, encoded_observations: torch.Tensor, temperature: float = 1.0
|
515 |
+
) -> Union[TransformedDistribution, Normal]:
|
516 |
+
outputs = self.network(encoded_observations)
|
517 |
+
|
518 |
+
means = self.mean_layer(outputs)
|
519 |
+
|
520 |
+
if self.fixed_std is None:
|
521 |
+
if self.std_parameterization == "exp":
|
522 |
+
log_stds = self.std_layer(outputs)
|
523 |
+
stds = torch.exp(log_stds)
|
524 |
+
elif self.std_parameterization == "softplus":
|
525 |
+
stds = self.std_layer(outputs)
|
526 |
+
stds = nn.functional.softplus(stds)
|
527 |
+
elif self.std_parameterization == "uniform":
|
528 |
+
stds = torch.exp(self.log_stds).expand_as(means)
|
529 |
+
else:
|
530 |
+
raise ValueError(
|
531 |
+
f"Invalid std_parameterization: {self.std_parameterization}"
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
stds = self.fixed_std.to(means.device).expand_as(means)
|
535 |
+
|
536 |
+
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(
|
537 |
+
torch.tensor(temperature)
|
538 |
+
)
|
539 |
+
|
540 |
+
if self.tanh_squash_distribution:
|
541 |
+
distribution = TanhMultivariateNormalDiag(
|
542 |
+
loc=means,
|
543 |
+
scale_diag=stds,
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
distribution = Normal(loc=means, scale=stds)
|
547 |
+
|
548 |
+
return distribution
|
549 |
+
|
550 |
+
|
551 |
+
class Critics(nn.Module):
|
552 |
+
def __init__(
|
553 |
+
self,
|
554 |
+
obs_encoded_dim: int,
|
555 |
+
networks: list[nn.Module],
|
556 |
+
num_backbones: int = 2,
|
557 |
+
init_final: Optional[float] = None,
|
558 |
+
):
|
559 |
+
super().__init__()
|
560 |
+
assert len(networks) == num_backbones
|
561 |
+
self.obs_encoded_dim = obs_encoded_dim
|
562 |
+
self.networks = nn.ModuleList(networks)
|
563 |
+
self.num_backbones = num_backbones
|
564 |
+
self.init_final = init_final
|
565 |
+
|
566 |
+
self.backbone_output_dims = networks[0].hidden_dims[-1]
|
567 |
+
|
568 |
+
if init_final is not None:
|
569 |
+
self.output_layer = nn.Linear(self.backbone_output_dims, 1)
|
570 |
+
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
571 |
+
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
572 |
+
else:
|
573 |
+
self.output_layer = nn.Linear(self.backbone_output_dims, 1)
|
574 |
+
nn.init.xavier_uniform_(self.output_layer.weight)
|
575 |
+
nn.init.zeros_(self.output_layer.bias)
|
576 |
+
|
577 |
+
@jaxtyped(typechecker=typechecker)
|
578 |
+
def forward(
|
579 |
+
self,
|
580 |
+
encoded_observations: Float[torch.Tensor, "batch {self.obs_encoded_dim}"],
|
581 |
+
actions: Float[torch.Tensor, "batch *num_actions action_dim"],
|
582 |
+
) -> Float[torch.Tensor, "{self.num_backbones} batch *num_actions"]:
|
583 |
+
if actions.ndim == 3:
|
584 |
+
# forward the q function with multiple actions on each state
|
585 |
+
encoded_observations = encoded_observations.unsqueeze(1).expand(
|
586 |
+
-1, actions.shape[1], -1
|
587 |
+
)
|
588 |
+
# HACK: check dimensions here
|
589 |
+
inputs = torch.cat([encoded_observations, actions], dim=-1)
|
590 |
+
|
591 |
+
backbone_outputs = []
|
592 |
+
for network in self.networks:
|
593 |
+
backbone_outputs.append(network(inputs))
|
594 |
+
backbone_outputs: Float[
|
595 |
+
torch.Tensor,
|
596 |
+
"{self.num_backbones} batch *num_actions {self.backbone_output_dims}",
|
597 |
+
] = torch.stack(backbone_outputs, dim=0)
|
598 |
+
|
599 |
+
value = self.output_layer(backbone_outputs)
|
600 |
+
# HACK: check output shape here
|
601 |
+
# if actions.ndim == 3:
|
602 |
+
# value = value.squeeze(-1).permute(0, 2, 1)
|
603 |
+
# else:
|
604 |
+
value = value.squeeze(-1)
|
605 |
+
return value # (num_backbones, batch, *num_actions)
|
606 |
+
|
607 |
+
|
608 |
+
class CalQlConfig(PretrainedConfig):
|
609 |
+
moedel_type = "calql"
|
610 |
+
|
611 |
+
def __init__(
|
612 |
+
self,
|
613 |
+
obs_encoded_dim=2048,
|
614 |
+
action_dim=70,
|
615 |
+
actor_lr=1e-4,
|
616 |
+
critic_lr=3e-4,
|
617 |
+
temp_lr=3e-4,
|
618 |
+
actor_wps=2000,
|
619 |
+
critic_wps=2000,
|
620 |
+
**kwargs,
|
621 |
+
):
|
622 |
+
self.cql_clip_diff_min = -np.inf
|
623 |
+
self.cql_clip_diff_max = np.inf
|
624 |
+
self.cql_alpha = 5.0
|
625 |
+
self.cql_autotune_alpha = False
|
626 |
+
self.action_dim = action_dim
|
627 |
+
self.target_entropy = -self.action_dim
|
628 |
+
self.obs_encoded_dim = obs_encoded_dim
|
629 |
+
self.cql_temperature_init_value = 1.0
|
630 |
+
self.critic_ensemble_size = 2
|
631 |
+
self.cql_n_actions = 4
|
632 |
+
self.cql_max_target_backup = True
|
633 |
+
self.policy_network_kwargs = dict(
|
634 |
+
input_dim=self.obs_encoded_dim,
|
635 |
+
hidden_dims=[256, 256],
|
636 |
+
activate_final=True,
|
637 |
+
use_layer_norm=False,
|
638 |
+
)
|
639 |
+
self.critic_network_kwargs = dict(
|
640 |
+
input_dim=self.obs_encoded_dim + self.action_dim,
|
641 |
+
hidden_dims=[256, 256],
|
642 |
+
activate_final=True,
|
643 |
+
use_layer_norm=False,
|
644 |
+
)
|
645 |
+
self.policy_kwargs = dict(
|
646 |
+
tanh_squash_distribution=True,
|
647 |
+
std_parameterization="exp",
|
648 |
+
)
|
649 |
+
self.critic_subsample_size = None
|
650 |
+
self.cql_max_target_backup = True
|
651 |
+
self.backup_entropy = False
|
652 |
+
self.discount = 0.98
|
653 |
+
self.goal_conditioned = True
|
654 |
+
self.gc_kwargs = dict(
|
655 |
+
negative_proportion=0.0,
|
656 |
+
)
|
657 |
+
self.use_td_loss = True
|
658 |
+
self.cql_action_sample_method = "uniform"
|
659 |
+
self.cql_importance_sample = True
|
660 |
+
self.cql_temp = 1.0
|
661 |
+
self.use_calql = True
|
662 |
+
|
663 |
+
self.actor_optimizer_kwargs = dict(
|
664 |
+
learning_rate=actor_lr,
|
665 |
+
warmup_steps=actor_wps,
|
666 |
+
)
|
667 |
+
self.critic_optimizer_kwargs = dict(
|
668 |
+
learning_rate=critic_lr,
|
669 |
+
warmup_steps=critic_wps,
|
670 |
+
)
|
671 |
+
self.temperature_optimizer_kwargs = dict(learning_rate=temp_lr)
|
672 |
+
|
673 |
+
super().__init__(**kwargs)
|
674 |
+
|
675 |
+
|
676 |
+
class CalQL(PreTrainedModel):
|
677 |
+
config_calss = CalQlConfig
|
678 |
+
|
679 |
+
def __init__(self, config: CalQlConfig):
|
680 |
+
super(CalQL, self).__init__(config=config)
|
681 |
+
self.config = config
|
682 |
+
|
683 |
+
self.temperature = GeqLagrangeMultiplier(
|
684 |
+
init_value=self.config.cql_temperature_init_value,
|
685 |
+
constraint_shape=(),
|
686 |
+
)
|
687 |
+
|
688 |
+
self.policy = Policy(
|
689 |
+
obs_encoded_dim=self.config.obs_encoded_dim,
|
690 |
+
network=MLP(**self.config.policy_network_kwargs),
|
691 |
+
action_dim=self.config.action_dim,
|
692 |
+
**self.config.policy_kwargs,
|
693 |
+
)
|
694 |
+
|
695 |
+
self.critics = Critics(
|
696 |
+
obs_encoded_dim=self.config.obs_encoded_dim,
|
697 |
+
networks=[
|
698 |
+
MLP(**self.config.critic_network_kwargs)
|
699 |
+
for _ in range(self.config.critic_ensemble_size)
|
700 |
+
],
|
701 |
+
num_backbones=self.config.critic_ensemble_size,
|
702 |
+
)
|
703 |
+
|
704 |
+
self.target_critics = deepcopy(self.critics)
|
705 |
+
|
706 |
+
def forward_policy_and_sample(
|
707 |
+
self,
|
708 |
+
encoded_obs: Float[torch.Tensor, "batch {self.config.obs_encoded_dim}"],
|
709 |
+
repeat: int = None,
|
710 |
+
):
|
711 |
+
action_dist = self.policy.forward(encoded_obs)
|
712 |
+
if repeat:
|
713 |
+
new_actions = action_dist.rsample(
|
714 |
+
torch.tensor([repeat])
|
715 |
+
) # repeat, tensor, act_dim
|
716 |
+
log_pi = action_dist.log_prob(new_actions)
|
717 |
+
new_actions = new_actions.permute(1, 0, 2) # (batch, repeat, action_dim)
|
718 |
+
log_pi = log_pi.permute(1, 0) # (batch, repeat)
|
719 |
+
|
720 |
+
else:
|
721 |
+
new_actions = action_dist.rsample() # (batch, action_dim)
|
722 |
+
log_pi = action_dist.log_prob(new_actions) # (batch)
|
723 |
+
# NOTE: detach gradient here
|
724 |
+
new_actions = new_actions.detach()
|
725 |
+
log_pi = log_pi.detach()
|
726 |
+
return new_actions, log_pi
|
727 |
+
|
728 |
+
def _compute_next_actions(self, batch: at.CalQlBatch):
|
729 |
+
"""
|
730 |
+
compute the next actions but with repeat cql_n_actions times
|
731 |
+
this should only be used when calculating critic loss using
|
732 |
+
cql_max_target_backup
|
733 |
+
"""
|
734 |
+
sample_n_actions = (
|
735 |
+
self.config.cql_n_actions if self.config.cql_max_target_backup else None
|
736 |
+
)
|
737 |
+
|
738 |
+
next_actions, next_actions_log_probs = self.forward_policy_and_sample(
|
739 |
+
batch["encoded_next_observations"],
|
740 |
+
repeat=sample_n_actions,
|
741 |
+
)
|
742 |
+
return next_actions, next_actions_log_probs
|
743 |
+
|
744 |
+
def temperature_loss_fn(self, batch: at.CalQlBatch):
|
745 |
+
next_actions, next_actions_log_probs = self._compute_next_actions(batch)
|
746 |
+
|
747 |
+
entropy = -next_actions_log_probs.mean()
|
748 |
+
temperature_loss = self.temperature.forward(
|
749 |
+
lhs=entropy,
|
750 |
+
rhs=self.config.target_entropy,
|
751 |
+
)
|
752 |
+
return temperature_loss, {"temperature_loss": temperature_loss}
|
753 |
+
|
754 |
+
def policy_loss_fn(self, batch: at.CalQlBatch):
|
755 |
+
batch_size = batch["rewards"].shape[0]
|
756 |
+
temperature = self.temperature.forward().detach() # detach gradient
|
757 |
+
|
758 |
+
action_distributions = self.policy.forward(batch["encoded_observations"])
|
759 |
+
actions = action_distributions.rsample()
|
760 |
+
log_probs = action_distributions.log_prob(actions)
|
761 |
+
|
762 |
+
predicted_qs = self.critics.forward(
|
763 |
+
batch["encoded_observations"],
|
764 |
+
actions,
|
765 |
+
).detach() # NOTE: detach grads
|
766 |
+
predicted_q = predicted_qs.min(dim=0)[0]
|
767 |
+
|
768 |
+
assert predicted_q.shape == (batch_size,)
|
769 |
+
assert log_probs.shape == (batch_size,)
|
770 |
+
|
771 |
+
nll_objective = -torch.mean(
|
772 |
+
action_distributions.log_prob(torch.clip(batch["actions"], -0.99, 0.99))
|
773 |
+
)
|
774 |
+
actor_objective = predicted_q
|
775 |
+
actor_loss = -torch.mean(actor_objective) + torch.mean(temperature * log_probs)
|
776 |
+
|
777 |
+
info = {
|
778 |
+
"actor_loss": actor_loss,
|
779 |
+
"actor_nll": nll_objective,
|
780 |
+
"temperature": temperature,
|
781 |
+
"entropy": -log_probs.mean(),
|
782 |
+
"log_probs": log_probs,
|
783 |
+
"actions_mse": ((actions - batch["actions"]) ** 2).sum(dim=-1).mean(),
|
784 |
+
"dataset_rewards": batch["rewards"],
|
785 |
+
"mc_returns": batch.get("mc_returns", None),
|
786 |
+
}
|
787 |
+
|
788 |
+
return actor_loss, info
|
789 |
+
|
790 |
+
def sac_critic_loss_fn(self, batch: at.CalQlBatch):
|
791 |
+
"""classes that inherit this class can change this function"""
|
792 |
+
batch_size = batch["rewards"].shape[0]
|
793 |
+
next_actions, next_actions_log_probs = self._compute_next_actions(batch)
|
794 |
+
# (batch_size, ) for sac, (batch_size, cql_n_actions) for cql
|
795 |
+
|
796 |
+
# Evaluate next Qs for all ensemble members (cheap because we're only doing the forward pass)
|
797 |
+
with torch.no_grad():
|
798 |
+
self.target_critics.eval()
|
799 |
+
target_next_qs = self.target_critics.forward(
|
800 |
+
batch["encoded_next_observations"],
|
801 |
+
next_actions,
|
802 |
+
) # (critic_ensemble_size, batch_size, cql_n_actions)
|
803 |
+
self.target_critics.train()
|
804 |
+
|
805 |
+
# Subsample if requested
|
806 |
+
if self.config.critic_subsample_size is not None:
|
807 |
+
subsample_idcs = torch.randint(
|
808 |
+
0,
|
809 |
+
self.config.critic_ensemble_size,
|
810 |
+
(self.config.critic_ensemble_size,),
|
811 |
+
device=target_next_qs.device,
|
812 |
+
)
|
813 |
+
target_next_qs = target_next_qs[subsample_idcs]
|
814 |
+
|
815 |
+
# Minimum Q across (subsampled) ensemble members
|
816 |
+
target_next_min_q = target_next_qs.min(dim=0)[0]
|
817 |
+
assert target_next_min_q.shape == next_actions_log_probs.shape
|
818 |
+
# (batch_size,) for sac, (batch_size, cql_n_actions) for cql
|
819 |
+
|
820 |
+
target_next_min_q = self._process_target_next_qs(
|
821 |
+
target_next_min_q,
|
822 |
+
next_actions_log_probs,
|
823 |
+
)
|
824 |
+
|
825 |
+
target_q = (
|
826 |
+
batch["rewards"] + self.config.discount * batch["masks"] * target_next_min_q
|
827 |
+
)
|
828 |
+
assert target_q.shape == (batch_size,)
|
829 |
+
|
830 |
+
predicted_qs = self.critics.forward(
|
831 |
+
batch["encoded_observations"], batch["actions"]
|
832 |
+
)
|
833 |
+
assert predicted_qs.shape == (self.config.critic_ensemble_size, batch_size)
|
834 |
+
|
835 |
+
target_qs = target_q.unsqueeze(0).expand(self.config.critic_ensemble_size, -1)
|
836 |
+
assert predicted_qs.shape == target_qs.shape
|
837 |
+
critic_loss = torch.mean((predicted_qs - target_qs) ** 2)
|
838 |
+
|
839 |
+
info = {
|
840 |
+
"td_err": critic_loss,
|
841 |
+
"online_q": torch.mean(predicted_qs),
|
842 |
+
"target_q": torch.mean(target_qs),
|
843 |
+
}
|
844 |
+
|
845 |
+
if self.config.goal_conditioned:
|
846 |
+
num_negatives = int(
|
847 |
+
self.config.gc_kwargs["negative_proportion"] * batch_size
|
848 |
+
)
|
849 |
+
info["negative_qs"] = torch.mean(predicted_qs, dim=-1)[
|
850 |
+
:num_negatives
|
851 |
+
].mean()
|
852 |
+
info["positive_qs"] = torch.mean(predicted_qs, dim=-1)[
|
853 |
+
num_negatives:
|
854 |
+
].mean()
|
855 |
+
|
856 |
+
return critic_loss, info
|
857 |
+
|
858 |
+
def _process_target_next_qs(self, target_next_qs, next_actions_log_probs):
|
859 |
+
"""add cql_max_target_backup option"""
|
860 |
+
|
861 |
+
if self.config.cql_max_target_backup:
|
862 |
+
max_target_indices = torch.argmax(target_next_qs, dim=-1, keepdim=True)
|
863 |
+
target_next_qs = torch.gather(
|
864 |
+
target_next_qs, -1, max_target_indices
|
865 |
+
).squeeze(-1)
|
866 |
+
next_actions_log_probs = torch.gather(
|
867 |
+
next_actions_log_probs, -1, max_target_indices
|
868 |
+
).squeeze(-1)
|
869 |
+
|
870 |
+
target_next_qs = self.sac_process_target_next_qs(
|
871 |
+
target_next_qs,
|
872 |
+
next_actions_log_probs,
|
873 |
+
)
|
874 |
+
|
875 |
+
return target_next_qs
|
876 |
+
|
877 |
+
def sac_process_target_next_qs(self, target_next_qs, next_actions_log_probs):
|
878 |
+
"""classes that inherit this class can add to this function
|
879 |
+
e.g. CQL will add the cql_max_target_backup option
|
880 |
+
"""
|
881 |
+
if self.config.backup_entropy:
|
882 |
+
temperature = self.forward_temperature()
|
883 |
+
target_next_qs = target_next_qs - temperature * next_actions_log_probs
|
884 |
+
|
885 |
+
return target_next_qs
|
886 |
+
|
887 |
+
def critic_loss_fn(self, batch: at.CalQlBatch):
|
888 |
+
"""add CQL loss on top of SAC loss"""
|
889 |
+
if self.config.use_td_loss:
|
890 |
+
td_loss, td_loss_info = self.sac_critic_loss_fn(batch)
|
891 |
+
else:
|
892 |
+
td_loss, td_loss_info = 0.0, {}
|
893 |
+
|
894 |
+
cql_q_diff, cql_intermediate_results = self._get_cql_q_diff(batch)
|
895 |
+
|
896 |
+
"""auto tune cql alpha"""
|
897 |
+
if self.config.cql_autotune_alpha:
|
898 |
+
raise NotImplementedError
|
899 |
+
# alpha = self.forward_cql_alpha_lagrange()
|
900 |
+
# cql_loss = (cql_q_diff - self.config["cql_target_action_gap"]).mean()
|
901 |
+
else:
|
902 |
+
alpha = self.config.cql_alpha
|
903 |
+
cql_loss = torch.clip(
|
904 |
+
cql_q_diff, self.config.cql_clip_diff_min, self.config.cql_clip_diff_max
|
905 |
+
).mean()
|
906 |
+
|
907 |
+
critic_loss = td_loss + alpha * cql_loss
|
908 |
+
|
909 |
+
info = {
|
910 |
+
**td_loss_info,
|
911 |
+
"critic_loss": critic_loss,
|
912 |
+
"td_err": td_loss,
|
913 |
+
"cql_loss": cql_loss,
|
914 |
+
"cql_alpha": alpha,
|
915 |
+
"cql_diff": cql_q_diff.mean(),
|
916 |
+
**cql_intermediate_results,
|
917 |
+
}
|
918 |
+
|
919 |
+
return critic_loss, info
|
920 |
+
|
921 |
+
def _get_cql_q_diff(self, batch: at.CalQlBatch):
|
922 |
+
"""
|
923 |
+
most of the CQL loss logic is here
|
924 |
+
It is needed for both critic_loss_fn and cql_alpha_loss_fn
|
925 |
+
"""
|
926 |
+
batch_size = batch["rewards"].shape[0]
|
927 |
+
|
928 |
+
q_pred = self.critics.forward(batch["encoded_observations"], batch["actions"])
|
929 |
+
# HACK: shape changed from jax implementation
|
930 |
+
assert q_pred.shape == (self.config.critic_ensemble_size, batch_size)
|
931 |
+
|
932 |
+
"""sample random actions"""
|
933 |
+
action_dim = batch["actions"].shape[-1]
|
934 |
+
if self.config.cql_action_sample_method == "uniform":
|
935 |
+
cql_random_actions = (
|
936 |
+
torch.rand(
|
937 |
+
(batch_size, self.config.cql_n_actions, action_dim),
|
938 |
+
device=batch["actions"].device,
|
939 |
+
)
|
940 |
+
* 2.0
|
941 |
+
- 1.0
|
942 |
+
)
|
943 |
+
elif self.config.cql_action_sample_method == "normal":
|
944 |
+
cql_random_actions = torch.randn(
|
945 |
+
(batch_size, self.config.cql_n_actions, action_dim),
|
946 |
+
device=batch["actions"].device,
|
947 |
+
)
|
948 |
+
else:
|
949 |
+
raise NotImplementedError
|
950 |
+
|
951 |
+
cql_current_actions, cql_current_log_pis = self.forward_policy_and_sample(
|
952 |
+
batch["encoded_observations"],
|
953 |
+
repeat=self.config.cql_n_actions,
|
954 |
+
)
|
955 |
+
assert cql_current_log_pis.shape == (batch_size, self.config.cql_n_actions)
|
956 |
+
|
957 |
+
cql_next_actions, cql_next_log_pis = self.forward_policy_and_sample(
|
958 |
+
batch["encoded_next_observations"],
|
959 |
+
repeat=self.config.cql_n_actions,
|
960 |
+
)
|
961 |
+
|
962 |
+
all_sampled_actions = torch.cat(
|
963 |
+
[
|
964 |
+
cql_random_actions,
|
965 |
+
cql_current_actions,
|
966 |
+
cql_next_actions,
|
967 |
+
],
|
968 |
+
dim=1,
|
969 |
+
)
|
970 |
+
|
971 |
+
"""q values of randomly sampled actions"""
|
972 |
+
cql_q_samples = self.critics.forward(
|
973 |
+
batch["encoded_observations"], all_sampled_actions
|
974 |
+
)
|
975 |
+
# HACK: shape changed from jax implementation
|
976 |
+
assert cql_q_samples.shape == (
|
977 |
+
self.config.critic_ensemble_size,
|
978 |
+
batch_size,
|
979 |
+
self.config.cql_n_actions * 3,
|
980 |
+
)
|
981 |
+
|
982 |
+
if self.config.critic_subsample_size is not None:
|
983 |
+
subsample_idcs = torch.randint(
|
984 |
+
0,
|
985 |
+
self.config.critic_ensemble_size,
|
986 |
+
(self.config.critic_ensemble_size,),
|
987 |
+
device=cql_q_samples.device,
|
988 |
+
)
|
989 |
+
cql_q_samples = cql_q_samples[subsample_idcs]
|
990 |
+
|
991 |
+
"""Cal-QL"""
|
992 |
+
if self.config.use_calql:
|
993 |
+
# HACK: check shape of mc_returns
|
994 |
+
mc_lower_bound = (
|
995 |
+
batch["mc_returns"]
|
996 |
+
.reshape(-1, 1)
|
997 |
+
.repeat(1, self.config.cql_n_actions * 2)
|
998 |
+
)
|
999 |
+
assert mc_lower_bound.shape == (
|
1000 |
+
batch_size,
|
1001 |
+
self.config.cql_n_actions * 2,
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
cql_q_pi = cql_q_samples[:, :, self.config.cql_n_actions :]
|
1005 |
+
num_vals = cql_q_pi.numel()
|
1006 |
+
calql_bound_rate = torch.sum((cql_q_pi < mc_lower_bound).float()) / num_vals
|
1007 |
+
cql_q_pi = torch.maximum(cql_q_pi, mc_lower_bound)
|
1008 |
+
cql_q_samples = torch.cat(
|
1009 |
+
[
|
1010 |
+
cql_q_samples[:, :, : self.config.cql_n_actions],
|
1011 |
+
cql_q_pi,
|
1012 |
+
],
|
1013 |
+
dim=-1,
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
if self.config.cql_importance_sample:
|
1017 |
+
random_density = torch.log(
|
1018 |
+
torch.tensor(0.5**action_dim, device=cql_q_samples.device)
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
importance_prob = torch.cat(
|
1022 |
+
[
|
1023 |
+
random_density.expand(batch_size, self.config.cql_n_actions),
|
1024 |
+
cql_current_log_pis,
|
1025 |
+
cql_next_log_pis,
|
1026 |
+
],
|
1027 |
+
dim=1,
|
1028 |
+
)
|
1029 |
+
# HACK: check dim
|
1030 |
+
cql_q_samples = cql_q_samples - importance_prob.unsqueeze(0)
|
1031 |
+
else:
|
1032 |
+
cql_q_samples = torch.cat([cql_q_samples, q_pred.unsqueeze(-1)], dim=-1)
|
1033 |
+
|
1034 |
+
cql_q_samples -= (
|
1035 |
+
torch.log(
|
1036 |
+
torch.tensor(
|
1037 |
+
cql_q_samples.shape[-1],
|
1038 |
+
dtype=torch.float,
|
1039 |
+
device=cql_q_samples.device,
|
1040 |
+
)
|
1041 |
+
)
|
1042 |
+
* self.config.cql_temp
|
1043 |
+
)
|
1044 |
+
# HACK: shape diff from jax implementation
|
1045 |
+
assert cql_q_samples.shape == (
|
1046 |
+
self.config.critic_ensemble_size,
|
1047 |
+
batch_size,
|
1048 |
+
3 * self.config.cql_n_actions + 1,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
"""log sum exp of the ood actions"""
|
1052 |
+
cql_ood_values = (
|
1053 |
+
torch.logsumexp(cql_q_samples / self.config.cql_temp, dim=-1)
|
1054 |
+
* self.config.cql_temp
|
1055 |
+
)
|
1056 |
+
assert cql_ood_values.shape == (self.config.critic_ensemble_size, batch_size)
|
1057 |
+
|
1058 |
+
cql_q_diff = cql_ood_values - q_pred
|
1059 |
+
info = {
|
1060 |
+
"cql_ood_values": cql_ood_values.mean(),
|
1061 |
+
}
|
1062 |
+
if self.config.use_calql:
|
1063 |
+
info["calql_bound_rate"] = calql_bound_rate
|
1064 |
+
|
1065 |
+
return cql_q_diff, info
|
1066 |
+
|
1067 |
+
@staticmethod
|
1068 |
+
def make_optimizer(
|
1069 |
+
params: torch.nn.Module,
|
1070 |
+
learning_rate: float = 3e-4,
|
1071 |
+
warmup_steps: int = 0,
|
1072 |
+
cosine_decay_steps: Optional[int] = None,
|
1073 |
+
weight_decay: Optional[float] = None,
|
1074 |
+
return_lr_schedule: bool = True,
|
1075 |
+
) -> Union[Optimizer, Tuple[Optimizer, LambdaLR]]:
|
1076 |
+
optimizer: Optimizer
|
1077 |
+
if weight_decay is not None:
|
1078 |
+
optimizer = AdamW(
|
1079 |
+
params=params,
|
1080 |
+
lr=learning_rate,
|
1081 |
+
weight_decay=weight_decay,
|
1082 |
+
)
|
1083 |
+
else:
|
1084 |
+
optimizer = Adam(params=params, lr=learning_rate)
|
1085 |
+
|
1086 |
+
def _lr_lambda(step: int) -> float:
|
1087 |
+
if warmup_steps > 0 and step < warmup_steps:
|
1088 |
+
return step / warmup_steps
|
1089 |
+
|
1090 |
+
if cosine_decay_steps is not None:
|
1091 |
+
decay_step = step - warmup_steps
|
1092 |
+
if decay_step < 0:
|
1093 |
+
return 0.0
|
1094 |
+
if decay_step >= cosine_decay_steps:
|
1095 |
+
return 0.0
|
1096 |
+
progress = decay_step / cosine_decay_steps
|
1097 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
1098 |
+
|
1099 |
+
return 1.0
|
1100 |
+
|
1101 |
+
scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda)
|
1102 |
+
|
1103 |
+
if return_lr_schedule:
|
1104 |
+
return optimizer, scheduler
|
1105 |
+
else:
|
1106 |
+
return optimizer
|
1107 |
+
|
1108 |
+
def prepare_optimizers(self):
|
1109 |
+
actor_optimizer, actor_scheduler = self.make_optimizer(
|
1110 |
+
self.policy.parameters(), **self.config.actor_optimizer_kwargs
|
1111 |
+
)
|
1112 |
+
critic_optimizer, critic_scheduler = self.make_optimizer(
|
1113 |
+
self.critics.parameters(), **self.config.critic_optimizer_kwargs
|
1114 |
+
)
|
1115 |
+
temperature_optimizer, temperature_scheduler = self.make_optimizer(
|
1116 |
+
self.temperature.parameters(), **self.config.temperature_optimizer_kwargs
|
1117 |
+
)
|
1118 |
+
|
1119 |
+
return (
|
1120 |
+
actor_optimizer,
|
1121 |
+
actor_scheduler,
|
1122 |
+
critic_optimizer,
|
1123 |
+
critic_scheduler,
|
1124 |
+
temperature_optimizer,
|
1125 |
+
temperature_scheduler,
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
def forward(self, batch: at.CalQlBatch):
|
1129 |
+
temperature_loss, temperature_loss_info = self.temperature_loss_fn(batch)
|
1130 |
+
policy_loss, policy_loss_info = self.policy_loss_fn(batch)
|
1131 |
+
critic_loss, critic_loss_info = self.critic_loss_fn(batch)
|
1132 |
+
|
1133 |
+
return (
|
1134 |
+
temperature_loss,
|
1135 |
+
policy_loss,
|
1136 |
+
critic_loss,
|
1137 |
+
{
|
1138 |
+
**temperature_loss_info,
|
1139 |
+
**policy_loss_info,
|
1140 |
+
**critic_loss_info,
|
1141 |
+
},
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
@jaxtyped(typechecker=typechecker)
|
1145 |
+
def get_q_values(
|
1146 |
+
self,
|
1147 |
+
encoded_observations: Float[
|
1148 |
+
torch.Tensor, "batch {self.config.obs_encoded_dim}"
|
1149 |
+
],
|
1150 |
+
noise_actions: Float[torch.Tensor, "batch num_actions action_dim"],
|
1151 |
+
) -> Float[torch.Tensor, "batch num_actions"]:
|
1152 |
+
# (num_backbones, batch, *num_actions)
|
1153 |
+
q_values = self.target_critics.forward(encoded_observations, noise_actions)
|
1154 |
+
q_values = q_values.min(dim=0)[0]
|
1155 |
+
return q_values
|