xiaoyuxi commited on
Commit
9193cab
·
1 Parent(s): cd14f82

support HubMixin

Browse files
Files changed (41) hide show
  1. LICENSE.txt +409 -0
  2. config/magic_infer_offline.yaml +47 -0
  3. config/magic_infer_online.yaml +47 -0
  4. docs/PAPER.md +4 -0
  5. inference.py +184 -0
  6. models/SpaTrackV2/models/vggt4track/__init__.py +1 -0
  7. models/SpaTrackV2/models/vggt4track/heads/camera_head.py +162 -0
  8. models/SpaTrackV2/models/vggt4track/heads/dpt_head.py +497 -0
  9. models/SpaTrackV2/models/vggt4track/heads/head_act.py +125 -0
  10. models/SpaTrackV2/models/vggt4track/heads/scale_head.py +162 -0
  11. models/SpaTrackV2/models/vggt4track/heads/track_head.py +108 -0
  12. models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py +5 -0
  13. models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py +209 -0
  14. models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py +246 -0
  15. models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py +218 -0
  16. models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py +226 -0
  17. models/SpaTrackV2/models/vggt4track/heads/utils.py +109 -0
  18. models/SpaTrackV2/models/vggt4track/layers/__init__.py +11 -0
  19. models/SpaTrackV2/models/vggt4track/layers/attention.py +98 -0
  20. models/SpaTrackV2/models/vggt4track/layers/block.py +259 -0
  21. models/SpaTrackV2/models/vggt4track/layers/drop_path.py +34 -0
  22. models/SpaTrackV2/models/vggt4track/layers/layer_scale.py +27 -0
  23. models/SpaTrackV2/models/vggt4track/layers/mlp.py +40 -0
  24. models/SpaTrackV2/models/vggt4track/layers/patch_embed.py +88 -0
  25. models/SpaTrackV2/models/vggt4track/layers/rope.py +188 -0
  26. models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py +72 -0
  27. models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py +407 -0
  28. models/SpaTrackV2/models/vggt4track/models/aggregator.py +338 -0
  29. models/SpaTrackV2/models/vggt4track/models/aggregator_front.py +342 -0
  30. models/SpaTrackV2/models/vggt4track/models/tracker_front.py +132 -0
  31. models/SpaTrackV2/models/vggt4track/models/vggt.py +96 -0
  32. models/SpaTrackV2/models/vggt4track/models/vggt_moe.py +107 -0
  33. models/SpaTrackV2/models/vggt4track/utils/__init__.py +1 -0
  34. models/SpaTrackV2/models/vggt4track/utils/geometry.py +166 -0
  35. models/SpaTrackV2/models/vggt4track/utils/load_fn.py +200 -0
  36. models/SpaTrackV2/models/vggt4track/utils/loss.py +123 -0
  37. models/SpaTrackV2/models/vggt4track/utils/pose_enc.py +130 -0
  38. models/SpaTrackV2/models/vggt4track/utils/rotation.py +138 -0
  39. models/SpaTrackV2/models/vggt4track/utils/visual_track.py +239 -0
  40. scripts/download.sh +5 -0
  41. viz.html +2115 -0
