VLA-Adapter commited on
Commit
109e941
·
verified ·
1 Parent(s): 6fc5c4b

Upload 19 files

Browse files
Inference-Spatial--20000_chkpt--97.8.log ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,217 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - Vision-Language-Action
5
+ - OpenHelix Team
6
+ base_model:
7
+ - Qwen/Qwen2.5-0.5B
8
+ language:
9
+ - en
10
+ pipeline_tag: robotics
11
+ ---
12
+
13
+
14
+ # Model Card for VLA-Adapter Libero-Long
15
+ VLA-Adapter: An Effective Paradigm for Tiny-Scale Vision-Language-Action Model trained on Libero-Goal.
16
+ - Project page: [https://vla-adapter.github.io/](https://vla-adapter.github.io/)
17
+ - Dataset: [https://huggingface.co/datasets/openvla/modified_libero_rlds/tree/main](https://huggingface.co/datasets/openvla/modified_libero_rlds/tree/main)
18
+
19
+ ## Model Details
20
+ We have developed and released the VLA-Adapter family of VLA models, a series of fine-tuned generative
21
+ action models. The VLA-Adapter VLM follows the Prismatic-VLM architecture, using only a very small backbone
22
+ (Qwen2.5-0.5B) for the LLM. On common robotics benchmarks, it surpasses open-source VLA models with 8.5B,
23
+ 7B, 4B, 3B, and 2B backbones.
24
+
25
+ **Input:** Models input image and text.
26
+
27
+ **Output:** Models generate action only.
28
+
29
+ **Model Architecture:** The VLA-Adapter consists of a VLM for receiving and processing image and text
30
+ information and a policy for generating actions. We systematically analyzed the benefits that the VLM
31
+ provides to different types of policy conditions and determined a unified framework. We then utilized
32
+ our designed Bridge Attention module to fuse the conditions generated by the VLM with the initial action
33
+ information in the policy, bridging the gap between VL and A to the greatest extent possible.
34
+ This resulted in a high-performance VLA model on a tiny-scale backbone.
35
+
36
+
37
+ ### Success Rate Comparison
38
+ <table>
39
+ <tr>
40
+ <td><strong>Category</strong>
41
+ </td>
42
+ <td><strong>Methods</strong>
43
+ </td>
44
+ <td><strong>Scale</strong>
45
+ </td>
46
+ <td><strong>LIBERO-Spatial</strong>
47
+ </td>
48
+ <td><strong>LIBERO-Object</strong>
49
+ </td>
50
+ <td><strong>LIBERO-Goal</strong>
51
+ </td>
52
+ <td><strong>LIBERO-Long</strong>
53
+ </td>
54
+ <td><strong>Avg.</strong>
55
+ </td>
56
+ </tr>
57
+ <tr>
58
+ <td rowspan="9">Large-scale</td>
59
+ <td>FlowVLA (Zhong et al., 2025)</td>
60
+ <td>8.5B</td><td>93.2</td><td>95.0</td><td>91.6</td><td>72.6</td><td>88.1</td>
61
+ </tr>
62
+
63
+ <tr>
64
+ <td>OpenVLA (Kim et al., 2024)</td>
65
+ <td>7B</td><td>84.7</td><td>88.4</td><td>79.2</td><td>53.7</td><td>76.5</td>
66
+ </tr>
67
+
68
+ <tr>
69
+ <td>OpenVLA-OFT (Kim et al., 2025)</td>
70
+ <td>7B</td><td><i><u>97.6*</u></i></td><td>98.4</td><td><b>97.9</b></td><td><b>94.5</b></td><td><b>97.1</b></td>
71
+ </tr>
72
+
73
+ <tr>
74
+ <td>UniVLA (Bu et al., 2025)</td>
75
+ <td>7B</td><td>96.5</td><td> 96.8</td><td> 95.6 </td><td>92.0 </td><td>95.2</td>
76
+ </tr>
77
+
78
+ <tr>
79
+ <td>CoT-VLA (Zhao et al., 2025)</td>
80
+ <td>7B</td><td>87.5 </td><td>91.6 </td><td>87.6</td><td> 69.0</td><td> 81.1</td>
81
+ </tr>
82
+
83
+ <tr>
84
+ <td>WorldVLA (Cen et al., 2025)</td>
85
+ <td>7B</td><td>87.6</td><td> 96.2</td><td> 83.4</td><td> 60.0</td><td> 81.8</td>
86
+ </tr>
87
+
88
+ <tr>
89
+ <td>TraceVLA (Zheng et al., 2025)</td>
90
+ <td>7B</td><td>84.6</td><td> 85.2</td><td> 75.1</td><td> 54.1</td><td> 74.8</td>
91
+ </tr>
92
+
93
+ <tr>
94
+ <td>MolmoAct (Lee et al., 2025)</td>
95
+ <td>7B</td><td>87.0</td><td> 95.4 </td><td>87.6</td><td> 77.2 </td><td>86.6</td>
96
+ </tr>
97
+
98
+ <tr>
99
+ <td>ThinkAct (Huang et al., 2025)</td>
100
+ <td>7B</td><td>88.3 </td><td>91.4</td><td> 87.1</td><td> 70.9</td><td> 84.4</td>
101
+ </tr>
102
+
103
+ <tr>
104
+ <td rowspan="7">Small-scale</td>
105
+ <td>4D-VLA (Zhang et al., 2025)</td>
106
+ <td>4B</td><td>88.9</td><td> 95.2</td><td> 90.9</td><td> 79.1 </td><td>88.6</td>
107
+ </tr>
108
+
109
+ <tr>
110
+ <td>SpatialVLA (Qu et al., 2025)</td>
111
+ <td>4B</td><td>88.2</td><td> 89.9</td><td> 78.6</td><td> 55.5 </td><td>78.1</td>
112
+ </tr>
113
+
114
+ <tr>
115
+ <td>π0 (Black et al., 2024)</td>
116
+ <td>3B</td><td>96.8</td><td> <i><u>98.8*</u></i> </td><td>95.8</td><td> 85.2</td><td> 94.2</td>
117
+ </tr>
118
+
119
+ <tr>
120
+ <td>π0-FAST (Pertsch et al., 2025)</td>
121
+ <td>3B</td><td>96.4</td><td> 96.8 </td><td>88.6</td><td> 60.2</td><td> 85.5</td>
122
+ </tr>
123
+
124
+ <tr>
125
+ <td>NORA (Hung et al., 2025)</td>
126
+ <td>3B</td><td>92.2 </td><td>95.4 </td><td>89.4</td><td> 74.6 </td><td>87.9</td>
127
+ </tr>
128
+
129
+ <tr>
130
+ <td>SmolVLA (Shukor et al., 2025)</td>
131
+ <td>2.2B</td><td>93.0</td><td> 94.0 </td><td>91.0</td><td> 77.0 </td><td>88.8</td>
132
+ </tr>
133
+
134
+ <tr>
135
+ <td>GR00T N1 (NVIDIA et al., 2025)</td>
136
+ <td>2B</td><td>94.4</td><td> 97.6 </td><td>93.0 </td><td>90.6</td><td> 93.9</td>
137
+ </tr>
138
+
139
+ <tr>
140
+ <td rowspan="4">Tiny-scale</td>
141
+ <td>Seer (Tian et al., 2025)</td>
142
+ <td>0.57B</td><td>-</td><td> - </td><td>- </td><td>78.7</td><td> 78.7</td>
143
+ </tr>
144
+
145
+ <tr>
146
+ <td>VLA-OS (Gao et al., 2025)</td>
147
+ <td>0.5B</td><td>87.0 </td><td>96.5</td><td> 92.7 </td><td>66.0</td><td> 85.6</td>
148
+ </tr>
149
+
150
+ <tr>
151
+ <td>Diffusion Policy (Chi et al., 2023)</td>
152
+ <td>-</td><td>78.3</td><td> 92.5</td><td> 68.3 </td><td>50.5 </td><td>72.4</td>
153
+ </tr>
154
+
155
+ <tr>
156
+ <td><b>VLA-Adapter (Ours)</b></td>
157
+ <td><b>0.5B</b></td><td><b>97.8</b></td><td> <b>99.2</b> </td><td><i><u>97.2*</u></i></td><td> <i><u>93.4* </u></i></td><td><i><u>96.9*</u></i></td>
158
+ </tr>
159
+
160
+ </table>
161
+
162
+ ### Effectiveness Comparison
163
+
164
+ <table>
165
+ <tr>
166
+ <td></td>
167
+ <td><strong>OpenVLA-OFT</strong></td>
168
+ <td><strong>VLA-Adapter</strong></td>
169
+ <td></td>
170
+ </tr>
171
+
172
+ <tr>
173
+ <td>Backbone</td>
174
+ <td>7B</td>
175
+ <td><strong>0.5B</strong></td>
176
+ <td>1/14×</td>
177
+ </tr>
178
+
179
+ <tr>
180
+ <td>Fine-Tuning Cost</td>
181
+ <td>304GPU·h</td>
182
+ <td><strong>8GPU·h</strong></td>
183
+ <td>1/38×</td>
184
+ </tr>
185
+
186
+ <tr>
187
+ <td>Training VRAM (8 batch)</td>
188
+ <td>62GB</td>
189
+ <td><strong>24.7GB</strong></td>
190
+ <td>0.4×</td>
191
+ </tr>
192
+
193
+ <tr>
194
+ <td>Throughput (8 chunk)</td>
195
+ <td>109.7Hz</td>
196
+ <td><strong>219.2Hz</strong></td>
197
+ <td>2×</td>
198
+ </tr>
199
+
200
+ <tr>
201
+ <td>Performance</td>
202
+ <td>97.1%</td>
203
+ <td><strong>96.9%</strong></td>
204
+ <td>Maintain</td>
205
+ </tr>
206
+ </table>
207
+
208
+ ## Citation instructions
209
+
210
+ ```BibTeX
211
+ @article{Wang2025VLAAdapter,
212
+ author = {Wang, Yihao and Ding, Pengxiang and Li, Lingxiao and Cui, Can and Ge, Zirui and Tong, Xinyang and Song, Wenxuan and Zhao, Han and Zhao, Wei and Hou, Pengxu and Huang, Siteng and Tang, Yifan and Wang, Wenhui and Zhang, Ru and Liu, Jianyi and Wang, Donglin},
213
+ title = {VLA-Adapter: An Effective Paradigm for Tiny-Scale Vision-Language-Action Model},
214
+ journal = {ArXiv},
215
+ year = {2025}
216
+ }
217
+ ```
action_head--20000_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:587679b0388d6e559344ca67ce8d10fe167fccbfeb1945c8d3d01deda2107f6b
3
+ size 204387786
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
config.json ADDED
@@ -0,0 +1,3182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../openvla-oft/pretrained_models/minivla/config.json",
3
+ "arch_specifier": "no-align+fused-gelu-mlp",
4
+ "architectures": [
5
+ "OpenVLAForActionPrediction"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
9
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
10
+ },
11
+ "hf_llm_id": "",
12
+ "image_resize_strategy": "resize-naive",
13
+ "image_sizes": [
14
+ 224,
15
+ 224
16
+ ],
17
+ "llm_backbone_id": "qwen25-0_5b-extra",
18
+ "llm_max_length": 2048,
19
+ "model_type": "openvla",
20
+ "n_action_bins": 256,
21
+ "norm_stats": {
22
+ "austin_buds_dataset_converted_externally_to_rlds": {
23
+ "action": {
24
+ "mask": [
25
+ true,
26
+ true,
27
+ true,
28
+ true,
29
+ true,
30
+ true,
31
+ false
32
+ ],
33
+ "max": [
34
+ 1.0,
35
+ 1.0,
36
+ 1.0,
37
+ 0.0,
38
+ 0.0,
39
+ 0.0,
40
+ 1.0
41
+ ],
42
+ "mean": [
43
+ -0.07678354531526566,
44
+ 0.0036849044263362885,
45
+ 0.05644911900162697,
46
+ 0.0,
47
+ 0.0,
48
+ 0.0,
49
+ 0.3510494828224182
50
+ ],
51
+ "min": [
52
+ -1.0,
53
+ -1.0,
54
+ -1.0,
55
+ 0.0,
56
+ 0.0,
57
+ 0.0,
58
+ 0.0
59
+ ],
60
+ "q01": [
61
+ -1.0,
62
+ -0.9599999785423279,
63
+ -0.8714285492897034,
64
+ 0.0,
65
+ 0.0,
66
+ 0.0,
67
+ 0.0
68
+ ],
69
+ "q99": [
70
+ 1.0,
71
+ 0.8600000143051147,
72
+ 1.0,
73
+ 0.0,
74
+ 0.0,
75
+ 0.0,
76
+ 1.0
77
+ ],
78
+ "std": [
79
+ 0.6367740631103516,
80
+ 0.37889179587364197,
81
+ 0.47796326875686646,
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 0.47721168398857117
86
+ ]
87
+ },
88
+ "num_trajectories": 50,
89
+ "num_transitions": 34112,
90
+ "proprio": {
91
+ "max": [
92
+ 0.0,
93
+ 0.0,
94
+ 0.0,
95
+ 0.0,
96
+ 0.0,
97
+ 0.0,
98
+ 0.0
99
+ ],
100
+ "mean": [
101
+ 0.0,
102
+ 0.0,
103
+ 0.0,
104
+ 0.0,
105
+ 0.0,
106
+ 0.0,
107
+ 0.0
108
+ ],
109
+ "min": [
110
+ 0.0,
111
+ 0.0,
112
+ 0.0,
113
+ 0.0,
114
+ 0.0,
115
+ 0.0,
116
+ 0.0
117
+ ],
118
+ "q01": [
119
+ 0.0,
120
+ 0.0,
121
+ 0.0,
122
+ 0.0,
123
+ 0.0,
124
+ 0.0,
125
+ 0.0
126
+ ],
127
+ "q99": [
128
+ 0.0,
129
+ 0.0,
130
+ 0.0,
131
+ 0.0,
132
+ 0.0,
133
+ 0.0,
134
+ 0.0
135
+ ],
136
+ "std": [
137
+ 0.0,
138
+ 0.0,
139
+ 0.0,
140
+ 0.0,
141
+ 0.0,
142
+ 0.0,
143
+ 0.0
144
+ ]
145
+ }
146
+ },
147
+ "austin_sailor_dataset_converted_externally_to_rlds": {
148
+ "action": {
149
+ "mask": [
150
+ true,
151
+ true,
152
+ true,
153
+ true,
154
+ true,
155
+ true,
156
+ false
157
+ ],
158
+ "max": [
159
+ 1.0,
160
+ 1.0,
161
+ 1.0,
162
+ 0.0,
163
+ 0.0,
164
+ 0.375,
165
+ 1.0
166
+ ],
167
+ "mean": [
168
+ 0.011825348250567913,
169
+ 0.006461074110120535,
170
+ 0.06023626774549484,
171
+ 0.0,
172
+ 0.0,
173
+ 0.0016465914668515325,
174
+ 0.5260950326919556
175
+ ],
176
+ "min": [
177
+ -1.0,
178
+ -1.0,
179
+ -1.0,
180
+ 0.0,
181
+ 0.0,
182
+ -0.375,
183
+ 0.0
184
+ ],
185
+ "q01": [
186
+ -1.0,
187
+ -0.9828571677207947,
188
+ -0.6000000238418579,
189
+ 0.0,
190
+ 0.0,
191
+ -0.17249999940395355,
192
+ 0.0
193
+ ],
194
+ "q99": [
195
+ 1.0,
196
+ 0.9457142949104309,
197
+ 1.0,
198
+ 0.0,
199
+ 0.0,
200
+ 0.17892856895923615,
201
+ 1.0
202
+ ],
203
+ "std": [
204
+ 0.46348899602890015,
205
+ 0.41240179538726807,
206
+ 0.411862850189209,
207
+ 0.0,
208
+ 0.0,
209
+ 0.0578610822558403,
210
+ 0.49894046783447266
211
+ ]
212
+ },
213
+ "num_trajectories": 240,
214
+ "num_transitions": 353094,
215
+ "proprio": {
216
+ "max": [
217
+ 0.0,
218
+ 0.0,
219
+ 0.0,
220
+ 0.0,
221
+ 0.0,
222
+ 0.0,
223
+ 0.0
224
+ ],
225
+ "mean": [
226
+ 0.0,
227
+ 0.0,
228
+ 0.0,
229
+ 0.0,
230
+ 0.0,
231
+ 0.0,
232
+ 0.0
233
+ ],
234
+ "min": [
235
+ 0.0,
236
+ 0.0,
237
+ 0.0,
238
+ 0.0,
239
+ 0.0,
240
+ 0.0,
241
+ 0.0
242
+ ],
243
+ "q01": [
244
+ 0.0,
245
+ 0.0,
246
+ 0.0,
247
+ 0.0,
248
+ 0.0,
249
+ 0.0,
250
+ 0.0
251
+ ],
252
+ "q99": [
253
+ 0.0,
254
+ 0.0,
255
+ 0.0,
256
+ 0.0,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0
260
+ ],
261
+ "std": [
262
+ 0.0,
263
+ 0.0,
264
+ 0.0,
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0
269
+ ]
270
+ }
271
+ },
272
+ "austin_sirius_dataset_converted_externally_to_rlds": {
273
+ "action": {
274
+ "mask": [
275
+ true,
276
+ true,
277
+ true,
278
+ true,
279
+ true,
280
+ true,
281
+ false
282
+ ],
283
+ "max": [
284
+ 1.0002285242080688,
285
+ 0.960608720779419,
286
+ 1.105179786682129,
287
+ 0.0,
288
+ 0.0,
289
+ 0.341785728931427,
290
+ 1.0
291
+ ],
292
+ "mean": [
293
+ 0.07747682929039001,
294
+ 0.03195561468601227,
295
+ 0.04244732856750488,
296
+ 0.0,
297
+ 0.0,
298
+ -0.01603456400334835,
299
+ 0.43260177969932556
300
+ ],
301
+ "min": [
302
+ -1.0183025598526,
303
+ -0.9800000190734863,
304
+ -0.9774575233459473,
305
+ 0.0,
306
+ 0.0,
307
+ -0.34607142210006714,
308
+ 0.0
309
+ ],
310
+ "q01": [
311
+ -0.780905865430832,
312
+ -0.5667179036140442,
313
+ -0.5254343223571777,
314
+ 0.0,
315
+ 0.0,
316
+ -0.28495091378688814,
317
+ 0.0
318
+ ],
319
+ "q99": [
320
+ 0.9569637751579284,
321
+ 0.6971374487876891,
322
+ 0.8124888157844541,
323
+ 0.0,
324
+ 0.0,
325
+ 0.1971428543329239,
326
+ 1.0
327
+ ],
328
+ "std": [
329
+ 0.3906329572200775,
330
+ 0.2998155355453491,
331
+ 0.2782271206378937,
332
+ 0.0,
333
+ 0.0,
334
+ 0.08120622485876083,
335
+ 0.49528297781944275
336
+ ]
337
+ },
338
+ "num_trajectories": 559,
339
+ "num_transitions": 279939,
340
+ "proprio": {
341
+ "max": [
342
+ 0.0,
343
+ 0.0,
344
+ 0.0,
345
+ 0.0,
346
+ 0.0,
347
+ 0.0,
348
+ 0.0
349
+ ],
350
+ "mean": [
351
+ 0.0,
352
+ 0.0,
353
+ 0.0,
354
+ 0.0,
355
+ 0.0,
356
+ 0.0,
357
+ 0.0
358
+ ],
359
+ "min": [
360
+ 0.0,
361
+ 0.0,
362
+ 0.0,
363
+ 0.0,
364
+ 0.0,
365
+ 0.0,
366
+ 0.0
367
+ ],
368
+ "q01": [
369
+ 0.0,
370
+ 0.0,
371
+ 0.0,
372
+ 0.0,
373
+ 0.0,
374
+ 0.0,
375
+ 0.0
376
+ ],
377
+ "q99": [
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.0,
382
+ 0.0,
383
+ 0.0,
384
+ 0.0
385
+ ],
386
+ "std": [
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0,
392
+ 0.0,
393
+ 0.0
394
+ ]
395
+ }
396
+ },
397
+ "bc_z": {
398
+ "action": {
399
+ "mask": [
400
+ true,
401
+ true,
402
+ true,
403
+ true,
404
+ true,
405
+ true,
406
+ false
407
+ ],
408
+ "max": [
409
+ 0.2165454924106598,
410
+ 0.1251407265663147,
411
+ 0.10772687941789627,
412
+ 0.33544227480888367,
413
+ 0.28117990493774414,
414
+ 0.40614867210388184,
415
+ 1.0
416
+ ],
417
+ "mean": [
418
+ -0.009958467446267605,
419
+ 0.0008958321413956583,
420
+ 0.004995597992092371,
421
+ 0.00029755113064311445,
422
+ -0.008735382929444313,
423
+ -0.030693737789988518,
424
+ 0.8344562649726868
425
+ ],
426
+ "min": [
427
+ -0.1677047461271286,
428
+ -0.14630407094955444,
429
+ -0.10066790133714676,
430
+ -0.29421567916870117,
431
+ -0.32101404666900635,
432
+ -0.4635624885559082,
433
+ 0.0
434
+ ],
435
+ "q01": [
436
+ -0.09220654994249344,
437
+ -0.06456145539879798,
438
+ -0.049121275544166565,
439
+ -0.11594625547528267,
440
+ -0.14152548640966414,
441
+ -0.2251061636209488,
442
+ 0.0
443
+ ],
444
+ "q99": [
445
+ 0.07628866866230968,
446
+ 0.058019736707210584,
447
+ 0.052540797740221024,
448
+ 0.11740604028105736,
449
+ 0.11703975558280955,
450
+ 0.16729306846857078,
451
+ 1.0
452
+ ],
453
+ "std": [
454
+ 0.03053455986082554,
455
+ 0.0231423731893301,
456
+ 0.020641816779971123,
457
+ 0.04155943542718887,
458
+ 0.046427831053733826,
459
+ 0.0769818127155304,
460
+ 0.3610210120677948
461
+ ]
462
+ },
463
+ "num_trajectories": 43264,
464
+ "num_transitions": 6015535,
465
+ "proprio": {
466
+ "max": [
467
+ 0.0,
468
+ 0.0,
469
+ 0.0,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.0
474
+ ],
475
+ "mean": [
476
+ 0.0,
477
+ 0.0,
478
+ 0.0,
479
+ 0.0,
480
+ 0.0,
481
+ 0.0,
482
+ 0.0
483
+ ],
484
+ "min": [
485
+ 0.0,
486
+ 0.0,
487
+ 0.0,
488
+ 0.0,
489
+ 0.0,
490
+ 0.0,
491
+ 0.0
492
+ ],
493
+ "q01": [
494
+ 0.0,
495
+ 0.0,
496
+ 0.0,
497
+ 0.0,
498
+ 0.0,
499
+ 0.0,
500
+ 0.0
501
+ ],
502
+ "q99": [
503
+ 0.0,
504
+ 0.0,
505
+ 0.0,
506
+ 0.0,
507
+ 0.0,
508
+ 0.0,
509
+ 0.0
510
+ ],
511
+ "std": [
512
+ 0.0,
513
+ 0.0,
514
+ 0.0,
515
+ 0.0,
516
+ 0.0,
517
+ 0.0,
518
+ 0.0
519
+ ]
520
+ }
521
+ },
522
+ "berkeley_autolab_ur5": {
523
+ "action": {
524
+ "mask": [
525
+ true,
526
+ true,
527
+ true,
528
+ true,
529
+ true,
530
+ true,
531
+ false
532
+ ],
533
+ "max": [
534
+ 0.019999999552965164,
535
+ 0.019999999552965164,
536
+ 0.019999999552965164,
537
+ 0.06666667014360428,
538
+ 0.06666667014360428,
539
+ 0.06666667014360428,
540
+ 1.0
541
+ ],
542
+ "mean": [
543
+ 0.0005683620693162084,
544
+ 0.001217700308188796,
545
+ -0.0005296372692100704,
546
+ 0.00021029810886830091,
547
+ 6.0695128922816366e-05,
548
+ 0.001204986940138042,
549
+ 0.6298308372497559
550
+ ],
551
+ "min": [
552
+ -0.019999999552965164,
553
+ -0.019999999552965164,
554
+ -0.019999999552965164,
555
+ -0.06666667014360428,
556
+ -0.06666667014360428,
557
+ -0.06666667014360428,
558
+ 0.0
559
+ ],
560
+ "q01": [
561
+ -0.019999999552965164,
562
+ -0.019999999552965164,
563
+ -0.019999999552965164,
564
+ -0.02628571353852749,
565
+ -0.06666667014360428,
566
+ -0.03847619146108627,
567
+ 0.0
568
+ ],
569
+ "q99": [
570
+ 0.019999999552965164,
571
+ 0.019999999552965164,
572
+ 0.019999999552965164,
573
+ 0.031809523701667786,
574
+ 0.06666667014360428,
575
+ 0.036571428179740906,
576
+ 1.0
577
+ ],
578
+ "std": [
579
+ 0.0115329809486866,
580
+ 0.007990492507815361,
581
+ 0.009577835910022259,
582
+ 0.009432995691895485,
583
+ 0.016427582129836082,
584
+ 0.011053967289626598,
585
+ 0.48267969489097595
586
+ ]
587
+ },
588
+ "num_trajectories": 1000,
589
+ "num_transitions": 97939,
590
+ "proprio": {
591
+ "max": [
592
+ 0.0,
593
+ 0.0,
594
+ 0.0,
595
+ 0.0,
596
+ 0.0,
597
+ 0.0,
598
+ 0.0
599
+ ],
600
+ "mean": [
601
+ 0.0,
602
+ 0.0,
603
+ 0.0,
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0
608
+ ],
609
+ "min": [
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0
617
+ ],
618
+ "q01": [
619
+ 0.0,
620
+ 0.0,
621
+ 0.0,
622
+ 0.0,
623
+ 0.0,
624
+ 0.0,
625
+ 0.0
626
+ ],
627
+ "q99": [
628
+ 0.0,
629
+ 0.0,
630
+ 0.0,
631
+ 0.0,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0
635
+ ],
636
+ "std": [
637
+ 0.0,
638
+ 0.0,
639
+ 0.0,
640
+ 0.0,
641
+ 0.0,
642
+ 0.0,
643
+ 0.0
644
+ ]
645
+ }
646
+ },
647
+ "berkeley_cable_routing": {
648
+ "action": {
649
+ "mask": [
650
+ true,
651
+ true,
652
+ true,
653
+ true,
654
+ true,
655
+ true,
656
+ false
657
+ ],
658
+ "max": [
659
+ 0.9633283019065857,
660
+ 1.0,
661
+ 1.0,
662
+ 0.0,
663
+ 0.0,
664
+ 1.0,
665
+ 0.0
666
+ ],
667
+ "mean": [
668
+ -0.07139874249696732,
669
+ 0.023609008640050888,
670
+ 0.10241943597793579,
671
+ 0.0,
672
+ 0.0,
673
+ 0.049671024084091187,
674
+ 0.0
675
+ ],
676
+ "min": [
677
+ -0.9809081554412842,
678
+ -0.9554349184036255,
679
+ -0.9994775056838989,
680
+ 0.0,
681
+ 0.0,
682
+ -1.0,
683
+ 0.0
684
+ ],
685
+ "q01": [
686
+ -0.5534318816661835,
687
+ -0.4797285574674606,
688
+ -0.5314934802055359,
689
+ 0.0,
690
+ 0.0,
691
+ -0.8855219376087189,
692
+ 0.0
693
+ ],
694
+ "q99": [
695
+ 0.42652835428714786,
696
+ 0.5000944086909298,
697
+ 0.639823433756829,
698
+ 0.0,
699
+ 0.0,
700
+ 0.984243879914284,
701
+ 0.0
702
+ ],
703
+ "std": [
704
+ 0.1815500408411026,
705
+ 0.1810990273952484,
706
+ 0.21220779418945312,
707
+ 0.0,
708
+ 0.0,
709
+ 0.3475511968135834,
710
+ 0.0
711
+ ]
712
+ },
713
+ "num_trajectories": 1647,
714
+ "num_transitions": 42328,
715
+ "proprio": {
716
+ "max": [
717
+ 0.0,
718
+ 0.0,
719
+ 0.0,
720
+ 0.0,
721
+ 0.0,
722
+ 0.0,
723
+ 0.0
724
+ ],
725
+ "mean": [
726
+ 0.0,
727
+ 0.0,
728
+ 0.0,
729
+ 0.0,
730
+ 0.0,
731
+ 0.0,
732
+ 0.0
733
+ ],
734
+ "min": [
735
+ 0.0,
736
+ 0.0,
737
+ 0.0,
738
+ 0.0,
739
+ 0.0,
740
+ 0.0,
741
+ 0.0
742
+ ],
743
+ "q01": [
744
+ 0.0,
745
+ 0.0,
746
+ 0.0,
747
+ 0.0,
748
+ 0.0,
749
+ 0.0,
750
+ 0.0
751
+ ],
752
+ "q99": [
753
+ 0.0,
754
+ 0.0,
755
+ 0.0,
756
+ 0.0,
757
+ 0.0,
758
+ 0.0,
759
+ 0.0
760
+ ],
761
+ "std": [
762
+ 0.0,
763
+ 0.0,
764
+ 0.0,
765
+ 0.0,
766
+ 0.0,
767
+ 0.0,
768
+ 0.0
769
+ ]
770
+ }
771
+ },
772
+ "berkeley_fanuc_manipulation": {
773
+ "action": {
774
+ "mask": [
775
+ true,
776
+ true,
777
+ true,
778
+ true,
779
+ true,
780
+ true,
781
+ false
782
+ ],
783
+ "max": [
784
+ 0.009999999776482582,
785
+ 0.009999999776482582,
786
+ 0.009999999776482582,
787
+ 0.03490658476948738,
788
+ 0.03490658476948738,
789
+ 0.03490658476948738,
790
+ 1.0
791
+ ],
792
+ "mean": [
793
+ 0.0007744057802483439,
794
+ -0.00031240080716088414,
795
+ -0.0015001941937953234,
796
+ -0.0007515158504247665,
797
+ -0.00015832878125365824,
798
+ 0.00014327642566058785,
799
+ 0.699295699596405
800
+ ],
801
+ "min": [
802
+ -0.009999999776482582,
803
+ -0.009999999776482582,
804
+ -0.009999999776482582,
805
+ -0.03490658476948738,
806
+ -0.03490658476948738,
807
+ -0.03490658476948738,
808
+ 0.0
809
+ ],
810
+ "q01": [
811
+ -0.009999999776482582,
812
+ -0.009999999776482582,
813
+ -0.009999999776482582,
814
+ -0.03490658476948738,
815
+ 0.0,
816
+ -0.03490658476948738,
817
+ 0.0
818
+ ],
819
+ "q99": [
820
+ 0.009999999776482582,
821
+ 0.009999999776482582,
822
+ 0.009999999776482582,
823
+ 0.03490658476948738,
824
+ 0.0,
825
+ 0.03490658476948738,
826
+ 1.0
827
+ ],
828
+ "std": [
829
+ 0.0034070091787725687,
830
+ 0.0049921851605176926,
831
+ 0.005344334989786148,
832
+ 0.00759894959628582,
833
+ 0.004081866703927517,
834
+ 0.008568956516683102,
835
+ 0.4586937427520752
836
+ ]
837
+ },
838
+ "num_trajectories": 415,
839
+ "num_transitions": 62613,
840
+ "proprio": {
841
+ "max": [
842
+ 0.0,
843
+ 0.0,
844
+ 0.0,
845
+ 0.0,
846
+ 0.0,
847
+ 0.0,
848
+ 0.0
849
+ ],
850
+ "mean": [
851
+ 0.0,
852
+ 0.0,
853
+ 0.0,
854
+ 0.0,
855
+ 0.0,
856
+ 0.0,
857
+ 0.0
858
+ ],
859
+ "min": [
860
+ 0.0,
861
+ 0.0,
862
+ 0.0,
863
+ 0.0,
864
+ 0.0,
865
+ 0.0,
866
+ 0.0
867
+ ],
868
+ "q01": [
869
+ 0.0,
870
+ 0.0,
871
+ 0.0,
872
+ 0.0,
873
+ 0.0,
874
+ 0.0,
875
+ 0.0
876
+ ],
877
+ "q99": [
878
+ 0.0,
879
+ 0.0,
880
+ 0.0,
881
+ 0.0,
882
+ 0.0,
883
+ 0.0,
884
+ 0.0
885
+ ],
886
+ "std": [
887
+ 0.0,
888
+ 0.0,
889
+ 0.0,
890
+ 0.0,
891
+ 0.0,
892
+ 0.0,
893
+ 0.0
894
+ ]
895
+ }
896
+ },
897
+ "bridge_orig": {
898
+ "action": {
899
+ "mask": [
900
+ true,
901
+ true,
902
+ true,
903
+ true,
904
+ true,
905
+ true,
906
+ false
907
+ ],
908
+ "max": [
909
+ 0.41691166162490845,
910
+ 0.25864794850349426,
911
+ 0.21218234300613403,
912
+ 3.122201919555664,
913
+ 1.8618112802505493,
914
+ 6.280478477478027,
915
+ 1.0
916
+ ],
917
+ "mean": [
918
+ 0.0002334194869035855,
919
+ 0.00013004911306779832,
920
+ -0.00012762474943883717,
921
+ -0.0001556558854645118,
922
+ -0.0004039328487124294,
923
+ 0.00023557482927571982,
924
+ 0.5764579176902771
925
+ ],
926
+ "min": [
927
+ -0.4007510244846344,
928
+ -0.13874775171279907,
929
+ -0.22553899884223938,
930
+ -3.2010786533355713,
931
+ -1.8618112802505493,
932
+ -6.279075622558594,
933
+ 0.0
934
+ ],
935
+ "q01": [
936
+ -0.02872725307941437,
937
+ -0.04170349963009357,
938
+ -0.026093858778476715,
939
+ -0.08092105075716972,
940
+ -0.09288699507713317,
941
+ -0.20718276381492615,
942
+ 0.0
943
+ ],
944
+ "q99": [
945
+ 0.028309678435325586,
946
+ 0.040855254605412394,
947
+ 0.040161586627364146,
948
+ 0.08192047759890528,
949
+ 0.07792850524187081,
950
+ 0.20382574498653397,
951
+ 1.0
952
+ ],
953
+ "std": [
954
+ 0.009765930473804474,
955
+ 0.013689135201275349,
956
+ 0.012667362578213215,
957
+ 0.028534092009067535,
958
+ 0.030637972056865692,
959
+ 0.07691419124603271,
960
+ 0.4973701536655426
961
+ ]
962
+ },
963
+ "num_trajectories": 60064,
964
+ "num_transitions": 2135463,
965
+ "proprio": {
966
+ "max": [
967
+ 0.0,
968
+ 0.0,
969
+ 0.0,
970
+ 0.0,
971
+ 0.0,
972
+ 0.0,
973
+ 0.0
974
+ ],
975
+ "mean": [
976
+ 0.0,
977
+ 0.0,
978
+ 0.0,
979
+ 0.0,
980
+ 0.0,
981
+ 0.0,
982
+ 0.0
983
+ ],
984
+ "min": [
985
+ 0.0,
986
+ 0.0,
987
+ 0.0,
988
+ 0.0,
989
+ 0.0,
990
+ 0.0,
991
+ 0.0
992
+ ],
993
+ "q01": [
994
+ 0.0,
995
+ 0.0,
996
+ 0.0,
997
+ 0.0,
998
+ 0.0,
999
+ 0.0,
1000
+ 0.0
1001
+ ],
1002
+ "q99": [
1003
+ 0.0,
1004
+ 0.0,
1005
+ 0.0,
1006
+ 0.0,
1007
+ 0.0,
1008
+ 0.0,
1009
+ 0.0
1010
+ ],
1011
+ "std": [
1012
+ 0.0,
1013
+ 0.0,
1014
+ 0.0,
1015
+ 0.0,
1016
+ 0.0,
1017
+ 0.0,
1018
+ 0.0
1019
+ ]
1020
+ }
1021
+ },
1022
+ "cmu_stretch": {
1023
+ "action": {
1024
+ "mask": [
1025
+ true,
1026
+ true,
1027
+ true,
1028
+ true,
1029
+ true,
1030
+ true,
1031
+ false
1032
+ ],
1033
+ "max": [
1034
+ 0.02338407188653946,
1035
+ 0.0,
1036
+ 0.023404927924275398,
1037
+ 0.0,
1038
+ 0.0,
1039
+ 0.0,
1040
+ 1.0
1041
+ ],
1042
+ "mean": [
1043
+ 0.00036304505192674696,
1044
+ 0.0,
1045
+ 0.0016466958913952112,
1046
+ 0.0,
1047
+ 0.0,
1048
+ 0.0,
1049
+ 0.3987048268318176
1050
+ ],
1051
+ "min": [
1052
+ -0.019353797659277916,
1053
+ 0.0,
1054
+ -0.02019215188920498,
1055
+ 0.0,
1056
+ 0.0,
1057
+ 0.0,
1058
+ 0.0
1059
+ ],
1060
+ "q01": [
1061
+ -0.011175686959177256,
1062
+ 0.0,
1063
+ -0.0032206363626755773,
1064
+ 0.0,
1065
+ 0.0,
1066
+ 0.0,
1067
+ 0.0
1068
+ ],
1069
+ "q99": [
1070
+ 0.014501785952597848,
1071
+ 0.0,
1072
+ 0.015056106168776728,
1073
+ 0.0,
1074
+ 0.0,
1075
+ 0.0,
1076
+ 1.0
1077
+ ],
1078
+ "std": [
1079
+ 0.004081828519701958,
1080
+ 0.0,
1081
+ 0.0037743328139185905,
1082
+ 0.0,
1083
+ 0.0,
1084
+ 0.0,
1085
+ 0.48963725566864014
1086
+ ]
1087
+ },
1088
+ "num_trajectories": 135,
1089
+ "num_transitions": 25016,
1090
+ "proprio": {
1091
+ "max": [
1092
+ 0.0,
1093
+ 0.0,
1094
+ 0.0,
1095
+ 0.0,
1096
+ 0.0,
1097
+ 0.0,
1098
+ 0.0
1099
+ ],
1100
+ "mean": [
1101
+ 0.0,
1102
+ 0.0,
1103
+ 0.0,
1104
+ 0.0,
1105
+ 0.0,
1106
+ 0.0,
1107
+ 0.0
1108
+ ],
1109
+ "min": [
1110
+ 0.0,
1111
+ 0.0,
1112
+ 0.0,
1113
+ 0.0,
1114
+ 0.0,
1115
+ 0.0,
1116
+ 0.0
1117
+ ],
1118
+ "q01": [
1119
+ 0.0,
1120
+ 0.0,
1121
+ 0.0,
1122
+ 0.0,
1123
+ 0.0,
1124
+ 0.0,
1125
+ 0.0
1126
+ ],
1127
+ "q99": [
1128
+ 0.0,
1129
+ 0.0,
1130
+ 0.0,
1131
+ 0.0,
1132
+ 0.0,
1133
+ 0.0,
1134
+ 0.0
1135
+ ],
1136
+ "std": [
1137
+ 0.0,
1138
+ 0.0,
1139
+ 0.0,
1140
+ 0.0,
1141
+ 0.0,
1142
+ 0.0,
1143
+ 0.0
1144
+ ]
1145
+ }
1146
+ },
1147
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
1148
+ "action": {
1149
+ "mask": [
1150
+ true,
1151
+ true,
1152
+ true,
1153
+ true,
1154
+ true,
1155
+ true,
1156
+ false
1157
+ ],
1158
+ "max": [
1159
+ 0.18991442024707794,
1160
+ 0.0739002525806427,
1161
+ 0.18064819276332855,
1162
+ 0.0866486132144928,
1163
+ 0.13464981317520142,
1164
+ 0.16910280287265778,
1165
+ 1.0
1166
+ ],
1167
+ "mean": [
1168
+ 0.006647810339927673,
1169
+ -0.0007657372043468058,
1170
+ 0.006522852927446365,
1171
+ 0.0011679717572405934,
1172
+ -0.006395625416189432,
1173
+ -0.011902998201549053,
1174
+ 0.6985887289047241
1175
+ ],
1176
+ "min": [
1177
+ -0.10054297000169754,
1178
+ -0.08427435159683228,
1179
+ -0.13533438742160797,
1180
+ -0.17556548118591309,
1181
+ -0.18485672771930695,
1182
+ -0.2680685818195343,
1183
+ 0.0
1184
+ ],
1185
+ "q01": [
1186
+ -0.02987122368067503,
1187
+ -0.06013262912631035,
1188
+ -0.08286409199237824,
1189
+ -0.05924444157630205,
1190
+ -0.15986866518855095,
1191
+ -0.15636983573436739,
1192
+ 0.0
1193
+ ],
1194
+ "q99": [
1195
+ 0.08832092039287087,
1196
+ 0.042126184627413736,
1197
+ 0.11311905644834042,
1198
+ 0.0643695573508739,
1199
+ 0.03941855944693088,
1200
+ 0.156646853685379,
1201
+ 1.0
1202
+ ],
1203
+ "std": [
1204
+ 0.021393608301877975,
1205
+ 0.01814231649041176,
1206
+ 0.03374375030398369,
1207
+ 0.01743541844189167,
1208
+ 0.03394376486539841,
1209
+ 0.04641875624656677,
1210
+ 0.4588589072227478
1211
+ ]
1212
+ },
1213
+ "num_trajectories": 104,
1214
+ "num_transitions": 8928,
1215
+ "proprio": {
1216
+ "max": [
1217
+ 0.0,
1218
+ 0.0,
1219
+ 0.0,
1220
+ 0.0,
1221
+ 0.0,
1222
+ 0.0,
1223
+ 0.0
1224
+ ],
1225
+ "mean": [
1226
+ 0.0,
1227
+ 0.0,
1228
+ 0.0,
1229
+ 0.0,
1230
+ 0.0,
1231
+ 0.0,
1232
+ 0.0
1233
+ ],
1234
+ "min": [
1235
+ 0.0,
1236
+ 0.0,
1237
+ 0.0,
1238
+ 0.0,
1239
+ 0.0,
1240
+ 0.0,
1241
+ 0.0
1242
+ ],
1243
+ "q01": [
1244
+ 0.0,
1245
+ 0.0,
1246
+ 0.0,
1247
+ 0.0,
1248
+ 0.0,
1249
+ 0.0,
1250
+ 0.0
1251
+ ],
1252
+ "q99": [
1253
+ 0.0,
1254
+ 0.0,
1255
+ 0.0,
1256
+ 0.0,
1257
+ 0.0,
1258
+ 0.0,
1259
+ 0.0
1260
+ ],
1261
+ "std": [
1262
+ 0.0,
1263
+ 0.0,
1264
+ 0.0,
1265
+ 0.0,
1266
+ 0.0,
1267
+ 0.0,
1268
+ 0.0
1269
+ ]
1270
+ }
1271
+ },
1272
+ "dobbe": {
1273
+ "action": {
1274
+ "mask": [
1275
+ true,
1276
+ true,
1277
+ true,
1278
+ true,
1279
+ true,
1280
+ true,
1281
+ false
1282
+ ],
1283
+ "max": [
1284
+ 38.590423583984375,
1285
+ 17.932697296142578,
1286
+ 4.843764305114746,
1287
+ 1.4372116327285767,
1288
+ 0.4340403974056244,
1289
+ 1.2057193517684937,
1290
+ 0.9998947381973267
1291
+ ],
1292
+ "mean": [
1293
+ -0.0001120665911003016,
1294
+ 0.0011229600058868527,
1295
+ -0.00010194431524723768,
1296
+ -7.371398532995954e-05,
1297
+ -0.00067531579406932,
1298
+ -5.6643435527803376e-05,
1299
+ 0.6318281888961792
1300
+ ],
1301
+ "min": [
1302
+ -5.700923442840576,
1303
+ -21.605947494506836,
1304
+ -123.72489929199219,
1305
+ -1.7229845523834229,
1306
+ -0.4998578727245331,
1307
+ -0.8867913484573364,
1308
+ 1.4196479014572105e-06
1309
+ ],
1310
+ "q01": [
1311
+ -0.01119564864784479,
1312
+ -0.014266146533191203,
1313
+ -0.0071747214533388615,
1314
+ -0.009444301575422287,
1315
+ -0.03990109823644161,
1316
+ -0.017422311007976532,
1317
+ 4.003279136668425e-05
1318
+ ],
1319
+ "q99": [
1320
+ 0.01015154086053368,
1321
+ 0.017181577533483497,
1322
+ 0.007216989761218411,
1323
+ 0.010380979906767595,
1324
+ 0.03556173853576176,
1325
+ 0.018032474815845446,
1326
+ 0.9982578039169312
1327
+ ],
1328
+ "std": [
1329
+ 0.04264938458800316,
1330
+ 0.04428559169173241,
1331
+ 0.12224084138870239,
1332
+ 0.005388413090258837,
1333
+ 0.011246449314057827,
1334
+ 0.006287882570177317,
1335
+ 0.39732322096824646
1336
+ ]
1337
+ },
1338
+ "num_trajectories": 5208,
1339
+ "num_transitions": 1139911,
1340
+ "proprio": {
1341
+ "max": [
1342
+ 0.0,
1343
+ 0.0,
1344
+ 0.0,
1345
+ 0.0,
1346
+ 0.0,
1347
+ 0.0,
1348
+ 0.0
1349
+ ],
1350
+ "mean": [
1351
+ 0.0,
1352
+ 0.0,
1353
+ 0.0,
1354
+ 0.0,
1355
+ 0.0,
1356
+ 0.0,
1357
+ 0.0
1358
+ ],
1359
+ "min": [
1360
+ 0.0,
1361
+ 0.0,
1362
+ 0.0,
1363
+ 0.0,
1364
+ 0.0,
1365
+ 0.0,
1366
+ 0.0
1367
+ ],
1368
+ "q01": [
1369
+ 0.0,
1370
+ 0.0,
1371
+ 0.0,
1372
+ 0.0,
1373
+ 0.0,
1374
+ 0.0,
1375
+ 0.0
1376
+ ],
1377
+ "q99": [
1378
+ 0.0,
1379
+ 0.0,
1380
+ 0.0,
1381
+ 0.0,
1382
+ 0.0,
1383
+ 0.0,
1384
+ 0.0
1385
+ ],
1386
+ "std": [
1387
+ 0.0,
1388
+ 0.0,
1389
+ 0.0,
1390
+ 0.0,
1391
+ 0.0,
1392
+ 0.0,
1393
+ 0.0
1394
+ ]
1395
+ }
1396
+ },
1397
+ "fmb_dataset": {
1398
+ "action": {
1399
+ "mask": [
1400
+ true,
1401
+ true,
1402
+ true,
1403
+ true,
1404
+ true,
1405
+ true,
1406
+ false
1407
+ ],
1408
+ "max": [
1409
+ 1.399999976158142,
1410
+ 1.0,
1411
+ 1.399999976158142,
1412
+ 1.0,
1413
+ 1.0,
1414
+ 1.0,
1415
+ 1.0
1416
+ ],
1417
+ "mean": [
1418
+ 0.059029702097177505,
1419
+ -0.06476633995771408,
1420
+ -0.09787475317716599,
1421
+ 0.004325388930737972,
1422
+ 0.00028963794466108084,
1423
+ -0.04457257315516472,
1424
+ 0.7336440086364746
1425
+ ],
1426
+ "min": [
1427
+ -1.399999976158142,
1428
+ -1.399999976158142,
1429
+ -1.0,
1430
+ -1.0,
1431
+ -1.0,
1432
+ -1.0,
1433
+ 0.0
1434
+ ],
1435
+ "q01": [
1436
+ -0.8257142901420593,
1437
+ -1.399999976158142,
1438
+ -1.0,
1439
+ -1.0,
1440
+ -0.3028571307659149,
1441
+ -1.0,
1442
+ 0.0
1443
+ ],
1444
+ "q99": [
1445
+ 1.0,
1446
+ 0.5257142782211304,
1447
+ 1.0,
1448
+ 1.0,
1449
+ 0.3400000035762787,
1450
+ 1.0,
1451
+ 1.0
1452
+ ],
1453
+ "std": [
1454
+ 0.28809213638305664,
1455
+ 0.2820415794849396,
1456
+ 0.4626740515232086,
1457
+ 0.3266514539718628,
1458
+ 0.10842999070882797,
1459
+ 0.3440099358558655,
1460
+ 0.4435282051563263
1461
+ ]
1462
+ },
1463
+ "num_trajectories": 8612,
1464
+ "num_transitions": 1137459,
1465
+ "proprio": {
1466
+ "max": [
1467
+ 0.0,
1468
+ 0.0,
1469
+ 0.0,
1470
+ 0.0,
1471
+ 0.0,
1472
+ 0.0,
1473
+ 0.0
1474
+ ],
1475
+ "mean": [
1476
+ 0.0,
1477
+ 0.0,
1478
+ 0.0,
1479
+ 0.0,
1480
+ 0.0,
1481
+ 0.0,
1482
+ 0.0
1483
+ ],
1484
+ "min": [
1485
+ 0.0,
1486
+ 0.0,
1487
+ 0.0,
1488
+ 0.0,
1489
+ 0.0,
1490
+ 0.0,
1491
+ 0.0
1492
+ ],
1493
+ "q01": [
1494
+ 0.0,
1495
+ 0.0,
1496
+ 0.0,
1497
+ 0.0,
1498
+ 0.0,
1499
+ 0.0,
1500
+ 0.0
1501
+ ],
1502
+ "q99": [
1503
+ 0.0,
1504
+ 0.0,
1505
+ 0.0,
1506
+ 0.0,
1507
+ 0.0,
1508
+ 0.0,
1509
+ 0.0
1510
+ ],
1511
+ "std": [
1512
+ 0.0,
1513
+ 0.0,
1514
+ 0.0,
1515
+ 0.0,
1516
+ 0.0,
1517
+ 0.0,
1518
+ 0.0
1519
+ ]
1520
+ }
1521
+ },
1522
+ "fractal20220817_data": {
1523
+ "action": {
1524
+ "mask": [
1525
+ true,
1526
+ true,
1527
+ true,
1528
+ true,
1529
+ true,
1530
+ true,
1531
+ false
1532
+ ],
1533
+ "max": [
1534
+ 2.9984593391418457,
1535
+ 22.09052848815918,
1536
+ 2.7507524490356445,
1537
+ 1.570636510848999,
1538
+ 1.5321086645126343,
1539
+ 1.5691522359848022,
1540
+ 1.0
1541
+ ],
1542
+ "mean": [
1543
+ 0.006987582892179489,
1544
+ 0.006265917327255011,
1545
+ -0.01262515690177679,
1546
+ 0.04333311319351196,
1547
+ -0.005756212864071131,
1548
+ 0.0009130256366916001,
1549
+ 0.5354204773902893
1550
+ ],
1551
+ "min": [
1552
+ -2.0204520225524902,
1553
+ -5.497899532318115,
1554
+ -2.031663417816162,
1555
+ -1.569917917251587,
1556
+ -1.569892168045044,
1557
+ -1.570419430732727,
1558
+ 0.0
1559
+ ],
1560
+ "q01": [
1561
+ -0.22453527510166169,
1562
+ -0.14820013284683228,
1563
+ -0.231589707583189,
1564
+ -0.3517994859814644,
1565
+ -0.4193011274933815,
1566
+ -0.43643461108207704,
1567
+ 0.0
1568
+ ],
1569
+ "q99": [
1570
+ 0.17824687153100965,
1571
+ 0.14938379630446405,
1572
+ 0.21842354819178575,
1573
+ 0.5892666035890578,
1574
+ 0.35272657424211445,
1575
+ 0.44796681255102094,
1576
+ 1.0
1577
+ ],
1578
+ "std": [
1579
+ 0.0692116990685463,
1580
+ 0.05970962345600128,
1581
+ 0.07353084534406662,
1582
+ 0.15610496699810028,
1583
+ 0.13164450228214264,
1584
+ 0.14593800902366638,
1585
+ 0.497110515832901
1586
+ ]
1587
+ },
1588
+ "num_trajectories": 87212,
1589
+ "num_transitions": 3786400,
1590
+ "proprio": {
1591
+ "max": [
1592
+ 0.0,
1593
+ 0.0,
1594
+ 0.0,
1595
+ 0.0,
1596
+ 0.0,
1597
+ 0.0,
1598
+ 0.0
1599
+ ],
1600
+ "mean": [
1601
+ 0.0,
1602
+ 0.0,
1603
+ 0.0,
1604
+ 0.0,
1605
+ 0.0,
1606
+ 0.0,
1607
+ 0.0
1608
+ ],
1609
+ "min": [
1610
+ 0.0,
1611
+ 0.0,
1612
+ 0.0,
1613
+ 0.0,
1614
+ 0.0,
1615
+ 0.0,
1616
+ 0.0
1617
+ ],
1618
+ "q01": [
1619
+ 0.0,
1620
+ 0.0,
1621
+ 0.0,
1622
+ 0.0,
1623
+ 0.0,
1624
+ 0.0,
1625
+ 0.0
1626
+ ],
1627
+ "q99": [
1628
+ 0.0,
1629
+ 0.0,
1630
+ 0.0,
1631
+ 0.0,
1632
+ 0.0,
1633
+ 0.0,
1634
+ 0.0
1635
+ ],
1636
+ "std": [
1637
+ 0.0,
1638
+ 0.0,
1639
+ 0.0,
1640
+ 0.0,
1641
+ 0.0,
1642
+ 0.0,
1643
+ 0.0
1644
+ ]
1645
+ }
1646
+ },
1647
+ "furniture_bench_dataset_converted_externally_to_rlds": {
1648
+ "action": {
1649
+ "mask": [
1650
+ true,
1651
+ true,
1652
+ true,
1653
+ true,
1654
+ true,
1655
+ true,
1656
+ false
1657
+ ],
1658
+ "max": [
1659
+ 0.10000000149011612,
1660
+ 0.10000000149011612,
1661
+ 0.10000000149011612,
1662
+ 0.8651833534240723,
1663
+ 1.0909736156463623,
1664
+ 2.863185405731201,
1665
+ 1.0
1666
+ ],
1667
+ "mean": [
1668
+ 0.00014610752987209707,
1669
+ 0.0010830952087417245,
1670
+ 0.0006224989192560315,
1671
+ -0.003303206292912364,
1672
+ -0.0026880695950239897,
1673
+ 0.018242603167891502,
1674
+ 0.48854944109916687
1675
+ ],
1676
+ "min": [
1677
+ -0.10495579987764359,
1678
+ -0.10939455777406693,
1679
+ -0.10000000149011612,
1680
+ -0.971906840801239,
1681
+ -1.0475432872772217,
1682
+ -3.06000018119812,
1683
+ 0.0
1684
+ ],
1685
+ "q01": [
1686
+ -0.053988199681043625,
1687
+ -0.05049169331789017,
1688
+ -0.032499241530895236,
1689
+ -0.1953887003660202,
1690
+ -0.41674559473991396,
1691
+ -0.8886768388748169,
1692
+ 0.0
1693
+ ],
1694
+ "q99": [
1695
+ 0.05414841488003723,
1696
+ 0.04965164884924884,
1697
+ 0.060055799782276154,
1698
+ 0.18231668293476103,
1699
+ 0.39867786407470646,
1700
+ 0.8772023963928218,
1701
+ 1.0
1702
+ ],
1703
+ "std": [
1704
+ 0.01610708422958851,
1705
+ 0.014891477301716805,
1706
+ 0.014014219865202904,
1707
+ 0.058274295181035995,
1708
+ 0.11417088657617569,
1709
+ 0.33479776978492737,
1710
+ 0.49991825222969055
1711
+ ]
1712
+ },
1713
+ "num_trajectories": 5100,
1714
+ "num_transitions": 3948057,
1715
+ "proprio": {
1716
+ "max": [
1717
+ 0.0,
1718
+ 0.0,
1719
+ 0.0,
1720
+ 0.0,
1721
+ 0.0,
1722
+ 0.0,
1723
+ 0.0
1724
+ ],
1725
+ "mean": [
1726
+ 0.0,
1727
+ 0.0,
1728
+ 0.0,
1729
+ 0.0,
1730
+ 0.0,
1731
+ 0.0,
1732
+ 0.0
1733
+ ],
1734
+ "min": [
1735
+ 0.0,
1736
+ 0.0,
1737
+ 0.0,
1738
+ 0.0,
1739
+ 0.0,
1740
+ 0.0,
1741
+ 0.0
1742
+ ],
1743
+ "q01": [
1744
+ 0.0,
1745
+ 0.0,
1746
+ 0.0,
1747
+ 0.0,
1748
+ 0.0,
1749
+ 0.0,
1750
+ 0.0
1751
+ ],
1752
+ "q99": [
1753
+ 0.0,
1754
+ 0.0,
1755
+ 0.0,
1756
+ 0.0,
1757
+ 0.0,
1758
+ 0.0,
1759
+ 0.0
1760
+ ],
1761
+ "std": [
1762
+ 0.0,
1763
+ 0.0,
1764
+ 0.0,
1765
+ 0.0,
1766
+ 0.0,
1767
+ 0.0,
1768
+ 0.0
1769
+ ]
1770
+ }
1771
+ },
1772
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
1773
+ "action": {
1774
+ "mask": [
1775
+ true,
1776
+ true,
1777
+ true,
1778
+ true,
1779
+ true,
1780
+ true,
1781
+ false
1782
+ ],
1783
+ "max": [
1784
+ 0.6634981632232666,
1785
+ 0.23428471386432648,
1786
+ 0.4308285415172577,
1787
+ 3.1415927410125732,
1788
+ 0.13647015392780304,
1789
+ 3.141592502593994,
1790
+ 1.0
1791
+ ],
1792
+ "mean": [
1793
+ 0.5274372696876526,
1794
+ 0.02858201041817665,
1795
+ 0.18712575733661652,
1796
+ 1.2339589595794678,
1797
+ 0.03226623684167862,
1798
+ -1.4199490547180176,
1799
+ 0.5550631880760193
1800
+ ],
1801
+ "min": [
1802
+ 0.3071657121181488,
1803
+ -0.29754969477653503,
1804
+ 0.06578229367733002,
1805
+ -3.1415927410125732,
1806
+ -0.04584203287959099,
1807
+ -3.141592502593994,
1808
+ 0.0
1809
+ ],
1810
+ "q01": [
1811
+ 0.3148897051811218,
1812
+ -0.20317550599575043,
1813
+ 0.06785467118024827,
1814
+ -3.140952730178833,
1815
+ -0.029743434861302376,
1816
+ -3.141091251373291,
1817
+ 0.0
1818
+ ],
1819
+ "q99": [
1820
+ 0.6472805738449097,
1821
+ 0.20846802592277527,
1822
+ 0.36855655312538155,
1823
+ 3.1409926891326903,
1824
+ 0.11424950212240226,
1825
+ 3.1410969257354737,
1826
+ 1.0
1827
+ ],
1828
+ "std": [
1829
+ 0.08108345419168472,
1830
+ 0.1116757020354271,
1831
+ 0.07747554779052734,
1832
+ 2.8737246990203857,
1833
+ 0.02774704433977604,
1834
+ 2.7678682804107666,
1835
+ 0.49695101380348206
1836
+ ]
1837
+ },
1838
+ "num_trajectories": 631,
1839
+ "num_transitions": 146241,
1840
+ "proprio": {
1841
+ "max": [
1842
+ 0.0,
1843
+ 0.0,
1844
+ 0.0,
1845
+ 0.0,
1846
+ 0.0,
1847
+ 0.0,
1848
+ 0.0
1849
+ ],
1850
+ "mean": [
1851
+ 0.0,
1852
+ 0.0,
1853
+ 0.0,
1854
+ 0.0,
1855
+ 0.0,
1856
+ 0.0,
1857
+ 0.0
1858
+ ],
1859
+ "min": [
1860
+ 0.0,
1861
+ 0.0,
1862
+ 0.0,
1863
+ 0.0,
1864
+ 0.0,
1865
+ 0.0,
1866
+ 0.0
1867
+ ],
1868
+ "q01": [
1869
+ 0.0,
1870
+ 0.0,
1871
+ 0.0,
1872
+ 0.0,
1873
+ 0.0,
1874
+ 0.0,
1875
+ 0.0
1876
+ ],
1877
+ "q99": [
1878
+ 0.0,
1879
+ 0.0,
1880
+ 0.0,
1881
+ 0.0,
1882
+ 0.0,
1883
+ 0.0,
1884
+ 0.0
1885
+ ],
1886
+ "std": [
1887
+ 0.0,
1888
+ 0.0,
1889
+ 0.0,
1890
+ 0.0,
1891
+ 0.0,
1892
+ 0.0,
1893
+ 0.0
1894
+ ]
1895
+ }
1896
+ },
1897
+ "jaco_play": {
1898
+ "action": {
1899
+ "mask": [
1900
+ true,
1901
+ true,
1902
+ true,
1903
+ true,
1904
+ true,
1905
+ true,
1906
+ false
1907
+ ],
1908
+ "max": [
1909
+ 0.20000000298023224,
1910
+ 0.20000000298023224,
1911
+ 0.20000000298023224,
1912
+ 0.0,
1913
+ 0.0,
1914
+ 0.0,
1915
+ 1.0
1916
+ ],
1917
+ "mean": [
1918
+ 0.0009658430935814977,
1919
+ -0.00580078037455678,
1920
+ -0.00395062193274498,
1921
+ 0.0,
1922
+ 0.0,
1923
+ 0.0,
1924
+ 0.34934908151626587
1925
+ ],
1926
+ "min": [
1927
+ -0.20000000298023224,
1928
+ -0.20000000298023224,
1929
+ -0.20000000298023224,
1930
+ 0.0,
1931
+ 0.0,
1932
+ 0.0,
1933
+ 0.0
1934
+ ],
1935
+ "q01": [
1936
+ -0.20000000298023224,
1937
+ -0.20000000298023224,
1938
+ -0.20000000298023224,
1939
+ 0.0,
1940
+ 0.0,
1941
+ 0.0,
1942
+ 0.0
1943
+ ],
1944
+ "q99": [
1945
+ 0.20000000298023224,
1946
+ 0.20000000298023224,
1947
+ 0.20000000298023224,
1948
+ 0.0,
1949
+ 0.0,
1950
+ 0.0,
1951
+ 1.0
1952
+ ],
1953
+ "std": [
1954
+ 0.12235074490308762,
1955
+ 0.09678777307271957,
1956
+ 0.11155334860086441,
1957
+ 0.0,
1958
+ 0.0,
1959
+ 0.0,
1960
+ 0.4768252968788147
1961
+ ]
1962
+ },
1963
+ "num_trajectories": 1085,
1964
+ "num_transitions": 77965,
1965
+ "proprio": {
1966
+ "max": [
1967
+ 0.0,
1968
+ 0.0,
1969
+ 0.0,
1970
+ 0.0,
1971
+ 0.0,
1972
+ 0.0,
1973
+ 0.0
1974
+ ],
1975
+ "mean": [
1976
+ 0.0,
1977
+ 0.0,
1978
+ 0.0,
1979
+ 0.0,
1980
+ 0.0,
1981
+ 0.0,
1982
+ 0.0
1983
+ ],
1984
+ "min": [
1985
+ 0.0,
1986
+ 0.0,
1987
+ 0.0,
1988
+ 0.0,
1989
+ 0.0,
1990
+ 0.0,
1991
+ 0.0
1992
+ ],
1993
+ "q01": [
1994
+ 0.0,
1995
+ 0.0,
1996
+ 0.0,
1997
+ 0.0,
1998
+ 0.0,
1999
+ 0.0,
2000
+ 0.0
2001
+ ],
2002
+ "q99": [
2003
+ 0.0,
2004
+ 0.0,
2005
+ 0.0,
2006
+ 0.0,
2007
+ 0.0,
2008
+ 0.0,
2009
+ 0.0
2010
+ ],
2011
+ "std": [
2012
+ 0.0,
2013
+ 0.0,
2014
+ 0.0,
2015
+ 0.0,
2016
+ 0.0,
2017
+ 0.0,
2018
+ 0.0
2019
+ ]
2020
+ }
2021
+ },
2022
+ "kuka": {
2023
+ "action": {
2024
+ "mask": [
2025
+ true,
2026
+ true,
2027
+ true,
2028
+ true,
2029
+ true,
2030
+ true,
2031
+ false
2032
+ ],
2033
+ "max": [
2034
+ 0.1697135865688324,
2035
+ 0.2777623236179352,
2036
+ 0.43710532784461975,
2037
+ 0.0,
2038
+ 0.0,
2039
+ 1.9684287309646606,
2040
+ 1.0
2041
+ ],
2042
+ "mean": [
2043
+ -0.0004668905457947403,
2044
+ 0.00040138536132872105,
2045
+ -0.001280792523175478,
2046
+ 0.0,
2047
+ 0.0,
2048
+ -0.03722453489899635,
2049
+ 0.4131543040275574
2050
+ ],
2051
+ "min": [
2052
+ -0.159867063164711,
2053
+ -0.2892282009124756,
2054
+ -0.2795473635196686,
2055
+ 0.0,
2056
+ 0.0,
2057
+ -1.9875637292861938,
2058
+ 0.0
2059
+ ],
2060
+ "q01": [
2061
+ -0.06619441494345665,
2062
+ -0.08713878810405731,
2063
+ -0.15083016991615295,
2064
+ 0.0,
2065
+ 0.0,
2066
+ -0.5415697038173676,
2067
+ 0.0
2068
+ ],
2069
+ "q99": [
2070
+ 0.06601839080452929,
2071
+ 0.08732476785779003,
2072
+ 0.18168179214000715,
2073
+ 0.0,
2074
+ 0.0,
2075
+ 0.2923380345106127,
2076
+ 1.0
2077
+ ],
2078
+ "std": [
2079
+ 0.02083250693976879,
2080
+ 0.02915887162089348,
2081
+ 0.06422865390777588,
2082
+ 0.0,
2083
+ 0.0,
2084
+ 0.14224295318126678,
2085
+ 0.49086448550224304
2086
+ ]
2087
+ },
2088
+ "num_trajectories": 209880,
2089
+ "num_transitions": 2455879,
2090
+ "proprio": {
2091
+ "max": [
2092
+ 0.0,
2093
+ 0.0,
2094
+ 0.0,
2095
+ 0.0,
2096
+ 0.0,
2097
+ 0.0,
2098
+ 0.0
2099
+ ],
2100
+ "mean": [
2101
+ 0.0,
2102
+ 0.0,
2103
+ 0.0,
2104
+ 0.0,
2105
+ 0.0,
2106
+ 0.0,
2107
+ 0.0
2108
+ ],
2109
+ "min": [
2110
+ 0.0,
2111
+ 0.0,
2112
+ 0.0,
2113
+ 0.0,
2114
+ 0.0,
2115
+ 0.0,
2116
+ 0.0
2117
+ ],
2118
+ "q01": [
2119
+ 0.0,
2120
+ 0.0,
2121
+ 0.0,
2122
+ 0.0,
2123
+ 0.0,
2124
+ 0.0,
2125
+ 0.0
2126
+ ],
2127
+ "q99": [
2128
+ 0.0,
2129
+ 0.0,
2130
+ 0.0,
2131
+ 0.0,
2132
+ 0.0,
2133
+ 0.0,
2134
+ 0.0
2135
+ ],
2136
+ "std": [
2137
+ 0.0,
2138
+ 0.0,
2139
+ 0.0,
2140
+ 0.0,
2141
+ 0.0,
2142
+ 0.0,
2143
+ 0.0
2144
+ ]
2145
+ }
2146
+ },
2147
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
2148
+ "action": {
2149
+ "mask": [
2150
+ true,
2151
+ true,
2152
+ true,
2153
+ true,
2154
+ true,
2155
+ true,
2156
+ false
2157
+ ],
2158
+ "max": [
2159
+ 0.06424188613891602,
2160
+ 0.07027634978294373,
2161
+ 0.06129661202430725,
2162
+ 6.281067848205566,
2163
+ 0.1967729926109314,
2164
+ 0.26377415657043457,
2165
+ 1.0
2166
+ ],
2167
+ "mean": [
2168
+ 0.001021989737637341,
2169
+ -0.00012002651783404872,
2170
+ 0.00032894269679673016,
2171
+ 0.0015034361276775599,
2172
+ -0.002198522910475731,
2173
+ -0.001663230243138969,
2174
+ 0.7230083346366882
2175
+ ],
2176
+ "min": [
2177
+ -0.05952230095863342,
2178
+ -0.07232445478439331,
2179
+ -0.06730806827545166,
2180
+ -6.278434753417969,
2181
+ -0.21479034423828125,
2182
+ -0.3627619743347168,
2183
+ 0.0
2184
+ ],
2185
+ "q01": [
2186
+ -0.03199600875377655,
2187
+ -0.032861671447753905,
2188
+ -0.03368805110454559,
2189
+ -0.12080862045288086,
2190
+ -0.12175218224525451,
2191
+ -0.11370223641395569,
2192
+ 0.0
2193
+ ],
2194
+ "q99": [
2195
+ 0.03101520001888276,
2196
+ 0.0373908892273903,
2197
+ 0.03646374464035038,
2198
+ 0.11764093399047852,
2199
+ 0.1258920183777809,
2200
+ 0.09366151213645942,
2201
+ 1.0
2202
+ ],
2203
+ "std": [
2204
+ 0.01327415369451046,
2205
+ 0.013215910643339157,
2206
+ 0.012822109274566174,
2207
+ 0.2732451558113098,
2208
+ 0.057022541761398315,
2209
+ 0.039172880351543427,
2210
+ 0.44752755761146545
2211
+ ]
2212
+ },
2213
+ "num_trajectories": 456,
2214
+ "num_transitions": 44875,
2215
+ "proprio": {
2216
+ "max": [
2217
+ 0.0,
2218
+ 0.0,
2219
+ 0.0,
2220
+ 0.0,
2221
+ 0.0,
2222
+ 0.0,
2223
+ 0.0
2224
+ ],
2225
+ "mean": [
2226
+ 0.0,
2227
+ 0.0,
2228
+ 0.0,
2229
+ 0.0,
2230
+ 0.0,
2231
+ 0.0,
2232
+ 0.0
2233
+ ],
2234
+ "min": [
2235
+ 0.0,
2236
+ 0.0,
2237
+ 0.0,
2238
+ 0.0,
2239
+ 0.0,
2240
+ 0.0,
2241
+ 0.0
2242
+ ],
2243
+ "q01": [
2244
+ 0.0,
2245
+ 0.0,
2246
+ 0.0,
2247
+ 0.0,
2248
+ 0.0,
2249
+ 0.0,
2250
+ 0.0
2251
+ ],
2252
+ "q99": [
2253
+ 0.0,
2254
+ 0.0,
2255
+ 0.0,
2256
+ 0.0,
2257
+ 0.0,
2258
+ 0.0,
2259
+ 0.0
2260
+ ],
2261
+ "std": [
2262
+ 0.0,
2263
+ 0.0,
2264
+ 0.0,
2265
+ 0.0,
2266
+ 0.0,
2267
+ 0.0,
2268
+ 0.0
2269
+ ]
2270
+ }
2271
+ },
2272
+ "roboturk": {
2273
+ "action": {
2274
+ "mask": [
2275
+ true,
2276
+ true,
2277
+ true,
2278
+ true,
2279
+ true,
2280
+ true,
2281
+ false
2282
+ ],
2283
+ "max": [
2284
+ 0.39124172925949097,
2285
+ 0.4601028263568878,
2286
+ 0.4870833456516266,
2287
+ 1.816888689994812,
2288
+ 1.8240282535552979,
2289
+ 1.4824820756912231,
2290
+ 1.0
2291
+ ],
2292
+ "mean": [
2293
+ 0.0014448732836171985,
2294
+ -0.0015945249469950795,
2295
+ -0.0011753785656765103,
2296
+ 0.0023012510500848293,
2297
+ -0.0009382463176734746,
2298
+ -0.00011485807772260159,
2299
+ 0.5746025443077087
2300
+ ],
2301
+ "min": [
2302
+ -0.6546999216079712,
2303
+ -0.6365841031074524,
2304
+ -0.4217723608016968,
2305
+ -1.6695482730865479,
2306
+ -1.8023357391357422,
2307
+ -1.4630827903747559,
2308
+ 0.0
2309
+ ],
2310
+ "q01": [
2311
+ -0.1342635464668274,
2312
+ -0.19996687173843383,
2313
+ -0.1482972100377083,
2314
+ -0.20720748245716095,
2315
+ -0.09676413893699647,
2316
+ -0.18075634717941286,
2317
+ 0.0
2318
+ ],
2319
+ "q99": [
2320
+ 0.14956976801157001,
2321
+ 0.1805950567126275,
2322
+ 0.18841815620660796,
2323
+ 0.21615413755178453,
2324
+ 0.09457383215427405,
2325
+ 0.18543301910162005,
2326
+ 1.0
2327
+ ],
2328
+ "std": [
2329
+ 0.04935386776924133,
2330
+ 0.0635455846786499,
2331
+ 0.061164740473032,
2332
+ 0.09553450345993042,
2333
+ 0.08420111238956451,
2334
+ 0.06517903506755829,
2335
+ 0.49452081322669983
2336
+ ]
2337
+ },
2338
+ "num_trajectories": 1995,
2339
+ "num_transitions": 187507,
2340
+ "proprio": {
2341
+ "max": [
2342
+ 0.0,
2343
+ 0.0,
2344
+ 0.0,
2345
+ 0.0,
2346
+ 0.0,
2347
+ 0.0,
2348
+ 0.0
2349
+ ],
2350
+ "mean": [
2351
+ 0.0,
2352
+ 0.0,
2353
+ 0.0,
2354
+ 0.0,
2355
+ 0.0,
2356
+ 0.0,
2357
+ 0.0
2358
+ ],
2359
+ "min": [
2360
+ 0.0,
2361
+ 0.0,
2362
+ 0.0,
2363
+ 0.0,
2364
+ 0.0,
2365
+ 0.0,
2366
+ 0.0
2367
+ ],
2368
+ "q01": [
2369
+ 0.0,
2370
+ 0.0,
2371
+ 0.0,
2372
+ 0.0,
2373
+ 0.0,
2374
+ 0.0,
2375
+ 0.0
2376
+ ],
2377
+ "q99": [
2378
+ 0.0,
2379
+ 0.0,
2380
+ 0.0,
2381
+ 0.0,
2382
+ 0.0,
2383
+ 0.0,
2384
+ 0.0
2385
+ ],
2386
+ "std": [
2387
+ 0.0,
2388
+ 0.0,
2389
+ 0.0,
2390
+ 0.0,
2391
+ 0.0,
2392
+ 0.0,
2393
+ 0.0
2394
+ ]
2395
+ }
2396
+ },
2397
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
2398
+ "action": {
2399
+ "mask": [
2400
+ true,
2401
+ true,
2402
+ true,
2403
+ true,
2404
+ true,
2405
+ true,
2406
+ false
2407
+ ],
2408
+ "max": [
2409
+ 0.02499854564666748,
2410
+ 0.02499903365969658,
2411
+ 0.024999922141432762,
2412
+ 0.24974457919597626,
2413
+ 0.24997030198574066,
2414
+ 0.24999946355819702,
2415
+ 1.0
2416
+ ],
2417
+ "mean": [
2418
+ 0.0007790001109242439,
2419
+ 0.00013707754260394722,
2420
+ -0.0002548607881180942,
2421
+ 0.0012903271708637476,
2422
+ -0.004751681815832853,
2423
+ 0.002692886395379901,
2424
+ 0.48855218291282654
2425
+ ],
2426
+ "min": [
2427
+ -0.024999044835567474,
2428
+ -0.024999700486660004,
2429
+ -0.02499929815530777,
2430
+ -0.24993225932121277,
2431
+ -0.2499666064977646,
2432
+ -0.2499932497739792,
2433
+ 0.0
2434
+ ],
2435
+ "q01": [
2436
+ -0.019992006458342076,
2437
+ -0.02415412735193968,
2438
+ -0.022941758055239916,
2439
+ -0.11085530579090118,
2440
+ -0.12024572037160397,
2441
+ -0.13314770206809043,
2442
+ 0.0
2443
+ ],
2444
+ "q99": [
2445
+ 0.022886231057345868,
2446
+ 0.022358838934451335,
2447
+ 0.02410089675337076,
2448
+ 0.12370114490389822,
2449
+ 0.11323311634361738,
2450
+ 0.18474749639630164,
2451
+ 1.0
2452
+ ],
2453
+ "std": [
2454
+ 0.008022161200642586,
2455
+ 0.009131459519267082,
2456
+ 0.009574338793754578,
2457
+ 0.04122216999530792,
2458
+ 0.0384303517639637,
2459
+ 0.04606688767671585,
2460
+ 0.49976691603660583
2461
+ ]
2462
+ },
2463
+ "num_trajectories": 570,
2464
+ "num_transitions": 358234,
2465
+ "proprio": {
2466
+ "max": [
2467
+ 0.0,
2468
+ 0.0,
2469
+ 0.0,
2470
+ 0.0,
2471
+ 0.0,
2472
+ 0.0,
2473
+ 0.0
2474
+ ],
2475
+ "mean": [
2476
+ 0.0,
2477
+ 0.0,
2478
+ 0.0,
2479
+ 0.0,
2480
+ 0.0,
2481
+ 0.0,
2482
+ 0.0
2483
+ ],
2484
+ "min": [
2485
+ 0.0,
2486
+ 0.0,
2487
+ 0.0,
2488
+ 0.0,
2489
+ 0.0,
2490
+ 0.0,
2491
+ 0.0
2492
+ ],
2493
+ "q01": [
2494
+ 0.0,
2495
+ 0.0,
2496
+ 0.0,
2497
+ 0.0,
2498
+ 0.0,
2499
+ 0.0,
2500
+ 0.0
2501
+ ],
2502
+ "q99": [
2503
+ 0.0,
2504
+ 0.0,
2505
+ 0.0,
2506
+ 0.0,
2507
+ 0.0,
2508
+ 0.0,
2509
+ 0.0
2510
+ ],
2511
+ "std": [
2512
+ 0.0,
2513
+ 0.0,
2514
+ 0.0,
2515
+ 0.0,
2516
+ 0.0,
2517
+ 0.0,
2518
+ 0.0
2519
+ ]
2520
+ }
2521
+ },
2522
+ "taco_play": {
2523
+ "action": {
2524
+ "mask": [
2525
+ true,
2526
+ true,
2527
+ true,
2528
+ true,
2529
+ true,
2530
+ true,
2531
+ false
2532
+ ],
2533
+ "max": [
2534
+ 1.4915844202041626,
2535
+ 2.1842432022094727,
2536
+ 2.6836395263671875,
2537
+ 5.035226821899414,
2538
+ 2.665864944458008,
2539
+ 4.250768661499023,
2540
+ 1.0
2541
+ ],
2542
+ "mean": [
2543
+ -0.003845922416076064,
2544
+ 0.009671456180512905,
2545
+ 0.012780580669641495,
2546
+ -0.005403771996498108,
2547
+ -0.009606587700545788,
2548
+ -0.002480733208358288,
2549
+ 0.4263913035392761
2550
+ ],
2551
+ "min": [
2552
+ -4.242457866668701,
2553
+ -3.192805051803589,
2554
+ -1.3371467590332031,
2555
+ -4.202683448791504,
2556
+ -2.6722638607025146,
2557
+ -3.3467135429382324,
2558
+ 0.0
2559
+ ],
2560
+ "q01": [
2561
+ -0.7106140398979186,
2562
+ -1.056944659948349,
2563
+ -0.5878450274467468,
2564
+ -0.7682853937149048,
2565
+ -0.7180147767066956,
2566
+ -1.5527938604354858,
2567
+ 0.0
2568
+ ],
2569
+ "q99": [
2570
+ 0.6482916426658629,
2571
+ 1.0051310062408447,
2572
+ 0.9480248689651489,
2573
+ 0.6926478147506714,
2574
+ 0.6351067513227462,
2575
+ 1.628010264635086,
2576
+ 1.0
2577
+ ],
2578
+ "std": [
2579
+ 0.23254038393497467,
2580
+ 0.36298269033432007,
2581
+ 0.28692901134490967,
2582
+ 0.2617705166339874,
2583
+ 0.2438892275094986,
2584
+ 0.5216503143310547,
2585
+ 0.4946896731853485
2586
+ ]
2587
+ },
2588
+ "num_trajectories": 3603,
2589
+ "num_transitions": 237798,
2590
+ "proprio": {
2591
+ "max": [
2592
+ 0.0,
2593
+ 0.0,
2594
+ 0.0,
2595
+ 0.0,
2596
+ 0.0,
2597
+ 0.0,
2598
+ 0.0
2599
+ ],
2600
+ "mean": [
2601
+ 0.0,
2602
+ 0.0,
2603
+ 0.0,
2604
+ 0.0,
2605
+ 0.0,
2606
+ 0.0,
2607
+ 0.0
2608
+ ],
2609
+ "min": [
2610
+ 0.0,
2611
+ 0.0,
2612
+ 0.0,
2613
+ 0.0,
2614
+ 0.0,
2615
+ 0.0,
2616
+ 0.0
2617
+ ],
2618
+ "q01": [
2619
+ 0.0,
2620
+ 0.0,
2621
+ 0.0,
2622
+ 0.0,
2623
+ 0.0,
2624
+ 0.0,
2625
+ 0.0
2626
+ ],
2627
+ "q99": [
2628
+ 0.0,
2629
+ 0.0,
2630
+ 0.0,
2631
+ 0.0,
2632
+ 0.0,
2633
+ 0.0,
2634
+ 0.0
2635
+ ],
2636
+ "std": [
2637
+ 0.0,
2638
+ 0.0,
2639
+ 0.0,
2640
+ 0.0,
2641
+ 0.0,
2642
+ 0.0,
2643
+ 0.0
2644
+ ]
2645
+ }
2646
+ },
2647
+ "toto": {
2648
+ "action": {
2649
+ "mask": [
2650
+ true,
2651
+ true,
2652
+ true,
2653
+ true,
2654
+ true,
2655
+ true,
2656
+ false
2657
+ ],
2658
+ "max": [
2659
+ 0.6839867234230042,
2660
+ 0.4454185664653778,
2661
+ 0.7984078526496887,
2662
+ 2.120781660079956,
2663
+ 1.371164321899414,
2664
+ 1.4118704795837402,
2665
+ 0.0
2666
+ ],
2667
+ "mean": [
2668
+ 0.38542115688323975,
2669
+ 0.007769413758069277,
2670
+ 0.3632740378379822,
2671
+ -0.6652036905288696,
2672
+ 0.1890396922826767,
2673
+ 0.03298724442720413,
2674
+ 0.0
2675
+ ],
2676
+ "min": [
2677
+ 0.09922284632921219,
2678
+ -0.5180193781852722,
2679
+ 0.13791072368621826,
2680
+ -2.635117530822754,
2681
+ -1.0734480619430542,
2682
+ -1.9282547235488892,
2683
+ 0.0
2684
+ ],
2685
+ "q01": [
2686
+ 0.1756722891330719,
2687
+ -0.3077590811252594,
2688
+ 0.235383919775486,
2689
+ -2.0908505964279174,
2690
+ -0.6191593289375306,
2691
+ -0.7488683319091797,
2692
+ 0.0
2693
+ ],
2694
+ "q99": [
2695
+ 0.6136963081359863,
2696
+ 0.33704194784164443,
2697
+ 0.6681221985816956,
2698
+ 0.7422861719131538,
2699
+ 0.7955395007133507,
2700
+ 0.740464625358582,
2701
+ 0.0
2702
+ ],
2703
+ "std": [
2704
+ 0.12211652100086212,
2705
+ 0.19378550350666046,
2706
+ 0.10178236663341522,
2707
+ 0.5725259184837341,
2708
+ 0.29884573817253113,
2709
+ 0.3259911835193634,
2710
+ 0.0
2711
+ ]
2712
+ },
2713
+ "num_trajectories": 1003,
2714
+ "num_transitions": 325699,
2715
+ "proprio": {
2716
+ "max": [
2717
+ 0.0,
2718
+ 0.0,
2719
+ 0.0,
2720
+ 0.0,
2721
+ 0.0,
2722
+ 0.0,
2723
+ 0.0
2724
+ ],
2725
+ "mean": [
2726
+ 0.0,
2727
+ 0.0,
2728
+ 0.0,
2729
+ 0.0,
2730
+ 0.0,
2731
+ 0.0,
2732
+ 0.0
2733
+ ],
2734
+ "min": [
2735
+ 0.0,
2736
+ 0.0,
2737
+ 0.0,
2738
+ 0.0,
2739
+ 0.0,
2740
+ 0.0,
2741
+ 0.0
2742
+ ],
2743
+ "q01": [
2744
+ 0.0,
2745
+ 0.0,
2746
+ 0.0,
2747
+ 0.0,
2748
+ 0.0,
2749
+ 0.0,
2750
+ 0.0
2751
+ ],
2752
+ "q99": [
2753
+ 0.0,
2754
+ 0.0,
2755
+ 0.0,
2756
+ 0.0,
2757
+ 0.0,
2758
+ 0.0,
2759
+ 0.0
2760
+ ],
2761
+ "std": [
2762
+ 0.0,
2763
+ 0.0,
2764
+ 0.0,
2765
+ 0.0,
2766
+ 0.0,
2767
+ 0.0,
2768
+ 0.0
2769
+ ]
2770
+ }
2771
+ },
2772
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
2773
+ "action": {
2774
+ "mask": [
2775
+ true,
2776
+ true,
2777
+ true,
2778
+ true,
2779
+ true,
2780
+ true,
2781
+ false
2782
+ ],
2783
+ "max": [
2784
+ 678.0,
2785
+ 400.0,
2786
+ 507.0,
2787
+ 180.00001525878906,
2788
+ 6.000013828277588,
2789
+ 116.99998474121094,
2790
+ 1.0
2791
+ ],
2792
+ "mean": [
2793
+ 410.37567138671875,
2794
+ 116.9518814086914,
2795
+ 192.35032653808594,
2796
+ -121.22441864013672,
2797
+ -33.84893035888672,
2798
+ 50.016136169433594,
2799
+ 0.741813600063324
2800
+ ],
2801
+ "min": [
2802
+ 172.0,
2803
+ -166.0,
2804
+ -99.99999237060547,
2805
+ -180.00001525878906,
2806
+ -89.0,
2807
+ -96.00010681152344,
2808
+ 0.0
2809
+ ],
2810
+ "q01": [
2811
+ 200.00001052856445,
2812
+ -102.31004211425781,
2813
+ -94.99993370056153,
2814
+ -180.00001525878906,
2815
+ -88.00001525878906,
2816
+ -38.999977111816406,
2817
+ 0.0
2818
+ ],
2819
+ "q99": [
2820
+ 637.0,
2821
+ 368.30999999999995,
2822
+ 493.0,
2823
+ 180.00001525878906,
2824
+ 0.999983012676239,
2825
+ 105.00001525878906,
2826
+ 1.0
2827
+ ],
2828
+ "std": [
2829
+ 122.81494903564453,
2830
+ 108.8009033203125,
2831
+ 130.303466796875,
2832
+ 116.28205108642578,
2833
+ 27.621843338012695,
2834
+ 41.02094650268555,
2835
+ 0.43763357400894165
2836
+ ]
2837
+ },
2838
+ "num_trajectories": 150,
2839
+ "num_transitions": 3970,
2840
+ "proprio": {
2841
+ "max": [
2842
+ 0.0,
2843
+ 0.0,
2844
+ 0.0,
2845
+ 0.0,
2846
+ 0.0,
2847
+ 0.0,
2848
+ 0.0
2849
+ ],
2850
+ "mean": [
2851
+ 0.0,
2852
+ 0.0,
2853
+ 0.0,
2854
+ 0.0,
2855
+ 0.0,
2856
+ 0.0,
2857
+ 0.0
2858
+ ],
2859
+ "min": [
2860
+ 0.0,
2861
+ 0.0,
2862
+ 0.0,
2863
+ 0.0,
2864
+ 0.0,
2865
+ 0.0,
2866
+ 0.0
2867
+ ],
2868
+ "q01": [
2869
+ 0.0,
2870
+ 0.0,
2871
+ 0.0,
2872
+ 0.0,
2873
+ 0.0,
2874
+ 0.0,
2875
+ 0.0
2876
+ ],
2877
+ "q99": [
2878
+ 0.0,
2879
+ 0.0,
2880
+ 0.0,
2881
+ 0.0,
2882
+ 0.0,
2883
+ 0.0,
2884
+ 0.0
2885
+ ],
2886
+ "std": [
2887
+ 0.0,
2888
+ 0.0,
2889
+ 0.0,
2890
+ 0.0,
2891
+ 0.0,
2892
+ 0.0,
2893
+ 0.0
2894
+ ]
2895
+ }
2896
+ },
2897
+ "utaustin_mutex": {
2898
+ "action": {
2899
+ "mask": [
2900
+ true,
2901
+ true,
2902
+ true,
2903
+ true,
2904
+ true,
2905
+ true,
2906
+ false
2907
+ ],
2908
+ "max": [
2909
+ 1.0,
2910
+ 1.0,
2911
+ 1.0,
2912
+ 0.375,
2913
+ 0.375,
2914
+ 0.375,
2915
+ 1.0
2916
+ ],
2917
+ "mean": [
2918
+ 0.06176406890153885,
2919
+ -0.005005486309528351,
2920
+ 0.10216785222291946,
2921
+ -0.03314131125807762,
2922
+ 0.013895004987716675,
2923
+ -0.011317633092403412,
2924
+ 0.5038976669311523
2925
+ ],
2926
+ "min": [
2927
+ -1.0,
2928
+ -1.0,
2929
+ -1.0,
2930
+ -0.375,
2931
+ -0.375,
2932
+ -0.375,
2933
+ 0.0
2934
+ ],
2935
+ "q01": [
2936
+ -0.4285714328289032,
2937
+ -0.9800000190734863,
2938
+ -0.5571428537368774,
2939
+ -0.375,
2940
+ -0.15642857551574707,
2941
+ -0.335357129573822,
2942
+ 0.0
2943
+ ],
2944
+ "q99": [
2945
+ 0.5914285778999329,
2946
+ 0.9714285731315613,
2947
+ 1.0,
2948
+ 0.3278571367263794,
2949
+ 0.207857146859169,
2950
+ 0.25607141852378845,
2951
+ 1.0
2952
+ ],
2953
+ "std": [
2954
+ 0.1875014752149582,
2955
+ 0.4468473494052887,
2956
+ 0.3792876601219177,
2957
+ 0.14097853004932404,
2958
+ 0.06453701853752136,
2959
+ 0.11765272170305252,
2960
+ 0.501045286655426
2961
+ ]
2962
+ },
2963
+ "num_trajectories": 1500,
2964
+ "num_transitions": 361883,
2965
+ "proprio": {
2966
+ "max": [
2967
+ 0.0,
2968
+ 0.0,
2969
+ 0.0,
2970
+ 0.0,
2971
+ 0.0,
2972
+ 0.0,
2973
+ 0.0
2974
+ ],
2975
+ "mean": [
2976
+ 0.0,
2977
+ 0.0,
2978
+ 0.0,
2979
+ 0.0,
2980
+ 0.0,
2981
+ 0.0,
2982
+ 0.0
2983
+ ],
2984
+ "min": [
2985
+ 0.0,
2986
+ 0.0,
2987
+ 0.0,
2988
+ 0.0,
2989
+ 0.0,
2990
+ 0.0,
2991
+ 0.0
2992
+ ],
2993
+ "q01": [
2994
+ 0.0,
2995
+ 0.0,
2996
+ 0.0,
2997
+ 0.0,
2998
+ 0.0,
2999
+ 0.0,
3000
+ 0.0
3001
+ ],
3002
+ "q99": [
3003
+ 0.0,
3004
+ 0.0,
3005
+ 0.0,
3006
+ 0.0,
3007
+ 0.0,
3008
+ 0.0,
3009
+ 0.0
3010
+ ],
3011
+ "std": [
3012
+ 0.0,
3013
+ 0.0,
3014
+ 0.0,
3015
+ 0.0,
3016
+ 0.0,
3017
+ 0.0,
3018
+ 0.0
3019
+ ]
3020
+ }
3021
+ },
3022
+ "viola": {
3023
+ "action": {
3024
+ "mask": [
3025
+ true,
3026
+ true,
3027
+ true,
3028
+ true,
3029
+ true,
3030
+ true,
3031
+ false
3032
+ ],
3033
+ "max": [
3034
+ 1.0,
3035
+ 1.0,
3036
+ 1.0,
3037
+ 0.375,
3038
+ 0.36321428418159485,
3039
+ 0.375,
3040
+ 1.0
3041
+ ],
3042
+ "mean": [
3043
+ 0.04761844128370285,
3044
+ -0.029204415157437325,
3045
+ 0.05586736649274826,
3046
+ -0.002618510741740465,
3047
+ 0.006867344491183758,
3048
+ -0.01682133786380291,
3049
+ 0.7323777675628662
3050
+ ],
3051
+ "min": [
3052
+ -1.0,
3053
+ -1.0,
3054
+ -1.0,
3055
+ -0.375,
3056
+ -0.375,
3057
+ -0.375,
3058
+ 0.0
3059
+ ],
3060
+ "q01": [
3061
+ -0.9628571271896362,
3062
+ -1.0,
3063
+ -1.0,
3064
+ -0.26249998807907104,
3065
+ -0.21321429312229156,
3066
+ -0.3385714292526245,
3067
+ 0.0
3068
+ ],
3069
+ "q99": [
3070
+ 0.9114285707473755,
3071
+ 0.868571400642395,
3072
+ 1.0,
3073
+ 0.2817857265472412,
3074
+ 0.2239285707473755,
3075
+ 0.3557142913341522,
3076
+ 1.0
3077
+ ],
3078
+ "std": [
3079
+ 0.39157867431640625,
3080
+ 0.4076525568962097,
3081
+ 0.40077948570251465,
3082
+ 0.10023996233940125,
3083
+ 0.0844319611787796,
3084
+ 0.10375042259693146,
3085
+ 0.44260647892951965
3086
+ ]
3087
+ },
3088
+ "num_trajectories": 150,
3089
+ "num_transitions": 76324,
3090
+ "proprio": {
3091
+ "max": [
3092
+ 0.0,
3093
+ 0.0,
3094
+ 0.0,
3095
+ 0.0,
3096
+ 0.0,
3097
+ 0.0,
3098
+ 0.0
3099
+ ],
3100
+ "mean": [
3101
+ 0.0,
3102
+ 0.0,
3103
+ 0.0,
3104
+ 0.0,
3105
+ 0.0,
3106
+ 0.0,
3107
+ 0.0
3108
+ ],
3109
+ "min": [
3110
+ 0.0,
3111
+ 0.0,
3112
+ 0.0,
3113
+ 0.0,
3114
+ 0.0,
3115
+ 0.0,
3116
+ 0.0
3117
+ ],
3118
+ "q01": [
3119
+ 0.0,
3120
+ 0.0,
3121
+ 0.0,
3122
+ 0.0,
3123
+ 0.0,
3124
+ 0.0,
3125
+ 0.0
3126
+ ],
3127
+ "q99": [
3128
+ 0.0,
3129
+ 0.0,
3130
+ 0.0,
3131
+ 0.0,
3132
+ 0.0,
3133
+ 0.0,
3134
+ 0.0
3135
+ ],
3136
+ "std": [
3137
+ 0.0,
3138
+ 0.0,
3139
+ 0.0,
3140
+ 0.0,
3141
+ 0.0,
3142
+ 0.0,
3143
+ 0.0
3144
+ ]
3145
+ }
3146
+ }
3147
+ },
3148
+ "output_projector_states": false,
3149
+ "pad_to_multiple_of": 64,
3150
+ "pad_token_id": 32000,
3151
+ "text_config": {
3152
+ "bos_token_id": 151643,
3153
+ "eos_token_id": 151643,
3154
+ "hidden_size": 896,
3155
+ "intermediate_size": 4864,
3156
+ "max_position_embeddings": 32768,
3157
+ "max_window_layers": 24,
3158
+ "model_type": "qwen2",
3159
+ "num_attention_heads": 14,
3160
+ "num_hidden_layers": 24,
3161
+ "num_key_value_heads": 2,
3162
+ "rope_theta": 1000000.0,
3163
+ "sliding_window": 32768,
3164
+ "tie_word_embeddings": true,
3165
+ "torch_dtype": "bfloat16",
3166
+ "use_mrope": false,
3167
+ "use_sliding_window": false,
3168
+ "vocab_size": 151936
3169
+ },
3170
+ "timm_model_ids": [
3171
+ "vit_large_patch14_reg4_dinov2.lvd142m",
3172
+ "vit_so400m_patch14_siglip_224"
3173
+ ],
3174
+ "timm_override_act_layers": [
3175
+ null,
3176
+ null
3177
+ ],
3178
+ "torch_dtype": "bfloat16",
3179
+ "transformers_version": "4.40.1",
3180
+ "use_fused_vision_backbone": true,
3181
+ "vision_backbone_id": "dinosiglip-vit-so-224px"
3182
+ }
configuration_prismatic.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ "qwen25-0_5b-extra": "Qwen/Qwen2.5-0.5B", "qwen25-0_5b-pure": "Qwen/Qwen2.5-0.5B"
58
+
59
+
60
+ }
61
+ LLM_BACKBONE_TO_HF_METACLASS = {
62
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
63
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
64
+
65
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
66
+
67
+ "phi-2-3b": "phi",
68
+ "qwen25-0_5b-extra": "qwen2" ,"qwen25-0_5b-pure": "qwen2"
69
+ }
70
+
71
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
72
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
73
+ # fmt: on
74
+
75
+
76
+ class PrismaticConfig(PretrainedConfig):
77
+ model_type: str = "prismatic"
78
+ is_composition: bool = False
79
+
80
+ def __init__(
81
+ self,
82
+ vision_backbone_id: str = "siglip-vit-so400m",
83
+ llm_backbone_id: str = "vicuna-v15-7b",
84
+ arch_specifier: str = "no-align+gelu-mlp",
85
+ use_fused_vision_backbone: Optional[bool] = None,
86
+ image_resize_strategy: str = "letterbox",
87
+ text_config: Optional[Dict[str, Any]] = None,
88
+ llm_max_length: int = 2048,
89
+ pad_token_id: int = 32000,
90
+ pad_to_multiple_of: int = 64,
91
+ output_projector_states: bool = False,
92
+ **kwargs: str,
93
+ ) -> None:
94
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
95
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
96
+
97
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
98
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
99
+
100
+ # Set Prismatic Configuration Fields
101
+ self.vision_backbone_id = vision_backbone_id
102
+ self.llm_backbone_id = llm_backbone_id
103
+ self.arch_specifier = arch_specifier
104
+ self.output_projector_states = output_projector_states
105
+
106
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
107
+ self.use_fused_vision_backbone = (
108
+ use_fused_vision_backbone
109
+ if use_fused_vision_backbone is not None
110
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
111
+ )
112
+
113
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
114
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
115
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
116
+ self.image_resize_strategy = image_resize_strategy
117
+
118
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
119
+ self.llm_max_length = llm_max_length
120
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
121
+
122
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
123
+ self.text_config = (
124
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
125
+ if text_config is not None
126
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
127
+ )
128
+
129
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
130
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
131
+
132
+
133
+ class OpenVLAConfig(PrismaticConfig):
134
+ model_type: str = "openvla"
135
+
136
+ def __init__(
137
+ self,
138
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
139
+ n_action_bins: int = 256,
140
+ **kwargs: str,
141
+ ) -> None:
142
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
143
+
144
+ super().__init__(**kwargs)
dataset_statistics.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "libero_spatial_no_noops": {
3
+ "action": {
4
+ "mean": [
5
+ 0.15312479436397552,
6
+ 0.13707277178764343,
7
+ -0.15526802837848663,
8
+ -0.005176450591534376,
9
+ -0.01120874285697937,
10
+ -0.020194264128804207,
11
+ 0.4578818082809448
12
+ ],
13
+ "std": [
14
+ 0.41272708773612976,
15
+ 0.34724321961402893,
16
+ 0.50869220495224,
17
+ 0.037266165018081665,
18
+ 0.07244449853897095,
19
+ 0.05762382969260216,
20
+ 0.49827873706817627
21
+ ],
22
+ "max": [
23
+ 0.9375,
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.1971428543329239,
27
+ 0.33642858266830444,
28
+ 0.375,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.9375,
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.1875,
36
+ -0.3675000071525574,
37
+ -0.36000001430511475,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.7454732114076613,
42
+ -0.6616071462631226,
43
+ -0.9375,
44
+ -0.1071428582072258,
45
+ -0.20678570866584778,
46
+ -0.1842857152223587,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.9375,
51
+ 0.8758928775787354,
52
+ 0.9321428537368774,
53
+ 0.1039285734295845,
54
+ 0.17678570747375488,
55
+ 0.14571428298950195,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ -0.024462558329105377,
71
+ 0.106529600918293,
72
+ 1.0580483675003052,
73
+ 3.0628468990325928,
74
+ -0.10464039444923401,
75
+ 0.08307311683893204,
76
+ 0.01995457336306572,
77
+ -0.020162804052233696
78
+ ],
79
+ "std": [
80
+ 0.1101478561758995,
81
+ 0.13784688711166382,
82
+ 0.1044282391667366,
83
+ 0.10451053828001022,
84
+ 0.4112098217010498,
85
+ 0.2176690548658371,
86
+ 0.017260896041989326,
87
+ 0.0171116404235363
88
+ ],
89
+ "max": [
90
+ 0.1759040206670761,
91
+ 0.3904820382595062,
92
+ 1.3290715217590332,
93
+ 3.4566118717193604,
94
+ 1.2268599271774292,
95
+ 1.0429412126541138,
96
+ 0.041053611785173416,
97
+ 0.000775813648942858
98
+ ],
99
+ "min": [
100
+ -0.3095473051071167,
101
+ -0.29250794649124146,
102
+ 0.9095591306686401,
103
+ 2.497488260269165,
104
+ -1.8006486892700195,
105
+ -0.7207611203193665,
106
+ -0.0004703797458205372,
107
+ -0.041536275297403336
108
+ ],
109
+ "q01": [
110
+ -0.2727657300233841,
111
+ -0.23721413239836692,
112
+ 0.9160063165426254,
113
+ 2.77949666261673,
114
+ -1.3187511622905732,
115
+ -0.41989982962608335,
116
+ 0.001503719249740243,
117
+ -0.03989770736545324
118
+ ],
119
+ "q99": [
120
+ 0.13529365032911292,
121
+ 0.3629165390133857,
122
+ 1.2862326657772063,
123
+ 3.2829698753356933,
124
+ 0.9332760351896285,
125
+ 0.6325724506378171,
126
+ 0.039933966137468815,
127
+ -0.001671919699292631
128
+ ]
129
+ },
130
+ "num_transitions": 52970,
131
+ "num_trajectories": 432
132
+ }
133
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151643,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.40.1"
7
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff4549dd63f819f7da0ca174f60310f229f214b92c702407e983b3120ece2dea
3
+ size 2505232584
modeling_prismatic.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ NUM_TOKENS
37
+ )
38
+
39
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
40
+
41
+
42
+
43
+ # Set up logger
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # === Utility Functions for Monkey-Patching ===
48
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
49
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
50
+ result = fn(*args, **kwargs)
51
+ return result[0] if isinstance(result, tuple) else result
52
+
53
+ return wrapper
54
+
55
+
56
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
57
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
58
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
59
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
61
+
62
+
63
+ def ls_apply_patch(ls_module: LayerScale):
64
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
65
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
66
+ del ls_module.gamma
67
+
68
+
69
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
70
+ class PrismaticVisionBackbone(nn.Module):
71
+ """
72
+ Vision backbone for Prismatic models that handles image feature extraction.
73
+
74
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
75
+ For fused backbones, features from both models are concatenated along the feature dimension.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ use_fused_vision_backbone: bool,
81
+ image_sizes: List[int],
82
+ timm_model_ids: List[str],
83
+ timm_override_act_layers: List[Optional[str]],
84
+ ) -> None:
85
+ """
86
+ Initialize the vision backbone.
87
+
88
+ Args:
89
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
90
+ image_sizes: List of image sizes for each backbone
91
+ timm_model_ids: List of TIMM model IDs to use for each backbone
92
+ timm_override_act_layers: List of activation layer overrides for each backbone
93
+ """
94
+ super().__init__()
95
+ self.use_fused_vision_backbone = use_fused_vision_backbone
96
+ self.num_images_in_input = 1 # Default value, can be overridden later
97
+
98
+ # Validate number of (fused) vision backbones
99
+ if len(timm_model_ids) > 2:
100
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
101
+
102
+ # Create primary featurizer
103
+ self.featurizer = self._create_featurizer(
104
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
105
+ )
106
+ self.embed_dim = self.featurizer.embed_dim
107
+
108
+ # Create secondary featurizer if using fused backbone
109
+ if self.use_fused_vision_backbone:
110
+ self.fused_featurizer = self._create_featurizer(
111
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
112
+ )
113
+ self.embed_dim += self.fused_featurizer.embed_dim
114
+
115
+ # Patch LayerScale modules for HF compatibility
116
+ self._patch_layer_scales()
117
+
118
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
119
+ """
120
+ Create a TIMM-based featurizer model with appropriate configurations.
121
+
122
+ Args:
123
+ model_id: The TIMM model ID to load
124
+ img_size: Input image size for the model
125
+ act_layer: Override for the activation layer type
126
+
127
+ Returns:
128
+ A configured featurizer model
129
+ """
130
+ featurizer = timm.create_model(
131
+ model_id,
132
+ pretrained=False,
133
+ num_classes=0,
134
+ img_size=img_size,
135
+ act_layer=act_layer,
136
+ )
137
+
138
+ # Monkey-patch the forward function to extract the second-to-last layer features
139
+ num_blocks = len(featurizer.blocks)
140
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
141
+
142
+ return featurizer
143
+
144
+ def _patch_layer_scales(self) -> None:
145
+ """
146
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
147
+
148
+ HF Transformers overwrites parameters with names containing 'gamma',
149
+ so we need to rename and modify the forward method.
150
+ """
151
+ # Patch primary featurizer
152
+ for module in self.featurizer.modules():
153
+ if isinstance(module, LayerScale):
154
+ ls_apply_patch(module)
155
+
156
+ # Patch secondary featurizer if it exists
157
+ if self.use_fused_vision_backbone:
158
+ for module in self.fused_featurizer.modules():
159
+ if isinstance(module, LayerScale):
160
+ ls_apply_patch(module)
161
+
162
+ def get_num_patches(self) -> int:
163
+ """
164
+ Returns the number of vision patches output by the vision backbone.
165
+
166
+ Returns:
167
+ Number of patches per image
168
+ """
169
+ return self.featurizer.patch_embed.num_patches
170
+
171
+ def get_num_images_in_input(self) -> int:
172
+ """
173
+ Returns the number of input images for the vision backbone.
174
+
175
+ Returns:
176
+ Number of images expected in the input
177
+ """
178
+ return self.num_images_in_input
179
+
180
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
181
+ """
182
+ Sets the number of input images for the vision backbone.
183
+
184
+ Args:
185
+ num_images_in_input: Number of images to expect in the input
186
+ """
187
+ self.num_images_in_input = num_images_in_input
188
+
189
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Implements the forward pass for the vision backbone.
192
+
193
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
194
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
195
+
196
+ Args:
197
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
198
+ """
199
+ if self.num_images_in_input == 1:
200
+ if not self.use_fused_vision_backbone:
201
+ return self.featurizer(pixel_values)
202
+
203
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
204
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
205
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
206
+
207
+ return torch.cat([patches, patches_fused], dim=2)
208
+
209
+ else:
210
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
211
+
212
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
213
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
214
+
215
+ # Process each image and collect patches
216
+ all_patches = []
217
+ for img in images:
218
+ # Split each image further into two stacks of channels (each with 3 channels)
219
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
220
+
221
+ # Get patches from both SigLIP and DINOv2 vision transformers
222
+ patches = self.featurizer(img_regular)
223
+ patches_fused = self.fused_featurizer(img_fused)
224
+
225
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
226
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
227
+ all_patches.append(combined_patches)
228
+
229
+ # Concatenate all patches along the patch dimension
230
+ return torch.cat(all_patches, dim=1)
231
+
232
+
233
+ # === Prismatic Projector (nn.Module) Definitions ===
234
+ class PrismaticProjector(nn.Module):
235
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
236
+ super().__init__()
237
+ self.use_fused_vision_backbone = use_fused_vision_backbone
238
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
239
+
240
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
241
+ if not self.use_fused_vision_backbone:
242
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
243
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
244
+ self.act_fn1 = nn.GELU()
245
+ else:
246
+ initial_projection_dim = 4 * vision_dim
247
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
248
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
249
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
250
+ self.act_fn1 = nn.GELU()
251
+ self.act_fn2 = nn.GELU()
252
+
253
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
254
+ if not self.use_fused_vision_backbone:
255
+ projected_features = self.fc1(img_patches)
256
+ projected_features = self.act_fn1(projected_features)
257
+ projected_features = self.fc2(projected_features)
258
+ else:
259
+ projected_features = self.fc1(img_patches)
260
+ projected_features = self.act_fn1(projected_features)
261
+ projected_features = self.fc2(projected_features)
262
+ projected_features = self.act_fn2(projected_features)
263
+ projected_features = self.fc3(projected_features)
264
+
265
+ return projected_features
266
+
267
+
268
+ # === Main HF Class Definitions ===
269
+ @dataclass
270
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
271
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
272
+
273
+ loss: Optional[torch.FloatTensor] = None
274
+ logits: torch.FloatTensor = None
275
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
276
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
277
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
278
+
279
+ # Additions for VLMs
280
+ projector_features: Optional[torch.FloatTensor] = None
281
+
282
+
283
+ class PrismaticPreTrainedModel(PreTrainedModel):
284
+ config_class: PretrainedConfig = PrismaticConfig
285
+ base_model_prefix: str = "model"
286
+ supports_gradient_checkpointing: bool = True
287
+
288
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
289
+ _skip_keys_device_placement: str = "past_key_values"
290
+ _supports_flash_attn_2: bool = True
291
+
292
+ def _init_weights(self, module: nn.Module) -> None:
293
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
294
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
295
+ # https://github.com/TRI-ML/prismatic-vlms
296
+ std = (
297
+ self.config.initializer_range
298
+ if hasattr(self.config, "initializer_range")
299
+ else self.config.text_config.initializer_range
300
+ )
301
+
302
+ if hasattr(module, "class_embedding"):
303
+ module.class_embedding.data.normal_(mean=0.0, std=std)
304
+
305
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
306
+ module.weight.data.normal_(mean=0.0, std=std)
307
+ if module.bias is not None:
308
+ module.bias.data.zero_()
309
+ elif isinstance(module, nn.Embedding):
310
+ module.weight.data.normal_(mean=0.0, std=std)
311
+ if module.padding_idx is not None:
312
+ module.weight.data[module.padding_idx].zero_()
313
+
314
+ @property
315
+ def _supports_sdpa(self) -> bool:
316
+ """Check LLM supports SDPA Attention"""
317
+ return self.language_model._supports_sdpa
318
+
319
+
320
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
321
+ def __init__(self, config: PrismaticConfig) -> None:
322
+ super().__init__(config)
323
+
324
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
325
+ if config.use_fused_vision_backbone is None:
326
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
327
+
328
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
329
+ raise NotImplementedError(
330
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
331
+ "if you urgently need support for latest TIMM versions."
332
+ )
333
+
334
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
335
+ logger.warning(
336
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
337
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
338
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
339
+ f"use the above versions."
340
+ )
341
+ # import pdb; pdb.set_trace()
342
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
343
+ self.vision_backbone = PrismaticVisionBackbone(
344
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
345
+ )
346
+
347
+ # Create Multimodal Projector
348
+ self.projector = PrismaticProjector(
349
+ config.use_fused_vision_backbone,
350
+ vision_dim=self.vision_backbone.embed_dim,
351
+ llm_dim=config.text_config.hidden_size,
352
+ )
353
+
354
+ # Instantiate LLM Backbone
355
+ self.language_model = AutoModelForCausalLM.from_config(
356
+ config.text_config, attn_implementation=config._attn_implementation
357
+ )
358
+
359
+ self.vocab_size = config.text_config.vocab_size
360
+ self.pad_token_id = config.pad_token_id
361
+ self.llm_dim = config.text_config.hidden_size
362
+
363
+ #Action query token
364
+ self.action_queries = nn.Embedding(NUM_TOKENS, self.llm_dim)
365
+ self.action_queries.weight.data.zero_()
366
+
367
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
368
+ self.post_init()
369
+
370
+ # === `PreTrainedModel` Boilerplate ===
371
+ def get_input_embeddings(self) -> nn.Module:
372
+ return self.language_model.get_input_embeddings()
373
+ def set_version(self, version: str):
374
+ self.version = version
375
+ return self.version
376
+
377
+
378
+ def set_input_embeddings(self, value: nn.Module) -> None:
379
+ self.language_model.set_input_embeddings(value)
380
+
381
+ def get_output_embeddings(self) -> nn.Module:
382
+ return self.language_model.get_output_embeddings()
383
+
384
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
385
+ self.language_model.set_output_embeddings(new_embeddings)
386
+
387
+ def get_decoder(self) -> nn.Module:
388
+ return self.language_model.get_decoder()
389
+
390
+ def set_decoder(self, decoder: nn.Module) -> None:
391
+ self.language_model.set_decoder(decoder)
392
+
393
+ def tie_weights(self) -> None:
394
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
395
+
396
+ def resize_token_embeddings(
397
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
398
+ ) -> nn.Embedding:
399
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
400
+
401
+ # Update config/instance variables
402
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
403
+ self.vocab_size = updated_embeddings.num_embeddings
404
+
405
+ return updated_embeddings
406
+
407
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
408
+ """
409
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
410
+ with embeddings from noisy_action_features, using vectorized operations.
411
+
412
+ Args:
413
+ input_embeddings: Tensor of shape (B, S, D)
414
+ all_actions_mask: Boolean tensor of shape (B, S)
415
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
416
+
417
+ Returns:
418
+ Modified input_embeddings tensor
419
+ """
420
+ # Clone input to avoid modifying the original tensor
421
+ new_input_embeddings = input_embeddings.clone()
422
+
423
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
424
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
425
+
426
+ # Create batch indices for splicing
427
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
428
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
429
+
430
+ # Get indices where mask is True for each sample
431
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
432
+
433
+ # Move the noisy action features into their correct positions
434
+ # print(noisy_action_features.size())
435
+ # import pdb; pdb.set_trace()
436
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
437
+
438
+ # Combine original input embeddings and noisy action embeddings using the mask
439
+ new_input_embeddings = torch.where(
440
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
441
+ )
442
+
443
+ return new_input_embeddings
444
+
445
+ def _process_action_masks(self, labels):
446
+ """Helper to get action masks from labels"""
447
+ current_action_mask = get_current_action_mask(labels)
448
+ next_actions_mask = get_next_actions_mask(labels)
449
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
450
+ return all_actions_mask
451
+
452
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
453
+ """Process vision features with optional FiLM conditioning"""
454
+ if use_film:
455
+ # FiLM: Infuse language inputs into visual features
456
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
457
+ else:
458
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
459
+
460
+ # Project patch embeddings into language embedding space
461
+ return self.projector(patch_features)
462
+
463
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
464
+ """Process proprioceptive features and append to vision features"""
465
+ if proprio_projector is not None and proprio is not None:
466
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
467
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
468
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
469
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
470
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
471
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
472
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
473
+ return projected_patch_embeddings
474
+
475
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
476
+ """Build multimodal embeddings and attention mask"""
477
+ # Update attention mask
478
+ # import pdb; pdb.set_trace()
479
+ projected_patch_attention_mask = None
480
+ if attention_mask is not None:
481
+ projected_patch_attention_mask = torch.full(
482
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
483
+ fill_value=True,
484
+ dtype=attention_mask.dtype,
485
+ device=attention_mask.device,
486
+ )
487
+
488
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
489
+ multimodal_embeddings = torch.cat(
490
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
491
+ )
492
+
493
+ multimodal_attention_mask = None
494
+ if attention_mask is not None:
495
+ multimodal_attention_mask = torch.cat(
496
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
497
+ )
498
+
499
+ return multimodal_embeddings, multimodal_attention_mask
500
+
501
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
502
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
503
+ if labels is not None:
504
+ projected_patch_labels = torch.full(
505
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
506
+ fill_value=IGNORE_INDEX,
507
+ dtype=labels.dtype,
508
+ device=labels.device,
509
+ )
510
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
511
+ return None
512
+
513
+ # === Core Prismatic VLM `forward()` Logic ===
514
+ def forward(
515
+ self,
516
+ input_ids: Optional[torch.LongTensor] = None,
517
+ attention_mask: Optional[torch.Tensor] = None,
518
+ pixel_values: Optional[torch.FloatTensor] = None,
519
+ labels: Optional[torch.LongTensor] = None,
520
+ inputs_embeds: Optional[torch.FloatTensor] = None,
521
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
522
+ use_cache: Optional[bool] = None,
523
+ output_attentions: Optional[bool] = None,
524
+ output_hidden_states: Optional[bool] = None,
525
+ output_projector_features: Optional[bool] = None,
526
+ return_dict: Optional[bool] = None,
527
+ proprio=None,
528
+ proprio_projector=None,
529
+ noisy_actions=None,
530
+ noisy_action_projector=None,
531
+ diffusion_timestep_embeddings=None,
532
+ use_film: bool = False,
533
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
534
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
535
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
536
+ output_hidden_states = (
537
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
538
+ )
539
+ output_projector_features = output_projector_features if output_projector_features is not None else False
540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541
+
542
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
543
+ use_cache = use_cache and not self.training
544
+
545
+ # Instantiate Placeholder for Projector Features
546
+ projected_patch_embeddings = None
547
+
548
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
549
+ if input_ids.shape[1] == 1:
550
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
551
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
552
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
553
+
554
+ language_model_output = self.language_model(
555
+ input_ids=input_ids,
556
+ attention_mask=None,
557
+ position_ids=None,
558
+ past_key_values=past_key_values,
559
+ inputs_embeds=None,
560
+ labels=None,
561
+ use_cache=use_cache,
562
+ output_attentions=output_attentions,
563
+ output_hidden_states=output_hidden_states,
564
+ return_dict=return_dict,
565
+ )
566
+
567
+ # === Handle Unimodal Forward ===
568
+ elif pixel_values is None:
569
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
570
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
571
+
572
+ language_model_output = self.language_model(
573
+ input_ids=input_ids,
574
+ attention_mask=attention_mask,
575
+ position_ids=None,
576
+ past_key_values=None,
577
+ inputs_embeds=None,
578
+ labels=labels,
579
+ use_cache=use_cache,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states,
582
+ return_dict=return_dict,
583
+ )
584
+
585
+ # === Handle Multimodal Forward ===
586
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
587
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
588
+
589
+ # Get input embeddings (from language model embeddings)
590
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
591
+
592
+ # import pdb; pdb.set_trace()
593
+ # Extract action masks
594
+ all_actions_mask = self._process_action_masks(labels)
595
+
596
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
597
+ # import pdb; pdb.set_trace()
598
+ # print(input_embeddings[~all_actions_mask].size())
599
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
600
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
601
+ ) # (B, lang_seq_len, llm_dim)
602
+
603
+ # Get visual features
604
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
605
+
606
+ # Add proprioceptive state if provided
607
+ if self.version == 'v1':
608
+ pass
609
+ else:
610
+ projected_patch_embeddings = self._process_proprio_features(
611
+ projected_patch_embeddings, proprio, proprio_projector
612
+ )
613
+
614
+ # [Diffusion] Add diffusion timestep embedding if provided
615
+ if diffusion_timestep_embeddings is not None:
616
+ if self.version == 'v1':
617
+ pass
618
+ else:
619
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
620
+ projected_patch_embeddings = torch.cat(
621
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
622
+ )
623
+
624
+
625
+ # Process action embeddings
626
+ if noisy_actions is not None:
627
+ # import pdb; pdb.set_trace()
628
+ if self.version == 'v1':
629
+ # action_queries = self.action_queries.weight # (1, h)
630
+ # action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
631
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
632
+ # action_attention_mask = None
633
+ # action_attention_mask = torch.full(
634
+ # (action_queries.shape[0], action_queries.shape[1]),
635
+ # fill_value=True,
636
+ # dtype=attention_mask.dtype,
637
+ # device=attention_mask.device,)
638
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
639
+
640
+ action_queries = self.action_queries.weight # (1, h)
641
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
642
+ all_actions_mask = self._process_action_masks(labels)
643
+ input_embeddings = self._replace_input_embeddings(
644
+ input_embeddings, all_actions_mask, action_queries)
645
+ # import pdb; pdb.set_trace()
646
+
647
+ else:
648
+ # Get mask corresponding to all action tokens
649
+ all_actions_mask = self._process_action_masks(labels)
650
+
651
+ # Reshape noisy actions into individual action tokens
652
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
653
+ B = noisy_actions.shape[0]
654
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
655
+ # Project noisy action tokens into language model embedding space
656
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
657
+ # Replace embeddings of the action tokens with noisy action embeddings
658
+ input_embeddings = self._replace_input_embeddings(
659
+ input_embeddings, all_actions_mask, noisy_action_features)
660
+
661
+ else:
662
+ if self.version == 'v1':
663
+ action_queries = self.action_queries.weight # (1, h)
664
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
665
+ all_actions_mask = self._process_action_masks(labels)
666
+ input_embeddings = self._replace_input_embeddings(
667
+ input_embeddings, all_actions_mask, action_queries)
668
+ # import pdb; pdb.set_trace()
669
+ else:
670
+ # Replace the embeddings of the action tokens with zeros
671
+ # (Later on, the positional embeddings will be added to them)
672
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
673
+ input_embeddings = input_embeddings * ~all_actions_mask
674
+
675
+
676
+ # Build multimodal embeddings & attention mask
677
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
678
+ input_embeddings, projected_patch_embeddings, attention_mask
679
+ )
680
+ # import pdb; pdb.set_trace()
681
+ # Build labels for multimodal sequence if needed
682
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
683
+
684
+ # import pdb; pdb.set_trace()
685
+ # Dispatch to language model
686
+ if self.version == 'v1':
687
+ # import pdb; pdb.set_trace()
688
+ language_model_output = self.language_model(
689
+ input_ids=None,
690
+ attention_mask=multimodal_attention_mask,
691
+ position_ids=None,
692
+ past_key_values=None,
693
+ inputs_embeds=multimodal_embeddings,
694
+ labels=None,
695
+ use_cache=use_cache,
696
+ output_attentions=output_attentions,
697
+ output_hidden_states=output_hidden_states,
698
+ return_dict=return_dict,
699
+ )
700
+ # import pdb; pdb.set_trace()
701
+ else:
702
+ language_model_output = self.language_model(
703
+ input_ids=None,
704
+ attention_mask=multimodal_attention_mask,
705
+ position_ids=None,
706
+ past_key_values=None,
707
+ inputs_embeds=multimodal_embeddings,
708
+ labels=multimodal_labels,
709
+ use_cache=use_cache,
710
+ output_attentions=output_attentions,
711
+ output_hidden_states=output_hidden_states,
712
+ return_dict=return_dict,
713
+ )
714
+
715
+ # === Otherwise =>> Assume Invalid! ===
716
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
717
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
718
+
719
+ else:
720
+ raise ValueError(
721
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
722
+ f"=> `input_ids` = {input_ids is not None}\n"
723
+ f"=> `attention_mask` = {attention_mask is not None}\n"
724
+ f"=> `pixel_values` = {pixel_values is not None}\n"
725
+ f"=> `labels` = {labels is not None}\n"
726
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
727
+ f"=> `past_key_values` = {past_key_values is not None}\n"
728
+ f"=> `use_cache` = {use_cache}"
729
+ )
730
+
731
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
732
+ if not return_dict:
733
+ if output_projector_features and (projected_patch_embeddings is not None):
734
+ return *language_model_output, projected_patch_embeddings
735
+
736
+ return language_model_output
737
+
738
+ if self.version == 'v1':
739
+ return PrismaticCausalLMOutputWithPast(
740
+ loss=language_model_output.loss,
741
+ past_key_values=language_model_output.past_key_values,
742
+ hidden_states=language_model_output.hidden_states,
743
+ attentions=language_model_output.attentions,
744
+ projector_features=projected_patch_embeddings,
745
+ )
746
+ else:
747
+ return PrismaticCausalLMOutputWithPast(
748
+ loss=language_model_output.loss,
749
+ logits=language_model_output.logits,
750
+ past_key_values=language_model_output.past_key_values,
751
+ hidden_states=language_model_output.hidden_states,
752
+ attentions=language_model_output.attentions,
753
+ projector_features=projected_patch_embeddings,
754
+ )
755
+
756
+ # === GenerationMixin Methods ===
757
+ def prepare_inputs_for_generation(
758
+ self,
759
+ input_ids: Optional[torch.Tensor] = None,
760
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
761
+ inputs_embeds: Optional[torch.FloatTensor] = None,
762
+ pixel_values: Optional[torch.FloatTensor] = None,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ **kwargs: str,
765
+ ) -> Dict[str, torch.Tensor]:
766
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
767
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
768
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
769
+ ):
770
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
771
+
772
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
773
+ if past_key_values is not None:
774
+ input_ids = input_ids[:, -1:]
775
+
776
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
777
+ if inputs_embeds is not None and past_key_values is None:
778
+ model_inputs = {"input_embeds": inputs_embeds}
779
+ else:
780
+ model_inputs = {"input_ids": input_ids}
781
+
782
+ # Make sure `pixel_values` are preserved in `model_inputs`
783
+ model_inputs.update(
784
+ {
785
+ "attention_mask": attention_mask,
786
+ "pixel_values": pixel_values,
787
+ "past_key_values": past_key_values,
788
+ "use_cache": kwargs.get("use_cache"),
789
+ }
790
+ )
791
+
792
+ return model_inputs
793
+
794
+ # Defer to Language Model (all handle this differently, with different return types)
795
+ def _reorder_cache(self, *args, **kwargs) -> Any:
796
+ return self.language_model._reorder_cache(*args, **kwargs)
797
+
798
+
799
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
800
+ config_class: PretrainedConfig = OpenVLAConfig
801
+
802
+ def __init__(self, config: OpenVLAConfig) -> None:
803
+ super().__init__(config)
804
+ self.norm_stats = config.norm_stats
805
+ # import pdb; pdb.set_trace()
806
+
807
+ # Compute action bins
808
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
809
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
810
+
811
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
812
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
813
+
814
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
815
+ """Prepares input for action prediction by adding necessary tokens"""
816
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
817
+ placeholder_action_token_ids = (
818
+ torch.ones((input_ids.shape[0], NUM_TOKENS)).to(input_ids.device).to(input_ids.dtype)
819
+ )
820
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
821
+
822
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
823
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
824
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
825
+
826
+ # Extend the attention mask to fit the new shape of input
827
+ # Note: Only batch size == 1 supported right now
828
+ mask_extension = (
829
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
830
+ .to(attention_mask.device)
831
+ .to(attention_mask.dtype)
832
+ )
833
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
834
+
835
+ return input_ids, attention_mask
836
+
837
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
838
+ """Creates labels tensor for action prediction if not provided"""
839
+ # Extend labels tensor with fake action labels
840
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
841
+ labels_extension = (
842
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
843
+ * ARBITRARY_ACTION_TOKEN_IDX
844
+ )
845
+ labels = torch.cat([labels, labels_extension], dim=-1)
846
+
847
+ # Replace last label token with stop token
848
+ labels[:, -1] = STOP_INDEX
849
+
850
+ return labels
851
+
852
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
853
+ """Unnormalize actions using dataset statistics"""
854
+ action_norm_stats = self.get_action_stats(unnorm_key)
855
+
856
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
857
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
858
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
859
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
860
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
861
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
862
+ else:
863
+ raise ValueError("Unsupported action/proprio normalization type detected!")
864
+
865
+ actions = np.where(
866
+ mask,
867
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
868
+ normalized_actions,
869
+ )
870
+
871
+ return actions
872
+
873
+ def _run_flow_matching_prediction(
874
+ self,
875
+ input_embeddings,
876
+ all_actions_mask,
877
+ noise,
878
+ action_head,
879
+ projected_patch_embeddings,
880
+ labels,
881
+ attention_mask,
882
+ NUM_PATCHES,
883
+ NUM_PROMPT_TOKENS,
884
+ noisy_action_projector
885
+ ):
886
+ """Run flow matching-based action prediction"""
887
+ # Clone embedding for reuse in each timestep
888
+ # orig_projected_patch_embeddings = projected_patch_embeddings.clone()
889
+
890
+ dt = -1.0 / action_head.num_flow_steps
891
+ dt = torch.tensor(dt, dtype=torch.bfloat16, device=labels.device)
892
+
893
+ curr_noisy_actions = noise
894
+ time = torch.tensor(1.0, dtype=torch.bfloat16, device=labels.device)
895
+ while time >= -dt / 2:
896
+ B = curr_noisy_actions.shape[0]
897
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
898
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
899
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
900
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
901
+
902
+ # Replace action token embeddings with noisy action embeddings
903
+ input_embeddings = self._replace_input_embeddings(
904
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
905
+ )
906
+
907
+ # Build multimodal embeddings and attention mask
908
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
909
+ input_embeddings, projected_patch_embeddings, attention_mask
910
+ )
911
+
912
+ # Forward pass through language model
913
+ language_model_output = self.language_model(
914
+ input_ids=None,
915
+ attention_mask=multimodal_attention_mask,
916
+ position_ids=None,
917
+ past_key_values=None,
918
+ inputs_embeds=multimodal_embeddings,
919
+ labels=None,
920
+ use_cache=None,
921
+ output_attentions=False,
922
+ output_hidden_states=True,
923
+ return_dict=True,
924
+ )
925
+
926
+ # Extract hidden states for action portion of response
927
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
928
+ actions_hidden_states = last_hidden_states[
929
+ :,
930
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
931
+ :,
932
+ ] # (B, act_chunk_len, D)
933
+
934
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
935
+ flow_pred = action_head.predict_flow(actions_hidden_states)
936
+ curr_noisy_actions += dt * flow_pred
937
+ time += dt
938
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
939
+
940
+ # Return final actions
941
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
942
+
943
+
944
+ def _run_diffusion_prediction(
945
+ self,
946
+ input_embeddings,
947
+ all_actions_mask,
948
+ noise,
949
+ action_head,
950
+ projected_patch_embeddings,
951
+ labels,
952
+ attention_mask,
953
+ NUM_PATCHES,
954
+ NUM_PROMPT_TOKENS,
955
+ noisy_action_projector,
956
+ ):
957
+ """Run diffusion-based action prediction"""
958
+ # Set diffusion timestep values
959
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
960
+ # Clone embedding for reuse in each timestep
961
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
962
+ curr_noisy_actions = noise
963
+
964
+ # Reverse diffusion: Iteratively denoise to generate action prediction
965
+ for t in action_head.noise_scheduler.timesteps:
966
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
967
+ # embedding, and diffusion timestep embedding)
968
+ timesteps = torch.Tensor([t]).to(labels.device)
969
+ diffusion_timestep_embeddings = (
970
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
971
+ ) # (B, llm_dim)
972
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
973
+
974
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
975
+ # (Later on, the positional embeddings will be added to them)
976
+
977
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
978
+ projected_patch_embeddings = torch.cat(
979
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
980
+ )
981
+
982
+ # Reshape and project noisy actions into language embedding space
983
+ B = curr_noisy_actions.shape[0]
984
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
985
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
986
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
987
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
988
+
989
+ # Replace action token embeddings with noisy action embeddings
990
+ input_embeddings = self._replace_input_embeddings(
991
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
992
+ )
993
+
994
+ # Build multimodal embeddings and attention mask
995
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
996
+ input_embeddings, projected_patch_embeddings, attention_mask
997
+ )
998
+
999
+ # Forward pass through language model
1000
+ language_model_output = self.language_model(
1001
+ input_ids=None,
1002
+ attention_mask=multimodal_attention_mask,
1003
+ position_ids=None,
1004
+ past_key_values=None,
1005
+ inputs_embeds=multimodal_embeddings,
1006
+ labels=None,
1007
+ use_cache=None,
1008
+ output_attentions=False,
1009
+ output_hidden_states=True,
1010
+ return_dict=True,
1011
+ )
1012
+
1013
+ # Extract hidden states for action portion of response
1014
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1015
+ actions_hidden_states = last_hidden_states[
1016
+ :,
1017
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1018
+ :,
1019
+ ] # (B, act_chunk_len, D)
1020
+
1021
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1022
+ noise_pred = action_head.predict_noise(actions_hidden_states)
1023
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1024
+
1025
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1026
+
1027
+ # Return final actions
1028
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1029
+
1030
+ def _run_diffusion_prediction_V1(
1031
+ self,
1032
+ input_embeddings,
1033
+ all_actions_mask,
1034
+ noise,
1035
+ action_head,
1036
+ projected_patch_embeddings,
1037
+ labels,
1038
+ attention_mask,
1039
+ NUM_PATCHES,
1040
+ NUM_PROMPT_TOKENS,
1041
+ noisy_action_projector,
1042
+ proprio,
1043
+ proprio_projector,
1044
+ ):
1045
+ """Run diffusion-based action prediction"""
1046
+ # Set diffusion timestep values
1047
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
1048
+ # Clone embedding for reuse in each timestep
1049
+ curr_noisy_actions = noise
1050
+
1051
+ # import pdb; pdb.set_trace()
1052
+
1053
+ action_queries = self.action_queries.weight # (1, h)
1054
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1055
+ # Replace action token embeddings with noisy action embeddings
1056
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1057
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
1058
+ # action_attention_mask = None
1059
+ # action_attention_mask = torch.full(
1060
+ # (action_queries.shape[0], action_queries.shape[1]),
1061
+ # fill_value=True,
1062
+ # dtype=attention_mask.dtype,
1063
+ # device=attention_mask.device,)
1064
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
1065
+
1066
+ # Build multimodal embeddings and attention mask
1067
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1068
+ input_embeddings, projected_patch_embeddings, attention_mask
1069
+ )
1070
+
1071
+ # import pdb; pdb.set_trace()
1072
+ # Forward pass through language model
1073
+ language_model_output = self.language_model(
1074
+ input_ids=None,
1075
+ attention_mask=multimodal_attention_mask,
1076
+ position_ids=None,
1077
+ past_key_values=None,
1078
+ inputs_embeds=multimodal_embeddings,
1079
+ labels=None,
1080
+ use_cache=None,
1081
+ output_attentions=False,
1082
+ output_hidden_states=True,
1083
+ return_dict=True,
1084
+ )
1085
+ multi_layer_hidden_states = []
1086
+ # import pdb; pdb.set_trace()
1087
+ for item in language_model_output.hidden_states[0:]:
1088
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1089
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1090
+ text_hidden_states = item
1091
+ # Get hidden states for action portion of response
1092
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1093
+ # import pdb; pdb.set_trace()
1094
+ batch_size = item.shape[0]
1095
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1096
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1097
+ multi_layer_hidden_states.append(all_hidden_states)
1098
+ # import pdb; pdb.set_trace()
1099
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1100
+ # import pdb; pdb.set_trace()
1101
+
1102
+
1103
+
1104
+ # Reverse diffusion: Iteratively denoise to generate action prediction
1105
+ for t in action_head.noise_scheduler.timesteps:
1106
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
1107
+ # embedding, and diffusion timestep embedding)
1108
+ timesteps = torch.Tensor([t]).to(labels.device)
1109
+ diffusion_timestep_embeddings = (
1110
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1111
+ ) # (B, llm_dim)
1112
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
1113
+
1114
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
1115
+ # (Later on, the positional embeddings will be added to them)
1116
+
1117
+ # Reshape and project noisy actions into language embedding space
1118
+ B = curr_noisy_actions.shape[0]
1119
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1120
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1121
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1122
+
1123
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1124
+ # noise_pred = action_head.predict_noise(actions_hidden_states)
1125
+ noise_pred = action_head.predict_noise(multi_layer_hidden_states,
1126
+ noisy_actions=curr_noisy_actions,
1127
+ timestep_embeddings = diffusion_timestep_embeddings,
1128
+ noisy_action_projector=noisy_action_projector,
1129
+ proprio=proprio ,
1130
+ proprio_projector=proprio_projector)
1131
+
1132
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1133
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1134
+
1135
+ # Return final actions
1136
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1137
+
1138
+ def _regression_or_discrete_prediction_V1(
1139
+ self,
1140
+ input_embeddings,
1141
+ all_actions_mask,
1142
+ projected_patch_embeddings,
1143
+ attention_mask,
1144
+ labels,
1145
+ NUM_PATCHES,
1146
+ NUM_PROMPT_TOKENS,
1147
+ action_head=None,
1148
+ proprio=None,
1149
+ proprio_projector=None,
1150
+ ):
1151
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1152
+
1153
+ action_queries = self.action_queries.weight # (1, h)
1154
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1155
+ # Replace action token embeddings with noisy action embeddings
1156
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1157
+
1158
+ # Build multimodal embeddings and attention mask
1159
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1160
+ input_embeddings, projected_patch_embeddings, attention_mask
1161
+ )
1162
+
1163
+ # Forward pass through language model
1164
+ language_model_output = self.language_model(
1165
+ input_ids=None,
1166
+ attention_mask=multimodal_attention_mask,
1167
+ position_ids=None,
1168
+ past_key_values=None,
1169
+ inputs_embeds=multimodal_embeddings,
1170
+ labels=None,
1171
+ use_cache=None,
1172
+ output_attentions=False,
1173
+ output_hidden_states=True,
1174
+ return_dict=True,
1175
+ )
1176
+
1177
+ # Extract hidden states for action tokens
1178
+ multi_layer_hidden_states = []
1179
+ # import pdb; pdb.set_trace()
1180
+ for item in language_model_output.hidden_states[0:]:
1181
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1182
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1183
+ text_hidden_states = item
1184
+ # Get hidden states for action portion of response
1185
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1186
+ # import pdb; pdb.set_trace()
1187
+ batch_size = item.shape[0]
1188
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1189
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1190
+ multi_layer_hidden_states.append(all_hidden_states)
1191
+ # import pdb; pdb.set_trace()
1192
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1193
+ # import pdb; pdb.set_trace()
1194
+
1195
+ # Handle different prediction methods
1196
+ if action_head is not None:
1197
+ # L1 regression prediction
1198
+ normalized_actions = action_head.predict_action(multi_layer_hidden_states,
1199
+ proprio=proprio,
1200
+ proprio_projector=proprio_projector)
1201
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1202
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1203
+ else:
1204
+ # Discrete token-based prediction
1205
+ predicted_action_token_ids = (
1206
+ language_model_output.logits[
1207
+ :,
1208
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1209
+ ]
1210
+ .argmax(dim=2)
1211
+ .cpu()
1212
+ .numpy()
1213
+ )
1214
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1215
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1216
+ normalized_actions = self.bin_centers[discretized_actions]
1217
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1218
+
1219
+ return normalized_actions, actions_hidden_states
1220
+
1221
+ def _regression_or_discrete_prediction(
1222
+ self,
1223
+ input_embeddings,
1224
+ all_actions_mask,
1225
+ projected_patch_embeddings,
1226
+ attention_mask,
1227
+ labels,
1228
+ NUM_PATCHES,
1229
+ NUM_PROMPT_TOKENS,
1230
+ action_head=None,
1231
+ ):
1232
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1233
+ # Zero out action token embeddings
1234
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1235
+ input_embeddings = input_embeddings * ~all_actions_mask
1236
+
1237
+ # Build multimodal embeddings and attention mask
1238
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1239
+ input_embeddings, projected_patch_embeddings, attention_mask
1240
+ )
1241
+
1242
+ # Forward pass through language model
1243
+ language_model_output = self.language_model(
1244
+ input_ids=None,
1245
+ attention_mask=multimodal_attention_mask,
1246
+ position_ids=None,
1247
+ past_key_values=None,
1248
+ inputs_embeds=multimodal_embeddings,
1249
+ labels=None,
1250
+ use_cache=None,
1251
+ output_attentions=False,
1252
+ output_hidden_states=True,
1253
+ return_dict=True,
1254
+ )
1255
+
1256
+ # Extract hidden states for action tokens
1257
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1258
+ actions_hidden_states = last_hidden_states[
1259
+ :,
1260
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1261
+ :,
1262
+ ] # (B, act_chunk_len, D)
1263
+
1264
+ # Handle different prediction methods
1265
+ if action_head is not None:
1266
+ # L1 regression prediction
1267
+ normalized_actions = action_head.predict_action(actions_hidden_states)
1268
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1269
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1270
+ else:
1271
+ # Discrete token-based prediction
1272
+ predicted_action_token_ids = (
1273
+ language_model_output.logits[
1274
+ :,
1275
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1276
+ ]
1277
+ .argmax(dim=2)
1278
+ .cpu()
1279
+ .numpy()
1280
+ )
1281
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1282
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1283
+ normalized_actions = self.bin_centers[discretized_actions]
1284
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1285
+
1286
+ return normalized_actions, actions_hidden_states
1287
+
1288
+ def predict_action(
1289
+ self,
1290
+ input_ids: Optional[torch.LongTensor] = None,
1291
+ unnorm_key: Optional[str] = None,
1292
+ proprio=None,
1293
+ proprio_projector=None,
1294
+ action_head=None,
1295
+ noisy_action_projector=None,
1296
+ use_film: bool = False,
1297
+ **kwargs: str,
1298
+ ) -> np.ndarray:
1299
+ """Predict actions from input sequence, with options for different prediction methods.
1300
+
1301
+ Args:
1302
+ input_ids: Input token ids
1303
+ unnorm_key: Key for unnormalization statistics
1304
+ proprio: Proprioceptive features
1305
+ proprio_projector: Projector for proprioceptive features
1306
+ action_head: Optional head for L1 regression or diffusion-based prediction
1307
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1308
+ use_film: Whether to use FiLM conditioning
1309
+ **kwargs: Additional arguments including pixel_values and attention_mask
1310
+
1311
+ Returns:
1312
+ Tuple of (unnormalized_actions, action_hidden_states)
1313
+ """
1314
+ # import pdb; pdb.set_trace()
1315
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1316
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1317
+
1318
+
1319
+ # if not torch.all(input_ids[:, -1] == 29871):
1320
+ # input_ids = torch.cat(
1321
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1322
+ # )
1323
+
1324
+
1325
+ pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224]
1326
+ attention_mask = kwargs["attention_mask"] #
1327
+
1328
+ # Create fake labels tensor (needed for action mask)
1329
+ labels = input_ids.clone()
1330
+ labels[:] = IGNORE_INDEX
1331
+
1332
+ # Get number of tokens in prompt (excluding the start token)
1333
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1334
+
1335
+ # import pdb; pdb.set_trace()
1336
+
1337
+ # Prepare inputs by adding necessary tokens
1338
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1339
+
1340
+ # Update labels tensor for action mask computation later
1341
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1342
+
1343
+ # Get input embeddings and action masks
1344
+ input_embeddings = self.get_input_embeddings()(input_ids)
1345
+ all_actions_mask = self._process_action_masks(labels)
1346
+
1347
+ # Extract language embeddings
1348
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1349
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1350
+ )
1351
+
1352
+ # Process vision features
1353
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1354
+
1355
+ # Add proprioceptive features if provided
1356
+ use_proprio = proprio_projector is not None and proprio is not None
1357
+ if use_proprio:
1358
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1359
+ if self.version == 'v1':
1360
+ pass
1361
+ else:
1362
+ projected_patch_embeddings = self._process_proprio_features(
1363
+ projected_patch_embeddings, proprio, proprio_projector
1364
+ )
1365
+ # import pdb; pdb.set_trace()
1366
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1367
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1368
+ use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
1369
+
1370
+
1371
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1372
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1373
+ if self.version == 'v1':
1374
+ # if use_diffusion:
1375
+ # NUM_PATCHES += 1
1376
+ pass
1377
+ else:
1378
+ if use_proprio:
1379
+ NUM_PATCHES += 1
1380
+ if use_diffusion:
1381
+ NUM_PATCHES += 1
1382
+
1383
+ # import pdb; pdb.set_trace()
1384
+ if use_flow_matching:
1385
+ # Sample random noise with shape equal to output action, used as the starting state for flow matching
1386
+ noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
1387
+
1388
+ # Run flow matching-based prediction
1389
+ normalized_actions, actions_hidden_states = self._run_flow_matching_prediction(
1390
+ input_embeddings,
1391
+ all_actions_mask,
1392
+ noise,
1393
+ action_head,
1394
+ projected_patch_embeddings,
1395
+ labels,
1396
+ attention_mask,
1397
+ NUM_PATCHES,
1398
+ NUM_PROMPT_TOKENS,
1399
+ noisy_action_projector
1400
+ )
1401
+ elif use_diffusion:
1402
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1403
+ noise = torch.randn(
1404
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1405
+ )
1406
+ # import pdb; pdb.set_trace()
1407
+ if self.version == 'v1':
1408
+
1409
+ # import pdb; pdb.set_trace()
1410
+ # Run diffusion-based prediction
1411
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
1412
+ input_embeddings, # [1, 86, 4096]
1413
+ all_actions_mask, # [1, 86]
1414
+ noise, # [1,8, 7]
1415
+ action_head,
1416
+ projected_patch_embeddings, # [1, 512, 4096]
1417
+ labels, # [1, 86]
1418
+ attention_mask, # [1, 86]
1419
+ NUM_PATCHES, # 512
1420
+ NUM_PROMPT_TOKENS, # 28
1421
+ noisy_action_projector,
1422
+ proprio, # [8]
1423
+ proprio_projector,
1424
+ )
1425
+ else:
1426
+ # Run diffusion-based prediction
1427
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1428
+ input_embeddings,
1429
+ all_actions_mask,
1430
+ noise,
1431
+ action_head,
1432
+ projected_patch_embeddings,
1433
+ labels,
1434
+ attention_mask,
1435
+ NUM_PATCHES,
1436
+ NUM_PROMPT_TOKENS,
1437
+ noisy_action_projector,
1438
+ )
1439
+
1440
+ else:
1441
+ if self.version == 'v1':
1442
+ # Run regression or discrete token-based prediction
1443
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction_V1(
1444
+ input_embeddings,
1445
+ all_actions_mask,
1446
+ projected_patch_embeddings,
1447
+ attention_mask,
1448
+ labels,
1449
+ NUM_PATCHES,
1450
+ NUM_PROMPT_TOKENS,
1451
+ action_head=action_head,
1452
+ proprio=proprio, # [8]
1453
+ proprio_projector=proprio_projector,
1454
+ )
1455
+ else:
1456
+ # Run regression or discrete token-based prediction
1457
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1458
+ input_embeddings,
1459
+ all_actions_mask,
1460
+ projected_patch_embeddings,
1461
+ attention_mask,
1462
+ labels,
1463
+ NUM_PATCHES,
1464
+ NUM_PROMPT_TOKENS,
1465
+ action_head,
1466
+ )
1467
+
1468
+ # import pdb; pdb.set_trace()
1469
+ # Unnormalize predicted actions
1470
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1471
+
1472
+ return actions, actions_hidden_states
1473
+
1474
+ @staticmethod
1475
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1476
+ """Validate and resolve the unnormalization key for action statistics"""
1477
+ if unnorm_key is None:
1478
+ assert len(norm_stats) == 1, (
1479
+ f"Your model was trained on more than one dataset, "
1480
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1481
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1482
+ )
1483
+ unnorm_key = next(iter(norm_stats.keys()))
1484
+
1485
+ assert unnorm_key in norm_stats, (
1486
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1487
+ f"please choose from: {norm_stats.keys()}"
1488
+ )
1489
+ return unnorm_key
1490
+
1491
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1492
+ """Get the dimensionality of the policy's action space."""
1493
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1494
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1495
+
1496
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1497
+ """Get all the logged statistics for the given dataset."""
1498
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1499
+ return self.norm_stats[unnorm_key]["action"]
preprocessor_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
+ "image_processor_type": "PrismaticImageProcessor",
7
+ "image_resize_strategy": "resize-naive",
8
+ "input_sizes": [
9
+ [
10
+ 3,
11
+ 224,
12
+ 224
13
+ ],
14
+ [
15
+ 3,
16
+ 224,
17
+ 224
18
+ ]
19
+ ],
20
+ "interpolations": [
21
+ "bicubic",
22
+ "bicubic"
23
+ ],
24
+ "means": [
25
+ [
26
+ 0.485,
27
+ 0.456,
28
+ 0.406
29
+ ],
30
+ [
31
+ 0.5,
32
+ 0.5,
33
+ 0.5
34
+ ]
35
+ ],
36
+ "processor_class": "PrismaticProcessor",
37
+ "stds": [
38
+ [
39
+ 0.229,
40
+ 0.224,
41
+ 0.225
42
+ ],
43
+ [
44
+ 0.5,
45
+ 0.5,
46
+ 0.5
47
+ ]
48
+ ],
49
+ "tvf_crop_params": [
50
+ {
51
+ "output_size": [
52
+ 224,
53
+ 224
54
+ ]
55
+ },
56
+ {
57
+ "output_size": [
58
+ 224,
59
+ 224
60
+ ]
61
+ }
62
+ ],
63
+ "tvf_do_letterbox": false,
64
+ "tvf_letterbox_fill": null,
65
+ "tvf_normalize_params": [
66
+ {
67
+ "inplace": false,
68
+ "mean": [
69
+ 0.484375,
70
+ 0.455078125,
71
+ 0.40625
72
+ ],
73
+ "std": [
74
+ 0.228515625,
75
+ 0.2236328125,
76
+ 0.224609375
77
+ ]
78
+ },
79
+ {
80
+ "inplace": false,
81
+ "mean": [
82
+ 0.5,
83
+ 0.5,
84
+ 0.5
85
+ ],
86
+ "std": [
87
+ 0.5,
88
+ 0.5,
89
+ 0.5
90
+ ]
91
+ }
92
+ ],
93
+ "tvf_resize_params": [
94
+ {
95
+ "antialias": true,
96
+ "interpolation": 3,
97
+ "max_size": null,
98
+ "size": [
99
+ 224,
100
+ 224
101
+ ]
102
+ },
103
+ {
104
+ "antialias": true,
105
+ "interpolation": 3,
106
+ "max_size": null,
107
+ "size": [
108
+ 224,
109
+ 224
110
+ ]
111
+ }
112
+ ],
113
+ "use_fused_vision_backbone": true
114
+ }
processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
proprio_projector--20000_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8481fa9895e1500037c7f059f8722f915055a9db1105a474edcdac18a2333364
3
+ size 1626096
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "auto_map": {
198
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
199
+ },
200
+ "bos_token": null,
201
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
202
+ "clean_up_tokenization_spaces": false,
203
+ "eos_token": "<|endoftext|>",
204
+ "errors": "replace",
205
+ "model_max_length": 131072,
206
+ "pad_token": "<|endoftext|>",
207
+ "processor_class": "PrismaticProcessor",
208
+ "split_special_tokens": false,
209
+ "tokenizer_class": "Qwen2Tokenizer",
210
+ "unk_token": null
211
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff