Upload 19 files
Browse files- Inference-Spatial--20000_chkpt--97.8.log +0 -0
- README.md +217 -3
- action_head--20000_checkpoint.pt +3 -0
- added_tokens.json +24 -0
- config.json +3182 -0
- configuration_prismatic.py +144 -0
- dataset_statistics.json +133 -0
- generation_config.json +7 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- modeling_prismatic.py +1499 -0
- preprocessor_config.json +114 -0
- processing_prismatic.py +257 -0
- processor_config.json +6 -0
- proprio_projector--20000_checkpoint.pt +3 -0
- special_tokens_map.json +31 -0
- tokenizer.json +0 -0
- tokenizer_config.json +211 -0
- vocab.json +0 -0
Inference-Spatial--20000_chkpt--97.8.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,3 +1,217 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
tags:
|
4 |
+
- Vision-Language-Action
|
5 |
+
- OpenHelix Team
|
6 |
+
base_model:
|
7 |
+
- Qwen/Qwen2.5-0.5B
|
8 |
+
language:
|
9 |
+
- en
|
10 |
+
pipeline_tag: robotics
|
11 |
+
---
|
12 |
+
|
13 |
+
|
14 |
+
# Model Card for VLA-Adapter Libero-Long
|
15 |
+
VLA-Adapter: An Effective Paradigm for Tiny-Scale Vision-Language-Action Model trained on Libero-Goal.
|
16 |
+
- Project page: [https://vla-adapter.github.io/](https://vla-adapter.github.io/)
|
17 |
+
- Dataset: [https://huggingface.co/datasets/openvla/modified_libero_rlds/tree/main](https://huggingface.co/datasets/openvla/modified_libero_rlds/tree/main)
|
18 |
+
|
19 |
+
## Model Details
|
20 |
+
We have developed and released the VLA-Adapter family of VLA models, a series of fine-tuned generative
|
21 |
+
action models. The VLA-Adapter VLM follows the Prismatic-VLM architecture, using only a very small backbone
|
22 |
+
(Qwen2.5-0.5B) for the LLM. On common robotics benchmarks, it surpasses open-source VLA models with 8.5B,
|
23 |
+
7B, 4B, 3B, and 2B backbones.
|
24 |
+
|
25 |
+
**Input:** Models input image and text.
|
26 |
+
|
27 |
+
**Output:** Models generate action only.
|
28 |
+
|
29 |
+
**Model Architecture:** The VLA-Adapter consists of a VLM for receiving and processing image and text
|
30 |
+
information and a policy for generating actions. We systematically analyzed the benefits that the VLM
|
31 |
+
provides to different types of policy conditions and determined a unified framework. We then utilized
|
32 |
+
our designed Bridge Attention module to fuse the conditions generated by the VLM with the initial action
|
33 |
+
information in the policy, bridging the gap between VL and A to the greatest extent possible.
|
34 |
+
This resulted in a high-performance VLA model on a tiny-scale backbone.
|
35 |
+
|
36 |
+
|
37 |
+
### Success Rate Comparison
|
38 |
+
<table>
|
39 |
+
<tr>
|
40 |
+
<td><strong>Category</strong>
|
41 |
+
</td>
|
42 |
+
<td><strong>Methods</strong>
|
43 |
+
</td>
|
44 |
+
<td><strong>Scale</strong>
|
45 |
+
</td>
|
46 |
+
<td><strong>LIBERO-Spatial</strong>
|
47 |
+
</td>
|
48 |
+
<td><strong>LIBERO-Object</strong>
|
49 |
+
</td>
|
50 |
+
<td><strong>LIBERO-Goal</strong>
|
51 |
+
</td>
|
52 |
+
<td><strong>LIBERO-Long</strong>
|
53 |
+
</td>
|
54 |
+
<td><strong>Avg.</strong>
|
55 |
+
</td>
|
56 |
+
</tr>
|
57 |
+
<tr>
|
58 |
+
<td rowspan="9">Large-scale</td>
|
59 |
+
<td>FlowVLA (Zhong et al., 2025)</td>
|
60 |
+
<td>8.5B</td><td>93.2</td><td>95.0</td><td>91.6</td><td>72.6</td><td>88.1</td>
|
61 |
+
</tr>
|
62 |
+
|
63 |
+
<tr>
|
64 |
+
<td>OpenVLA (Kim et al., 2024)</td>
|
65 |
+
<td>7B</td><td>84.7</td><td>88.4</td><td>79.2</td><td>53.7</td><td>76.5</td>
|
66 |
+
</tr>
|
67 |
+
|
68 |
+
<tr>
|
69 |
+
<td>OpenVLA-OFT (Kim et al., 2025)</td>
|
70 |
+
<td>7B</td><td><i><u>97.6*</u></i></td><td>98.4</td><td><b>97.9</b></td><td><b>94.5</b></td><td><b>97.1</b></td>
|
71 |
+
</tr>
|
72 |
+
|
73 |
+
<tr>
|
74 |
+
<td>UniVLA (Bu et al., 2025)</td>
|
75 |
+
<td>7B</td><td>96.5</td><td> 96.8</td><td> 95.6 </td><td>92.0 </td><td>95.2</td>
|
76 |
+
</tr>
|
77 |
+
|
78 |
+
<tr>
|
79 |
+
<td>CoT-VLA (Zhao et al., 2025)</td>
|
80 |
+
<td>7B</td><td>87.5 </td><td>91.6 </td><td>87.6</td><td> 69.0</td><td> 81.1</td>
|
81 |
+
</tr>
|
82 |
+
|
83 |
+
<tr>
|
84 |
+
<td>WorldVLA (Cen et al., 2025)</td>
|
85 |
+
<td>7B</td><td>87.6</td><td> 96.2</td><td> 83.4</td><td> 60.0</td><td> 81.8</td>
|
86 |
+
</tr>
|
87 |
+
|
88 |
+
<tr>
|
89 |
+
<td>TraceVLA (Zheng et al., 2025)</td>
|
90 |
+
<td>7B</td><td>84.6</td><td> 85.2</td><td> 75.1</td><td> 54.1</td><td> 74.8</td>
|
91 |
+
</tr>
|
92 |
+
|
93 |
+
<tr>
|
94 |
+
<td>MolmoAct (Lee et al., 2025)</td>
|
95 |
+
<td>7B</td><td>87.0</td><td> 95.4 </td><td>87.6</td><td> 77.2 </td><td>86.6</td>
|
96 |
+
</tr>
|
97 |
+
|
98 |
+
<tr>
|
99 |
+
<td>ThinkAct (Huang et al., 2025)</td>
|
100 |
+
<td>7B</td><td>88.3 </td><td>91.4</td><td> 87.1</td><td> 70.9</td><td> 84.4</td>
|
101 |
+
</tr>
|
102 |
+
|
103 |
+
<tr>
|
104 |
+
<td rowspan="7">Small-scale</td>
|
105 |
+
<td>4D-VLA (Zhang et al., 2025)</td>
|
106 |
+
<td>4B</td><td>88.9</td><td> 95.2</td><td> 90.9</td><td> 79.1 </td><td>88.6</td>
|
107 |
+
</tr>
|
108 |
+
|
109 |
+
<tr>
|
110 |
+
<td>SpatialVLA (Qu et al., 2025)</td>
|
111 |
+
<td>4B</td><td>88.2</td><td> 89.9</td><td> 78.6</td><td> 55.5 </td><td>78.1</td>
|
112 |
+
</tr>
|
113 |
+
|
114 |
+
<tr>
|
115 |
+
<td>π0 (Black et al., 2024)</td>
|
116 |
+
<td>3B</td><td>96.8</td><td> <i><u>98.8*</u></i> </td><td>95.8</td><td> 85.2</td><td> 94.2</td>
|
117 |
+
</tr>
|
118 |
+
|
119 |
+
<tr>
|
120 |
+
<td>π0-FAST (Pertsch et al., 2025)</td>
|
121 |
+
<td>3B</td><td>96.4</td><td> 96.8 </td><td>88.6</td><td> 60.2</td><td> 85.5</td>
|
122 |
+
</tr>
|
123 |
+
|
124 |
+
<tr>
|
125 |
+
<td>NORA (Hung et al., 2025)</td>
|
126 |
+
<td>3B</td><td>92.2 </td><td>95.4 </td><td>89.4</td><td> 74.6 </td><td>87.9</td>
|
127 |
+
</tr>
|
128 |
+
|
129 |
+
<tr>
|
130 |
+
<td>SmolVLA (Shukor et al., 2025)</td>
|
131 |
+
<td>2.2B</td><td>93.0</td><td> 94.0 </td><td>91.0</td><td> 77.0 </td><td>88.8</td>
|
132 |
+
</tr>
|
133 |
+
|
134 |
+
<tr>
|
135 |
+
<td>GR00T N1 (NVIDIA et al., 2025)</td>
|
136 |
+
<td>2B</td><td>94.4</td><td> 97.6 </td><td>93.0 </td><td>90.6</td><td> 93.9</td>
|
137 |
+
</tr>
|
138 |
+
|
139 |
+
<tr>
|
140 |
+
<td rowspan="4">Tiny-scale</td>
|
141 |
+
<td>Seer (Tian et al., 2025)</td>
|
142 |
+
<td>0.57B</td><td>-</td><td> - </td><td>- </td><td>78.7</td><td> 78.7</td>
|
143 |
+
</tr>
|
144 |
+
|
145 |
+
<tr>
|
146 |
+
<td>VLA-OS (Gao et al., 2025)</td>
|
147 |
+
<td>0.5B</td><td>87.0 </td><td>96.5</td><td> 92.7 </td><td>66.0</td><td> 85.6</td>
|
148 |
+
</tr>
|
149 |
+
|
150 |
+
<tr>
|
151 |
+
<td>Diffusion Policy (Chi et al., 2023)</td>
|
152 |
+
<td>-</td><td>78.3</td><td> 92.5</td><td> 68.3 </td><td>50.5 </td><td>72.4</td>
|
153 |
+
</tr>
|
154 |
+
|
155 |
+
<tr>
|
156 |
+
<td><b>VLA-Adapter (Ours)</b></td>
|
157 |
+
<td><b>0.5B</b></td><td><b>97.8</b></td><td> <b>99.2</b> </td><td><i><u>97.2*</u></i></td><td> <i><u>93.4* </u></i></td><td><i><u>96.9*</u></i></td>
|
158 |
+
</tr>
|
159 |
+
|
160 |
+
</table>
|
161 |
+
|
162 |
+
### Effectiveness Comparison
|
163 |
+
|
164 |
+
<table>
|
165 |
+
<tr>
|
166 |
+
<td></td>
|
167 |
+
<td><strong>OpenVLA-OFT</strong></td>
|
168 |
+
<td><strong>VLA-Adapter</strong></td>
|
169 |
+
<td></td>
|
170 |
+
</tr>
|
171 |
+
|
172 |
+
<tr>
|
173 |
+
<td>Backbone</td>
|
174 |
+
<td>7B</td>
|
175 |
+
<td><strong>0.5B</strong></td>
|
176 |
+
<td>1/14×</td>
|
177 |
+
</tr>
|
178 |
+
|
179 |
+
<tr>
|
180 |
+
<td>Fine-Tuning Cost</td>
|
181 |
+
<td>304GPU·h</td>
|
182 |
+
<td><strong>8GPU·h</strong></td>
|
183 |
+
<td>1/38×</td>
|
184 |
+
</tr>
|
185 |
+
|
186 |
+
<tr>
|
187 |
+
<td>Training VRAM (8 batch)</td>
|
188 |
+
<td>62GB</td>
|
189 |
+
<td><strong>24.7GB</strong></td>
|
190 |
+
<td>0.4×</td>
|
191 |
+
</tr>
|
192 |
+
|
193 |
+
<tr>
|
194 |
+
<td>Throughput (8 chunk)</td>
|
195 |
+
<td>109.7Hz</td>
|
196 |
+
<td><strong>219.2Hz</strong></td>
|
197 |
+
<td>2×</td>
|
198 |
+
</tr>
|
199 |
+
|
200 |
+
<tr>
|
201 |
+
<td>Performance</td>
|
202 |
+
<td>97.1%</td>
|
203 |
+
<td><strong>96.9%</strong></td>
|
204 |
+
<td>Maintain</td>
|
205 |
+
</tr>
|
206 |
+
</table>
|
207 |
+
|
208 |
+
## Citation instructions
|
209 |
+
|
210 |
+
```BibTeX
|
211 |
+
@article{Wang2025VLAAdapter,
|
212 |
+
author = {Wang, Yihao and Ding, Pengxiang and Li, Lingxiao and Cui, Can and Ge, Zirui and Tong, Xinyang and Song, Wenxuan and Zhao, Han and Zhao, Wei and Hou, Pengxu and Huang, Siteng and Tang, Yifan and Wang, Wenhui and Zhang, Ru and Liu, Jianyi and Wang, Donglin},
|
213 |
+
title = {VLA-Adapter: An Effective Paradigm for Tiny-Scale Vision-Language-Action Model},
|
214 |
+
journal = {ArXiv},
|
215 |
+
year = {2025}
|
216 |
+
}
|
217 |
+
```
|
action_head--20000_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:587679b0388d6e559344ca67ce8d10fe167fccbfeb1945c8d3d01deda2107f6b
|
3 |
+
size 204387786
|
added_tokens.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</tool_call>": 151658,
|
3 |
+
"<tool_call>": 151657,
|
4 |
+
"<|box_end|>": 151649,
|
5 |
+
"<|box_start|>": 151648,
|
6 |
+
"<|endoftext|>": 151643,
|
7 |
+
"<|file_sep|>": 151664,
|
8 |
+
"<|fim_middle|>": 151660,
|
9 |
+
"<|fim_pad|>": 151662,
|
10 |
+
"<|fim_prefix|>": 151659,
|
11 |
+
"<|fim_suffix|>": 151661,
|
12 |
+
"<|im_end|>": 151645,
|
13 |
+
"<|im_start|>": 151644,
|
14 |
+
"<|image_pad|>": 151655,
|
15 |
+
"<|object_ref_end|>": 151647,
|
16 |
+
"<|object_ref_start|>": 151646,
|
17 |
+
"<|quad_end|>": 151651,
|
18 |
+
"<|quad_start|>": 151650,
|
19 |
+
"<|repo_name|>": 151663,
|
20 |
+
"<|video_pad|>": 151656,
|
21 |
+
"<|vision_end|>": 151653,
|
22 |
+
"<|vision_pad|>": 151654,
|
23 |
+
"<|vision_start|>": 151652
|
24 |
+
}
|
config.json
ADDED
@@ -0,0 +1,3182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "../openvla-oft/pretrained_models/minivla/config.json",
|
3 |
+
"arch_specifier": "no-align+fused-gelu-mlp",
|
4 |
+
"architectures": [
|
5 |
+
"OpenVLAForActionPrediction"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_prismatic.OpenVLAConfig",
|
9 |
+
"AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
|
10 |
+
},
|
11 |
+
"hf_llm_id": "",
|
12 |
+
"image_resize_strategy": "resize-naive",
|
13 |
+
"image_sizes": [
|
14 |
+
224,
|
15 |
+
224
|
16 |
+
],
|
17 |
+
"llm_backbone_id": "qwen25-0_5b-extra",
|
18 |
+
"llm_max_length": 2048,
|
19 |
+
"model_type": "openvla",
|
20 |
+
"n_action_bins": 256,
|
21 |
+
"norm_stats": {
|
22 |
+
"austin_buds_dataset_converted_externally_to_rlds": {
|
23 |
+
"action": {
|
24 |
+
"mask": [
|
25 |
+
true,
|
26 |
+
true,
|
27 |
+
true,
|
28 |
+
true,
|
29 |
+
true,
|
30 |
+
true,
|
31 |
+
false
|
32 |
+
],
|
33 |
+
"max": [
|
34 |
+
1.0,
|
35 |
+
1.0,
|
36 |
+
1.0,
|
37 |
+
0.0,
|
38 |
+
0.0,
|
39 |
+
0.0,
|
40 |
+
1.0
|
41 |
+
],
|
42 |
+
"mean": [
|
43 |
+
-0.07678354531526566,
|
44 |
+
0.0036849044263362885,
|
45 |
+
0.05644911900162697,
|
46 |
+
0.0,
|
47 |
+
0.0,
|
48 |
+
0.0,
|
49 |
+
0.3510494828224182
|
50 |
+
],
|
51 |
+
"min": [
|
52 |
+
-1.0,
|
53 |
+
-1.0,
|
54 |
+
-1.0,
|
55 |
+
0.0,
|
56 |
+
0.0,
|
57 |
+
0.0,
|
58 |
+
0.0
|
59 |
+
],
|
60 |
+
"q01": [
|
61 |
+
-1.0,
|
62 |
+
-0.9599999785423279,
|
63 |
+
-0.8714285492897034,
|
64 |
+
0.0,
|
65 |
+
0.0,
|
66 |
+
0.0,
|
67 |
+
0.0
|
68 |
+
],
|
69 |
+
"q99": [
|
70 |
+
1.0,
|
71 |
+
0.8600000143051147,
|
72 |
+
1.0,
|
73 |
+
0.0,
|
74 |
+
0.0,
|
75 |
+
0.0,
|
76 |
+
1.0
|
77 |
+
],
|
78 |
+
"std": [
|
79 |
+
0.6367740631103516,
|
80 |
+
0.37889179587364197,
|
81 |
+
0.47796326875686646,
|
82 |
+
0.0,
|
83 |
+
0.0,
|
84 |
+
0.0,
|
85 |
+
0.47721168398857117
|
86 |
+
]
|
87 |
+
},
|
88 |
+
"num_trajectories": 50,
|
89 |
+
"num_transitions": 34112,
|
90 |
+
"proprio": {
|
91 |
+
"max": [
|
92 |
+
0.0,
|
93 |
+
0.0,
|
94 |
+
0.0,
|
95 |
+
0.0,
|
96 |
+
0.0,
|
97 |
+
0.0,
|
98 |
+
0.0
|
99 |
+
],
|
100 |
+
"mean": [
|
101 |
+
0.0,
|
102 |
+
0.0,
|
103 |
+
0.0,
|
104 |
+
0.0,
|
105 |
+
0.0,
|
106 |
+
0.0,
|
107 |
+
0.0
|
108 |
+
],
|
109 |
+
"min": [
|
110 |
+
0.0,
|
111 |
+
0.0,
|
112 |
+
0.0,
|
113 |
+
0.0,
|
114 |
+
0.0,
|
115 |
+
0.0,
|
116 |
+
0.0
|
117 |
+
],
|
118 |
+
"q01": [
|
119 |
+
0.0,
|
120 |
+
0.0,
|
121 |
+
0.0,
|
122 |
+
0.0,
|
123 |
+
0.0,
|
124 |
+
0.0,
|
125 |
+
0.0
|
126 |
+
],
|
127 |
+
"q99": [
|
128 |
+
0.0,
|
129 |
+
0.0,
|
130 |
+
0.0,
|
131 |
+
0.0,
|
132 |
+
0.0,
|
133 |
+
0.0,
|
134 |
+
0.0
|
135 |
+
],
|
136 |
+
"std": [
|
137 |
+
0.0,
|
138 |
+
0.0,
|
139 |
+
0.0,
|
140 |
+
0.0,
|
141 |
+
0.0,
|
142 |
+
0.0,
|
143 |
+
0.0
|
144 |
+
]
|
145 |
+
}
|
146 |
+
},
|
147 |
+
"austin_sailor_dataset_converted_externally_to_rlds": {
|
148 |
+
"action": {
|
149 |
+
"mask": [
|
150 |
+
true,
|
151 |
+
true,
|
152 |
+
true,
|
153 |
+
true,
|
154 |
+
true,
|
155 |
+
true,
|
156 |
+
false
|
157 |
+
],
|
158 |
+
"max": [
|
159 |
+
1.0,
|
160 |
+
1.0,
|
161 |
+
1.0,
|
162 |
+
0.0,
|
163 |
+
0.0,
|
164 |
+
0.375,
|
165 |
+
1.0
|
166 |
+
],
|
167 |
+
"mean": [
|
168 |
+
0.011825348250567913,
|
169 |
+
0.006461074110120535,
|
170 |
+
0.06023626774549484,
|
171 |
+
0.0,
|
172 |
+
0.0,
|
173 |
+
0.0016465914668515325,
|
174 |
+
0.5260950326919556
|
175 |
+
],
|
176 |
+
"min": [
|
177 |
+
-1.0,
|
178 |
+
-1.0,
|
179 |
+
-1.0,
|
180 |
+
0.0,
|
181 |
+
0.0,
|
182 |
+
-0.375,
|
183 |
+
0.0
|
184 |
+
],
|
185 |
+
"q01": [
|
186 |
+
-1.0,
|
187 |
+
-0.9828571677207947,
|
188 |
+
-0.6000000238418579,
|
189 |
+
0.0,
|
190 |
+
0.0,
|
191 |
+
-0.17249999940395355,
|
192 |
+
0.0
|
193 |
+
],
|
194 |
+
"q99": [
|
195 |
+
1.0,
|
196 |
+
0.9457142949104309,
|
197 |
+
1.0,
|
198 |
+
0.0,
|
199 |
+
0.0,
|
200 |
+
0.17892856895923615,
|
201 |
+
1.0
|
202 |
+
],
|
203 |
+
"std": [
|
204 |
+
0.46348899602890015,
|
205 |
+
0.41240179538726807,
|
206 |
+
0.411862850189209,
|
207 |
+
0.0,
|
208 |
+
0.0,
|
209 |
+
0.0578610822558403,
|
210 |
+
0.49894046783447266
|
211 |
+
]
|
212 |
+
},
|
213 |
+
"num_trajectories": 240,
|
214 |
+
"num_transitions": 353094,
|
215 |
+
"proprio": {
|
216 |
+
"max": [
|
217 |
+
0.0,
|
218 |
+
0.0,
|
219 |
+
0.0,
|
220 |
+
0.0,
|
221 |
+
0.0,
|
222 |
+
0.0,
|
223 |
+
0.0
|
224 |
+
],
|
225 |
+
"mean": [
|
226 |
+
0.0,
|
227 |
+
0.0,
|
228 |
+
0.0,
|
229 |
+
0.0,
|
230 |
+
0.0,
|
231 |
+
0.0,
|
232 |
+
0.0
|
233 |
+
],
|
234 |
+
"min": [
|
235 |
+
0.0,
|
236 |
+
0.0,
|
237 |
+
0.0,
|
238 |
+
0.0,
|
239 |
+
0.0,
|
240 |
+
0.0,
|
241 |
+
0.0
|
242 |
+
],
|
243 |
+
"q01": [
|
244 |
+
0.0,
|
245 |
+
0.0,
|
246 |
+
0.0,
|
247 |
+
0.0,
|
248 |
+
0.0,
|
249 |
+
0.0,
|
250 |
+
0.0
|
251 |
+
],
|
252 |
+
"q99": [
|
253 |
+
0.0,
|
254 |
+
0.0,
|
255 |
+
0.0,
|
256 |
+
0.0,
|
257 |
+
0.0,
|
258 |
+
0.0,
|
259 |
+
0.0
|
260 |
+
],
|
261 |
+
"std": [
|
262 |
+
0.0,
|
263 |
+
0.0,
|
264 |
+
0.0,
|
265 |
+
0.0,
|
266 |
+
0.0,
|
267 |
+
0.0,
|
268 |
+
0.0
|
269 |
+
]
|
270 |
+
}
|
271 |
+
},
|
272 |
+
"austin_sirius_dataset_converted_externally_to_rlds": {
|
273 |
+
"action": {
|
274 |
+
"mask": [
|
275 |
+
true,
|
276 |
+
true,
|
277 |
+
true,
|
278 |
+
true,
|
279 |
+
true,
|
280 |
+
true,
|
281 |
+
false
|
282 |
+
],
|
283 |
+
"max": [
|
284 |
+
1.0002285242080688,
|
285 |
+
0.960608720779419,
|
286 |
+
1.105179786682129,
|
287 |
+
0.0,
|
288 |
+
0.0,
|
289 |
+
0.341785728931427,
|
290 |
+
1.0
|
291 |
+
],
|
292 |
+
"mean": [
|
293 |
+
0.07747682929039001,
|
294 |
+
0.03195561468601227,
|
295 |
+
0.04244732856750488,
|
296 |
+
0.0,
|
297 |
+
0.0,
|
298 |
+
-0.01603456400334835,
|
299 |
+
0.43260177969932556
|
300 |
+
],
|
301 |
+
"min": [
|
302 |
+
-1.0183025598526,
|
303 |
+
-0.9800000190734863,
|
304 |
+
-0.9774575233459473,
|
305 |
+
0.0,
|
306 |
+
0.0,
|
307 |
+
-0.34607142210006714,
|
308 |
+
0.0
|
309 |
+
],
|
310 |
+
"q01": [
|
311 |
+
-0.780905865430832,
|
312 |
+
-0.5667179036140442,
|
313 |
+
-0.5254343223571777,
|
314 |
+
0.0,
|
315 |
+
0.0,
|
316 |
+
-0.28495091378688814,
|
317 |
+
0.0
|
318 |
+
],
|
319 |
+
"q99": [
|
320 |
+
0.9569637751579284,
|
321 |
+
0.6971374487876891,
|
322 |
+
0.8124888157844541,
|
323 |
+
0.0,
|
324 |
+
0.0,
|
325 |
+
0.1971428543329239,
|
326 |
+
1.0
|
327 |
+
],
|
328 |
+
"std": [
|
329 |
+
0.3906329572200775,
|
330 |
+
0.2998155355453491,
|
331 |
+
0.2782271206378937,
|
332 |
+
0.0,
|
333 |
+
0.0,
|
334 |
+
0.08120622485876083,
|
335 |
+
0.49528297781944275
|
336 |
+
]
|
337 |
+
},
|
338 |
+
"num_trajectories": 559,
|
339 |
+
"num_transitions": 279939,
|
340 |
+
"proprio": {
|
341 |
+
"max": [
|
342 |
+
0.0,
|
343 |
+
0.0,
|
344 |
+
0.0,
|
345 |
+
0.0,
|
346 |
+
0.0,
|
347 |
+
0.0,
|
348 |
+
0.0
|
349 |
+
],
|
350 |
+
"mean": [
|
351 |
+
0.0,
|
352 |
+
0.0,
|
353 |
+
0.0,
|
354 |
+
0.0,
|
355 |
+
0.0,
|
356 |
+
0.0,
|
357 |
+
0.0
|
358 |
+
],
|
359 |
+
"min": [
|
360 |
+
0.0,
|
361 |
+
0.0,
|
362 |
+
0.0,
|
363 |
+
0.0,
|
364 |
+
0.0,
|
365 |
+
0.0,
|
366 |
+
0.0
|
367 |
+
],
|
368 |
+
"q01": [
|
369 |
+
0.0,
|
370 |
+
0.0,
|
371 |
+
0.0,
|
372 |
+
0.0,
|
373 |
+
0.0,
|
374 |
+
0.0,
|
375 |
+
0.0
|
376 |
+
],
|
377 |
+
"q99": [
|
378 |
+
0.0,
|
379 |
+
0.0,
|
380 |
+
0.0,
|
381 |
+
0.0,
|
382 |
+
0.0,
|
383 |
+
0.0,
|
384 |
+
0.0
|
385 |
+
],
|
386 |
+
"std": [
|
387 |
+
0.0,
|
388 |
+
0.0,
|
389 |
+
0.0,
|
390 |
+
0.0,
|
391 |
+
0.0,
|
392 |
+
0.0,
|
393 |
+
0.0
|
394 |
+
]
|
395 |
+
}
|
396 |
+
},
|
397 |
+
"bc_z": {
|
398 |
+
"action": {
|
399 |
+
"mask": [
|
400 |
+
true,
|
401 |
+
true,
|
402 |
+
true,
|
403 |
+
true,
|
404 |
+
true,
|
405 |
+
true,
|
406 |
+
false
|
407 |
+
],
|
408 |
+
"max": [
|
409 |
+
0.2165454924106598,
|
410 |
+
0.1251407265663147,
|
411 |
+
0.10772687941789627,
|
412 |
+
0.33544227480888367,
|
413 |
+
0.28117990493774414,
|
414 |
+
0.40614867210388184,
|
415 |
+
1.0
|
416 |
+
],
|
417 |
+
"mean": [
|
418 |
+
-0.009958467446267605,
|
419 |
+
0.0008958321413956583,
|
420 |
+
0.004995597992092371,
|
421 |
+
0.00029755113064311445,
|
422 |
+
-0.008735382929444313,
|
423 |
+
-0.030693737789988518,
|
424 |
+
0.8344562649726868
|
425 |
+
],
|
426 |
+
"min": [
|
427 |
+
-0.1677047461271286,
|
428 |
+
-0.14630407094955444,
|
429 |
+
-0.10066790133714676,
|
430 |
+
-0.29421567916870117,
|
431 |
+
-0.32101404666900635,
|
432 |
+
-0.4635624885559082,
|
433 |
+
0.0
|
434 |
+
],
|
435 |
+
"q01": [
|
436 |
+
-0.09220654994249344,
|
437 |
+
-0.06456145539879798,
|
438 |
+
-0.049121275544166565,
|
439 |
+
-0.11594625547528267,
|
440 |
+
-0.14152548640966414,
|
441 |
+
-0.2251061636209488,
|
442 |
+
0.0
|
443 |
+
],
|
444 |
+
"q99": [
|
445 |
+
0.07628866866230968,
|
446 |
+
0.058019736707210584,
|
447 |
+
0.052540797740221024,
|
448 |
+
0.11740604028105736,
|
449 |
+
0.11703975558280955,
|
450 |
+
0.16729306846857078,
|
451 |
+
1.0
|
452 |
+
],
|
453 |
+
"std": [
|
454 |
+
0.03053455986082554,
|
455 |
+
0.0231423731893301,
|
456 |
+
0.020641816779971123,
|
457 |
+
0.04155943542718887,
|
458 |
+
0.046427831053733826,
|
459 |
+
0.0769818127155304,
|
460 |
+
0.3610210120677948
|
461 |
+
]
|
462 |
+
},
|
463 |
+
"num_trajectories": 43264,
|
464 |
+
"num_transitions": 6015535,
|
465 |
+
"proprio": {
|
466 |
+
"max": [
|
467 |
+
0.0,
|
468 |
+
0.0,
|
469 |
+
0.0,
|
470 |
+
0.0,
|
471 |
+
0.0,
|
472 |
+
0.0,
|
473 |
+
0.0
|
474 |
+
],
|
475 |
+
"mean": [
|
476 |
+
0.0,
|
477 |
+
0.0,
|
478 |
+
0.0,
|
479 |
+
0.0,
|
480 |
+
0.0,
|
481 |
+
0.0,
|
482 |
+
0.0
|
483 |
+
],
|
484 |
+
"min": [
|
485 |
+
0.0,
|
486 |
+
0.0,
|
487 |
+
0.0,
|
488 |
+
0.0,
|
489 |
+
0.0,
|
490 |
+
0.0,
|
491 |
+
0.0
|
492 |
+
],
|
493 |
+
"q01": [
|
494 |
+
0.0,
|
495 |
+
0.0,
|
496 |
+
0.0,
|
497 |
+
0.0,
|
498 |
+
0.0,
|
499 |
+
0.0,
|
500 |
+
0.0
|
501 |
+
],
|
502 |
+
"q99": [
|
503 |
+
0.0,
|
504 |
+
0.0,
|
505 |
+
0.0,
|
506 |
+
0.0,
|
507 |
+
0.0,
|
508 |
+
0.0,
|
509 |
+
0.0
|
510 |
+
],
|
511 |
+
"std": [
|
512 |
+
0.0,
|
513 |
+
0.0,
|
514 |
+
0.0,
|
515 |
+
0.0,
|
516 |
+
0.0,
|
517 |
+
0.0,
|
518 |
+
0.0
|
519 |
+
]
|
520 |
+
}
|
521 |
+
},
|
522 |
+
"berkeley_autolab_ur5": {
|
523 |
+
"action": {
|
524 |
+
"mask": [
|
525 |
+
true,
|
526 |
+
true,
|
527 |
+
true,
|
528 |
+
true,
|
529 |
+
true,
|
530 |
+
true,
|
531 |
+
false
|
532 |
+
],
|
533 |
+
"max": [
|
534 |
+
0.019999999552965164,
|
535 |
+
0.019999999552965164,
|
536 |
+
0.019999999552965164,
|
537 |
+
0.06666667014360428,
|
538 |
+
0.06666667014360428,
|
539 |
+
0.06666667014360428,
|
540 |
+
1.0
|
541 |
+
],
|
542 |
+
"mean": [
|
543 |
+
0.0005683620693162084,
|
544 |
+
0.001217700308188796,
|
545 |
+
-0.0005296372692100704,
|
546 |
+
0.00021029810886830091,
|
547 |
+
6.0695128922816366e-05,
|
548 |
+
0.001204986940138042,
|
549 |
+
0.6298308372497559
|
550 |
+
],
|
551 |
+
"min": [
|
552 |
+
-0.019999999552965164,
|
553 |
+
-0.019999999552965164,
|
554 |
+
-0.019999999552965164,
|
555 |
+
-0.06666667014360428,
|
556 |
+
-0.06666667014360428,
|
557 |
+
-0.06666667014360428,
|
558 |
+
0.0
|
559 |
+
],
|
560 |
+
"q01": [
|
561 |
+
-0.019999999552965164,
|
562 |
+
-0.019999999552965164,
|
563 |
+
-0.019999999552965164,
|
564 |
+
-0.02628571353852749,
|
565 |
+
-0.06666667014360428,
|
566 |
+
-0.03847619146108627,
|
567 |
+
0.0
|
568 |
+
],
|
569 |
+
"q99": [
|
570 |
+
0.019999999552965164,
|
571 |
+
0.019999999552965164,
|
572 |
+
0.019999999552965164,
|
573 |
+
0.031809523701667786,
|
574 |
+
0.06666667014360428,
|
575 |
+
0.036571428179740906,
|
576 |
+
1.0
|
577 |
+
],
|
578 |
+
"std": [
|
579 |
+
0.0115329809486866,
|
580 |
+
0.007990492507815361,
|
581 |
+
0.009577835910022259,
|
582 |
+
0.009432995691895485,
|
583 |
+
0.016427582129836082,
|
584 |
+
0.011053967289626598,
|
585 |
+
0.48267969489097595
|
586 |
+
]
|
587 |
+
},
|
588 |
+
"num_trajectories": 1000,
|
589 |
+
"num_transitions": 97939,
|
590 |
+
"proprio": {
|
591 |
+
"max": [
|
592 |
+
0.0,
|
593 |
+
0.0,
|
594 |
+
0.0,
|
595 |
+
0.0,
|
596 |
+
0.0,
|
597 |
+
0.0,
|
598 |
+
0.0
|
599 |
+
],
|
600 |
+
"mean": [
|
601 |
+
0.0,
|
602 |
+
0.0,
|
603 |
+
0.0,
|
604 |
+
0.0,
|
605 |
+
0.0,
|
606 |
+
0.0,
|
607 |
+
0.0
|
608 |
+
],
|
609 |
+
"min": [
|
610 |
+
0.0,
|
611 |
+
0.0,
|
612 |
+
0.0,
|
613 |
+
0.0,
|
614 |
+
0.0,
|
615 |
+
0.0,
|
616 |
+
0.0
|
617 |
+
],
|
618 |
+
"q01": [
|
619 |
+
0.0,
|
620 |
+
0.0,
|
621 |
+
0.0,
|
622 |
+
0.0,
|
623 |
+
0.0,
|
624 |
+
0.0,
|
625 |
+
0.0
|
626 |
+
],
|
627 |
+
"q99": [
|
628 |
+
0.0,
|
629 |
+
0.0,
|
630 |
+
0.0,
|
631 |
+
0.0,
|
632 |
+
0.0,
|
633 |
+
0.0,
|
634 |
+
0.0
|
635 |
+
],
|
636 |
+
"std": [
|
637 |
+
0.0,
|
638 |
+
0.0,
|
639 |
+
0.0,
|
640 |
+
0.0,
|
641 |
+
0.0,
|
642 |
+
0.0,
|
643 |
+
0.0
|
644 |
+
]
|
645 |
+
}
|
646 |
+
},
|
647 |
+
"berkeley_cable_routing": {
|
648 |
+
"action": {
|
649 |
+
"mask": [
|
650 |
+
true,
|
651 |
+
true,
|
652 |
+
true,
|
653 |
+
true,
|
654 |
+
true,
|
655 |
+
true,
|
656 |
+
false
|
657 |
+
],
|
658 |
+
"max": [
|
659 |
+
0.9633283019065857,
|
660 |
+
1.0,
|
661 |
+
1.0,
|
662 |
+
0.0,
|
663 |
+
0.0,
|
664 |
+
1.0,
|
665 |
+
0.0
|
666 |
+
],
|
667 |
+
"mean": [
|
668 |
+
-0.07139874249696732,
|
669 |
+
0.023609008640050888,
|
670 |
+
0.10241943597793579,
|
671 |
+
0.0,
|
672 |
+
0.0,
|
673 |
+
0.049671024084091187,
|
674 |
+
0.0
|
675 |
+
],
|
676 |
+
"min": [
|
677 |
+
-0.9809081554412842,
|
678 |
+
-0.9554349184036255,
|
679 |
+
-0.9994775056838989,
|
680 |
+
0.0,
|
681 |
+
0.0,
|
682 |
+
-1.0,
|
683 |
+
0.0
|
684 |
+
],
|
685 |
+
"q01": [
|
686 |
+
-0.5534318816661835,
|
687 |
+
-0.4797285574674606,
|
688 |
+
-0.5314934802055359,
|
689 |
+
0.0,
|
690 |
+
0.0,
|
691 |
+
-0.8855219376087189,
|
692 |
+
0.0
|
693 |
+
],
|
694 |
+
"q99": [
|
695 |
+
0.42652835428714786,
|
696 |
+
0.5000944086909298,
|
697 |
+
0.639823433756829,
|
698 |
+
0.0,
|
699 |
+
0.0,
|
700 |
+
0.984243879914284,
|
701 |
+
0.0
|
702 |
+
],
|
703 |
+
"std": [
|
704 |
+
0.1815500408411026,
|
705 |
+
0.1810990273952484,
|
706 |
+
0.21220779418945312,
|
707 |
+
0.0,
|
708 |
+
0.0,
|
709 |
+
0.3475511968135834,
|
710 |
+
0.0
|
711 |
+
]
|
712 |
+
},
|
713 |
+
"num_trajectories": 1647,
|
714 |
+
"num_transitions": 42328,
|
715 |
+
"proprio": {
|
716 |
+
"max": [
|
717 |
+
0.0,
|
718 |
+
0.0,
|
719 |
+
0.0,
|
720 |
+
0.0,
|
721 |
+
0.0,
|
722 |
+
0.0,
|
723 |
+
0.0
|
724 |
+
],
|
725 |
+
"mean": [
|
726 |
+
0.0,
|
727 |
+
0.0,
|
728 |
+
0.0,
|
729 |
+
0.0,
|
730 |
+
0.0,
|
731 |
+
0.0,
|
732 |
+
0.0
|
733 |
+
],
|
734 |
+
"min": [
|
735 |
+
0.0,
|
736 |
+
0.0,
|
737 |
+
0.0,
|
738 |
+
0.0,
|
739 |
+
0.0,
|
740 |
+
0.0,
|
741 |
+
0.0
|
742 |
+
],
|
743 |
+
"q01": [
|
744 |
+
0.0,
|
745 |
+
0.0,
|
746 |
+
0.0,
|
747 |
+
0.0,
|
748 |
+
0.0,
|
749 |
+
0.0,
|
750 |
+
0.0
|
751 |
+
],
|
752 |
+
"q99": [
|
753 |
+
0.0,
|
754 |
+
0.0,
|
755 |
+
0.0,
|
756 |
+
0.0,
|
757 |
+
0.0,
|
758 |
+
0.0,
|
759 |
+
0.0
|
760 |
+
],
|
761 |
+
"std": [
|
762 |
+
0.0,
|
763 |
+
0.0,
|
764 |
+
0.0,
|
765 |
+
0.0,
|
766 |
+
0.0,
|
767 |
+
0.0,
|
768 |
+
0.0
|
769 |
+
]
|
770 |
+
}
|
771 |
+
},
|
772 |
+
"berkeley_fanuc_manipulation": {
|
773 |
+
"action": {
|
774 |
+
"mask": [
|
775 |
+
true,
|
776 |
+
true,
|
777 |
+
true,
|
778 |
+
true,
|
779 |
+
true,
|
780 |
+
true,
|
781 |
+
false
|
782 |
+
],
|
783 |
+
"max": [
|
784 |
+
0.009999999776482582,
|
785 |
+
0.009999999776482582,
|
786 |
+
0.009999999776482582,
|
787 |
+
0.03490658476948738,
|
788 |
+
0.03490658476948738,
|
789 |
+
0.03490658476948738,
|
790 |
+
1.0
|
791 |
+
],
|
792 |
+
"mean": [
|
793 |
+
0.0007744057802483439,
|
794 |
+
-0.00031240080716088414,
|
795 |
+
-0.0015001941937953234,
|
796 |
+
-0.0007515158504247665,
|
797 |
+
-0.00015832878125365824,
|
798 |
+
0.00014327642566058785,
|
799 |
+
0.699295699596405
|
800 |
+
],
|
801 |
+
"min": [
|
802 |
+
-0.009999999776482582,
|
803 |
+
-0.009999999776482582,
|
804 |
+
-0.009999999776482582,
|
805 |
+
-0.03490658476948738,
|
806 |
+
-0.03490658476948738,
|
807 |
+
-0.03490658476948738,
|
808 |
+
0.0
|
809 |
+
],
|
810 |
+
"q01": [
|
811 |
+
-0.009999999776482582,
|
812 |
+
-0.009999999776482582,
|
813 |
+
-0.009999999776482582,
|
814 |
+
-0.03490658476948738,
|
815 |
+
0.0,
|
816 |
+
-0.03490658476948738,
|
817 |
+
0.0
|
818 |
+
],
|
819 |
+
"q99": [
|
820 |
+
0.009999999776482582,
|
821 |
+
0.009999999776482582,
|
822 |
+
0.009999999776482582,
|
823 |
+
0.03490658476948738,
|
824 |
+
0.0,
|
825 |
+
0.03490658476948738,
|
826 |
+
1.0
|
827 |
+
],
|
828 |
+
"std": [
|
829 |
+
0.0034070091787725687,
|
830 |
+
0.0049921851605176926,
|
831 |
+
0.005344334989786148,
|
832 |
+
0.00759894959628582,
|
833 |
+
0.004081866703927517,
|
834 |
+
0.008568956516683102,
|
835 |
+
0.4586937427520752
|
836 |
+
]
|
837 |
+
},
|
838 |
+
"num_trajectories": 415,
|
839 |
+
"num_transitions": 62613,
|
840 |
+
"proprio": {
|
841 |
+
"max": [
|
842 |
+
0.0,
|
843 |
+
0.0,
|
844 |
+
0.0,
|
845 |
+
0.0,
|
846 |
+
0.0,
|
847 |
+
0.0,
|
848 |
+
0.0
|
849 |
+
],
|
850 |
+
"mean": [
|
851 |
+
0.0,
|
852 |
+
0.0,
|
853 |
+
0.0,
|
854 |
+
0.0,
|
855 |
+
0.0,
|
856 |
+
0.0,
|
857 |
+
0.0
|
858 |
+
],
|
859 |
+
"min": [
|
860 |
+
0.0,
|
861 |
+
0.0,
|
862 |
+
0.0,
|
863 |
+
0.0,
|
864 |
+
0.0,
|
865 |
+
0.0,
|
866 |
+
0.0
|
867 |
+
],
|
868 |
+
"q01": [
|
869 |
+
0.0,
|
870 |
+
0.0,
|
871 |
+
0.0,
|
872 |
+
0.0,
|
873 |
+
0.0,
|
874 |
+
0.0,
|
875 |
+
0.0
|
876 |
+
],
|
877 |
+
"q99": [
|
878 |
+
0.0,
|
879 |
+
0.0,
|
880 |
+
0.0,
|
881 |
+
0.0,
|
882 |
+
0.0,
|
883 |
+
0.0,
|
884 |
+
0.0
|
885 |
+
],
|
886 |
+
"std": [
|
887 |
+
0.0,
|
888 |
+
0.0,
|
889 |
+
0.0,
|
890 |
+
0.0,
|
891 |
+
0.0,
|
892 |
+
0.0,
|
893 |
+
0.0
|
894 |
+
]
|
895 |
+
}
|
896 |
+
},
|
897 |
+
"bridge_orig": {
|
898 |
+
"action": {
|
899 |
+
"mask": [
|
900 |
+
true,
|
901 |
+
true,
|
902 |
+
true,
|
903 |
+
true,
|
904 |
+
true,
|
905 |
+
true,
|
906 |
+
false
|
907 |
+
],
|
908 |
+
"max": [
|
909 |
+
0.41691166162490845,
|
910 |
+
0.25864794850349426,
|
911 |
+
0.21218234300613403,
|
912 |
+
3.122201919555664,
|
913 |
+
1.8618112802505493,
|
914 |
+
6.280478477478027,
|
915 |
+
1.0
|
916 |
+
],
|
917 |
+
"mean": [
|
918 |
+
0.0002334194869035855,
|
919 |
+
0.00013004911306779832,
|
920 |
+
-0.00012762474943883717,
|
921 |
+
-0.0001556558854645118,
|
922 |
+
-0.0004039328487124294,
|
923 |
+
0.00023557482927571982,
|
924 |
+
0.5764579176902771
|
925 |
+
],
|
926 |
+
"min": [
|
927 |
+
-0.4007510244846344,
|
928 |
+
-0.13874775171279907,
|
929 |
+
-0.22553899884223938,
|
930 |
+
-3.2010786533355713,
|
931 |
+
-1.8618112802505493,
|
932 |
+
-6.279075622558594,
|
933 |
+
0.0
|
934 |
+
],
|
935 |
+
"q01": [
|
936 |
+
-0.02872725307941437,
|
937 |
+
-0.04170349963009357,
|
938 |
+
-0.026093858778476715,
|
939 |
+
-0.08092105075716972,
|
940 |
+
-0.09288699507713317,
|
941 |
+
-0.20718276381492615,
|
942 |
+
0.0
|
943 |
+
],
|
944 |
+
"q99": [
|
945 |
+
0.028309678435325586,
|
946 |
+
0.040855254605412394,
|
947 |
+
0.040161586627364146,
|
948 |
+
0.08192047759890528,
|
949 |
+
0.07792850524187081,
|
950 |
+
0.20382574498653397,
|
951 |
+
1.0
|
952 |
+
],
|
953 |
+
"std": [
|
954 |
+
0.009765930473804474,
|
955 |
+
0.013689135201275349,
|
956 |
+
0.012667362578213215,
|
957 |
+
0.028534092009067535,
|
958 |
+
0.030637972056865692,
|
959 |
+
0.07691419124603271,
|
960 |
+
0.4973701536655426
|
961 |
+
]
|
962 |
+
},
|
963 |
+
"num_trajectories": 60064,
|
964 |
+
"num_transitions": 2135463,
|
965 |
+
"proprio": {
|
966 |
+
"max": [
|
967 |
+
0.0,
|
968 |
+
0.0,
|
969 |
+
0.0,
|
970 |
+
0.0,
|
971 |
+
0.0,
|
972 |
+
0.0,
|
973 |
+
0.0
|
974 |
+
],
|
975 |
+
"mean": [
|
976 |
+
0.0,
|
977 |
+
0.0,
|
978 |
+
0.0,
|
979 |
+
0.0,
|
980 |
+
0.0,
|
981 |
+
0.0,
|
982 |
+
0.0
|
983 |
+
],
|
984 |
+
"min": [
|
985 |
+
0.0,
|
986 |
+
0.0,
|
987 |
+
0.0,
|
988 |
+
0.0,
|
989 |
+
0.0,
|
990 |
+
0.0,
|
991 |
+
0.0
|
992 |
+
],
|
993 |
+
"q01": [
|
994 |
+
0.0,
|
995 |
+
0.0,
|
996 |
+
0.0,
|
997 |
+
0.0,
|
998 |
+
0.0,
|
999 |
+
0.0,
|
1000 |
+
0.0
|
1001 |
+
],
|
1002 |
+
"q99": [
|
1003 |
+
0.0,
|
1004 |
+
0.0,
|
1005 |
+
0.0,
|
1006 |
+
0.0,
|
1007 |
+
0.0,
|
1008 |
+
0.0,
|
1009 |
+
0.0
|
1010 |
+
],
|
1011 |
+
"std": [
|
1012 |
+
0.0,
|
1013 |
+
0.0,
|
1014 |
+
0.0,
|
1015 |
+
0.0,
|
1016 |
+
0.0,
|
1017 |
+
0.0,
|
1018 |
+
0.0
|
1019 |
+
]
|
1020 |
+
}
|
1021 |
+
},
|
1022 |
+
"cmu_stretch": {
|
1023 |
+
"action": {
|
1024 |
+
"mask": [
|
1025 |
+
true,
|
1026 |
+
true,
|
1027 |
+
true,
|
1028 |
+
true,
|
1029 |
+
true,
|
1030 |
+
true,
|
1031 |
+
false
|
1032 |
+
],
|
1033 |
+
"max": [
|
1034 |
+
0.02338407188653946,
|
1035 |
+
0.0,
|
1036 |
+
0.023404927924275398,
|
1037 |
+
0.0,
|
1038 |
+
0.0,
|
1039 |
+
0.0,
|
1040 |
+
1.0
|
1041 |
+
],
|
1042 |
+
"mean": [
|
1043 |
+
0.00036304505192674696,
|
1044 |
+
0.0,
|
1045 |
+
0.0016466958913952112,
|
1046 |
+
0.0,
|
1047 |
+
0.0,
|
1048 |
+
0.0,
|
1049 |
+
0.3987048268318176
|
1050 |
+
],
|
1051 |
+
"min": [
|
1052 |
+
-0.019353797659277916,
|
1053 |
+
0.0,
|
1054 |
+
-0.02019215188920498,
|
1055 |
+
0.0,
|
1056 |
+
0.0,
|
1057 |
+
0.0,
|
1058 |
+
0.0
|
1059 |
+
],
|
1060 |
+
"q01": [
|
1061 |
+
-0.011175686959177256,
|
1062 |
+
0.0,
|
1063 |
+
-0.0032206363626755773,
|
1064 |
+
0.0,
|
1065 |
+
0.0,
|
1066 |
+
0.0,
|
1067 |
+
0.0
|
1068 |
+
],
|
1069 |
+
"q99": [
|
1070 |
+
0.014501785952597848,
|
1071 |
+
0.0,
|
1072 |
+
0.015056106168776728,
|
1073 |
+
0.0,
|
1074 |
+
0.0,
|
1075 |
+
0.0,
|
1076 |
+
1.0
|
1077 |
+
],
|
1078 |
+
"std": [
|
1079 |
+
0.004081828519701958,
|
1080 |
+
0.0,
|
1081 |
+
0.0037743328139185905,
|
1082 |
+
0.0,
|
1083 |
+
0.0,
|
1084 |
+
0.0,
|
1085 |
+
0.48963725566864014
|
1086 |
+
]
|
1087 |
+
},
|
1088 |
+
"num_trajectories": 135,
|
1089 |
+
"num_transitions": 25016,
|
1090 |
+
"proprio": {
|
1091 |
+
"max": [
|
1092 |
+
0.0,
|
1093 |
+
0.0,
|
1094 |
+
0.0,
|
1095 |
+
0.0,
|
1096 |
+
0.0,
|
1097 |
+
0.0,
|
1098 |
+
0.0
|
1099 |
+
],
|
1100 |
+
"mean": [
|
1101 |
+
0.0,
|
1102 |
+
0.0,
|
1103 |
+
0.0,
|
1104 |
+
0.0,
|
1105 |
+
0.0,
|
1106 |
+
0.0,
|
1107 |
+
0.0
|
1108 |
+
],
|
1109 |
+
"min": [
|
1110 |
+
0.0,
|
1111 |
+
0.0,
|
1112 |
+
0.0,
|
1113 |
+
0.0,
|
1114 |
+
0.0,
|
1115 |
+
0.0,
|
1116 |
+
0.0
|
1117 |
+
],
|
1118 |
+
"q01": [
|
1119 |
+
0.0,
|
1120 |
+
0.0,
|
1121 |
+
0.0,
|
1122 |
+
0.0,
|
1123 |
+
0.0,
|
1124 |
+
0.0,
|
1125 |
+
0.0
|
1126 |
+
],
|
1127 |
+
"q99": [
|
1128 |
+
0.0,
|
1129 |
+
0.0,
|
1130 |
+
0.0,
|
1131 |
+
0.0,
|
1132 |
+
0.0,
|
1133 |
+
0.0,
|
1134 |
+
0.0
|
1135 |
+
],
|
1136 |
+
"std": [
|
1137 |
+
0.0,
|
1138 |
+
0.0,
|
1139 |
+
0.0,
|
1140 |
+
0.0,
|
1141 |
+
0.0,
|
1142 |
+
0.0,
|
1143 |
+
0.0
|
1144 |
+
]
|
1145 |
+
}
|
1146 |
+
},
|
1147 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
1148 |
+
"action": {
|
1149 |
+
"mask": [
|
1150 |
+
true,
|
1151 |
+
true,
|
1152 |
+
true,
|
1153 |
+
true,
|
1154 |
+
true,
|
1155 |
+
true,
|
1156 |
+
false
|
1157 |
+
],
|
1158 |
+
"max": [
|
1159 |
+
0.18991442024707794,
|
1160 |
+
0.0739002525806427,
|
1161 |
+
0.18064819276332855,
|
1162 |
+
0.0866486132144928,
|
1163 |
+
0.13464981317520142,
|
1164 |
+
0.16910280287265778,
|
1165 |
+
1.0
|
1166 |
+
],
|
1167 |
+
"mean": [
|
1168 |
+
0.006647810339927673,
|
1169 |
+
-0.0007657372043468058,
|
1170 |
+
0.006522852927446365,
|
1171 |
+
0.0011679717572405934,
|
1172 |
+
-0.006395625416189432,
|
1173 |
+
-0.011902998201549053,
|
1174 |
+
0.6985887289047241
|
1175 |
+
],
|
1176 |
+
"min": [
|
1177 |
+
-0.10054297000169754,
|
1178 |
+
-0.08427435159683228,
|
1179 |
+
-0.13533438742160797,
|
1180 |
+
-0.17556548118591309,
|
1181 |
+
-0.18485672771930695,
|
1182 |
+
-0.2680685818195343,
|
1183 |
+
0.0
|
1184 |
+
],
|
1185 |
+
"q01": [
|
1186 |
+
-0.02987122368067503,
|
1187 |
+
-0.06013262912631035,
|
1188 |
+
-0.08286409199237824,
|
1189 |
+
-0.05924444157630205,
|
1190 |
+
-0.15986866518855095,
|
1191 |
+
-0.15636983573436739,
|
1192 |
+
0.0
|
1193 |
+
],
|
1194 |
+
"q99": [
|
1195 |
+
0.08832092039287087,
|
1196 |
+
0.042126184627413736,
|
1197 |
+
0.11311905644834042,
|
1198 |
+
0.0643695573508739,
|
1199 |
+
0.03941855944693088,
|
1200 |
+
0.156646853685379,
|
1201 |
+
1.0
|
1202 |
+
],
|
1203 |
+
"std": [
|
1204 |
+
0.021393608301877975,
|
1205 |
+
0.01814231649041176,
|
1206 |
+
0.03374375030398369,
|
1207 |
+
0.01743541844189167,
|
1208 |
+
0.03394376486539841,
|
1209 |
+
0.04641875624656677,
|
1210 |
+
0.4588589072227478
|
1211 |
+
]
|
1212 |
+
},
|
1213 |
+
"num_trajectories": 104,
|
1214 |
+
"num_transitions": 8928,
|
1215 |
+
"proprio": {
|
1216 |
+
"max": [
|
1217 |
+
0.0,
|
1218 |
+
0.0,
|
1219 |
+
0.0,
|
1220 |
+
0.0,
|
1221 |
+
0.0,
|
1222 |
+
0.0,
|
1223 |
+
0.0
|
1224 |
+
],
|
1225 |
+
"mean": [
|
1226 |
+
0.0,
|
1227 |
+
0.0,
|
1228 |
+
0.0,
|
1229 |
+
0.0,
|
1230 |
+
0.0,
|
1231 |
+
0.0,
|
1232 |
+
0.0
|
1233 |
+
],
|
1234 |
+
"min": [
|
1235 |
+
0.0,
|
1236 |
+
0.0,
|
1237 |
+
0.0,
|
1238 |
+
0.0,
|
1239 |
+
0.0,
|
1240 |
+
0.0,
|
1241 |
+
0.0
|
1242 |
+
],
|
1243 |
+
"q01": [
|
1244 |
+
0.0,
|
1245 |
+
0.0,
|
1246 |
+
0.0,
|
1247 |
+
0.0,
|
1248 |
+
0.0,
|
1249 |
+
0.0,
|
1250 |
+
0.0
|
1251 |
+
],
|
1252 |
+
"q99": [
|
1253 |
+
0.0,
|
1254 |
+
0.0,
|
1255 |
+
0.0,
|
1256 |
+
0.0,
|
1257 |
+
0.0,
|
1258 |
+
0.0,
|
1259 |
+
0.0
|
1260 |
+
],
|
1261 |
+
"std": [
|
1262 |
+
0.0,
|
1263 |
+
0.0,
|
1264 |
+
0.0,
|
1265 |
+
0.0,
|
1266 |
+
0.0,
|
1267 |
+
0.0,
|
1268 |
+
0.0
|
1269 |
+
]
|
1270 |
+
}
|
1271 |
+
},
|
1272 |
+
"dobbe": {
|
1273 |
+
"action": {
|
1274 |
+
"mask": [
|
1275 |
+
true,
|
1276 |
+
true,
|
1277 |
+
true,
|
1278 |
+
true,
|
1279 |
+
true,
|
1280 |
+
true,
|
1281 |
+
false
|
1282 |
+
],
|
1283 |
+
"max": [
|
1284 |
+
38.590423583984375,
|
1285 |
+
17.932697296142578,
|
1286 |
+
4.843764305114746,
|
1287 |
+
1.4372116327285767,
|
1288 |
+
0.4340403974056244,
|
1289 |
+
1.2057193517684937,
|
1290 |
+
0.9998947381973267
|
1291 |
+
],
|
1292 |
+
"mean": [
|
1293 |
+
-0.0001120665911003016,
|
1294 |
+
0.0011229600058868527,
|
1295 |
+
-0.00010194431524723768,
|
1296 |
+
-7.371398532995954e-05,
|
1297 |
+
-0.00067531579406932,
|
1298 |
+
-5.6643435527803376e-05,
|
1299 |
+
0.6318281888961792
|
1300 |
+
],
|
1301 |
+
"min": [
|
1302 |
+
-5.700923442840576,
|
1303 |
+
-21.605947494506836,
|
1304 |
+
-123.72489929199219,
|
1305 |
+
-1.7229845523834229,
|
1306 |
+
-0.4998578727245331,
|
1307 |
+
-0.8867913484573364,
|
1308 |
+
1.4196479014572105e-06
|
1309 |
+
],
|
1310 |
+
"q01": [
|
1311 |
+
-0.01119564864784479,
|
1312 |
+
-0.014266146533191203,
|
1313 |
+
-0.0071747214533388615,
|
1314 |
+
-0.009444301575422287,
|
1315 |
+
-0.03990109823644161,
|
1316 |
+
-0.017422311007976532,
|
1317 |
+
4.003279136668425e-05
|
1318 |
+
],
|
1319 |
+
"q99": [
|
1320 |
+
0.01015154086053368,
|
1321 |
+
0.017181577533483497,
|
1322 |
+
0.007216989761218411,
|
1323 |
+
0.010380979906767595,
|
1324 |
+
0.03556173853576176,
|
1325 |
+
0.018032474815845446,
|
1326 |
+
0.9982578039169312
|
1327 |
+
],
|
1328 |
+
"std": [
|
1329 |
+
0.04264938458800316,
|
1330 |
+
0.04428559169173241,
|
1331 |
+
0.12224084138870239,
|
1332 |
+
0.005388413090258837,
|
1333 |
+
0.011246449314057827,
|
1334 |
+
0.006287882570177317,
|
1335 |
+
0.39732322096824646
|
1336 |
+
]
|
1337 |
+
},
|
1338 |
+
"num_trajectories": 5208,
|
1339 |
+
"num_transitions": 1139911,
|
1340 |
+
"proprio": {
|
1341 |
+
"max": [
|
1342 |
+
0.0,
|
1343 |
+
0.0,
|
1344 |
+
0.0,
|
1345 |
+
0.0,
|
1346 |
+
0.0,
|
1347 |
+
0.0,
|
1348 |
+
0.0
|
1349 |
+
],
|
1350 |
+
"mean": [
|
1351 |
+
0.0,
|
1352 |
+
0.0,
|
1353 |
+
0.0,
|
1354 |
+
0.0,
|
1355 |
+
0.0,
|
1356 |
+
0.0,
|
1357 |
+
0.0
|
1358 |
+
],
|
1359 |
+
"min": [
|
1360 |
+
0.0,
|
1361 |
+
0.0,
|
1362 |
+
0.0,
|
1363 |
+
0.0,
|
1364 |
+
0.0,
|
1365 |
+
0.0,
|
1366 |
+
0.0
|
1367 |
+
],
|
1368 |
+
"q01": [
|
1369 |
+
0.0,
|
1370 |
+
0.0,
|
1371 |
+
0.0,
|
1372 |
+
0.0,
|
1373 |
+
0.0,
|
1374 |
+
0.0,
|
1375 |
+
0.0
|
1376 |
+
],
|
1377 |
+
"q99": [
|
1378 |
+
0.0,
|
1379 |
+
0.0,
|
1380 |
+
0.0,
|
1381 |
+
0.0,
|
1382 |
+
0.0,
|
1383 |
+
0.0,
|
1384 |
+
0.0
|
1385 |
+
],
|
1386 |
+
"std": [
|
1387 |
+
0.0,
|
1388 |
+
0.0,
|
1389 |
+
0.0,
|
1390 |
+
0.0,
|
1391 |
+
0.0,
|
1392 |
+
0.0,
|
1393 |
+
0.0
|
1394 |
+
]
|
1395 |
+
}
|
1396 |
+
},
|
1397 |
+
"fmb_dataset": {
|
1398 |
+
"action": {
|
1399 |
+
"mask": [
|
1400 |
+
true,
|
1401 |
+
true,
|
1402 |
+
true,
|
1403 |
+
true,
|
1404 |
+
true,
|
1405 |
+
true,
|
1406 |
+
false
|
1407 |
+
],
|
1408 |
+
"max": [
|
1409 |
+
1.399999976158142,
|
1410 |
+
1.0,
|
1411 |
+
1.399999976158142,
|
1412 |
+
1.0,
|
1413 |
+
1.0,
|
1414 |
+
1.0,
|
1415 |
+
1.0
|
1416 |
+
],
|
1417 |
+
"mean": [
|
1418 |
+
0.059029702097177505,
|
1419 |
+
-0.06476633995771408,
|
1420 |
+
-0.09787475317716599,
|
1421 |
+
0.004325388930737972,
|
1422 |
+
0.00028963794466108084,
|
1423 |
+
-0.04457257315516472,
|
1424 |
+
0.7336440086364746
|
1425 |
+
],
|
1426 |
+
"min": [
|
1427 |
+
-1.399999976158142,
|
1428 |
+
-1.399999976158142,
|
1429 |
+
-1.0,
|
1430 |
+
-1.0,
|
1431 |
+
-1.0,
|
1432 |
+
-1.0,
|
1433 |
+
0.0
|
1434 |
+
],
|
1435 |
+
"q01": [
|
1436 |
+
-0.8257142901420593,
|
1437 |
+
-1.399999976158142,
|
1438 |
+
-1.0,
|
1439 |
+
-1.0,
|
1440 |
+
-0.3028571307659149,
|
1441 |
+
-1.0,
|
1442 |
+
0.0
|
1443 |
+
],
|
1444 |
+
"q99": [
|
1445 |
+
1.0,
|
1446 |
+
0.5257142782211304,
|
1447 |
+
1.0,
|
1448 |
+
1.0,
|
1449 |
+
0.3400000035762787,
|
1450 |
+
1.0,
|
1451 |
+
1.0
|
1452 |
+
],
|
1453 |
+
"std": [
|
1454 |
+
0.28809213638305664,
|
1455 |
+
0.2820415794849396,
|
1456 |
+
0.4626740515232086,
|
1457 |
+
0.3266514539718628,
|
1458 |
+
0.10842999070882797,
|
1459 |
+
0.3440099358558655,
|
1460 |
+
0.4435282051563263
|
1461 |
+
]
|
1462 |
+
},
|
1463 |
+
"num_trajectories": 8612,
|
1464 |
+
"num_transitions": 1137459,
|
1465 |
+
"proprio": {
|
1466 |
+
"max": [
|
1467 |
+
0.0,
|
1468 |
+
0.0,
|
1469 |
+
0.0,
|
1470 |
+
0.0,
|
1471 |
+
0.0,
|
1472 |
+
0.0,
|
1473 |
+
0.0
|
1474 |
+
],
|
1475 |
+
"mean": [
|
1476 |
+
0.0,
|
1477 |
+
0.0,
|
1478 |
+
0.0,
|
1479 |
+
0.0,
|
1480 |
+
0.0,
|
1481 |
+
0.0,
|
1482 |
+
0.0
|
1483 |
+
],
|
1484 |
+
"min": [
|
1485 |
+
0.0,
|
1486 |
+
0.0,
|
1487 |
+
0.0,
|
1488 |
+
0.0,
|
1489 |
+
0.0,
|
1490 |
+
0.0,
|
1491 |
+
0.0
|
1492 |
+
],
|
1493 |
+
"q01": [
|
1494 |
+
0.0,
|
1495 |
+
0.0,
|
1496 |
+
0.0,
|
1497 |
+
0.0,
|
1498 |
+
0.0,
|
1499 |
+
0.0,
|
1500 |
+
0.0
|
1501 |
+
],
|
1502 |
+
"q99": [
|
1503 |
+
0.0,
|
1504 |
+
0.0,
|
1505 |
+
0.0,
|
1506 |
+
0.0,
|
1507 |
+
0.0,
|
1508 |
+
0.0,
|
1509 |
+
0.0
|
1510 |
+
],
|
1511 |
+
"std": [
|
1512 |
+
0.0,
|
1513 |
+
0.0,
|
1514 |
+
0.0,
|
1515 |
+
0.0,
|
1516 |
+
0.0,
|
1517 |
+
0.0,
|
1518 |
+
0.0
|
1519 |
+
]
|
1520 |
+
}
|
1521 |
+
},
|
1522 |
+
"fractal20220817_data": {
|
1523 |
+
"action": {
|
1524 |
+
"mask": [
|
1525 |
+
true,
|
1526 |
+
true,
|
1527 |
+
true,
|
1528 |
+
true,
|
1529 |
+
true,
|
1530 |
+
true,
|
1531 |
+
false
|
1532 |
+
],
|
1533 |
+
"max": [
|
1534 |
+
2.9984593391418457,
|
1535 |
+
22.09052848815918,
|
1536 |
+
2.7507524490356445,
|
1537 |
+
1.570636510848999,
|
1538 |
+
1.5321086645126343,
|
1539 |
+
1.5691522359848022,
|
1540 |
+
1.0
|
1541 |
+
],
|
1542 |
+
"mean": [
|
1543 |
+
0.006987582892179489,
|
1544 |
+
0.006265917327255011,
|
1545 |
+
-0.01262515690177679,
|
1546 |
+
0.04333311319351196,
|
1547 |
+
-0.005756212864071131,
|
1548 |
+
0.0009130256366916001,
|
1549 |
+
0.5354204773902893
|
1550 |
+
],
|
1551 |
+
"min": [
|
1552 |
+
-2.0204520225524902,
|
1553 |
+
-5.497899532318115,
|
1554 |
+
-2.031663417816162,
|
1555 |
+
-1.569917917251587,
|
1556 |
+
-1.569892168045044,
|
1557 |
+
-1.570419430732727,
|
1558 |
+
0.0
|
1559 |
+
],
|
1560 |
+
"q01": [
|
1561 |
+
-0.22453527510166169,
|
1562 |
+
-0.14820013284683228,
|
1563 |
+
-0.231589707583189,
|
1564 |
+
-0.3517994859814644,
|
1565 |
+
-0.4193011274933815,
|
1566 |
+
-0.43643461108207704,
|
1567 |
+
0.0
|
1568 |
+
],
|
1569 |
+
"q99": [
|
1570 |
+
0.17824687153100965,
|
1571 |
+
0.14938379630446405,
|
1572 |
+
0.21842354819178575,
|
1573 |
+
0.5892666035890578,
|
1574 |
+
0.35272657424211445,
|
1575 |
+
0.44796681255102094,
|
1576 |
+
1.0
|
1577 |
+
],
|
1578 |
+
"std": [
|
1579 |
+
0.0692116990685463,
|
1580 |
+
0.05970962345600128,
|
1581 |
+
0.07353084534406662,
|
1582 |
+
0.15610496699810028,
|
1583 |
+
0.13164450228214264,
|
1584 |
+
0.14593800902366638,
|
1585 |
+
0.497110515832901
|
1586 |
+
]
|
1587 |
+
},
|
1588 |
+
"num_trajectories": 87212,
|
1589 |
+
"num_transitions": 3786400,
|
1590 |
+
"proprio": {
|
1591 |
+
"max": [
|
1592 |
+
0.0,
|
1593 |
+
0.0,
|
1594 |
+
0.0,
|
1595 |
+
0.0,
|
1596 |
+
0.0,
|
1597 |
+
0.0,
|
1598 |
+
0.0
|
1599 |
+
],
|
1600 |
+
"mean": [
|
1601 |
+
0.0,
|
1602 |
+
0.0,
|
1603 |
+
0.0,
|
1604 |
+
0.0,
|
1605 |
+
0.0,
|
1606 |
+
0.0,
|
1607 |
+
0.0
|
1608 |
+
],
|
1609 |
+
"min": [
|
1610 |
+
0.0,
|
1611 |
+
0.0,
|
1612 |
+
0.0,
|
1613 |
+
0.0,
|
1614 |
+
0.0,
|
1615 |
+
0.0,
|
1616 |
+
0.0
|
1617 |
+
],
|
1618 |
+
"q01": [
|
1619 |
+
0.0,
|
1620 |
+
0.0,
|
1621 |
+
0.0,
|
1622 |
+
0.0,
|
1623 |
+
0.0,
|
1624 |
+
0.0,
|
1625 |
+
0.0
|
1626 |
+
],
|
1627 |
+
"q99": [
|
1628 |
+
0.0,
|
1629 |
+
0.0,
|
1630 |
+
0.0,
|
1631 |
+
0.0,
|
1632 |
+
0.0,
|
1633 |
+
0.0,
|
1634 |
+
0.0
|
1635 |
+
],
|
1636 |
+
"std": [
|
1637 |
+
0.0,
|
1638 |
+
0.0,
|
1639 |
+
0.0,
|
1640 |
+
0.0,
|
1641 |
+
0.0,
|
1642 |
+
0.0,
|
1643 |
+
0.0
|
1644 |
+
]
|
1645 |
+
}
|
1646 |
+
},
|
1647 |
+
"furniture_bench_dataset_converted_externally_to_rlds": {
|
1648 |
+
"action": {
|
1649 |
+
"mask": [
|
1650 |
+
true,
|
1651 |
+
true,
|
1652 |
+
true,
|
1653 |
+
true,
|
1654 |
+
true,
|
1655 |
+
true,
|
1656 |
+
false
|
1657 |
+
],
|
1658 |
+
"max": [
|
1659 |
+
0.10000000149011612,
|
1660 |
+
0.10000000149011612,
|
1661 |
+
0.10000000149011612,
|
1662 |
+
0.8651833534240723,
|
1663 |
+
1.0909736156463623,
|
1664 |
+
2.863185405731201,
|
1665 |
+
1.0
|
1666 |
+
],
|
1667 |
+
"mean": [
|
1668 |
+
0.00014610752987209707,
|
1669 |
+
0.0010830952087417245,
|
1670 |
+
0.0006224989192560315,
|
1671 |
+
-0.003303206292912364,
|
1672 |
+
-0.0026880695950239897,
|
1673 |
+
0.018242603167891502,
|
1674 |
+
0.48854944109916687
|
1675 |
+
],
|
1676 |
+
"min": [
|
1677 |
+
-0.10495579987764359,
|
1678 |
+
-0.10939455777406693,
|
1679 |
+
-0.10000000149011612,
|
1680 |
+
-0.971906840801239,
|
1681 |
+
-1.0475432872772217,
|
1682 |
+
-3.06000018119812,
|
1683 |
+
0.0
|
1684 |
+
],
|
1685 |
+
"q01": [
|
1686 |
+
-0.053988199681043625,
|
1687 |
+
-0.05049169331789017,
|
1688 |
+
-0.032499241530895236,
|
1689 |
+
-0.1953887003660202,
|
1690 |
+
-0.41674559473991396,
|
1691 |
+
-0.8886768388748169,
|
1692 |
+
0.0
|
1693 |
+
],
|
1694 |
+
"q99": [
|
1695 |
+
0.05414841488003723,
|
1696 |
+
0.04965164884924884,
|
1697 |
+
0.060055799782276154,
|
1698 |
+
0.18231668293476103,
|
1699 |
+
0.39867786407470646,
|
1700 |
+
0.8772023963928218,
|
1701 |
+
1.0
|
1702 |
+
],
|
1703 |
+
"std": [
|
1704 |
+
0.01610708422958851,
|
1705 |
+
0.014891477301716805,
|
1706 |
+
0.014014219865202904,
|
1707 |
+
0.058274295181035995,
|
1708 |
+
0.11417088657617569,
|
1709 |
+
0.33479776978492737,
|
1710 |
+
0.49991825222969055
|
1711 |
+
]
|
1712 |
+
},
|
1713 |
+
"num_trajectories": 5100,
|
1714 |
+
"num_transitions": 3948057,
|
1715 |
+
"proprio": {
|
1716 |
+
"max": [
|
1717 |
+
0.0,
|
1718 |
+
0.0,
|
1719 |
+
0.0,
|
1720 |
+
0.0,
|
1721 |
+
0.0,
|
1722 |
+
0.0,
|
1723 |
+
0.0
|
1724 |
+
],
|
1725 |
+
"mean": [
|
1726 |
+
0.0,
|
1727 |
+
0.0,
|
1728 |
+
0.0,
|
1729 |
+
0.0,
|
1730 |
+
0.0,
|
1731 |
+
0.0,
|
1732 |
+
0.0
|
1733 |
+
],
|
1734 |
+
"min": [
|
1735 |
+
0.0,
|
1736 |
+
0.0,
|
1737 |
+
0.0,
|
1738 |
+
0.0,
|
1739 |
+
0.0,
|
1740 |
+
0.0,
|
1741 |
+
0.0
|
1742 |
+
],
|
1743 |
+
"q01": [
|
1744 |
+
0.0,
|
1745 |
+
0.0,
|
1746 |
+
0.0,
|
1747 |
+
0.0,
|
1748 |
+
0.0,
|
1749 |
+
0.0,
|
1750 |
+
0.0
|
1751 |
+
],
|
1752 |
+
"q99": [
|
1753 |
+
0.0,
|
1754 |
+
0.0,
|
1755 |
+
0.0,
|
1756 |
+
0.0,
|
1757 |
+
0.0,
|
1758 |
+
0.0,
|
1759 |
+
0.0
|
1760 |
+
],
|
1761 |
+
"std": [
|
1762 |
+
0.0,
|
1763 |
+
0.0,
|
1764 |
+
0.0,
|
1765 |
+
0.0,
|
1766 |
+
0.0,
|
1767 |
+
0.0,
|
1768 |
+
0.0
|
1769 |
+
]
|
1770 |
+
}
|
1771 |
+
},
|
1772 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
1773 |
+
"action": {
|
1774 |
+
"mask": [
|
1775 |
+
true,
|
1776 |
+
true,
|
1777 |
+
true,
|
1778 |
+
true,
|
1779 |
+
true,
|
1780 |
+
true,
|
1781 |
+
false
|
1782 |
+
],
|
1783 |
+
"max": [
|
1784 |
+
0.6634981632232666,
|
1785 |
+
0.23428471386432648,
|
1786 |
+
0.4308285415172577,
|
1787 |
+
3.1415927410125732,
|
1788 |
+
0.13647015392780304,
|
1789 |
+
3.141592502593994,
|
1790 |
+
1.0
|
1791 |
+
],
|
1792 |
+
"mean": [
|
1793 |
+
0.5274372696876526,
|
1794 |
+
0.02858201041817665,
|
1795 |
+
0.18712575733661652,
|
1796 |
+
1.2339589595794678,
|
1797 |
+
0.03226623684167862,
|
1798 |
+
-1.4199490547180176,
|
1799 |
+
0.5550631880760193
|
1800 |
+
],
|
1801 |
+
"min": [
|
1802 |
+
0.3071657121181488,
|
1803 |
+
-0.29754969477653503,
|
1804 |
+
0.06578229367733002,
|
1805 |
+
-3.1415927410125732,
|
1806 |
+
-0.04584203287959099,
|
1807 |
+
-3.141592502593994,
|
1808 |
+
0.0
|
1809 |
+
],
|
1810 |
+
"q01": [
|
1811 |
+
0.3148897051811218,
|
1812 |
+
-0.20317550599575043,
|
1813 |
+
0.06785467118024827,
|
1814 |
+
-3.140952730178833,
|
1815 |
+
-0.029743434861302376,
|
1816 |
+
-3.141091251373291,
|
1817 |
+
0.0
|
1818 |
+
],
|
1819 |
+
"q99": [
|
1820 |
+
0.6472805738449097,
|
1821 |
+
0.20846802592277527,
|
1822 |
+
0.36855655312538155,
|
1823 |
+
3.1409926891326903,
|
1824 |
+
0.11424950212240226,
|
1825 |
+
3.1410969257354737,
|
1826 |
+
1.0
|
1827 |
+
],
|
1828 |
+
"std": [
|
1829 |
+
0.08108345419168472,
|
1830 |
+
0.1116757020354271,
|
1831 |
+
0.07747554779052734,
|
1832 |
+
2.8737246990203857,
|
1833 |
+
0.02774704433977604,
|
1834 |
+
2.7678682804107666,
|
1835 |
+
0.49695101380348206
|
1836 |
+
]
|
1837 |
+
},
|
1838 |
+
"num_trajectories": 631,
|
1839 |
+
"num_transitions": 146241,
|
1840 |
+
"proprio": {
|
1841 |
+
"max": [
|
1842 |
+
0.0,
|
1843 |
+
0.0,
|
1844 |
+
0.0,
|
1845 |
+
0.0,
|
1846 |
+
0.0,
|
1847 |
+
0.0,
|
1848 |
+
0.0
|
1849 |
+
],
|
1850 |
+
"mean": [
|
1851 |
+
0.0,
|
1852 |
+
0.0,
|
1853 |
+
0.0,
|
1854 |
+
0.0,
|
1855 |
+
0.0,
|
1856 |
+
0.0,
|
1857 |
+
0.0
|
1858 |
+
],
|
1859 |
+
"min": [
|
1860 |
+
0.0,
|
1861 |
+
0.0,
|
1862 |
+
0.0,
|
1863 |
+
0.0,
|
1864 |
+
0.0,
|
1865 |
+
0.0,
|
1866 |
+
0.0
|
1867 |
+
],
|
1868 |
+
"q01": [
|
1869 |
+
0.0,
|
1870 |
+
0.0,
|
1871 |
+
0.0,
|
1872 |
+
0.0,
|
1873 |
+
0.0,
|
1874 |
+
0.0,
|
1875 |
+
0.0
|
1876 |
+
],
|
1877 |
+
"q99": [
|
1878 |
+
0.0,
|
1879 |
+
0.0,
|
1880 |
+
0.0,
|
1881 |
+
0.0,
|
1882 |
+
0.0,
|
1883 |
+
0.0,
|
1884 |
+
0.0
|
1885 |
+
],
|
1886 |
+
"std": [
|
1887 |
+
0.0,
|
1888 |
+
0.0,
|
1889 |
+
0.0,
|
1890 |
+
0.0,
|
1891 |
+
0.0,
|
1892 |
+
0.0,
|
1893 |
+
0.0
|
1894 |
+
]
|
1895 |
+
}
|
1896 |
+
},
|
1897 |
+
"jaco_play": {
|
1898 |
+
"action": {
|
1899 |
+
"mask": [
|
1900 |
+
true,
|
1901 |
+
true,
|
1902 |
+
true,
|
1903 |
+
true,
|
1904 |
+
true,
|
1905 |
+
true,
|
1906 |
+
false
|
1907 |
+
],
|
1908 |
+
"max": [
|
1909 |
+
0.20000000298023224,
|
1910 |
+
0.20000000298023224,
|
1911 |
+
0.20000000298023224,
|
1912 |
+
0.0,
|
1913 |
+
0.0,
|
1914 |
+
0.0,
|
1915 |
+
1.0
|
1916 |
+
],
|
1917 |
+
"mean": [
|
1918 |
+
0.0009658430935814977,
|
1919 |
+
-0.00580078037455678,
|
1920 |
+
-0.00395062193274498,
|
1921 |
+
0.0,
|
1922 |
+
0.0,
|
1923 |
+
0.0,
|
1924 |
+
0.34934908151626587
|
1925 |
+
],
|
1926 |
+
"min": [
|
1927 |
+
-0.20000000298023224,
|
1928 |
+
-0.20000000298023224,
|
1929 |
+
-0.20000000298023224,
|
1930 |
+
0.0,
|
1931 |
+
0.0,
|
1932 |
+
0.0,
|
1933 |
+
0.0
|
1934 |
+
],
|
1935 |
+
"q01": [
|
1936 |
+
-0.20000000298023224,
|
1937 |
+
-0.20000000298023224,
|
1938 |
+
-0.20000000298023224,
|
1939 |
+
0.0,
|
1940 |
+
0.0,
|
1941 |
+
0.0,
|
1942 |
+
0.0
|
1943 |
+
],
|
1944 |
+
"q99": [
|
1945 |
+
0.20000000298023224,
|
1946 |
+
0.20000000298023224,
|
1947 |
+
0.20000000298023224,
|
1948 |
+
0.0,
|
1949 |
+
0.0,
|
1950 |
+
0.0,
|
1951 |
+
1.0
|
1952 |
+
],
|
1953 |
+
"std": [
|
1954 |
+
0.12235074490308762,
|
1955 |
+
0.09678777307271957,
|
1956 |
+
0.11155334860086441,
|
1957 |
+
0.0,
|
1958 |
+
0.0,
|
1959 |
+
0.0,
|
1960 |
+
0.4768252968788147
|
1961 |
+
]
|
1962 |
+
},
|
1963 |
+
"num_trajectories": 1085,
|
1964 |
+
"num_transitions": 77965,
|
1965 |
+
"proprio": {
|
1966 |
+
"max": [
|
1967 |
+
0.0,
|
1968 |
+
0.0,
|
1969 |
+
0.0,
|
1970 |
+
0.0,
|
1971 |
+
0.0,
|
1972 |
+
0.0,
|
1973 |
+
0.0
|
1974 |
+
],
|
1975 |
+
"mean": [
|
1976 |
+
0.0,
|
1977 |
+
0.0,
|
1978 |
+
0.0,
|
1979 |
+
0.0,
|
1980 |
+
0.0,
|
1981 |
+
0.0,
|
1982 |
+
0.0
|
1983 |
+
],
|
1984 |
+
"min": [
|
1985 |
+
0.0,
|
1986 |
+
0.0,
|
1987 |
+
0.0,
|
1988 |
+
0.0,
|
1989 |
+
0.0,
|
1990 |
+
0.0,
|
1991 |
+
0.0
|
1992 |
+
],
|
1993 |
+
"q01": [
|
1994 |
+
0.0,
|
1995 |
+
0.0,
|
1996 |
+
0.0,
|
1997 |
+
0.0,
|
1998 |
+
0.0,
|
1999 |
+
0.0,
|
2000 |
+
0.0
|
2001 |
+
],
|
2002 |
+
"q99": [
|
2003 |
+
0.0,
|
2004 |
+
0.0,
|
2005 |
+
0.0,
|
2006 |
+
0.0,
|
2007 |
+
0.0,
|
2008 |
+
0.0,
|
2009 |
+
0.0
|
2010 |
+
],
|
2011 |
+
"std": [
|
2012 |
+
0.0,
|
2013 |
+
0.0,
|
2014 |
+
0.0,
|
2015 |
+
0.0,
|
2016 |
+
0.0,
|
2017 |
+
0.0,
|
2018 |
+
0.0
|
2019 |
+
]
|
2020 |
+
}
|
2021 |
+
},
|
2022 |
+
"kuka": {
|
2023 |
+
"action": {
|
2024 |
+
"mask": [
|
2025 |
+
true,
|
2026 |
+
true,
|
2027 |
+
true,
|
2028 |
+
true,
|
2029 |
+
true,
|
2030 |
+
true,
|
2031 |
+
false
|
2032 |
+
],
|
2033 |
+
"max": [
|
2034 |
+
0.1697135865688324,
|
2035 |
+
0.2777623236179352,
|
2036 |
+
0.43710532784461975,
|
2037 |
+
0.0,
|
2038 |
+
0.0,
|
2039 |
+
1.9684287309646606,
|
2040 |
+
1.0
|
2041 |
+
],
|
2042 |
+
"mean": [
|
2043 |
+
-0.0004668905457947403,
|
2044 |
+
0.00040138536132872105,
|
2045 |
+
-0.001280792523175478,
|
2046 |
+
0.0,
|
2047 |
+
0.0,
|
2048 |
+
-0.03722453489899635,
|
2049 |
+
0.4131543040275574
|
2050 |
+
],
|
2051 |
+
"min": [
|
2052 |
+
-0.159867063164711,
|
2053 |
+
-0.2892282009124756,
|
2054 |
+
-0.2795473635196686,
|
2055 |
+
0.0,
|
2056 |
+
0.0,
|
2057 |
+
-1.9875637292861938,
|
2058 |
+
0.0
|
2059 |
+
],
|
2060 |
+
"q01": [
|
2061 |
+
-0.06619441494345665,
|
2062 |
+
-0.08713878810405731,
|
2063 |
+
-0.15083016991615295,
|
2064 |
+
0.0,
|
2065 |
+
0.0,
|
2066 |
+
-0.5415697038173676,
|
2067 |
+
0.0
|
2068 |
+
],
|
2069 |
+
"q99": [
|
2070 |
+
0.06601839080452929,
|
2071 |
+
0.08732476785779003,
|
2072 |
+
0.18168179214000715,
|
2073 |
+
0.0,
|
2074 |
+
0.0,
|
2075 |
+
0.2923380345106127,
|
2076 |
+
1.0
|
2077 |
+
],
|
2078 |
+
"std": [
|
2079 |
+
0.02083250693976879,
|
2080 |
+
0.02915887162089348,
|
2081 |
+
0.06422865390777588,
|
2082 |
+
0.0,
|
2083 |
+
0.0,
|
2084 |
+
0.14224295318126678,
|
2085 |
+
0.49086448550224304
|
2086 |
+
]
|
2087 |
+
},
|
2088 |
+
"num_trajectories": 209880,
|
2089 |
+
"num_transitions": 2455879,
|
2090 |
+
"proprio": {
|
2091 |
+
"max": [
|
2092 |
+
0.0,
|
2093 |
+
0.0,
|
2094 |
+
0.0,
|
2095 |
+
0.0,
|
2096 |
+
0.0,
|
2097 |
+
0.0,
|
2098 |
+
0.0
|
2099 |
+
],
|
2100 |
+
"mean": [
|
2101 |
+
0.0,
|
2102 |
+
0.0,
|
2103 |
+
0.0,
|
2104 |
+
0.0,
|
2105 |
+
0.0,
|
2106 |
+
0.0,
|
2107 |
+
0.0
|
2108 |
+
],
|
2109 |
+
"min": [
|
2110 |
+
0.0,
|
2111 |
+
0.0,
|
2112 |
+
0.0,
|
2113 |
+
0.0,
|
2114 |
+
0.0,
|
2115 |
+
0.0,
|
2116 |
+
0.0
|
2117 |
+
],
|
2118 |
+
"q01": [
|
2119 |
+
0.0,
|
2120 |
+
0.0,
|
2121 |
+
0.0,
|
2122 |
+
0.0,
|
2123 |
+
0.0,
|
2124 |
+
0.0,
|
2125 |
+
0.0
|
2126 |
+
],
|
2127 |
+
"q99": [
|
2128 |
+
0.0,
|
2129 |
+
0.0,
|
2130 |
+
0.0,
|
2131 |
+
0.0,
|
2132 |
+
0.0,
|
2133 |
+
0.0,
|
2134 |
+
0.0
|
2135 |
+
],
|
2136 |
+
"std": [
|
2137 |
+
0.0,
|
2138 |
+
0.0,
|
2139 |
+
0.0,
|
2140 |
+
0.0,
|
2141 |
+
0.0,
|
2142 |
+
0.0,
|
2143 |
+
0.0
|
2144 |
+
]
|
2145 |
+
}
|
2146 |
+
},
|
2147 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
2148 |
+
"action": {
|
2149 |
+
"mask": [
|
2150 |
+
true,
|
2151 |
+
true,
|
2152 |
+
true,
|
2153 |
+
true,
|
2154 |
+
true,
|
2155 |
+
true,
|
2156 |
+
false
|
2157 |
+
],
|
2158 |
+
"max": [
|
2159 |
+
0.06424188613891602,
|
2160 |
+
0.07027634978294373,
|
2161 |
+
0.06129661202430725,
|
2162 |
+
6.281067848205566,
|
2163 |
+
0.1967729926109314,
|
2164 |
+
0.26377415657043457,
|
2165 |
+
1.0
|
2166 |
+
],
|
2167 |
+
"mean": [
|
2168 |
+
0.001021989737637341,
|
2169 |
+
-0.00012002651783404872,
|
2170 |
+
0.00032894269679673016,
|
2171 |
+
0.0015034361276775599,
|
2172 |
+
-0.002198522910475731,
|
2173 |
+
-0.001663230243138969,
|
2174 |
+
0.7230083346366882
|
2175 |
+
],
|
2176 |
+
"min": [
|
2177 |
+
-0.05952230095863342,
|
2178 |
+
-0.07232445478439331,
|
2179 |
+
-0.06730806827545166,
|
2180 |
+
-6.278434753417969,
|
2181 |
+
-0.21479034423828125,
|
2182 |
+
-0.3627619743347168,
|
2183 |
+
0.0
|
2184 |
+
],
|
2185 |
+
"q01": [
|
2186 |
+
-0.03199600875377655,
|
2187 |
+
-0.032861671447753905,
|
2188 |
+
-0.03368805110454559,
|
2189 |
+
-0.12080862045288086,
|
2190 |
+
-0.12175218224525451,
|
2191 |
+
-0.11370223641395569,
|
2192 |
+
0.0
|
2193 |
+
],
|
2194 |
+
"q99": [
|
2195 |
+
0.03101520001888276,
|
2196 |
+
0.0373908892273903,
|
2197 |
+
0.03646374464035038,
|
2198 |
+
0.11764093399047852,
|
2199 |
+
0.1258920183777809,
|
2200 |
+
0.09366151213645942,
|
2201 |
+
1.0
|
2202 |
+
],
|
2203 |
+
"std": [
|
2204 |
+
0.01327415369451046,
|
2205 |
+
0.013215910643339157,
|
2206 |
+
0.012822109274566174,
|
2207 |
+
0.2732451558113098,
|
2208 |
+
0.057022541761398315,
|
2209 |
+
0.039172880351543427,
|
2210 |
+
0.44752755761146545
|
2211 |
+
]
|
2212 |
+
},
|
2213 |
+
"num_trajectories": 456,
|
2214 |
+
"num_transitions": 44875,
|
2215 |
+
"proprio": {
|
2216 |
+
"max": [
|
2217 |
+
0.0,
|
2218 |
+
0.0,
|
2219 |
+
0.0,
|
2220 |
+
0.0,
|
2221 |
+
0.0,
|
2222 |
+
0.0,
|
2223 |
+
0.0
|
2224 |
+
],
|
2225 |
+
"mean": [
|
2226 |
+
0.0,
|
2227 |
+
0.0,
|
2228 |
+
0.0,
|
2229 |
+
0.0,
|
2230 |
+
0.0,
|
2231 |
+
0.0,
|
2232 |
+
0.0
|
2233 |
+
],
|
2234 |
+
"min": [
|
2235 |
+
0.0,
|
2236 |
+
0.0,
|
2237 |
+
0.0,
|
2238 |
+
0.0,
|
2239 |
+
0.0,
|
2240 |
+
0.0,
|
2241 |
+
0.0
|
2242 |
+
],
|
2243 |
+
"q01": [
|
2244 |
+
0.0,
|
2245 |
+
0.0,
|
2246 |
+
0.0,
|
2247 |
+
0.0,
|
2248 |
+
0.0,
|
2249 |
+
0.0,
|
2250 |
+
0.0
|
2251 |
+
],
|
2252 |
+
"q99": [
|
2253 |
+
0.0,
|
2254 |
+
0.0,
|
2255 |
+
0.0,
|
2256 |
+
0.0,
|
2257 |
+
0.0,
|
2258 |
+
0.0,
|
2259 |
+
0.0
|
2260 |
+
],
|
2261 |
+
"std": [
|
2262 |
+
0.0,
|
2263 |
+
0.0,
|
2264 |
+
0.0,
|
2265 |
+
0.0,
|
2266 |
+
0.0,
|
2267 |
+
0.0,
|
2268 |
+
0.0
|
2269 |
+
]
|
2270 |
+
}
|
2271 |
+
},
|
2272 |
+
"roboturk": {
|
2273 |
+
"action": {
|
2274 |
+
"mask": [
|
2275 |
+
true,
|
2276 |
+
true,
|
2277 |
+
true,
|
2278 |
+
true,
|
2279 |
+
true,
|
2280 |
+
true,
|
2281 |
+
false
|
2282 |
+
],
|
2283 |
+
"max": [
|
2284 |
+
0.39124172925949097,
|
2285 |
+
0.4601028263568878,
|
2286 |
+
0.4870833456516266,
|
2287 |
+
1.816888689994812,
|
2288 |
+
1.8240282535552979,
|
2289 |
+
1.4824820756912231,
|
2290 |
+
1.0
|
2291 |
+
],
|
2292 |
+
"mean": [
|
2293 |
+
0.0014448732836171985,
|
2294 |
+
-0.0015945249469950795,
|
2295 |
+
-0.0011753785656765103,
|
2296 |
+
0.0023012510500848293,
|
2297 |
+
-0.0009382463176734746,
|
2298 |
+
-0.00011485807772260159,
|
2299 |
+
0.5746025443077087
|
2300 |
+
],
|
2301 |
+
"min": [
|
2302 |
+
-0.6546999216079712,
|
2303 |
+
-0.6365841031074524,
|
2304 |
+
-0.4217723608016968,
|
2305 |
+
-1.6695482730865479,
|
2306 |
+
-1.8023357391357422,
|
2307 |
+
-1.4630827903747559,
|
2308 |
+
0.0
|
2309 |
+
],
|
2310 |
+
"q01": [
|
2311 |
+
-0.1342635464668274,
|
2312 |
+
-0.19996687173843383,
|
2313 |
+
-0.1482972100377083,
|
2314 |
+
-0.20720748245716095,
|
2315 |
+
-0.09676413893699647,
|
2316 |
+
-0.18075634717941286,
|
2317 |
+
0.0
|
2318 |
+
],
|
2319 |
+
"q99": [
|
2320 |
+
0.14956976801157001,
|
2321 |
+
0.1805950567126275,
|
2322 |
+
0.18841815620660796,
|
2323 |
+
0.21615413755178453,
|
2324 |
+
0.09457383215427405,
|
2325 |
+
0.18543301910162005,
|
2326 |
+
1.0
|
2327 |
+
],
|
2328 |
+
"std": [
|
2329 |
+
0.04935386776924133,
|
2330 |
+
0.0635455846786499,
|
2331 |
+
0.061164740473032,
|
2332 |
+
0.09553450345993042,
|
2333 |
+
0.08420111238956451,
|
2334 |
+
0.06517903506755829,
|
2335 |
+
0.49452081322669983
|
2336 |
+
]
|
2337 |
+
},
|
2338 |
+
"num_trajectories": 1995,
|
2339 |
+
"num_transitions": 187507,
|
2340 |
+
"proprio": {
|
2341 |
+
"max": [
|
2342 |
+
0.0,
|
2343 |
+
0.0,
|
2344 |
+
0.0,
|
2345 |
+
0.0,
|
2346 |
+
0.0,
|
2347 |
+
0.0,
|
2348 |
+
0.0
|
2349 |
+
],
|
2350 |
+
"mean": [
|
2351 |
+
0.0,
|
2352 |
+
0.0,
|
2353 |
+
0.0,
|
2354 |
+
0.0,
|
2355 |
+
0.0,
|
2356 |
+
0.0,
|
2357 |
+
0.0
|
2358 |
+
],
|
2359 |
+
"min": [
|
2360 |
+
0.0,
|
2361 |
+
0.0,
|
2362 |
+
0.0,
|
2363 |
+
0.0,
|
2364 |
+
0.0,
|
2365 |
+
0.0,
|
2366 |
+
0.0
|
2367 |
+
],
|
2368 |
+
"q01": [
|
2369 |
+
0.0,
|
2370 |
+
0.0,
|
2371 |
+
0.0,
|
2372 |
+
0.0,
|
2373 |
+
0.0,
|
2374 |
+
0.0,
|
2375 |
+
0.0
|
2376 |
+
],
|
2377 |
+
"q99": [
|
2378 |
+
0.0,
|
2379 |
+
0.0,
|
2380 |
+
0.0,
|
2381 |
+
0.0,
|
2382 |
+
0.0,
|
2383 |
+
0.0,
|
2384 |
+
0.0
|
2385 |
+
],
|
2386 |
+
"std": [
|
2387 |
+
0.0,
|
2388 |
+
0.0,
|
2389 |
+
0.0,
|
2390 |
+
0.0,
|
2391 |
+
0.0,
|
2392 |
+
0.0,
|
2393 |
+
0.0
|
2394 |
+
]
|
2395 |
+
}
|
2396 |
+
},
|
2397 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
2398 |
+
"action": {
|
2399 |
+
"mask": [
|
2400 |
+
true,
|
2401 |
+
true,
|
2402 |
+
true,
|
2403 |
+
true,
|
2404 |
+
true,
|
2405 |
+
true,
|
2406 |
+
false
|
2407 |
+
],
|
2408 |
+
"max": [
|
2409 |
+
0.02499854564666748,
|
2410 |
+
0.02499903365969658,
|
2411 |
+
0.024999922141432762,
|
2412 |
+
0.24974457919597626,
|
2413 |
+
0.24997030198574066,
|
2414 |
+
0.24999946355819702,
|
2415 |
+
1.0
|
2416 |
+
],
|
2417 |
+
"mean": [
|
2418 |
+
0.0007790001109242439,
|
2419 |
+
0.00013707754260394722,
|
2420 |
+
-0.0002548607881180942,
|
2421 |
+
0.0012903271708637476,
|
2422 |
+
-0.004751681815832853,
|
2423 |
+
0.002692886395379901,
|
2424 |
+
0.48855218291282654
|
2425 |
+
],
|
2426 |
+
"min": [
|
2427 |
+
-0.024999044835567474,
|
2428 |
+
-0.024999700486660004,
|
2429 |
+
-0.02499929815530777,
|
2430 |
+
-0.24993225932121277,
|
2431 |
+
-0.2499666064977646,
|
2432 |
+
-0.2499932497739792,
|
2433 |
+
0.0
|
2434 |
+
],
|
2435 |
+
"q01": [
|
2436 |
+
-0.019992006458342076,
|
2437 |
+
-0.02415412735193968,
|
2438 |
+
-0.022941758055239916,
|
2439 |
+
-0.11085530579090118,
|
2440 |
+
-0.12024572037160397,
|
2441 |
+
-0.13314770206809043,
|
2442 |
+
0.0
|
2443 |
+
],
|
2444 |
+
"q99": [
|
2445 |
+
0.022886231057345868,
|
2446 |
+
0.022358838934451335,
|
2447 |
+
0.02410089675337076,
|
2448 |
+
0.12370114490389822,
|
2449 |
+
0.11323311634361738,
|
2450 |
+
0.18474749639630164,
|
2451 |
+
1.0
|
2452 |
+
],
|
2453 |
+
"std": [
|
2454 |
+
0.008022161200642586,
|
2455 |
+
0.009131459519267082,
|
2456 |
+
0.009574338793754578,
|
2457 |
+
0.04122216999530792,
|
2458 |
+
0.0384303517639637,
|
2459 |
+
0.04606688767671585,
|
2460 |
+
0.49976691603660583
|
2461 |
+
]
|
2462 |
+
},
|
2463 |
+
"num_trajectories": 570,
|
2464 |
+
"num_transitions": 358234,
|
2465 |
+
"proprio": {
|
2466 |
+
"max": [
|
2467 |
+
0.0,
|
2468 |
+
0.0,
|
2469 |
+
0.0,
|
2470 |
+
0.0,
|
2471 |
+
0.0,
|
2472 |
+
0.0,
|
2473 |
+
0.0
|
2474 |
+
],
|
2475 |
+
"mean": [
|
2476 |
+
0.0,
|
2477 |
+
0.0,
|
2478 |
+
0.0,
|
2479 |
+
0.0,
|
2480 |
+
0.0,
|
2481 |
+
0.0,
|
2482 |
+
0.0
|
2483 |
+
],
|
2484 |
+
"min": [
|
2485 |
+
0.0,
|
2486 |
+
0.0,
|
2487 |
+
0.0,
|
2488 |
+
0.0,
|
2489 |
+
0.0,
|
2490 |
+
0.0,
|
2491 |
+
0.0
|
2492 |
+
],
|
2493 |
+
"q01": [
|
2494 |
+
0.0,
|
2495 |
+
0.0,
|
2496 |
+
0.0,
|
2497 |
+
0.0,
|
2498 |
+
0.0,
|
2499 |
+
0.0,
|
2500 |
+
0.0
|
2501 |
+
],
|
2502 |
+
"q99": [
|
2503 |
+
0.0,
|
2504 |
+
0.0,
|
2505 |
+
0.0,
|
2506 |
+
0.0,
|
2507 |
+
0.0,
|
2508 |
+
0.0,
|
2509 |
+
0.0
|
2510 |
+
],
|
2511 |
+
"std": [
|
2512 |
+
0.0,
|
2513 |
+
0.0,
|
2514 |
+
0.0,
|
2515 |
+
0.0,
|
2516 |
+
0.0,
|
2517 |
+
0.0,
|
2518 |
+
0.0
|
2519 |
+
]
|
2520 |
+
}
|
2521 |
+
},
|
2522 |
+
"taco_play": {
|
2523 |
+
"action": {
|
2524 |
+
"mask": [
|
2525 |
+
true,
|
2526 |
+
true,
|
2527 |
+
true,
|
2528 |
+
true,
|
2529 |
+
true,
|
2530 |
+
true,
|
2531 |
+
false
|
2532 |
+
],
|
2533 |
+
"max": [
|
2534 |
+
1.4915844202041626,
|
2535 |
+
2.1842432022094727,
|
2536 |
+
2.6836395263671875,
|
2537 |
+
5.035226821899414,
|
2538 |
+
2.665864944458008,
|
2539 |
+
4.250768661499023,
|
2540 |
+
1.0
|
2541 |
+
],
|
2542 |
+
"mean": [
|
2543 |
+
-0.003845922416076064,
|
2544 |
+
0.009671456180512905,
|
2545 |
+
0.012780580669641495,
|
2546 |
+
-0.005403771996498108,
|
2547 |
+
-0.009606587700545788,
|
2548 |
+
-0.002480733208358288,
|
2549 |
+
0.4263913035392761
|
2550 |
+
],
|
2551 |
+
"min": [
|
2552 |
+
-4.242457866668701,
|
2553 |
+
-3.192805051803589,
|
2554 |
+
-1.3371467590332031,
|
2555 |
+
-4.202683448791504,
|
2556 |
+
-2.6722638607025146,
|
2557 |
+
-3.3467135429382324,
|
2558 |
+
0.0
|
2559 |
+
],
|
2560 |
+
"q01": [
|
2561 |
+
-0.7106140398979186,
|
2562 |
+
-1.056944659948349,
|
2563 |
+
-0.5878450274467468,
|
2564 |
+
-0.7682853937149048,
|
2565 |
+
-0.7180147767066956,
|
2566 |
+
-1.5527938604354858,
|
2567 |
+
0.0
|
2568 |
+
],
|
2569 |
+
"q99": [
|
2570 |
+
0.6482916426658629,
|
2571 |
+
1.0051310062408447,
|
2572 |
+
0.9480248689651489,
|
2573 |
+
0.6926478147506714,
|
2574 |
+
0.6351067513227462,
|
2575 |
+
1.628010264635086,
|
2576 |
+
1.0
|
2577 |
+
],
|
2578 |
+
"std": [
|
2579 |
+
0.23254038393497467,
|
2580 |
+
0.36298269033432007,
|
2581 |
+
0.28692901134490967,
|
2582 |
+
0.2617705166339874,
|
2583 |
+
0.2438892275094986,
|
2584 |
+
0.5216503143310547,
|
2585 |
+
0.4946896731853485
|
2586 |
+
]
|
2587 |
+
},
|
2588 |
+
"num_trajectories": 3603,
|
2589 |
+
"num_transitions": 237798,
|
2590 |
+
"proprio": {
|
2591 |
+
"max": [
|
2592 |
+
0.0,
|
2593 |
+
0.0,
|
2594 |
+
0.0,
|
2595 |
+
0.0,
|
2596 |
+
0.0,
|
2597 |
+
0.0,
|
2598 |
+
0.0
|
2599 |
+
],
|
2600 |
+
"mean": [
|
2601 |
+
0.0,
|
2602 |
+
0.0,
|
2603 |
+
0.0,
|
2604 |
+
0.0,
|
2605 |
+
0.0,
|
2606 |
+
0.0,
|
2607 |
+
0.0
|
2608 |
+
],
|
2609 |
+
"min": [
|
2610 |
+
0.0,
|
2611 |
+
0.0,
|
2612 |
+
0.0,
|
2613 |
+
0.0,
|
2614 |
+
0.0,
|
2615 |
+
0.0,
|
2616 |
+
0.0
|
2617 |
+
],
|
2618 |
+
"q01": [
|
2619 |
+
0.0,
|
2620 |
+
0.0,
|
2621 |
+
0.0,
|
2622 |
+
0.0,
|
2623 |
+
0.0,
|
2624 |
+
0.0,
|
2625 |
+
0.0
|
2626 |
+
],
|
2627 |
+
"q99": [
|
2628 |
+
0.0,
|
2629 |
+
0.0,
|
2630 |
+
0.0,
|
2631 |
+
0.0,
|
2632 |
+
0.0,
|
2633 |
+
0.0,
|
2634 |
+
0.0
|
2635 |
+
],
|
2636 |
+
"std": [
|
2637 |
+
0.0,
|
2638 |
+
0.0,
|
2639 |
+
0.0,
|
2640 |
+
0.0,
|
2641 |
+
0.0,
|
2642 |
+
0.0,
|
2643 |
+
0.0
|
2644 |
+
]
|
2645 |
+
}
|
2646 |
+
},
|
2647 |
+
"toto": {
|
2648 |
+
"action": {
|
2649 |
+
"mask": [
|
2650 |
+
true,
|
2651 |
+
true,
|
2652 |
+
true,
|
2653 |
+
true,
|
2654 |
+
true,
|
2655 |
+
true,
|
2656 |
+
false
|
2657 |
+
],
|
2658 |
+
"max": [
|
2659 |
+
0.6839867234230042,
|
2660 |
+
0.4454185664653778,
|
2661 |
+
0.7984078526496887,
|
2662 |
+
2.120781660079956,
|
2663 |
+
1.371164321899414,
|
2664 |
+
1.4118704795837402,
|
2665 |
+
0.0
|
2666 |
+
],
|
2667 |
+
"mean": [
|
2668 |
+
0.38542115688323975,
|
2669 |
+
0.007769413758069277,
|
2670 |
+
0.3632740378379822,
|
2671 |
+
-0.6652036905288696,
|
2672 |
+
0.1890396922826767,
|
2673 |
+
0.03298724442720413,
|
2674 |
+
0.0
|
2675 |
+
],
|
2676 |
+
"min": [
|
2677 |
+
0.09922284632921219,
|
2678 |
+
-0.5180193781852722,
|
2679 |
+
0.13791072368621826,
|
2680 |
+
-2.635117530822754,
|
2681 |
+
-1.0734480619430542,
|
2682 |
+
-1.9282547235488892,
|
2683 |
+
0.0
|
2684 |
+
],
|
2685 |
+
"q01": [
|
2686 |
+
0.1756722891330719,
|
2687 |
+
-0.3077590811252594,
|
2688 |
+
0.235383919775486,
|
2689 |
+
-2.0908505964279174,
|
2690 |
+
-0.6191593289375306,
|
2691 |
+
-0.7488683319091797,
|
2692 |
+
0.0
|
2693 |
+
],
|
2694 |
+
"q99": [
|
2695 |
+
0.6136963081359863,
|
2696 |
+
0.33704194784164443,
|
2697 |
+
0.6681221985816956,
|
2698 |
+
0.7422861719131538,
|
2699 |
+
0.7955395007133507,
|
2700 |
+
0.740464625358582,
|
2701 |
+
0.0
|
2702 |
+
],
|
2703 |
+
"std": [
|
2704 |
+
0.12211652100086212,
|
2705 |
+
0.19378550350666046,
|
2706 |
+
0.10178236663341522,
|
2707 |
+
0.5725259184837341,
|
2708 |
+
0.29884573817253113,
|
2709 |
+
0.3259911835193634,
|
2710 |
+
0.0
|
2711 |
+
]
|
2712 |
+
},
|
2713 |
+
"num_trajectories": 1003,
|
2714 |
+
"num_transitions": 325699,
|
2715 |
+
"proprio": {
|
2716 |
+
"max": [
|
2717 |
+
0.0,
|
2718 |
+
0.0,
|
2719 |
+
0.0,
|
2720 |
+
0.0,
|
2721 |
+
0.0,
|
2722 |
+
0.0,
|
2723 |
+
0.0
|
2724 |
+
],
|
2725 |
+
"mean": [
|
2726 |
+
0.0,
|
2727 |
+
0.0,
|
2728 |
+
0.0,
|
2729 |
+
0.0,
|
2730 |
+
0.0,
|
2731 |
+
0.0,
|
2732 |
+
0.0
|
2733 |
+
],
|
2734 |
+
"min": [
|
2735 |
+
0.0,
|
2736 |
+
0.0,
|
2737 |
+
0.0,
|
2738 |
+
0.0,
|
2739 |
+
0.0,
|
2740 |
+
0.0,
|
2741 |
+
0.0
|
2742 |
+
],
|
2743 |
+
"q01": [
|
2744 |
+
0.0,
|
2745 |
+
0.0,
|
2746 |
+
0.0,
|
2747 |
+
0.0,
|
2748 |
+
0.0,
|
2749 |
+
0.0,
|
2750 |
+
0.0
|
2751 |
+
],
|
2752 |
+
"q99": [
|
2753 |
+
0.0,
|
2754 |
+
0.0,
|
2755 |
+
0.0,
|
2756 |
+
0.0,
|
2757 |
+
0.0,
|
2758 |
+
0.0,
|
2759 |
+
0.0
|
2760 |
+
],
|
2761 |
+
"std": [
|
2762 |
+
0.0,
|
2763 |
+
0.0,
|
2764 |
+
0.0,
|
2765 |
+
0.0,
|
2766 |
+
0.0,
|
2767 |
+
0.0,
|
2768 |
+
0.0
|
2769 |
+
]
|
2770 |
+
}
|
2771 |
+
},
|
2772 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
2773 |
+
"action": {
|
2774 |
+
"mask": [
|
2775 |
+
true,
|
2776 |
+
true,
|
2777 |
+
true,
|
2778 |
+
true,
|
2779 |
+
true,
|
2780 |
+
true,
|
2781 |
+
false
|
2782 |
+
],
|
2783 |
+
"max": [
|
2784 |
+
678.0,
|
2785 |
+
400.0,
|
2786 |
+
507.0,
|
2787 |
+
180.00001525878906,
|
2788 |
+
6.000013828277588,
|
2789 |
+
116.99998474121094,
|
2790 |
+
1.0
|
2791 |
+
],
|
2792 |
+
"mean": [
|
2793 |
+
410.37567138671875,
|
2794 |
+
116.9518814086914,
|
2795 |
+
192.35032653808594,
|
2796 |
+
-121.22441864013672,
|
2797 |
+
-33.84893035888672,
|
2798 |
+
50.016136169433594,
|
2799 |
+
0.741813600063324
|
2800 |
+
],
|
2801 |
+
"min": [
|
2802 |
+
172.0,
|
2803 |
+
-166.0,
|
2804 |
+
-99.99999237060547,
|
2805 |
+
-180.00001525878906,
|
2806 |
+
-89.0,
|
2807 |
+
-96.00010681152344,
|
2808 |
+
0.0
|
2809 |
+
],
|
2810 |
+
"q01": [
|
2811 |
+
200.00001052856445,
|
2812 |
+
-102.31004211425781,
|
2813 |
+
-94.99993370056153,
|
2814 |
+
-180.00001525878906,
|
2815 |
+
-88.00001525878906,
|
2816 |
+
-38.999977111816406,
|
2817 |
+
0.0
|
2818 |
+
],
|
2819 |
+
"q99": [
|
2820 |
+
637.0,
|
2821 |
+
368.30999999999995,
|
2822 |
+
493.0,
|
2823 |
+
180.00001525878906,
|
2824 |
+
0.999983012676239,
|
2825 |
+
105.00001525878906,
|
2826 |
+
1.0
|
2827 |
+
],
|
2828 |
+
"std": [
|
2829 |
+
122.81494903564453,
|
2830 |
+
108.8009033203125,
|
2831 |
+
130.303466796875,
|
2832 |
+
116.28205108642578,
|
2833 |
+
27.621843338012695,
|
2834 |
+
41.02094650268555,
|
2835 |
+
0.43763357400894165
|
2836 |
+
]
|
2837 |
+
},
|
2838 |
+
"num_trajectories": 150,
|
2839 |
+
"num_transitions": 3970,
|
2840 |
+
"proprio": {
|
2841 |
+
"max": [
|
2842 |
+
0.0,
|
2843 |
+
0.0,
|
2844 |
+
0.0,
|
2845 |
+
0.0,
|
2846 |
+
0.0,
|
2847 |
+
0.0,
|
2848 |
+
0.0
|
2849 |
+
],
|
2850 |
+
"mean": [
|
2851 |
+
0.0,
|
2852 |
+
0.0,
|
2853 |
+
0.0,
|
2854 |
+
0.0,
|
2855 |
+
0.0,
|
2856 |
+
0.0,
|
2857 |
+
0.0
|
2858 |
+
],
|
2859 |
+
"min": [
|
2860 |
+
0.0,
|
2861 |
+
0.0,
|
2862 |
+
0.0,
|
2863 |
+
0.0,
|
2864 |
+
0.0,
|
2865 |
+
0.0,
|
2866 |
+
0.0
|
2867 |
+
],
|
2868 |
+
"q01": [
|
2869 |
+
0.0,
|
2870 |
+
0.0,
|
2871 |
+
0.0,
|
2872 |
+
0.0,
|
2873 |
+
0.0,
|
2874 |
+
0.0,
|
2875 |
+
0.0
|
2876 |
+
],
|
2877 |
+
"q99": [
|
2878 |
+
0.0,
|
2879 |
+
0.0,
|
2880 |
+
0.0,
|
2881 |
+
0.0,
|
2882 |
+
0.0,
|
2883 |
+
0.0,
|
2884 |
+
0.0
|
2885 |
+
],
|
2886 |
+
"std": [
|
2887 |
+
0.0,
|
2888 |
+
0.0,
|
2889 |
+
0.0,
|
2890 |
+
0.0,
|
2891 |
+
0.0,
|
2892 |
+
0.0,
|
2893 |
+
0.0
|
2894 |
+
]
|
2895 |
+
}
|
2896 |
+
},
|
2897 |
+
"utaustin_mutex": {
|
2898 |
+
"action": {
|
2899 |
+
"mask": [
|
2900 |
+
true,
|
2901 |
+
true,
|
2902 |
+
true,
|
2903 |
+
true,
|
2904 |
+
true,
|
2905 |
+
true,
|
2906 |
+
false
|
2907 |
+
],
|
2908 |
+
"max": [
|
2909 |
+
1.0,
|
2910 |
+
1.0,
|
2911 |
+
1.0,
|
2912 |
+
0.375,
|
2913 |
+
0.375,
|
2914 |
+
0.375,
|
2915 |
+
1.0
|
2916 |
+
],
|
2917 |
+
"mean": [
|
2918 |
+
0.06176406890153885,
|
2919 |
+
-0.005005486309528351,
|
2920 |
+
0.10216785222291946,
|
2921 |
+
-0.03314131125807762,
|
2922 |
+
0.013895004987716675,
|
2923 |
+
-0.011317633092403412,
|
2924 |
+
0.5038976669311523
|
2925 |
+
],
|
2926 |
+
"min": [
|
2927 |
+
-1.0,
|
2928 |
+
-1.0,
|
2929 |
+
-1.0,
|
2930 |
+
-0.375,
|
2931 |
+
-0.375,
|
2932 |
+
-0.375,
|
2933 |
+
0.0
|
2934 |
+
],
|
2935 |
+
"q01": [
|
2936 |
+
-0.4285714328289032,
|
2937 |
+
-0.9800000190734863,
|
2938 |
+
-0.5571428537368774,
|
2939 |
+
-0.375,
|
2940 |
+
-0.15642857551574707,
|
2941 |
+
-0.335357129573822,
|
2942 |
+
0.0
|
2943 |
+
],
|
2944 |
+
"q99": [
|
2945 |
+
0.5914285778999329,
|
2946 |
+
0.9714285731315613,
|
2947 |
+
1.0,
|
2948 |
+
0.3278571367263794,
|
2949 |
+
0.207857146859169,
|
2950 |
+
0.25607141852378845,
|
2951 |
+
1.0
|
2952 |
+
],
|
2953 |
+
"std": [
|
2954 |
+
0.1875014752149582,
|
2955 |
+
0.4468473494052887,
|
2956 |
+
0.3792876601219177,
|
2957 |
+
0.14097853004932404,
|
2958 |
+
0.06453701853752136,
|
2959 |
+
0.11765272170305252,
|
2960 |
+
0.501045286655426
|
2961 |
+
]
|
2962 |
+
},
|
2963 |
+
"num_trajectories": 1500,
|
2964 |
+
"num_transitions": 361883,
|
2965 |
+
"proprio": {
|
2966 |
+
"max": [
|
2967 |
+
0.0,
|
2968 |
+
0.0,
|
2969 |
+
0.0,
|
2970 |
+
0.0,
|
2971 |
+
0.0,
|
2972 |
+
0.0,
|
2973 |
+
0.0
|
2974 |
+
],
|
2975 |
+
"mean": [
|
2976 |
+
0.0,
|
2977 |
+
0.0,
|
2978 |
+
0.0,
|
2979 |
+
0.0,
|
2980 |
+
0.0,
|
2981 |
+
0.0,
|
2982 |
+
0.0
|
2983 |
+
],
|
2984 |
+
"min": [
|
2985 |
+
0.0,
|
2986 |
+
0.0,
|
2987 |
+
0.0,
|
2988 |
+
0.0,
|
2989 |
+
0.0,
|
2990 |
+
0.0,
|
2991 |
+
0.0
|
2992 |
+
],
|
2993 |
+
"q01": [
|
2994 |
+
0.0,
|
2995 |
+
0.0,
|
2996 |
+
0.0,
|
2997 |
+
0.0,
|
2998 |
+
0.0,
|
2999 |
+
0.0,
|
3000 |
+
0.0
|
3001 |
+
],
|
3002 |
+
"q99": [
|
3003 |
+
0.0,
|
3004 |
+
0.0,
|
3005 |
+
0.0,
|
3006 |
+
0.0,
|
3007 |
+
0.0,
|
3008 |
+
0.0,
|
3009 |
+
0.0
|
3010 |
+
],
|
3011 |
+
"std": [
|
3012 |
+
0.0,
|
3013 |
+
0.0,
|
3014 |
+
0.0,
|
3015 |
+
0.0,
|
3016 |
+
0.0,
|
3017 |
+
0.0,
|
3018 |
+
0.0
|
3019 |
+
]
|
3020 |
+
}
|
3021 |
+
},
|
3022 |
+
"viola": {
|
3023 |
+
"action": {
|
3024 |
+
"mask": [
|
3025 |
+
true,
|
3026 |
+
true,
|
3027 |
+
true,
|
3028 |
+
true,
|
3029 |
+
true,
|
3030 |
+
true,
|
3031 |
+
false
|
3032 |
+
],
|
3033 |
+
"max": [
|
3034 |
+
1.0,
|
3035 |
+
1.0,
|
3036 |
+
1.0,
|
3037 |
+
0.375,
|
3038 |
+
0.36321428418159485,
|
3039 |
+
0.375,
|
3040 |
+
1.0
|
3041 |
+
],
|
3042 |
+
"mean": [
|
3043 |
+
0.04761844128370285,
|
3044 |
+
-0.029204415157437325,
|
3045 |
+
0.05586736649274826,
|
3046 |
+
-0.002618510741740465,
|
3047 |
+
0.006867344491183758,
|
3048 |
+
-0.01682133786380291,
|
3049 |
+
0.7323777675628662
|
3050 |
+
],
|
3051 |
+
"min": [
|
3052 |
+
-1.0,
|
3053 |
+
-1.0,
|
3054 |
+
-1.0,
|
3055 |
+
-0.375,
|
3056 |
+
-0.375,
|
3057 |
+
-0.375,
|
3058 |
+
0.0
|
3059 |
+
],
|
3060 |
+
"q01": [
|
3061 |
+
-0.9628571271896362,
|
3062 |
+
-1.0,
|
3063 |
+
-1.0,
|
3064 |
+
-0.26249998807907104,
|
3065 |
+
-0.21321429312229156,
|
3066 |
+
-0.3385714292526245,
|
3067 |
+
0.0
|
3068 |
+
],
|
3069 |
+
"q99": [
|
3070 |
+
0.9114285707473755,
|
3071 |
+
0.868571400642395,
|
3072 |
+
1.0,
|
3073 |
+
0.2817857265472412,
|
3074 |
+
0.2239285707473755,
|
3075 |
+
0.3557142913341522,
|
3076 |
+
1.0
|
3077 |
+
],
|
3078 |
+
"std": [
|
3079 |
+
0.39157867431640625,
|
3080 |
+
0.4076525568962097,
|
3081 |
+
0.40077948570251465,
|
3082 |
+
0.10023996233940125,
|
3083 |
+
0.0844319611787796,
|
3084 |
+
0.10375042259693146,
|
3085 |
+
0.44260647892951965
|
3086 |
+
]
|
3087 |
+
},
|
3088 |
+
"num_trajectories": 150,
|
3089 |
+
"num_transitions": 76324,
|
3090 |
+
"proprio": {
|
3091 |
+
"max": [
|
3092 |
+
0.0,
|
3093 |
+
0.0,
|
3094 |
+
0.0,
|
3095 |
+
0.0,
|
3096 |
+
0.0,
|
3097 |
+
0.0,
|
3098 |
+
0.0
|
3099 |
+
],
|
3100 |
+
"mean": [
|
3101 |
+
0.0,
|
3102 |
+
0.0,
|
3103 |
+
0.0,
|
3104 |
+
0.0,
|
3105 |
+
0.0,
|
3106 |
+
0.0,
|
3107 |
+
0.0
|
3108 |
+
],
|
3109 |
+
"min": [
|
3110 |
+
0.0,
|
3111 |
+
0.0,
|
3112 |
+
0.0,
|
3113 |
+
0.0,
|
3114 |
+
0.0,
|
3115 |
+
0.0,
|
3116 |
+
0.0
|
3117 |
+
],
|
3118 |
+
"q01": [
|
3119 |
+
0.0,
|
3120 |
+
0.0,
|
3121 |
+
0.0,
|
3122 |
+
0.0,
|
3123 |
+
0.0,
|
3124 |
+
0.0,
|
3125 |
+
0.0
|
3126 |
+
],
|
3127 |
+
"q99": [
|
3128 |
+
0.0,
|
3129 |
+
0.0,
|
3130 |
+
0.0,
|
3131 |
+
0.0,
|
3132 |
+
0.0,
|
3133 |
+
0.0,
|
3134 |
+
0.0
|
3135 |
+
],
|
3136 |
+
"std": [
|
3137 |
+
0.0,
|
3138 |
+
0.0,
|
3139 |
+
0.0,
|
3140 |
+
0.0,
|
3141 |
+
0.0,
|
3142 |
+
0.0,
|
3143 |
+
0.0
|
3144 |
+
]
|
3145 |
+
}
|
3146 |
+
}
|
3147 |
+
},
|
3148 |
+
"output_projector_states": false,
|
3149 |
+
"pad_to_multiple_of": 64,
|
3150 |
+
"pad_token_id": 32000,
|
3151 |
+
"text_config": {
|
3152 |
+
"bos_token_id": 151643,
|
3153 |
+
"eos_token_id": 151643,
|
3154 |
+
"hidden_size": 896,
|
3155 |
+
"intermediate_size": 4864,
|
3156 |
+
"max_position_embeddings": 32768,
|
3157 |
+
"max_window_layers": 24,
|
3158 |
+
"model_type": "qwen2",
|
3159 |
+
"num_attention_heads": 14,
|
3160 |
+
"num_hidden_layers": 24,
|
3161 |
+
"num_key_value_heads": 2,
|
3162 |
+
"rope_theta": 1000000.0,
|
3163 |
+
"sliding_window": 32768,
|
3164 |
+
"tie_word_embeddings": true,
|
3165 |
+
"torch_dtype": "bfloat16",
|
3166 |
+
"use_mrope": false,
|
3167 |
+
"use_sliding_window": false,
|
3168 |
+
"vocab_size": 151936
|
3169 |
+
},
|
3170 |
+
"timm_model_ids": [
|
3171 |
+
"vit_large_patch14_reg4_dinov2.lvd142m",
|
3172 |
+
"vit_so400m_patch14_siglip_224"
|
3173 |
+
],
|
3174 |
+
"timm_override_act_layers": [
|
3175 |
+
null,
|
3176 |
+
null
|
3177 |
+
],
|
3178 |
+
"torch_dtype": "bfloat16",
|
3179 |
+
"transformers_version": "4.40.1",
|
3180 |
+
"use_fused_vision_backbone": true,
|
3181 |
+
"vision_backbone_id": "dinosiglip-vit-so-224px"
|
3182 |
+
}
|
configuration_prismatic.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
configuration_prismatic.py
|
3 |
+
|
4 |
+
HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
|
5 |
+
Default configuration specifies `siglip-224px+7b`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
|
10 |
+
from transformers import PretrainedConfig
|
11 |
+
from transformers.models.auto import CONFIG_MAPPING
|
12 |
+
|
13 |
+
# === Utilities for Mapping Prismatic names to HF names ===
|
14 |
+
# fmt: off
|
15 |
+
VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
|
16 |
+
"clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
|
17 |
+
|
18 |
+
"clip-vit-l-336px": [336],
|
19 |
+
"siglip-vit-so400m-384px": [384],
|
20 |
+
|
21 |
+
"dinoclip-vit-l-336px": [336, 336],
|
22 |
+
"dinosiglip-vit-so-224px": [224, 224],
|
23 |
+
"dinosiglip-vit-so-384px": [384, 384],
|
24 |
+
}
|
25 |
+
VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
|
26 |
+
"clip-vit-l": ["vit_large_patch14_clip_224.openai"],
|
27 |
+
"clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
|
28 |
+
|
29 |
+
"dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
|
30 |
+
"in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
|
31 |
+
|
32 |
+
"siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
|
33 |
+
"siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
|
34 |
+
|
35 |
+
"dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
|
36 |
+
"dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
|
37 |
+
"dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
|
38 |
+
}
|
39 |
+
TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
|
40 |
+
"clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
|
41 |
+
"dinov2-vit-l": [None], "in1k-vit-l": [None],
|
42 |
+
"siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
|
43 |
+
"dinoclip-vit-l-336px": [None, "quick_gelu"],
|
44 |
+
"dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
|
45 |
+
}
|
46 |
+
|
47 |
+
LLM_BACKBONE_TO_HF_PATH = {
|
48 |
+
"llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
|
49 |
+
"llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
|
50 |
+
|
51 |
+
"vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
|
52 |
+
|
53 |
+
"mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
|
54 |
+
"mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
|
55 |
+
|
56 |
+
"phi-2-3b": "microsoft/phi-2",
|
57 |
+
"qwen25-0_5b-extra": "Qwen/Qwen2.5-0.5B", "qwen25-0_5b-pure": "Qwen/Qwen2.5-0.5B"
|
58 |
+
|
59 |
+
|
60 |
+
}
|
61 |
+
LLM_BACKBONE_TO_HF_METACLASS = {
|
62 |
+
"llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
|
63 |
+
"vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
|
64 |
+
|
65 |
+
"mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
|
66 |
+
|
67 |
+
"phi-2-3b": "phi",
|
68 |
+
"qwen25-0_5b-extra": "qwen2" ,"qwen25-0_5b-pure": "qwen2"
|
69 |
+
}
|
70 |
+
|
71 |
+
VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
|
72 |
+
VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
|
73 |
+
# fmt: on
|
74 |
+
|
75 |
+
|
76 |
+
class PrismaticConfig(PretrainedConfig):
|
77 |
+
model_type: str = "prismatic"
|
78 |
+
is_composition: bool = False
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
vision_backbone_id: str = "siglip-vit-so400m",
|
83 |
+
llm_backbone_id: str = "vicuna-v15-7b",
|
84 |
+
arch_specifier: str = "no-align+gelu-mlp",
|
85 |
+
use_fused_vision_backbone: Optional[bool] = None,
|
86 |
+
image_resize_strategy: str = "letterbox",
|
87 |
+
text_config: Optional[Dict[str, Any]] = None,
|
88 |
+
llm_max_length: int = 2048,
|
89 |
+
pad_token_id: int = 32000,
|
90 |
+
pad_to_multiple_of: int = 64,
|
91 |
+
output_projector_states: bool = False,
|
92 |
+
**kwargs: str,
|
93 |
+
) -> None:
|
94 |
+
if vision_backbone_id not in VALID_VISION_BACKBONES:
|
95 |
+
raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
|
96 |
+
|
97 |
+
if llm_backbone_id not in VALID_LLM_BACKBONES:
|
98 |
+
raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
|
99 |
+
|
100 |
+
# Set Prismatic Configuration Fields
|
101 |
+
self.vision_backbone_id = vision_backbone_id
|
102 |
+
self.llm_backbone_id = llm_backbone_id
|
103 |
+
self.arch_specifier = arch_specifier
|
104 |
+
self.output_projector_states = output_projector_states
|
105 |
+
|
106 |
+
# [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
|
107 |
+
self.use_fused_vision_backbone = (
|
108 |
+
use_fused_vision_backbone
|
109 |
+
if use_fused_vision_backbone is not None
|
110 |
+
else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
|
111 |
+
)
|
112 |
+
|
113 |
+
self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
|
114 |
+
self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
|
115 |
+
self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
|
116 |
+
self.image_resize_strategy = image_resize_strategy
|
117 |
+
|
118 |
+
self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
|
119 |
+
self.llm_max_length = llm_max_length
|
120 |
+
self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
|
121 |
+
|
122 |
+
# [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
|
123 |
+
self.text_config = (
|
124 |
+
CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
|
125 |
+
if text_config is not None
|
126 |
+
else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
|
127 |
+
)
|
128 |
+
|
129 |
+
# Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
|
130 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
131 |
+
|
132 |
+
|
133 |
+
class OpenVLAConfig(PrismaticConfig):
|
134 |
+
model_type: str = "openvla"
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
|
139 |
+
n_action_bins: int = 256,
|
140 |
+
**kwargs: str,
|
141 |
+
) -> None:
|
142 |
+
self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
|
143 |
+
|
144 |
+
super().__init__(**kwargs)
|
dataset_statistics.json
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"libero_spatial_no_noops": {
|
3 |
+
"action": {
|
4 |
+
"mean": [
|
5 |
+
0.15312479436397552,
|
6 |
+
0.13707277178764343,
|
7 |
+
-0.15526802837848663,
|
8 |
+
-0.005176450591534376,
|
9 |
+
-0.01120874285697937,
|
10 |
+
-0.020194264128804207,
|
11 |
+
0.4578818082809448
|
12 |
+
],
|
13 |
+
"std": [
|
14 |
+
0.41272708773612976,
|
15 |
+
0.34724321961402893,
|
16 |
+
0.50869220495224,
|
17 |
+
0.037266165018081665,
|
18 |
+
0.07244449853897095,
|
19 |
+
0.05762382969260216,
|
20 |
+
0.49827873706817627
|
21 |
+
],
|
22 |
+
"max": [
|
23 |
+
0.9375,
|
24 |
+
0.9375,
|
25 |
+
0.9375,
|
26 |
+
0.1971428543329239,
|
27 |
+
0.33642858266830444,
|
28 |
+
0.375,
|
29 |
+
1.0
|
30 |
+
],
|
31 |
+
"min": [
|
32 |
+
-0.9375,
|
33 |
+
-0.9375,
|
34 |
+
-0.9375,
|
35 |
+
-0.1875,
|
36 |
+
-0.3675000071525574,
|
37 |
+
-0.36000001430511475,
|
38 |
+
0.0
|
39 |
+
],
|
40 |
+
"q01": [
|
41 |
+
-0.7454732114076613,
|
42 |
+
-0.6616071462631226,
|
43 |
+
-0.9375,
|
44 |
+
-0.1071428582072258,
|
45 |
+
-0.20678570866584778,
|
46 |
+
-0.1842857152223587,
|
47 |
+
0.0
|
48 |
+
],
|
49 |
+
"q99": [
|
50 |
+
0.9375,
|
51 |
+
0.8758928775787354,
|
52 |
+
0.9321428537368774,
|
53 |
+
0.1039285734295845,
|
54 |
+
0.17678570747375488,
|
55 |
+
0.14571428298950195,
|
56 |
+
1.0
|
57 |
+
],
|
58 |
+
"mask": [
|
59 |
+
true,
|
60 |
+
true,
|
61 |
+
true,
|
62 |
+
true,
|
63 |
+
true,
|
64 |
+
true,
|
65 |
+
false
|
66 |
+
]
|
67 |
+
},
|
68 |
+
"proprio": {
|
69 |
+
"mean": [
|
70 |
+
-0.024462558329105377,
|
71 |
+
0.106529600918293,
|
72 |
+
1.0580483675003052,
|
73 |
+
3.0628468990325928,
|
74 |
+
-0.10464039444923401,
|
75 |
+
0.08307311683893204,
|
76 |
+
0.01995457336306572,
|
77 |
+
-0.020162804052233696
|
78 |
+
],
|
79 |
+
"std": [
|
80 |
+
0.1101478561758995,
|
81 |
+
0.13784688711166382,
|
82 |
+
0.1044282391667366,
|
83 |
+
0.10451053828001022,
|
84 |
+
0.4112098217010498,
|
85 |
+
0.2176690548658371,
|
86 |
+
0.017260896041989326,
|
87 |
+
0.0171116404235363
|
88 |
+
],
|
89 |
+
"max": [
|
90 |
+
0.1759040206670761,
|
91 |
+
0.3904820382595062,
|
92 |
+
1.3290715217590332,
|
93 |
+
3.4566118717193604,
|
94 |
+
1.2268599271774292,
|
95 |
+
1.0429412126541138,
|
96 |
+
0.041053611785173416,
|
97 |
+
0.000775813648942858
|
98 |
+
],
|
99 |
+
"min": [
|
100 |
+
-0.3095473051071167,
|
101 |
+
-0.29250794649124146,
|
102 |
+
0.9095591306686401,
|
103 |
+
2.497488260269165,
|
104 |
+
-1.8006486892700195,
|
105 |
+
-0.7207611203193665,
|
106 |
+
-0.0004703797458205372,
|
107 |
+
-0.041536275297403336
|
108 |
+
],
|
109 |
+
"q01": [
|
110 |
+
-0.2727657300233841,
|
111 |
+
-0.23721413239836692,
|
112 |
+
0.9160063165426254,
|
113 |
+
2.77949666261673,
|
114 |
+
-1.3187511622905732,
|
115 |
+
-0.41989982962608335,
|
116 |
+
0.001503719249740243,
|
117 |
+
-0.03989770736545324
|
118 |
+
],
|
119 |
+
"q99": [
|
120 |
+
0.13529365032911292,
|
121 |
+
0.3629165390133857,
|
122 |
+
1.2862326657772063,
|
123 |
+
3.2829698753356933,
|
124 |
+
0.9332760351896285,
|
125 |
+
0.6325724506378171,
|
126 |
+
0.039933966137468815,
|
127 |
+
-0.001671919699292631
|
128 |
+
]
|
129 |
+
},
|
130 |
+
"num_transitions": 52970,
|
131 |
+
"num_trajectories": 432
|
132 |
+
}
|
133 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 151643,
|
4 |
+
"eos_token_id": 151643,
|
5 |
+
"pad_token_id": 32000,
|
6 |
+
"transformers_version": "4.40.1"
|
7 |
+
}
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff4549dd63f819f7da0ca174f60310f229f214b92c702407e983b3120ece2dea
|
3 |
+
size 2505232584
|
modeling_prismatic.py
ADDED
@@ -0,0 +1,1499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
modeling_prismatic.py
|
3 |
+
|
4 |
+
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
|
5 |
+
Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
|
6 |
+
but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from functools import partial
|
12 |
+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import timm
|
16 |
+
import tokenizers
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import transformers
|
20 |
+
from timm.models.vision_transformer import LayerScale
|
21 |
+
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
22 |
+
from transformers.modeling_outputs import ModelOutput
|
23 |
+
|
24 |
+
from prismatic.training.train_utils import (
|
25 |
+
get_current_action_mask,
|
26 |
+
get_next_actions_mask,
|
27 |
+
)
|
28 |
+
from prismatic.vla.constants import (
|
29 |
+
ACTION_DIM,
|
30 |
+
ACTION_PROPRIO_NORMALIZATION_TYPE,
|
31 |
+
ACTION_TOKEN_BEGIN_IDX,
|
32 |
+
IGNORE_INDEX,
|
33 |
+
NUM_ACTIONS_CHUNK,
|
34 |
+
STOP_INDEX,
|
35 |
+
NormalizationType,
|
36 |
+
NUM_TOKENS
|
37 |
+
)
|
38 |
+
|
39 |
+
from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
# Set up logger
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
# === Utility Functions for Monkey-Patching ===
|
48 |
+
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
49 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
50 |
+
result = fn(*args, **kwargs)
|
51 |
+
return result[0] if isinstance(result, tuple) else result
|
52 |
+
|
53 |
+
return wrapper
|
54 |
+
|
55 |
+
|
56 |
+
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
57 |
+
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
58 |
+
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
59 |
+
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
60 |
+
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
61 |
+
|
62 |
+
|
63 |
+
def ls_apply_patch(ls_module: LayerScale):
|
64 |
+
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
65 |
+
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
66 |
+
del ls_module.gamma
|
67 |
+
|
68 |
+
|
69 |
+
# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
|
70 |
+
class PrismaticVisionBackbone(nn.Module):
|
71 |
+
"""
|
72 |
+
Vision backbone for Prismatic models that handles image feature extraction.
|
73 |
+
|
74 |
+
Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
|
75 |
+
For fused backbones, features from both models are concatenated along the feature dimension.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
use_fused_vision_backbone: bool,
|
81 |
+
image_sizes: List[int],
|
82 |
+
timm_model_ids: List[str],
|
83 |
+
timm_override_act_layers: List[Optional[str]],
|
84 |
+
) -> None:
|
85 |
+
"""
|
86 |
+
Initialize the vision backbone.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
use_fused_vision_backbone: Whether to use two backbones and fuse their features
|
90 |
+
image_sizes: List of image sizes for each backbone
|
91 |
+
timm_model_ids: List of TIMM model IDs to use for each backbone
|
92 |
+
timm_override_act_layers: List of activation layer overrides for each backbone
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
96 |
+
self.num_images_in_input = 1 # Default value, can be overridden later
|
97 |
+
|
98 |
+
# Validate number of (fused) vision backbones
|
99 |
+
if len(timm_model_ids) > 2:
|
100 |
+
raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
|
101 |
+
|
102 |
+
# Create primary featurizer
|
103 |
+
self.featurizer = self._create_featurizer(
|
104 |
+
model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
|
105 |
+
)
|
106 |
+
self.embed_dim = self.featurizer.embed_dim
|
107 |
+
|
108 |
+
# Create secondary featurizer if using fused backbone
|
109 |
+
if self.use_fused_vision_backbone:
|
110 |
+
self.fused_featurizer = self._create_featurizer(
|
111 |
+
model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
|
112 |
+
)
|
113 |
+
self.embed_dim += self.fused_featurizer.embed_dim
|
114 |
+
|
115 |
+
# Patch LayerScale modules for HF compatibility
|
116 |
+
self._patch_layer_scales()
|
117 |
+
|
118 |
+
def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
|
119 |
+
"""
|
120 |
+
Create a TIMM-based featurizer model with appropriate configurations.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
model_id: The TIMM model ID to load
|
124 |
+
img_size: Input image size for the model
|
125 |
+
act_layer: Override for the activation layer type
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
A configured featurizer model
|
129 |
+
"""
|
130 |
+
featurizer = timm.create_model(
|
131 |
+
model_id,
|
132 |
+
pretrained=False,
|
133 |
+
num_classes=0,
|
134 |
+
img_size=img_size,
|
135 |
+
act_layer=act_layer,
|
136 |
+
)
|
137 |
+
|
138 |
+
# Monkey-patch the forward function to extract the second-to-last layer features
|
139 |
+
num_blocks = len(featurizer.blocks)
|
140 |
+
featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
|
141 |
+
|
142 |
+
return featurizer
|
143 |
+
|
144 |
+
def _patch_layer_scales(self) -> None:
|
145 |
+
"""
|
146 |
+
Patch all LayerScale modules to be compatible with HF's parameter naming.
|
147 |
+
|
148 |
+
HF Transformers overwrites parameters with names containing 'gamma',
|
149 |
+
so we need to rename and modify the forward method.
|
150 |
+
"""
|
151 |
+
# Patch primary featurizer
|
152 |
+
for module in self.featurizer.modules():
|
153 |
+
if isinstance(module, LayerScale):
|
154 |
+
ls_apply_patch(module)
|
155 |
+
|
156 |
+
# Patch secondary featurizer if it exists
|
157 |
+
if self.use_fused_vision_backbone:
|
158 |
+
for module in self.fused_featurizer.modules():
|
159 |
+
if isinstance(module, LayerScale):
|
160 |
+
ls_apply_patch(module)
|
161 |
+
|
162 |
+
def get_num_patches(self) -> int:
|
163 |
+
"""
|
164 |
+
Returns the number of vision patches output by the vision backbone.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Number of patches per image
|
168 |
+
"""
|
169 |
+
return self.featurizer.patch_embed.num_patches
|
170 |
+
|
171 |
+
def get_num_images_in_input(self) -> int:
|
172 |
+
"""
|
173 |
+
Returns the number of input images for the vision backbone.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Number of images expected in the input
|
177 |
+
"""
|
178 |
+
return self.num_images_in_input
|
179 |
+
|
180 |
+
def set_num_images_in_input(self, num_images_in_input: int) -> None:
|
181 |
+
"""
|
182 |
+
Sets the number of input images for the vision backbone.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
num_images_in_input: Number of images to expect in the input
|
186 |
+
"""
|
187 |
+
self.num_images_in_input = num_images_in_input
|
188 |
+
|
189 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
190 |
+
"""
|
191 |
+
Implements the forward pass for the vision backbone.
|
192 |
+
|
193 |
+
If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
|
194 |
+
(otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
|
195 |
+
|
196 |
+
Args:
|
197 |
+
pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
|
198 |
+
"""
|
199 |
+
if self.num_images_in_input == 1:
|
200 |
+
if not self.use_fused_vision_backbone:
|
201 |
+
return self.featurizer(pixel_values)
|
202 |
+
|
203 |
+
# Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
|
204 |
+
img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
|
205 |
+
patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
|
206 |
+
|
207 |
+
return torch.cat([patches, patches_fused], dim=2)
|
208 |
+
|
209 |
+
else:
|
210 |
+
assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
|
211 |
+
|
212 |
+
# Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
|
213 |
+
images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
|
214 |
+
|
215 |
+
# Process each image and collect patches
|
216 |
+
all_patches = []
|
217 |
+
for img in images:
|
218 |
+
# Split each image further into two stacks of channels (each with 3 channels)
|
219 |
+
img_regular, img_fused = torch.split(img, [3, 3], dim=1)
|
220 |
+
|
221 |
+
# Get patches from both SigLIP and DINOv2 vision transformers
|
222 |
+
patches = self.featurizer(img_regular)
|
223 |
+
patches_fused = self.fused_featurizer(img_fused)
|
224 |
+
|
225 |
+
# Concatenate SigLIP and DINOv2 patches along the hidden dimension
|
226 |
+
combined_patches = torch.cat([patches, patches_fused], dim=2)
|
227 |
+
all_patches.append(combined_patches)
|
228 |
+
|
229 |
+
# Concatenate all patches along the patch dimension
|
230 |
+
return torch.cat(all_patches, dim=1)
|
231 |
+
|
232 |
+
|
233 |
+
# === Prismatic Projector (nn.Module) Definitions ===
|
234 |
+
class PrismaticProjector(nn.Module):
|
235 |
+
def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
|
236 |
+
super().__init__()
|
237 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
238 |
+
self.vision_dim, self.llm_dim = vision_dim, llm_dim
|
239 |
+
|
240 |
+
# Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
|
241 |
+
if not self.use_fused_vision_backbone:
|
242 |
+
self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
|
243 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
244 |
+
self.act_fn1 = nn.GELU()
|
245 |
+
else:
|
246 |
+
initial_projection_dim = 4 * vision_dim
|
247 |
+
self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
|
248 |
+
self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
|
249 |
+
self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
250 |
+
self.act_fn1 = nn.GELU()
|
251 |
+
self.act_fn2 = nn.GELU()
|
252 |
+
|
253 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
254 |
+
if not self.use_fused_vision_backbone:
|
255 |
+
projected_features = self.fc1(img_patches)
|
256 |
+
projected_features = self.act_fn1(projected_features)
|
257 |
+
projected_features = self.fc2(projected_features)
|
258 |
+
else:
|
259 |
+
projected_features = self.fc1(img_patches)
|
260 |
+
projected_features = self.act_fn1(projected_features)
|
261 |
+
projected_features = self.fc2(projected_features)
|
262 |
+
projected_features = self.act_fn2(projected_features)
|
263 |
+
projected_features = self.fc3(projected_features)
|
264 |
+
|
265 |
+
return projected_features
|
266 |
+
|
267 |
+
|
268 |
+
# === Main HF Class Definitions ===
|
269 |
+
@dataclass
|
270 |
+
class PrismaticCausalLMOutputWithPast(ModelOutput):
|
271 |
+
"""Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
|
272 |
+
|
273 |
+
loss: Optional[torch.FloatTensor] = None
|
274 |
+
logits: torch.FloatTensor = None
|
275 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
276 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
277 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
278 |
+
|
279 |
+
# Additions for VLMs
|
280 |
+
projector_features: Optional[torch.FloatTensor] = None
|
281 |
+
|
282 |
+
|
283 |
+
class PrismaticPreTrainedModel(PreTrainedModel):
|
284 |
+
config_class: PretrainedConfig = PrismaticConfig
|
285 |
+
base_model_prefix: str = "model"
|
286 |
+
supports_gradient_checkpointing: bool = True
|
287 |
+
|
288 |
+
_no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
|
289 |
+
_skip_keys_device_placement: str = "past_key_values"
|
290 |
+
_supports_flash_attn_2: bool = True
|
291 |
+
|
292 |
+
def _init_weights(self, module: nn.Module) -> None:
|
293 |
+
# Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
|
294 |
+
# => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
|
295 |
+
# https://github.com/TRI-ML/prismatic-vlms
|
296 |
+
std = (
|
297 |
+
self.config.initializer_range
|
298 |
+
if hasattr(self.config, "initializer_range")
|
299 |
+
else self.config.text_config.initializer_range
|
300 |
+
)
|
301 |
+
|
302 |
+
if hasattr(module, "class_embedding"):
|
303 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
304 |
+
|
305 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
306 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
307 |
+
if module.bias is not None:
|
308 |
+
module.bias.data.zero_()
|
309 |
+
elif isinstance(module, nn.Embedding):
|
310 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
311 |
+
if module.padding_idx is not None:
|
312 |
+
module.weight.data[module.padding_idx].zero_()
|
313 |
+
|
314 |
+
@property
|
315 |
+
def _supports_sdpa(self) -> bool:
|
316 |
+
"""Check LLM supports SDPA Attention"""
|
317 |
+
return self.language_model._supports_sdpa
|
318 |
+
|
319 |
+
|
320 |
+
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
321 |
+
def __init__(self, config: PrismaticConfig) -> None:
|
322 |
+
super().__init__(config)
|
323 |
+
|
324 |
+
# [Validation] Lightweight Validate on `config` Fields + Dependency Versions
|
325 |
+
if config.use_fused_vision_backbone is None:
|
326 |
+
raise ValueError("Missing config field `use_fused_vision_backbone`")
|
327 |
+
|
328 |
+
if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
|
329 |
+
raise NotImplementedError(
|
330 |
+
"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
|
331 |
+
"if you urgently need support for latest TIMM versions."
|
332 |
+
)
|
333 |
+
|
334 |
+
if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
|
335 |
+
logger.warning(
|
336 |
+
f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
|
337 |
+
f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
|
338 |
+
f"there might be inference-time regressions due to dependency changes. If in doubt, please"
|
339 |
+
f"use the above versions."
|
340 |
+
)
|
341 |
+
# import pdb; pdb.set_trace()
|
342 |
+
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
|
343 |
+
self.vision_backbone = PrismaticVisionBackbone(
|
344 |
+
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
|
345 |
+
)
|
346 |
+
|
347 |
+
# Create Multimodal Projector
|
348 |
+
self.projector = PrismaticProjector(
|
349 |
+
config.use_fused_vision_backbone,
|
350 |
+
vision_dim=self.vision_backbone.embed_dim,
|
351 |
+
llm_dim=config.text_config.hidden_size,
|
352 |
+
)
|
353 |
+
|
354 |
+
# Instantiate LLM Backbone
|
355 |
+
self.language_model = AutoModelForCausalLM.from_config(
|
356 |
+
config.text_config, attn_implementation=config._attn_implementation
|
357 |
+
)
|
358 |
+
|
359 |
+
self.vocab_size = config.text_config.vocab_size
|
360 |
+
self.pad_token_id = config.pad_token_id
|
361 |
+
self.llm_dim = config.text_config.hidden_size
|
362 |
+
|
363 |
+
#Action query token
|
364 |
+
self.action_queries = nn.Embedding(NUM_TOKENS, self.llm_dim)
|
365 |
+
self.action_queries.weight.data.zero_()
|
366 |
+
|
367 |
+
# HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
|
368 |
+
self.post_init()
|
369 |
+
|
370 |
+
# === `PreTrainedModel` Boilerplate ===
|
371 |
+
def get_input_embeddings(self) -> nn.Module:
|
372 |
+
return self.language_model.get_input_embeddings()
|
373 |
+
def set_version(self, version: str):
|
374 |
+
self.version = version
|
375 |
+
return self.version
|
376 |
+
|
377 |
+
|
378 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
379 |
+
self.language_model.set_input_embeddings(value)
|
380 |
+
|
381 |
+
def get_output_embeddings(self) -> nn.Module:
|
382 |
+
return self.language_model.get_output_embeddings()
|
383 |
+
|
384 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
385 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
386 |
+
|
387 |
+
def get_decoder(self) -> nn.Module:
|
388 |
+
return self.language_model.get_decoder()
|
389 |
+
|
390 |
+
def set_decoder(self, decoder: nn.Module) -> None:
|
391 |
+
self.language_model.set_decoder(decoder)
|
392 |
+
|
393 |
+
def tie_weights(self) -> None:
|
394 |
+
self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
|
395 |
+
|
396 |
+
def resize_token_embeddings(
|
397 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
398 |
+
) -> nn.Embedding:
|
399 |
+
updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
400 |
+
|
401 |
+
# Update config/instance variables
|
402 |
+
self.config.text_config.vocab_size = updated_embeddings.num_embeddings
|
403 |
+
self.vocab_size = updated_embeddings.num_embeddings
|
404 |
+
|
405 |
+
return updated_embeddings
|
406 |
+
|
407 |
+
def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
|
408 |
+
"""
|
409 |
+
Replace embeddings in input_embeddings at positions where all_actions_mask is True
|
410 |
+
with embeddings from noisy_action_features, using vectorized operations.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
input_embeddings: Tensor of shape (B, S, D)
|
414 |
+
all_actions_mask: Boolean tensor of shape (B, S)
|
415 |
+
noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
Modified input_embeddings tensor
|
419 |
+
"""
|
420 |
+
# Clone input to avoid modifying the original tensor
|
421 |
+
new_input_embeddings = input_embeddings.clone()
|
422 |
+
|
423 |
+
# Create a tensor with the same shape of input_embeddings to hold the noisy action features
|
424 |
+
repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
|
425 |
+
|
426 |
+
# Create batch indices for splicing
|
427 |
+
batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
|
428 |
+
batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
|
429 |
+
|
430 |
+
# Get indices where mask is True for each sample
|
431 |
+
masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
|
432 |
+
|
433 |
+
# Move the noisy action features into their correct positions
|
434 |
+
# print(noisy_action_features.size())
|
435 |
+
# import pdb; pdb.set_trace()
|
436 |
+
repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
|
437 |
+
|
438 |
+
# Combine original input embeddings and noisy action embeddings using the mask
|
439 |
+
new_input_embeddings = torch.where(
|
440 |
+
all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
|
441 |
+
)
|
442 |
+
|
443 |
+
return new_input_embeddings
|
444 |
+
|
445 |
+
def _process_action_masks(self, labels):
|
446 |
+
"""Helper to get action masks from labels"""
|
447 |
+
current_action_mask = get_current_action_mask(labels)
|
448 |
+
next_actions_mask = get_next_actions_mask(labels)
|
449 |
+
all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
|
450 |
+
return all_actions_mask
|
451 |
+
|
452 |
+
def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
|
453 |
+
"""Process vision features with optional FiLM conditioning"""
|
454 |
+
if use_film:
|
455 |
+
# FiLM: Infuse language inputs into visual features
|
456 |
+
patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
|
457 |
+
else:
|
458 |
+
patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
|
459 |
+
|
460 |
+
# Project patch embeddings into language embedding space
|
461 |
+
return self.projector(patch_features)
|
462 |
+
|
463 |
+
def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
|
464 |
+
"""Process proprioceptive features and append to vision features"""
|
465 |
+
if proprio_projector is not None and proprio is not None:
|
466 |
+
# projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
|
467 |
+
# proprio: (bsz, proprio_dim) or (propro_dim,)
|
468 |
+
proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
|
469 |
+
proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
|
470 |
+
proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
|
471 |
+
# For simplicity, just append proprio token to the end of projected vision patch tokens
|
472 |
+
return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
|
473 |
+
return projected_patch_embeddings
|
474 |
+
|
475 |
+
def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
|
476 |
+
"""Build multimodal embeddings and attention mask"""
|
477 |
+
# Update attention mask
|
478 |
+
# import pdb; pdb.set_trace()
|
479 |
+
projected_patch_attention_mask = None
|
480 |
+
if attention_mask is not None:
|
481 |
+
projected_patch_attention_mask = torch.full(
|
482 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
483 |
+
fill_value=True,
|
484 |
+
dtype=attention_mask.dtype,
|
485 |
+
device=attention_mask.device,
|
486 |
+
)
|
487 |
+
|
488 |
+
# Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
|
489 |
+
multimodal_embeddings = torch.cat(
|
490 |
+
[input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
|
491 |
+
)
|
492 |
+
|
493 |
+
multimodal_attention_mask = None
|
494 |
+
if attention_mask is not None:
|
495 |
+
multimodal_attention_mask = torch.cat(
|
496 |
+
[attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
|
497 |
+
)
|
498 |
+
|
499 |
+
return multimodal_embeddings, multimodal_attention_mask
|
500 |
+
|
501 |
+
def _build_multimodal_labels(self, labels, projected_patch_embeddings):
|
502 |
+
"""Build multimodal labels with IGNORE_INDEX for patch embeddings"""
|
503 |
+
if labels is not None:
|
504 |
+
projected_patch_labels = torch.full(
|
505 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
506 |
+
fill_value=IGNORE_INDEX,
|
507 |
+
dtype=labels.dtype,
|
508 |
+
device=labels.device,
|
509 |
+
)
|
510 |
+
return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
|
511 |
+
return None
|
512 |
+
|
513 |
+
# === Core Prismatic VLM `forward()` Logic ===
|
514 |
+
def forward(
|
515 |
+
self,
|
516 |
+
input_ids: Optional[torch.LongTensor] = None,
|
517 |
+
attention_mask: Optional[torch.Tensor] = None,
|
518 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
519 |
+
labels: Optional[torch.LongTensor] = None,
|
520 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
521 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
522 |
+
use_cache: Optional[bool] = None,
|
523 |
+
output_attentions: Optional[bool] = None,
|
524 |
+
output_hidden_states: Optional[bool] = None,
|
525 |
+
output_projector_features: Optional[bool] = None,
|
526 |
+
return_dict: Optional[bool] = None,
|
527 |
+
proprio=None,
|
528 |
+
proprio_projector=None,
|
529 |
+
noisy_actions=None,
|
530 |
+
noisy_action_projector=None,
|
531 |
+
diffusion_timestep_embeddings=None,
|
532 |
+
use_film: bool = False,
|
533 |
+
) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
|
534 |
+
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
|
535 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
536 |
+
output_hidden_states = (
|
537 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
538 |
+
)
|
539 |
+
output_projector_features = output_projector_features if output_projector_features is not None else False
|
540 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
541 |
+
|
542 |
+
# Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
|
543 |
+
use_cache = use_cache and not self.training
|
544 |
+
|
545 |
+
# Instantiate Placeholder for Projector Features
|
546 |
+
projected_patch_embeddings = None
|
547 |
+
|
548 |
+
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
549 |
+
if input_ids.shape[1] == 1:
|
550 |
+
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
551 |
+
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
552 |
+
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
553 |
+
|
554 |
+
language_model_output = self.language_model(
|
555 |
+
input_ids=input_ids,
|
556 |
+
attention_mask=None,
|
557 |
+
position_ids=None,
|
558 |
+
past_key_values=past_key_values,
|
559 |
+
inputs_embeds=None,
|
560 |
+
labels=None,
|
561 |
+
use_cache=use_cache,
|
562 |
+
output_attentions=output_attentions,
|
563 |
+
output_hidden_states=output_hidden_states,
|
564 |
+
return_dict=return_dict,
|
565 |
+
)
|
566 |
+
|
567 |
+
# === Handle Unimodal Forward ===
|
568 |
+
elif pixel_values is None:
|
569 |
+
assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
|
570 |
+
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
571 |
+
|
572 |
+
language_model_output = self.language_model(
|
573 |
+
input_ids=input_ids,
|
574 |
+
attention_mask=attention_mask,
|
575 |
+
position_ids=None,
|
576 |
+
past_key_values=None,
|
577 |
+
inputs_embeds=None,
|
578 |
+
labels=labels,
|
579 |
+
use_cache=use_cache,
|
580 |
+
output_attentions=output_attentions,
|
581 |
+
output_hidden_states=output_hidden_states,
|
582 |
+
return_dict=return_dict,
|
583 |
+
)
|
584 |
+
|
585 |
+
# === Handle Multimodal Forward ===
|
586 |
+
elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
|
587 |
+
assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
|
588 |
+
|
589 |
+
# Get input embeddings (from language model embeddings)
|
590 |
+
input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
|
591 |
+
|
592 |
+
# import pdb; pdb.set_trace()
|
593 |
+
# Extract action masks
|
594 |
+
all_actions_mask = self._process_action_masks(labels)
|
595 |
+
|
596 |
+
# Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
|
597 |
+
# import pdb; pdb.set_trace()
|
598 |
+
# print(input_embeddings[~all_actions_mask].size())
|
599 |
+
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
600 |
+
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
601 |
+
) # (B, lang_seq_len, llm_dim)
|
602 |
+
|
603 |
+
# Get visual features
|
604 |
+
projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
|
605 |
+
|
606 |
+
# Add proprioceptive state if provided
|
607 |
+
if self.version == 'v1':
|
608 |
+
pass
|
609 |
+
else:
|
610 |
+
projected_patch_embeddings = self._process_proprio_features(
|
611 |
+
projected_patch_embeddings, proprio, proprio_projector
|
612 |
+
)
|
613 |
+
|
614 |
+
# [Diffusion] Add diffusion timestep embedding if provided
|
615 |
+
if diffusion_timestep_embeddings is not None:
|
616 |
+
if self.version == 'v1':
|
617 |
+
pass
|
618 |
+
else:
|
619 |
+
# For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
|
620 |
+
projected_patch_embeddings = torch.cat(
|
621 |
+
(projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
|
622 |
+
)
|
623 |
+
|
624 |
+
|
625 |
+
# Process action embeddings
|
626 |
+
if noisy_actions is not None:
|
627 |
+
# import pdb; pdb.set_trace()
|
628 |
+
if self.version == 'v1':
|
629 |
+
# action_queries = self.action_queries.weight # (1, h)
|
630 |
+
# action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
631 |
+
# input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
|
632 |
+
# action_attention_mask = None
|
633 |
+
# action_attention_mask = torch.full(
|
634 |
+
# (action_queries.shape[0], action_queries.shape[1]),
|
635 |
+
# fill_value=True,
|
636 |
+
# dtype=attention_mask.dtype,
|
637 |
+
# device=attention_mask.device,)
|
638 |
+
# attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
|
639 |
+
|
640 |
+
action_queries = self.action_queries.weight # (1, h)
|
641 |
+
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
642 |
+
all_actions_mask = self._process_action_masks(labels)
|
643 |
+
input_embeddings = self._replace_input_embeddings(
|
644 |
+
input_embeddings, all_actions_mask, action_queries)
|
645 |
+
# import pdb; pdb.set_trace()
|
646 |
+
|
647 |
+
else:
|
648 |
+
# Get mask corresponding to all action tokens
|
649 |
+
all_actions_mask = self._process_action_masks(labels)
|
650 |
+
|
651 |
+
# Reshape noisy actions into individual action tokens
|
652 |
+
# noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
|
653 |
+
B = noisy_actions.shape[0]
|
654 |
+
noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
|
655 |
+
# Project noisy action tokens into language model embedding space
|
656 |
+
noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
|
657 |
+
# Replace embeddings of the action tokens with noisy action embeddings
|
658 |
+
input_embeddings = self._replace_input_embeddings(
|
659 |
+
input_embeddings, all_actions_mask, noisy_action_features)
|
660 |
+
|
661 |
+
else:
|
662 |
+
if self.version == 'v1':
|
663 |
+
action_queries = self.action_queries.weight # (1, h)
|
664 |
+
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
665 |
+
all_actions_mask = self._process_action_masks(labels)
|
666 |
+
input_embeddings = self._replace_input_embeddings(
|
667 |
+
input_embeddings, all_actions_mask, action_queries)
|
668 |
+
# import pdb; pdb.set_trace()
|
669 |
+
else:
|
670 |
+
# Replace the embeddings of the action tokens with zeros
|
671 |
+
# (Later on, the positional embeddings will be added to them)
|
672 |
+
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
|
673 |
+
input_embeddings = input_embeddings * ~all_actions_mask
|
674 |
+
|
675 |
+
|
676 |
+
# Build multimodal embeddings & attention mask
|
677 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
678 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
679 |
+
)
|
680 |
+
# import pdb; pdb.set_trace()
|
681 |
+
# Build labels for multimodal sequence if needed
|
682 |
+
multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
|
683 |
+
|
684 |
+
# import pdb; pdb.set_trace()
|
685 |
+
# Dispatch to language model
|
686 |
+
if self.version == 'v1':
|
687 |
+
# import pdb; pdb.set_trace()
|
688 |
+
language_model_output = self.language_model(
|
689 |
+
input_ids=None,
|
690 |
+
attention_mask=multimodal_attention_mask,
|
691 |
+
position_ids=None,
|
692 |
+
past_key_values=None,
|
693 |
+
inputs_embeds=multimodal_embeddings,
|
694 |
+
labels=None,
|
695 |
+
use_cache=use_cache,
|
696 |
+
output_attentions=output_attentions,
|
697 |
+
output_hidden_states=output_hidden_states,
|
698 |
+
return_dict=return_dict,
|
699 |
+
)
|
700 |
+
# import pdb; pdb.set_trace()
|
701 |
+
else:
|
702 |
+
language_model_output = self.language_model(
|
703 |
+
input_ids=None,
|
704 |
+
attention_mask=multimodal_attention_mask,
|
705 |
+
position_ids=None,
|
706 |
+
past_key_values=None,
|
707 |
+
inputs_embeds=multimodal_embeddings,
|
708 |
+
labels=multimodal_labels,
|
709 |
+
use_cache=use_cache,
|
710 |
+
output_attentions=output_attentions,
|
711 |
+
output_hidden_states=output_hidden_states,
|
712 |
+
return_dict=return_dict,
|
713 |
+
)
|
714 |
+
|
715 |
+
# === Otherwise =>> Assume Invalid! ===
|
716 |
+
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
|
717 |
+
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
|
718 |
+
|
719 |
+
else:
|
720 |
+
raise ValueError(
|
721 |
+
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
|
722 |
+
f"=> `input_ids` = {input_ids is not None}\n"
|
723 |
+
f"=> `attention_mask` = {attention_mask is not None}\n"
|
724 |
+
f"=> `pixel_values` = {pixel_values is not None}\n"
|
725 |
+
f"=> `labels` = {labels is not None}\n"
|
726 |
+
f"=> `input_embeds` = {inputs_embeds is not None}\n"
|
727 |
+
f"=> `past_key_values` = {past_key_values is not None}\n"
|
728 |
+
f"=> `use_cache` = {use_cache}"
|
729 |
+
)
|
730 |
+
|
731 |
+
# Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
|
732 |
+
if not return_dict:
|
733 |
+
if output_projector_features and (projected_patch_embeddings is not None):
|
734 |
+
return *language_model_output, projected_patch_embeddings
|
735 |
+
|
736 |
+
return language_model_output
|
737 |
+
|
738 |
+
if self.version == 'v1':
|
739 |
+
return PrismaticCausalLMOutputWithPast(
|
740 |
+
loss=language_model_output.loss,
|
741 |
+
past_key_values=language_model_output.past_key_values,
|
742 |
+
hidden_states=language_model_output.hidden_states,
|
743 |
+
attentions=language_model_output.attentions,
|
744 |
+
projector_features=projected_patch_embeddings,
|
745 |
+
)
|
746 |
+
else:
|
747 |
+
return PrismaticCausalLMOutputWithPast(
|
748 |
+
loss=language_model_output.loss,
|
749 |
+
logits=language_model_output.logits,
|
750 |
+
past_key_values=language_model_output.past_key_values,
|
751 |
+
hidden_states=language_model_output.hidden_states,
|
752 |
+
attentions=language_model_output.attentions,
|
753 |
+
projector_features=projected_patch_embeddings,
|
754 |
+
)
|
755 |
+
|
756 |
+
# === GenerationMixin Methods ===
|
757 |
+
def prepare_inputs_for_generation(
|
758 |
+
self,
|
759 |
+
input_ids: Optional[torch.Tensor] = None,
|
760 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
761 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
762 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
763 |
+
attention_mask: Optional[torch.Tensor] = None,
|
764 |
+
**kwargs: str,
|
765 |
+
) -> Dict[str, torch.Tensor]:
|
766 |
+
"""Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
|
767 |
+
if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
|
768 |
+
(inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
|
769 |
+
):
|
770 |
+
raise ValueError("Generation with batch size > 1 is not currently supported!")
|
771 |
+
|
772 |
+
# Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
|
773 |
+
if past_key_values is not None:
|
774 |
+
input_ids = input_ids[:, -1:]
|
775 |
+
|
776 |
+
# If `input_embeds` are passed, we only want to use them in the 1st generation step
|
777 |
+
if inputs_embeds is not None and past_key_values is None:
|
778 |
+
model_inputs = {"input_embeds": inputs_embeds}
|
779 |
+
else:
|
780 |
+
model_inputs = {"input_ids": input_ids}
|
781 |
+
|
782 |
+
# Make sure `pixel_values` are preserved in `model_inputs`
|
783 |
+
model_inputs.update(
|
784 |
+
{
|
785 |
+
"attention_mask": attention_mask,
|
786 |
+
"pixel_values": pixel_values,
|
787 |
+
"past_key_values": past_key_values,
|
788 |
+
"use_cache": kwargs.get("use_cache"),
|
789 |
+
}
|
790 |
+
)
|
791 |
+
|
792 |
+
return model_inputs
|
793 |
+
|
794 |
+
# Defer to Language Model (all handle this differently, with different return types)
|
795 |
+
def _reorder_cache(self, *args, **kwargs) -> Any:
|
796 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
797 |
+
|
798 |
+
|
799 |
+
class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
800 |
+
config_class: PretrainedConfig = OpenVLAConfig
|
801 |
+
|
802 |
+
def __init__(self, config: OpenVLAConfig) -> None:
|
803 |
+
super().__init__(config)
|
804 |
+
self.norm_stats = config.norm_stats
|
805 |
+
# import pdb; pdb.set_trace()
|
806 |
+
|
807 |
+
# Compute action bins
|
808 |
+
self.bins = np.linspace(-1, 1, config.n_action_bins)
|
809 |
+
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
|
810 |
+
|
811 |
+
# Compute vocab size for de-tokenization -- revert added "multiple of"
|
812 |
+
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
813 |
+
|
814 |
+
def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
|
815 |
+
"""Prepares input for action prediction by adding necessary tokens"""
|
816 |
+
# Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
|
817 |
+
placeholder_action_token_ids = (
|
818 |
+
torch.ones((input_ids.shape[0], NUM_TOKENS)).to(input_ids.device).to(input_ids.dtype)
|
819 |
+
)
|
820 |
+
input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
|
821 |
+
|
822 |
+
# Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
|
823 |
+
stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
|
824 |
+
input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
|
825 |
+
|
826 |
+
# Extend the attention mask to fit the new shape of input
|
827 |
+
# Note: Only batch size == 1 supported right now
|
828 |
+
mask_extension = (
|
829 |
+
torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
|
830 |
+
.to(attention_mask.device)
|
831 |
+
.to(attention_mask.dtype)
|
832 |
+
)
|
833 |
+
attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
|
834 |
+
|
835 |
+
return input_ids, attention_mask
|
836 |
+
|
837 |
+
def _prepare_labels_for_action_prediction(self, labels, input_ids):
|
838 |
+
"""Creates labels tensor for action prediction if not provided"""
|
839 |
+
# Extend labels tensor with fake action labels
|
840 |
+
ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
|
841 |
+
labels_extension = (
|
842 |
+
torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
|
843 |
+
* ARBITRARY_ACTION_TOKEN_IDX
|
844 |
+
)
|
845 |
+
labels = torch.cat([labels, labels_extension], dim=-1)
|
846 |
+
|
847 |
+
# Replace last label token with stop token
|
848 |
+
labels[:, -1] = STOP_INDEX
|
849 |
+
|
850 |
+
return labels
|
851 |
+
|
852 |
+
def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
|
853 |
+
"""Unnormalize actions using dataset statistics"""
|
854 |
+
action_norm_stats = self.get_action_stats(unnorm_key)
|
855 |
+
|
856 |
+
if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
|
857 |
+
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
|
858 |
+
action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
|
859 |
+
elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
|
860 |
+
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
|
861 |
+
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
|
862 |
+
else:
|
863 |
+
raise ValueError("Unsupported action/proprio normalization type detected!")
|
864 |
+
|
865 |
+
actions = np.where(
|
866 |
+
mask,
|
867 |
+
0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
|
868 |
+
normalized_actions,
|
869 |
+
)
|
870 |
+
|
871 |
+
return actions
|
872 |
+
|
873 |
+
def _run_flow_matching_prediction(
|
874 |
+
self,
|
875 |
+
input_embeddings,
|
876 |
+
all_actions_mask,
|
877 |
+
noise,
|
878 |
+
action_head,
|
879 |
+
projected_patch_embeddings,
|
880 |
+
labels,
|
881 |
+
attention_mask,
|
882 |
+
NUM_PATCHES,
|
883 |
+
NUM_PROMPT_TOKENS,
|
884 |
+
noisy_action_projector
|
885 |
+
):
|
886 |
+
"""Run flow matching-based action prediction"""
|
887 |
+
# Clone embedding for reuse in each timestep
|
888 |
+
# orig_projected_patch_embeddings = projected_patch_embeddings.clone()
|
889 |
+
|
890 |
+
dt = -1.0 / action_head.num_flow_steps
|
891 |
+
dt = torch.tensor(dt, dtype=torch.bfloat16, device=labels.device)
|
892 |
+
|
893 |
+
curr_noisy_actions = noise
|
894 |
+
time = torch.tensor(1.0, dtype=torch.bfloat16, device=labels.device)
|
895 |
+
while time >= -dt / 2:
|
896 |
+
B = curr_noisy_actions.shape[0]
|
897 |
+
orig_curr_noisy_actions_shape = curr_noisy_actions.shape
|
898 |
+
curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
|
899 |
+
noisy_action_features = noisy_action_projector(curr_noisy_actions)
|
900 |
+
curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
|
901 |
+
|
902 |
+
# Replace action token embeddings with noisy action embeddings
|
903 |
+
input_embeddings = self._replace_input_embeddings(
|
904 |
+
input_embeddings.clone(), all_actions_mask, noisy_action_features
|
905 |
+
)
|
906 |
+
|
907 |
+
# Build multimodal embeddings and attention mask
|
908 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
909 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
910 |
+
)
|
911 |
+
|
912 |
+
# Forward pass through language model
|
913 |
+
language_model_output = self.language_model(
|
914 |
+
input_ids=None,
|
915 |
+
attention_mask=multimodal_attention_mask,
|
916 |
+
position_ids=None,
|
917 |
+
past_key_values=None,
|
918 |
+
inputs_embeds=multimodal_embeddings,
|
919 |
+
labels=None,
|
920 |
+
use_cache=None,
|
921 |
+
output_attentions=False,
|
922 |
+
output_hidden_states=True,
|
923 |
+
return_dict=True,
|
924 |
+
)
|
925 |
+
|
926 |
+
# Extract hidden states for action portion of response
|
927 |
+
last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
|
928 |
+
actions_hidden_states = last_hidden_states[
|
929 |
+
:,
|
930 |
+
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
931 |
+
:,
|
932 |
+
] # (B, act_chunk_len, D)
|
933 |
+
|
934 |
+
# Predict noise and update noisy actions: x_t -> x_{t-1}
|
935 |
+
flow_pred = action_head.predict_flow(actions_hidden_states)
|
936 |
+
curr_noisy_actions += dt * flow_pred
|
937 |
+
time += dt
|
938 |
+
curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
939 |
+
|
940 |
+
# Return final actions
|
941 |
+
return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
|
942 |
+
|
943 |
+
|
944 |
+
def _run_diffusion_prediction(
|
945 |
+
self,
|
946 |
+
input_embeddings,
|
947 |
+
all_actions_mask,
|
948 |
+
noise,
|
949 |
+
action_head,
|
950 |
+
projected_patch_embeddings,
|
951 |
+
labels,
|
952 |
+
attention_mask,
|
953 |
+
NUM_PATCHES,
|
954 |
+
NUM_PROMPT_TOKENS,
|
955 |
+
noisy_action_projector,
|
956 |
+
):
|
957 |
+
"""Run diffusion-based action prediction"""
|
958 |
+
# Set diffusion timestep values
|
959 |
+
action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
|
960 |
+
# Clone embedding for reuse in each timestep
|
961 |
+
orig_projected_patch_embeddings = projected_patch_embeddings.clone()
|
962 |
+
curr_noisy_actions = noise
|
963 |
+
|
964 |
+
# Reverse diffusion: Iteratively denoise to generate action prediction
|
965 |
+
for t in action_head.noise_scheduler.timesteps:
|
966 |
+
# Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
|
967 |
+
# embedding, and diffusion timestep embedding)
|
968 |
+
timesteps = torch.Tensor([t]).to(labels.device)
|
969 |
+
diffusion_timestep_embeddings = (
|
970 |
+
action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
971 |
+
) # (B, llm_dim)
|
972 |
+
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
973 |
+
|
974 |
+
# [Diffusion] Replace the embeddings of the action tokens with noisy actions
|
975 |
+
# (Later on, the positional embeddings will be added to them)
|
976 |
+
|
977 |
+
# For simplicity, append diffusion timestep embedding to the end of projected vision tokens
|
978 |
+
projected_patch_embeddings = torch.cat(
|
979 |
+
(orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
|
980 |
+
)
|
981 |
+
|
982 |
+
# Reshape and project noisy actions into language embedding space
|
983 |
+
B = curr_noisy_actions.shape[0]
|
984 |
+
orig_curr_noisy_actions_shape = curr_noisy_actions.shape
|
985 |
+
curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
|
986 |
+
noisy_action_features = noisy_action_projector(curr_noisy_actions)
|
987 |
+
curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
|
988 |
+
|
989 |
+
# Replace action token embeddings with noisy action embeddings
|
990 |
+
input_embeddings = self._replace_input_embeddings(
|
991 |
+
input_embeddings.clone(), all_actions_mask, noisy_action_features
|
992 |
+
)
|
993 |
+
|
994 |
+
# Build multimodal embeddings and attention mask
|
995 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
996 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
997 |
+
)
|
998 |
+
|
999 |
+
# Forward pass through language model
|
1000 |
+
language_model_output = self.language_model(
|
1001 |
+
input_ids=None,
|
1002 |
+
attention_mask=multimodal_attention_mask,
|
1003 |
+
position_ids=None,
|
1004 |
+
past_key_values=None,
|
1005 |
+
inputs_embeds=multimodal_embeddings,
|
1006 |
+
labels=None,
|
1007 |
+
use_cache=None,
|
1008 |
+
output_attentions=False,
|
1009 |
+
output_hidden_states=True,
|
1010 |
+
return_dict=True,
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
# Extract hidden states for action portion of response
|
1014 |
+
last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
|
1015 |
+
actions_hidden_states = last_hidden_states[
|
1016 |
+
:,
|
1017 |
+
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
1018 |
+
:,
|
1019 |
+
] # (B, act_chunk_len, D)
|
1020 |
+
|
1021 |
+
# Predict noise and update noisy actions: x_t -> x_{t-1}
|
1022 |
+
noise_pred = action_head.predict_noise(actions_hidden_states)
|
1023 |
+
curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
1024 |
+
|
1025 |
+
curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1026 |
+
|
1027 |
+
# Return final actions
|
1028 |
+
return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
|
1029 |
+
|
1030 |
+
def _run_diffusion_prediction_V1(
|
1031 |
+
self,
|
1032 |
+
input_embeddings,
|
1033 |
+
all_actions_mask,
|
1034 |
+
noise,
|
1035 |
+
action_head,
|
1036 |
+
projected_patch_embeddings,
|
1037 |
+
labels,
|
1038 |
+
attention_mask,
|
1039 |
+
NUM_PATCHES,
|
1040 |
+
NUM_PROMPT_TOKENS,
|
1041 |
+
noisy_action_projector,
|
1042 |
+
proprio,
|
1043 |
+
proprio_projector,
|
1044 |
+
):
|
1045 |
+
"""Run diffusion-based action prediction"""
|
1046 |
+
# Set diffusion timestep values
|
1047 |
+
action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
|
1048 |
+
# Clone embedding for reuse in each timestep
|
1049 |
+
curr_noisy_actions = noise
|
1050 |
+
|
1051 |
+
# import pdb; pdb.set_trace()
|
1052 |
+
|
1053 |
+
action_queries = self.action_queries.weight # (1, h)
|
1054 |
+
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
1055 |
+
# Replace action token embeddings with noisy action embeddings
|
1056 |
+
input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
|
1057 |
+
# input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
|
1058 |
+
# action_attention_mask = None
|
1059 |
+
# action_attention_mask = torch.full(
|
1060 |
+
# (action_queries.shape[0], action_queries.shape[1]),
|
1061 |
+
# fill_value=True,
|
1062 |
+
# dtype=attention_mask.dtype,
|
1063 |
+
# device=attention_mask.device,)
|
1064 |
+
# attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
|
1065 |
+
|
1066 |
+
# Build multimodal embeddings and attention mask
|
1067 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
1068 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
# import pdb; pdb.set_trace()
|
1072 |
+
# Forward pass through language model
|
1073 |
+
language_model_output = self.language_model(
|
1074 |
+
input_ids=None,
|
1075 |
+
attention_mask=multimodal_attention_mask,
|
1076 |
+
position_ids=None,
|
1077 |
+
past_key_values=None,
|
1078 |
+
inputs_embeds=multimodal_embeddings,
|
1079 |
+
labels=None,
|
1080 |
+
use_cache=None,
|
1081 |
+
output_attentions=False,
|
1082 |
+
output_hidden_states=True,
|
1083 |
+
return_dict=True,
|
1084 |
+
)
|
1085 |
+
multi_layer_hidden_states = []
|
1086 |
+
# import pdb; pdb.set_trace()
|
1087 |
+
for item in language_model_output.hidden_states[0:]:
|
1088 |
+
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
1089 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
1090 |
+
text_hidden_states = item
|
1091 |
+
# Get hidden states for action portion of response
|
1092 |
+
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
1093 |
+
# import pdb; pdb.set_trace()
|
1094 |
+
batch_size = item.shape[0]
|
1095 |
+
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
1096 |
+
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
1097 |
+
multi_layer_hidden_states.append(all_hidden_states)
|
1098 |
+
# import pdb; pdb.set_trace()
|
1099 |
+
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
1100 |
+
# import pdb; pdb.set_trace()
|
1101 |
+
|
1102 |
+
|
1103 |
+
|
1104 |
+
# Reverse diffusion: Iteratively denoise to generate action prediction
|
1105 |
+
for t in action_head.noise_scheduler.timesteps:
|
1106 |
+
# Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
|
1107 |
+
# embedding, and diffusion timestep embedding)
|
1108 |
+
timesteps = torch.Tensor([t]).to(labels.device)
|
1109 |
+
diffusion_timestep_embeddings = (
|
1110 |
+
action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
1111 |
+
) # (B, llm_dim)
|
1112 |
+
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
1113 |
+
|
1114 |
+
# [Diffusion] Replace the embeddings of the action tokens with noisy actions
|
1115 |
+
# (Later on, the positional embeddings will be added to them)
|
1116 |
+
|
1117 |
+
# Reshape and project noisy actions into language embedding space
|
1118 |
+
B = curr_noisy_actions.shape[0]
|
1119 |
+
orig_curr_noisy_actions_shape = curr_noisy_actions.shape
|
1120 |
+
curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
|
1121 |
+
curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
|
1122 |
+
|
1123 |
+
# Predict noise and update noisy actions: x_t -> x_{t-1}
|
1124 |
+
# noise_pred = action_head.predict_noise(actions_hidden_states)
|
1125 |
+
noise_pred = action_head.predict_noise(multi_layer_hidden_states,
|
1126 |
+
noisy_actions=curr_noisy_actions,
|
1127 |
+
timestep_embeddings = diffusion_timestep_embeddings,
|
1128 |
+
noisy_action_projector=noisy_action_projector,
|
1129 |
+
proprio=proprio ,
|
1130 |
+
proprio_projector=proprio_projector)
|
1131 |
+
|
1132 |
+
curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
1133 |
+
curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1134 |
+
|
1135 |
+
# Return final actions
|
1136 |
+
return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
|
1137 |
+
|
1138 |
+
def _regression_or_discrete_prediction_V1(
|
1139 |
+
self,
|
1140 |
+
input_embeddings,
|
1141 |
+
all_actions_mask,
|
1142 |
+
projected_patch_embeddings,
|
1143 |
+
attention_mask,
|
1144 |
+
labels,
|
1145 |
+
NUM_PATCHES,
|
1146 |
+
NUM_PROMPT_TOKENS,
|
1147 |
+
action_head=None,
|
1148 |
+
proprio=None,
|
1149 |
+
proprio_projector=None,
|
1150 |
+
):
|
1151 |
+
"""Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
|
1152 |
+
|
1153 |
+
action_queries = self.action_queries.weight # (1, h)
|
1154 |
+
action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
|
1155 |
+
# Replace action token embeddings with noisy action embeddings
|
1156 |
+
input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
|
1157 |
+
|
1158 |
+
# Build multimodal embeddings and attention mask
|
1159 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
1160 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
# Forward pass through language model
|
1164 |
+
language_model_output = self.language_model(
|
1165 |
+
input_ids=None,
|
1166 |
+
attention_mask=multimodal_attention_mask,
|
1167 |
+
position_ids=None,
|
1168 |
+
past_key_values=None,
|
1169 |
+
inputs_embeds=multimodal_embeddings,
|
1170 |
+
labels=None,
|
1171 |
+
use_cache=None,
|
1172 |
+
output_attentions=False,
|
1173 |
+
output_hidden_states=True,
|
1174 |
+
return_dict=True,
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
# Extract hidden states for action tokens
|
1178 |
+
multi_layer_hidden_states = []
|
1179 |
+
# import pdb; pdb.set_trace()
|
1180 |
+
for item in language_model_output.hidden_states[0:]:
|
1181 |
+
# last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
1182 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
1183 |
+
text_hidden_states = item
|
1184 |
+
# Get hidden states for action portion of response
|
1185 |
+
actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
|
1186 |
+
# import pdb; pdb.set_trace()
|
1187 |
+
batch_size = item.shape[0]
|
1188 |
+
task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
|
1189 |
+
all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
|
1190 |
+
multi_layer_hidden_states.append(all_hidden_states)
|
1191 |
+
# import pdb; pdb.set_trace()
|
1192 |
+
multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
|
1193 |
+
# import pdb; pdb.set_trace()
|
1194 |
+
|
1195 |
+
# Handle different prediction methods
|
1196 |
+
if action_head is not None:
|
1197 |
+
# L1 regression prediction
|
1198 |
+
normalized_actions = action_head.predict_action(multi_layer_hidden_states,
|
1199 |
+
proprio=proprio,
|
1200 |
+
proprio_projector=proprio_projector)
|
1201 |
+
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1202 |
+
normalized_actions = normalized_actions.float().cpu().detach().numpy()
|
1203 |
+
else:
|
1204 |
+
# Discrete token-based prediction
|
1205 |
+
predicted_action_token_ids = (
|
1206 |
+
language_model_output.logits[
|
1207 |
+
:,
|
1208 |
+
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
1209 |
+
]
|
1210 |
+
.argmax(dim=2)
|
1211 |
+
.cpu()
|
1212 |
+
.numpy()
|
1213 |
+
)
|
1214 |
+
discretized_actions = self.vocab_size - predicted_action_token_ids
|
1215 |
+
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
|
1216 |
+
normalized_actions = self.bin_centers[discretized_actions]
|
1217 |
+
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1218 |
+
|
1219 |
+
return normalized_actions, actions_hidden_states
|
1220 |
+
|
1221 |
+
def _regression_or_discrete_prediction(
|
1222 |
+
self,
|
1223 |
+
input_embeddings,
|
1224 |
+
all_actions_mask,
|
1225 |
+
projected_patch_embeddings,
|
1226 |
+
attention_mask,
|
1227 |
+
labels,
|
1228 |
+
NUM_PATCHES,
|
1229 |
+
NUM_PROMPT_TOKENS,
|
1230 |
+
action_head=None,
|
1231 |
+
):
|
1232 |
+
"""Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
|
1233 |
+
# Zero out action token embeddings
|
1234 |
+
all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
|
1235 |
+
input_embeddings = input_embeddings * ~all_actions_mask
|
1236 |
+
|
1237 |
+
# Build multimodal embeddings and attention mask
|
1238 |
+
multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
|
1239 |
+
input_embeddings, projected_patch_embeddings, attention_mask
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
# Forward pass through language model
|
1243 |
+
language_model_output = self.language_model(
|
1244 |
+
input_ids=None,
|
1245 |
+
attention_mask=multimodal_attention_mask,
|
1246 |
+
position_ids=None,
|
1247 |
+
past_key_values=None,
|
1248 |
+
inputs_embeds=multimodal_embeddings,
|
1249 |
+
labels=None,
|
1250 |
+
use_cache=None,
|
1251 |
+
output_attentions=False,
|
1252 |
+
output_hidden_states=True,
|
1253 |
+
return_dict=True,
|
1254 |
+
)
|
1255 |
+
|
1256 |
+
# Extract hidden states for action tokens
|
1257 |
+
last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
|
1258 |
+
actions_hidden_states = last_hidden_states[
|
1259 |
+
:,
|
1260 |
+
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
1261 |
+
:,
|
1262 |
+
] # (B, act_chunk_len, D)
|
1263 |
+
|
1264 |
+
# Handle different prediction methods
|
1265 |
+
if action_head is not None:
|
1266 |
+
# L1 regression prediction
|
1267 |
+
normalized_actions = action_head.predict_action(actions_hidden_states)
|
1268 |
+
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1269 |
+
normalized_actions = normalized_actions.float().cpu().detach().numpy()
|
1270 |
+
else:
|
1271 |
+
# Discrete token-based prediction
|
1272 |
+
predicted_action_token_ids = (
|
1273 |
+
language_model_output.logits[
|
1274 |
+
:,
|
1275 |
+
NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
|
1276 |
+
]
|
1277 |
+
.argmax(dim=2)
|
1278 |
+
.cpu()
|
1279 |
+
.numpy()
|
1280 |
+
)
|
1281 |
+
discretized_actions = self.vocab_size - predicted_action_token_ids
|
1282 |
+
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
|
1283 |
+
normalized_actions = self.bin_centers[discretized_actions]
|
1284 |
+
normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
|
1285 |
+
|
1286 |
+
return normalized_actions, actions_hidden_states
|
1287 |
+
|
1288 |
+
def predict_action(
|
1289 |
+
self,
|
1290 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1291 |
+
unnorm_key: Optional[str] = None,
|
1292 |
+
proprio=None,
|
1293 |
+
proprio_projector=None,
|
1294 |
+
action_head=None,
|
1295 |
+
noisy_action_projector=None,
|
1296 |
+
use_film: bool = False,
|
1297 |
+
**kwargs: str,
|
1298 |
+
) -> np.ndarray:
|
1299 |
+
"""Predict actions from input sequence, with options for different prediction methods.
|
1300 |
+
|
1301 |
+
Args:
|
1302 |
+
input_ids: Input token ids
|
1303 |
+
unnorm_key: Key for unnormalization statistics
|
1304 |
+
proprio: Proprioceptive features
|
1305 |
+
proprio_projector: Projector for proprioceptive features
|
1306 |
+
action_head: Optional head for L1 regression or diffusion-based prediction
|
1307 |
+
noisy_action_projector: Projector for noisy actions in diffusion-based prediction
|
1308 |
+
use_film: Whether to use FiLM conditioning
|
1309 |
+
**kwargs: Additional arguments including pixel_values and attention_mask
|
1310 |
+
|
1311 |
+
Returns:
|
1312 |
+
Tuple of (unnormalized_actions, action_hidden_states)
|
1313 |
+
"""
|
1314 |
+
# import pdb; pdb.set_trace()
|
1315 |
+
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
1316 |
+
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
1317 |
+
|
1318 |
+
|
1319 |
+
# if not torch.all(input_ids[:, -1] == 29871):
|
1320 |
+
# input_ids = torch.cat(
|
1321 |
+
# (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
1322 |
+
# )
|
1323 |
+
|
1324 |
+
|
1325 |
+
pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224]
|
1326 |
+
attention_mask = kwargs["attention_mask"] #
|
1327 |
+
|
1328 |
+
# Create fake labels tensor (needed for action mask)
|
1329 |
+
labels = input_ids.clone()
|
1330 |
+
labels[:] = IGNORE_INDEX
|
1331 |
+
|
1332 |
+
# Get number of tokens in prompt (excluding the start token)
|
1333 |
+
NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
|
1334 |
+
|
1335 |
+
# import pdb; pdb.set_trace()
|
1336 |
+
|
1337 |
+
# Prepare inputs by adding necessary tokens
|
1338 |
+
input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
|
1339 |
+
|
1340 |
+
# Update labels tensor for action mask computation later
|
1341 |
+
labels = self._prepare_labels_for_action_prediction(labels, input_ids)
|
1342 |
+
|
1343 |
+
# Get input embeddings and action masks
|
1344 |
+
input_embeddings = self.get_input_embeddings()(input_ids)
|
1345 |
+
all_actions_mask = self._process_action_masks(labels)
|
1346 |
+
|
1347 |
+
# Extract language embeddings
|
1348 |
+
language_embeddings = input_embeddings[~all_actions_mask].reshape(
|
1349 |
+
input_embeddings.shape[0], -1, input_embeddings.shape[2]
|
1350 |
+
)
|
1351 |
+
|
1352 |
+
# Process vision features
|
1353 |
+
projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
|
1354 |
+
|
1355 |
+
# Add proprioceptive features if provided
|
1356 |
+
use_proprio = proprio_projector is not None and proprio is not None
|
1357 |
+
if use_proprio:
|
1358 |
+
proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
|
1359 |
+
if self.version == 'v1':
|
1360 |
+
pass
|
1361 |
+
else:
|
1362 |
+
projected_patch_embeddings = self._process_proprio_features(
|
1363 |
+
projected_patch_embeddings, proprio, proprio_projector
|
1364 |
+
)
|
1365 |
+
# import pdb; pdb.set_trace()
|
1366 |
+
# Use diffusion if provided, otherwise use regression or discrete prediction
|
1367 |
+
use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
|
1368 |
+
use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
|
1369 |
+
|
1370 |
+
|
1371 |
+
# Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
|
1372 |
+
NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
|
1373 |
+
if self.version == 'v1':
|
1374 |
+
# if use_diffusion:
|
1375 |
+
# NUM_PATCHES += 1
|
1376 |
+
pass
|
1377 |
+
else:
|
1378 |
+
if use_proprio:
|
1379 |
+
NUM_PATCHES += 1
|
1380 |
+
if use_diffusion:
|
1381 |
+
NUM_PATCHES += 1
|
1382 |
+
|
1383 |
+
# import pdb; pdb.set_trace()
|
1384 |
+
if use_flow_matching:
|
1385 |
+
# Sample random noise with shape equal to output action, used as the starting state for flow matching
|
1386 |
+
noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
|
1387 |
+
|
1388 |
+
# Run flow matching-based prediction
|
1389 |
+
normalized_actions, actions_hidden_states = self._run_flow_matching_prediction(
|
1390 |
+
input_embeddings,
|
1391 |
+
all_actions_mask,
|
1392 |
+
noise,
|
1393 |
+
action_head,
|
1394 |
+
projected_patch_embeddings,
|
1395 |
+
labels,
|
1396 |
+
attention_mask,
|
1397 |
+
NUM_PATCHES,
|
1398 |
+
NUM_PROMPT_TOKENS,
|
1399 |
+
noisy_action_projector
|
1400 |
+
)
|
1401 |
+
elif use_diffusion:
|
1402 |
+
# Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
|
1403 |
+
noise = torch.randn(
|
1404 |
+
size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
|
1405 |
+
)
|
1406 |
+
# import pdb; pdb.set_trace()
|
1407 |
+
if self.version == 'v1':
|
1408 |
+
|
1409 |
+
# import pdb; pdb.set_trace()
|
1410 |
+
# Run diffusion-based prediction
|
1411 |
+
normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
|
1412 |
+
input_embeddings, # [1, 86, 4096]
|
1413 |
+
all_actions_mask, # [1, 86]
|
1414 |
+
noise, # [1,8, 7]
|
1415 |
+
action_head,
|
1416 |
+
projected_patch_embeddings, # [1, 512, 4096]
|
1417 |
+
labels, # [1, 86]
|
1418 |
+
attention_mask, # [1, 86]
|
1419 |
+
NUM_PATCHES, # 512
|
1420 |
+
NUM_PROMPT_TOKENS, # 28
|
1421 |
+
noisy_action_projector,
|
1422 |
+
proprio, # [8]
|
1423 |
+
proprio_projector,
|
1424 |
+
)
|
1425 |
+
else:
|
1426 |
+
# Run diffusion-based prediction
|
1427 |
+
normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
|
1428 |
+
input_embeddings,
|
1429 |
+
all_actions_mask,
|
1430 |
+
noise,
|
1431 |
+
action_head,
|
1432 |
+
projected_patch_embeddings,
|
1433 |
+
labels,
|
1434 |
+
attention_mask,
|
1435 |
+
NUM_PATCHES,
|
1436 |
+
NUM_PROMPT_TOKENS,
|
1437 |
+
noisy_action_projector,
|
1438 |
+
)
|
1439 |
+
|
1440 |
+
else:
|
1441 |
+
if self.version == 'v1':
|
1442 |
+
# Run regression or discrete token-based prediction
|
1443 |
+
normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction_V1(
|
1444 |
+
input_embeddings,
|
1445 |
+
all_actions_mask,
|
1446 |
+
projected_patch_embeddings,
|
1447 |
+
attention_mask,
|
1448 |
+
labels,
|
1449 |
+
NUM_PATCHES,
|
1450 |
+
NUM_PROMPT_TOKENS,
|
1451 |
+
action_head=action_head,
|
1452 |
+
proprio=proprio, # [8]
|
1453 |
+
proprio_projector=proprio_projector,
|
1454 |
+
)
|
1455 |
+
else:
|
1456 |
+
# Run regression or discrete token-based prediction
|
1457 |
+
normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
|
1458 |
+
input_embeddings,
|
1459 |
+
all_actions_mask,
|
1460 |
+
projected_patch_embeddings,
|
1461 |
+
attention_mask,
|
1462 |
+
labels,
|
1463 |
+
NUM_PATCHES,
|
1464 |
+
NUM_PROMPT_TOKENS,
|
1465 |
+
action_head,
|
1466 |
+
)
|
1467 |
+
|
1468 |
+
# import pdb; pdb.set_trace()
|
1469 |
+
# Unnormalize predicted actions
|
1470 |
+
actions = self._unnormalize_actions(normalized_actions, unnorm_key)
|
1471 |
+
|
1472 |
+
return actions, actions_hidden_states
|
1473 |
+
|
1474 |
+
@staticmethod
|
1475 |
+
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
1476 |
+
"""Validate and resolve the unnormalization key for action statistics"""
|
1477 |
+
if unnorm_key is None:
|
1478 |
+
assert len(norm_stats) == 1, (
|
1479 |
+
f"Your model was trained on more than one dataset, "
|
1480 |
+
f"please pass a `unnorm_key` from the following options to choose the statistics "
|
1481 |
+
f"used for un-normalizing actions: {norm_stats.keys()}"
|
1482 |
+
)
|
1483 |
+
unnorm_key = next(iter(norm_stats.keys()))
|
1484 |
+
|
1485 |
+
assert unnorm_key in norm_stats, (
|
1486 |
+
f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
|
1487 |
+
f"please choose from: {norm_stats.keys()}"
|
1488 |
+
)
|
1489 |
+
return unnorm_key
|
1490 |
+
|
1491 |
+
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
1492 |
+
"""Get the dimensionality of the policy's action space."""
|
1493 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
1494 |
+
return len(self.norm_stats[unnorm_key]["action"]["min"])
|
1495 |
+
|
1496 |
+
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
1497 |
+
"""Get all the logged statistics for the given dataset."""
|
1498 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
1499 |
+
return self.norm_stats[unnorm_key]["action"]
|
preprocessor_config.json
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
|
4 |
+
"AutoProcessor": "processing_prismatic.PrismaticProcessor"
|
5 |
+
},
|
6 |
+
"image_processor_type": "PrismaticImageProcessor",
|
7 |
+
"image_resize_strategy": "resize-naive",
|
8 |
+
"input_sizes": [
|
9 |
+
[
|
10 |
+
3,
|
11 |
+
224,
|
12 |
+
224
|
13 |
+
],
|
14 |
+
[
|
15 |
+
3,
|
16 |
+
224,
|
17 |
+
224
|
18 |
+
]
|
19 |
+
],
|
20 |
+
"interpolations": [
|
21 |
+
"bicubic",
|
22 |
+
"bicubic"
|
23 |
+
],
|
24 |
+
"means": [
|
25 |
+
[
|
26 |
+
0.485,
|
27 |
+
0.456,
|
28 |
+
0.406
|
29 |
+
],
|
30 |
+
[
|
31 |
+
0.5,
|
32 |
+
0.5,
|
33 |
+
0.5
|
34 |
+
]
|
35 |
+
],
|
36 |
+
"processor_class": "PrismaticProcessor",
|
37 |
+
"stds": [
|
38 |
+
[
|
39 |
+
0.229,
|
40 |
+
0.224,
|
41 |
+
0.225
|
42 |
+
],
|
43 |
+
[
|
44 |
+
0.5,
|
45 |
+
0.5,
|
46 |
+
0.5
|
47 |
+
]
|
48 |
+
],
|
49 |
+
"tvf_crop_params": [
|
50 |
+
{
|
51 |
+
"output_size": [
|
52 |
+
224,
|
53 |
+
224
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"output_size": [
|
58 |
+
224,
|
59 |
+
224
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"tvf_do_letterbox": false,
|
64 |
+
"tvf_letterbox_fill": null,
|
65 |
+
"tvf_normalize_params": [
|
66 |
+
{
|
67 |
+
"inplace": false,
|
68 |
+
"mean": [
|
69 |
+
0.484375,
|
70 |
+
0.455078125,
|
71 |
+
0.40625
|
72 |
+
],
|
73 |
+
"std": [
|
74 |
+
0.228515625,
|
75 |
+
0.2236328125,
|
76 |
+
0.224609375
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"inplace": false,
|
81 |
+
"mean": [
|
82 |
+
0.5,
|
83 |
+
0.5,
|
84 |
+
0.5
|
85 |
+
],
|
86 |
+
"std": [
|
87 |
+
0.5,
|
88 |
+
0.5,
|
89 |
+
0.5
|
90 |
+
]
|
91 |
+
}
|
92 |
+
],
|
93 |
+
"tvf_resize_params": [
|
94 |
+
{
|
95 |
+
"antialias": true,
|
96 |
+
"interpolation": 3,
|
97 |
+
"max_size": null,
|
98 |
+
"size": [
|
99 |
+
224,
|
100 |
+
224
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"antialias": true,
|
105 |
+
"interpolation": 3,
|
106 |
+
"max_size": null,
|
107 |
+
"size": [
|
108 |
+
224,
|
109 |
+
224
|
110 |
+
]
|
111 |
+
}
|
112 |
+
],
|
113 |
+
"use_fused_vision_backbone": true
|
114 |
+
}
|
processing_prismatic.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
processing_prismatic.py
|
3 |
+
|
4 |
+
HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
|
5 |
+
specifies `siglip-224px+7b`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Any, ClassVar, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import timm.data
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms.functional as TVF
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
15 |
+
from transformers import PreTrainedTokenizerBase
|
16 |
+
from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
|
17 |
+
from transformers.processing_utils import ProcessorMixin
|
18 |
+
from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
19 |
+
from transformers.utils import TensorType
|
20 |
+
|
21 |
+
|
22 |
+
# === Image Processing ===
|
23 |
+
def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
|
24 |
+
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
25 |
+
(w, h), max_wh = image.size, max(image.size)
|
26 |
+
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
|
27 |
+
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
28 |
+
|
29 |
+
return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
|
30 |
+
|
31 |
+
|
32 |
+
class PrismaticImageProcessor(ImageProcessingMixin):
|
33 |
+
model_input_names: ClassVar[List[str]] = ["pixel_values"]
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
use_fused_vision_backbone: bool = False,
|
38 |
+
image_resize_strategy: str = "letterbox",
|
39 |
+
input_sizes: Optional[List[Tuple[int, int, int]]] = None,
|
40 |
+
interpolations: Optional[List[str]] = None,
|
41 |
+
means: Optional[List[Tuple[float, float, float]]] = None,
|
42 |
+
stds: Optional[List[Tuple[float, float, float]]] = None,
|
43 |
+
**kwargs: str,
|
44 |
+
) -> None:
|
45 |
+
"""
|
46 |
+
Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
|
47 |
+
created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
|
48 |
+
|
49 |
+
@param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
|
50 |
+
@param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
|
51 |
+
@param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
|
52 |
+
@param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
|
53 |
+
@param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
|
54 |
+
@param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
|
55 |
+
"""
|
56 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
57 |
+
self.image_resize_strategy = image_resize_strategy
|
58 |
+
|
59 |
+
# Handle `None` default values
|
60 |
+
input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
|
61 |
+
means = [(0.5, 0.5, 0.5)] if means is None else means
|
62 |
+
stds = [(0.5, 0.5, 0.5)] if stds is None else stds
|
63 |
+
|
64 |
+
# TIMM `data_cfg` Parameters
|
65 |
+
self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
|
66 |
+
|
67 |
+
# Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
|
68 |
+
self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
|
69 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
70 |
+
|
71 |
+
for idx in range(len(input_sizes)):
|
72 |
+
transform = timm.data.create_transform(
|
73 |
+
input_size=self.input_sizes[idx],
|
74 |
+
interpolation=self.interpolations[idx],
|
75 |
+
mean=self.means[idx],
|
76 |
+
std=self.stds[idx],
|
77 |
+
crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
|
78 |
+
crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
|
79 |
+
is_training=False, # No image augmentations when loading the transform!
|
80 |
+
)
|
81 |
+
|
82 |
+
# [Validation] Ensure appropriate transform structure, expected sizes
|
83 |
+
if not (
|
84 |
+
isinstance(transform, Compose)
|
85 |
+
and (len(transform.transforms) == 4)
|
86 |
+
and isinstance(transform.transforms[0], Resize)
|
87 |
+
and isinstance(transform.transforms[1], CenterCrop)
|
88 |
+
and isinstance(transform.transforms[2], ToTensor)
|
89 |
+
and isinstance(transform.transforms[3], Normalize)
|
90 |
+
and (transform.transforms[0].size == self.input_sizes[idx][-1])
|
91 |
+
and (transform.transforms[1].size == self.input_sizes[idx][-2:])
|
92 |
+
):
|
93 |
+
raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
|
94 |
+
|
95 |
+
# HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
|
96 |
+
# => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
|
97 |
+
resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
|
98 |
+
self.tvf_resize_params.append(
|
99 |
+
{
|
100 |
+
"size": resize_t.size,
|
101 |
+
"interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
|
102 |
+
"max_size": None,
|
103 |
+
"antialias": True,
|
104 |
+
}
|
105 |
+
)
|
106 |
+
self.tvf_crop_params.append({"output_size": crop_t.size})
|
107 |
+
self.tvf_normalize_params.append(
|
108 |
+
{
|
109 |
+
"mean": norm_t.mean.float().numpy().tolist(),
|
110 |
+
"std": norm_t.std.float().numpy().tolist(),
|
111 |
+
"inplace": False,
|
112 |
+
}
|
113 |
+
)
|
114 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
115 |
+
|
116 |
+
# Handle Prismatic `image_resize_strategy`
|
117 |
+
if self.image_resize_strategy == "resize-naive":
|
118 |
+
self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
|
119 |
+
elif self.image_resize_strategy == "letterbox":
|
120 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
|
121 |
+
elif self.image_resize_strategy == "resize-crop":
|
122 |
+
pass
|
123 |
+
else:
|
124 |
+
raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
|
125 |
+
|
126 |
+
# Dispatch **kwargs to super()
|
127 |
+
super().__init__(**kwargs)
|
128 |
+
|
129 |
+
def apply_transform(self, img: Image.Image) -> torch.Tensor:
|
130 |
+
"""Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
|
131 |
+
if self.tvf_do_letterbox:
|
132 |
+
img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
|
133 |
+
|
134 |
+
# [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
|
135 |
+
imgs_t = []
|
136 |
+
for idx in range(len(self.input_sizes)):
|
137 |
+
img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
|
138 |
+
img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
|
139 |
+
img_idx_t = TVF.to_tensor(img_idx)
|
140 |
+
img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
|
141 |
+
imgs_t.append(img_idx_t)
|
142 |
+
|
143 |
+
# [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
|
144 |
+
img_t = torch.vstack(imgs_t)
|
145 |
+
|
146 |
+
return img_t
|
147 |
+
|
148 |
+
def preprocess(
|
149 |
+
self,
|
150 |
+
images: Union[Image.Image, List[Image.Image]],
|
151 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
152 |
+
**_: str,
|
153 |
+
) -> BatchFeature:
|
154 |
+
"""
|
155 |
+
Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
|
156 |
+
explicitly only handle PIL.Image.Image instances for simplicity.
|
157 |
+
|
158 |
+
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
159 |
+
@param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
|
160 |
+
|
161 |
+
@return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
|
162 |
+
"""
|
163 |
+
if not isinstance(images, list):
|
164 |
+
images = [images]
|
165 |
+
|
166 |
+
# Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
|
167 |
+
pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
|
168 |
+
|
169 |
+
# Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
|
170 |
+
return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
|
171 |
+
|
172 |
+
def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
|
173 |
+
return self.preprocess(images, **kwargs)
|
174 |
+
|
175 |
+
|
176 |
+
# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
|
177 |
+
# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
|
178 |
+
class PrismaticProcessor(ProcessorMixin):
|
179 |
+
attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
|
180 |
+
image_processor_class: str = "AutoImageProcessor"
|
181 |
+
tokenizer_class: str = "AutoTokenizer"
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
image_processor: Optional[ImageProcessingMixin] = None,
|
186 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
187 |
+
) -> None:
|
188 |
+
super().__init__(image_processor, tokenizer)
|
189 |
+
|
190 |
+
def __call__(
|
191 |
+
self,
|
192 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
193 |
+
images: Union[Image.Image, List[Image.Image]],
|
194 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
195 |
+
truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
|
196 |
+
max_length: Optional[int] = None,
|
197 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
198 |
+
) -> BatchFeature:
|
199 |
+
"""
|
200 |
+
Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
|
201 |
+
forwards images to PrismaticImageProcessor.
|
202 |
+
|
203 |
+
@param text: The (batch) of text to encode; must be a string or list of strings.
|
204 |
+
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
205 |
+
@param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
|
206 |
+
@param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
|
207 |
+
@param max_length: Maximum length (in tokens) to truncate
|
208 |
+
@param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
|
209 |
+
|
210 |
+
@return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
|
211 |
+
"""
|
212 |
+
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
|
213 |
+
text_inputs = self.tokenizer(
|
214 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
215 |
+
)
|
216 |
+
|
217 |
+
# [Validate] Need same number of images and text inputs!
|
218 |
+
if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
|
219 |
+
raise ValueError("Batch is malformed; expected same number of images and text inputs!")
|
220 |
+
|
221 |
+
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
|
222 |
+
|
223 |
+
# === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
|
224 |
+
def batch_decode(
|
225 |
+
self,
|
226 |
+
sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
227 |
+
skip_special_tokens: bool = False,
|
228 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
229 |
+
**kwargs: str,
|
230 |
+
) -> List[str]:
|
231 |
+
return self.tokenizer.batch_decode(
|
232 |
+
sequences=sequences,
|
233 |
+
skip_special_tokens=skip_special_tokens,
|
234 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
235 |
+
**kwargs,
|
236 |
+
)
|
237 |
+
|
238 |
+
def decode(
|
239 |
+
self,
|
240 |
+
token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
241 |
+
skip_special_tokens: bool = False,
|
242 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
243 |
+
**kwargs: str,
|
244 |
+
) -> str:
|
245 |
+
return self.tokenizer.decode(
|
246 |
+
token_ids=token_ids,
|
247 |
+
skip_special_tokens=skip_special_tokens,
|
248 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
249 |
+
**kwargs,
|
250 |
+
)
|
251 |
+
|
252 |
+
@property
|
253 |
+
def model_input_names(self) -> List[str]:
|
254 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
255 |
+
image_processor_input_names = self.image_processor.model_input_names
|
256 |
+
|
257 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
processor_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_prismatic.PrismaticProcessor"
|
4 |
+
},
|
5 |
+
"processor_class": "PrismaticProcessor"
|
6 |
+
}
|
proprio_projector--20000_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8481fa9895e1500037c7f059f8722f915055a9db1105a474edcdac18a2333364
|
3 |
+
size 1626096
|
special_tokens_map.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>"
|
16 |
+
],
|
17 |
+
"eos_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
},
|
24 |
+
"pad_token": {
|
25 |
+
"content": "<|endoftext|>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": false,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
}
|
31 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
}
|
181 |
+
},
|
182 |
+
"additional_special_tokens": [
|
183 |
+
"<|im_start|>",
|
184 |
+
"<|im_end|>",
|
185 |
+
"<|object_ref_start|>",
|
186 |
+
"<|object_ref_end|>",
|
187 |
+
"<|box_start|>",
|
188 |
+
"<|box_end|>",
|
189 |
+
"<|quad_start|>",
|
190 |
+
"<|quad_end|>",
|
191 |
+
"<|vision_start|>",
|
192 |
+
"<|vision_end|>",
|
193 |
+
"<|vision_pad|>",
|
194 |
+
"<|image_pad|>",
|
195 |
+
"<|video_pad|>"
|
196 |
+
],
|
197 |
+
"auto_map": {
|
198 |
+
"AutoProcessor": "processing_prismatic.PrismaticProcessor"
|
199 |
+
},
|
200 |
+
"bos_token": null,
|
201 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
202 |
+
"clean_up_tokenization_spaces": false,
|
203 |
+
"eos_token": "<|endoftext|>",
|
204 |
+
"errors": "replace",
|
205 |
+
"model_max_length": 131072,
|
206 |
+
"pad_token": "<|endoftext|>",
|
207 |
+
"processor_class": "PrismaticProcessor",
|
208 |
+
"split_special_tokens": false,
|
209 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
210 |
+
"unk_token": null
|
211 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|