LICENSE.txt ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
408
+
409
+
config/magic_infer_offline.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ # config the hydra logger, only in hydra `$` can be decoded as cite
3
+ data: ./assets/room
4
+ vis_track: false
5
+ hydra:
6
+ run:
7
+ dir: .
8
+ output_subdir: null
9
+ job_logging: {}
10
+ hydra_logging: {}
11
+ mixed_precision: bf16
12
+ visdom:
13
+ viz_ip: "localhost"
14
+ port: 6666
15
+ relax_load: false
16
+ res_all: 336
17
+ # config the ckpt path
18
+ ckpts: "Yuxihenry/SpatialTrackerCkpts"
19
+ batch_size: 1
20
+ input:
21
+ type: image
22
+ fps: 1
23
+ model_wind_size: 32
24
+ model:
25
+ backbone_cfg:
26
+ ckpt_dir: "checkpoints/model.pt"
27
+ chunk_size: 24 # downsample factor for patchified features
28
+ ckpt_fwd: true
29
+ ft_cfg:
30
+ mode: "fix"
31
+ paras_name: []
32
+ resolution: 336
33
+ max_len: 512
34
+ Track_cfg:
35
+ base_ckpt: "checkpoints/scaled_offline.pth"
36
+ base:
37
+ stride: 4
38
+ corr_radius: 3
39
+ window_len: 60
40
+ stablizer: True
41
+ mode: "online"
42
+ s_wind: 200
43
+ overlap: 4
44
+ track_num: 0
45
+
46
+ dist_train:
47
+ num_nodes: 1
config/magic_infer_online.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ # config the hydra logger, only in hydra `$` can be decoded as cite
3
+ data: ./assets/room
4
+ vis_track: false
5
+ hydra:
6
+ run:
7
+ dir: .
8
+ output_subdir: null
9
+ job_logging: {}
10
+ hydra_logging: {}
11
+ mixed_precision: bf16
12
+ visdom:
13
+ viz_ip: "localhost"
14
+ port: 6666
15
+ relax_load: false
16
+ res_all: 336
17
+ # config the ckpt path
18
+ ckpts: "Yuxihenry/SpatialTrackerCkpts"
19
+ batch_size: 1
20
+ input:
21
+ type: image
22
+ fps: 1
23
+ model_wind_size: 32
24
+ model:
25
+ backbone_cfg:
26
+ ckpt_dir: "checkpoints/model.pt"
27
+ chunk_size: 24 # downsample factor for patchified features
28
+ ckpt_fwd: true
29
+ ft_cfg:
30
+ mode: "fix"
31
+ paras_name: []
32
+ resolution: 336
33
+ max_len: 512
34
+ Track_cfg:
35
+ base_ckpt: "checkpoints/scaled_online.pth"
36
+ base:
37
+ stride: 4
38
+ corr_radius: 3
39
+ window_len: 20
40
+ stablizer: False
41
+ mode: "online"
42
+ s_wind: 20
43
+ overlap: 6
44
+ track_num: 0
45
+
46
+ dist_train:
47
+ num_nodes: 1
docs/PAPER.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # SpatialTrackerV2: Final version paper still polishing, ETA in one week.
2
+
3
+ ## Overall
4
+ SpatialTrackerV2 proposes a end-to-end and differentiable pipeline to unify video depth, camera pose and 3D tracking. This unified pipeline enable large-scale joint training of both part in diverse types of data.
inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pycolmap
2
+ from models.SpaTrackV2.models.predictor import Predictor
3
+ import yaml
4
+ import easydict
5
+ import os
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+ import io
12
+ import moviepy.editor as mp
13
+ from models.SpaTrackV2.utils.visualizer import Visualizer
14
+ import tqdm
15
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
16
+ import glob
17
+ from rich import print
18
+ import argparse
19
+ import decord
20
+ from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
21
+ from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
22
+ from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--track_mode", type=str, default="offline")
27
+ parser.add_argument("--data_type", type=str, default="RGBD")
28
+ parser.add_argument("--data_dir", type=str, default="assets/example0")
29
+ parser.add_argument("--video_name", type=str, default="snowboard")
30
+ parser.add_argument("--grid_size", type=int, default=10)
31
+ parser.add_argument("--vo_points", type=int, default=756)
32
+ parser.add_argument("--fps", type=int, default=1)
33
+ return parser.parse_args()
34
+
35
+ if __name__ == "__main__":
36
+ args = parse_args()
37
+ out_dir = args.data_dir + "/results"
38
+ # fps
39
+ fps = int(args.fps)
40
+ mask_dir = args.data_dir + f"/{args.video_name}.png"
41
+
42
+ vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
43
+ vggt4track_model.eval()
44
+ vggt4track_model = vggt4track_model.to("cuda")
45
+
46
+ if args.data_type == "RGBD":
47
+ npz_dir = args.data_dir + f"/{args.video_name}.npz"
48
+ data_npz_load = dict(np.load(npz_dir, allow_pickle=True))
49
+ #TODO: tapip format
50
+ video_tensor = data_npz_load["video"] * 255
51
+ video_tensor = torch.from_numpy(video_tensor)
52
+ video_tensor = video_tensor[::fps]
53
+ depth_tensor = data_npz_load["depths"]
54
+ depth_tensor = depth_tensor[::fps]
55
+ intrs = data_npz_load["intrinsics"]
56
+ intrs = intrs[::fps]
57
+ extrs = np.linalg.inv(data_npz_load["extrinsics"])
58
+ extrs = extrs[::fps]
59
+ unc_metric = None
60
+ elif args.data_type == "RGB":
61
+ vid_dir = os.path.join(args.data_dir, f"{args.video_name}.mp4")
62
+ video_reader = decord.VideoReader(vid_dir)
63
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
64
+ video_tensor = video_tensor[::fps].float()
65
+
66
+ # process the image tensor
67
+ video_tensor = preprocess_image(video_tensor)[None]
68
+ with torch.no_grad():
69
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
70
+ # Predict attributes including cameras, depth maps, and point maps.
71
+ predictions = vggt4track_model(video_tensor.cuda()/255)
72
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
73
+ depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
74
+
75
+ depth_tensor = depth_map.squeeze().cpu().numpy()
76
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
77
+ extrs = extrinsic.squeeze().cpu().numpy()
78
+ intrs = intrinsic.squeeze().cpu().numpy()
79
+ video_tensor = video_tensor.squeeze()
80
+ #NOTE: 20% of the depth is not reliable
81
+ # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
82
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
83
+
84
+ data_npz_load = {}
85
+
86
+ if os.path.exists(mask_dir):
87
+ mask_files = mask_dir
88
+ mask = cv2.imread(mask_files)
89
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
90
+ mask = mask.sum(axis=-1)>0
91
+ else:
92
+ mask = np.ones_like(video_tensor[0,0].numpy())>0
93
+
94
+ # get all data pieces
95
+ viz = True
96
+ os.makedirs(out_dir, exist_ok=True)
97
+
98
+ # with open(cfg_dir, "r") as f:
99
+ # cfg = yaml.load(f, Loader=yaml.FullLoader)
100
+ # cfg = easydict.EasyDict(cfg)
101
+ # cfg.out_dir = out_dir
102
+ # cfg.model.track_num = args.vo_points
103
+ # print(f"Downloading model from HuggingFace: {cfg.ckpts}")
104
+ if args.track_mode == "offline":
105
+ model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
106
+ else:
107
+ model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
108
+
109
+ # config the model; the track_num is the number of points in the grid
110
+ model.spatrack.track_num = args.vo_points
111
+
112
+ model.eval()
113
+ model.to("cuda")
114
+ viser = Visualizer(save_dir=out_dir, grayscale=True,
115
+ fps=10, pad_value=0, tracks_leave_trace=5)
116
+
117
+ grid_size = args.grid_size
118
+
119
+ # get frame H W
120
+ if video_tensor is None:
121
+ cap = cv2.VideoCapture(video_path)
122
+ frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
123
+ frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
124
+ else:
125
+ frame_H, frame_W = video_tensor.shape[2:]
126
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
127
+
128
+ # Sample mask values at grid points and filter out points where mask=0
129
+ if os.path.exists(mask_dir):
130
+ grid_pts_int = grid_pts[0].long()
131
+ mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
132
+ grid_pts = grid_pts[:, mask_values]
133
+
134
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
135
+
136
+ # Run model inference
137
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
138
+ (
139
+ c2w_traj, intrs, point_map, conf_depth,
140
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
141
+ ) = model.forward(video_tensor, depth=depth_tensor,
142
+ intrs=intrs, extrs=extrs,
143
+ queries=query_xyt,
144
+ fps=1, full_point=False, iters_track=4,
145
+ query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
146
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
147
+
148
+ # resize the results to avoid too large I/O Burden
149
+ # depth and image, the maximum side is 336
150
+ max_size = 336
151
+ h, w = video.shape[2:]
152
+ scale = min(max_size / h, max_size / w)
153
+ if scale < 1:
154
+ new_h, new_w = int(h * scale), int(w * scale)
155
+ video = T.Resize((new_h, new_w))(video)
156
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
157
+ point_map = T.Resize((new_h, new_w))(point_map)
158
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
159
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
160
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
161
+ if depth_tensor is not None:
162
+ if isinstance(depth_tensor, torch.Tensor):
163
+ depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
164
+ else:
165
+ depth_tensor = T.Resize((new_h, new_w))(torch.from_numpy(depth_tensor))
166
+
167
+ if viz:
168
+ viser.visualize(video=video[None],
169
+ tracks=track2d_pred[None][...,:2],
170
+ visibility=vis_pred[None],filename="test")
171
+
172
+ # save as the tapip3d format
173
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
174
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
175
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
176
+ depth_save = point_map[:,2,...]
177
+ depth_save[conf_depth<0.5] = 0
178
+ data_npz_load["depths"] = depth_save.cpu().numpy()
179
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
180
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
181
+ data_npz_load["unc_metric"] = conf_depth.cpu().numpy()
182
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
183
+
184
+ print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
models/SpaTrackV2/models/vggt4track/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/SpaTrackV2/models/vggt4track/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from models.SpaTrackV2.models.vggt4track.layers import Mlp
15
+ from models.SpaTrackV2.models.vggt4track.layers.block import Block
16
+ from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
models/SpaTrackV2/models/vggt4track/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.view(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.view(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.view(B, S, *preds.shape[1:])
259
+ conf = conf.view(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
models/SpaTrackV2/models/vggt4track/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
models/SpaTrackV2/models/vggt4track/heads/scale_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from models.SpaTrackV2.models.vggt4track.layers import Mlp
15
+ from models.SpaTrackV2.models.vggt4track.layers.block import Block
16
+ from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
17
+
18
+
19
+ class ScaleHead(nn.Module):
20
+ """
21
+ ScaleHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ self.target_dim = 2
41
+
42
+ self.trans_act = trans_act
43
+ self.quat_act = quat_act
44
+ self.fl_act = fl_act
45
+ self.trunk_depth = trunk_depth
46
+
47
+ # Build the trunk using a sequence of transformer blocks.
48
+ self.trunk = nn.Sequential(
49
+ *[
50
+ Block(
51
+ dim=dim_in,
52
+ num_heads=num_heads,
53
+ mlp_ratio=mlp_ratio,
54
+ init_values=init_values,
55
+ )
56
+ for _ in range(trunk_depth)
57
+ ]
58
+ )
59
+
60
+ # Normalizations for camera token and trunk output.
61
+ self.token_norm = nn.LayerNorm(dim_in)
62
+ self.trunk_norm = nn.LayerNorm(dim_in)
63
+
64
+ # Learnable empty camera pose token.
65
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
66
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
67
+
68
+ # Module for producing modulation parameters: shift, scale, and a gate.
69
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
70
+
71
+ # Adaptive layer normalization without affine parameters.
72
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
73
+ self.pose_branch = Mlp(
74
+ in_features=dim_in,
75
+ hidden_features=dim_in // 2,
76
+ out_features=self.target_dim,
77
+ drop=0,
78
+ )
79
+
80
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
81
+ """
82
+ Forward pass to predict camera parameters.
83
+
84
+ Args:
85
+ aggregated_tokens_list (list): List of token tensors from the network;
86
+ the last tensor is used for prediction.
87
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
88
+
89
+ Returns:
90
+ list: A list of predicted camera encodings (post-activation) from each iteration.
91
+ """
92
+ # Use tokens from the last block for camera prediction.
93
+ tokens = aggregated_tokens_list[-1]
94
+
95
+ # Extract the camera tokens
96
+ pose_tokens = tokens[:, :, 5]
97
+ pose_tokens = self.token_norm(pose_tokens)
98
+
99
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
100
+ return pred_pose_enc_list
101
+
102
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
103
+ """
104
+ Iteratively refine camera pose predictions.
105
+
106
+ Args:
107
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
108
+ num_iterations (int): Number of refinement iterations.
109
+
110
+ Returns:
111
+ list: List of activated camera encodings from each iteration.
112
+ """
113
+ B, S, C = pose_tokens.shape # S is expected to be 1.
114
+ pred_pose_enc = None
115
+ pred_pose_enc_list = []
116
+
117
+ for _ in range(num_iterations):
118
+ # Use a learned empty pose for the first iteration.
119
+ if pred_pose_enc is None:
120
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
121
+ else:
122
+ # Detach the previous prediction to avoid backprop through time.
123
+ pred_pose_enc = pred_pose_enc.detach()
124
+ module_input = self.embed_pose(pred_pose_enc)
125
+
126
+ # Generate modulation parameters and split them into shift, scale, and gate components.
127
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
128
+
129
+ # Adaptive layer normalization and modulation.
130
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
131
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
132
+
133
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
134
+ # Compute the delta update for the pose encoding.
135
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
136
+
137
+ if pred_pose_enc is None:
138
+ pred_pose_enc = pred_pose_enc_delta
139
+ else:
140
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
141
+
142
+ # Apply final activation functions for translation, quaternion, and field-of-view.
143
+ activated_pose = activate_pose(
144
+ pred_pose_enc,
145
+ trans_act=self.trans_act,
146
+ quat_act=self.quat_act,
147
+ fl_act=self.fl_act,
148
+ )
149
+ activated_pose_proc = activated_pose.clone()
150
+ activated_pose_proc[...,:1] = activated_pose_proc[...,:1].clamp(min=1e-5, max=1e3)
151
+ activated_pose_proc[...,1:] = activated_pose_proc[...,1:]*1e-2
152
+ pred_pose_enc_list.append(activated_pose_proc)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
models/SpaTrackV2/models/vggt4track/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(
60
+ hidden_size,
61
+ num_heads,
62
+ mlp_ratio=mlp_ratio,
63
+ attn_class=nn.MultiheadAttention,
64
+ )
65
+ for _ in range(time_depth)
66
+ ]
67
+ )
68
+
69
+ if add_space_attn:
70
+ self.space_virtual_blocks = nn.ModuleList(
71
+ [
72
+ AttnBlock(
73
+ hidden_size,
74
+ num_heads,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_class=nn.MultiheadAttention,
77
+ )
78
+ for _ in range(space_depth)
79
+ ]
80
+ )
81
+ self.space_point2virtual_blocks = nn.ModuleList(
82
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
83
+ )
84
+ self.space_virtual2point_blocks = nn.ModuleList(
85
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
86
+ )
87
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
88
+ self.initialize_weights()
89
+
90
+ def initialize_weights(self):
91
+ def _basic_init(module):
92
+ if isinstance(module, nn.Linear):
93
+ torch.nn.init.xavier_uniform_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
97
+
98
+ self.apply(_basic_init)
99
+
100
+ def forward(self, input_tensor, mask=None):
101
+ # Apply input LayerNorm
102
+ input_tensor = self.input_norm(input_tensor)
103
+ tokens = self.input_transform(input_tensor)
104
+
105
+ init_tokens = tokens
106
+
107
+ B, _, T, _ = tokens.shape
108
+
109
+ if self.add_space_attn:
110
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
111
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
112
+
113
+ _, N, _, _ = tokens.shape
114
+
115
+ j = 0
116
+ for i in range(len(self.time_blocks)):
117
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
118
+
119
+ time_tokens = self.time_blocks[i](time_tokens)
120
+
121
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
122
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
123
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
124
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
125
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
126
+
127
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
128
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
129
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
130
+
131
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
132
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
133
+ j += 1
134
+
135
+ if self.add_space_attn:
136
+ tokens = tokens[:, : N - self.num_virtual_tracks]
137
+
138
+ tokens = tokens + init_tokens
139
+
140
+ # Apply output LayerNorm before final projection
141
+ tokens = self.output_norm(tokens)
142
+ flow = self.flow_head(tokens)
143
+
144
+ return flow, None
145
+
146
+
147
+ class CorrBlock:
148
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
149
+ """
150
+ Build a pyramid of feature maps from the input.
151
+
152
+ fmaps: Tensor (B, S, C, H, W)
153
+ num_levels: number of pyramid levels (each downsampled by factor 2)
154
+ radius: search radius for sampling correlation
155
+ multiple_track_feats: if True, split the target features per pyramid level
156
+ padding_mode: passed to grid_sample / bilinear_sampler
157
+ """
158
+ B, S, C, H, W = fmaps.shape
159
+ self.S, self.C, self.H, self.W = S, C, H, W
160
+ self.num_levels = num_levels
161
+ self.radius = radius
162
+ self.padding_mode = padding_mode
163
+ self.multiple_track_feats = multiple_track_feats
164
+
165
+ # Build pyramid: each level is half the spatial resolution of the previous
166
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
167
+ current_fmaps = fmaps
168
+ for i in range(num_levels - 1):
169
+ B, S, C, H, W = current_fmaps.shape
170
+ # Merge batch & sequence dimensions
171
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
172
+ # Avg pool down by factor 2
173
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
174
+ _, _, H_new, W_new = current_fmaps.shape
175
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
176
+ self.fmaps_pyramid.append(current_fmaps)
177
+
178
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
179
+ # This grid is added to the (scaled) coordinate centroids.
180
+ r = self.radius
181
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
182
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
183
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
184
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
185
+
186
+ def corr_sample(self, targets, coords):
187
+ """
188
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
189
+ volume, sample it immediately, then discard it. This saves GPU memory.
190
+
191
+ Args:
192
+ targets: Tensor (B, S, N, C) — features for the current targets.
193
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
194
+
195
+ Returns:
196
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
197
+ """
198
+ B, S, N, C = targets.shape
199
+
200
+ # If you have multiple track features, split them per level.
201
+ if self.multiple_track_feats:
202
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
203
+
204
+ out_pyramid = []
205
+ for i, fmaps in enumerate(self.fmaps_pyramid):
206
+ # Get current spatial resolution H, W for this pyramid level.
207
+ B, S, C, H, W = fmaps.shape
208
+ # Reshape feature maps for correlation computation:
209
+ # fmap2s: (B, S, C, H*W)
210
+ fmap2s = fmaps.view(B, S, C, H * W)
211
+ # Choose appropriate target features.
212
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
213
+
214
+ # Compute correlation directly
215
+ corrs = compute_corr_level(fmap1, fmap2s, C)
216
+ corrs = corrs.view(B, S, N, H, W)
217
+
218
+ # Prepare sampling grid:
219
+ # Scale down the coordinates for the current level.
220
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
221
+ # Make sure our precomputed delta grid is on the same device/dtype.
222
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
223
+ # Now the grid for grid_sample is:
224
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
225
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
226
+
227
+ # Sample from the correlation volume using bilinear interpolation.
228
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
229
+ corrs_sampled = bilinear_sampler(
230
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
231
+ )
232
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
233
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
234
+ out_pyramid.append(corrs_sampled)
235
+
236
+ # Concatenate all levels along the last dimension.
237
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
238
+ return out
239
+
240
+
241
+ def compute_corr_level(fmap1, fmap2s, C):
242
+ # fmap1: (B, S, N, C)
243
+ # fmap2s: (B, S, C, H*W)
244
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
245
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
246
+ return corrs / math.sqrt(C)
models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
models/SpaTrackV2/models/vggt4track/heads/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ device = pos.device
49
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
50
+ omega /= embed_dim / 2.0
51
+ omega = 1.0 / omega_0**omega # (D/2,)
52
+
53
+ pos = pos.reshape(-1) # (M,)
54
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
55
+
56
+ emb_sin = torch.sin(out) # (M, D/2)
57
+ emb_cos = torch.cos(out) # (M, D/2)
58
+
59
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
60
+ return emb.float()
61
+
62
+
63
+ # Inspired by https://github.com/microsoft/moge
64
+
65
+
66
+ def create_uv_grid(
67
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
68
+ ) -> torch.Tensor:
69
+ """
70
+ Create a normalized UV grid of shape (width, height, 2).
71
+
72
+ The grid spans horizontally and vertically according to an aspect ratio,
73
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
74
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
75
+
76
+ Args:
77
+ width (int): Number of points horizontally.
78
+ height (int): Number of points vertically.
79
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
80
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
81
+ device (torch.device, optional): Device on which the tensor is created.
82
+
83
+ Returns:
84
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
85
+ """
86
+ # Derive aspect ratio if not explicitly provided
87
+ if aspect_ratio is None:
88
+ aspect_ratio = float(width) / float(height)
89
+
90
+ # Compute normalized spans for X and Y
91
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
92
+ span_x = aspect_ratio / diag_factor
93
+ span_y = 1.0 / diag_factor
94
+
95
+ # Establish the linspace boundaries
96
+ left_x = -span_x * (width - 1) / width
97
+ right_x = span_x * (width - 1) / width
98
+ top_y = -span_y * (height - 1) / height
99
+ bottom_y = span_y * (height - 1) / height
100
+
101
+ # Generate 1D coordinates
102
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
103
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
104
+
105
+ # Create 2D meshgrid (width x height) and stack into UV
106
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
107
+ uv_grid = torch.stack((uu, vv), dim=-1)
108
+
109
+ return uv_grid
models/SpaTrackV2/models/vggt4track/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
models/SpaTrackV2/models/vggt4track/layers/attention.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(
62
+ q,
63
+ k,
64
+ v,
65
+ dropout_p=self.attn_drop.p if self.training else 0.0,
66
+ )
67
+ else:
68
+ q = q * self.scale
69
+ attn = q @ k.transpose(-2, -1)
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+ x = attn @ v
73
+
74
+ x = x.transpose(1, 2).reshape(B, N, C)
75
+ x = self.proj(x)
76
+ x = self.proj_drop(x)
77
+ return x
78
+
79
+
80
+ class MemEffAttention(Attention):
81
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
82
+ assert pos is None
83
+ if not XFORMERS_AVAILABLE:
84
+ if attn_bias is not None:
85
+ raise AssertionError("xFormers is required for using nested tensors")
86
+ return super().forward(x)
87
+
88
+ B, N, C = x.shape
89
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
90
+
91
+ q, k, v = unbind(qkv, 2)
92
+
93
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
94
+ x = x.reshape([B, N, C])
95
+
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x
models/SpaTrackV2/models/vggt4track/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim,
71
+ hidden_features=mlp_hidden_dim,
72
+ act_layer=act_layer,
73
+ drop=drop,
74
+ bias=ffn_bias,
75
+ )
76
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
78
+
79
+ self.sample_drop_ratio = drop_path
80
+
81
+ def forward(self, x: Tensor, pos=None) -> Tensor:
82
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
83
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
84
+
85
+ def ffn_residual_func(x: Tensor) -> Tensor:
86
+ return self.ls2(self.mlp(self.norm2(x)))
87
+
88
+ if self.training and self.sample_drop_ratio > 0.1:
89
+ # the overhead is compensated only for a drop path rate larger than 0.1
90
+ x = drop_add_residual_stochastic_depth(
91
+ x,
92
+ pos=pos,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x, pos=pos)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ pos=None,
115
+ ) -> Tensor:
116
+ # 1) extract subset using permutation
117
+ b, n, d = x.shape
118
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
119
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
120
+ x_subset = x[brange]
121
+
122
+ # 2) apply residual_func to get residual
123
+ if pos is not None:
124
+ # if necessary, apply rope to the subset
125
+ pos = pos[brange]
126
+ residual = residual_func(x_subset, pos=pos)
127
+ else:
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
models/SpaTrackV2/models/vggt4track/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
models/SpaTrackV2/models/vggt4track/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
models/SpaTrackV2/models/vggt4track/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
models/SpaTrackV2/models/vggt4track/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
models/SpaTrackV2/models/vggt4track/layers/rope.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+
24
+ class PositionGetter:
25
+ """Generates and caches 2D spatial positions for patches in a grid.
26
+
27
+ This class efficiently manages the generation of spatial coordinates for patches
28
+ in a 2D grid, caching results to avoid redundant computations.
29
+
30
+ Attributes:
31
+ position_cache: Dictionary storing precomputed position tensors for different
32
+ grid dimensions.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the position generator with an empty cache."""
37
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
38
+
39
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
40
+ """Generates spatial positions for a batch of patches.
41
+
42
+ Args:
43
+ batch_size: Number of samples in the batch.
44
+ height: Height of the grid in patches.
45
+ width: Width of the grid in patches.
46
+ device: Target device for the position tensor.
47
+
48
+ Returns:
49
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
50
+ for each position in the grid, repeated for each batch item.
51
+ """
52
+ if (height, width) not in self.position_cache:
53
+ y_coords = torch.arange(height, device=device)
54
+ x_coords = torch.arange(width, device=device)
55
+ positions = torch.cartesian_prod(y_coords, x_coords)
56
+ self.position_cache[height, width] = positions
57
+
58
+ cached_positions = self.position_cache[height, width]
59
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
60
+
61
+
62
+ class RotaryPositionEmbedding2D(nn.Module):
63
+ """2D Rotary Position Embedding implementation.
64
+
65
+ This module applies rotary position embeddings to input tokens based on their
66
+ 2D spatial positions. It handles the position-dependent rotation of features
67
+ separately for vertical and horizontal dimensions.
68
+
69
+ Args:
70
+ frequency: Base frequency for the position embeddings. Default: 100.0
71
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
72
+
73
+ Attributes:
74
+ base_frequency: Base frequency for computing position embeddings.
75
+ scaling_factor: Factor to scale the computed frequencies.
76
+ frequency_cache: Cache for storing precomputed frequency components.
77
+ """
78
+
79
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
80
+ """Initializes the 2D RoPE module."""
81
+ super().__init__()
82
+ self.base_frequency = frequency
83
+ self.scaling_factor = scaling_factor
84
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
85
+
86
+ def _compute_frequency_components(
87
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Computes frequency components for rotary embeddings.
90
+
91
+ Args:
92
+ dim: Feature dimension (must be even).
93
+ seq_len: Maximum sequence length.
94
+ device: Target device for computations.
95
+ dtype: Data type for the computed tensors.
96
+
97
+ Returns:
98
+ Tuple of (cosine, sine) tensors for frequency components.
99
+ """
100
+ cache_key = (dim, seq_len, device, dtype)
101
+ if cache_key not in self.frequency_cache:
102
+ # Compute frequency bands
103
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
104
+ inv_freq = 1.0 / (self.base_frequency**exponents)
105
+
106
+ # Generate position-dependent frequencies
107
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
108
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
109
+
110
+ # Compute and cache frequency components
111
+ angles = angles.to(dtype)
112
+ angles = torch.cat((angles, angles), dim=-1)
113
+ cos_components = angles.cos().to(dtype)
114
+ sin_components = angles.sin().to(dtype)
115
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
116
+
117
+ return self.frequency_cache[cache_key]
118
+
119
+ @staticmethod
120
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
121
+ """Performs feature rotation by splitting and recombining feature dimensions.
122
+
123
+ Args:
124
+ x: Input tensor to rotate.
125
+
126
+ Returns:
127
+ Rotated feature tensor.
128
+ """
129
+ feature_dim = x.shape[-1]
130
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def _apply_1d_rope(
134
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """Applies 1D rotary position embeddings along one dimension.
137
+
138
+ Args:
139
+ tokens: Input token features.
140
+ positions: Position indices.
141
+ cos_comp: Cosine components for rotation.
142
+ sin_comp: Sine components for rotation.
143
+
144
+ Returns:
145
+ Tokens with applied rotary position embeddings.
146
+ """
147
+ # Embed positions with frequency components
148
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
149
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
150
+
151
+ # Apply rotation
152
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
153
+
154
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
155
+ """Applies 2D rotary position embeddings to input tokens.
156
+
157
+ Args:
158
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
159
+ The feature dimension (dim) must be divisible by 4.
160
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
161
+ the y and x coordinates for each token.
162
+
163
+ Returns:
164
+ Tensor of same shape as input with applied 2D rotary position embeddings.
165
+
166
+ Raises:
167
+ AssertionError: If input dimensions are invalid or positions are malformed.
168
+ """
169
+ # Validate inputs
170
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
171
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
172
+
173
+ # Compute feature dimension for each spatial direction
174
+ feature_dim = tokens.size(-1) // 2
175
+
176
+ # Get frequency components
177
+ max_position = int(positions.max()) + 1
178
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
179
+
180
+ # Split features for vertical and horizontal processing
181
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
182
+
183
+ # Apply RoPE separately for each dimension
184
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
185
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
186
+
187
+ # Combine processed features
188
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25
+ if not depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ for child_name, child_module in module.named_children():
28
+ child_name = ".".join((name, child_name)) if name else child_name
29
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30
+ if depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ return module
33
+
34
+
35
+ class BlockChunk(nn.ModuleList):
36
+ def forward(self, x):
37
+ for b in self:
38
+ x = b(x)
39
+ return x
40
+
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=Block,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ num_register_tokens=0,
64
+ interpolate_antialias=False,
65
+ interpolate_offset=0.1,
66
+ qk_norm=False,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ # tricky but makes it work
97
+ self.use_checkpoint = False
98
+ #
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ qk_norm=qk_norm,
154
+ )
155
+ for i in range(depth)
156
+ ]
157
+ if block_chunks > 0:
158
+ self.chunked_blocks = True
159
+ chunked_blocks = []
160
+ chunksize = depth // block_chunks
161
+ for i in range(0, depth, chunksize):
162
+ # this is to keep the block index consistent if we chunk the block list
163
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
164
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
165
+ else:
166
+ self.chunked_blocks = False
167
+ self.blocks = nn.ModuleList(blocks_list)
168
+
169
+ self.norm = norm_layer(embed_dim)
170
+ self.head = nn.Identity()
171
+
172
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
173
+
174
+ self.init_weights()
175
+
176
+ def init_weights(self):
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ nn.init.normal_(self.cls_token, std=1e-6)
179
+ if self.register_tokens is not None:
180
+ nn.init.normal_(self.register_tokens, std=1e-6)
181
+ named_apply(init_weights_vit_timm, self)
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
196
+ assert N == M * M
197
+ kwargs = {}
198
+ if self.interpolate_offset:
199
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
200
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
201
+ sx = float(w0 + self.interpolate_offset) / M
202
+ sy = float(h0 + self.interpolate_offset) / M
203
+ kwargs["scale_factor"] = (sx, sy)
204
+ else:
205
+ # Simply specify an output size instead of a scale factor
206
+ kwargs["size"] = (w0, h0)
207
+ patch_pos_embed = nn.functional.interpolate(
208
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
209
+ mode="bicubic",
210
+ antialias=self.interpolate_antialias,
211
+ **kwargs,
212
+ )
213
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
214
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
215
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
216
+
217
+ def prepare_tokens_with_masks(self, x, masks=None):
218
+ B, nc, w, h = x.shape
219
+ x = self.patch_embed(x)
220
+ if masks is not None:
221
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
222
+
223
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
224
+ x = x + self.interpolate_pos_encoding(x, w, h)
225
+
226
+ if self.register_tokens is not None:
227
+ x = torch.cat(
228
+ (
229
+ x[:, :1],
230
+ self.register_tokens.expand(x.shape[0], -1, -1),
231
+ x[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ return x
237
+
238
+ def forward_features_list(self, x_list, masks_list):
239
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
240
+
241
+ for blk in self.blocks:
242
+ if self.use_checkpoint:
243
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
244
+ else:
245
+ x = blk(x)
246
+
247
+ all_x = x
248
+ output = []
249
+ for x, masks in zip(all_x, masks_list):
250
+ x_norm = self.norm(x)
251
+ output.append(
252
+ {
253
+ "x_norm_clstoken": x_norm[:, 0],
254
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
255
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
256
+ "x_prenorm": x,
257
+ "masks": masks,
258
+ }
259
+ )
260
+ return output
261
+
262
+ def forward_features(self, x, masks=None):
263
+ if isinstance(x, list):
264
+ return self.forward_features_list(x, masks)
265
+
266
+ x = self.prepare_tokens_with_masks(x, masks)
267
+
268
+ for blk in self.blocks:
269
+ if self.use_checkpoint:
270
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
271
+ else:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=True, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
models/SpaTrackV2/models/vggt4track/models/aggregator.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
14
+ from models.SpaTrackV2.models.vggt4track.layers.block import Block
15
+ from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+ from torch.utils.checkpoint import checkpoint
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
22
+ _RESNET_STD = [0.229, 0.224, 0.225]
23
+
24
+
25
+ class Aggregator(nn.Module):
26
+ """
27
+ The Aggregator applies alternating-attention over input frames,
28
+ as described in VGGT: Visual Geometry Grounded Transformer.
29
+
30
+
31
+ Args:
32
+ img_size (int): Image size in pixels.
33
+ patch_size (int): Size of each patch for PatchEmbed.
34
+ embed_dim (int): Dimension of the token embeddings.
35
+ depth (int): Number of blocks.
36
+ num_heads (int): Number of attention heads.
37
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
38
+ num_register_tokens (int): Number of register tokens.
39
+ block_fn (nn.Module): The block type used for attention (Block by default).
40
+ qkv_bias (bool): Whether to include bias in QKV projections.
41
+ proj_bias (bool): Whether to include bias in the output projection.
42
+ ffn_bias (bool): Whether to include bias in MLP layers.
43
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
44
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
45
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
46
+ qk_norm (bool): Whether to apply QK normalization.
47
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
48
+ init_values (float): Init scale for layer scale.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ img_size=518,
54
+ patch_size=14,
55
+ embed_dim=1024,
56
+ depth=24,
57
+ num_heads=16,
58
+ mlp_ratio=4.0,
59
+ num_register_tokens=4,
60
+ block_fn=Block,
61
+ qkv_bias=True,
62
+ proj_bias=True,
63
+ ffn_bias=True,
64
+ patch_embed="dinov2_vitl14_reg",
65
+ aa_order=["frame", "global"],
66
+ aa_block_size=1,
67
+ qk_norm=True,
68
+ rope_freq=100,
69
+ init_values=0.01,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
74
+
75
+ # Initialize rotary position embedding if frequency > 0
76
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
77
+ self.position_getter = PositionGetter() if self.rope is not None else None
78
+
79
+ self.frame_blocks = nn.ModuleList(
80
+ [
81
+ block_fn(
82
+ dim=embed_dim,
83
+ num_heads=num_heads,
84
+ mlp_ratio=mlp_ratio,
85
+ qkv_bias=qkv_bias,
86
+ proj_bias=proj_bias,
87
+ ffn_bias=ffn_bias,
88
+ init_values=init_values,
89
+ qk_norm=qk_norm,
90
+ rope=self.rope,
91
+ )
92
+ for _ in range(depth)
93
+ ]
94
+ )
95
+
96
+ self.global_blocks = nn.ModuleList(
97
+ [
98
+ block_fn(
99
+ dim=embed_dim,
100
+ num_heads=num_heads,
101
+ mlp_ratio=mlp_ratio,
102
+ qkv_bias=qkv_bias,
103
+ proj_bias=proj_bias,
104
+ ffn_bias=ffn_bias,
105
+ init_values=init_values,
106
+ qk_norm=qk_norm,
107
+ rope=self.rope,
108
+ )
109
+ for _ in range(depth)
110
+ ]
111
+ )
112
+
113
+ self.depth = depth
114
+ self.aa_order = aa_order
115
+ self.patch_size = patch_size
116
+ self.aa_block_size = aa_block_size
117
+
118
+ # Validate that depth is divisible by aa_block_size
119
+ if self.depth % self.aa_block_size != 0:
120
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
121
+
122
+ self.aa_block_num = self.depth // self.aa_block_size
123
+
124
+ # Note: We have two camera tokens, one for the first frame and one for the rest
125
+ # The same applies for register tokens
126
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
127
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
128
+
129
+ # The patch tokens start after the camera and register tokens
130
+ self.patch_start_idx = 1 + num_register_tokens
131
+
132
+ # Initialize parameters with small values
133
+ nn.init.normal_(self.camera_token, std=1e-6)
134
+ nn.init.normal_(self.register_token, std=1e-6)
135
+
136
+ # Register normalization constants as buffers
137
+ for name, value in (
138
+ ("_resnet_mean", _RESNET_MEAN),
139
+ ("_resnet_std", _RESNET_STD),
140
+ ):
141
+ self.register_buffer(
142
+ name,
143
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
144
+ persistent=False,
145
+ )
146
+
147
+ def __build_patch_embed__(
148
+ self,
149
+ patch_embed,
150
+ img_size,
151
+ patch_size,
152
+ num_register_tokens,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ block_chunks=0,
156
+ init_values=1.0,
157
+ embed_dim=1024,
158
+ ):
159
+ """
160
+ Build the patch embed layer. If 'conv', we use a
161
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
162
+ """
163
+
164
+ if "conv" in patch_embed:
165
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
166
+ else:
167
+ vit_models = {
168
+ "dinov2_vitl14_reg": vit_large,
169
+ "dinov2_vitb14_reg": vit_base,
170
+ "dinov2_vits14_reg": vit_small,
171
+ "dinov2_vitg2_reg": vit_giant2,
172
+ }
173
+
174
+ self.patch_embed = vit_models[patch_embed](
175
+ img_size=img_size,
176
+ patch_size=patch_size,
177
+ num_register_tokens=num_register_tokens,
178
+ interpolate_antialias=interpolate_antialias,
179
+ interpolate_offset=interpolate_offset,
180
+ block_chunks=block_chunks,
181
+ init_values=init_values,
182
+ )
183
+
184
+ # Disable gradient updates for mask token
185
+ if hasattr(self.patch_embed, "mask_token"):
186
+ self.patch_embed.mask_token.requires_grad_(False)
187
+
188
+ def forward(
189
+ self,
190
+ images: torch.Tensor,
191
+ ) -> Tuple[List[torch.Tensor], int]:
192
+ """
193
+ Args:
194
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
195
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
196
+
197
+ Returns:
198
+ (list[torch.Tensor], int):
199
+ The list of outputs from the attention blocks,
200
+ and the patch_start_idx indicating where patch tokens begin.
201
+ """
202
+ B, S, C_in, H, W = images.shape
203
+
204
+ if C_in != 3:
205
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
206
+
207
+ # Normalize images and reshape for patch embed
208
+ images = (images - self._resnet_mean) / self._resnet_std
209
+
210
+ # Reshape to [B*S, C, H, W] for patch embedding
211
+ images = images.view(B * S, C_in, H, W)
212
+ patch_tokens = self.patch_embed(images)
213
+
214
+ if isinstance(patch_tokens, dict):
215
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
216
+
217
+ _, P, C = patch_tokens.shape
218
+
219
+ # Expand camera and register tokens to match batch size and sequence length
220
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
221
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
222
+
223
+ # Concatenate special tokens with patch tokens
224
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
225
+
226
+ pos = None
227
+ if self.rope is not None:
228
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
229
+
230
+ if self.patch_start_idx > 0:
231
+ # do not use position embedding for special tokens (camera and register tokens)
232
+ # so set pos to 0 for the special tokens
233
+ pos = pos + 1
234
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
235
+ pos = torch.cat([pos_special, pos], dim=1)
236
+
237
+ # update P because we added special tokens
238
+ _, P, C = tokens.shape
239
+
240
+ frame_idx = 0
241
+ global_idx = 0
242
+ output_list = []
243
+
244
+ for _ in range(self.aa_block_num):
245
+ for attn_type in self.aa_order:
246
+ if attn_type == "frame":
247
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
248
+ tokens, B, S, P, C, frame_idx, pos=pos
249
+ )
250
+ elif attn_type == "global":
251
+ tokens, global_idx, global_intermediates = self._process_global_attention(
252
+ tokens, B, S, P, C, global_idx, pos=pos
253
+ )
254
+ else:
255
+ raise ValueError(f"Unknown attention type: {attn_type}")
256
+
257
+ for i in range(len(frame_intermediates)):
258
+ # concat frame and global intermediates, [B x S x P x 2C]
259
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
260
+ output_list.append(concat_inter)
261
+
262
+ del concat_inter
263
+ del frame_intermediates
264
+ del global_intermediates
265
+ return output_list, self.patch_start_idx
266
+
267
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
268
+ """
269
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
270
+ """
271
+ # If needed, reshape tokens or positions:
272
+ if tokens.shape != (B * S, P, C):
273
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
274
+
275
+ if pos is not None and pos.shape != (B * S, P, 2):
276
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
277
+
278
+ intermediates = []
279
+
280
+ # by default, self.aa_block_size=1, which processes one block at a time
281
+ for _ in range(self.aa_block_size):
282
+ if self.training:
283
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=False)
284
+ else:
285
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
286
+ frame_idx += 1
287
+ intermediates.append(tokens.view(B, S, P, C))
288
+
289
+ return tokens, frame_idx, intermediates
290
+
291
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
292
+ """
293
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
294
+ """
295
+ if tokens.shape != (B, S * P, C):
296
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
297
+
298
+ if pos is not None and pos.shape != (B, S * P, 2):
299
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
300
+
301
+ intermediates = []
302
+
303
+ # by default, self.aa_block_size=1, which processes one block at a time
304
+ for _ in range(self.aa_block_size):
305
+ if self.training:
306
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=False)
307
+ else:
308
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
309
+ global_idx += 1
310
+ intermediates.append(tokens.view(B, S, P, C))
311
+
312
+ return tokens, global_idx, intermediates
313
+
314
+
315
+ def slice_expand_and_flatten(token_tensor, B, S):
316
+ """
317
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
318
+ 1) Uses the first position (index=0) for the first frame only
319
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
320
+ 3) Expands both to match batch size B
321
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
322
+ followed by (S-1) second-position tokens
323
+ 5) Flattens to (B*S, X, C) for processing
324
+
325
+ Returns:
326
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
327
+ """
328
+
329
+ # Slice out the "query" tokens => shape (1, 1, ...)
330
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
331
+ # Slice out the "other" tokens => shape (1, S-1, ...)
332
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
333
+ # Concatenate => shape (B, S, ...)
334
+ combined = torch.cat([query, others], dim=1)
335
+
336
+ # Finally flatten => shape (B*S, ...)
337
+ combined = combined.view(B * S, *combined.shape[2:])
338
+ return combined
models/SpaTrackV2/models/vggt4track/models/aggregator_front.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
14
+ from models.SpaTrackV2.models.vggt4track.layers.block import Block
15
+ from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+ from torch.utils.checkpoint import checkpoint
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
22
+ _RESNET_STD = [0.229, 0.224, 0.225]
23
+
24
+
25
+ class Aggregator(nn.Module):
26
+ """
27
+ The Aggregator applies alternating-attention over input frames,
28
+ as described in VGGT: Visual Geometry Grounded Transformer.
29
+
30
+
31
+ Args:
32
+ img_size (int): Image size in pixels.
33
+ patch_size (int): Size of each patch for PatchEmbed.
34
+ embed_dim (int): Dimension of the token embeddings.
35
+ depth (int): Number of blocks.
36
+ num_heads (int): Number of attention heads.
37
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
38
+ num_register_tokens (int): Number of register tokens.
39
+ block_fn (nn.Module): The block type used for attention (Block by default).
40
+ qkv_bias (bool): Whether to include bias in QKV projections.
41
+ proj_bias (bool): Whether to include bias in the output projection.
42
+ ffn_bias (bool): Whether to include bias in MLP layers.
43
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
44
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
45
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
46
+ qk_norm (bool): Whether to apply QK normalization.
47
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
48
+ init_values (float): Init scale for layer scale.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ img_size=518,
54
+ patch_size=14,
55
+ embed_dim=1024,
56
+ depth=24,
57
+ num_heads=16,
58
+ mlp_ratio=4.0,
59
+ num_register_tokens=4,
60
+ block_fn=Block,
61
+ qkv_bias=True,
62
+ proj_bias=True,
63
+ ffn_bias=True,
64
+ patch_embed="dinov2_vitl14_reg",
65
+ aa_order=["frame", "global"],
66
+ aa_block_size=1,
67
+ qk_norm=True,
68
+ rope_freq=100,
69
+ init_values=0.01,
70
+ ):
71
+ super().__init__()
72
+
73
+ # self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
74
+
75
+ self.use_reentrant = False
76
+ # Initialize rotary position embedding if frequency > 0
77
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
78
+ self.position_getter = PositionGetter() if self.rope is not None else None
79
+
80
+ self.frame_blocks = nn.ModuleList(
81
+ [
82
+ block_fn(
83
+ dim=embed_dim,
84
+ num_heads=num_heads,
85
+ mlp_ratio=mlp_ratio,
86
+ qkv_bias=qkv_bias,
87
+ proj_bias=proj_bias,
88
+ ffn_bias=ffn_bias,
89
+ init_values=init_values,
90
+ qk_norm=qk_norm,
91
+ rope=self.rope,
92
+ )
93
+ for _ in range(depth)
94
+ ]
95
+ )
96
+
97
+ self.global_blocks = nn.ModuleList(
98
+ [
99
+ block_fn(
100
+ dim=embed_dim,
101
+ num_heads=num_heads,
102
+ mlp_ratio=mlp_ratio,
103
+ qkv_bias=qkv_bias,
104
+ proj_bias=proj_bias,
105
+ ffn_bias=ffn_bias,
106
+ init_values=init_values,
107
+ qk_norm=qk_norm,
108
+ rope=self.rope,
109
+ )
110
+ for _ in range(depth)
111
+ ]
112
+ )
113
+
114
+ self.depth = depth
115
+ self.aa_order = aa_order
116
+ self.patch_size = patch_size
117
+ self.aa_block_size = aa_block_size
118
+
119
+ # Validate that depth is divisible by aa_block_size
120
+ if self.depth % self.aa_block_size != 0:
121
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
122
+
123
+ self.aa_block_num = self.depth // self.aa_block_size
124
+
125
+ # Note: We have two camera tokens, one for the first frame and one for the rest
126
+ # The same applies for register tokens
127
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
128
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
129
+ self.scale_shift_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
130
+
131
+ # The patch tokens start after the camera and register tokens
132
+ self.patch_start_idx = 1 + num_register_tokens + 1
133
+
134
+ # Initialize parameters with small values
135
+ nn.init.normal_(self.camera_token, std=1e-6)
136
+ nn.init.normal_(self.register_token, std=1e-6)
137
+ nn.init.normal_(self.scale_shift_token, std=1e-6)
138
+
139
+ # Register normalization constants as buffers
140
+ for name, value in (
141
+ ("_resnet_mean", _RESNET_MEAN),
142
+ ("_resnet_std", _RESNET_STD),
143
+ ):
144
+ self.register_buffer(
145
+ name,
146
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
147
+ persistent=False,
148
+ )
149
+
150
+ def __build_patch_embed__(
151
+ self,
152
+ patch_embed,
153
+ img_size,
154
+ patch_size,
155
+ num_register_tokens,
156
+ interpolate_antialias=True,
157
+ interpolate_offset=0.0,
158
+ block_chunks=0,
159
+ init_values=1.0,
160
+ embed_dim=1024,
161
+ ):
162
+ """
163
+ Build the patch embed layer. If 'conv', we use a
164
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
165
+ """
166
+
167
+ if "conv" in patch_embed:
168
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
169
+ else:
170
+ vit_models = {
171
+ "dinov2_vitl14_reg": vit_large,
172
+ "dinov2_vitb14_reg": vit_base,
173
+ "dinov2_vits14_reg": vit_small,
174
+ "dinov2_vitg2_reg": vit_giant2,
175
+ }
176
+
177
+ self.patch_embed = vit_models[patch_embed](
178
+ img_size=img_size,
179
+ patch_size=patch_size,
180
+ num_register_tokens=num_register_tokens,
181
+ interpolate_antialias=interpolate_antialias,
182
+ interpolate_offset=interpolate_offset,
183
+ block_chunks=block_chunks,
184
+ init_values=init_values,
185
+ )
186
+
187
+ # Disable gradient updates for mask token
188
+ if hasattr(self.patch_embed, "mask_token"):
189
+ self.patch_embed.mask_token.requires_grad_(False)
190
+
191
+ def forward(
192
+ self,
193
+ images: torch.Tensor,
194
+ patch_tokens: torch.Tensor,
195
+ ) -> Tuple[List[torch.Tensor], int]:
196
+ """
197
+ Args:
198
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
199
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
200
+
201
+ Returns:
202
+ (list[torch.Tensor], int):
203
+ The list of outputs from the attention blocks,
204
+ and the patch_start_idx indicating where patch tokens begin.
205
+ """
206
+ B, S, C_in, H, W = images.shape
207
+
208
+ # if C_in != 3:
209
+ # raise ValueError(f"Expected 3 input channels, got {C_in}")
210
+
211
+ # # Normalize images and reshape for patch embed
212
+ # images = (images - self._resnet_mean) / self._resnet_std
213
+
214
+ # # Reshape to [B*S, C, H, W] for patch embedding
215
+ # images = images.view(B * S, C_in, H, W)
216
+ # patch_tokens = self.patch_embed(images)
217
+
218
+ if isinstance(patch_tokens, dict):
219
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
220
+
221
+ _, P, C = patch_tokens.shape
222
+ # Expand camera and register tokens to match batch size and sequence length
223
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
224
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
225
+ scale_shift_token = slice_expand_and_flatten(self.scale_shift_token, B, S)
226
+
227
+ # Concatenate special tokens with patch tokens
228
+ tokens = torch.cat([camera_token, register_token, scale_shift_token, patch_tokens], dim=1)
229
+
230
+ pos = None
231
+ if self.rope is not None:
232
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
233
+
234
+ if self.patch_start_idx > 0:
235
+ # do not use position embedding for special tokens (camera and register tokens)
236
+ # so set pos to 0 for the special tokens
237
+ pos = pos + 1
238
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
239
+ pos = torch.cat([pos_special, pos], dim=1)
240
+
241
+ # update P because we added special tokens
242
+ _, P, C = tokens.shape
243
+
244
+ frame_idx = 0
245
+ global_idx = 0
246
+ output_list = []
247
+
248
+ for _ in range(self.aa_block_num):
249
+ for attn_type in self.aa_order:
250
+ if attn_type == "frame":
251
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
252
+ tokens, B, S, P, C, frame_idx, pos=pos
253
+ )
254
+ elif attn_type == "global":
255
+ tokens, global_idx, global_intermediates = self._process_global_attention(
256
+ tokens, B, S, P, C, global_idx, pos=pos
257
+ )
258
+ else:
259
+ raise ValueError(f"Unknown attention type: {attn_type}")
260
+
261
+ for i in range(len(frame_intermediates)):
262
+ # concat frame and global intermediates, [B x S x P x 2C]
263
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
264
+ output_list.append(concat_inter)
265
+
266
+ del concat_inter
267
+ del frame_intermediates
268
+ del global_intermediates
269
+ return output_list, self.patch_start_idx
270
+
271
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
272
+ """
273
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
274
+ """
275
+ # If needed, reshape tokens or positions:
276
+ if tokens.shape != (B * S, P, C):
277
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
278
+
279
+ if pos is not None and pos.shape != (B * S, P, 2):
280
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
281
+
282
+ intermediates = []
283
+
284
+ # by default, self.aa_block_size=1, which processes one block at a time
285
+ for _ in range(self.aa_block_size):
286
+ if self.training:
287
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
288
+ else:
289
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
290
+ frame_idx += 1
291
+ intermediates.append(tokens.view(B, S, P, C))
292
+
293
+ return tokens, frame_idx, intermediates
294
+
295
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
296
+ """
297
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
298
+ """
299
+ if tokens.shape != (B, S * P, C):
300
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
301
+
302
+ if pos is not None and pos.shape != (B, S * P, 2):
303
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
304
+
305
+ intermediates = []
306
+
307
+ # by default, self.aa_block_size=1, which processes one block at a time
308
+ for _ in range(self.aa_block_size):
309
+ if self.training:
310
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
311
+ else:
312
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
313
+ global_idx += 1
314
+ intermediates.append(tokens.view(B, S, P, C))
315
+
316
+ return tokens, global_idx, intermediates
317
+
318
+
319
+ def slice_expand_and_flatten(token_tensor, B, S):
320
+ """
321
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
322
+ 1) Uses the first position (index=0) for the first frame only
323
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
324
+ 3) Expands both to match batch size B
325
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
326
+ followed by (S-1) second-position tokens
327
+ 5) Flattens to (B*S, X, C) for processing
328
+
329
+ Returns:
330
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
331
+ """
332
+
333
+ # Slice out the "query" tokens => shape (1, 1, ...)
334
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
335
+ # Slice out the "other" tokens => shape (1, S-1, ...)
336
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
337
+ # Concatenate => shape (B, S, ...)
338
+ combined = torch.cat([query, others], dim=1)
339
+
340
+ # Finally flatten => shape (B*S, ...)
341
+ combined = combined.view(B * S, *combined.shape[2:])
342
+ return combined
models/SpaTrackV2/models/vggt4track/models/tracker_front.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.checkpoint import checkpoint
10
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
11
+
12
+ from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator
13
+ from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
14
+ from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead
15
+ from einops import rearrange
16
+ from models.SpaTrackV2.utils.loss import compute_loss
17
+ from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri
18
+ import torch.nn.functional as F
19
+
20
+ class FrontTracker(nn.Module, PyTorchModelHubMixin):
21
+ def __init__(self, img_size=518,
22
+ patch_size=14, embed_dim=1024, base_model=None, use_checkpoint=True, use_scale_head=False):
23
+ super().__init__()
24
+
25
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
26
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
27
+ if use_scale_head:
28
+ self.scale_head = ScaleHead(dim_in=2 * embed_dim)
29
+ else:
30
+ self.scale_head = None
31
+ self.base_model = base_model
32
+ self.use_checkpoint = use_checkpoint
33
+ self.intermediate_layers = [4, 11, 17, 23]
34
+ self.residual_proj = nn.ModuleList([nn.Linear(2048, 1024) for _ in range(len(self.intermediate_layers))])
35
+ # init the residual proj
36
+ for i in range(len(self.intermediate_layers)):
37
+ nn.init.xavier_uniform_(self.residual_proj[i].weight)
38
+ nn.init.zeros_(self.residual_proj[i].bias)
39
+ # self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
40
+ # self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
41
+ # self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
42
+
43
+ def forward(self,
44
+ images: torch.Tensor,
45
+ annots = {},
46
+ **kwargs):
47
+ """
48
+ Forward pass of the FrontTracker model.
49
+
50
+ Args:
51
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
52
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
53
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
54
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
55
+ Default: None
56
+
57
+ Returns:
58
+ dict: A dictionary containing the following predictions:
59
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
60
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
61
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
62
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
63
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
64
+ - images (torch.Tensor): Original input images, preserved for visualization
65
+
66
+ If query_points is provided, also includes:
67
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
68
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
69
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
70
+ """
71
+
72
+ # If without batch dimension, add it
73
+ if len(images.shape) == 4:
74
+ images = images.unsqueeze(0)
75
+ B, T, C, H, W = images.shape
76
+ images = (images - self.base_model.image_mean) / self.base_model.image_std
77
+ H_14 = H // 14 * 14
78
+ W_14 = W // 14 * 14
79
+ image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
80
+
81
+ with torch.no_grad():
82
+ features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
83
+ self.base_model.intermediate_layers, return_class_token=True)
84
+ # aggregate the features with checkpoint
85
+ aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
86
+
87
+ # enhance the features
88
+ enhanced_features = []
89
+ for layer_i, layer in enumerate(self.intermediate_layers):
90
+ # patch_feat_i = features[layer_i][0] + self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
91
+ patch_feat_i = self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
92
+ enhance_i = (patch_feat_i, features[layer_i][1])
93
+ enhanced_features.append(enhance_i)
94
+
95
+ predictions = {}
96
+
97
+ with torch.cuda.amp.autocast(enabled=False):
98
+ if self.camera_head is not None:
99
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
100
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
101
+ if self.scale_head is not None:
102
+ scale_list = self.scale_head(aggregated_tokens_list)
103
+ predictions["scale"] = scale_list[-1] # scale of the last iteration
104
+ # Predict points (and mask) with checkpoint
105
+ output = self.base_model.head(enhanced_features, image_14)
106
+ points, mask = output
107
+
108
+ # Post-process points and mask
109
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
110
+ points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
111
+ # prepare the predictions
112
+ predictions["images"] = (images * self.base_model.image_std + self.base_model.image_mean)*255.0
113
+ points = F.interpolate(points.permute(0, 3, 1, 2), (H, W), mode="bilinear", align_corners=False, antialias=True).permute(0, 2, 3, 1)
114
+ predictions["points_map"] = points
115
+ mask = F.interpolate(mask.unsqueeze(1), (H, W), mode="bilinear", align_corners=False, antialias=True).squeeze(1)
116
+ predictions["unc_metric"] = mask
117
+ predictions["pose_enc_list"] = pose_enc_list
118
+
119
+ if self.training:
120
+ loss = compute_loss(predictions, annots)
121
+ predictions["loss"] = loss
122
+
123
+ # rescale the points
124
+ if self.scale_head is not None:
125
+ points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
126
+ points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
127
+ predictions["points_map"] = points_scale
128
+
129
+ predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
130
+ predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
131
+ predictions["images"].shape[-2:])
132
+ return predictions
models/SpaTrackV2/models/vggt4track/models/vggt.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
+
11
+ from vggt.models.aggregator import Aggregator
12
+ from vggt.heads.camera_head import CameraHead
13
+ from vggt.heads.dpt_head import DPTHead
14
+ from vggt.heads.track_head import TrackHead
15
+
16
+
17
+ class VGGT(nn.Module, PyTorchModelHubMixin):
18
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
19
+ super().__init__()
20
+
21
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
22
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
23
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
24
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
25
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
26
+
27
+ def forward(
28
+ self,
29
+ images: torch.Tensor,
30
+ query_points: torch.Tensor = None,
31
+ ):
32
+ """
33
+ Forward pass of the VGGT model.
34
+
35
+ Args:
36
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
37
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
38
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
39
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
40
+ Default: None
41
+
42
+ Returns:
43
+ dict: A dictionary containing the following predictions:
44
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
45
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
46
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
47
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
48
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
49
+ - images (torch.Tensor): Original input images, preserved for visualization
50
+
51
+ If query_points is provided, also includes:
52
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
53
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
54
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
55
+ """
56
+
57
+ # If without batch dimension, add it
58
+ if len(images.shape) == 4:
59
+ images = images.unsqueeze(0)
60
+ if query_points is not None and len(query_points.shape) == 2:
61
+ query_points = query_points.unsqueeze(0)
62
+
63
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
64
+
65
+ predictions = {}
66
+
67
+ with torch.cuda.amp.autocast(enabled=False):
68
+ if self.camera_head is not None:
69
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
70
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
71
+
72
+ if self.depth_head is not None:
73
+ depth, depth_conf = self.depth_head(
74
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
75
+ )
76
+ predictions["depth"] = depth
77
+ predictions["depth_conf"] = depth_conf
78
+
79
+ if self.point_head is not None:
80
+ pts3d, pts3d_conf = self.point_head(
81
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
82
+ )
83
+ predictions["world_points"] = pts3d
84
+ predictions["world_points_conf"] = pts3d_conf
85
+
86
+ if self.track_head is not None and query_points is not None:
87
+ track_list, vis, conf = self.track_head(
88
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
89
+ )
90
+ predictions["track"] = track_list[-1] # track of the last iteration
91
+ predictions["vis"] = vis
92
+ predictions["conf"] = conf
93
+
94
+ predictions["images"] = images
95
+
96
+ return predictions
models/SpaTrackV2/models/vggt4track/models/vggt_moe.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
+
11
+ from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator
12
+ from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
13
+ from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead
14
+ from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead
15
+ from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss
16
+ from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
17
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
18
+ from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
19
+ from einops import rearrange
20
+ import torch.nn.functional as F
21
+
22
+ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
23
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
24
+ super().__init__()
25
+
26
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
27
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
28
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="sigmoid")
29
+
30
+ def forward(
31
+ self,
32
+ images: torch.Tensor,
33
+ annots = {},
34
+ **kwargs):
35
+ """
36
+ Forward pass of the VGGT4Track model.
37
+
38
+ Args:
39
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
40
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
41
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
42
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
43
+ Default: None
44
+
45
+ Returns:
46
+ dict: A dictionary containing the following predictions:
47
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
48
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
49
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
50
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
51
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
52
+ - images (torch.Tensor): Original input images, preserved for visualization
53
+
54
+ If query_points is provided, also includes:
55
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
56
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
57
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
58
+ """
59
+
60
+ # If without batch dimension, add it
61
+ B, T, C, H, W = images.shape
62
+ images_proc = preprocess_image(images.view(B*T, C, H, W).clone())
63
+ images_proc = rearrange(images_proc, '(b t) c h w -> b t c h w', b=B, t=T)
64
+ _, _, _, H_proc, W_proc = images_proc.shape
65
+
66
+ if len(images.shape) == 4:
67
+ images = images.unsqueeze(0)
68
+
69
+ with torch.no_grad():
70
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
71
+
72
+ predictions = {}
73
+
74
+ with torch.cuda.amp.autocast(enabled=False):
75
+ if self.camera_head is not None:
76
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
77
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
78
+ predictions["pose_enc_list"] = pose_enc_list
79
+
80
+ if self.depth_head is not None:
81
+ depth, depth_conf = self.depth_head(
82
+ aggregated_tokens_list, images=images_proc, patch_start_idx=patch_start_idx
83
+ )
84
+ predictions["depth"] = depth
85
+ predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
86
+
87
+ predictions["images"] = (images)*255.0
88
+ # output the camera pose
89
+ predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
90
+ predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
91
+ images_proc.shape[-2:])
92
+ predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
93
+ points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
94
+ predictions["points_map"] = points_map
95
+ #NOTE: resize back
96
+ predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
97
+ size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
98
+ predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
99
+ size=(H, W), mode='bilinear', align_corners=True)[:,0]
100
+ predictions["intrs"][..., :1, :] *= W/W_proc
101
+ predictions["intrs"][..., 1:2, :] *= H/H_proc
102
+
103
+ if self.training:
104
+ loss = compute_loss(predictions, annots)
105
+ predictions["loss"] = loss
106
+
107
+ return predictions
models/SpaTrackV2/models/vggt4track/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/SpaTrackV2/models/vggt4track/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
models/SpaTrackV2/models/vggt4track/utils/load_fn.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 518
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+
55
+ # Open image
56
+ img = Image.open(image_path)
57
+
58
+ # If there's an alpha channel, blend onto white background:
59
+ if img.mode == "RGBA":
60
+ # Create white background
61
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
62
+ # Alpha composite onto the white background
63
+ img = Image.alpha_composite(background, img)
64
+
65
+ # Now convert to "RGB" (this step assigns white for transparent areas)
66
+ img = img.convert("RGB")
67
+
68
+ width, height = img.size
69
+
70
+ if mode == "pad":
71
+ # Make the largest dimension 518px while maintaining aspect ratio
72
+ if width >= height:
73
+ new_width = target_size
74
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
75
+ else:
76
+ new_height = target_size
77
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
78
+ else: # mode == "crop"
79
+ # Original behavior: set width to 518px
80
+ new_width = target_size
81
+ # Calculate height maintaining aspect ratio, divisible by 14
82
+ new_height = round(height * (new_width / width) / 14) * 14
83
+
84
+ # Resize with new dimensions (width, height)
85
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
86
+ img = to_tensor(img) # Convert to tensor (0, 1)
87
+
88
+ # Center crop height if it's larger than 518 (only in crop mode)
89
+ if mode == "crop" and new_height > target_size:
90
+ start_y = (new_height - target_size) // 2
91
+ img = img[:, start_y : start_y + target_size, :]
92
+
93
+ # For pad mode, pad to make a square of target_size x target_size
94
+ if mode == "pad":
95
+ h_padding = target_size - img.shape[1]
96
+ w_padding = target_size - img.shape[2]
97
+
98
+ if h_padding > 0 or w_padding > 0:
99
+ pad_top = h_padding // 2
100
+ pad_bottom = h_padding - pad_top
101
+ pad_left = w_padding // 2
102
+ pad_right = w_padding - pad_left
103
+
104
+ # Pad with white (value=1.0)
105
+ img = torch.nn.functional.pad(
106
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
107
+ )
108
+
109
+ shapes.add((img.shape[1], img.shape[2]))
110
+ images.append(img)
111
+
112
+ # Check if we have different shapes
113
+ # In theory our model can also work well with different shapes
114
+ if len(shapes) > 1:
115
+ print(f"Warning: Found images with different shapes: {shapes}")
116
+ # Find maximum dimensions
117
+ max_height = max(shape[0] for shape in shapes)
118
+ max_width = max(shape[1] for shape in shapes)
119
+
120
+ # Pad images if necessary
121
+ padded_images = []
122
+ for img in images:
123
+ h_padding = max_height - img.shape[1]
124
+ w_padding = max_width - img.shape[2]
125
+
126
+ if h_padding > 0 or w_padding > 0:
127
+ pad_top = h_padding // 2
128
+ pad_bottom = h_padding - pad_top
129
+ pad_left = w_padding // 2
130
+ pad_right = w_padding - pad_left
131
+
132
+ img = torch.nn.functional.pad(
133
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
134
+ )
135
+ padded_images.append(img)
136
+ images = padded_images
137
+
138
+ images = torch.stack(images) # concatenate images
139
+
140
+ # Ensure correct shape when single image
141
+ if len(image_path_list) == 1:
142
+ # Verify shape is (1, C, H, W)
143
+ if images.dim() == 3:
144
+ images = images.unsqueeze(0)
145
+
146
+ return images
147
+
148
+ def preprocess_image(img_tensor, mode="crop", target_size=518):
149
+ """
150
+ Preprocess image tensor(s) to target size with crop or pad mode.
151
+ Args:
152
+ img_tensor (torch.Tensor): Image tensor of shape (C, H, W) or (T, C, H, W), values in [0, 1]
153
+ mode (str): 'crop' or 'pad'
154
+ target_size (int): Target size for width/height
155
+ Returns:
156
+ torch.Tensor: Preprocessed image tensor(s), same batch dim as input
157
+ """
158
+ if mode not in ["crop", "pad"]:
159
+ raise ValueError("Mode must be either 'crop' or 'pad'")
160
+ if img_tensor.dim() == 3:
161
+ tensors = [img_tensor]
162
+ squeeze = True
163
+ elif img_tensor.dim() == 4:
164
+ tensors = list(img_tensor)
165
+ squeeze = False
166
+ else:
167
+ raise ValueError("Input tensor must be (C, H, W) or (T, C, H, W)")
168
+ processed = []
169
+ for img in tensors:
170
+ C, H, W = img.shape
171
+ if mode == "pad":
172
+ if W >= H:
173
+ new_W = target_size
174
+ new_H = round(H * (new_W / W) / 14) * 14
175
+ else:
176
+ new_H = target_size
177
+ new_W = round(W * (new_H / H) / 14) * 14
178
+ out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
179
+ h_padding = target_size - new_H
180
+ w_padding = target_size - new_W
181
+ pad_top = h_padding // 2
182
+ pad_bottom = h_padding - pad_top
183
+ pad_left = w_padding // 2
184
+ pad_right = w_padding - pad_left
185
+ if h_padding > 0 or w_padding > 0:
186
+ out = torch.nn.functional.pad(
187
+ out, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
188
+ )
189
+ else: # crop
190
+ new_W = target_size
191
+ new_H = round(H * (new_W / W) / 14) * 14
192
+ out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
193
+ if new_H > target_size:
194
+ start_y = (new_H - target_size) // 2
195
+ out = out[:, start_y : start_y + target_size, :]
196
+ processed.append(out)
197
+ result = torch.stack(processed)
198
+ if squeeze:
199
+ return result[0]
200
+ return result
models/SpaTrackV2/models/vggt4track/utils/loss.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains the loss functions for FrontTracker
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import utils3d
6
+ from models.moge.train.losses import (
7
+ affine_invariant_global_loss,
8
+ affine_invariant_local_loss,
9
+ edge_loss,
10
+ normal_loss,
11
+ mask_l2_loss,
12
+ mask_bce_loss,
13
+ monitoring,
14
+ )
15
+ import torch.nn.functional as F
16
+ from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
17
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
18
+ from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
19
+
20
+ def compute_loss(predictions, annots):
21
+ """
22
+ Compute the loss for the FrontTracker model.
23
+ """
24
+
25
+ B, T, C, H, W = predictions["images"].shape
26
+ H_resize, W_resize = H, W
27
+
28
+ if "poses_gt" in annots.keys():
29
+ intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
30
+ H_resize, W_resize, min(H, W))
31
+ else:
32
+ c2w_traj_gt = None
33
+
34
+ if "intrs_gt" in annots.keys():
35
+ intrs = annots["intrs_gt"].view(B, T, 3, 3)
36
+ fx_factor = W_resize / W
37
+ fy_factor = H_resize / H
38
+ intrs[:,:,0,:] *= fx_factor
39
+ intrs[:,:,1,:] *= fy_factor
40
+
41
+ if "depth_gt" in annots.keys():
42
+
43
+ metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
44
+ metric_depth_gt = F.interpolate(metric_depth_gt,
45
+ size=(H_resize, W_resize), mode='nearest')
46
+
47
+ _depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
48
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
49
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
50
+ iqr = q75 - q25
51
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
52
+ _depth_roi = torch.tensor(
53
+ [1e-1, upper_bound.item()],
54
+ dtype=metric_depth_gt.dtype,
55
+ device=metric_depth_gt.device
56
+ )
57
+ mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
58
+ # fin mask
59
+ gt_mask_fin = ((metric_depth_gt > 0)*(mask_roi)).float()
60
+ # filter the sky
61
+ inf_thres = 50*q25.clamp(min=200, max=1e3)
62
+ gt_mask_inf = (metric_depth_gt > inf_thres).float()
63
+ # gt mask
64
+ gt_mask = (metric_depth_gt > 0)*(metric_depth_gt < 10*q25)
65
+
66
+ points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
67
+
68
+ if annots["syn_real"] == 1:
69
+ ln_msk_l2, _ = mask_l2_loss(predictions["unc_metric"], gt_mask_fin[:,0], gt_mask_inf[:,0])
70
+ ln_msk_l2 = 50*ln_msk_l2.mean()
71
+ else:
72
+ ln_msk_l2 = 0 * points_map_gt.mean()
73
+
74
+ # loss1: global invariant loss
75
+ ln_depth_glob, _, gt_metric_scale, gt_metric_shift = affine_invariant_global_loss(predictions["points_map"], points_map_gt, gt_mask[:,0], align_resolution=32)
76
+ ln_depth_glob = 100*ln_depth_glob.mean()
77
+ # loss2: edge loss
78
+ ln_edge, _ = edge_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
79
+ ln_edge = ln_edge.mean()
80
+ # loss3: normal loss
81
+ ln_normal, _ = normal_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
82
+ ln_normal = ln_normal.mean()
83
+ #NOTE: loss4: consistent loss
84
+ norm_rescale = gt_metric_scale.mean()
85
+ points_map_gt_cons = points_map_gt.clone() / norm_rescale
86
+ if "scale" in predictions.keys():
87
+ scale_ = predictions["scale"].view(B*T, 2, 1, 1)[:,:1]
88
+ shift_ = predictions["scale"].view(B*T, 2, 1, 1)[:,1:]
89
+ else:
90
+ scale_ = torch.ones_like(predictions["points_map"])
91
+ shift_ = torch.zeros_like(predictions["points_map"])[..., 2:]
92
+
93
+ points_pred_cons = predictions["points_map"] * scale_
94
+ points_pred_cons[..., 2:] += shift_
95
+ pred_mask = predictions["unc_metric"].clone().clamp(min=5e-2)
96
+ ln_cons = torch.abs(points_pred_cons - points_map_gt_cons).norm(dim=-1) * pred_mask - 0.05 * torch.log(pred_mask)
97
+ ln_cons = 0.5*ln_cons[(1-gt_mask_inf.squeeze()).bool()].clamp(max=100).mean()
98
+ # loss5: scale shift loss
99
+ if "scale" in predictions.keys():
100
+ ln_scale_shift = torch.abs(scale_.squeeze() - gt_metric_scale / norm_rescale) + torch.abs(shift_.squeeze() - gt_metric_shift[:,2] / norm_rescale)
101
+ ln_scale_shift = 10*ln_scale_shift.mean()
102
+ else:
103
+ ln_scale_shift = 0 * ln_cons.mean()
104
+ # loss6: pose loss
105
+ c2w_traj_gt[...,:3, 3] /= norm_rescale
106
+ ln_pose = 0
107
+ for i_t, pose_enc_i in enumerate(predictions["pose_enc_list"]):
108
+ pose_enc_gt = extri_intri_to_pose_encoding(torch.inverse(c2w_traj_gt)[...,:3,:4], intrs, predictions["images"].shape[-2:])
109
+ T_loss = torch.abs(pose_enc_i[..., :3] - pose_enc_gt[..., :3]).mean()
110
+ R_loss = torch.abs(pose_enc_i[..., 3:7] - pose_enc_gt[..., 3:7]).mean()
111
+ K_loss = torch.abs(pose_enc_i[..., 7:] - pose_enc_gt[..., 7:]).mean()
112
+ pose_loss_i = 25*(T_loss + R_loss) + K_loss
113
+ ln_pose += 0.8**(len(predictions["pose_enc_list"]) - i_t - 1)*(pose_loss_i)
114
+ ln_pose = 5*ln_pose
115
+ if annots["syn_real"] == 1:
116
+ loss = ln_depth_glob + ln_edge + ln_normal + ln_cons + ln_scale_shift + ln_pose + ln_msk_l2
117
+ else:
118
+ loss = ln_cons + ln_pose
119
+ ln_scale_shift = 0*ln_scale_shift
120
+ return {"loss": loss, "ln_depth_glob": ln_depth_glob, "ln_edge": ln_edge, "ln_normal": ln_normal,
121
+ "ln_cons": ln_cons, "ln_scale_shift": ln_scale_shift,
122
+ "ln_pose": ln_pose, "ln_msk_l2": ln_msk_l2, "norm_scale": norm_rescale}
123
+
models/SpaTrackV2/models/vggt4track/utils/pose_enc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ ):
17
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
18
+
19
+ This function transforms camera parameters into a unified pose encoding format,
20
+ which can be used for various downstream tasks like pose prediction or representation.
21
+
22
+ Args:
23
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
24
+ where B is batch size and S is sequence length.
25
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
26
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
27
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
28
+ Defined in pixels, with format:
29
+ [[fx, 0, cx],
30
+ [0, fy, cy],
31
+ [0, 0, 1]]
32
+ where fx, fy are focal lengths and (cx, cy) is the principal point
33
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
34
+ Required for computing field of view values. For example: (256, 512).
35
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
36
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
37
+
38
+ Returns:
39
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
40
+ For "absT_quaR_FoV" type, the 9 dimensions are:
41
+ - [:3] = absolute translation vector T (3D)
42
+ - [3:7] = rotation as quaternion quat (4D)
43
+ - [7:] = field of view (2D)
44
+ """
45
+
46
+ # extrinsics: BxSx3x4
47
+ # intrinsics: BxSx3x3
48
+
49
+ if pose_encoding_type == "absT_quaR_FoV":
50
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
51
+ T = extrinsics[:, :, :3, 3] # BxSx3
52
+
53
+ quat = mat_to_quat(R)
54
+ # Note the order of h and w here
55
+ H, W = image_size_hw
56
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
57
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
58
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return pose_encoding
63
+
64
+
65
+ def pose_encoding_to_extri_intri(
66
+ pose_encoding,
67
+ image_size_hw=None, # e.g., (256, 512)
68
+ pose_encoding_type="absT_quaR_FoV",
69
+ build_intrinsics=True,
70
+ ):
71
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
72
+
73
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
74
+ reconstructing the full camera parameters from the compact encoding.
75
+
76
+ Args:
77
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
78
+ where B is batch size and S is sequence length.
79
+ For "absT_quaR_FoV" type, the 9 dimensions are:
80
+ - [:3] = absolute translation vector T (3D)
81
+ - [3:7] = rotation as quaternion quat (4D)
82
+ - [7:] = field of view (2D)
83
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
84
+ Required for reconstructing intrinsics from field of view values.
85
+ For example: (256, 512).
86
+ pose_encoding_type (str): Type of pose encoding used. Currently only
87
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
88
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
89
+ If False, only extrinsics are returned and intrinsics will be None.
90
+
91
+ Returns:
92
+ tuple: (extrinsics, intrinsics)
93
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
94
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
95
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
96
+ a 3x1 translation vector.
97
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
98
+ or None if build_intrinsics is False. Defined in pixels, with format:
99
+ [[fx, 0, cx],
100
+ [0, fy, cy],
101
+ [0, 0, 1]]
102
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
103
+ assumed to be at the center of the image (W/2, H/2).
104
+ """
105
+
106
+ intrinsics = None
107
+
108
+ if pose_encoding_type == "absT_quaR_FoV":
109
+ T = pose_encoding[..., :3]
110
+ quat = pose_encoding[..., 3:7]
111
+ fov_h = pose_encoding[..., 7]
112
+ fov_w = pose_encoding[..., 8]
113
+
114
+ R = quat_to_mat(quat)
115
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
116
+
117
+ if build_intrinsics:
118
+ H, W = image_size_hw
119
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
120
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
121
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
122
+ intrinsics[..., 0, 0] = fx
123
+ intrinsics[..., 1, 1] = fy
124
+ intrinsics[..., 0, 2] = W / 2
125
+ intrinsics[..., 1, 2] = H / 2
126
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ return extrinsics, intrinsics
models/SpaTrackV2/models/vggt4track/utils/rotation.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [
67
+ 1.0 + m00 + m11 + m22,
68
+ 1.0 + m00 - m11 - m22,
69
+ 1.0 - m00 + m11 - m22,
70
+ 1.0 - m00 - m11 + m22,
71
+ ],
72
+ dim=-1,
73
+ )
74
+ )
75
+
76
+ # we produce the desired quaternion multiplied by each of r, i, j, k
77
+ quat_by_rijk = torch.stack(
78
+ [
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
85
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
86
+ # `int`.
87
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
88
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
89
+ # `int`.
90
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
91
+ ],
92
+ dim=-2,
93
+ )
94
+
95
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
96
+ # the candidate won't be picked.
97
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
98
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
99
+
100
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
101
+ # forall i; we pick the best-conditioned one (with the largest denominator)
102
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+
112
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Returns torch.sqrt(torch.max(0, x))
115
+ but with a zero subgradient where x is 0.
116
+ """
117
+ ret = torch.zeros_like(x)
118
+ positive_mask = x > 0
119
+ if torch.is_grad_enabled():
120
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
121
+ else:
122
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
123
+ return ret
124
+
125
+
126
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Convert a unit quaternion to a standard form: one in which the real
129
+ part is non negative.
130
+
131
+ Args:
132
+ quaternions: Quaternions with real part last,
133
+ as tensor of shape (..., 4).
134
+
135
+ Returns:
136
+ Standardized quaternions as tensor of shape (..., 4).
137
+ """
138
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
models/SpaTrackV2/models/vggt4track/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
scripts/download.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Download the example data using gdown
4
+ mkdir -p ./assets/example1
5
+ gdown --id 1q6n2R5ihfMoD-dU_u5vfcMALZSihNgiq -O ./assets/example1/snowboard.npz
viz.html ADDED
@@ -0,0 +1,2115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>3D Point Cloud Visualizer</title>
7
+ <style>
8
+ :root {
9
+ --primary: #9b59b6; /* Brighter purple for dark mode */
10
+ --primary-light: #3a2e4a;
11
+ --secondary: #a86add;
12
+ --accent: #ff6e6e;
13
+ --bg: #1a1a1a;
14
+ --surface: #2c2c2c;
15
+ --text: #e0e0e0;
16
+ --text-secondary: #a0a0a0;
17
+ --border: #444444;
18
+ --shadow: rgba(0, 0, 0, 0.2);
19
+ --shadow-hover: rgba(0, 0, 0, 0.3);
20
+
21
+ --space-sm: 16px;
22
+ --space-md: 24px;
23
+ --space-lg: 32px;
24
+ }
25
+
26
+ body {
27
+ margin: 0;
28
+ overflow: hidden;
29
+ background: var(--bg);
30
+ color: var(--text);
31
+ font-family: 'Inter', sans-serif;
32
+ -webkit-font-smoothing: antialiased;
33
+ }
34
+
35
+ #canvas-container {
36
+ position: absolute;
37
+ width: 100%;
38
+ height: 100%;
39
+ }
40
+
41
+ #ui-container {
42
+ position: absolute;
43
+ top: 0;
44
+ left: 0;
45
+ width: 100%;
46
+ height: 100%;
47
+ pointer-events: none;
48
+ z-index: 10;
49
+ }
50
+
51
+ #status-bar {
52
+ position: absolute;
53
+ top: 16px;
54
+ left: 16px;
55
+ background: rgba(30, 30, 30, 0.9);
56
+ padding: 8px 16px;
57
+ border-radius: 8px;
58
+ pointer-events: auto;
59
+ box-shadow: 0 4px 6px var(--shadow);
60
+ backdrop-filter: blur(4px);
61
+ border: 1px solid var(--border);
62
+ color: var(--text);
63
+ transition: opacity 0.5s ease, transform 0.5s ease;
64
+ font-weight: 500;
65
+ }
66
+
67
+ #status-bar.hidden {
68
+ opacity: 0;
69
+ transform: translateY(-20px);
70
+ pointer-events: none;
71
+ }
72
+
73
+ #control-panel {
74
+ position: absolute;
75
+ bottom: 16px;
76
+ left: 50%;
77
+ transform: translateX(-50%);
78
+ background: rgba(44, 44, 44, 0.95);
79
+ padding: 6px 8px;
80
+ border-radius: 6px;
81
+ display: flex;
82
+ gap: 8px;
83
+ align-items: center;
84
+ justify-content: space-between;
85
+ pointer-events: auto;
86
+ box-shadow: 0 4px 10px var(--shadow);
87
+ backdrop-filter: blur(4px);
88
+ border: 1px solid var(--border);
89
+ }
90
+
91
+ #timeline {
92
+ width: 150px;
93
+ height: 4px;
94
+ background: rgba(255, 255, 255, 0.1);
95
+ border-radius: 2px;
96
+ position: relative;
97
+ cursor: pointer;
98
+ }
99
+
100
+ #progress {
101
+ position: absolute;
102
+ height: 100%;
103
+ background: var(--primary);
104
+ border-radius: 2px;
105
+ width: 0%;
106
+ }
107
+
108
+ #playback-controls {
109
+ display: flex;
110
+ gap: 4px;
111
+ align-items: center;
112
+ }
113
+
114
+ button {
115
+ background: rgba(255, 255, 255, 0.08);
116
+ border: 1px solid var(--border);
117
+ color: var(--text);
118
+ padding: 4px 6px;
119
+ border-radius: 3px;
120
+ cursor: pointer;
121
+ display: flex;
122
+ align-items: center;
123
+ justify-content: center;
124
+ transition: background 0.2s, transform 0.2s;
125
+ font-family: 'Inter', sans-serif;
126
+ font-weight: 500;
127
+ font-size: 6px;
128
+ }
129
+
130
+ button:hover {
131
+ background: rgba(255, 255, 255, 0.15);
132
+ transform: translateY(-1px);
133
+ }
134
+
135
+ button.active {
136
+ background: var(--primary);
137
+ color: white;
138
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
139
+ }
140
+
141
+ select, input {
142
+ background: rgba(255, 255, 255, 0.08);
143
+ border: 1px solid var(--border);
144
+ color: var(--text);
145
+ padding: 4px 6px;
146
+ border-radius: 3px;
147
+ cursor: pointer;
148
+ font-family: 'Inter', sans-serif;
149
+ font-size: 6px;
150
+ }
151
+
152
+ .icon {
153
+ width: 10px;
154
+ height: 10px;
155
+ fill: currentColor;
156
+ }
157
+
158
+ .tooltip {
159
+ position: absolute;
160
+ bottom: 100%;
161
+ left: 50%;
162
+ transform: translateX(-50%);
163
+ background: var(--surface);
164
+ color: var(--text);
165
+ padding: 3px 6px;
166
+ border-radius: 3px;
167
+ font-size: 7px;
168
+ white-space: nowrap;
169
+ margin-bottom: 4px;
170
+ opacity: 0;
171
+ transition: opacity 0.2s;
172
+ pointer-events: none;
173
+ box-shadow: 0 2px 4px var(--shadow);
174
+ border: 1px solid var(--border);
175
+ }
176
+
177
+ button:hover .tooltip {
178
+ opacity: 1;
179
+ }
180
+
181
+ #settings-panel {
182
+ position: absolute;
183
+ top: 16px;
184
+ right: 16px;
185
+ background: rgba(44, 44, 44, 0.98);
186
+ padding: 10px;
187
+ border-radius: 6px;
188
+ width: 195px;
189
+ max-height: calc(100vh - 40px);
190
+ overflow-y: auto;
191
+ pointer-events: auto;
192
+ box-shadow: 0 4px 15px var(--shadow);
193
+ backdrop-filter: blur(4px);
194
+ border: 1px solid var(--border);
195
+ display: block;
196
+ opacity: 1;
197
+ scrollbar-width: thin;
198
+ scrollbar-color: var(--primary-light) transparent;
199
+ transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
200
+ }
201
+
202
+ #settings-panel.is-hidden {
203
+ transform: translateX(calc(100% + 20px));
204
+ opacity: 0;
205
+ pointer-events: none;
206
+ }
207
+
208
+ #settings-panel::-webkit-scrollbar {
209
+ width: 3px;
210
+ }
211
+
212
+ #settings-panel::-webkit-scrollbar-track {
213
+ background: transparent;
214
+ }
215
+
216
+ #settings-panel::-webkit-scrollbar-thumb {
217
+ background-color: var(--primary-light);
218
+ border-radius: 3px;
219
+ }
220
+
221
+ @media (max-height: 700px) {
222
+ #settings-panel {
223
+ max-height: calc(100vh - 40px);
224
+ }
225
+ }
226
+
227
+ @media (max-width: 768px) {
228
+ #control-panel {
229
+ width: 90%;
230
+ flex-wrap: wrap;
231
+ justify-content: center;
232
+ }
233
+
234
+ #timeline {
235
+ width: 100%;
236
+ order: 3;
237
+ margin-top: 10px;
238
+ }
239
+
240
+ #settings-panel {
241
+ width: 140px;
242
+ right: 10px;
243
+ top: 10px;
244
+ max-height: calc(100vh - 20px);
245
+ }
246
+ }
247
+
248
+ .settings-group {
249
+ margin-bottom: 8px;
250
+ }
251
+
252
+ .settings-group h3 {
253
+ margin: 0 0 6px 0;
254
+ font-size: 10px;
255
+ font-weight: 500;
256
+ color: var(--text-secondary);
257
+ }
258
+
259
+ .slider-container {
260
+ display: flex;
261
+ align-items: center;
262
+ gap: 6px;
263
+ width: 100%;
264
+ }
265
+
266
+ .slider-container label {
267
+ min-width: 60px;
268
+ font-size: 10px;
269
+ flex-shrink: 0;
270
+ }
271
+
272
+ input[type="range"] {
273
+ flex: 1;
274
+ height: 2px;
275
+ -webkit-appearance: none;
276
+ background: rgba(255, 255, 255, 0.1);
277
+ border-radius: 1px;
278
+ min-width: 0;
279
+ }
280
+
281
+ input[type="range"]::-webkit-slider-thumb {
282
+ -webkit-appearance: none;
283
+ width: 8px;
284
+ height: 8px;
285
+ border-radius: 50%;
286
+ background: var(--primary);
287
+ cursor: pointer;
288
+ }
289
+
290
+ .toggle-switch {
291
+ position: relative;
292
+ display: inline-block;
293
+ width: 20px;
294
+ height: 10px;
295
+ }
296
+
297
+ .toggle-switch input {
298
+ opacity: 0;
299
+ width: 0;
300
+ height: 0;
301
+ }
302
+
303
+ .toggle-slider {
304
+ position: absolute;
305
+ cursor: pointer;
306
+ top: 0;
307
+ left: 0;
308
+ right: 0;
309
+ bottom: 0;
310
+ background: rgba(255, 255, 255, 0.1);
311
+ transition: .4s;
312
+ border-radius: 10px;
313
+ }
314
+
315
+ .toggle-slider:before {
316
+ position: absolute;
317
+ content: "";
318
+ height: 8px;
319
+ width: 8px;
320
+ left: 1px;
321
+ bottom: 1px;
322
+ background: var(--surface);
323
+ border: 1px solid var(--border);
324
+ transition: .4s;
325
+ border-radius: 50%;
326
+ }
327
+
328
+ input:checked + .toggle-slider {
329
+ background: var(--primary);
330
+ }
331
+
332
+ input:checked + .toggle-slider:before {
333
+ transform: translateX(10px);
334
+ }
335
+
336
+ .checkbox-container {
337
+ display: flex;
338
+ align-items: center;
339
+ gap: 4px;
340
+ margin-bottom: 4px;
341
+ }
342
+
343
+ .checkbox-container label {
344
+ font-size: 10px;
345
+ cursor: pointer;
346
+ }
347
+
348
+ #loading-overlay {
349
+ position: absolute;
350
+ top: 0;
351
+ left: 0;
352
+ width: 100%;
353
+ height: 100%;
354
+ background: var(--bg);
355
+ display: flex;
356
+ flex-direction: column;
357
+ align-items: center;
358
+ justify-content: center;
359
+ z-index: 100;
360
+ transition: opacity 0.5s;
361
+ }
362
+
363
+ #loading-overlay.fade-out {
364
+ opacity: 0;
365
+ pointer-events: none;
366
+ }
367
+
368
+ .spinner {
369
+ width: 50px;
370
+ height: 50px;
371
+ border: 5px solid rgba(155, 89, 182, 0.2);
372
+ border-radius: 50%;
373
+ border-top-color: var(--primary);
374
+ animation: spin 1s ease-in-out infinite;
375
+ margin-bottom: 16px;
376
+ }
377
+
378
+ @keyframes spin {
379
+ to { transform: rotate(360deg); }
380
+ }
381
+
382
+ #loading-text {
383
+ margin-top: 16px;
384
+ font-size: 18px;
385
+ color: var(--text);
386
+ font-weight: 500;
387
+ }
388
+
389
+ #frame-counter {
390
+ color: var(--text-secondary);
391
+ font-size: 7px;
392
+ font-weight: 500;
393
+ min-width: 60px;
394
+ text-align: center;
395
+ padding: 0 4px;
396
+ }
397
+
398
+ .control-btn {
399
+ background: rgba(255, 255, 255, 0.08);
400
+ border: 1px solid var(--border);
401
+ padding: 4px 6px;
402
+ border-radius: 3px;
403
+ cursor: pointer;
404
+ display: flex;
405
+ align-items: center;
406
+ justify-content: center;
407
+ transition: all 0.2s ease;
408
+ font-size: 6px;
409
+ }
410
+
411
+ .control-btn:hover {
412
+ background: rgba(255, 255, 255, 0.15);
413
+ transform: translateY(-1px);
414
+ }
415
+
416
+ .control-btn.active {
417
+ background: var(--primary);
418
+ color: white;
419
+ }
420
+
421
+ .control-btn.active:hover {
422
+ background: var(--primary);
423
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
424
+ }
425
+
426
+ #settings-toggle-btn {
427
+ position: relative;
428
+ border-radius: 6px;
429
+ z-index: 20;
430
+ }
431
+
432
+ #settings-toggle-btn.active {
433
+ background: var(--primary);
434
+ color: white;
435
+ }
436
+
437
+ #status-bar,
438
+ #control-panel,
439
+ #settings-panel,
440
+ button,
441
+ input,
442
+ select,
443
+ .toggle-switch {
444
+ pointer-events: auto;
445
+ }
446
+
447
+ h2 {
448
+ font-size: 0.9rem;
449
+ font-weight: 600;
450
+ margin-top: 0;
451
+ margin-bottom: 12px;
452
+ color: var(--primary);
453
+ cursor: move;
454
+ user-select: none;
455
+ display: flex;
456
+ align-items: center;
457
+ }
458
+
459
+ .drag-handle {
460
+ font-size: 10px;
461
+ margin-right: 4px;
462
+ opacity: 0.6;
463
+ }
464
+
465
+ h2:hover .drag-handle {
466
+ opacity: 1;
467
+ }
468
+
469
+ .loading-subtitle {
470
+ font-size: 7px;
471
+ color: var(--text-secondary);
472
+ margin-top: 4px;
473
+ }
474
+
475
+ #reset-view-btn {
476
+ background: var(--primary-light);
477
+ color: var(--primary);
478
+ border: 1px solid rgba(155, 89, 182, 0.2);
479
+ font-weight: 600;
480
+ font-size: 9px;
481
+ padding: 4px 6px;
482
+ transition: all 0.2s;
483
+ }
484
+
485
+ #reset-view-btn:hover {
486
+ background: var(--primary);
487
+ color: white;
488
+ transform: translateY(-2px);
489
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
490
+ }
491
+
492
+ #show-settings-btn {
493
+ position: absolute;
494
+ top: 16px;
495
+ right: 16px;
496
+ z-index: 15;
497
+ display: none;
498
+ }
499
+
500
+ #settings-panel.visible {
501
+ display: block;
502
+ opacity: 1;
503
+ animation: slideIn 0.3s ease forwards;
504
+ }
505
+
506
+ @keyframes slideIn {
507
+ from {
508
+ transform: translateY(20px);
509
+ opacity: 0;
510
+ }
511
+ to {
512
+ transform: translateY(0);
513
+ opacity: 1;
514
+ }
515
+ }
516
+
517
+ .dragging {
518
+ opacity: 0.9;
519
+ box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
520
+ transition: none !important;
521
+ }
522
+
523
+ /* Tooltip for draggable element */
524
+ .tooltip-drag {
525
+ position: absolute;
526
+ left: 50%;
527
+ transform: translateX(-50%);
528
+ background: var(--primary);
529
+ color: white;
530
+ font-size: 9px;
531
+ padding: 2px 4px;
532
+ border-radius: 2px;
533
+ opacity: 0;
534
+ pointer-events: none;
535
+ transition: opacity 0.3s;
536
+ white-space: nowrap;
537
+ bottom: 100%;
538
+ margin-bottom: 4px;
539
+ }
540
+
541
+ h2:hover .tooltip-drag {
542
+ opacity: 1;
543
+ }
544
+
545
+ .btn-group {
546
+ display: flex;
547
+ margin-top: 8px;
548
+ }
549
+
550
+ #reset-settings-btn {
551
+ background: var(--primary-light);
552
+ color: var(--primary);
553
+ border: 1px solid rgba(155, 89, 182, 0.2);
554
+ font-weight: 600;
555
+ font-size: 9px;
556
+ padding: 4px 6px;
557
+ transition: all 0.2s;
558
+ }
559
+
560
+ #reset-settings-btn:hover {
561
+ background: var(--primary);
562
+ color: white;
563
+ transform: translateY(-2px);
564
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
565
+ }
566
+ </style>
567
+ </head>
568
+ <body>
569
+ <link rel="preconnect" href="https://fonts.googleapis.com">
570
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
571
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
572
+
573
+ <div id="canvas-container"></div>
574
+
575
+ <div id="ui-container">
576
+ <div id="status-bar">Initializing...</div>
577
+
578
+ <div id="control-panel">
579
+ <button id="play-pause-btn" class="control-btn">
580
+ <svg class="icon" viewBox="0 0 24 24">
581
+ <path id="play-icon" d="M8 5v14l11-7z"/>
582
+ <path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
583
+ </svg>
584
+ <span class="tooltip">Play/Pause</span>
585
+ </button>
586
+
587
+ <div id="timeline">
588
+ <div id="progress"></div>
589
+ </div>
590
+
591
+ <div id="frame-counter">Frame 0 / 0</div>
592
+
593
+ <div id="playback-controls">
594
+ <button id="speed-btn" class="control-btn">1x</button>
595
+ </div>
596
+ </div>
597
+
598
+ <div id="settings-panel">
599
+ <h2>
600
+ <span class="drag-handle">☰</span>
601
+ Visualization Settings
602
+ <button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
603
+ <svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
604
+ <path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
605
+ </svg>
606
+ </button>
607
+ </h2>
608
+
609
+ <div class="settings-group">
610
+ <h3>Point Cloud</h3>
611
+ <div class="slider-container">
612
+ <label for="point-size">Size</label>
613
+ <input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
614
+ </div>
615
+ <div class="slider-container">
616
+ <label for="point-opacity">Opacity</label>
617
+ <input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
618
+ </div>
619
+ <div class="slider-container">
620
+ <label for="max-depth">Max Depth</label>
621
+ <input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
622
+ </div>
623
+ </div>
624
+
625
+ <div class="settings-group">
626
+ <h3>Trajectory</h3>
627
+ <div class="checkbox-container">
628
+ <label class="toggle-switch">
629
+ <input type="checkbox" id="show-trajectory" checked>
630
+ <span class="toggle-slider"></span>
631
+ </label>
632
+ <label for="show-trajectory">Show Trajectory</label>
633
+ </div>
634
+ <div class="checkbox-container">
635
+ <label class="toggle-switch">
636
+ <input type="checkbox" id="enable-rich-trail">
637
+ <span class="toggle-slider"></span>
638
+ </label>
639
+ <label for="enable-rich-trail">Visual-Rich Trail</label>
640
+ </div>
641
+ <div class="slider-container">
642
+ <label for="trajectory-line-width">Line Width</label>
643
+ <input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
644
+ </div>
645
+ <div class="slider-container">
646
+ <label for="trajectory-ball-size">Ball Size</label>
647
+ <input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
648
+ </div>
649
+ <div class="slider-container">
650
+ <label for="trajectory-history">History Frames</label>
651
+ <input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
652
+ </div>
653
+ <div class="slider-container" id="tail-opacity-container" style="display: none;">
654
+ <label for="trajectory-fade">Tail Opacity</label>
655
+ <input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
656
+ </div>
657
+ </div>
658
+
659
+ <div class="settings-group">
660
+ <h3>Camera</h3>
661
+ <div class="checkbox-container">
662
+ <label class="toggle-switch">
663
+ <input type="checkbox" id="show-camera-frustum" checked>
664
+ <span class="toggle-slider"></span>
665
+ </label>
666
+ <label for="show-camera-frustum">Show Camera Frustum</label>
667
+ </div>
668
+ <div class="slider-container">
669
+ <label for="frustum-size">Size</label>
670
+ <input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
671
+ </div>
672
+ </div>
673
+
674
+ <div class="settings-group">
675
+ <h3>Keep History</h3>
676
+ <div class="checkbox-container">
677
+ <label class="toggle-switch">
678
+ <input type="checkbox" id="enable-keep-history">
679
+ <span class="toggle-slider"></span>
680
+ </label>
681
+ <label for="enable-keep-history">Enable Keep History</label>
682
+ </div>
683
+ <div class="slider-container">
684
+ <label for="history-stride">Stride</label>
685
+ <select id="history-stride">
686
+ <option value="1">1</option>
687
+ <option value="2">2</option>
688
+ <option value="5" selected>5</option>
689
+ <option value="10">10</option>
690
+ <option value="20">20</option>
691
+ </select>
692
+ </div>
693
+ </div>
694
+
695
+ <div class="settings-group">
696
+ <h3>Background</h3>
697
+ <div class="checkbox-container">
698
+ <label class="toggle-switch">
699
+ <input type="checkbox" id="white-background">
700
+ <span class="toggle-slider"></span>
701
+ </label>
702
+ <label for="white-background">White Background</label>
703
+ </div>
704
+ </div>
705
+
706
+ <div class="settings-group">
707
+ <div class="btn-group">
708
+ <button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
709
+ <button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
710
+ </div>
711
+ </div>
712
+ </div>
713
+
714
+ <button id="show-settings-btn" class="control-btn" title="Show Settings">
715
+ <svg class="icon" viewBox="0 0 24 24">
716
+ <path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
717
+ </svg>
718
+ </button>
719
+ </div>
720
+
721
+ <div id="loading-overlay">
722
+ <!-- <div class="spinner"></div> -->
723
+ <div id="loading-text"></div>
724
+ <div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
725
+ </div>
726
+
727
+ <!-- Libraries -->
728
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
729
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/build/three.min.js"></script>
730
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/controls/OrbitControls.js"></script>
731
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/build/dat.gui.min.js"></script>
732
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegmentsGeometry.js"></script>
733
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineGeometry.js"></script>
734
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineMaterial.js"></script>
735
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegments2.js"></script>
736
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/Line2.js"></script>
737
+
738
+ <script>
739
+ class PointCloudVisualizer {
740
+ constructor() {
741
+ this.data = null;
742
+ this.config = {};
743
+ this.currentFrame = 0;
744
+ this.isPlaying = false;
745
+ this.playbackSpeed = 1;
746
+ this.lastFrameTime = 0;
747
+ this.defaultSettings = null;
748
+
749
+ this.ui = {
750
+ statusBar: document.getElementById('status-bar'),
751
+ playPauseBtn: document.getElementById('play-pause-btn'),
752
+ speedBtn: document.getElementById('speed-btn'),
753
+ timeline: document.getElementById('timeline'),
754
+ progress: document.getElementById('progress'),
755
+ settingsPanel: document.getElementById('settings-panel'),
756
+ loadingOverlay: document.getElementById('loading-overlay'),
757
+ loadingText: document.getElementById('loading-text'),
758
+ settingsToggleBtn: document.getElementById('settings-toggle-btn'),
759
+ frameCounter: document.getElementById('frame-counter'),
760
+ pointSize: document.getElementById('point-size'),
761
+ pointOpacity: document.getElementById('point-opacity'),
762
+ maxDepth: document.getElementById('max-depth'),
763
+ showTrajectory: document.getElementById('show-trajectory'),
764
+ enableRichTrail: document.getElementById('enable-rich-trail'),
765
+ trajectoryLineWidth: document.getElementById('trajectory-line-width'),
766
+ trajectoryBallSize: document.getElementById('trajectory-ball-size'),
767
+ trajectoryHistory: document.getElementById('trajectory-history'),
768
+ trajectoryFade: document.getElementById('trajectory-fade'),
769
+ tailOpacityContainer: document.getElementById('tail-opacity-container'),
770
+ resetViewBtn: document.getElementById('reset-view-btn'),
771
+ showCameraFrustum: document.getElementById('show-camera-frustum'),
772
+ frustumSize: document.getElementById('frustum-size'),
773
+ hideSettingsBtn: document.getElementById('hide-settings-btn'),
774
+ showSettingsBtn: document.getElementById('show-settings-btn'),
775
+ enableKeepHistory: document.getElementById('enable-keep-history'),
776
+ historyStride: document.getElementById('history-stride'),
777
+ whiteBackground: document.getElementById('white-background')
778
+ };
779
+
780
+ this.scene = null;
781
+ this.camera = null;
782
+ this.renderer = null;
783
+ this.controls = null;
784
+ this.pointCloud = null;
785
+ this.trajectories = [];
786
+ this.cameraFrustum = null;
787
+
788
+ // Keep History functionality
789
+ this.historyPointClouds = [];
790
+ this.historyTrajectories = [];
791
+ this.historyFrames = [];
792
+ this.maxHistoryFrames = 20;
793
+
794
+ this.initThreeJS();
795
+ this.loadDefaultSettings().then(() => {
796
+ this.initEventListeners();
797
+ this.loadData();
798
+ });
799
+ }
800
+
801
+ async loadDefaultSettings() {
802
+ try {
803
+ const urlParams = new URLSearchParams(window.location.search);
804
+ const dataPath = urlParams.get('data') || '';
805
+
806
+ const defaultSettings = {
807
+ pointSize: 0.03,
808
+ pointOpacity: 1.0,
809
+ showTrajectory: true,
810
+ trajectoryLineWidth: 2.5,
811
+ trajectoryBallSize: 0.015,
812
+ trajectoryHistory: 0,
813
+ showCameraFrustum: true,
814
+ frustumSize: 0.2
815
+ };
816
+
817
+ if (!dataPath) {
818
+ this.defaultSettings = defaultSettings;
819
+ this.applyDefaultSettings();
820
+ return;
821
+ }
822
+
823
+ // Try to extract dataset and videoId from the data path
824
+ // Expected format: demos/datasetname/videoid.bin
825
+ const pathParts = dataPath.split('/');
826
+ if (pathParts.length < 3) {
827
+ this.defaultSettings = defaultSettings;
828
+ this.applyDefaultSettings();
829
+ return;
830
+ }
831
+
832
+ const datasetName = pathParts[pathParts.length - 2];
833
+ let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
834
+
835
+ // Load settings from data.json
836
+ const response = await fetch('./data.json');
837
+ if (!response.ok) {
838
+ this.defaultSettings = defaultSettings;
839
+ this.applyDefaultSettings();
840
+ return;
841
+ }
842
+
843
+ const settingsData = await response.json();
844
+
845
+ // Check if this dataset and video exist
846
+ if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
847
+ this.defaultSettings = settingsData[datasetName][videoId];
848
+ } else {
849
+ this.defaultSettings = defaultSettings;
850
+ }
851
+
852
+ this.applyDefaultSettings();
853
+ } catch (error) {
854
+ console.error("Error loading default settings:", error);
855
+
856
+ this.defaultSettings = {
857
+ pointSize: 0.03,
858
+ pointOpacity: 1.0,
859
+ showTrajectory: true,
860
+ trajectoryLineWidth: 2.5,
861
+ trajectoryBallSize: 0.015,
862
+ trajectoryHistory: 0,
863
+ showCameraFrustum: true,
864
+ frustumSize: 0.2
865
+ };
866
+
867
+ this.applyDefaultSettings();
868
+ }
869
+ }
870
+
871
+ applyDefaultSettings() {
872
+ if (!this.defaultSettings) return;
873
+
874
+ if (this.ui.pointSize) {
875
+ this.ui.pointSize.value = this.defaultSettings.pointSize;
876
+ }
877
+
878
+ if (this.ui.pointOpacity) {
879
+ this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
880
+ }
881
+
882
+ if (this.ui.maxDepth) {
883
+ this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
884
+ }
885
+
886
+ if (this.ui.showTrajectory) {
887
+ this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
888
+ }
889
+
890
+ if (this.ui.trajectoryLineWidth) {
891
+ this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
892
+ }
893
+
894
+ if (this.ui.trajectoryBallSize) {
895
+ this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
896
+ }
897
+
898
+ if (this.ui.trajectoryHistory) {
899
+ this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
900
+ }
901
+
902
+ if (this.ui.showCameraFrustum) {
903
+ this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
904
+ }
905
+
906
+ if (this.ui.frustumSize) {
907
+ this.ui.frustumSize.value = this.defaultSettings.frustumSize;
908
+ }
909
+ }
910
+
911
+ initThreeJS() {
912
+ this.scene = new THREE.Scene();
913
+ this.scene.background = new THREE.Color(0x1a1a1a);
914
+
915
+ this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
916
+ this.camera.position.set(0, 0, 0);
917
+
918
+ this.renderer = new THREE.WebGLRenderer({ antialias: true });
919
+ this.renderer.setPixelRatio(window.devicePixelRatio);
920
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
921
+ document.getElementById('canvas-container').appendChild(this.renderer.domElement);
922
+
923
+ this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
924
+ this.controls.enableDamping = true;
925
+ this.controls.dampingFactor = 0.05;
926
+ this.controls.target.set(0, 0, 0);
927
+ this.controls.minDistance = 0.1;
928
+ this.controls.maxDistance = 1000;
929
+ this.controls.update();
930
+
931
+ const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
932
+ this.scene.add(ambientLight);
933
+
934
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
935
+ directionalLight.position.set(1, 1, 1);
936
+ this.scene.add(directionalLight);
937
+ }
938
+
939
+ initEventListeners() {
940
+ window.addEventListener('resize', () => this.onWindowResize());
941
+
942
+ this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
943
+
944
+ this.ui.timeline.addEventListener('click', (e) => {
945
+ const rect = this.ui.timeline.getBoundingClientRect();
946
+ const pos = (e.clientX - rect.left) / rect.width;
947
+ this.seekTo(pos);
948
+ });
949
+
950
+ this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
951
+
952
+ this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
953
+ this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
954
+ this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
955
+ this.ui.showTrajectory.addEventListener('change', () => {
956
+ this.trajectories.forEach(trajectory => {
957
+ trajectory.visible = this.ui.showTrajectory.checked;
958
+ });
959
+ });
960
+
961
+ this.ui.enableRichTrail.addEventListener('change', () => {
962
+ this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
963
+ this.updateTrajectories(this.currentFrame);
964
+ });
965
+
966
+ this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
967
+ this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
968
+ this.ui.trajectoryHistory.addEventListener('input', () => {
969
+ this.updateTrajectories(this.currentFrame);
970
+ });
971
+ this.ui.trajectoryFade.addEventListener('input', () => {
972
+ this.updateTrajectories(this.currentFrame);
973
+ });
974
+
975
+ this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
976
+
977
+ const resetSettingsBtn = document.getElementById('reset-settings-btn');
978
+ if (resetSettingsBtn) {
979
+ resetSettingsBtn.addEventListener('click', () => this.resetSettings());
980
+ }
981
+
982
+ document.addEventListener('keydown', (e) => {
983
+ if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
984
+ this.ui.settingsPanel.classList.remove('visible');
985
+ this.ui.settingsToggleBtn.classList.remove('active');
986
+ }
987
+ });
988
+
989
+ if (this.ui.settingsToggleBtn) {
990
+ this.ui.settingsToggleBtn.addEventListener('click', () => {
991
+ const isVisible = this.ui.settingsPanel.classList.toggle('visible');
992
+ this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
993
+
994
+ if (isVisible) {
995
+ const panelRect = this.ui.settingsPanel.getBoundingClientRect();
996
+ const viewportHeight = window.innerHeight;
997
+
998
+ if (panelRect.bottom > viewportHeight) {
999
+ this.ui.settingsPanel.style.bottom = 'auto';
1000
+ this.ui.settingsPanel.style.top = '80px';
1001
+ }
1002
+ }
1003
+ });
1004
+ }
1005
+
1006
+ if (this.ui.frustumSize) {
1007
+ this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
1008
+ }
1009
+
1010
+ if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
1011
+ this.ui.hideSettingsBtn.addEventListener('click', () => {
1012
+ this.ui.settingsPanel.classList.add('is-hidden');
1013
+ this.ui.showSettingsBtn.style.display = 'flex';
1014
+ });
1015
+
1016
+ this.ui.showSettingsBtn.addEventListener('click', () => {
1017
+ this.ui.settingsPanel.classList.remove('is-hidden');
1018
+ this.ui.showSettingsBtn.style.display = 'none';
1019
+ });
1020
+ }
1021
+
1022
+ // Keep History event listeners
1023
+ if (this.ui.enableKeepHistory) {
1024
+ this.ui.enableKeepHistory.addEventListener('change', () => {
1025
+ if (!this.ui.enableKeepHistory.checked) {
1026
+ this.clearHistory();
1027
+ }
1028
+ });
1029
+ }
1030
+
1031
+ if (this.ui.historyStride) {
1032
+ this.ui.historyStride.addEventListener('change', () => {
1033
+ this.clearHistory();
1034
+ });
1035
+ }
1036
+
1037
+ // Background toggle event listener
1038
+ if (this.ui.whiteBackground) {
1039
+ this.ui.whiteBackground.addEventListener('change', () => {
1040
+ this.toggleBackground();
1041
+ });
1042
+ }
1043
+ }
1044
+
1045
+ makeElementDraggable(element) {
1046
+ let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
1047
+
1048
+ const dragHandle = element.querySelector('h2');
1049
+
1050
+ if (dragHandle) {
1051
+ dragHandle.onmousedown = dragMouseDown;
1052
+ dragHandle.title = "Drag to move panel";
1053
+ } else {
1054
+ element.onmousedown = dragMouseDown;
1055
+ }
1056
+
1057
+ function dragMouseDown(e) {
1058
+ e = e || window.event;
1059
+ e.preventDefault();
1060
+ pos3 = e.clientX;
1061
+ pos4 = e.clientY;
1062
+ document.onmouseup = closeDragElement;
1063
+ document.onmousemove = elementDrag;
1064
+
1065
+ element.classList.add('dragging');
1066
+ }
1067
+
1068
+ function elementDrag(e) {
1069
+ e = e || window.event;
1070
+ e.preventDefault();
1071
+ pos1 = pos3 - e.clientX;
1072
+ pos2 = pos4 - e.clientY;
1073
+ pos3 = e.clientX;
1074
+ pos4 = e.clientY;
1075
+
1076
+ const newTop = element.offsetTop - pos2;
1077
+ const newLeft = element.offsetLeft - pos1;
1078
+
1079
+ const viewportWidth = window.innerWidth;
1080
+ const viewportHeight = window.innerHeight;
1081
+
1082
+ const panelRect = element.getBoundingClientRect();
1083
+
1084
+ const maxTop = viewportHeight - 50;
1085
+ const maxLeft = viewportWidth - 50;
1086
+
1087
+ element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
1088
+ element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
1089
+
1090
+ // Remove bottom/right settings when dragging
1091
+ element.style.bottom = 'auto';
1092
+ element.style.right = 'auto';
1093
+ }
1094
+
1095
+ function closeDragElement() {
1096
+ document.onmouseup = null;
1097
+ document.onmousemove = null;
1098
+
1099
+ element.classList.remove('dragging');
1100
+ }
1101
+ }
1102
+
1103
+ async loadData() {
1104
+ try {
1105
+ // this.ui.loadingText.textContent = "Loading binary data...";
1106
+
1107
+ let arrayBuffer;
1108
+
1109
+ if (window.embeddedBase64) {
1110
+ // Base64 embedded path
1111
+ const binaryString = atob(window.embeddedBase64);
1112
+ const len = binaryString.length;
1113
+ const bytes = new Uint8Array(len);
1114
+ for (let i = 0; i < len; i++) {
1115
+ bytes[i] = binaryString.charCodeAt(i);
1116
+ }
1117
+ arrayBuffer = bytes.buffer;
1118
+ } else {
1119
+ // Default fetch path (fallback)
1120
+ const urlParams = new URLSearchParams(window.location.search);
1121
+ const dataPath = urlParams.get('data') || 'data.bin';
1122
+
1123
+ const response = await fetch(dataPath);
1124
+ if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
1125
+ arrayBuffer = await response.arrayBuffer();
1126
+ }
1127
+
1128
+ const dataView = new DataView(arrayBuffer);
1129
+ const headerLen = dataView.getUint32(0, true);
1130
+
1131
+ const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
1132
+ const header = JSON.parse(headerText);
1133
+
1134
+ const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
1135
+ const decompressed = pako.inflate(compressedBlob).buffer;
1136
+
1137
+ const arrays = {};
1138
+ for (const key in header) {
1139
+ if (key === "meta") continue;
1140
+
1141
+ const meta = header[key];
1142
+ const { dtype, shape, offset, length } = meta;
1143
+ const slice = decompressed.slice(offset, offset + length);
1144
+
1145
+ let typedArray;
1146
+ switch (dtype) {
1147
+ case "uint8": typedArray = new Uint8Array(slice); break;
1148
+ case "uint16": typedArray = new Uint16Array(slice); break;
1149
+ case "float32": typedArray = new Float32Array(slice); break;
1150
+ case "float64": typedArray = new Float64Array(slice); break;
1151
+ default: throw new Error(`Unknown dtype: ${dtype}`);
1152
+ }
1153
+
1154
+ arrays[key] = { data: typedArray, shape: shape };
1155
+ }
1156
+
1157
+ this.data = arrays;
1158
+ this.config = header.meta;
1159
+
1160
+ this.initCameraWithCorrectFOV();
1161
+ this.ui.loadingText.textContent = "Creating point cloud...";
1162
+
1163
+ this.initPointCloud();
1164
+ this.initTrajectories();
1165
+
1166
+ setTimeout(() => {
1167
+ this.ui.loadingOverlay.classList.add('fade-out');
1168
+ this.ui.statusBar.classList.add('hidden');
1169
+ this.startAnimation();
1170
+ }, 500);
1171
+ } catch (error) {
1172
+ console.error("Error loading data:", error);
1173
+ this.ui.statusBar.textContent = `Error: ${error.message}`;
1174
+ // this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
1175
+ }
1176
+ }
1177
+
1178
+ initPointCloud() {
1179
+ const numPoints = this.config.resolution[0] * this.config.resolution[1];
1180
+ const positions = new Float32Array(numPoints * 3);
1181
+ const colors = new Float32Array(numPoints * 3);
1182
+
1183
+ const geometry = new THREE.BufferGeometry();
1184
+ geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
1185
+ geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
1186
+
1187
+ const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
1188
+ const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
1189
+
1190
+ const material = new THREE.PointsMaterial({
1191
+ size: pointSize,
1192
+ vertexColors: true,
1193
+ transparent: true,
1194
+ opacity: pointOpacity,
1195
+ sizeAttenuation: true
1196
+ });
1197
+
1198
+ this.pointCloud = new THREE.Points(geometry, material);
1199
+ this.scene.add(this.pointCloud);
1200
+ }
1201
+
1202
+ initTrajectories() {
1203
+ if (!this.data.trajectories) return;
1204
+
1205
+ this.trajectories.forEach(trajectory => {
1206
+ if (trajectory.userData.lineSegments) {
1207
+ trajectory.userData.lineSegments.forEach(segment => {
1208
+ segment.geometry.dispose();
1209
+ segment.material.dispose();
1210
+ });
1211
+ }
1212
+ this.scene.remove(trajectory);
1213
+ });
1214
+ this.trajectories = [];
1215
+
1216
+ const shape = this.data.trajectories.shape;
1217
+ if (!shape || shape.length < 2) return;
1218
+
1219
+ const [totalFrames, numTrajectories] = shape;
1220
+ const palette = this.createColorPalette(numTrajectories);
1221
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1222
+ const maxHistory = 500; // Max value of the history slider, for the object pool
1223
+
1224
+ for (let i = 0; i < numTrajectories; i++) {
1225
+ const trajectoryGroup = new THREE.Group();
1226
+
1227
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1228
+ const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
1229
+ const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
1230
+ const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
1231
+ trajectoryGroup.add(positionMarker);
1232
+
1233
+ // High-Performance Line (default)
1234
+ const simpleLineGeometry = new THREE.BufferGeometry();
1235
+ const simpleLinePositions = new Float32Array(maxHistory * 3);
1236
+ simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
1237
+ const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
1238
+ simpleLine.frustumCulled = false;
1239
+ trajectoryGroup.add(simpleLine);
1240
+
1241
+ // High-Quality Line Segments (for rich trail)
1242
+ const lineSegments = [];
1243
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1244
+
1245
+ // Create a pool of line segment objects
1246
+ for (let j = 0; j < maxHistory - 1; j++) {
1247
+ const lineGeometry = new THREE.LineGeometry();
1248
+ lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
1249
+ const lineMaterial = new THREE.LineMaterial({
1250
+ color: palette[i],
1251
+ linewidth: lineWidth,
1252
+ resolution: resolution,
1253
+ transparent: true,
1254
+ depthWrite: false, // Correctly handle transparency
1255
+ opacity: 0
1256
+ });
1257
+ const segment = new THREE.Line2(lineGeometry, lineMaterial);
1258
+ segment.frustumCulled = false;
1259
+ segment.visible = false; // Start with all segments hidden
1260
+ trajectoryGroup.add(segment);
1261
+ lineSegments.push(segment);
1262
+ }
1263
+
1264
+ trajectoryGroup.userData = {
1265
+ marker: positionMarker,
1266
+ simpleLine: simpleLine,
1267
+ lineSegments: lineSegments,
1268
+ color: palette[i]
1269
+ };
1270
+
1271
+ this.scene.add(trajectoryGroup);
1272
+ this.trajectories.push(trajectoryGroup);
1273
+ }
1274
+
1275
+ const showTrajectory = this.ui.showTrajectory.checked;
1276
+ this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
1277
+ }
1278
+
1279
+ createColorPalette(count) {
1280
+ const colors = [];
1281
+ const hueStep = 360 / count;
1282
+
1283
+ for (let i = 0; i < count; i++) {
1284
+ const hue = (i * hueStep) % 360;
1285
+ const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
1286
+ colors.push(color);
1287
+ }
1288
+
1289
+ return colors;
1290
+ }
1291
+
1292
+ updatePointCloud(frameIndex) {
1293
+ if (!this.data || !this.pointCloud) return;
1294
+
1295
+ const positions = this.pointCloud.geometry.attributes.position.array;
1296
+ const colors = this.pointCloud.geometry.attributes.color.array;
1297
+
1298
+ const rgbVideo = this.data.rgb_video;
1299
+ const depthsRgb = this.data.depths_rgb;
1300
+ const intrinsics = this.data.intrinsics;
1301
+ const invExtrinsics = this.data.inv_extrinsics;
1302
+
1303
+ const width = this.config.resolution[0];
1304
+ const height = this.config.resolution[1];
1305
+ const numPoints = width * height;
1306
+
1307
+ const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
1308
+ const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
1309
+
1310
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1311
+ const transform = this.getTransformElements(invExtrMat);
1312
+
1313
+ const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
1314
+ const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
1315
+
1316
+ const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
1317
+
1318
+ let validPointCount = 0;
1319
+
1320
+ for (let i = 0; i < numPoints; i++) {
1321
+ const xPix = i % width;
1322
+ const yPix = Math.floor(i / width);
1323
+
1324
+ const d0 = depthFrame[i * 3];
1325
+ const d1 = depthFrame[i * 3 + 1];
1326
+ const depthEncoded = d0 | (d1 << 8);
1327
+ const depthValue = (depthEncoded / ((1 << 16) - 1)) *
1328
+ (this.config.depthRange[1] - this.config.depthRange[0]) +
1329
+ this.config.depthRange[0];
1330
+
1331
+ if (depthValue === 0 || depthValue > maxDepth) {
1332
+ continue;
1333
+ }
1334
+
1335
+ const X = ((xPix - cx) * depthValue) / fx;
1336
+ const Y = ((yPix - cy) * depthValue) / fy;
1337
+ const Z = depthValue;
1338
+
1339
+ const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
1340
+ const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
1341
+ const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
1342
+
1343
+ const index = validPointCount * 3;
1344
+ positions[index] = tx;
1345
+ positions[index + 1] = -ty;
1346
+ positions[index + 2] = -tz;
1347
+
1348
+ colors[index] = rgbFrame[i * 3] / 255;
1349
+ colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
1350
+ colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
1351
+
1352
+ validPointCount++;
1353
+ }
1354
+
1355
+ this.pointCloud.geometry.setDrawRange(0, validPointCount);
1356
+ this.pointCloud.geometry.attributes.position.needsUpdate = true;
1357
+ this.pointCloud.geometry.attributes.color.needsUpdate = true;
1358
+ this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
1359
+
1360
+ this.updateTrajectories(frameIndex);
1361
+
1362
+ // Keep History management
1363
+ this.updateHistory(frameIndex);
1364
+
1365
+ const progress = (frameIndex + 1) / this.config.totalFrames;
1366
+ this.ui.progress.style.width = `${progress * 100}%`;
1367
+
1368
+ if (this.ui.frameCounter && this.config.totalFrames) {
1369
+ this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
1370
+ }
1371
+
1372
+ this.updateCameraFrustum(frameIndex);
1373
+ }
1374
+
1375
+ updateTrajectories(frameIndex) {
1376
+ if (!this.data.trajectories || this.trajectories.length === 0) return;
1377
+
1378
+ const trajectoryData = this.data.trajectories.data;
1379
+ const [totalFrames, numTrajectories] = this.data.trajectories.shape;
1380
+ const historyFrames = parseInt(this.ui.trajectoryHistory.value);
1381
+ const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
1382
+
1383
+ const isRichMode = this.ui.enableRichTrail.checked;
1384
+
1385
+ for (let i = 0; i < numTrajectories; i++) {
1386
+ const trajectoryGroup = this.trajectories[i];
1387
+ const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
1388
+
1389
+ const currentPos = new THREE.Vector3();
1390
+ const currentOffset = (frameIndex * numTrajectories + i) * 3;
1391
+
1392
+ currentPos.x = trajectoryData[currentOffset];
1393
+ currentPos.y = -trajectoryData[currentOffset + 1];
1394
+ currentPos.z = -trajectoryData[currentOffset + 2];
1395
+
1396
+ marker.position.copy(currentPos);
1397
+ marker.material.opacity = 1.0;
1398
+
1399
+ const historyToShow = Math.min(historyFrames, frameIndex + 1);
1400
+
1401
+ if (isRichMode) {
1402
+ // --- High-Quality Mode ---
1403
+ simpleLine.visible = false;
1404
+
1405
+ for (let j = 0; j < lineSegments.length; j++) {
1406
+ const segment = lineSegments[j];
1407
+ if (j < historyToShow - 1) {
1408
+ const headFrame = frameIndex - j;
1409
+ const tailFrame = frameIndex - j - 1;
1410
+ const headOffset = (headFrame * numTrajectories + i) * 3;
1411
+ const tailOffset = (tailFrame * numTrajectories + i) * 3;
1412
+ const positions = [
1413
+ trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
1414
+ trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
1415
+ ];
1416
+ segment.geometry.setPositions(positions);
1417
+ const headOpacity = 1.0;
1418
+ const normalizedAge = j / Math.max(1, historyToShow - 2);
1419
+ const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
1420
+ segment.material.opacity = Math.max(0, alpha);
1421
+ segment.visible = true;
1422
+ } else {
1423
+ segment.visible = false;
1424
+ }
1425
+ }
1426
+ } else {
1427
+ // --- Performance Mode ---
1428
+ lineSegments.forEach(s => s.visible = false);
1429
+ simpleLine.visible = true;
1430
+
1431
+ const positions = simpleLine.geometry.attributes.position.array;
1432
+ for (let j = 0; j < historyToShow; j++) {
1433
+ const historyFrame = Math.max(0, frameIndex - j);
1434
+ const offset = (historyFrame * numTrajectories + i) * 3;
1435
+ positions[j * 3] = trajectoryData[offset];
1436
+ positions[j * 3 + 1] = -trajectoryData[offset + 1];
1437
+ positions[j * 3 + 2] = -trajectoryData[offset + 2];
1438
+ }
1439
+ simpleLine.geometry.setDrawRange(0, historyToShow);
1440
+ simpleLine.geometry.attributes.position.needsUpdate = true;
1441
+ }
1442
+ }
1443
+ }
1444
+
1445
+ updateTrajectorySettings() {
1446
+ if (!this.trajectories || this.trajectories.length === 0) return;
1447
+
1448
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1449
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1450
+
1451
+ this.trajectories.forEach(trajectoryGroup => {
1452
+ const { marker, lineSegments } = trajectoryGroup.userData;
1453
+
1454
+ marker.geometry.dispose();
1455
+ marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
1456
+
1457
+ // Line width only affects rich mode
1458
+ lineSegments.forEach(segment => {
1459
+ if (segment.material) {
1460
+ segment.material.linewidth = lineWidth;
1461
+ }
1462
+ });
1463
+ });
1464
+
1465
+ this.updateTrajectories(this.currentFrame);
1466
+ }
1467
+
1468
+ getDepthColor(normalizedDepth) {
1469
+ const hue = (1 - normalizedDepth) * 240 / 360;
1470
+ const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
1471
+ return color;
1472
+ }
1473
+
1474
+ getFrame(typedArray, shape, frameIndex) {
1475
+ const [T, H, W, C] = shape;
1476
+ const frameSize = H * W * C;
1477
+ const offset = frameIndex * frameSize;
1478
+ return typedArray.subarray(offset, offset + frameSize);
1479
+ }
1480
+
1481
+ get3x3Matrix(typedArray, shape, frameIndex) {
1482
+ const frameSize = 9;
1483
+ const offset = frameIndex * frameSize;
1484
+ const K = [];
1485
+ for (let i = 0; i < 3; i++) {
1486
+ const row = [];
1487
+ for (let j = 0; j < 3; j++) {
1488
+ row.push(typedArray[offset + i * 3 + j]);
1489
+ }
1490
+ K.push(row);
1491
+ }
1492
+ return K;
1493
+ }
1494
+
1495
+ get4x4Matrix(typedArray, shape, frameIndex) {
1496
+ const frameSize = 16;
1497
+ const offset = frameIndex * frameSize;
1498
+ const M = [];
1499
+ for (let i = 0; i < 4; i++) {
1500
+ const row = [];
1501
+ for (let j = 0; j < 4; j++) {
1502
+ row.push(typedArray[offset + i * 4 + j]);
1503
+ }
1504
+ M.push(row);
1505
+ }
1506
+ return M;
1507
+ }
1508
+
1509
+ getTransformElements(matrix) {
1510
+ return {
1511
+ m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
1512
+ m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
1513
+ m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
1514
+ };
1515
+ }
1516
+
1517
+ togglePlayback() {
1518
+ this.isPlaying = !this.isPlaying;
1519
+
1520
+ const playIcon = document.getElementById('play-icon');
1521
+ const pauseIcon = document.getElementById('pause-icon');
1522
+
1523
+ if (this.isPlaying) {
1524
+ playIcon.style.display = 'none';
1525
+ pauseIcon.style.display = 'block';
1526
+ this.lastFrameTime = performance.now();
1527
+ } else {
1528
+ playIcon.style.display = 'block';
1529
+ pauseIcon.style.display = 'none';
1530
+ }
1531
+ }
1532
+
1533
+ cyclePlaybackSpeed() {
1534
+ const speeds = [0.5, 1, 2, 4, 8];
1535
+ const speedRates = speeds.map(s => s * this.config.baseFrameRate);
1536
+
1537
+ let currentIndex = 0;
1538
+ const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
1539
+
1540
+ for (let i = 0; i < speeds.length; i++) {
1541
+ if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
1542
+ currentIndex = i;
1543
+ }
1544
+ }
1545
+
1546
+ const nextIndex = (currentIndex + 1) % speeds.length;
1547
+ this.playbackSpeed = speedRates[nextIndex];
1548
+ this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
1549
+
1550
+ if (speeds[nextIndex] === 1) {
1551
+ this.ui.speedBtn.classList.remove('active');
1552
+ } else {
1553
+ this.ui.speedBtn.classList.add('active');
1554
+ }
1555
+ }
1556
+
1557
+ seekTo(position) {
1558
+ const frameIndex = Math.floor(position * this.config.totalFrames);
1559
+ this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
1560
+ this.updatePointCloud(this.currentFrame);
1561
+ }
1562
+
1563
+ updatePointCloudSettings() {
1564
+ if (!this.pointCloud) return;
1565
+
1566
+ const size = parseFloat(this.ui.pointSize.value);
1567
+ const opacity = parseFloat(this.ui.pointOpacity.value);
1568
+
1569
+ this.pointCloud.material.size = size;
1570
+ this.pointCloud.material.opacity = opacity;
1571
+ this.pointCloud.material.needsUpdate = true;
1572
+
1573
+ this.updatePointCloud(this.currentFrame);
1574
+ }
1575
+
1576
+ updateControls() {
1577
+ if (!this.controls) return;
1578
+ this.controls.update();
1579
+ }
1580
+
1581
+ resetView() {
1582
+ if (!this.camera || !this.controls) return;
1583
+
1584
+ // Reset camera position
1585
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1586
+
1587
+ // Reset controls
1588
+ this.controls.reset();
1589
+
1590
+ // Set target slightly in front of camera
1591
+ this.controls.target.set(0, 0, -1);
1592
+ this.controls.update();
1593
+
1594
+ // Show status message
1595
+ this.ui.statusBar.textContent = "View reset";
1596
+ this.ui.statusBar.classList.remove('hidden');
1597
+
1598
+ // Hide status message after a few seconds
1599
+ setTimeout(() => {
1600
+ this.ui.statusBar.classList.add('hidden');
1601
+ }, 3000);
1602
+ }
1603
+
1604
+ onWindowResize() {
1605
+ if (!this.camera || !this.renderer) return;
1606
+
1607
+ const windowAspect = window.innerWidth / window.innerHeight;
1608
+ this.camera.aspect = windowAspect;
1609
+ this.camera.updateProjectionMatrix();
1610
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
1611
+
1612
+ if (this.trajectories && this.trajectories.length > 0) {
1613
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1614
+ this.trajectories.forEach(trajectory => {
1615
+ const { lineSegments } = trajectory.userData;
1616
+ if (lineSegments && lineSegments.length > 0) {
1617
+ lineSegments.forEach(segment => {
1618
+ if (segment.material && segment.material.resolution) {
1619
+ segment.material.resolution.copy(resolution);
1620
+ }
1621
+ });
1622
+ }
1623
+ });
1624
+ }
1625
+
1626
+ if (this.cameraFrustum) {
1627
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1628
+ this.cameraFrustum.children.forEach(line => {
1629
+ if (line.material && line.material.resolution) {
1630
+ line.material.resolution.copy(resolution);
1631
+ }
1632
+ });
1633
+ }
1634
+ }
1635
+
1636
+ startAnimation() {
1637
+ this.isPlaying = true;
1638
+ this.lastFrameTime = performance.now();
1639
+
1640
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1641
+ this.controls.target.set(0, 0, -1);
1642
+ this.controls.update();
1643
+
1644
+ this.playbackSpeed = this.config.baseFrameRate;
1645
+
1646
+ document.getElementById('play-icon').style.display = 'none';
1647
+ document.getElementById('pause-icon').style.display = 'block';
1648
+
1649
+ this.animate();
1650
+ }
1651
+
1652
+ animate() {
1653
+ requestAnimationFrame(() => this.animate());
1654
+
1655
+ if (this.controls) {
1656
+ this.controls.update();
1657
+ }
1658
+
1659
+ if (this.isPlaying && this.data) {
1660
+ const now = performance.now();
1661
+ const delta = (now - this.lastFrameTime) / 1000;
1662
+
1663
+ const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
1664
+ if (framesToAdvance > 0) {
1665
+ this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
1666
+ this.lastFrameTime = now;
1667
+ this.updatePointCloud(this.currentFrame);
1668
+ }
1669
+ }
1670
+
1671
+ if (this.renderer && this.scene && this.camera) {
1672
+ this.renderer.render(this.scene, this.camera);
1673
+ }
1674
+ }
1675
+
1676
+ initCameraWithCorrectFOV() {
1677
+ const fov = this.config.fov || 60;
1678
+
1679
+ const windowAspect = window.innerWidth / window.innerHeight;
1680
+
1681
+ this.camera = new THREE.PerspectiveCamera(
1682
+ fov,
1683
+ windowAspect,
1684
+ 0.1,
1685
+ 10000
1686
+ );
1687
+
1688
+ this.controls.object = this.camera;
1689
+ this.controls.update();
1690
+
1691
+ this.initCameraFrustum();
1692
+ }
1693
+
1694
+ initCameraFrustum() {
1695
+ this.cameraFrustum = new THREE.Group();
1696
+
1697
+ this.scene.add(this.cameraFrustum);
1698
+
1699
+ this.initCameraFrustumGeometry();
1700
+
1701
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
1702
+
1703
+ this.cameraFrustum.visible = showCameraFrustum;
1704
+ }
1705
+
1706
+ initCameraFrustumGeometry() {
1707
+ const fov = this.config.fov || 60;
1708
+ const originalAspect = this.config.original_aspect_ratio || 1.33;
1709
+
1710
+ const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
1711
+
1712
+ const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
1713
+ const halfWidth = halfHeight * originalAspect;
1714
+
1715
+ const vertices = [
1716
+ new THREE.Vector3(0, 0, 0),
1717
+ new THREE.Vector3(-halfWidth, -halfHeight, size),
1718
+ new THREE.Vector3(halfWidth, -halfHeight, size),
1719
+ new THREE.Vector3(halfWidth, halfHeight, size),
1720
+ new THREE.Vector3(-halfWidth, halfHeight, size)
1721
+ ];
1722
+
1723
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1724
+
1725
+ const linePairs = [
1726
+ [1, 2], [2, 3], [3, 4], [4, 1],
1727
+ [0, 1], [0, 2], [0, 3], [0, 4]
1728
+ ];
1729
+
1730
+ const colors = {
1731
+ edge: new THREE.Color(0x3366ff),
1732
+ ray: new THREE.Color(0x33cc66)
1733
+ };
1734
+
1735
+ linePairs.forEach((pair, index) => {
1736
+ const positions = [
1737
+ vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
1738
+ vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
1739
+ ];
1740
+
1741
+ const lineGeometry = new THREE.LineGeometry();
1742
+ lineGeometry.setPositions(positions);
1743
+
1744
+ let color = index < 4 ? colors.edge : colors.ray;
1745
+
1746
+ const lineMaterial = new THREE.LineMaterial({
1747
+ color: color,
1748
+ linewidth: 2,
1749
+ resolution: resolution,
1750
+ dashed: false
1751
+ });
1752
+
1753
+ const line = new THREE.Line2(lineGeometry, lineMaterial);
1754
+ this.cameraFrustum.add(line);
1755
+ });
1756
+ }
1757
+
1758
+ updateCameraFrustum(frameIndex) {
1759
+ if (!this.cameraFrustum || !this.data) return;
1760
+
1761
+ const invExtrinsics = this.data.inv_extrinsics;
1762
+ if (!invExtrinsics) return;
1763
+
1764
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1765
+
1766
+ const matrix = new THREE.Matrix4();
1767
+ matrix.set(
1768
+ invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
1769
+ invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
1770
+ invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
1771
+ invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
1772
+ );
1773
+
1774
+ const position = new THREE.Vector3();
1775
+ position.setFromMatrixPosition(matrix);
1776
+
1777
+ const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
1778
+
1779
+ const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
1780
+
1781
+ const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
1782
+
1783
+ const quaternion = new THREE.Quaternion();
1784
+ quaternion.setFromRotationMatrix(finalRotation);
1785
+
1786
+ position.y = -position.y;
1787
+ position.z = -position.z;
1788
+
1789
+ this.cameraFrustum.position.copy(position);
1790
+ this.cameraFrustum.quaternion.copy(quaternion);
1791
+
1792
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
1793
+
1794
+ if (this.cameraFrustum.visible !== showCameraFrustum) {
1795
+ this.cameraFrustum.visible = showCameraFrustum;
1796
+ }
1797
+
1798
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1799
+ this.cameraFrustum.children.forEach(line => {
1800
+ if (line.material && line.material.resolution) {
1801
+ line.material.resolution.copy(resolution);
1802
+ }
1803
+ });
1804
+ }
1805
+
1806
+ updateFrustumDimensions() {
1807
+ if (!this.cameraFrustum) return;
1808
+
1809
+ while(this.cameraFrustum.children.length > 0) {
1810
+ const child = this.cameraFrustum.children[0];
1811
+ if (child.geometry) child.geometry.dispose();
1812
+ if (child.material) child.material.dispose();
1813
+ this.cameraFrustum.remove(child);
1814
+ }
1815
+
1816
+ this.initCameraFrustumGeometry();
1817
+
1818
+ this.updateCameraFrustum(this.currentFrame);
1819
+ }
1820
+
1821
+ // Keep History methods
1822
+ updateHistory(frameIndex) {
1823
+ if (!this.ui.enableKeepHistory.checked || !this.data) return;
1824
+
1825
+ const stride = parseInt(this.ui.historyStride.value);
1826
+ const newHistoryFrames = this.calculateHistoryFrames(frameIndex, stride);
1827
+
1828
+ // Check if history frames changed
1829
+ if (this.arraysEqual(this.historyFrames, newHistoryFrames)) return;
1830
+
1831
+ this.clearHistory();
1832
+ this.historyFrames = newHistoryFrames;
1833
+
1834
+ // Create history point clouds and trajectories
1835
+ this.historyFrames.forEach(historyFrame => {
1836
+ if (historyFrame !== frameIndex) {
1837
+ this.createHistoryPointCloud(historyFrame);
1838
+ this.createHistoryTrajectories(historyFrame);
1839
+ }
1840
+ });
1841
+ }
1842
+
1843
+ calculateHistoryFrames(currentFrame, stride) {
1844
+ const frames = [];
1845
+ let frame = 1; // Start from frame 1
1846
+
1847
+ while (frame <= currentFrame && frames.length < this.maxHistoryFrames) {
1848
+ frames.push(frame);
1849
+ frame += stride;
1850
+ }
1851
+
1852
+ // Always include current frame
1853
+ if (!frames.includes(currentFrame)) {
1854
+ frames.push(currentFrame);
1855
+ }
1856
+
1857
+ return frames.sort((a, b) => a - b);
1858
+ }
1859
+
1860
+ createHistoryPointCloud(frameIndex) {
1861
+ const numPoints = this.config.resolution[0] * this.config.resolution[1];
1862
+ const positions = new Float32Array(numPoints * 3);
1863
+ const colors = new Float32Array(numPoints * 3);
1864
+
1865
+ const geometry = new THREE.BufferGeometry();
1866
+ geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
1867
+ geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
1868
+
1869
+ const material = new THREE.PointsMaterial({
1870
+ size: parseFloat(this.ui.pointSize.value),
1871
+ vertexColors: true,
1872
+ transparent: true,
1873
+ opacity: 0.5, // Transparent for history
1874
+ sizeAttenuation: true
1875
+ });
1876
+
1877
+ const historyPointCloud = new THREE.Points(geometry, material);
1878
+ this.scene.add(historyPointCloud);
1879
+ this.historyPointClouds.push(historyPointCloud);
1880
+
1881
+ // Update the history point cloud with data
1882
+ this.updateHistoryPointCloud(historyPointCloud, frameIndex);
1883
+ }
1884
+
1885
+ updateHistoryPointCloud(pointCloud, frameIndex) {
1886
+ const positions = pointCloud.geometry.attributes.position.array;
1887
+ const colors = pointCloud.geometry.attributes.color.array;
1888
+
1889
+ const rgbVideo = this.data.rgb_video;
1890
+ const depthsRgb = this.data.depths_rgb;
1891
+ const intrinsics = this.data.intrinsics;
1892
+ const invExtrinsics = this.data.inv_extrinsics;
1893
+
1894
+ const width = this.config.resolution[0];
1895
+ const height = this.config.resolution[1];
1896
+ const numPoints = width * height;
1897
+
1898
+ const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
1899
+ const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
1900
+
1901
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1902
+ const transform = this.getTransformElements(invExtrMat);
1903
+
1904
+ const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
1905
+ const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
1906
+
1907
+ const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
1908
+
1909
+ let validPointCount = 0;
1910
+
1911
+ for (let i = 0; i < numPoints; i++) {
1912
+ const xPix = i % width;
1913
+ const yPix = Math.floor(i / width);
1914
+
1915
+ const d0 = depthFrame[i * 3];
1916
+ const d1 = depthFrame[i * 3 + 1];
1917
+ const depthEncoded = d0 | (d1 << 8);
1918
+ const depthValue = (depthEncoded / ((1 << 16) - 1)) *
1919
+ (this.config.depthRange[1] - this.config.depthRange[0]) +
1920
+ this.config.depthRange[0];
1921
+
1922
+ if (depthValue === 0 || depthValue > maxDepth) {
1923
+ continue;
1924
+ }
1925
+
1926
+ const X = ((xPix - cx) * depthValue) / fx;
1927
+ const Y = ((yPix - cy) * depthValue) / fy;
1928
+ const Z = depthValue;
1929
+
1930
+ const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
1931
+ const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
1932
+ const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
1933
+
1934
+ const index = validPointCount * 3;
1935
+ positions[index] = tx;
1936
+ positions[index + 1] = -ty;
1937
+ positions[index + 2] = -tz;
1938
+
1939
+ colors[index] = rgbFrame[i * 3] / 255;
1940
+ colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
1941
+ colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
1942
+
1943
+ validPointCount++;
1944
+ }
1945
+
1946
+ pointCloud.geometry.setDrawRange(0, validPointCount);
1947
+ pointCloud.geometry.attributes.position.needsUpdate = true;
1948
+ pointCloud.geometry.attributes.color.needsUpdate = true;
1949
+ }
1950
+
1951
+ createHistoryTrajectories(frameIndex) {
1952
+ if (!this.data.trajectories) return;
1953
+
1954
+ const trajectoryData = this.data.trajectories.data;
1955
+ const [totalFrames, numTrajectories] = this.data.trajectories.shape;
1956
+ const palette = this.createColorPalette(numTrajectories);
1957
+
1958
+ const historyTrajectoryGroup = new THREE.Group();
1959
+
1960
+ for (let i = 0; i < numTrajectories; i++) {
1961
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1962
+ const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
1963
+ const sphereMaterial = new THREE.MeshBasicMaterial({
1964
+ color: palette[i],
1965
+ transparent: true,
1966
+ opacity: 0.3 // Transparent for history
1967
+ });
1968
+ const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
1969
+
1970
+ const currentOffset = (frameIndex * numTrajectories + i) * 3;
1971
+ positionMarker.position.set(
1972
+ trajectoryData[currentOffset],
1973
+ -trajectoryData[currentOffset + 1],
1974
+ -trajectoryData[currentOffset + 2]
1975
+ );
1976
+
1977
+ historyTrajectoryGroup.add(positionMarker);
1978
+ }
1979
+
1980
+ this.scene.add(historyTrajectoryGroup);
1981
+ this.historyTrajectories.push(historyTrajectoryGroup);
1982
+ }
1983
+
1984
+ clearHistory() {
1985
+ // Clear history point clouds
1986
+ this.historyPointClouds.forEach(pointCloud => {
1987
+ if (pointCloud.geometry) pointCloud.geometry.dispose();
1988
+ if (pointCloud.material) pointCloud.material.dispose();
1989
+ this.scene.remove(pointCloud);
1990
+ });
1991
+ this.historyPointClouds = [];
1992
+
1993
+ // Clear history trajectories
1994
+ this.historyTrajectories.forEach(trajectoryGroup => {
1995
+ trajectoryGroup.children.forEach(child => {
1996
+ if (child.geometry) child.geometry.dispose();
1997
+ if (child.material) child.material.dispose();
1998
+ });
1999
+ this.scene.remove(trajectoryGroup);
2000
+ });
2001
+ this.historyTrajectories = [];
2002
+
2003
+ this.historyFrames = [];
2004
+ }
2005
+
2006
+ arraysEqual(a, b) {
2007
+ if (a.length !== b.length) return false;
2008
+ for (let i = 0; i < a.length; i++) {
2009
+ if (a[i] !== b[i]) return false;
2010
+ }
2011
+ return true;
2012
+ }
2013
+
2014
+ toggleBackground() {
2015
+ const isWhiteBackground = this.ui.whiteBackground.checked;
2016
+
2017
+ if (isWhiteBackground) {
2018
+ // Switch to white background
2019
+ document.body.style.backgroundColor = '#ffffff';
2020
+ this.scene.background = new THREE.Color(0xffffff);
2021
+
2022
+ // Update UI elements for white background
2023
+ document.documentElement.style.setProperty('--bg', '#ffffff');
2024
+ document.documentElement.style.setProperty('--text', '#333333');
2025
+ document.documentElement.style.setProperty('--text-secondary', '#666666');
2026
+ document.documentElement.style.setProperty('--border', '#cccccc');
2027
+ document.documentElement.style.setProperty('--surface', '#f5f5f5');
2028
+ document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.1)');
2029
+ document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.2)');
2030
+
2031
+ // Update status bar and control panel backgrounds
2032
+ this.ui.statusBar.style.background = 'rgba(245, 245, 245, 0.9)';
2033
+ this.ui.statusBar.style.color = '#333333';
2034
+
2035
+ const controlPanel = document.getElementById('control-panel');
2036
+ if (controlPanel) {
2037
+ controlPanel.style.background = 'rgba(245, 245, 245, 0.95)';
2038
+ }
2039
+
2040
+ const settingsPanel = document.getElementById('settings-panel');
2041
+ if (settingsPanel) {
2042
+ settingsPanel.style.background = 'rgba(245, 245, 245, 0.98)';
2043
+ }
2044
+
2045
+ } else {
2046
+ // Switch back to dark background
2047
+ document.body.style.backgroundColor = '#1a1a1a';
2048
+ this.scene.background = new THREE.Color(0x1a1a1a);
2049
+
2050
+ // Restore original dark theme variables
2051
+ document.documentElement.style.setProperty('--bg', '#1a1a1a');
2052
+ document.documentElement.style.setProperty('--text', '#e0e0e0');
2053
+ document.documentElement.style.setProperty('--text-secondary', '#a0a0a0');
2054
+ document.documentElement.style.setProperty('--border', '#444444');
2055
+ document.documentElement.style.setProperty('--surface', '#2c2c2c');
2056
+ document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.2)');
2057
+ document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.3)');
2058
+
2059
+ // Restore original UI backgrounds
2060
+ this.ui.statusBar.style.background = 'rgba(30, 30, 30, 0.9)';
2061
+ this.ui.statusBar.style.color = '#e0e0e0';
2062
+
2063
+ const controlPanel = document.getElementById('control-panel');
2064
+ if (controlPanel) {
2065
+ controlPanel.style.background = 'rgba(44, 44, 44, 0.95)';
2066
+ }
2067
+
2068
+ const settingsPanel = document.getElementById('settings-panel');
2069
+ if (settingsPanel) {
2070
+ settingsPanel.style.background = 'rgba(44, 44, 44, 0.98)';
2071
+ }
2072
+ }
2073
+
2074
+ // Show status message
2075
+ this.ui.statusBar.textContent = isWhiteBackground ? "Switched to white background" : "Switched to dark background";
2076
+ this.ui.statusBar.classList.remove('hidden');
2077
+
2078
+ setTimeout(() => {
2079
+ this.ui.statusBar.classList.add('hidden');
2080
+ }, 2000);
2081
+ }
2082
+
2083
+ resetSettings() {
2084
+ if (!this.defaultSettings) return;
2085
+
2086
+ this.applyDefaultSettings();
2087
+
2088
+ // Reset background to dark theme
2089
+ if (this.ui.whiteBackground) {
2090
+ this.ui.whiteBackground.checked = false;
2091
+ this.toggleBackground();
2092
+ }
2093
+
2094
+ this.updatePointCloudSettings();
2095
+ this.updateTrajectorySettings();
2096
+ this.updateFrustumDimensions();
2097
+
2098
+ // Clear history when resetting settings
2099
+ this.clearHistory();
2100
+
2101
+ this.ui.statusBar.textContent = "Settings reset to defaults";
2102
+ this.ui.statusBar.classList.remove('hidden');
2103
+
2104
+ setTimeout(() => {
2105
+ this.ui.statusBar.classList.add('hidden');
2106
+ }, 3000);
2107
+ }
2108
+ }
2109
+
2110
+ window.addEventListener('DOMContentLoaded', () => {
2111
+ new PointCloudVisualizer();
2112
+ });
2113
+ </script>
2114
+ </body>
2115
+ </html>