Spaces:
Running
on
Zero
Running
on
Zero
xiaoyuxi
commited on
Commit
·
9193cab
1
Parent(s):
cd14f82
support HubMixin
Browse files- LICENSE.txt +409 -0
- config/magic_infer_offline.yaml +47 -0
- config/magic_infer_online.yaml +47 -0
- docs/PAPER.md +4 -0
- inference.py +184 -0
- models/SpaTrackV2/models/vggt4track/__init__.py +1 -0
- models/SpaTrackV2/models/vggt4track/heads/camera_head.py +162 -0
- models/SpaTrackV2/models/vggt4track/heads/dpt_head.py +497 -0
- models/SpaTrackV2/models/vggt4track/heads/head_act.py +125 -0
- models/SpaTrackV2/models/vggt4track/heads/scale_head.py +162 -0
- models/SpaTrackV2/models/vggt4track/heads/track_head.py +108 -0
- models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py +5 -0
- models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py +209 -0
- models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py +246 -0
- models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py +218 -0
- models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py +226 -0
- models/SpaTrackV2/models/vggt4track/heads/utils.py +109 -0
- models/SpaTrackV2/models/vggt4track/layers/__init__.py +11 -0
- models/SpaTrackV2/models/vggt4track/layers/attention.py +98 -0
- models/SpaTrackV2/models/vggt4track/layers/block.py +259 -0
- models/SpaTrackV2/models/vggt4track/layers/drop_path.py +34 -0
- models/SpaTrackV2/models/vggt4track/layers/layer_scale.py +27 -0
- models/SpaTrackV2/models/vggt4track/layers/mlp.py +40 -0
- models/SpaTrackV2/models/vggt4track/layers/patch_embed.py +88 -0
- models/SpaTrackV2/models/vggt4track/layers/rope.py +188 -0
- models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py +72 -0
- models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py +407 -0
- models/SpaTrackV2/models/vggt4track/models/aggregator.py +338 -0
- models/SpaTrackV2/models/vggt4track/models/aggregator_front.py +342 -0
- models/SpaTrackV2/models/vggt4track/models/tracker_front.py +132 -0
- models/SpaTrackV2/models/vggt4track/models/vggt.py +96 -0
- models/SpaTrackV2/models/vggt4track/models/vggt_moe.py +107 -0
- models/SpaTrackV2/models/vggt4track/utils/__init__.py +1 -0
- models/SpaTrackV2/models/vggt4track/utils/geometry.py +166 -0
- models/SpaTrackV2/models/vggt4track/utils/load_fn.py +200 -0
- models/SpaTrackV2/models/vggt4track/utils/loss.py +123 -0
- models/SpaTrackV2/models/vggt4track/utils/pose_enc.py +130 -0
- models/SpaTrackV2/models/vggt4track/utils/rotation.py +138 -0
- models/SpaTrackV2/models/vggt4track/utils/visual_track.py +239 -0
- scripts/download.sh +5 -0
- viz.html +2115 -0
LICENSE.txt
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
88 |
+
closely related to copyright including, without limitation,
|
89 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
90 |
+
Rights, without regard to how the rights are labeled or
|
91 |
+
categorized. For purposes of this Public License, the rights
|
92 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
93 |
+
Rights.
|
94 |
+
d. Effective Technological Measures means those measures that, in the
|
95 |
+
absence of proper authority, may not be circumvented under laws
|
96 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
97 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
98 |
+
agreements.
|
99 |
+
|
100 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
101 |
+
any other exception or limitation to Copyright and Similar Rights
|
102 |
+
that applies to Your use of the Licensed Material.
|
103 |
+
|
104 |
+
f. Licensed Material means the artistic or literary work, database,
|
105 |
+
or other material to which the Licensor applied this Public
|
106 |
+
License.
|
107 |
+
|
108 |
+
g. Licensed Rights means the rights granted to You subject to the
|
109 |
+
terms and conditions of this Public License, which are limited to
|
110 |
+
all Copyright and Similar Rights that apply to Your use of the
|
111 |
+
Licensed Material and that the Licensor has authority to license.
|
112 |
+
|
113 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
114 |
+
under this Public License.
|
115 |
+
|
116 |
+
i. NonCommercial means not primarily intended for or directed towards
|
117 |
+
commercial advantage or monetary compensation. For purposes of
|
118 |
+
this Public License, the exchange of the Licensed Material for
|
119 |
+
other material subject to Copyright and Similar Rights by digital
|
120 |
+
file-sharing or similar means is NonCommercial provided there is
|
121 |
+
no payment of monetary compensation in connection with the
|
122 |
+
exchange.
|
123 |
+
|
124 |
+
j. Share means to provide material to the public by any means or
|
125 |
+
process that requires permission under the Licensed Rights, such
|
126 |
+
as reproduction, public display, public performance, distribution,
|
127 |
+
dissemination, communication, or importation, and to make material
|
128 |
+
available to the public including in ways that members of the
|
129 |
+
public may access the material from a place and at a time
|
130 |
+
individually chosen by them.
|
131 |
+
|
132 |
+
k. Sui Generis Database Rights means rights other than copyright
|
133 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
134 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
135 |
+
as amended and/or succeeded, as well as other essentially
|
136 |
+
equivalent rights anywhere in the world.
|
137 |
+
|
138 |
+
l. You means the individual or entity exercising the Licensed Rights
|
139 |
+
under this Public License. Your has a corresponding meaning.
|
140 |
+
|
141 |
+
|
142 |
+
Section 2 -- Scope.
|
143 |
+
|
144 |
+
a. License grant.
|
145 |
+
|
146 |
+
1. Subject to the terms and conditions of this Public License,
|
147 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
148 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
149 |
+
exercise the Licensed Rights in the Licensed Material to:
|
150 |
+
|
151 |
+
a. reproduce and Share the Licensed Material, in whole or
|
152 |
+
in part, for NonCommercial purposes only; and
|
153 |
+
|
154 |
+
b. produce, reproduce, and Share Adapted Material for
|
155 |
+
NonCommercial purposes only.
|
156 |
+
|
157 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
158 |
+
Exceptions and Limitations apply to Your use, this Public
|
159 |
+
License does not apply, and You do not need to comply with
|
160 |
+
its terms and conditions.
|
161 |
+
|
162 |
+
3. Term. The term of this Public License is specified in Section
|
163 |
+
6(a).
|
164 |
+
|
165 |
+
4. Media and formats; technical modifications allowed. The
|
166 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
167 |
+
all media and formats whether now known or hereafter created,
|
168 |
+
and to make technical modifications necessary to do so. The
|
169 |
+
Licensor waives and/or agrees not to assert any right or
|
170 |
+
authority to forbid You from making technical modifications
|
171 |
+
necessary to exercise the Licensed Rights, including
|
172 |
+
technical modifications necessary to circumvent Effective
|
173 |
+
Technological Measures. For purposes of this Public License,
|
174 |
+
simply making modifications authorized by this Section 2(a)
|
175 |
+
(4) never produces Adapted Material.
|
176 |
+
|
177 |
+
5. Downstream recipients.
|
178 |
+
|
179 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
180 |
+
recipient of the Licensed Material automatically
|
181 |
+
receives an offer from the Licensor to exercise the
|
182 |
+
Licensed Rights under the terms and conditions of this
|
183 |
+
Public License.
|
184 |
+
|
185 |
+
b. No downstream restrictions. You may not offer or impose
|
186 |
+
any additional or different terms or conditions on, or
|
187 |
+
apply any Effective Technological Measures to, the
|
188 |
+
Licensed Material if doing so restricts exercise of the
|
189 |
+
Licensed Rights by any recipient of the Licensed
|
190 |
+
Material.
|
191 |
+
|
192 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
193 |
+
may be construed as permission to assert or imply that You
|
194 |
+
are, or that Your use of the Licensed Material is, connected
|
195 |
+
with, or sponsored, endorsed, or granted official status by,
|
196 |
+
the Licensor or others designated to receive attribution as
|
197 |
+
provided in Section 3(a)(1)(A)(i).
|
198 |
+
|
199 |
+
b. Other rights.
|
200 |
+
|
201 |
+
1. Moral rights, such as the right of integrity, are not
|
202 |
+
licensed under this Public License, nor are publicity,
|
203 |
+
privacy, and/or other similar personality rights; however, to
|
204 |
+
the extent possible, the Licensor waives and/or agrees not to
|
205 |
+
assert any such rights held by the Licensor to the limited
|
206 |
+
extent necessary to allow You to exercise the Licensed
|
207 |
+
Rights, but not otherwise.
|
208 |
+
|
209 |
+
2. Patent and trademark rights are not licensed under this
|
210 |
+
Public License.
|
211 |
+
|
212 |
+
3. To the extent possible, the Licensor waives any right to
|
213 |
+
collect royalties from You for the exercise of the Licensed
|
214 |
+
Rights, whether directly or through a collecting society
|
215 |
+
under any voluntary or waivable statutory or compulsory
|
216 |
+
licensing scheme. In all other cases the Licensor expressly
|
217 |
+
reserves any right to collect such royalties, including when
|
218 |
+
the Licensed Material is used other than for NonCommercial
|
219 |
+
purposes.
|
220 |
+
|
221 |
+
|
222 |
+
Section 3 -- License Conditions.
|
223 |
+
|
224 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
225 |
+
following conditions.
|
226 |
+
|
227 |
+
a. Attribution.
|
228 |
+
|
229 |
+
1. If You Share the Licensed Material (including in modified
|
230 |
+
form), You must:
|
231 |
+
|
232 |
+
a. retain the following if it is supplied by the Licensor
|
233 |
+
with the Licensed Material:
|
234 |
+
|
235 |
+
i. identification of the creator(s) of the Licensed
|
236 |
+
Material and any others designated to receive
|
237 |
+
attribution, in any reasonable manner requested by
|
238 |
+
the Licensor (including by pseudonym if
|
239 |
+
designated);
|
240 |
+
|
241 |
+
ii. a copyright notice;
|
242 |
+
|
243 |
+
iii. a notice that refers to this Public License;
|
244 |
+
|
245 |
+
iv. a notice that refers to the disclaimer of
|
246 |
+
warranties;
|
247 |
+
|
248 |
+
v. a URI or hyperlink to the Licensed Material to the
|
249 |
+
extent reasonably practicable;
|
250 |
+
|
251 |
+
b. indicate if You modified the Licensed Material and
|
252 |
+
retain an indication of any previous modifications; and
|
253 |
+
|
254 |
+
c. indicate the Licensed Material is licensed under this
|
255 |
+
Public License, and include the text of, or the URI or
|
256 |
+
hyperlink to, this Public License.
|
257 |
+
|
258 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
259 |
+
reasonable manner based on the medium, means, and context in
|
260 |
+
which You Share the Licensed Material. For example, it may be
|
261 |
+
reasonable to satisfy the conditions by providing a URI or
|
262 |
+
hyperlink to a resource that includes the required
|
263 |
+
information.
|
264 |
+
|
265 |
+
3. If requested by the Licensor, You must remove any of the
|
266 |
+
information required by Section 3(a)(1)(A) to the extent
|
267 |
+
reasonably practicable.
|
268 |
+
|
269 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
270 |
+
License You apply must not prevent recipients of the Adapted
|
271 |
+
Material from complying with this Public License.
|
272 |
+
|
273 |
+
|
274 |
+
Section 4 -- Sui Generis Database Rights.
|
275 |
+
|
276 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
277 |
+
apply to Your use of the Licensed Material:
|
278 |
+
|
279 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
280 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
281 |
+
portion of the contents of the database for NonCommercial purposes
|
282 |
+
only;
|
283 |
+
|
284 |
+
b. if You include all or a substantial portion of the database
|
285 |
+
contents in a database in which You have Sui Generis Database
|
286 |
+
Rights, then the database in which You have Sui Generis Database
|
287 |
+
Rights (but not its individual contents) is Adapted Material; and
|
288 |
+
|
289 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
290 |
+
all or a substantial portion of the contents of the database.
|
291 |
+
|
292 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
293 |
+
replace Your obligations under this Public License where the Licensed
|
294 |
+
Rights include other Copyright and Similar Rights.
|
295 |
+
|
296 |
+
|
297 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
298 |
+
|
299 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
300 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
301 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
302 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
303 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
304 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
305 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
306 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
307 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
308 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
309 |
+
|
310 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
311 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
312 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
313 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
314 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
315 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
316 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
317 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
318 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
319 |
+
|
320 |
+
c. The disclaimer of warranties and limitation of liability provided
|
321 |
+
above shall be interpreted in a manner that, to the extent
|
322 |
+
possible, most closely approximates an absolute disclaimer and
|
323 |
+
waiver of all liability.
|
324 |
+
|
325 |
+
|
326 |
+
Section 6 -- Term and Termination.
|
327 |
+
|
328 |
+
a. This Public License applies for the term of the Copyright and
|
329 |
+
Similar Rights licensed here. However, if You fail to comply with
|
330 |
+
this Public License, then Your rights under this Public License
|
331 |
+
terminate automatically.
|
332 |
+
|
333 |
+
b. Where Your right to use the Licensed Material has terminated under
|
334 |
+
Section 6(a), it reinstates:
|
335 |
+
|
336 |
+
1. automatically as of the date the violation is cured, provided
|
337 |
+
it is cured within 30 days of Your discovery of the
|
338 |
+
violation; or
|
339 |
+
|
340 |
+
2. upon express reinstatement by the Licensor.
|
341 |
+
|
342 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
343 |
+
right the Licensor may have to seek remedies for Your violations
|
344 |
+
of this Public License.
|
345 |
+
|
346 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
347 |
+
Licensed Material under separate terms or conditions or stop
|
348 |
+
distributing the Licensed Material at any time; however, doing so
|
349 |
+
will not terminate this Public License.
|
350 |
+
|
351 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
352 |
+
License.
|
353 |
+
|
354 |
+
|
355 |
+
Section 7 -- Other Terms and Conditions.
|
356 |
+
|
357 |
+
a. The Licensor shall not be bound by any additional or different
|
358 |
+
terms or conditions communicated by You unless expressly agreed.
|
359 |
+
|
360 |
+
b. Any arrangements, understandings, or agreements regarding the
|
361 |
+
Licensed Material not stated herein are separate from and
|
362 |
+
independent of the terms and conditions of this Public License.
|
363 |
+
|
364 |
+
|
365 |
+
Section 8 -- Interpretation.
|
366 |
+
|
367 |
+
a. For the avoidance of doubt, this Public License does not, and
|
368 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
369 |
+
conditions on any use of the Licensed Material that could lawfully
|
370 |
+
be made without permission under this Public License.
|
371 |
+
|
372 |
+
b. To the extent possible, if any provision of this Public License is
|
373 |
+
deemed unenforceable, it shall be automatically reformed to the
|
374 |
+
minimum extent necessary to make it enforceable. If the provision
|
375 |
+
cannot be reformed, it shall be severed from this Public License
|
376 |
+
without affecting the enforceability of the remaining terms and
|
377 |
+
conditions.
|
378 |
+
|
379 |
+
c. No term or condition of this Public License will be waived and no
|
380 |
+
failure to comply consented to unless expressly agreed to by the
|
381 |
+
Licensor.
|
382 |
+
|
383 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
384 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
385 |
+
that apply to the Licensor or You, including from the legal
|
386 |
+
processes of any jurisdiction or authority.
|
387 |
+
|
388 |
+
=======================================================================
|
389 |
+
|
390 |
+
Creative Commons is not a party to its public
|
391 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
392 |
+
its public licenses to material it publishes and in those instances
|
393 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
394 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
395 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
396 |
+
material is shared under a Creative Commons public license or as
|
397 |
+
otherwise permitted by the Creative Commons policies published at
|
398 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
399 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
400 |
+
of Creative Commons without its prior written consent including,
|
401 |
+
without limitation, in connection with any unauthorized modifications
|
402 |
+
to any of its public licenses or any other arrangements,
|
403 |
+
understandings, or agreements concerning use of licensed material. For
|
404 |
+
the avoidance of doubt, this paragraph does not form part of the
|
405 |
+
public licenses.
|
406 |
+
|
407 |
+
Creative Commons may be contacted at creativecommons.org.
|
408 |
+
|
409 |
+
|
config/magic_infer_offline.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 0
|
2 |
+
# config the hydra logger, only in hydra `$` can be decoded as cite
|
3 |
+
data: ./assets/room
|
4 |
+
vis_track: false
|
5 |
+
hydra:
|
6 |
+
run:
|
7 |
+
dir: .
|
8 |
+
output_subdir: null
|
9 |
+
job_logging: {}
|
10 |
+
hydra_logging: {}
|
11 |
+
mixed_precision: bf16
|
12 |
+
visdom:
|
13 |
+
viz_ip: "localhost"
|
14 |
+
port: 6666
|
15 |
+
relax_load: false
|
16 |
+
res_all: 336
|
17 |
+
# config the ckpt path
|
18 |
+
ckpts: "Yuxihenry/SpatialTrackerCkpts"
|
19 |
+
batch_size: 1
|
20 |
+
input:
|
21 |
+
type: image
|
22 |
+
fps: 1
|
23 |
+
model_wind_size: 32
|
24 |
+
model:
|
25 |
+
backbone_cfg:
|
26 |
+
ckpt_dir: "checkpoints/model.pt"
|
27 |
+
chunk_size: 24 # downsample factor for patchified features
|
28 |
+
ckpt_fwd: true
|
29 |
+
ft_cfg:
|
30 |
+
mode: "fix"
|
31 |
+
paras_name: []
|
32 |
+
resolution: 336
|
33 |
+
max_len: 512
|
34 |
+
Track_cfg:
|
35 |
+
base_ckpt: "checkpoints/scaled_offline.pth"
|
36 |
+
base:
|
37 |
+
stride: 4
|
38 |
+
corr_radius: 3
|
39 |
+
window_len: 60
|
40 |
+
stablizer: True
|
41 |
+
mode: "online"
|
42 |
+
s_wind: 200
|
43 |
+
overlap: 4
|
44 |
+
track_num: 0
|
45 |
+
|
46 |
+
dist_train:
|
47 |
+
num_nodes: 1
|
config/magic_infer_online.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 0
|
2 |
+
# config the hydra logger, only in hydra `$` can be decoded as cite
|
3 |
+
data: ./assets/room
|
4 |
+
vis_track: false
|
5 |
+
hydra:
|
6 |
+
run:
|
7 |
+
dir: .
|
8 |
+
output_subdir: null
|
9 |
+
job_logging: {}
|
10 |
+
hydra_logging: {}
|
11 |
+
mixed_precision: bf16
|
12 |
+
visdom:
|
13 |
+
viz_ip: "localhost"
|
14 |
+
port: 6666
|
15 |
+
relax_load: false
|
16 |
+
res_all: 336
|
17 |
+
# config the ckpt path
|
18 |
+
ckpts: "Yuxihenry/SpatialTrackerCkpts"
|
19 |
+
batch_size: 1
|
20 |
+
input:
|
21 |
+
type: image
|
22 |
+
fps: 1
|
23 |
+
model_wind_size: 32
|
24 |
+
model:
|
25 |
+
backbone_cfg:
|
26 |
+
ckpt_dir: "checkpoints/model.pt"
|
27 |
+
chunk_size: 24 # downsample factor for patchified features
|
28 |
+
ckpt_fwd: true
|
29 |
+
ft_cfg:
|
30 |
+
mode: "fix"
|
31 |
+
paras_name: []
|
32 |
+
resolution: 336
|
33 |
+
max_len: 512
|
34 |
+
Track_cfg:
|
35 |
+
base_ckpt: "checkpoints/scaled_online.pth"
|
36 |
+
base:
|
37 |
+
stride: 4
|
38 |
+
corr_radius: 3
|
39 |
+
window_len: 20
|
40 |
+
stablizer: False
|
41 |
+
mode: "online"
|
42 |
+
s_wind: 20
|
43 |
+
overlap: 6
|
44 |
+
track_num: 0
|
45 |
+
|
46 |
+
dist_train:
|
47 |
+
num_nodes: 1
|
docs/PAPER.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SpatialTrackerV2: Final version paper still polishing, ETA in one week.
|
2 |
+
|
3 |
+
## Overall
|
4 |
+
SpatialTrackerV2 proposes a end-to-end and differentiable pipeline to unify video depth, camera pose and 3D tracking. This unified pipeline enable large-scale joint training of both part in diverse types of data.
|
inference.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pycolmap
|
2 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
3 |
+
import yaml
|
4 |
+
import easydict
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from PIL import Image
|
11 |
+
import io
|
12 |
+
import moviepy.editor as mp
|
13 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
14 |
+
import tqdm
|
15 |
+
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
16 |
+
import glob
|
17 |
+
from rich import print
|
18 |
+
import argparse
|
19 |
+
import decord
|
20 |
+
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
21 |
+
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
22 |
+
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument("--track_mode", type=str, default="offline")
|
27 |
+
parser.add_argument("--data_type", type=str, default="RGBD")
|
28 |
+
parser.add_argument("--data_dir", type=str, default="assets/example0")
|
29 |
+
parser.add_argument("--video_name", type=str, default="snowboard")
|
30 |
+
parser.add_argument("--grid_size", type=int, default=10)
|
31 |
+
parser.add_argument("--vo_points", type=int, default=756)
|
32 |
+
parser.add_argument("--fps", type=int, default=1)
|
33 |
+
return parser.parse_args()
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
args = parse_args()
|
37 |
+
out_dir = args.data_dir + "/results"
|
38 |
+
# fps
|
39 |
+
fps = int(args.fps)
|
40 |
+
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
41 |
+
|
42 |
+
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
43 |
+
vggt4track_model.eval()
|
44 |
+
vggt4track_model = vggt4track_model.to("cuda")
|
45 |
+
|
46 |
+
if args.data_type == "RGBD":
|
47 |
+
npz_dir = args.data_dir + f"/{args.video_name}.npz"
|
48 |
+
data_npz_load = dict(np.load(npz_dir, allow_pickle=True))
|
49 |
+
#TODO: tapip format
|
50 |
+
video_tensor = data_npz_load["video"] * 255
|
51 |
+
video_tensor = torch.from_numpy(video_tensor)
|
52 |
+
video_tensor = video_tensor[::fps]
|
53 |
+
depth_tensor = data_npz_load["depths"]
|
54 |
+
depth_tensor = depth_tensor[::fps]
|
55 |
+
intrs = data_npz_load["intrinsics"]
|
56 |
+
intrs = intrs[::fps]
|
57 |
+
extrs = np.linalg.inv(data_npz_load["extrinsics"])
|
58 |
+
extrs = extrs[::fps]
|
59 |
+
unc_metric = None
|
60 |
+
elif args.data_type == "RGB":
|
61 |
+
vid_dir = os.path.join(args.data_dir, f"{args.video_name}.mp4")
|
62 |
+
video_reader = decord.VideoReader(vid_dir)
|
63 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
64 |
+
video_tensor = video_tensor[::fps].float()
|
65 |
+
|
66 |
+
# process the image tensor
|
67 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
68 |
+
with torch.no_grad():
|
69 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
70 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
71 |
+
predictions = vggt4track_model(video_tensor.cuda()/255)
|
72 |
+
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
73 |
+
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
74 |
+
|
75 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
76 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
77 |
+
extrs = extrinsic.squeeze().cpu().numpy()
|
78 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
79 |
+
video_tensor = video_tensor.squeeze()
|
80 |
+
#NOTE: 20% of the depth is not reliable
|
81 |
+
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
82 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
83 |
+
|
84 |
+
data_npz_load = {}
|
85 |
+
|
86 |
+
if os.path.exists(mask_dir):
|
87 |
+
mask_files = mask_dir
|
88 |
+
mask = cv2.imread(mask_files)
|
89 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
90 |
+
mask = mask.sum(axis=-1)>0
|
91 |
+
else:
|
92 |
+
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
93 |
+
|
94 |
+
# get all data pieces
|
95 |
+
viz = True
|
96 |
+
os.makedirs(out_dir, exist_ok=True)
|
97 |
+
|
98 |
+
# with open(cfg_dir, "r") as f:
|
99 |
+
# cfg = yaml.load(f, Loader=yaml.FullLoader)
|
100 |
+
# cfg = easydict.EasyDict(cfg)
|
101 |
+
# cfg.out_dir = out_dir
|
102 |
+
# cfg.model.track_num = args.vo_points
|
103 |
+
# print(f"Downloading model from HuggingFace: {cfg.ckpts}")
|
104 |
+
if args.track_mode == "offline":
|
105 |
+
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
106 |
+
else:
|
107 |
+
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
108 |
+
|
109 |
+
# config the model; the track_num is the number of points in the grid
|
110 |
+
model.spatrack.track_num = args.vo_points
|
111 |
+
|
112 |
+
model.eval()
|
113 |
+
model.to("cuda")
|
114 |
+
viser = Visualizer(save_dir=out_dir, grayscale=True,
|
115 |
+
fps=10, pad_value=0, tracks_leave_trace=5)
|
116 |
+
|
117 |
+
grid_size = args.grid_size
|
118 |
+
|
119 |
+
# get frame H W
|
120 |
+
if video_tensor is None:
|
121 |
+
cap = cv2.VideoCapture(video_path)
|
122 |
+
frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
123 |
+
frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
124 |
+
else:
|
125 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
126 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
127 |
+
|
128 |
+
# Sample mask values at grid points and filter out points where mask=0
|
129 |
+
if os.path.exists(mask_dir):
|
130 |
+
grid_pts_int = grid_pts[0].long()
|
131 |
+
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
132 |
+
grid_pts = grid_pts[:, mask_values]
|
133 |
+
|
134 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
135 |
+
|
136 |
+
# Run model inference
|
137 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
138 |
+
(
|
139 |
+
c2w_traj, intrs, point_map, conf_depth,
|
140 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
141 |
+
) = model.forward(video_tensor, depth=depth_tensor,
|
142 |
+
intrs=intrs, extrs=extrs,
|
143 |
+
queries=query_xyt,
|
144 |
+
fps=1, full_point=False, iters_track=4,
|
145 |
+
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
146 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
147 |
+
|
148 |
+
# resize the results to avoid too large I/O Burden
|
149 |
+
# depth and image, the maximum side is 336
|
150 |
+
max_size = 336
|
151 |
+
h, w = video.shape[2:]
|
152 |
+
scale = min(max_size / h, max_size / w)
|
153 |
+
if scale < 1:
|
154 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
155 |
+
video = T.Resize((new_h, new_w))(video)
|
156 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
157 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
158 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
159 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
160 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
161 |
+
if depth_tensor is not None:
|
162 |
+
if isinstance(depth_tensor, torch.Tensor):
|
163 |
+
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
164 |
+
else:
|
165 |
+
depth_tensor = T.Resize((new_h, new_w))(torch.from_numpy(depth_tensor))
|
166 |
+
|
167 |
+
if viz:
|
168 |
+
viser.visualize(video=video[None],
|
169 |
+
tracks=track2d_pred[None][...,:2],
|
170 |
+
visibility=vis_pred[None],filename="test")
|
171 |
+
|
172 |
+
# save as the tapip3d format
|
173 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
174 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
175 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
176 |
+
depth_save = point_map[:,2,...]
|
177 |
+
depth_save[conf_depth<0.5] = 0
|
178 |
+
data_npz_load["depths"] = depth_save.cpu().numpy()
|
179 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
180 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
181 |
+
data_npz_load["unc_metric"] = conf_depth.cpu().numpy()
|
182 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
183 |
+
|
184 |
+
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
models/SpaTrackV2/models/vggt4track/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/SpaTrackV2/models/vggt4track/heads/camera_head.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from models.SpaTrackV2.models.vggt4track.layers import Mlp
|
15 |
+
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
16 |
+
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
|
17 |
+
|
18 |
+
|
19 |
+
class CameraHead(nn.Module):
|
20 |
+
"""
|
21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
22 |
+
|
23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dim_in: int = 2048,
|
29 |
+
trunk_depth: int = 4,
|
30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
31 |
+
num_heads: int = 16,
|
32 |
+
mlp_ratio: int = 4,
|
33 |
+
init_values: float = 0.01,
|
34 |
+
trans_act: str = "linear",
|
35 |
+
quat_act: str = "linear",
|
36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
41 |
+
self.target_dim = 9
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
44 |
+
|
45 |
+
self.trans_act = trans_act
|
46 |
+
self.quat_act = quat_act
|
47 |
+
self.fl_act = fl_act
|
48 |
+
self.trunk_depth = trunk_depth
|
49 |
+
|
50 |
+
# Build the trunk using a sequence of transformer blocks.
|
51 |
+
self.trunk = nn.Sequential(
|
52 |
+
*[
|
53 |
+
Block(
|
54 |
+
dim=dim_in,
|
55 |
+
num_heads=num_heads,
|
56 |
+
mlp_ratio=mlp_ratio,
|
57 |
+
init_values=init_values,
|
58 |
+
)
|
59 |
+
for _ in range(trunk_depth)
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
# Normalizations for camera token and trunk output.
|
64 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
65 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
66 |
+
|
67 |
+
# Learnable empty camera pose token.
|
68 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
69 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
70 |
+
|
71 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
72 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
73 |
+
|
74 |
+
# Adaptive layer normalization without affine parameters.
|
75 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
76 |
+
self.pose_branch = Mlp(
|
77 |
+
in_features=dim_in,
|
78 |
+
hidden_features=dim_in // 2,
|
79 |
+
out_features=self.target_dim,
|
80 |
+
drop=0,
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
84 |
+
"""
|
85 |
+
Forward pass to predict camera parameters.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
89 |
+
the last tensor is used for prediction.
|
90 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
94 |
+
"""
|
95 |
+
# Use tokens from the last block for camera prediction.
|
96 |
+
tokens = aggregated_tokens_list[-1]
|
97 |
+
|
98 |
+
# Extract the camera tokens
|
99 |
+
pose_tokens = tokens[:, :, 0]
|
100 |
+
pose_tokens = self.token_norm(pose_tokens)
|
101 |
+
|
102 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
103 |
+
return pred_pose_enc_list
|
104 |
+
|
105 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
106 |
+
"""
|
107 |
+
Iteratively refine camera pose predictions.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
111 |
+
num_iterations (int): Number of refinement iterations.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
list: List of activated camera encodings from each iteration.
|
115 |
+
"""
|
116 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
117 |
+
pred_pose_enc = None
|
118 |
+
pred_pose_enc_list = []
|
119 |
+
|
120 |
+
for _ in range(num_iterations):
|
121 |
+
# Use a learned empty pose for the first iteration.
|
122 |
+
if pred_pose_enc is None:
|
123 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
124 |
+
else:
|
125 |
+
# Detach the previous prediction to avoid backprop through time.
|
126 |
+
pred_pose_enc = pred_pose_enc.detach()
|
127 |
+
module_input = self.embed_pose(pred_pose_enc)
|
128 |
+
|
129 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
130 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
131 |
+
|
132 |
+
# Adaptive layer normalization and modulation.
|
133 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
134 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
135 |
+
|
136 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
137 |
+
# Compute the delta update for the pose encoding.
|
138 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
139 |
+
|
140 |
+
if pred_pose_enc is None:
|
141 |
+
pred_pose_enc = pred_pose_enc_delta
|
142 |
+
else:
|
143 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
144 |
+
|
145 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
146 |
+
activated_pose = activate_pose(
|
147 |
+
pred_pose_enc,
|
148 |
+
trans_act=self.trans_act,
|
149 |
+
quat_act=self.quat_act,
|
150 |
+
fl_act=self.fl_act,
|
151 |
+
)
|
152 |
+
pred_pose_enc_list.append(activated_pose)
|
153 |
+
|
154 |
+
return pred_pose_enc_list
|
155 |
+
|
156 |
+
|
157 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
158 |
+
"""
|
159 |
+
Modulate the input tensor using scaling and shifting parameters.
|
160 |
+
"""
|
161 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
162 |
+
return x * (1 + scale) + shift
|
models/SpaTrackV2/models/vggt4track/heads/dpt_head.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
from typing import List, Dict, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from .head_act import activate_head
|
18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
19 |
+
|
20 |
+
|
21 |
+
class DPTHead(nn.Module):
|
22 |
+
"""
|
23 |
+
DPT Head for dense prediction tasks.
|
24 |
+
|
25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dim_in (int): Input dimension (channels).
|
31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dim_in: int,
|
46 |
+
patch_size: int = 14,
|
47 |
+
output_dim: int = 4,
|
48 |
+
activation: str = "inv_log",
|
49 |
+
conf_activation: str = "expp1",
|
50 |
+
features: int = 256,
|
51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
52 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
53 |
+
pos_embed: bool = True,
|
54 |
+
feature_only: bool = False,
|
55 |
+
down_ratio: int = 1,
|
56 |
+
) -> None:
|
57 |
+
super(DPTHead, self).__init__()
|
58 |
+
self.patch_size = patch_size
|
59 |
+
self.activation = activation
|
60 |
+
self.conf_activation = conf_activation
|
61 |
+
self.pos_embed = pos_embed
|
62 |
+
self.feature_only = feature_only
|
63 |
+
self.down_ratio = down_ratio
|
64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
65 |
+
|
66 |
+
self.norm = nn.LayerNorm(dim_in)
|
67 |
+
|
68 |
+
# Projection layers for each output channel from tokens.
|
69 |
+
self.projects = nn.ModuleList(
|
70 |
+
[
|
71 |
+
nn.Conv2d(
|
72 |
+
in_channels=dim_in,
|
73 |
+
out_channels=oc,
|
74 |
+
kernel_size=1,
|
75 |
+
stride=1,
|
76 |
+
padding=0,
|
77 |
+
)
|
78 |
+
for oc in out_channels
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
# Resize layers for upsampling feature maps.
|
83 |
+
self.resize_layers = nn.ModuleList(
|
84 |
+
[
|
85 |
+
nn.ConvTranspose2d(
|
86 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
87 |
+
),
|
88 |
+
nn.ConvTranspose2d(
|
89 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
90 |
+
),
|
91 |
+
nn.Identity(),
|
92 |
+
nn.Conv2d(
|
93 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
94 |
+
),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
|
98 |
+
self.scratch = _make_scratch(
|
99 |
+
out_channels,
|
100 |
+
features,
|
101 |
+
expand=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
# Attach additional modules to scratch.
|
105 |
+
self.scratch.stem_transpose = None
|
106 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
107 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
108 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
109 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
110 |
+
|
111 |
+
head_features_1 = features
|
112 |
+
head_features_2 = 32
|
113 |
+
|
114 |
+
if feature_only:
|
115 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
116 |
+
else:
|
117 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
118 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
119 |
+
)
|
120 |
+
conv2_in_channels = head_features_1 // 2
|
121 |
+
|
122 |
+
self.scratch.output_conv2 = nn.Sequential(
|
123 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
aggregated_tokens_list: List[torch.Tensor],
|
131 |
+
images: torch.Tensor,
|
132 |
+
patch_start_idx: int,
|
133 |
+
frames_chunk_size: int = 8,
|
134 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
135 |
+
"""
|
136 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
137 |
+
Args:
|
138 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
139 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
140 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
141 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
142 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
143 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Tensor or Tuple[Tensor, Tensor]:
|
147 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
148 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
149 |
+
"""
|
150 |
+
B, S, _, H, W = images.shape
|
151 |
+
|
152 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
153 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
154 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
155 |
+
|
156 |
+
# Otherwise, process frames in chunks to manage memory usage
|
157 |
+
assert frames_chunk_size > 0
|
158 |
+
|
159 |
+
# Process frames in batches
|
160 |
+
all_preds = []
|
161 |
+
all_conf = []
|
162 |
+
|
163 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
164 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
165 |
+
|
166 |
+
# Process batch of frames
|
167 |
+
if self.feature_only:
|
168 |
+
chunk_output = self._forward_impl(
|
169 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
170 |
+
)
|
171 |
+
all_preds.append(chunk_output)
|
172 |
+
else:
|
173 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
174 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
175 |
+
)
|
176 |
+
all_preds.append(chunk_preds)
|
177 |
+
all_conf.append(chunk_conf)
|
178 |
+
|
179 |
+
# Concatenate results along the sequence dimension
|
180 |
+
if self.feature_only:
|
181 |
+
return torch.cat(all_preds, dim=1)
|
182 |
+
else:
|
183 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
184 |
+
|
185 |
+
def _forward_impl(
|
186 |
+
self,
|
187 |
+
aggregated_tokens_list: List[torch.Tensor],
|
188 |
+
images: torch.Tensor,
|
189 |
+
patch_start_idx: int,
|
190 |
+
frames_start_idx: int = None,
|
191 |
+
frames_end_idx: int = None,
|
192 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
193 |
+
"""
|
194 |
+
Implementation of the forward pass through the DPT head.
|
195 |
+
|
196 |
+
This method processes a specific chunk of frames from the sequence.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
200 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
201 |
+
patch_start_idx (int): Starting index for patch tokens.
|
202 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
203 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
207 |
+
"""
|
208 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
209 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
210 |
+
|
211 |
+
B, S, _, H, W = images.shape
|
212 |
+
|
213 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
214 |
+
|
215 |
+
out = []
|
216 |
+
dpt_idx = 0
|
217 |
+
|
218 |
+
for layer_idx in self.intermediate_layer_idx:
|
219 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
220 |
+
|
221 |
+
# Select frames if processing a chunk
|
222 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
223 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
224 |
+
|
225 |
+
x = x.view(B * S, -1, x.shape[-1])
|
226 |
+
|
227 |
+
x = self.norm(x)
|
228 |
+
|
229 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
230 |
+
|
231 |
+
x = self.projects[dpt_idx](x)
|
232 |
+
if self.pos_embed:
|
233 |
+
x = self._apply_pos_embed(x, W, H)
|
234 |
+
x = self.resize_layers[dpt_idx](x)
|
235 |
+
|
236 |
+
out.append(x)
|
237 |
+
dpt_idx += 1
|
238 |
+
|
239 |
+
# Fuse features from multiple layers.
|
240 |
+
out = self.scratch_forward(out)
|
241 |
+
# Interpolate fused output to match target image resolution.
|
242 |
+
out = custom_interpolate(
|
243 |
+
out,
|
244 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
245 |
+
mode="bilinear",
|
246 |
+
align_corners=True,
|
247 |
+
)
|
248 |
+
|
249 |
+
if self.pos_embed:
|
250 |
+
out = self._apply_pos_embed(out, W, H)
|
251 |
+
|
252 |
+
if self.feature_only:
|
253 |
+
return out.view(B, S, *out.shape[1:])
|
254 |
+
|
255 |
+
out = self.scratch.output_conv2(out)
|
256 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
257 |
+
|
258 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
259 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
260 |
+
return preds, conf
|
261 |
+
|
262 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
263 |
+
"""
|
264 |
+
Apply positional embedding to tensor x.
|
265 |
+
"""
|
266 |
+
patch_w = x.shape[-1]
|
267 |
+
patch_h = x.shape[-2]
|
268 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
269 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
270 |
+
pos_embed = pos_embed * ratio
|
271 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
272 |
+
return x + pos_embed
|
273 |
+
|
274 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
275 |
+
"""
|
276 |
+
Forward pass through the fusion blocks.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
features (List[Tensor]): List of feature maps from different layers.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tensor: Fused feature map.
|
283 |
+
"""
|
284 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
285 |
+
|
286 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
287 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
288 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
289 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
290 |
+
|
291 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
292 |
+
del layer_4_rn, layer_4
|
293 |
+
|
294 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
295 |
+
del layer_3_rn, layer_3
|
296 |
+
|
297 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
298 |
+
del layer_2_rn, layer_2
|
299 |
+
|
300 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
301 |
+
del layer_1_rn, layer_1
|
302 |
+
|
303 |
+
out = self.scratch.output_conv1(out)
|
304 |
+
return out
|
305 |
+
|
306 |
+
|
307 |
+
################################################################################
|
308 |
+
# Modules
|
309 |
+
################################################################################
|
310 |
+
|
311 |
+
|
312 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
313 |
+
return FeatureFusionBlock(
|
314 |
+
features,
|
315 |
+
nn.ReLU(inplace=True),
|
316 |
+
deconv=False,
|
317 |
+
bn=False,
|
318 |
+
expand=False,
|
319 |
+
align_corners=True,
|
320 |
+
size=size,
|
321 |
+
has_residual=has_residual,
|
322 |
+
groups=groups,
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
327 |
+
scratch = nn.Module()
|
328 |
+
out_shape1 = out_shape
|
329 |
+
out_shape2 = out_shape
|
330 |
+
out_shape3 = out_shape
|
331 |
+
if len(in_shape) >= 4:
|
332 |
+
out_shape4 = out_shape
|
333 |
+
|
334 |
+
if expand:
|
335 |
+
out_shape1 = out_shape
|
336 |
+
out_shape2 = out_shape * 2
|
337 |
+
out_shape3 = out_shape * 4
|
338 |
+
if len(in_shape) >= 4:
|
339 |
+
out_shape4 = out_shape * 8
|
340 |
+
|
341 |
+
scratch.layer1_rn = nn.Conv2d(
|
342 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
343 |
+
)
|
344 |
+
scratch.layer2_rn = nn.Conv2d(
|
345 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
346 |
+
)
|
347 |
+
scratch.layer3_rn = nn.Conv2d(
|
348 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
349 |
+
)
|
350 |
+
if len(in_shape) >= 4:
|
351 |
+
scratch.layer4_rn = nn.Conv2d(
|
352 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
353 |
+
)
|
354 |
+
return scratch
|
355 |
+
|
356 |
+
|
357 |
+
class ResidualConvUnit(nn.Module):
|
358 |
+
"""Residual convolution module."""
|
359 |
+
|
360 |
+
def __init__(self, features, activation, bn, groups=1):
|
361 |
+
"""Init.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
features (int): number of features
|
365 |
+
"""
|
366 |
+
super().__init__()
|
367 |
+
|
368 |
+
self.bn = bn
|
369 |
+
self.groups = groups
|
370 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
371 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
372 |
+
|
373 |
+
self.norm1 = None
|
374 |
+
self.norm2 = None
|
375 |
+
|
376 |
+
self.activation = activation
|
377 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
"""Forward pass.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
x (tensor): input
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
tensor: output
|
387 |
+
"""
|
388 |
+
|
389 |
+
out = self.activation(x)
|
390 |
+
out = self.conv1(out)
|
391 |
+
if self.norm1 is not None:
|
392 |
+
out = self.norm1(out)
|
393 |
+
|
394 |
+
out = self.activation(out)
|
395 |
+
out = self.conv2(out)
|
396 |
+
if self.norm2 is not None:
|
397 |
+
out = self.norm2(out)
|
398 |
+
|
399 |
+
return self.skip_add.add(out, x)
|
400 |
+
|
401 |
+
|
402 |
+
class FeatureFusionBlock(nn.Module):
|
403 |
+
"""Feature fusion block."""
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
features,
|
408 |
+
activation,
|
409 |
+
deconv=False,
|
410 |
+
bn=False,
|
411 |
+
expand=False,
|
412 |
+
align_corners=True,
|
413 |
+
size=None,
|
414 |
+
has_residual=True,
|
415 |
+
groups=1,
|
416 |
+
):
|
417 |
+
"""Init.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
features (int): number of features
|
421 |
+
"""
|
422 |
+
super(FeatureFusionBlock, self).__init__()
|
423 |
+
|
424 |
+
self.deconv = deconv
|
425 |
+
self.align_corners = align_corners
|
426 |
+
self.groups = groups
|
427 |
+
self.expand = expand
|
428 |
+
out_features = features
|
429 |
+
if self.expand == True:
|
430 |
+
out_features = features // 2
|
431 |
+
|
432 |
+
self.out_conv = nn.Conv2d(
|
433 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
434 |
+
)
|
435 |
+
|
436 |
+
if has_residual:
|
437 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
438 |
+
|
439 |
+
self.has_residual = has_residual
|
440 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
441 |
+
|
442 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
443 |
+
self.size = size
|
444 |
+
|
445 |
+
def forward(self, *xs, size=None):
|
446 |
+
"""Forward pass.
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
tensor: output
|
450 |
+
"""
|
451 |
+
output = xs[0]
|
452 |
+
|
453 |
+
if self.has_residual:
|
454 |
+
res = self.resConfUnit1(xs[1])
|
455 |
+
output = self.skip_add.add(output, res)
|
456 |
+
|
457 |
+
output = self.resConfUnit2(output)
|
458 |
+
|
459 |
+
if (size is None) and (self.size is None):
|
460 |
+
modifier = {"scale_factor": 2}
|
461 |
+
elif size is None:
|
462 |
+
modifier = {"size": self.size}
|
463 |
+
else:
|
464 |
+
modifier = {"size": size}
|
465 |
+
|
466 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
467 |
+
output = self.out_conv(output)
|
468 |
+
|
469 |
+
return output
|
470 |
+
|
471 |
+
|
472 |
+
def custom_interpolate(
|
473 |
+
x: torch.Tensor,
|
474 |
+
size: Tuple[int, int] = None,
|
475 |
+
scale_factor: float = None,
|
476 |
+
mode: str = "bilinear",
|
477 |
+
align_corners: bool = True,
|
478 |
+
) -> torch.Tensor:
|
479 |
+
"""
|
480 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
481 |
+
"""
|
482 |
+
if size is None:
|
483 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
484 |
+
|
485 |
+
INT_MAX = 1610612736
|
486 |
+
|
487 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
488 |
+
|
489 |
+
if input_elements > INT_MAX:
|
490 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
491 |
+
interpolated_chunks = [
|
492 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
493 |
+
]
|
494 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
495 |
+
return x.contiguous()
|
496 |
+
else:
|
497 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
models/SpaTrackV2/models/vggt4track/heads/head_act.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
13 |
+
"""
|
14 |
+
Activate pose parameters with specified activation functions.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
18 |
+
trans_act: Activation type for translation component
|
19 |
+
quat_act: Activation type for quaternion component
|
20 |
+
fl_act: Activation type for focal length component
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Activated pose parameters tensor
|
24 |
+
"""
|
25 |
+
T = pred_pose_enc[..., :3]
|
26 |
+
quat = pred_pose_enc[..., 3:7]
|
27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
28 |
+
|
29 |
+
T = base_pose_act(T, trans_act)
|
30 |
+
quat = base_pose_act(quat, quat_act)
|
31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
32 |
+
|
33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
34 |
+
|
35 |
+
return pred_pose_enc
|
36 |
+
|
37 |
+
|
38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
39 |
+
"""
|
40 |
+
Apply basic activation function to pose parameters.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
pose_enc: Tensor containing encoded pose parameters
|
44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Activated pose parameters
|
48 |
+
"""
|
49 |
+
if act_type == "linear":
|
50 |
+
return pose_enc
|
51 |
+
elif act_type == "inv_log":
|
52 |
+
return inverse_log_transform(pose_enc)
|
53 |
+
elif act_type == "exp":
|
54 |
+
return torch.exp(pose_enc)
|
55 |
+
elif act_type == "relu":
|
56 |
+
return F.relu(pose_enc)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
59 |
+
|
60 |
+
|
61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
62 |
+
"""
|
63 |
+
Process network output to extract 3D points and confidence values.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
out: Network output tensor (B, C, H, W)
|
67 |
+
activation: Activation type for 3D points
|
68 |
+
conf_activation: Activation type for confidence values
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Tuple of (3D points tensor, confidence tensor)
|
72 |
+
"""
|
73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
75 |
+
|
76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
77 |
+
xyz = fmap[:, :, :, :-1]
|
78 |
+
conf = fmap[:, :, :, -1]
|
79 |
+
|
80 |
+
if activation == "norm_exp":
|
81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
82 |
+
xyz_normed = xyz / d
|
83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
84 |
+
elif activation == "norm":
|
85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
86 |
+
elif activation == "exp":
|
87 |
+
pts3d = torch.exp(xyz)
|
88 |
+
elif activation == "relu":
|
89 |
+
pts3d = F.relu(xyz)
|
90 |
+
elif activation == "inv_log":
|
91 |
+
pts3d = inverse_log_transform(xyz)
|
92 |
+
elif activation == "xy_inv_log":
|
93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
94 |
+
z = inverse_log_transform(z)
|
95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
96 |
+
elif activation == "sigmoid":
|
97 |
+
pts3d = torch.sigmoid(xyz)
|
98 |
+
elif activation == "linear":
|
99 |
+
pts3d = xyz
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
102 |
+
|
103 |
+
if conf_activation == "expp1":
|
104 |
+
conf_out = 1 + conf.exp()
|
105 |
+
elif conf_activation == "expp0":
|
106 |
+
conf_out = conf.exp()
|
107 |
+
elif conf_activation == "sigmoid":
|
108 |
+
conf_out = torch.sigmoid(conf)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
111 |
+
|
112 |
+
return pts3d, conf_out
|
113 |
+
|
114 |
+
|
115 |
+
def inverse_log_transform(y):
|
116 |
+
"""
|
117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
y: Input tensor
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Transformed tensor
|
124 |
+
"""
|
125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
models/SpaTrackV2/models/vggt4track/heads/scale_head.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from models.SpaTrackV2.models.vggt4track.layers import Mlp
|
15 |
+
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
16 |
+
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
|
17 |
+
|
18 |
+
|
19 |
+
class ScaleHead(nn.Module):
|
20 |
+
"""
|
21 |
+
ScaleHead predicts camera parameters from token representations using iterative refinement.
|
22 |
+
|
23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dim_in: int = 2048,
|
29 |
+
trunk_depth: int = 4,
|
30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
31 |
+
num_heads: int = 16,
|
32 |
+
mlp_ratio: int = 4,
|
33 |
+
init_values: float = 0.01,
|
34 |
+
trans_act: str = "linear",
|
35 |
+
quat_act: str = "linear",
|
36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.target_dim = 2
|
41 |
+
|
42 |
+
self.trans_act = trans_act
|
43 |
+
self.quat_act = quat_act
|
44 |
+
self.fl_act = fl_act
|
45 |
+
self.trunk_depth = trunk_depth
|
46 |
+
|
47 |
+
# Build the trunk using a sequence of transformer blocks.
|
48 |
+
self.trunk = nn.Sequential(
|
49 |
+
*[
|
50 |
+
Block(
|
51 |
+
dim=dim_in,
|
52 |
+
num_heads=num_heads,
|
53 |
+
mlp_ratio=mlp_ratio,
|
54 |
+
init_values=init_values,
|
55 |
+
)
|
56 |
+
for _ in range(trunk_depth)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Normalizations for camera token and trunk output.
|
61 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
62 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
63 |
+
|
64 |
+
# Learnable empty camera pose token.
|
65 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
66 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
67 |
+
|
68 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
69 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
70 |
+
|
71 |
+
# Adaptive layer normalization without affine parameters.
|
72 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
73 |
+
self.pose_branch = Mlp(
|
74 |
+
in_features=dim_in,
|
75 |
+
hidden_features=dim_in // 2,
|
76 |
+
out_features=self.target_dim,
|
77 |
+
drop=0,
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
81 |
+
"""
|
82 |
+
Forward pass to predict camera parameters.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
86 |
+
the last tensor is used for prediction.
|
87 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
91 |
+
"""
|
92 |
+
# Use tokens from the last block for camera prediction.
|
93 |
+
tokens = aggregated_tokens_list[-1]
|
94 |
+
|
95 |
+
# Extract the camera tokens
|
96 |
+
pose_tokens = tokens[:, :, 5]
|
97 |
+
pose_tokens = self.token_norm(pose_tokens)
|
98 |
+
|
99 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
100 |
+
return pred_pose_enc_list
|
101 |
+
|
102 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
103 |
+
"""
|
104 |
+
Iteratively refine camera pose predictions.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
108 |
+
num_iterations (int): Number of refinement iterations.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
list: List of activated camera encodings from each iteration.
|
112 |
+
"""
|
113 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
114 |
+
pred_pose_enc = None
|
115 |
+
pred_pose_enc_list = []
|
116 |
+
|
117 |
+
for _ in range(num_iterations):
|
118 |
+
# Use a learned empty pose for the first iteration.
|
119 |
+
if pred_pose_enc is None:
|
120 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
121 |
+
else:
|
122 |
+
# Detach the previous prediction to avoid backprop through time.
|
123 |
+
pred_pose_enc = pred_pose_enc.detach()
|
124 |
+
module_input = self.embed_pose(pred_pose_enc)
|
125 |
+
|
126 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
127 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
128 |
+
|
129 |
+
# Adaptive layer normalization and modulation.
|
130 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
131 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
132 |
+
|
133 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
134 |
+
# Compute the delta update for the pose encoding.
|
135 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
136 |
+
|
137 |
+
if pred_pose_enc is None:
|
138 |
+
pred_pose_enc = pred_pose_enc_delta
|
139 |
+
else:
|
140 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
141 |
+
|
142 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
143 |
+
activated_pose = activate_pose(
|
144 |
+
pred_pose_enc,
|
145 |
+
trans_act=self.trans_act,
|
146 |
+
quat_act=self.quat_act,
|
147 |
+
fl_act=self.fl_act,
|
148 |
+
)
|
149 |
+
activated_pose_proc = activated_pose.clone()
|
150 |
+
activated_pose_proc[...,:1] = activated_pose_proc[...,:1].clamp(min=1e-5, max=1e3)
|
151 |
+
activated_pose_proc[...,1:] = activated_pose_proc[...,1:]*1e-2
|
152 |
+
pred_pose_enc_list.append(activated_pose_proc)
|
153 |
+
|
154 |
+
return pred_pose_enc_list
|
155 |
+
|
156 |
+
|
157 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
158 |
+
"""
|
159 |
+
Modulate the input tensor using scaling and shifting parameters.
|
160 |
+
"""
|
161 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
162 |
+
return x * (1 + scale) + shift
|
models/SpaTrackV2/models/vggt4track/heads/track_head.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
from .dpt_head import DPTHead
|
9 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
10 |
+
|
11 |
+
|
12 |
+
class TrackHead(nn.Module):
|
13 |
+
"""
|
14 |
+
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
15 |
+
The tracking is performed iteratively, refining predictions over multiple iterations.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
dim_in,
|
21 |
+
patch_size=14,
|
22 |
+
features=128,
|
23 |
+
iters=4,
|
24 |
+
predict_conf=True,
|
25 |
+
stride=2,
|
26 |
+
corr_levels=7,
|
27 |
+
corr_radius=4,
|
28 |
+
hidden_size=384,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initialize the TrackHead module.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dim_in (int): Input dimension of tokens from the backbone.
|
35 |
+
patch_size (int): Size of image patches used in the vision transformer.
|
36 |
+
features (int): Number of feature channels in the feature extractor output.
|
37 |
+
iters (int): Number of refinement iterations for tracking predictions.
|
38 |
+
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
39 |
+
stride (int): Stride value for the tracker predictor.
|
40 |
+
corr_levels (int): Number of correlation pyramid levels
|
41 |
+
corr_radius (int): Radius for correlation computation, controlling the search area.
|
42 |
+
hidden_size (int): Size of hidden layers in the tracker network.
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.patch_size = patch_size
|
47 |
+
|
48 |
+
# Feature extractor based on DPT architecture
|
49 |
+
# Processes tokens into feature maps for tracking
|
50 |
+
self.feature_extractor = DPTHead(
|
51 |
+
dim_in=dim_in,
|
52 |
+
patch_size=patch_size,
|
53 |
+
features=features,
|
54 |
+
feature_only=True, # Only output features, no activation
|
55 |
+
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
56 |
+
pos_embed=False,
|
57 |
+
)
|
58 |
+
|
59 |
+
# Tracker module that predicts point trajectories
|
60 |
+
# Takes feature maps and predicts coordinates and visibility
|
61 |
+
self.tracker = BaseTrackerPredictor(
|
62 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
63 |
+
predict_conf=predict_conf,
|
64 |
+
stride=stride,
|
65 |
+
corr_levels=corr_levels,
|
66 |
+
corr_radius=corr_radius,
|
67 |
+
hidden_size=hidden_size,
|
68 |
+
)
|
69 |
+
|
70 |
+
self.iters = iters
|
71 |
+
|
72 |
+
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
|
73 |
+
"""
|
74 |
+
Forward pass of the TrackHead.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
78 |
+
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
79 |
+
B = batch size, S = sequence length.
|
80 |
+
patch_start_idx (int): Starting index for patch tokens.
|
81 |
+
query_points (torch.Tensor, optional): Initial query points to track.
|
82 |
+
If None, points are initialized by the tracker.
|
83 |
+
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
tuple:
|
87 |
+
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
88 |
+
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
89 |
+
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
90 |
+
"""
|
91 |
+
B, S, _, H, W = images.shape
|
92 |
+
|
93 |
+
# Extract features from tokens
|
94 |
+
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
95 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
|
96 |
+
|
97 |
+
# Use default iterations if not specified
|
98 |
+
if iters is None:
|
99 |
+
iters = self.iters
|
100 |
+
|
101 |
+
# Perform tracking using the extracted features
|
102 |
+
coord_preds, vis_scores, conf_scores = self.tracker(
|
103 |
+
query_points=query_points,
|
104 |
+
fmaps=feature_maps,
|
105 |
+
iters=iters,
|
106 |
+
)
|
107 |
+
|
108 |
+
return coord_preds, vis_scores, conf_scores
|
models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
|
12 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
13 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
14 |
+
from .modules import Mlp
|
15 |
+
|
16 |
+
|
17 |
+
class BaseTrackerPredictor(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
stride=1,
|
21 |
+
corr_levels=5,
|
22 |
+
corr_radius=4,
|
23 |
+
latent_dim=128,
|
24 |
+
hidden_size=384,
|
25 |
+
use_spaceatt=True,
|
26 |
+
depth=6,
|
27 |
+
max_scale=518,
|
28 |
+
predict_conf=True,
|
29 |
+
):
|
30 |
+
super(BaseTrackerPredictor, self).__init__()
|
31 |
+
"""
|
32 |
+
The base template to create a track predictor
|
33 |
+
|
34 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
35 |
+
and https://github.com/facebookresearch/vggsfm
|
36 |
+
"""
|
37 |
+
|
38 |
+
self.stride = stride
|
39 |
+
self.latent_dim = latent_dim
|
40 |
+
self.corr_levels = corr_levels
|
41 |
+
self.corr_radius = corr_radius
|
42 |
+
self.hidden_size = hidden_size
|
43 |
+
self.max_scale = max_scale
|
44 |
+
self.predict_conf = predict_conf
|
45 |
+
|
46 |
+
self.flows_emb_dim = latent_dim // 2
|
47 |
+
|
48 |
+
self.corr_mlp = Mlp(
|
49 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
50 |
+
hidden_features=self.hidden_size,
|
51 |
+
out_features=self.latent_dim,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
55 |
+
|
56 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
57 |
+
|
58 |
+
space_depth = depth if use_spaceatt else 0
|
59 |
+
time_depth = depth
|
60 |
+
|
61 |
+
self.updateformer = EfficientUpdateFormer(
|
62 |
+
space_depth=space_depth,
|
63 |
+
time_depth=time_depth,
|
64 |
+
input_dim=self.transformer_dim,
|
65 |
+
hidden_size=self.hidden_size,
|
66 |
+
output_dim=self.latent_dim + 2,
|
67 |
+
mlp_ratio=4.0,
|
68 |
+
add_space_attn=use_spaceatt,
|
69 |
+
)
|
70 |
+
|
71 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
72 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
73 |
+
|
74 |
+
# A linear layer to update track feats at each iteration
|
75 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
76 |
+
|
77 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
78 |
+
|
79 |
+
if predict_conf:
|
80 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
81 |
+
|
82 |
+
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
|
83 |
+
"""
|
84 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
85 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
86 |
+
note HH and WW is the size of feature maps instead of original images
|
87 |
+
"""
|
88 |
+
B, N, D = query_points.shape
|
89 |
+
B, S, C, HH, WW = fmaps.shape
|
90 |
+
|
91 |
+
assert D == 2, "Input points must be 2D coordinates"
|
92 |
+
|
93 |
+
# apply a layernorm to fmaps here
|
94 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
95 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
96 |
+
|
97 |
+
# Scale the input query_points because we may downsample the images
|
98 |
+
# by down_ratio or self.stride
|
99 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
100 |
+
# its query_points should be query_points/4
|
101 |
+
if down_ratio > 1:
|
102 |
+
query_points = query_points / float(down_ratio)
|
103 |
+
|
104 |
+
query_points = query_points / float(self.stride)
|
105 |
+
|
106 |
+
# Init with coords as the query points
|
107 |
+
# It means the search will start from the position of query points at the reference frames
|
108 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
109 |
+
|
110 |
+
# Sample/extract the features of the query points in the query frame
|
111 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
112 |
+
|
113 |
+
# init track feats by query feats
|
114 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
115 |
+
# back up the init coords
|
116 |
+
coords_backup = coords.clone()
|
117 |
+
|
118 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
119 |
+
|
120 |
+
coord_preds = []
|
121 |
+
|
122 |
+
# Iterative Refinement
|
123 |
+
for _ in range(iters):
|
124 |
+
# Detach the gradients from the last iteration
|
125 |
+
# (in my experience, not very important for performance)
|
126 |
+
coords = coords.detach()
|
127 |
+
|
128 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
129 |
+
|
130 |
+
corr_dim = fcorrs.shape[3]
|
131 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
132 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
133 |
+
|
134 |
+
# Movement of current coords relative to query points
|
135 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
136 |
+
|
137 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
138 |
+
|
139 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
140 |
+
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
|
141 |
+
|
142 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
143 |
+
|
144 |
+
# Concatenate them as the input for the transformers
|
145 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
146 |
+
|
147 |
+
# 2D positional embed
|
148 |
+
# TODO: this can be much simplified
|
149 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
150 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
151 |
+
|
152 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
153 |
+
|
154 |
+
x = transformer_input + sampled_pos_emb
|
155 |
+
|
156 |
+
# Add the query ref token to the track feats
|
157 |
+
query_ref_token = torch.cat(
|
158 |
+
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
|
159 |
+
)
|
160 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
161 |
+
|
162 |
+
# B, N, S, C
|
163 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
164 |
+
|
165 |
+
# Compute the delta coordinates and delta track features
|
166 |
+
delta, _ = self.updateformer(x)
|
167 |
+
|
168 |
+
# BN, S, C
|
169 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
170 |
+
delta_coords_ = delta[:, :, :2]
|
171 |
+
delta_feats_ = delta[:, :, 2:]
|
172 |
+
|
173 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
174 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
175 |
+
|
176 |
+
# Update the track features
|
177 |
+
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
178 |
+
|
179 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
180 |
+
|
181 |
+
# B x S x N x 2
|
182 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
183 |
+
|
184 |
+
# Force coord0 as query
|
185 |
+
# because we assume the query points should not be changed
|
186 |
+
coords[:, 0] = coords_backup[:, 0]
|
187 |
+
|
188 |
+
# The predicted tracks are in the original image scale
|
189 |
+
if down_ratio > 1:
|
190 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
191 |
+
else:
|
192 |
+
coord_preds.append(coords * self.stride)
|
193 |
+
|
194 |
+
# B, S, N
|
195 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
196 |
+
if apply_sigmoid:
|
197 |
+
vis_e = torch.sigmoid(vis_e)
|
198 |
+
|
199 |
+
if self.predict_conf:
|
200 |
+
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
201 |
+
if apply_sigmoid:
|
202 |
+
conf_e = torch.sigmoid(conf_e)
|
203 |
+
else:
|
204 |
+
conf_e = None
|
205 |
+
|
206 |
+
if return_feat:
|
207 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
208 |
+
else:
|
209 |
+
return coord_preds, vis_e, conf_e
|
models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .utils import bilinear_sampler
|
16 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
17 |
+
|
18 |
+
|
19 |
+
class EfficientUpdateFormer(nn.Module):
|
20 |
+
"""
|
21 |
+
Transformer model that updates track estimates.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
space_depth=6,
|
27 |
+
time_depth=6,
|
28 |
+
input_dim=320,
|
29 |
+
hidden_size=384,
|
30 |
+
num_heads=8,
|
31 |
+
output_dim=130,
|
32 |
+
mlp_ratio=4.0,
|
33 |
+
add_space_attn=True,
|
34 |
+
num_virtual_tracks=64,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.out_channels = 2
|
39 |
+
self.num_heads = num_heads
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.add_space_attn = add_space_attn
|
42 |
+
|
43 |
+
# Add input LayerNorm before linear projection
|
44 |
+
self.input_norm = nn.LayerNorm(input_dim)
|
45 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
46 |
+
|
47 |
+
# Add output LayerNorm before final projection
|
48 |
+
self.output_norm = nn.LayerNorm(hidden_size)
|
49 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
50 |
+
self.num_virtual_tracks = num_virtual_tracks
|
51 |
+
|
52 |
+
if self.add_space_attn:
|
53 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
54 |
+
else:
|
55 |
+
self.virual_tracks = None
|
56 |
+
|
57 |
+
self.time_blocks = nn.ModuleList(
|
58 |
+
[
|
59 |
+
AttnBlock(
|
60 |
+
hidden_size,
|
61 |
+
num_heads,
|
62 |
+
mlp_ratio=mlp_ratio,
|
63 |
+
attn_class=nn.MultiheadAttention,
|
64 |
+
)
|
65 |
+
for _ in range(time_depth)
|
66 |
+
]
|
67 |
+
)
|
68 |
+
|
69 |
+
if add_space_attn:
|
70 |
+
self.space_virtual_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
AttnBlock(
|
73 |
+
hidden_size,
|
74 |
+
num_heads,
|
75 |
+
mlp_ratio=mlp_ratio,
|
76 |
+
attn_class=nn.MultiheadAttention,
|
77 |
+
)
|
78 |
+
for _ in range(space_depth)
|
79 |
+
]
|
80 |
+
)
|
81 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
82 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
83 |
+
)
|
84 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
85 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
86 |
+
)
|
87 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
88 |
+
self.initialize_weights()
|
89 |
+
|
90 |
+
def initialize_weights(self):
|
91 |
+
def _basic_init(module):
|
92 |
+
if isinstance(module, nn.Linear):
|
93 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
94 |
+
if module.bias is not None:
|
95 |
+
nn.init.constant_(module.bias, 0)
|
96 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
97 |
+
|
98 |
+
self.apply(_basic_init)
|
99 |
+
|
100 |
+
def forward(self, input_tensor, mask=None):
|
101 |
+
# Apply input LayerNorm
|
102 |
+
input_tensor = self.input_norm(input_tensor)
|
103 |
+
tokens = self.input_transform(input_tensor)
|
104 |
+
|
105 |
+
init_tokens = tokens
|
106 |
+
|
107 |
+
B, _, T, _ = tokens.shape
|
108 |
+
|
109 |
+
if self.add_space_attn:
|
110 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
111 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
112 |
+
|
113 |
+
_, N, _, _ = tokens.shape
|
114 |
+
|
115 |
+
j = 0
|
116 |
+
for i in range(len(self.time_blocks)):
|
117 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
118 |
+
|
119 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
120 |
+
|
121 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
122 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
123 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
124 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
125 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
126 |
+
|
127 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
128 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
129 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
130 |
+
|
131 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
132 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
133 |
+
j += 1
|
134 |
+
|
135 |
+
if self.add_space_attn:
|
136 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
137 |
+
|
138 |
+
tokens = tokens + init_tokens
|
139 |
+
|
140 |
+
# Apply output LayerNorm before final projection
|
141 |
+
tokens = self.output_norm(tokens)
|
142 |
+
flow = self.flow_head(tokens)
|
143 |
+
|
144 |
+
return flow, None
|
145 |
+
|
146 |
+
|
147 |
+
class CorrBlock:
|
148 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
149 |
+
"""
|
150 |
+
Build a pyramid of feature maps from the input.
|
151 |
+
|
152 |
+
fmaps: Tensor (B, S, C, H, W)
|
153 |
+
num_levels: number of pyramid levels (each downsampled by factor 2)
|
154 |
+
radius: search radius for sampling correlation
|
155 |
+
multiple_track_feats: if True, split the target features per pyramid level
|
156 |
+
padding_mode: passed to grid_sample / bilinear_sampler
|
157 |
+
"""
|
158 |
+
B, S, C, H, W = fmaps.shape
|
159 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
160 |
+
self.num_levels = num_levels
|
161 |
+
self.radius = radius
|
162 |
+
self.padding_mode = padding_mode
|
163 |
+
self.multiple_track_feats = multiple_track_feats
|
164 |
+
|
165 |
+
# Build pyramid: each level is half the spatial resolution of the previous
|
166 |
+
self.fmaps_pyramid = [fmaps] # level 0 is full resolution
|
167 |
+
current_fmaps = fmaps
|
168 |
+
for i in range(num_levels - 1):
|
169 |
+
B, S, C, H, W = current_fmaps.shape
|
170 |
+
# Merge batch & sequence dimensions
|
171 |
+
current_fmaps = current_fmaps.reshape(B * S, C, H, W)
|
172 |
+
# Avg pool down by factor 2
|
173 |
+
current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
|
174 |
+
_, _, H_new, W_new = current_fmaps.shape
|
175 |
+
current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
|
176 |
+
self.fmaps_pyramid.append(current_fmaps)
|
177 |
+
|
178 |
+
# Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
|
179 |
+
# This grid is added to the (scaled) coordinate centroids.
|
180 |
+
r = self.radius
|
181 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
182 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
183 |
+
# delta: for every (dy,dx) displacement (i.e. Δx, Δy)
|
184 |
+
self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
|
185 |
+
|
186 |
+
def corr_sample(self, targets, coords):
|
187 |
+
"""
|
188 |
+
Instead of storing the entire correlation pyramid, we compute each level's correlation
|
189 |
+
volume, sample it immediately, then discard it. This saves GPU memory.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
targets: Tensor (B, S, N, C) — features for the current targets.
|
193 |
+
coords: Tensor (B, S, N, 2) — coordinates at full resolution.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
|
197 |
+
"""
|
198 |
+
B, S, N, C = targets.shape
|
199 |
+
|
200 |
+
# If you have multiple track features, split them per level.
|
201 |
+
if self.multiple_track_feats:
|
202 |
+
targets_split = torch.split(targets, C // self.num_levels, dim=-1)
|
203 |
+
|
204 |
+
out_pyramid = []
|
205 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
206 |
+
# Get current spatial resolution H, W for this pyramid level.
|
207 |
+
B, S, C, H, W = fmaps.shape
|
208 |
+
# Reshape feature maps for correlation computation:
|
209 |
+
# fmap2s: (B, S, C, H*W)
|
210 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
211 |
+
# Choose appropriate target features.
|
212 |
+
fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
|
213 |
+
|
214 |
+
# Compute correlation directly
|
215 |
+
corrs = compute_corr_level(fmap1, fmap2s, C)
|
216 |
+
corrs = corrs.view(B, S, N, H, W)
|
217 |
+
|
218 |
+
# Prepare sampling grid:
|
219 |
+
# Scale down the coordinates for the current level.
|
220 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
|
221 |
+
# Make sure our precomputed delta grid is on the same device/dtype.
|
222 |
+
delta_lvl = self.delta.to(coords.device).to(coords.dtype)
|
223 |
+
# Now the grid for grid_sample is:
|
224 |
+
# coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
|
225 |
+
coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
|
226 |
+
|
227 |
+
# Sample from the correlation volume using bilinear interpolation.
|
228 |
+
# We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
|
229 |
+
corrs_sampled = bilinear_sampler(
|
230 |
+
corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
|
231 |
+
)
|
232 |
+
# The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
|
233 |
+
corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
|
234 |
+
out_pyramid.append(corrs_sampled)
|
235 |
+
|
236 |
+
# Concatenate all levels along the last dimension.
|
237 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous()
|
238 |
+
return out
|
239 |
+
|
240 |
+
|
241 |
+
def compute_corr_level(fmap1, fmap2s, C):
|
242 |
+
# fmap1: (B, S, N, C)
|
243 |
+
# fmap2s: (B, S, C, H*W)
|
244 |
+
corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
|
245 |
+
corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
|
246 |
+
return corrs / math.sqrt(C)
|
models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from functools import partial
|
12 |
+
from typing import Callable
|
13 |
+
import collections
|
14 |
+
from torch import Tensor
|
15 |
+
from itertools import repeat
|
16 |
+
|
17 |
+
|
18 |
+
# From PyTorch internals
|
19 |
+
def _ntuple(n):
|
20 |
+
def parse(x):
|
21 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
22 |
+
return tuple(x)
|
23 |
+
return tuple(repeat(x, n))
|
24 |
+
|
25 |
+
return parse
|
26 |
+
|
27 |
+
|
28 |
+
def exists(val):
|
29 |
+
return val is not None
|
30 |
+
|
31 |
+
|
32 |
+
def default(val, d):
|
33 |
+
return val if exists(val) else d
|
34 |
+
|
35 |
+
|
36 |
+
to_2tuple = _ntuple(2)
|
37 |
+
|
38 |
+
|
39 |
+
class ResidualBlock(nn.Module):
|
40 |
+
"""
|
41 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
45 |
+
super(ResidualBlock, self).__init__()
|
46 |
+
|
47 |
+
self.conv1 = nn.Conv2d(
|
48 |
+
in_planes,
|
49 |
+
planes,
|
50 |
+
kernel_size=kernel_size,
|
51 |
+
padding=1,
|
52 |
+
stride=stride,
|
53 |
+
padding_mode="zeros",
|
54 |
+
)
|
55 |
+
self.conv2 = nn.Conv2d(
|
56 |
+
planes,
|
57 |
+
planes,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
padding=1,
|
60 |
+
padding_mode="zeros",
|
61 |
+
)
|
62 |
+
self.relu = nn.ReLU(inplace=True)
|
63 |
+
|
64 |
+
num_groups = planes // 8
|
65 |
+
|
66 |
+
if norm_fn == "group":
|
67 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
68 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
69 |
+
if not stride == 1:
|
70 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
71 |
+
|
72 |
+
elif norm_fn == "batch":
|
73 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
74 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
77 |
+
|
78 |
+
elif norm_fn == "instance":
|
79 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
80 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
81 |
+
if not stride == 1:
|
82 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
83 |
+
|
84 |
+
elif norm_fn == "none":
|
85 |
+
self.norm1 = nn.Sequential()
|
86 |
+
self.norm2 = nn.Sequential()
|
87 |
+
if not stride == 1:
|
88 |
+
self.norm3 = nn.Sequential()
|
89 |
+
else:
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
if stride == 1:
|
93 |
+
self.downsample = None
|
94 |
+
else:
|
95 |
+
self.downsample = nn.Sequential(
|
96 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
97 |
+
self.norm3,
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
y = x
|
102 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
103 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
104 |
+
|
105 |
+
if self.downsample is not None:
|
106 |
+
x = self.downsample(x)
|
107 |
+
|
108 |
+
return self.relu(x + y)
|
109 |
+
|
110 |
+
|
111 |
+
class Mlp(nn.Module):
|
112 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
in_features,
|
117 |
+
hidden_features=None,
|
118 |
+
out_features=None,
|
119 |
+
act_layer=nn.GELU,
|
120 |
+
norm_layer=None,
|
121 |
+
bias=True,
|
122 |
+
drop=0.0,
|
123 |
+
use_conv=False,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
out_features = out_features or in_features
|
127 |
+
hidden_features = hidden_features or in_features
|
128 |
+
bias = to_2tuple(bias)
|
129 |
+
drop_probs = to_2tuple(drop)
|
130 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
131 |
+
|
132 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
133 |
+
self.act = act_layer()
|
134 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
135 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
136 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.fc1(x)
|
140 |
+
x = self.act(x)
|
141 |
+
x = self.drop1(x)
|
142 |
+
x = self.fc2(x)
|
143 |
+
x = self.drop2(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class AttnBlock(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
hidden_size,
|
151 |
+
num_heads,
|
152 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
153 |
+
mlp_ratio=4.0,
|
154 |
+
**block_kwargs
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
Self attention block
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
162 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
163 |
+
|
164 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
165 |
+
|
166 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
167 |
+
|
168 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
169 |
+
|
170 |
+
def forward(self, x, mask=None):
|
171 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
172 |
+
# attn_mask = mask if mask is not None else None
|
173 |
+
# Normalize before attention
|
174 |
+
x = self.norm1(x)
|
175 |
+
|
176 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
177 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
178 |
+
|
179 |
+
attn_output, _ = self.attn(x, x, x)
|
180 |
+
|
181 |
+
# Add & Norm
|
182 |
+
x = x + attn_output
|
183 |
+
x = x + self.mlp(self.norm2(x))
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class CrossAttnBlock(nn.Module):
|
188 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
189 |
+
"""
|
190 |
+
Cross attention block
|
191 |
+
"""
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
195 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
196 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
197 |
+
|
198 |
+
self.cross_attn = nn.MultiheadAttention(
|
199 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
200 |
+
)
|
201 |
+
|
202 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
203 |
+
|
204 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
205 |
+
|
206 |
+
def forward(self, x, context, mask=None):
|
207 |
+
# Normalize inputs
|
208 |
+
x = self.norm1(x)
|
209 |
+
context = self.norm_context(context)
|
210 |
+
|
211 |
+
# Apply cross attention
|
212 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
213 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
214 |
+
|
215 |
+
# Add & Norm
|
216 |
+
x = x + attn_output
|
217 |
+
x = x + self.mlp(self.norm2(x))
|
218 |
+
return x
|
models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from https://github.com/facebookresearch/vggsfm
|
8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
|
18 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
21 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
22 |
+
Args:
|
23 |
+
- embed_dim: The embedding dimension.
|
24 |
+
- grid_size: The grid size.
|
25 |
+
Returns:
|
26 |
+
- pos_embed: The generated 2D positional embedding.
|
27 |
+
"""
|
28 |
+
if isinstance(grid_size, tuple):
|
29 |
+
grid_size_h, grid_size_w = grid_size
|
30 |
+
else:
|
31 |
+
grid_size_h = grid_size_w = grid_size
|
32 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
33 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
34 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
35 |
+
grid = torch.stack(grid, dim=0)
|
36 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
37 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
38 |
+
if return_grid:
|
39 |
+
return (
|
40 |
+
pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
|
41 |
+
grid,
|
42 |
+
)
|
43 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
44 |
+
|
45 |
+
|
46 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
- embed_dim: The embedding dimension.
|
52 |
+
- grid: The grid to generate the embedding from.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
- emb: The generated 2D positional embedding.
|
56 |
+
"""
|
57 |
+
assert embed_dim % 2 == 0
|
58 |
+
|
59 |
+
# use half of dimensions to encode grid_h
|
60 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
61 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
62 |
+
|
63 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
64 |
+
return emb
|
65 |
+
|
66 |
+
|
67 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
68 |
+
"""
|
69 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
- embed_dim: The embedding dimension.
|
73 |
+
- pos: The position to generate the embedding from.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
- emb: The generated 1D positional embedding.
|
77 |
+
"""
|
78 |
+
assert embed_dim % 2 == 0
|
79 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
80 |
+
omega /= embed_dim / 2.0
|
81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
82 |
+
|
83 |
+
pos = pos.reshape(-1) # (M,)
|
84 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
85 |
+
|
86 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
87 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
88 |
+
|
89 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
90 |
+
return emb[None].float()
|
91 |
+
|
92 |
+
|
93 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
- xy: The coordinates to generate the embedding from.
|
99 |
+
- C: The size of the embedding.
|
100 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
- pe: The generated 2D positional embedding.
|
104 |
+
"""
|
105 |
+
B, N, D = xy.shape
|
106 |
+
assert D == 2
|
107 |
+
|
108 |
+
x = xy[:, :, 0:1]
|
109 |
+
y = xy[:, :, 1:2]
|
110 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
111 |
+
|
112 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
113 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
114 |
+
|
115 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
116 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
117 |
+
|
118 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
119 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
120 |
+
|
121 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
122 |
+
if cat_coords:
|
123 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
124 |
+
return pe
|
125 |
+
|
126 |
+
|
127 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
128 |
+
r"""Sample a tensor using bilinear interpolation
|
129 |
+
|
130 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
131 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
132 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
133 |
+
convention.
|
134 |
+
|
135 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
136 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
137 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
138 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
139 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
140 |
+
|
141 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
142 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
143 |
+
that in this case the order of the components is slightly different
|
144 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
145 |
+
|
146 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
147 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
148 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
149 |
+
pixel.
|
150 |
+
|
151 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
152 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
153 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
154 |
+
pixel.
|
155 |
+
|
156 |
+
Similar conventions apply to the :math:`y` for the range
|
157 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
158 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
input (Tensor): batch of input images.
|
162 |
+
coords (Tensor): batch of coordinates.
|
163 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
164 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: sampled points.
|
168 |
+
"""
|
169 |
+
coords = coords.detach().clone()
|
170 |
+
############################################################
|
171 |
+
# IMPORTANT:
|
172 |
+
coords = coords.to(input.device).to(input.dtype)
|
173 |
+
############################################################
|
174 |
+
|
175 |
+
sizes = input.shape[2:]
|
176 |
+
|
177 |
+
assert len(sizes) in [2, 3]
|
178 |
+
|
179 |
+
if len(sizes) == 3:
|
180 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
181 |
+
coords = coords[..., [1, 2, 0]]
|
182 |
+
|
183 |
+
if align_corners:
|
184 |
+
scale = torch.tensor(
|
185 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
|
189 |
+
|
190 |
+
coords.mul_(scale) # coords = coords * scale
|
191 |
+
coords.sub_(1) # coords = coords - 1
|
192 |
+
|
193 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
194 |
+
|
195 |
+
|
196 |
+
def sample_features4d(input, coords):
|
197 |
+
r"""Sample spatial features
|
198 |
+
|
199 |
+
`sample_features4d(input, coords)` samples the spatial features
|
200 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
201 |
+
|
202 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
203 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
204 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
205 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
206 |
+
|
207 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
208 |
+
R, C)`.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
input (Tensor): spatial features.
|
212 |
+
coords (Tensor): points.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Tensor: sampled features.
|
216 |
+
"""
|
217 |
+
|
218 |
+
B, _, _, _ = input.shape
|
219 |
+
|
220 |
+
# B R 2 -> B R 1 2
|
221 |
+
coords = coords.unsqueeze(2)
|
222 |
+
|
223 |
+
# B C R 1
|
224 |
+
feats = bilinear_sampler(input, coords)
|
225 |
+
|
226 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|
models/SpaTrackV2/models/vggt4track/heads/utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
14 |
+
|
15 |
+
Args:
|
16 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
17 |
+
embed_dim: Output channel dimension for embeddings
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
21 |
+
"""
|
22 |
+
H, W, grid_dim = pos_grid.shape
|
23 |
+
assert grid_dim == 2
|
24 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
25 |
+
|
26 |
+
# Process x and y coordinates separately
|
27 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
28 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
29 |
+
|
30 |
+
# Combine and reshape
|
31 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
32 |
+
|
33 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
34 |
+
|
35 |
+
|
36 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
37 |
+
"""
|
38 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
- embed_dim: The embedding dimension.
|
42 |
+
- pos: The position to generate the embedding from.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
- emb: The generated 1D positional embedding.
|
46 |
+
"""
|
47 |
+
assert embed_dim % 2 == 0
|
48 |
+
device = pos.device
|
49 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
50 |
+
omega /= embed_dim / 2.0
|
51 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
52 |
+
|
53 |
+
pos = pos.reshape(-1) # (M,)
|
54 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
55 |
+
|
56 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
57 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
58 |
+
|
59 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
60 |
+
return emb.float()
|
61 |
+
|
62 |
+
|
63 |
+
# Inspired by https://github.com/microsoft/moge
|
64 |
+
|
65 |
+
|
66 |
+
def create_uv_grid(
|
67 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Create a normalized UV grid of shape (width, height, 2).
|
71 |
+
|
72 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
73 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
74 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
width (int): Number of points horizontally.
|
78 |
+
height (int): Number of points vertically.
|
79 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
80 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
81 |
+
device (torch.device, optional): Device on which the tensor is created.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
85 |
+
"""
|
86 |
+
# Derive aspect ratio if not explicitly provided
|
87 |
+
if aspect_ratio is None:
|
88 |
+
aspect_ratio = float(width) / float(height)
|
89 |
+
|
90 |
+
# Compute normalized spans for X and Y
|
91 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
92 |
+
span_x = aspect_ratio / diag_factor
|
93 |
+
span_y = 1.0 / diag_factor
|
94 |
+
|
95 |
+
# Establish the linspace boundaries
|
96 |
+
left_x = -span_x * (width - 1) / width
|
97 |
+
right_x = span_x * (width - 1) / width
|
98 |
+
top_y = -span_y * (height - 1) / height
|
99 |
+
bottom_y = span_y * (height - 1) / height
|
100 |
+
|
101 |
+
# Generate 1D coordinates
|
102 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
103 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
104 |
+
|
105 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
106 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
107 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
108 |
+
|
109 |
+
return uv_grid
|
models/SpaTrackV2/models/vggt4track/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
models/SpaTrackV2/models/vggt4track/layers/attention.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
XFORMERS_AVAILABLE = False
|
19 |
+
|
20 |
+
|
21 |
+
class Attention(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dim: int,
|
25 |
+
num_heads: int = 8,
|
26 |
+
qkv_bias: bool = True,
|
27 |
+
proj_bias: bool = True,
|
28 |
+
attn_drop: float = 0.0,
|
29 |
+
proj_drop: float = 0.0,
|
30 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
31 |
+
qk_norm: bool = False,
|
32 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
33 |
+
rope=None,
|
34 |
+
) -> None:
|
35 |
+
super().__init__()
|
36 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
37 |
+
self.num_heads = num_heads
|
38 |
+
self.head_dim = dim // num_heads
|
39 |
+
self.scale = self.head_dim**-0.5
|
40 |
+
self.fused_attn = fused_attn
|
41 |
+
|
42 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
self.rope = rope
|
49 |
+
|
50 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
51 |
+
B, N, C = x.shape
|
52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
53 |
+
q, k, v = qkv.unbind(0)
|
54 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
55 |
+
|
56 |
+
if self.rope is not None:
|
57 |
+
q = self.rope(q, pos)
|
58 |
+
k = self.rope(k, pos)
|
59 |
+
|
60 |
+
if self.fused_attn:
|
61 |
+
x = F.scaled_dot_product_attention(
|
62 |
+
q,
|
63 |
+
k,
|
64 |
+
v,
|
65 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
q = q * self.scale
|
69 |
+
attn = q @ k.transpose(-2, -1)
|
70 |
+
attn = attn.softmax(dim=-1)
|
71 |
+
attn = self.attn_drop(attn)
|
72 |
+
x = attn @ v
|
73 |
+
|
74 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
75 |
+
x = self.proj(x)
|
76 |
+
x = self.proj_drop(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class MemEffAttention(Attention):
|
81 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
82 |
+
assert pos is None
|
83 |
+
if not XFORMERS_AVAILABLE:
|
84 |
+
if attn_bias is not None:
|
85 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
86 |
+
return super().forward(x)
|
87 |
+
|
88 |
+
B, N, C = x.shape
|
89 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
90 |
+
|
91 |
+
q, k, v = unbind(qkv, 2)
|
92 |
+
|
93 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
94 |
+
x = x.reshape([B, N, C])
|
95 |
+
|
96 |
+
x = self.proj(x)
|
97 |
+
x = self.proj_drop(x)
|
98 |
+
return x
|
models/SpaTrackV2/models/vggt4track/layers/block.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
XFORMERS_AVAILABLE = False
|
25 |
+
|
26 |
+
|
27 |
+
class Block(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dim: int,
|
31 |
+
num_heads: int,
|
32 |
+
mlp_ratio: float = 4.0,
|
33 |
+
qkv_bias: bool = True,
|
34 |
+
proj_bias: bool = True,
|
35 |
+
ffn_bias: bool = True,
|
36 |
+
drop: float = 0.0,
|
37 |
+
attn_drop: float = 0.0,
|
38 |
+
init_values=None,
|
39 |
+
drop_path: float = 0.0,
|
40 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
41 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
42 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
43 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
44 |
+
qk_norm: bool = False,
|
45 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
46 |
+
rope=None,
|
47 |
+
) -> None:
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.norm1 = norm_layer(dim)
|
51 |
+
|
52 |
+
self.attn = attn_class(
|
53 |
+
dim,
|
54 |
+
num_heads=num_heads,
|
55 |
+
qkv_bias=qkv_bias,
|
56 |
+
proj_bias=proj_bias,
|
57 |
+
attn_drop=attn_drop,
|
58 |
+
proj_drop=drop,
|
59 |
+
qk_norm=qk_norm,
|
60 |
+
fused_attn=fused_attn,
|
61 |
+
rope=rope,
|
62 |
+
)
|
63 |
+
|
64 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
65 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
66 |
+
|
67 |
+
self.norm2 = norm_layer(dim)
|
68 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
69 |
+
self.mlp = ffn_layer(
|
70 |
+
in_features=dim,
|
71 |
+
hidden_features=mlp_hidden_dim,
|
72 |
+
act_layer=act_layer,
|
73 |
+
drop=drop,
|
74 |
+
bias=ffn_bias,
|
75 |
+
)
|
76 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
77 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
78 |
+
|
79 |
+
self.sample_drop_ratio = drop_path
|
80 |
+
|
81 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
82 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
83 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
84 |
+
|
85 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
86 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
87 |
+
|
88 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
89 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
90 |
+
x = drop_add_residual_stochastic_depth(
|
91 |
+
x,
|
92 |
+
pos=pos,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x, pos=pos)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
pos=None,
|
115 |
+
) -> Tensor:
|
116 |
+
# 1) extract subset using permutation
|
117 |
+
b, n, d = x.shape
|
118 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
119 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
120 |
+
x_subset = x[brange]
|
121 |
+
|
122 |
+
# 2) apply residual_func to get residual
|
123 |
+
if pos is not None:
|
124 |
+
# if necessary, apply rope to the subset
|
125 |
+
pos = pos[brange]
|
126 |
+
residual = residual_func(x_subset, pos=pos)
|
127 |
+
else:
|
128 |
+
residual = residual_func(x_subset)
|
129 |
+
|
130 |
+
x_flat = x.flatten(1)
|
131 |
+
residual = residual.flatten(1)
|
132 |
+
|
133 |
+
residual_scale_factor = b / sample_subset_size
|
134 |
+
|
135 |
+
# 3) add the residual
|
136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
137 |
+
return x_plus_residual.view_as(x)
|
138 |
+
|
139 |
+
|
140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
141 |
+
b, n, d = x.shape
|
142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
144 |
+
residual_scale_factor = b / sample_subset_size
|
145 |
+
return brange, residual_scale_factor
|
146 |
+
|
147 |
+
|
148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
149 |
+
if scaling_vector is None:
|
150 |
+
x_flat = x.flatten(1)
|
151 |
+
residual = residual.flatten(1)
|
152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
153 |
+
else:
|
154 |
+
x_plus_residual = scaled_index_add(
|
155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
156 |
+
)
|
157 |
+
return x_plus_residual
|
158 |
+
|
159 |
+
|
160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
161 |
+
|
162 |
+
|
163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
164 |
+
"""
|
165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
166 |
+
"""
|
167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
169 |
+
if all_shapes not in attn_bias_cache.keys():
|
170 |
+
seqlens = []
|
171 |
+
for b, x in zip(batch_sizes, x_list):
|
172 |
+
for _ in range(b):
|
173 |
+
seqlens.append(x.shape[1])
|
174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
175 |
+
attn_bias._batch_sizes = batch_sizes
|
176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
177 |
+
|
178 |
+
if branges is not None:
|
179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
180 |
+
else:
|
181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
183 |
+
|
184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
185 |
+
|
186 |
+
|
187 |
+
def drop_add_residual_stochastic_depth_list(
|
188 |
+
x_list: List[Tensor],
|
189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
190 |
+
sample_drop_ratio: float = 0.0,
|
191 |
+
scaling_vector=None,
|
192 |
+
) -> Tensor:
|
193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
195 |
+
branges = [s[0] for s in branges_scales]
|
196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
197 |
+
|
198 |
+
# 2) get attention bias and index+concat the tensors
|
199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
200 |
+
|
201 |
+
# 3) apply residual_func to get residual, and split the result
|
202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
203 |
+
|
204 |
+
outputs = []
|
205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
207 |
+
return outputs
|
208 |
+
|
209 |
+
|
210 |
+
class NestedTensorBlock(Block):
|
211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
212 |
+
"""
|
213 |
+
x_list contains a list of tensors to nest together and run
|
214 |
+
"""
|
215 |
+
assert isinstance(self.attn, MemEffAttention)
|
216 |
+
|
217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
218 |
+
|
219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
221 |
+
|
222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
223 |
+
return self.mlp(self.norm2(x))
|
224 |
+
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=attn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
232 |
+
x_list,
|
233 |
+
residual_func=ffn_residual_func,
|
234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
236 |
+
)
|
237 |
+
return x_list
|
238 |
+
else:
|
239 |
+
|
240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
242 |
+
|
243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
245 |
+
|
246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
248 |
+
x = x + ffn_residual_func(x)
|
249 |
+
return attn_bias.split(x)
|
250 |
+
|
251 |
+
def forward(self, x_or_x_list):
|
252 |
+
if isinstance(x_or_x_list, Tensor):
|
253 |
+
return super().forward(x_or_x_list)
|
254 |
+
elif isinstance(x_or_x_list, list):
|
255 |
+
if not XFORMERS_AVAILABLE:
|
256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
257 |
+
return self.forward_nested(x_or_x_list)
|
258 |
+
else:
|
259 |
+
raise AssertionError
|
models/SpaTrackV2/models/vggt4track/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
models/SpaTrackV2/models/vggt4track/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
models/SpaTrackV2/models/vggt4track/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
models/SpaTrackV2/models/vggt4track/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
models/SpaTrackV2/models/vggt4track/layers/rope.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
8 |
+
|
9 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
10 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
11 |
+
|
12 |
+
# Inspired by:
|
13 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
14 |
+
# https://github.com/naver-ai/rope-vit
|
15 |
+
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from typing import Dict, Tuple
|
22 |
+
|
23 |
+
|
24 |
+
class PositionGetter:
|
25 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
26 |
+
|
27 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
28 |
+
in a 2D grid, caching results to avoid redundant computations.
|
29 |
+
|
30 |
+
Attributes:
|
31 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
32 |
+
grid dimensions.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
"""Initializes the position generator with an empty cache."""
|
37 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
38 |
+
|
39 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
40 |
+
"""Generates spatial positions for a batch of patches.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
batch_size: Number of samples in the batch.
|
44 |
+
height: Height of the grid in patches.
|
45 |
+
width: Width of the grid in patches.
|
46 |
+
device: Target device for the position tensor.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
50 |
+
for each position in the grid, repeated for each batch item.
|
51 |
+
"""
|
52 |
+
if (height, width) not in self.position_cache:
|
53 |
+
y_coords = torch.arange(height, device=device)
|
54 |
+
x_coords = torch.arange(width, device=device)
|
55 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
56 |
+
self.position_cache[height, width] = positions
|
57 |
+
|
58 |
+
cached_positions = self.position_cache[height, width]
|
59 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
60 |
+
|
61 |
+
|
62 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
63 |
+
"""2D Rotary Position Embedding implementation.
|
64 |
+
|
65 |
+
This module applies rotary position embeddings to input tokens based on their
|
66 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
67 |
+
separately for vertical and horizontal dimensions.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
71 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
72 |
+
|
73 |
+
Attributes:
|
74 |
+
base_frequency: Base frequency for computing position embeddings.
|
75 |
+
scaling_factor: Factor to scale the computed frequencies.
|
76 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
80 |
+
"""Initializes the 2D RoPE module."""
|
81 |
+
super().__init__()
|
82 |
+
self.base_frequency = frequency
|
83 |
+
self.scaling_factor = scaling_factor
|
84 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
85 |
+
|
86 |
+
def _compute_frequency_components(
|
87 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
89 |
+
"""Computes frequency components for rotary embeddings.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
dim: Feature dimension (must be even).
|
93 |
+
seq_len: Maximum sequence length.
|
94 |
+
device: Target device for computations.
|
95 |
+
dtype: Data type for the computed tensors.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
99 |
+
"""
|
100 |
+
cache_key = (dim, seq_len, device, dtype)
|
101 |
+
if cache_key not in self.frequency_cache:
|
102 |
+
# Compute frequency bands
|
103 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
104 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
105 |
+
|
106 |
+
# Generate position-dependent frequencies
|
107 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
108 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
109 |
+
|
110 |
+
# Compute and cache frequency components
|
111 |
+
angles = angles.to(dtype)
|
112 |
+
angles = torch.cat((angles, angles), dim=-1)
|
113 |
+
cos_components = angles.cos().to(dtype)
|
114 |
+
sin_components = angles.sin().to(dtype)
|
115 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
116 |
+
|
117 |
+
return self.frequency_cache[cache_key]
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
121 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
x: Input tensor to rotate.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Rotated feature tensor.
|
128 |
+
"""
|
129 |
+
feature_dim = x.shape[-1]
|
130 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
131 |
+
return torch.cat((-x2, x1), dim=-1)
|
132 |
+
|
133 |
+
def _apply_1d_rope(
|
134 |
+
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
135 |
+
) -> torch.Tensor:
|
136 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
tokens: Input token features.
|
140 |
+
positions: Position indices.
|
141 |
+
cos_comp: Cosine components for rotation.
|
142 |
+
sin_comp: Sine components for rotation.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Tokens with applied rotary position embeddings.
|
146 |
+
"""
|
147 |
+
# Embed positions with frequency components
|
148 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
149 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
150 |
+
|
151 |
+
# Apply rotation
|
152 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
153 |
+
|
154 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
155 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
159 |
+
The feature dimension (dim) must be divisible by 4.
|
160 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
161 |
+
the y and x coordinates for each token.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
165 |
+
|
166 |
+
Raises:
|
167 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
168 |
+
"""
|
169 |
+
# Validate inputs
|
170 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
171 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
172 |
+
|
173 |
+
# Compute feature dimension for each spatial direction
|
174 |
+
feature_dim = tokens.size(-1) // 2
|
175 |
+
|
176 |
+
# Get frequency components
|
177 |
+
max_position = int(positions.max()) + 1
|
178 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
179 |
+
|
180 |
+
# Split features for vertical and horizontal processing
|
181 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
182 |
+
|
183 |
+
# Apply RoPE separately for each dimension
|
184 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
185 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
186 |
+
|
187 |
+
# Combine processed features
|
188 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
# try:
|
39 |
+
# if XFORMERS_ENABLED:
|
40 |
+
# from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
# XFORMERS_AVAILABLE = True
|
43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
# else:
|
45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
# raise ImportError
|
47 |
+
# except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch.utils.checkpoint import checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
20 |
+
|
21 |
+
logger = logging.getLogger("dinov2")
|
22 |
+
|
23 |
+
|
24 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
25 |
+
if not depth_first and include_root:
|
26 |
+
fn(module=module, name=name)
|
27 |
+
for child_name, child_module in module.named_children():
|
28 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
29 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
30 |
+
if depth_first and include_root:
|
31 |
+
fn(module=module, name=name)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
class BlockChunk(nn.ModuleList):
|
36 |
+
def forward(self, x):
|
37 |
+
for b in self:
|
38 |
+
x = b(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class DinoVisionTransformer(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
img_size=224,
|
46 |
+
patch_size=16,
|
47 |
+
in_chans=3,
|
48 |
+
embed_dim=768,
|
49 |
+
depth=12,
|
50 |
+
num_heads=12,
|
51 |
+
mlp_ratio=4.0,
|
52 |
+
qkv_bias=True,
|
53 |
+
ffn_bias=True,
|
54 |
+
proj_bias=True,
|
55 |
+
drop_path_rate=0.0,
|
56 |
+
drop_path_uniform=False,
|
57 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
58 |
+
embed_layer=PatchEmbed,
|
59 |
+
act_layer=nn.GELU,
|
60 |
+
block_fn=Block,
|
61 |
+
ffn_layer="mlp",
|
62 |
+
block_chunks=1,
|
63 |
+
num_register_tokens=0,
|
64 |
+
interpolate_antialias=False,
|
65 |
+
interpolate_offset=0.1,
|
66 |
+
qk_norm=False,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
img_size (int, tuple): input image size
|
71 |
+
patch_size (int, tuple): patch size
|
72 |
+
in_chans (int): number of input channels
|
73 |
+
embed_dim (int): embedding dimension
|
74 |
+
depth (int): depth of transformer
|
75 |
+
num_heads (int): number of attention heads
|
76 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
77 |
+
qkv_bias (bool): enable bias for qkv if True
|
78 |
+
proj_bias (bool): enable bias for proj in attn if True
|
79 |
+
ffn_bias (bool): enable bias for ffn if True
|
80 |
+
drop_path_rate (float): stochastic depth rate
|
81 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
82 |
+
weight_init (str): weight init scheme
|
83 |
+
init_values (float): layer-scale init values
|
84 |
+
embed_layer (nn.Module): patch embedding layer
|
85 |
+
act_layer (nn.Module): MLP activation layer
|
86 |
+
block_fn (nn.Module): transformer block class
|
87 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
88 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
89 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
90 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
91 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
95 |
+
|
96 |
+
# tricky but makes it work
|
97 |
+
self.use_checkpoint = False
|
98 |
+
#
|
99 |
+
|
100 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
101 |
+
self.num_tokens = 1
|
102 |
+
self.n_blocks = depth
|
103 |
+
self.num_heads = num_heads
|
104 |
+
self.patch_size = patch_size
|
105 |
+
self.num_register_tokens = num_register_tokens
|
106 |
+
self.interpolate_antialias = interpolate_antialias
|
107 |
+
self.interpolate_offset = interpolate_offset
|
108 |
+
|
109 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
110 |
+
num_patches = self.patch_embed.num_patches
|
111 |
+
|
112 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
113 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
114 |
+
assert num_register_tokens >= 0
|
115 |
+
self.register_tokens = (
|
116 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
117 |
+
)
|
118 |
+
|
119 |
+
if drop_path_uniform is True:
|
120 |
+
dpr = [drop_path_rate] * depth
|
121 |
+
else:
|
122 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
123 |
+
|
124 |
+
if ffn_layer == "mlp":
|
125 |
+
logger.info("using MLP layer as FFN")
|
126 |
+
ffn_layer = Mlp
|
127 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
128 |
+
logger.info("using SwiGLU layer as FFN")
|
129 |
+
ffn_layer = SwiGLUFFNFused
|
130 |
+
elif ffn_layer == "identity":
|
131 |
+
logger.info("using Identity layer as FFN")
|
132 |
+
|
133 |
+
def f(*args, **kwargs):
|
134 |
+
return nn.Identity()
|
135 |
+
|
136 |
+
ffn_layer = f
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
blocks_list = [
|
141 |
+
block_fn(
|
142 |
+
dim=embed_dim,
|
143 |
+
num_heads=num_heads,
|
144 |
+
mlp_ratio=mlp_ratio,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
proj_bias=proj_bias,
|
147 |
+
ffn_bias=ffn_bias,
|
148 |
+
drop_path=dpr[i],
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
act_layer=act_layer,
|
151 |
+
ffn_layer=ffn_layer,
|
152 |
+
init_values=init_values,
|
153 |
+
qk_norm=qk_norm,
|
154 |
+
)
|
155 |
+
for i in range(depth)
|
156 |
+
]
|
157 |
+
if block_chunks > 0:
|
158 |
+
self.chunked_blocks = True
|
159 |
+
chunked_blocks = []
|
160 |
+
chunksize = depth // block_chunks
|
161 |
+
for i in range(0, depth, chunksize):
|
162 |
+
# this is to keep the block index consistent if we chunk the block list
|
163 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
164 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
165 |
+
else:
|
166 |
+
self.chunked_blocks = False
|
167 |
+
self.blocks = nn.ModuleList(blocks_list)
|
168 |
+
|
169 |
+
self.norm = norm_layer(embed_dim)
|
170 |
+
self.head = nn.Identity()
|
171 |
+
|
172 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
173 |
+
|
174 |
+
self.init_weights()
|
175 |
+
|
176 |
+
def init_weights(self):
|
177 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
178 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
179 |
+
if self.register_tokens is not None:
|
180 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
181 |
+
named_apply(init_weights_vit_timm, self)
|
182 |
+
|
183 |
+
def interpolate_pos_encoding(self, x, w, h):
|
184 |
+
previous_dtype = x.dtype
|
185 |
+
npatch = x.shape[1] - 1
|
186 |
+
N = self.pos_embed.shape[1] - 1
|
187 |
+
if npatch == N and w == h:
|
188 |
+
return self.pos_embed
|
189 |
+
pos_embed = self.pos_embed.float()
|
190 |
+
class_pos_embed = pos_embed[:, 0]
|
191 |
+
patch_pos_embed = pos_embed[:, 1:]
|
192 |
+
dim = x.shape[-1]
|
193 |
+
w0 = w // self.patch_size
|
194 |
+
h0 = h // self.patch_size
|
195 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
196 |
+
assert N == M * M
|
197 |
+
kwargs = {}
|
198 |
+
if self.interpolate_offset:
|
199 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
200 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
201 |
+
sx = float(w0 + self.interpolate_offset) / M
|
202 |
+
sy = float(h0 + self.interpolate_offset) / M
|
203 |
+
kwargs["scale_factor"] = (sx, sy)
|
204 |
+
else:
|
205 |
+
# Simply specify an output size instead of a scale factor
|
206 |
+
kwargs["size"] = (w0, h0)
|
207 |
+
patch_pos_embed = nn.functional.interpolate(
|
208 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
209 |
+
mode="bicubic",
|
210 |
+
antialias=self.interpolate_antialias,
|
211 |
+
**kwargs,
|
212 |
+
)
|
213 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
214 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
215 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
216 |
+
|
217 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
218 |
+
B, nc, w, h = x.shape
|
219 |
+
x = self.patch_embed(x)
|
220 |
+
if masks is not None:
|
221 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
222 |
+
|
223 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
224 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
225 |
+
|
226 |
+
if self.register_tokens is not None:
|
227 |
+
x = torch.cat(
|
228 |
+
(
|
229 |
+
x[:, :1],
|
230 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
231 |
+
x[:, 1:],
|
232 |
+
),
|
233 |
+
dim=1,
|
234 |
+
)
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
def forward_features_list(self, x_list, masks_list):
|
239 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
240 |
+
|
241 |
+
for blk in self.blocks:
|
242 |
+
if self.use_checkpoint:
|
243 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
244 |
+
else:
|
245 |
+
x = blk(x)
|
246 |
+
|
247 |
+
all_x = x
|
248 |
+
output = []
|
249 |
+
for x, masks in zip(all_x, masks_list):
|
250 |
+
x_norm = self.norm(x)
|
251 |
+
output.append(
|
252 |
+
{
|
253 |
+
"x_norm_clstoken": x_norm[:, 0],
|
254 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
255 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
256 |
+
"x_prenorm": x,
|
257 |
+
"masks": masks,
|
258 |
+
}
|
259 |
+
)
|
260 |
+
return output
|
261 |
+
|
262 |
+
def forward_features(self, x, masks=None):
|
263 |
+
if isinstance(x, list):
|
264 |
+
return self.forward_features_list(x, masks)
|
265 |
+
|
266 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
267 |
+
|
268 |
+
for blk in self.blocks:
|
269 |
+
if self.use_checkpoint:
|
270 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
271 |
+
else:
|
272 |
+
x = blk(x)
|
273 |
+
|
274 |
+
x_norm = self.norm(x)
|
275 |
+
return {
|
276 |
+
"x_norm_clstoken": x_norm[:, 0],
|
277 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
278 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
279 |
+
"x_prenorm": x,
|
280 |
+
"masks": masks,
|
281 |
+
}
|
282 |
+
|
283 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
286 |
+
output, total_block_len = [], len(self.blocks)
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for i, blk in enumerate(self.blocks):
|
289 |
+
x = blk(x)
|
290 |
+
if i in blocks_to_take:
|
291 |
+
output.append(x)
|
292 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
293 |
+
return output
|
294 |
+
|
295 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
296 |
+
x = self.prepare_tokens_with_masks(x)
|
297 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
298 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
299 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
300 |
+
for block_chunk in self.blocks:
|
301 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
302 |
+
x = blk(x)
|
303 |
+
if i in blocks_to_take:
|
304 |
+
output.append(x)
|
305 |
+
i += 1
|
306 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
307 |
+
return output
|
308 |
+
|
309 |
+
def get_intermediate_layers(
|
310 |
+
self,
|
311 |
+
x: torch.Tensor,
|
312 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
313 |
+
reshape: bool = False,
|
314 |
+
return_class_token: bool = False,
|
315 |
+
norm=True,
|
316 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
317 |
+
if self.chunked_blocks:
|
318 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
319 |
+
else:
|
320 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
321 |
+
if norm:
|
322 |
+
outputs = [self.norm(out) for out in outputs]
|
323 |
+
class_tokens = [out[:, 0] for out in outputs]
|
324 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
325 |
+
if reshape:
|
326 |
+
B, _, w, h = x.shape
|
327 |
+
outputs = [
|
328 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
329 |
+
for out in outputs
|
330 |
+
]
|
331 |
+
if return_class_token:
|
332 |
+
return tuple(zip(outputs, class_tokens))
|
333 |
+
return tuple(outputs)
|
334 |
+
|
335 |
+
def forward(self, *args, is_training=True, **kwargs):
|
336 |
+
ret = self.forward_features(*args, **kwargs)
|
337 |
+
if is_training:
|
338 |
+
return ret
|
339 |
+
else:
|
340 |
+
return self.head(ret["x_norm_clstoken"])
|
341 |
+
|
342 |
+
|
343 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
344 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
345 |
+
if isinstance(module, nn.Linear):
|
346 |
+
trunc_normal_(module.weight, std=0.02)
|
347 |
+
if module.bias is not None:
|
348 |
+
nn.init.zeros_(module.bias)
|
349 |
+
|
350 |
+
|
351 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
352 |
+
model = DinoVisionTransformer(
|
353 |
+
patch_size=patch_size,
|
354 |
+
embed_dim=384,
|
355 |
+
depth=12,
|
356 |
+
num_heads=6,
|
357 |
+
mlp_ratio=4,
|
358 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
359 |
+
num_register_tokens=num_register_tokens,
|
360 |
+
**kwargs,
|
361 |
+
)
|
362 |
+
return model
|
363 |
+
|
364 |
+
|
365 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
366 |
+
model = DinoVisionTransformer(
|
367 |
+
patch_size=patch_size,
|
368 |
+
embed_dim=768,
|
369 |
+
depth=12,
|
370 |
+
num_heads=12,
|
371 |
+
mlp_ratio=4,
|
372 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
373 |
+
num_register_tokens=num_register_tokens,
|
374 |
+
**kwargs,
|
375 |
+
)
|
376 |
+
return model
|
377 |
+
|
378 |
+
|
379 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
380 |
+
model = DinoVisionTransformer(
|
381 |
+
patch_size=patch_size,
|
382 |
+
embed_dim=1024,
|
383 |
+
depth=24,
|
384 |
+
num_heads=16,
|
385 |
+
mlp_ratio=4,
|
386 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
387 |
+
num_register_tokens=num_register_tokens,
|
388 |
+
**kwargs,
|
389 |
+
)
|
390 |
+
return model
|
391 |
+
|
392 |
+
|
393 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
394 |
+
"""
|
395 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
396 |
+
"""
|
397 |
+
model = DinoVisionTransformer(
|
398 |
+
patch_size=patch_size,
|
399 |
+
embed_dim=1536,
|
400 |
+
depth=40,
|
401 |
+
num_heads=24,
|
402 |
+
mlp_ratio=4,
|
403 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
404 |
+
num_register_tokens=num_register_tokens,
|
405 |
+
**kwargs,
|
406 |
+
)
|
407 |
+
return model
|
models/SpaTrackV2/models/vggt4track/models/aggregator.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
12 |
+
|
13 |
+
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
|
14 |
+
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
15 |
+
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
16 |
+
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
17 |
+
from torch.utils.checkpoint import checkpoint
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
22 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
23 |
+
|
24 |
+
|
25 |
+
class Aggregator(nn.Module):
|
26 |
+
"""
|
27 |
+
The Aggregator applies alternating-attention over input frames,
|
28 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
29 |
+
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img_size (int): Image size in pixels.
|
33 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
34 |
+
embed_dim (int): Dimension of the token embeddings.
|
35 |
+
depth (int): Number of blocks.
|
36 |
+
num_heads (int): Number of attention heads.
|
37 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
38 |
+
num_register_tokens (int): Number of register tokens.
|
39 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
40 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
41 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
42 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
43 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
44 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
45 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
46 |
+
qk_norm (bool): Whether to apply QK normalization.
|
47 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
48 |
+
init_values (float): Init scale for layer scale.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
img_size=518,
|
54 |
+
patch_size=14,
|
55 |
+
embed_dim=1024,
|
56 |
+
depth=24,
|
57 |
+
num_heads=16,
|
58 |
+
mlp_ratio=4.0,
|
59 |
+
num_register_tokens=4,
|
60 |
+
block_fn=Block,
|
61 |
+
qkv_bias=True,
|
62 |
+
proj_bias=True,
|
63 |
+
ffn_bias=True,
|
64 |
+
patch_embed="dinov2_vitl14_reg",
|
65 |
+
aa_order=["frame", "global"],
|
66 |
+
aa_block_size=1,
|
67 |
+
qk_norm=True,
|
68 |
+
rope_freq=100,
|
69 |
+
init_values=0.01,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
74 |
+
|
75 |
+
# Initialize rotary position embedding if frequency > 0
|
76 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
77 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
78 |
+
|
79 |
+
self.frame_blocks = nn.ModuleList(
|
80 |
+
[
|
81 |
+
block_fn(
|
82 |
+
dim=embed_dim,
|
83 |
+
num_heads=num_heads,
|
84 |
+
mlp_ratio=mlp_ratio,
|
85 |
+
qkv_bias=qkv_bias,
|
86 |
+
proj_bias=proj_bias,
|
87 |
+
ffn_bias=ffn_bias,
|
88 |
+
init_values=init_values,
|
89 |
+
qk_norm=qk_norm,
|
90 |
+
rope=self.rope,
|
91 |
+
)
|
92 |
+
for _ in range(depth)
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
self.global_blocks = nn.ModuleList(
|
97 |
+
[
|
98 |
+
block_fn(
|
99 |
+
dim=embed_dim,
|
100 |
+
num_heads=num_heads,
|
101 |
+
mlp_ratio=mlp_ratio,
|
102 |
+
qkv_bias=qkv_bias,
|
103 |
+
proj_bias=proj_bias,
|
104 |
+
ffn_bias=ffn_bias,
|
105 |
+
init_values=init_values,
|
106 |
+
qk_norm=qk_norm,
|
107 |
+
rope=self.rope,
|
108 |
+
)
|
109 |
+
for _ in range(depth)
|
110 |
+
]
|
111 |
+
)
|
112 |
+
|
113 |
+
self.depth = depth
|
114 |
+
self.aa_order = aa_order
|
115 |
+
self.patch_size = patch_size
|
116 |
+
self.aa_block_size = aa_block_size
|
117 |
+
|
118 |
+
# Validate that depth is divisible by aa_block_size
|
119 |
+
if self.depth % self.aa_block_size != 0:
|
120 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
121 |
+
|
122 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
123 |
+
|
124 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
125 |
+
# The same applies for register tokens
|
126 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
127 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
128 |
+
|
129 |
+
# The patch tokens start after the camera and register tokens
|
130 |
+
self.patch_start_idx = 1 + num_register_tokens
|
131 |
+
|
132 |
+
# Initialize parameters with small values
|
133 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
134 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
135 |
+
|
136 |
+
# Register normalization constants as buffers
|
137 |
+
for name, value in (
|
138 |
+
("_resnet_mean", _RESNET_MEAN),
|
139 |
+
("_resnet_std", _RESNET_STD),
|
140 |
+
):
|
141 |
+
self.register_buffer(
|
142 |
+
name,
|
143 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
144 |
+
persistent=False,
|
145 |
+
)
|
146 |
+
|
147 |
+
def __build_patch_embed__(
|
148 |
+
self,
|
149 |
+
patch_embed,
|
150 |
+
img_size,
|
151 |
+
patch_size,
|
152 |
+
num_register_tokens,
|
153 |
+
interpolate_antialias=True,
|
154 |
+
interpolate_offset=0.0,
|
155 |
+
block_chunks=0,
|
156 |
+
init_values=1.0,
|
157 |
+
embed_dim=1024,
|
158 |
+
):
|
159 |
+
"""
|
160 |
+
Build the patch embed layer. If 'conv', we use a
|
161 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
162 |
+
"""
|
163 |
+
|
164 |
+
if "conv" in patch_embed:
|
165 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
166 |
+
else:
|
167 |
+
vit_models = {
|
168 |
+
"dinov2_vitl14_reg": vit_large,
|
169 |
+
"dinov2_vitb14_reg": vit_base,
|
170 |
+
"dinov2_vits14_reg": vit_small,
|
171 |
+
"dinov2_vitg2_reg": vit_giant2,
|
172 |
+
}
|
173 |
+
|
174 |
+
self.patch_embed = vit_models[patch_embed](
|
175 |
+
img_size=img_size,
|
176 |
+
patch_size=patch_size,
|
177 |
+
num_register_tokens=num_register_tokens,
|
178 |
+
interpolate_antialias=interpolate_antialias,
|
179 |
+
interpolate_offset=interpolate_offset,
|
180 |
+
block_chunks=block_chunks,
|
181 |
+
init_values=init_values,
|
182 |
+
)
|
183 |
+
|
184 |
+
# Disable gradient updates for mask token
|
185 |
+
if hasattr(self.patch_embed, "mask_token"):
|
186 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
images: torch.Tensor,
|
191 |
+
) -> Tuple[List[torch.Tensor], int]:
|
192 |
+
"""
|
193 |
+
Args:
|
194 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
195 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
(list[torch.Tensor], int):
|
199 |
+
The list of outputs from the attention blocks,
|
200 |
+
and the patch_start_idx indicating where patch tokens begin.
|
201 |
+
"""
|
202 |
+
B, S, C_in, H, W = images.shape
|
203 |
+
|
204 |
+
if C_in != 3:
|
205 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
206 |
+
|
207 |
+
# Normalize images and reshape for patch embed
|
208 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
209 |
+
|
210 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
211 |
+
images = images.view(B * S, C_in, H, W)
|
212 |
+
patch_tokens = self.patch_embed(images)
|
213 |
+
|
214 |
+
if isinstance(patch_tokens, dict):
|
215 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
216 |
+
|
217 |
+
_, P, C = patch_tokens.shape
|
218 |
+
|
219 |
+
# Expand camera and register tokens to match batch size and sequence length
|
220 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
221 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
222 |
+
|
223 |
+
# Concatenate special tokens with patch tokens
|
224 |
+
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
|
225 |
+
|
226 |
+
pos = None
|
227 |
+
if self.rope is not None:
|
228 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
229 |
+
|
230 |
+
if self.patch_start_idx > 0:
|
231 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
232 |
+
# so set pos to 0 for the special tokens
|
233 |
+
pos = pos + 1
|
234 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
235 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
236 |
+
|
237 |
+
# update P because we added special tokens
|
238 |
+
_, P, C = tokens.shape
|
239 |
+
|
240 |
+
frame_idx = 0
|
241 |
+
global_idx = 0
|
242 |
+
output_list = []
|
243 |
+
|
244 |
+
for _ in range(self.aa_block_num):
|
245 |
+
for attn_type in self.aa_order:
|
246 |
+
if attn_type == "frame":
|
247 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
248 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
249 |
+
)
|
250 |
+
elif attn_type == "global":
|
251 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
252 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
256 |
+
|
257 |
+
for i in range(len(frame_intermediates)):
|
258 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
259 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
260 |
+
output_list.append(concat_inter)
|
261 |
+
|
262 |
+
del concat_inter
|
263 |
+
del frame_intermediates
|
264 |
+
del global_intermediates
|
265 |
+
return output_list, self.patch_start_idx
|
266 |
+
|
267 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
268 |
+
"""
|
269 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
270 |
+
"""
|
271 |
+
# If needed, reshape tokens or positions:
|
272 |
+
if tokens.shape != (B * S, P, C):
|
273 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
274 |
+
|
275 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
276 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
277 |
+
|
278 |
+
intermediates = []
|
279 |
+
|
280 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
281 |
+
for _ in range(self.aa_block_size):
|
282 |
+
if self.training:
|
283 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=False)
|
284 |
+
else:
|
285 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
286 |
+
frame_idx += 1
|
287 |
+
intermediates.append(tokens.view(B, S, P, C))
|
288 |
+
|
289 |
+
return tokens, frame_idx, intermediates
|
290 |
+
|
291 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
292 |
+
"""
|
293 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
294 |
+
"""
|
295 |
+
if tokens.shape != (B, S * P, C):
|
296 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
297 |
+
|
298 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
299 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
300 |
+
|
301 |
+
intermediates = []
|
302 |
+
|
303 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
304 |
+
for _ in range(self.aa_block_size):
|
305 |
+
if self.training:
|
306 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=False)
|
307 |
+
else:
|
308 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
309 |
+
global_idx += 1
|
310 |
+
intermediates.append(tokens.view(B, S, P, C))
|
311 |
+
|
312 |
+
return tokens, global_idx, intermediates
|
313 |
+
|
314 |
+
|
315 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
316 |
+
"""
|
317 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
318 |
+
1) Uses the first position (index=0) for the first frame only
|
319 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
320 |
+
3) Expands both to match batch size B
|
321 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
322 |
+
followed by (S-1) second-position tokens
|
323 |
+
5) Flattens to (B*S, X, C) for processing
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
327 |
+
"""
|
328 |
+
|
329 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
330 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
331 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
332 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
333 |
+
# Concatenate => shape (B, S, ...)
|
334 |
+
combined = torch.cat([query, others], dim=1)
|
335 |
+
|
336 |
+
# Finally flatten => shape (B*S, ...)
|
337 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
338 |
+
return combined
|
models/SpaTrackV2/models/vggt4track/models/aggregator_front.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
12 |
+
|
13 |
+
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
|
14 |
+
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
15 |
+
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
16 |
+
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
17 |
+
from torch.utils.checkpoint import checkpoint
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
22 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
23 |
+
|
24 |
+
|
25 |
+
class Aggregator(nn.Module):
|
26 |
+
"""
|
27 |
+
The Aggregator applies alternating-attention over input frames,
|
28 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
29 |
+
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img_size (int): Image size in pixels.
|
33 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
34 |
+
embed_dim (int): Dimension of the token embeddings.
|
35 |
+
depth (int): Number of blocks.
|
36 |
+
num_heads (int): Number of attention heads.
|
37 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
38 |
+
num_register_tokens (int): Number of register tokens.
|
39 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
40 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
41 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
42 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
43 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
44 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
45 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
46 |
+
qk_norm (bool): Whether to apply QK normalization.
|
47 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
48 |
+
init_values (float): Init scale for layer scale.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
img_size=518,
|
54 |
+
patch_size=14,
|
55 |
+
embed_dim=1024,
|
56 |
+
depth=24,
|
57 |
+
num_heads=16,
|
58 |
+
mlp_ratio=4.0,
|
59 |
+
num_register_tokens=4,
|
60 |
+
block_fn=Block,
|
61 |
+
qkv_bias=True,
|
62 |
+
proj_bias=True,
|
63 |
+
ffn_bias=True,
|
64 |
+
patch_embed="dinov2_vitl14_reg",
|
65 |
+
aa_order=["frame", "global"],
|
66 |
+
aa_block_size=1,
|
67 |
+
qk_norm=True,
|
68 |
+
rope_freq=100,
|
69 |
+
init_values=0.01,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
# self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
74 |
+
|
75 |
+
self.use_reentrant = False
|
76 |
+
# Initialize rotary position embedding if frequency > 0
|
77 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
78 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
79 |
+
|
80 |
+
self.frame_blocks = nn.ModuleList(
|
81 |
+
[
|
82 |
+
block_fn(
|
83 |
+
dim=embed_dim,
|
84 |
+
num_heads=num_heads,
|
85 |
+
mlp_ratio=mlp_ratio,
|
86 |
+
qkv_bias=qkv_bias,
|
87 |
+
proj_bias=proj_bias,
|
88 |
+
ffn_bias=ffn_bias,
|
89 |
+
init_values=init_values,
|
90 |
+
qk_norm=qk_norm,
|
91 |
+
rope=self.rope,
|
92 |
+
)
|
93 |
+
for _ in range(depth)
|
94 |
+
]
|
95 |
+
)
|
96 |
+
|
97 |
+
self.global_blocks = nn.ModuleList(
|
98 |
+
[
|
99 |
+
block_fn(
|
100 |
+
dim=embed_dim,
|
101 |
+
num_heads=num_heads,
|
102 |
+
mlp_ratio=mlp_ratio,
|
103 |
+
qkv_bias=qkv_bias,
|
104 |
+
proj_bias=proj_bias,
|
105 |
+
ffn_bias=ffn_bias,
|
106 |
+
init_values=init_values,
|
107 |
+
qk_norm=qk_norm,
|
108 |
+
rope=self.rope,
|
109 |
+
)
|
110 |
+
for _ in range(depth)
|
111 |
+
]
|
112 |
+
)
|
113 |
+
|
114 |
+
self.depth = depth
|
115 |
+
self.aa_order = aa_order
|
116 |
+
self.patch_size = patch_size
|
117 |
+
self.aa_block_size = aa_block_size
|
118 |
+
|
119 |
+
# Validate that depth is divisible by aa_block_size
|
120 |
+
if self.depth % self.aa_block_size != 0:
|
121 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
122 |
+
|
123 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
124 |
+
|
125 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
126 |
+
# The same applies for register tokens
|
127 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
128 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
129 |
+
self.scale_shift_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
130 |
+
|
131 |
+
# The patch tokens start after the camera and register tokens
|
132 |
+
self.patch_start_idx = 1 + num_register_tokens + 1
|
133 |
+
|
134 |
+
# Initialize parameters with small values
|
135 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
136 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
137 |
+
nn.init.normal_(self.scale_shift_token, std=1e-6)
|
138 |
+
|
139 |
+
# Register normalization constants as buffers
|
140 |
+
for name, value in (
|
141 |
+
("_resnet_mean", _RESNET_MEAN),
|
142 |
+
("_resnet_std", _RESNET_STD),
|
143 |
+
):
|
144 |
+
self.register_buffer(
|
145 |
+
name,
|
146 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
147 |
+
persistent=False,
|
148 |
+
)
|
149 |
+
|
150 |
+
def __build_patch_embed__(
|
151 |
+
self,
|
152 |
+
patch_embed,
|
153 |
+
img_size,
|
154 |
+
patch_size,
|
155 |
+
num_register_tokens,
|
156 |
+
interpolate_antialias=True,
|
157 |
+
interpolate_offset=0.0,
|
158 |
+
block_chunks=0,
|
159 |
+
init_values=1.0,
|
160 |
+
embed_dim=1024,
|
161 |
+
):
|
162 |
+
"""
|
163 |
+
Build the patch embed layer. If 'conv', we use a
|
164 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
165 |
+
"""
|
166 |
+
|
167 |
+
if "conv" in patch_embed:
|
168 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
169 |
+
else:
|
170 |
+
vit_models = {
|
171 |
+
"dinov2_vitl14_reg": vit_large,
|
172 |
+
"dinov2_vitb14_reg": vit_base,
|
173 |
+
"dinov2_vits14_reg": vit_small,
|
174 |
+
"dinov2_vitg2_reg": vit_giant2,
|
175 |
+
}
|
176 |
+
|
177 |
+
self.patch_embed = vit_models[patch_embed](
|
178 |
+
img_size=img_size,
|
179 |
+
patch_size=patch_size,
|
180 |
+
num_register_tokens=num_register_tokens,
|
181 |
+
interpolate_antialias=interpolate_antialias,
|
182 |
+
interpolate_offset=interpolate_offset,
|
183 |
+
block_chunks=block_chunks,
|
184 |
+
init_values=init_values,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Disable gradient updates for mask token
|
188 |
+
if hasattr(self.patch_embed, "mask_token"):
|
189 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
images: torch.Tensor,
|
194 |
+
patch_tokens: torch.Tensor,
|
195 |
+
) -> Tuple[List[torch.Tensor], int]:
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
199 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
(list[torch.Tensor], int):
|
203 |
+
The list of outputs from the attention blocks,
|
204 |
+
and the patch_start_idx indicating where patch tokens begin.
|
205 |
+
"""
|
206 |
+
B, S, C_in, H, W = images.shape
|
207 |
+
|
208 |
+
# if C_in != 3:
|
209 |
+
# raise ValueError(f"Expected 3 input channels, got {C_in}")
|
210 |
+
|
211 |
+
# # Normalize images and reshape for patch embed
|
212 |
+
# images = (images - self._resnet_mean) / self._resnet_std
|
213 |
+
|
214 |
+
# # Reshape to [B*S, C, H, W] for patch embedding
|
215 |
+
# images = images.view(B * S, C_in, H, W)
|
216 |
+
# patch_tokens = self.patch_embed(images)
|
217 |
+
|
218 |
+
if isinstance(patch_tokens, dict):
|
219 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
220 |
+
|
221 |
+
_, P, C = patch_tokens.shape
|
222 |
+
# Expand camera and register tokens to match batch size and sequence length
|
223 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
224 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
225 |
+
scale_shift_token = slice_expand_and_flatten(self.scale_shift_token, B, S)
|
226 |
+
|
227 |
+
# Concatenate special tokens with patch tokens
|
228 |
+
tokens = torch.cat([camera_token, register_token, scale_shift_token, patch_tokens], dim=1)
|
229 |
+
|
230 |
+
pos = None
|
231 |
+
if self.rope is not None:
|
232 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
233 |
+
|
234 |
+
if self.patch_start_idx > 0:
|
235 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
236 |
+
# so set pos to 0 for the special tokens
|
237 |
+
pos = pos + 1
|
238 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
239 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
240 |
+
|
241 |
+
# update P because we added special tokens
|
242 |
+
_, P, C = tokens.shape
|
243 |
+
|
244 |
+
frame_idx = 0
|
245 |
+
global_idx = 0
|
246 |
+
output_list = []
|
247 |
+
|
248 |
+
for _ in range(self.aa_block_num):
|
249 |
+
for attn_type in self.aa_order:
|
250 |
+
if attn_type == "frame":
|
251 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
252 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
253 |
+
)
|
254 |
+
elif attn_type == "global":
|
255 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
256 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
260 |
+
|
261 |
+
for i in range(len(frame_intermediates)):
|
262 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
263 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
264 |
+
output_list.append(concat_inter)
|
265 |
+
|
266 |
+
del concat_inter
|
267 |
+
del frame_intermediates
|
268 |
+
del global_intermediates
|
269 |
+
return output_list, self.patch_start_idx
|
270 |
+
|
271 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
272 |
+
"""
|
273 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
274 |
+
"""
|
275 |
+
# If needed, reshape tokens or positions:
|
276 |
+
if tokens.shape != (B * S, P, C):
|
277 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
278 |
+
|
279 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
280 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
281 |
+
|
282 |
+
intermediates = []
|
283 |
+
|
284 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
285 |
+
for _ in range(self.aa_block_size):
|
286 |
+
if self.training:
|
287 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
288 |
+
else:
|
289 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
290 |
+
frame_idx += 1
|
291 |
+
intermediates.append(tokens.view(B, S, P, C))
|
292 |
+
|
293 |
+
return tokens, frame_idx, intermediates
|
294 |
+
|
295 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
296 |
+
"""
|
297 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
298 |
+
"""
|
299 |
+
if tokens.shape != (B, S * P, C):
|
300 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
301 |
+
|
302 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
303 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
304 |
+
|
305 |
+
intermediates = []
|
306 |
+
|
307 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
308 |
+
for _ in range(self.aa_block_size):
|
309 |
+
if self.training:
|
310 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
311 |
+
else:
|
312 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
313 |
+
global_idx += 1
|
314 |
+
intermediates.append(tokens.view(B, S, P, C))
|
315 |
+
|
316 |
+
return tokens, global_idx, intermediates
|
317 |
+
|
318 |
+
|
319 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
320 |
+
"""
|
321 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
322 |
+
1) Uses the first position (index=0) for the first frame only
|
323 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
324 |
+
3) Expands both to match batch size B
|
325 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
326 |
+
followed by (S-1) second-position tokens
|
327 |
+
5) Flattens to (B*S, X, C) for processing
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
331 |
+
"""
|
332 |
+
|
333 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
334 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
335 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
336 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
337 |
+
# Concatenate => shape (B, S, ...)
|
338 |
+
combined = torch.cat([query, others], dim=1)
|
339 |
+
|
340 |
+
# Finally flatten => shape (B*S, ...)
|
341 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
342 |
+
return combined
|
models/SpaTrackV2/models/vggt4track/models/tracker_front.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.utils.checkpoint import checkpoint
|
10 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
11 |
+
|
12 |
+
from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator
|
13 |
+
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
|
14 |
+
from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead
|
15 |
+
from einops import rearrange
|
16 |
+
from models.SpaTrackV2.utils.loss import compute_loss
|
17 |
+
from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
21 |
+
def __init__(self, img_size=518,
|
22 |
+
patch_size=14, embed_dim=1024, base_model=None, use_checkpoint=True, use_scale_head=False):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
26 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
27 |
+
if use_scale_head:
|
28 |
+
self.scale_head = ScaleHead(dim_in=2 * embed_dim)
|
29 |
+
else:
|
30 |
+
self.scale_head = None
|
31 |
+
self.base_model = base_model
|
32 |
+
self.use_checkpoint = use_checkpoint
|
33 |
+
self.intermediate_layers = [4, 11, 17, 23]
|
34 |
+
self.residual_proj = nn.ModuleList([nn.Linear(2048, 1024) for _ in range(len(self.intermediate_layers))])
|
35 |
+
# init the residual proj
|
36 |
+
for i in range(len(self.intermediate_layers)):
|
37 |
+
nn.init.xavier_uniform_(self.residual_proj[i].weight)
|
38 |
+
nn.init.zeros_(self.residual_proj[i].bias)
|
39 |
+
# self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
40 |
+
# self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
|
41 |
+
# self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
|
42 |
+
|
43 |
+
def forward(self,
|
44 |
+
images: torch.Tensor,
|
45 |
+
annots = {},
|
46 |
+
**kwargs):
|
47 |
+
"""
|
48 |
+
Forward pass of the FrontTracker model.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
|
52 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
53 |
+
query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
|
54 |
+
Shape: [N, 2] or [B, N, 2], where N is the number of query points.
|
55 |
+
Default: None
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
dict: A dictionary containing the following predictions:
|
59 |
+
- pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
|
60 |
+
- depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
|
61 |
+
- depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
|
62 |
+
- world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
|
63 |
+
- world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
|
64 |
+
- images (torch.Tensor): Original input images, preserved for visualization
|
65 |
+
|
66 |
+
If query_points is provided, also includes:
|
67 |
+
- track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
|
68 |
+
- vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
|
69 |
+
- conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
|
70 |
+
"""
|
71 |
+
|
72 |
+
# If without batch dimension, add it
|
73 |
+
if len(images.shape) == 4:
|
74 |
+
images = images.unsqueeze(0)
|
75 |
+
B, T, C, H, W = images.shape
|
76 |
+
images = (images - self.base_model.image_mean) / self.base_model.image_std
|
77 |
+
H_14 = H // 14 * 14
|
78 |
+
W_14 = W // 14 * 14
|
79 |
+
image_14 = F.interpolate(images.view(B*T, C, H, W), (H_14, W_14), mode="bilinear", align_corners=False, antialias=True).view(B, T, C, H_14, W_14)
|
80 |
+
|
81 |
+
with torch.no_grad():
|
82 |
+
features = self.base_model.backbone.get_intermediate_layers(rearrange(image_14, 'b t c h w -> (b t) c h w'),
|
83 |
+
self.base_model.intermediate_layers, return_class_token=True)
|
84 |
+
# aggregate the features with checkpoint
|
85 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(image_14, patch_tokens=features[-1][0])
|
86 |
+
|
87 |
+
# enhance the features
|
88 |
+
enhanced_features = []
|
89 |
+
for layer_i, layer in enumerate(self.intermediate_layers):
|
90 |
+
# patch_feat_i = features[layer_i][0] + self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
|
91 |
+
patch_feat_i = self.residual_proj[layer_i](aggregated_tokens_list[layer][:,:,patch_start_idx:,:].view(B*T, features[layer_i][0].shape[1], -1))
|
92 |
+
enhance_i = (patch_feat_i, features[layer_i][1])
|
93 |
+
enhanced_features.append(enhance_i)
|
94 |
+
|
95 |
+
predictions = {}
|
96 |
+
|
97 |
+
with torch.cuda.amp.autocast(enabled=False):
|
98 |
+
if self.camera_head is not None:
|
99 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
100 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
101 |
+
if self.scale_head is not None:
|
102 |
+
scale_list = self.scale_head(aggregated_tokens_list)
|
103 |
+
predictions["scale"] = scale_list[-1] # scale of the last iteration
|
104 |
+
# Predict points (and mask) with checkpoint
|
105 |
+
output = self.base_model.head(enhanced_features, image_14)
|
106 |
+
points, mask = output
|
107 |
+
|
108 |
+
# Post-process points and mask
|
109 |
+
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
110 |
+
points = self.base_model._remap_points(points) # slightly improves the performance in case of very large output values
|
111 |
+
# prepare the predictions
|
112 |
+
predictions["images"] = (images * self.base_model.image_std + self.base_model.image_mean)*255.0
|
113 |
+
points = F.interpolate(points.permute(0, 3, 1, 2), (H, W), mode="bilinear", align_corners=False, antialias=True).permute(0, 2, 3, 1)
|
114 |
+
predictions["points_map"] = points
|
115 |
+
mask = F.interpolate(mask.unsqueeze(1), (H, W), mode="bilinear", align_corners=False, antialias=True).squeeze(1)
|
116 |
+
predictions["unc_metric"] = mask
|
117 |
+
predictions["pose_enc_list"] = pose_enc_list
|
118 |
+
|
119 |
+
if self.training:
|
120 |
+
loss = compute_loss(predictions, annots)
|
121 |
+
predictions["loss"] = loss
|
122 |
+
|
123 |
+
# rescale the points
|
124 |
+
if self.scale_head is not None:
|
125 |
+
points_scale = points * predictions["scale"].view(B*T, 1, 1, 2)[..., :1]
|
126 |
+
points_scale[..., 2:] += predictions["scale"].view(B*T, 1, 1, 2)[..., 1:]
|
127 |
+
predictions["points_map"] = points_scale
|
128 |
+
|
129 |
+
predictions["poses_pred"] = torch.eye(4)[None].repeat(predictions["images"].shape[1], 1, 1)[None]
|
130 |
+
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
131 |
+
predictions["images"].shape[-2:])
|
132 |
+
return predictions
|
models/SpaTrackV2/models/vggt4track/models/vggt.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
10 |
+
|
11 |
+
from vggt.models.aggregator import Aggregator
|
12 |
+
from vggt.heads.camera_head import CameraHead
|
13 |
+
from vggt.heads.dpt_head import DPTHead
|
14 |
+
from vggt.heads.track_head import TrackHead
|
15 |
+
|
16 |
+
|
17 |
+
class VGGT(nn.Module, PyTorchModelHubMixin):
|
18 |
+
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
22 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
23 |
+
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
24 |
+
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
|
25 |
+
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
|
26 |
+
|
27 |
+
def forward(
|
28 |
+
self,
|
29 |
+
images: torch.Tensor,
|
30 |
+
query_points: torch.Tensor = None,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Forward pass of the VGGT model.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
|
37 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
38 |
+
query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
|
39 |
+
Shape: [N, 2] or [B, N, 2], where N is the number of query points.
|
40 |
+
Default: None
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
dict: A dictionary containing the following predictions:
|
44 |
+
- pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
|
45 |
+
- depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
|
46 |
+
- depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
|
47 |
+
- world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
|
48 |
+
- world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
|
49 |
+
- images (torch.Tensor): Original input images, preserved for visualization
|
50 |
+
|
51 |
+
If query_points is provided, also includes:
|
52 |
+
- track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
|
53 |
+
- vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
|
54 |
+
- conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
|
55 |
+
"""
|
56 |
+
|
57 |
+
# If without batch dimension, add it
|
58 |
+
if len(images.shape) == 4:
|
59 |
+
images = images.unsqueeze(0)
|
60 |
+
if query_points is not None and len(query_points.shape) == 2:
|
61 |
+
query_points = query_points.unsqueeze(0)
|
62 |
+
|
63 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
64 |
+
|
65 |
+
predictions = {}
|
66 |
+
|
67 |
+
with torch.cuda.amp.autocast(enabled=False):
|
68 |
+
if self.camera_head is not None:
|
69 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
70 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
71 |
+
|
72 |
+
if self.depth_head is not None:
|
73 |
+
depth, depth_conf = self.depth_head(
|
74 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
75 |
+
)
|
76 |
+
predictions["depth"] = depth
|
77 |
+
predictions["depth_conf"] = depth_conf
|
78 |
+
|
79 |
+
if self.point_head is not None:
|
80 |
+
pts3d, pts3d_conf = self.point_head(
|
81 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
82 |
+
)
|
83 |
+
predictions["world_points"] = pts3d
|
84 |
+
predictions["world_points_conf"] = pts3d_conf
|
85 |
+
|
86 |
+
if self.track_head is not None and query_points is not None:
|
87 |
+
track_list, vis, conf = self.track_head(
|
88 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
|
89 |
+
)
|
90 |
+
predictions["track"] = track_list[-1] # track of the last iteration
|
91 |
+
predictions["vis"] = vis
|
92 |
+
predictions["conf"] = conf
|
93 |
+
|
94 |
+
predictions["images"] = images
|
95 |
+
|
96 |
+
return predictions
|
models/SpaTrackV2/models/vggt4track/models/vggt_moe.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
10 |
+
|
11 |
+
from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator
|
12 |
+
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
|
13 |
+
from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead
|
14 |
+
from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead
|
15 |
+
from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss
|
16 |
+
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
|
17 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
18 |
+
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
19 |
+
from einops import rearrange
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
class VGGT4Track(nn.Module, PyTorchModelHubMixin):
|
23 |
+
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
27 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
28 |
+
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="sigmoid")
|
29 |
+
|
30 |
+
def forward(
|
31 |
+
self,
|
32 |
+
images: torch.Tensor,
|
33 |
+
annots = {},
|
34 |
+
**kwargs):
|
35 |
+
"""
|
36 |
+
Forward pass of the VGGT4Track model.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
|
40 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
41 |
+
query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
|
42 |
+
Shape: [N, 2] or [B, N, 2], where N is the number of query points.
|
43 |
+
Default: None
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
dict: A dictionary containing the following predictions:
|
47 |
+
- pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
|
48 |
+
- depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
|
49 |
+
- depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
|
50 |
+
- world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
|
51 |
+
- world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
|
52 |
+
- images (torch.Tensor): Original input images, preserved for visualization
|
53 |
+
|
54 |
+
If query_points is provided, also includes:
|
55 |
+
- track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
|
56 |
+
- vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
|
57 |
+
- conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
|
58 |
+
"""
|
59 |
+
|
60 |
+
# If without batch dimension, add it
|
61 |
+
B, T, C, H, W = images.shape
|
62 |
+
images_proc = preprocess_image(images.view(B*T, C, H, W).clone())
|
63 |
+
images_proc = rearrange(images_proc, '(b t) c h w -> b t c h w', b=B, t=T)
|
64 |
+
_, _, _, H_proc, W_proc = images_proc.shape
|
65 |
+
|
66 |
+
if len(images.shape) == 4:
|
67 |
+
images = images.unsqueeze(0)
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images_proc)
|
71 |
+
|
72 |
+
predictions = {}
|
73 |
+
|
74 |
+
with torch.cuda.amp.autocast(enabled=False):
|
75 |
+
if self.camera_head is not None:
|
76 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
77 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
78 |
+
predictions["pose_enc_list"] = pose_enc_list
|
79 |
+
|
80 |
+
if self.depth_head is not None:
|
81 |
+
depth, depth_conf = self.depth_head(
|
82 |
+
aggregated_tokens_list, images=images_proc, patch_start_idx=patch_start_idx
|
83 |
+
)
|
84 |
+
predictions["depth"] = depth
|
85 |
+
predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
|
86 |
+
|
87 |
+
predictions["images"] = (images)*255.0
|
88 |
+
# output the camera pose
|
89 |
+
predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
|
90 |
+
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
91 |
+
images_proc.shape[-2:])
|
92 |
+
predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
|
93 |
+
points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
|
94 |
+
predictions["points_map"] = points_map
|
95 |
+
#NOTE: resize back
|
96 |
+
predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
|
97 |
+
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
98 |
+
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
99 |
+
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
100 |
+
predictions["intrs"][..., :1, :] *= W/W_proc
|
101 |
+
predictions["intrs"][..., 1:2, :] *= H/H_proc
|
102 |
+
|
103 |
+
if self.training:
|
104 |
+
loss = compute_loss(predictions, annots)
|
105 |
+
predictions["loss"] = loss
|
106 |
+
|
107 |
+
return predictions
|
models/SpaTrackV2/models/vggt4track/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/SpaTrackV2/models/vggt4track/utils/geometry.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def unproject_depth_map_to_point_map(
|
13 |
+
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
14 |
+
) -> np.ndarray:
|
15 |
+
"""
|
16 |
+
Unproject a batch of depth maps to 3D world coordinates.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
20 |
+
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
21 |
+
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
25 |
+
"""
|
26 |
+
if isinstance(depth_map, torch.Tensor):
|
27 |
+
depth_map = depth_map.cpu().numpy()
|
28 |
+
if isinstance(extrinsics_cam, torch.Tensor):
|
29 |
+
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
30 |
+
if isinstance(intrinsics_cam, torch.Tensor):
|
31 |
+
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
32 |
+
|
33 |
+
world_points_list = []
|
34 |
+
for frame_idx in range(depth_map.shape[0]):
|
35 |
+
cur_world_points, _, _ = depth_to_world_coords_points(
|
36 |
+
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
37 |
+
)
|
38 |
+
world_points_list.append(cur_world_points)
|
39 |
+
world_points_array = np.stack(world_points_list, axis=0)
|
40 |
+
|
41 |
+
return world_points_array
|
42 |
+
|
43 |
+
|
44 |
+
def depth_to_world_coords_points(
|
45 |
+
depth_map: np.ndarray,
|
46 |
+
extrinsic: np.ndarray,
|
47 |
+
intrinsic: np.ndarray,
|
48 |
+
eps=1e-8,
|
49 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
50 |
+
"""
|
51 |
+
Convert a depth map to world coordinates.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
55 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
56 |
+
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
60 |
+
"""
|
61 |
+
if depth_map is None:
|
62 |
+
return None, None, None
|
63 |
+
|
64 |
+
# Valid depth mask
|
65 |
+
point_mask = depth_map > eps
|
66 |
+
|
67 |
+
# Convert depth map to camera coordinates
|
68 |
+
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
69 |
+
|
70 |
+
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
71 |
+
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
72 |
+
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
73 |
+
|
74 |
+
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
75 |
+
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
76 |
+
|
77 |
+
# Apply the rotation and translation to the camera coordinates
|
78 |
+
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
79 |
+
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
80 |
+
|
81 |
+
return world_coords_points, cam_coords_points, point_mask
|
82 |
+
|
83 |
+
|
84 |
+
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
85 |
+
"""
|
86 |
+
Convert a depth map to camera coordinates.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
90 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
94 |
+
"""
|
95 |
+
H, W = depth_map.shape
|
96 |
+
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
97 |
+
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
98 |
+
|
99 |
+
# Intrinsic parameters
|
100 |
+
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
101 |
+
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
102 |
+
|
103 |
+
# Generate grid of pixel coordinates
|
104 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
105 |
+
|
106 |
+
# Unproject to camera coordinates
|
107 |
+
x_cam = (u - cu) * depth_map / fu
|
108 |
+
y_cam = (v - cv) * depth_map / fv
|
109 |
+
z_cam = depth_map
|
110 |
+
|
111 |
+
# Stack to form camera coordinates
|
112 |
+
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
113 |
+
|
114 |
+
return cam_coords
|
115 |
+
|
116 |
+
|
117 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
118 |
+
"""
|
119 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
120 |
+
|
121 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
122 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
126 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
127 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
131 |
+
|
132 |
+
Shapes:
|
133 |
+
se3: (N, 4, 4)
|
134 |
+
R: (N, 3, 3)
|
135 |
+
T: (N, 3, 1)
|
136 |
+
"""
|
137 |
+
# Check if se3 is a numpy array or a torch tensor
|
138 |
+
is_numpy = isinstance(se3, np.ndarray)
|
139 |
+
|
140 |
+
# Validate shapes
|
141 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
142 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
143 |
+
|
144 |
+
# Extract R and T if not provided
|
145 |
+
if R is None:
|
146 |
+
R = se3[:, :3, :3] # (N,3,3)
|
147 |
+
if T is None:
|
148 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
149 |
+
|
150 |
+
# Transpose R
|
151 |
+
if is_numpy:
|
152 |
+
# Compute the transpose of the rotation for NumPy
|
153 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
154 |
+
# -R^T t for NumPy
|
155 |
+
top_right = -np.matmul(R_transposed, T)
|
156 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
157 |
+
else:
|
158 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
159 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
160 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
161 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
162 |
+
|
163 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
164 |
+
inverted_matrix[:, :3, 3:] = top_right
|
165 |
+
|
166 |
+
return inverted_matrix
|
models/SpaTrackV2/models/vggt4track/utils/load_fn.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms as TF
|
10 |
+
|
11 |
+
|
12 |
+
def load_and_preprocess_images(image_path_list, mode="crop"):
|
13 |
+
"""
|
14 |
+
A quick start function to load and preprocess images for model input.
|
15 |
+
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image_path_list (list): List of paths to image files
|
19 |
+
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
20 |
+
- "crop" (default): Sets width to 518px and center crops height if needed.
|
21 |
+
- "pad": Preserves all pixels by making the largest dimension 518px
|
22 |
+
and padding the smaller dimension to reach a square shape.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
ValueError: If the input list is empty or if mode is invalid
|
29 |
+
|
30 |
+
Notes:
|
31 |
+
- Images with different dimensions will be padded with white (value=1.0)
|
32 |
+
- A warning is printed when images have different shapes
|
33 |
+
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
34 |
+
and height is center-cropped if larger than 518px
|
35 |
+
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
36 |
+
and the smaller dimension is padded to reach a square shape (518x518)
|
37 |
+
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
38 |
+
"""
|
39 |
+
# Check for empty list
|
40 |
+
if len(image_path_list) == 0:
|
41 |
+
raise ValueError("At least 1 image is required")
|
42 |
+
|
43 |
+
# Validate mode
|
44 |
+
if mode not in ["crop", "pad"]:
|
45 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
46 |
+
|
47 |
+
images = []
|
48 |
+
shapes = set()
|
49 |
+
to_tensor = TF.ToTensor()
|
50 |
+
target_size = 518
|
51 |
+
|
52 |
+
# First process all images and collect their shapes
|
53 |
+
for image_path in image_path_list:
|
54 |
+
|
55 |
+
# Open image
|
56 |
+
img = Image.open(image_path)
|
57 |
+
|
58 |
+
# If there's an alpha channel, blend onto white background:
|
59 |
+
if img.mode == "RGBA":
|
60 |
+
# Create white background
|
61 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
62 |
+
# Alpha composite onto the white background
|
63 |
+
img = Image.alpha_composite(background, img)
|
64 |
+
|
65 |
+
# Now convert to "RGB" (this step assigns white for transparent areas)
|
66 |
+
img = img.convert("RGB")
|
67 |
+
|
68 |
+
width, height = img.size
|
69 |
+
|
70 |
+
if mode == "pad":
|
71 |
+
# Make the largest dimension 518px while maintaining aspect ratio
|
72 |
+
if width >= height:
|
73 |
+
new_width = target_size
|
74 |
+
new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
|
75 |
+
else:
|
76 |
+
new_height = target_size
|
77 |
+
new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
|
78 |
+
else: # mode == "crop"
|
79 |
+
# Original behavior: set width to 518px
|
80 |
+
new_width = target_size
|
81 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
82 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
83 |
+
|
84 |
+
# Resize with new dimensions (width, height)
|
85 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
86 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
87 |
+
|
88 |
+
# Center crop height if it's larger than 518 (only in crop mode)
|
89 |
+
if mode == "crop" and new_height > target_size:
|
90 |
+
start_y = (new_height - target_size) // 2
|
91 |
+
img = img[:, start_y : start_y + target_size, :]
|
92 |
+
|
93 |
+
# For pad mode, pad to make a square of target_size x target_size
|
94 |
+
if mode == "pad":
|
95 |
+
h_padding = target_size - img.shape[1]
|
96 |
+
w_padding = target_size - img.shape[2]
|
97 |
+
|
98 |
+
if h_padding > 0 or w_padding > 0:
|
99 |
+
pad_top = h_padding // 2
|
100 |
+
pad_bottom = h_padding - pad_top
|
101 |
+
pad_left = w_padding // 2
|
102 |
+
pad_right = w_padding - pad_left
|
103 |
+
|
104 |
+
# Pad with white (value=1.0)
|
105 |
+
img = torch.nn.functional.pad(
|
106 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
107 |
+
)
|
108 |
+
|
109 |
+
shapes.add((img.shape[1], img.shape[2]))
|
110 |
+
images.append(img)
|
111 |
+
|
112 |
+
# Check if we have different shapes
|
113 |
+
# In theory our model can also work well with different shapes
|
114 |
+
if len(shapes) > 1:
|
115 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
116 |
+
# Find maximum dimensions
|
117 |
+
max_height = max(shape[0] for shape in shapes)
|
118 |
+
max_width = max(shape[1] for shape in shapes)
|
119 |
+
|
120 |
+
# Pad images if necessary
|
121 |
+
padded_images = []
|
122 |
+
for img in images:
|
123 |
+
h_padding = max_height - img.shape[1]
|
124 |
+
w_padding = max_width - img.shape[2]
|
125 |
+
|
126 |
+
if h_padding > 0 or w_padding > 0:
|
127 |
+
pad_top = h_padding // 2
|
128 |
+
pad_bottom = h_padding - pad_top
|
129 |
+
pad_left = w_padding // 2
|
130 |
+
pad_right = w_padding - pad_left
|
131 |
+
|
132 |
+
img = torch.nn.functional.pad(
|
133 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
134 |
+
)
|
135 |
+
padded_images.append(img)
|
136 |
+
images = padded_images
|
137 |
+
|
138 |
+
images = torch.stack(images) # concatenate images
|
139 |
+
|
140 |
+
# Ensure correct shape when single image
|
141 |
+
if len(image_path_list) == 1:
|
142 |
+
# Verify shape is (1, C, H, W)
|
143 |
+
if images.dim() == 3:
|
144 |
+
images = images.unsqueeze(0)
|
145 |
+
|
146 |
+
return images
|
147 |
+
|
148 |
+
def preprocess_image(img_tensor, mode="crop", target_size=518):
|
149 |
+
"""
|
150 |
+
Preprocess image tensor(s) to target size with crop or pad mode.
|
151 |
+
Args:
|
152 |
+
img_tensor (torch.Tensor): Image tensor of shape (C, H, W) or (T, C, H, W), values in [0, 1]
|
153 |
+
mode (str): 'crop' or 'pad'
|
154 |
+
target_size (int): Target size for width/height
|
155 |
+
Returns:
|
156 |
+
torch.Tensor: Preprocessed image tensor(s), same batch dim as input
|
157 |
+
"""
|
158 |
+
if mode not in ["crop", "pad"]:
|
159 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
160 |
+
if img_tensor.dim() == 3:
|
161 |
+
tensors = [img_tensor]
|
162 |
+
squeeze = True
|
163 |
+
elif img_tensor.dim() == 4:
|
164 |
+
tensors = list(img_tensor)
|
165 |
+
squeeze = False
|
166 |
+
else:
|
167 |
+
raise ValueError("Input tensor must be (C, H, W) or (T, C, H, W)")
|
168 |
+
processed = []
|
169 |
+
for img in tensors:
|
170 |
+
C, H, W = img.shape
|
171 |
+
if mode == "pad":
|
172 |
+
if W >= H:
|
173 |
+
new_W = target_size
|
174 |
+
new_H = round(H * (new_W / W) / 14) * 14
|
175 |
+
else:
|
176 |
+
new_H = target_size
|
177 |
+
new_W = round(W * (new_H / H) / 14) * 14
|
178 |
+
out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
|
179 |
+
h_padding = target_size - new_H
|
180 |
+
w_padding = target_size - new_W
|
181 |
+
pad_top = h_padding // 2
|
182 |
+
pad_bottom = h_padding - pad_top
|
183 |
+
pad_left = w_padding // 2
|
184 |
+
pad_right = w_padding - pad_left
|
185 |
+
if h_padding > 0 or w_padding > 0:
|
186 |
+
out = torch.nn.functional.pad(
|
187 |
+
out, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
188 |
+
)
|
189 |
+
else: # crop
|
190 |
+
new_W = target_size
|
191 |
+
new_H = round(H * (new_W / W) / 14) * 14
|
192 |
+
out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
|
193 |
+
if new_H > target_size:
|
194 |
+
start_y = (new_H - target_size) // 2
|
195 |
+
out = out[:, start_y : start_y + target_size, :]
|
196 |
+
processed.append(out)
|
197 |
+
result = torch.stack(processed)
|
198 |
+
if squeeze:
|
199 |
+
return result[0]
|
200 |
+
return result
|
models/SpaTrackV2/models/vggt4track/utils/loss.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains the loss functions for FrontTracker
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import utils3d
|
6 |
+
from models.moge.train.losses import (
|
7 |
+
affine_invariant_global_loss,
|
8 |
+
affine_invariant_local_loss,
|
9 |
+
edge_loss,
|
10 |
+
normal_loss,
|
11 |
+
mask_l2_loss,
|
12 |
+
mask_bce_loss,
|
13 |
+
monitoring,
|
14 |
+
)
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
|
17 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
18 |
+
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
|
19 |
+
|
20 |
+
def compute_loss(predictions, annots):
|
21 |
+
"""
|
22 |
+
Compute the loss for the FrontTracker model.
|
23 |
+
"""
|
24 |
+
|
25 |
+
B, T, C, H, W = predictions["images"].shape
|
26 |
+
H_resize, W_resize = H, W
|
27 |
+
|
28 |
+
if "poses_gt" in annots.keys():
|
29 |
+
intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
|
30 |
+
H_resize, W_resize, min(H, W))
|
31 |
+
else:
|
32 |
+
c2w_traj_gt = None
|
33 |
+
|
34 |
+
if "intrs_gt" in annots.keys():
|
35 |
+
intrs = annots["intrs_gt"].view(B, T, 3, 3)
|
36 |
+
fx_factor = W_resize / W
|
37 |
+
fy_factor = H_resize / H
|
38 |
+
intrs[:,:,0,:] *= fx_factor
|
39 |
+
intrs[:,:,1,:] *= fy_factor
|
40 |
+
|
41 |
+
if "depth_gt" in annots.keys():
|
42 |
+
|
43 |
+
metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
|
44 |
+
metric_depth_gt = F.interpolate(metric_depth_gt,
|
45 |
+
size=(H_resize, W_resize), mode='nearest')
|
46 |
+
|
47 |
+
_depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
|
48 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
49 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
50 |
+
iqr = q75 - q25
|
51 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
52 |
+
_depth_roi = torch.tensor(
|
53 |
+
[1e-1, upper_bound.item()],
|
54 |
+
dtype=metric_depth_gt.dtype,
|
55 |
+
device=metric_depth_gt.device
|
56 |
+
)
|
57 |
+
mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
|
58 |
+
# fin mask
|
59 |
+
gt_mask_fin = ((metric_depth_gt > 0)*(mask_roi)).float()
|
60 |
+
# filter the sky
|
61 |
+
inf_thres = 50*q25.clamp(min=200, max=1e3)
|
62 |
+
gt_mask_inf = (metric_depth_gt > inf_thres).float()
|
63 |
+
# gt mask
|
64 |
+
gt_mask = (metric_depth_gt > 0)*(metric_depth_gt < 10*q25)
|
65 |
+
|
66 |
+
points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
67 |
+
|
68 |
+
if annots["syn_real"] == 1:
|
69 |
+
ln_msk_l2, _ = mask_l2_loss(predictions["unc_metric"], gt_mask_fin[:,0], gt_mask_inf[:,0])
|
70 |
+
ln_msk_l2 = 50*ln_msk_l2.mean()
|
71 |
+
else:
|
72 |
+
ln_msk_l2 = 0 * points_map_gt.mean()
|
73 |
+
|
74 |
+
# loss1: global invariant loss
|
75 |
+
ln_depth_glob, _, gt_metric_scale, gt_metric_shift = affine_invariant_global_loss(predictions["points_map"], points_map_gt, gt_mask[:,0], align_resolution=32)
|
76 |
+
ln_depth_glob = 100*ln_depth_glob.mean()
|
77 |
+
# loss2: edge loss
|
78 |
+
ln_edge, _ = edge_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
|
79 |
+
ln_edge = ln_edge.mean()
|
80 |
+
# loss3: normal loss
|
81 |
+
ln_normal, _ = normal_loss(predictions["points_map"], points_map_gt, gt_mask[:,0])
|
82 |
+
ln_normal = ln_normal.mean()
|
83 |
+
#NOTE: loss4: consistent loss
|
84 |
+
norm_rescale = gt_metric_scale.mean()
|
85 |
+
points_map_gt_cons = points_map_gt.clone() / norm_rescale
|
86 |
+
if "scale" in predictions.keys():
|
87 |
+
scale_ = predictions["scale"].view(B*T, 2, 1, 1)[:,:1]
|
88 |
+
shift_ = predictions["scale"].view(B*T, 2, 1, 1)[:,1:]
|
89 |
+
else:
|
90 |
+
scale_ = torch.ones_like(predictions["points_map"])
|
91 |
+
shift_ = torch.zeros_like(predictions["points_map"])[..., 2:]
|
92 |
+
|
93 |
+
points_pred_cons = predictions["points_map"] * scale_
|
94 |
+
points_pred_cons[..., 2:] += shift_
|
95 |
+
pred_mask = predictions["unc_metric"].clone().clamp(min=5e-2)
|
96 |
+
ln_cons = torch.abs(points_pred_cons - points_map_gt_cons).norm(dim=-1) * pred_mask - 0.05 * torch.log(pred_mask)
|
97 |
+
ln_cons = 0.5*ln_cons[(1-gt_mask_inf.squeeze()).bool()].clamp(max=100).mean()
|
98 |
+
# loss5: scale shift loss
|
99 |
+
if "scale" in predictions.keys():
|
100 |
+
ln_scale_shift = torch.abs(scale_.squeeze() - gt_metric_scale / norm_rescale) + torch.abs(shift_.squeeze() - gt_metric_shift[:,2] / norm_rescale)
|
101 |
+
ln_scale_shift = 10*ln_scale_shift.mean()
|
102 |
+
else:
|
103 |
+
ln_scale_shift = 0 * ln_cons.mean()
|
104 |
+
# loss6: pose loss
|
105 |
+
c2w_traj_gt[...,:3, 3] /= norm_rescale
|
106 |
+
ln_pose = 0
|
107 |
+
for i_t, pose_enc_i in enumerate(predictions["pose_enc_list"]):
|
108 |
+
pose_enc_gt = extri_intri_to_pose_encoding(torch.inverse(c2w_traj_gt)[...,:3,:4], intrs, predictions["images"].shape[-2:])
|
109 |
+
T_loss = torch.abs(pose_enc_i[..., :3] - pose_enc_gt[..., :3]).mean()
|
110 |
+
R_loss = torch.abs(pose_enc_i[..., 3:7] - pose_enc_gt[..., 3:7]).mean()
|
111 |
+
K_loss = torch.abs(pose_enc_i[..., 7:] - pose_enc_gt[..., 7:]).mean()
|
112 |
+
pose_loss_i = 25*(T_loss + R_loss) + K_loss
|
113 |
+
ln_pose += 0.8**(len(predictions["pose_enc_list"]) - i_t - 1)*(pose_loss_i)
|
114 |
+
ln_pose = 5*ln_pose
|
115 |
+
if annots["syn_real"] == 1:
|
116 |
+
loss = ln_depth_glob + ln_edge + ln_normal + ln_cons + ln_scale_shift + ln_pose + ln_msk_l2
|
117 |
+
else:
|
118 |
+
loss = ln_cons + ln_pose
|
119 |
+
ln_scale_shift = 0*ln_scale_shift
|
120 |
+
return {"loss": loss, "ln_depth_glob": ln_depth_glob, "ln_edge": ln_edge, "ln_normal": ln_normal,
|
121 |
+
"ln_cons": ln_cons, "ln_scale_shift": ln_scale_shift,
|
122 |
+
"ln_pose": ln_pose, "ln_msk_l2": ln_msk_l2, "norm_scale": norm_rescale}
|
123 |
+
|
models/SpaTrackV2/models/vggt4track/utils/pose_enc.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from .rotation import quat_to_mat, mat_to_quat
|
9 |
+
|
10 |
+
|
11 |
+
def extri_intri_to_pose_encoding(
|
12 |
+
extrinsics,
|
13 |
+
intrinsics,
|
14 |
+
image_size_hw=None, # e.g., (256, 512)
|
15 |
+
pose_encoding_type="absT_quaR_FoV",
|
16 |
+
):
|
17 |
+
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
|
18 |
+
|
19 |
+
This function transforms camera parameters into a unified pose encoding format,
|
20 |
+
which can be used for various downstream tasks like pose prediction or representation.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
|
24 |
+
where B is batch size and S is sequence length.
|
25 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
|
26 |
+
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
|
27 |
+
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
|
28 |
+
Defined in pixels, with format:
|
29 |
+
[[fx, 0, cx],
|
30 |
+
[0, fy, cy],
|
31 |
+
[0, 0, 1]]
|
32 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point
|
33 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
34 |
+
Required for computing field of view values. For example: (256, 512).
|
35 |
+
pose_encoding_type (str): Type of pose encoding to use. Currently only
|
36 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
|
40 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
41 |
+
- [:3] = absolute translation vector T (3D)
|
42 |
+
- [3:7] = rotation as quaternion quat (4D)
|
43 |
+
- [7:] = field of view (2D)
|
44 |
+
"""
|
45 |
+
|
46 |
+
# extrinsics: BxSx3x4
|
47 |
+
# intrinsics: BxSx3x3
|
48 |
+
|
49 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
50 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
51 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
52 |
+
|
53 |
+
quat = mat_to_quat(R)
|
54 |
+
# Note the order of h and w here
|
55 |
+
H, W = image_size_hw
|
56 |
+
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
57 |
+
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
58 |
+
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
59 |
+
else:
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
return pose_encoding
|
63 |
+
|
64 |
+
|
65 |
+
def pose_encoding_to_extri_intri(
|
66 |
+
pose_encoding,
|
67 |
+
image_size_hw=None, # e.g., (256, 512)
|
68 |
+
pose_encoding_type="absT_quaR_FoV",
|
69 |
+
build_intrinsics=True,
|
70 |
+
):
|
71 |
+
"""Convert a pose encoding back to camera extrinsics and intrinsics.
|
72 |
+
|
73 |
+
This function performs the inverse operation of extri_intri_to_pose_encoding,
|
74 |
+
reconstructing the full camera parameters from the compact encoding.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
|
78 |
+
where B is batch size and S is sequence length.
|
79 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
80 |
+
- [:3] = absolute translation vector T (3D)
|
81 |
+
- [3:7] = rotation as quaternion quat (4D)
|
82 |
+
- [7:] = field of view (2D)
|
83 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
84 |
+
Required for reconstructing intrinsics from field of view values.
|
85 |
+
For example: (256, 512).
|
86 |
+
pose_encoding_type (str): Type of pose encoding used. Currently only
|
87 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
88 |
+
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
|
89 |
+
If False, only extrinsics are returned and intrinsics will be None.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
tuple: (extrinsics, intrinsics)
|
93 |
+
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
|
94 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
|
95 |
+
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
|
96 |
+
a 3x1 translation vector.
|
97 |
+
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
|
98 |
+
or None if build_intrinsics is False. Defined in pixels, with format:
|
99 |
+
[[fx, 0, cx],
|
100 |
+
[0, fy, cy],
|
101 |
+
[0, 0, 1]]
|
102 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point,
|
103 |
+
assumed to be at the center of the image (W/2, H/2).
|
104 |
+
"""
|
105 |
+
|
106 |
+
intrinsics = None
|
107 |
+
|
108 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
109 |
+
T = pose_encoding[..., :3]
|
110 |
+
quat = pose_encoding[..., 3:7]
|
111 |
+
fov_h = pose_encoding[..., 7]
|
112 |
+
fov_w = pose_encoding[..., 8]
|
113 |
+
|
114 |
+
R = quat_to_mat(quat)
|
115 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
116 |
+
|
117 |
+
if build_intrinsics:
|
118 |
+
H, W = image_size_hw
|
119 |
+
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
120 |
+
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
121 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
122 |
+
intrinsics[..., 0, 0] = fx
|
123 |
+
intrinsics[..., 1, 1] = fy
|
124 |
+
intrinsics[..., 0, 2] = W / 2
|
125 |
+
intrinsics[..., 1, 2] = H / 2
|
126 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
127 |
+
else:
|
128 |
+
raise NotImplementedError
|
129 |
+
|
130 |
+
return extrinsics, intrinsics
|
models/SpaTrackV2/models/vggt4track/utils/rotation.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
17 |
+
|
18 |
+
Convert rotations given as quaternions to rotation matrices.
|
19 |
+
Args:
|
20 |
+
quaternions: quaternions with real part last,
|
21 |
+
as tensor of shape (..., 4).
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
25 |
+
"""
|
26 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
27 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
28 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
29 |
+
|
30 |
+
o = torch.stack(
|
31 |
+
(
|
32 |
+
1 - two_s * (j * j + k * k),
|
33 |
+
two_s * (i * j - k * r),
|
34 |
+
two_s * (i * k + j * r),
|
35 |
+
two_s * (i * j + k * r),
|
36 |
+
1 - two_s * (i * i + k * k),
|
37 |
+
two_s * (j * k - i * r),
|
38 |
+
two_s * (i * k - j * r),
|
39 |
+
two_s * (j * k + i * r),
|
40 |
+
1 - two_s * (i * i + j * j),
|
41 |
+
),
|
42 |
+
-1,
|
43 |
+
)
|
44 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
45 |
+
|
46 |
+
|
47 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
Convert rotations given as rotation matrices to quaternions.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
56 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
57 |
+
"""
|
58 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
59 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
60 |
+
|
61 |
+
batch_dim = matrix.shape[:-2]
|
62 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
63 |
+
|
64 |
+
q_abs = _sqrt_positive_part(
|
65 |
+
torch.stack(
|
66 |
+
[
|
67 |
+
1.0 + m00 + m11 + m22,
|
68 |
+
1.0 + m00 - m11 - m22,
|
69 |
+
1.0 - m00 + m11 - m22,
|
70 |
+
1.0 - m00 - m11 + m22,
|
71 |
+
],
|
72 |
+
dim=-1,
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
77 |
+
quat_by_rijk = torch.stack(
|
78 |
+
[
|
79 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
80 |
+
# `int`.
|
81 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
82 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
83 |
+
# `int`.
|
84 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
85 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
86 |
+
# `int`.
|
87 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
88 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
89 |
+
# `int`.
|
90 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
91 |
+
],
|
92 |
+
dim=-2,
|
93 |
+
)
|
94 |
+
|
95 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
96 |
+
# the candidate won't be picked.
|
97 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
98 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
99 |
+
|
100 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
101 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
102 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
103 |
+
|
104 |
+
# Convert from rijk to ijkr
|
105 |
+
out = out[..., [1, 2, 3, 0]]
|
106 |
+
|
107 |
+
out = standardize_quaternion(out)
|
108 |
+
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
113 |
+
"""
|
114 |
+
Returns torch.sqrt(torch.max(0, x))
|
115 |
+
but with a zero subgradient where x is 0.
|
116 |
+
"""
|
117 |
+
ret = torch.zeros_like(x)
|
118 |
+
positive_mask = x > 0
|
119 |
+
if torch.is_grad_enabled():
|
120 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
121 |
+
else:
|
122 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
123 |
+
return ret
|
124 |
+
|
125 |
+
|
126 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
127 |
+
"""
|
128 |
+
Convert a unit quaternion to a standard form: one in which the real
|
129 |
+
part is non negative.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
quaternions: Quaternions with real part last,
|
133 |
+
as tensor of shape (..., 4).
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Standardized quaternions as tensor of shape (..., 4).
|
137 |
+
"""
|
138 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
models/SpaTrackV2/models/vggt4track/utils/visual_track.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
def color_from_xy(x, y, W, H, cmap_name="hsv"):
|
14 |
+
"""
|
15 |
+
Map (x, y) -> color in (R, G, B).
|
16 |
+
1) Normalize x,y to [0,1].
|
17 |
+
2) Combine them into a single scalar c in [0,1].
|
18 |
+
3) Use matplotlib's colormap to convert c -> (R,G,B).
|
19 |
+
|
20 |
+
You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
|
21 |
+
"""
|
22 |
+
import matplotlib.cm
|
23 |
+
import matplotlib.colors
|
24 |
+
|
25 |
+
x_norm = x / max(W - 1, 1)
|
26 |
+
y_norm = y / max(H - 1, 1)
|
27 |
+
# Simple combination:
|
28 |
+
c = (x_norm + y_norm) / 2.0
|
29 |
+
|
30 |
+
cmap = matplotlib.cm.get_cmap(cmap_name)
|
31 |
+
# cmap(c) -> (r,g,b,a) in [0,1]
|
32 |
+
rgba = cmap(c)
|
33 |
+
r, g, b = rgba[0], rgba[1], rgba[2]
|
34 |
+
return (r, g, b) # in [0,1], RGB order
|
35 |
+
|
36 |
+
|
37 |
+
def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
|
38 |
+
"""
|
39 |
+
Given all tracks in one sample (b), compute a (N,3) array of RGB color values
|
40 |
+
in [0,255]. The color is determined by the (x,y) position in the first
|
41 |
+
visible frame for each track.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
|
45 |
+
vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
|
46 |
+
image_width, image_height: used for normalizing (x, y).
|
47 |
+
cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
|
51 |
+
"""
|
52 |
+
S, N, _ = tracks_b.shape
|
53 |
+
track_colors = np.zeros((N, 3), dtype=np.uint8)
|
54 |
+
|
55 |
+
if vis_mask_b is None:
|
56 |
+
# treat all as visible
|
57 |
+
vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
|
58 |
+
|
59 |
+
for i in range(N):
|
60 |
+
# Find first visible frame for track i
|
61 |
+
visible_frames = torch.where(vis_mask_b[:, i])[0]
|
62 |
+
if len(visible_frames) == 0:
|
63 |
+
# track is never visible; just assign black or something
|
64 |
+
track_colors[i] = (0, 0, 0)
|
65 |
+
continue
|
66 |
+
|
67 |
+
first_s = int(visible_frames[0].item())
|
68 |
+
# use that frame's (x,y)
|
69 |
+
x, y = tracks_b[first_s, i].tolist()
|
70 |
+
|
71 |
+
# map (x,y) -> (R,G,B) in [0,1]
|
72 |
+
r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
|
73 |
+
# scale to [0,255]
|
74 |
+
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
75 |
+
track_colors[i] = (r, g, b)
|
76 |
+
|
77 |
+
return track_colors
|
78 |
+
|
79 |
+
|
80 |
+
def visualize_tracks_on_images(
|
81 |
+
images,
|
82 |
+
tracks,
|
83 |
+
track_vis_mask=None,
|
84 |
+
out_dir="track_visuals_concat_by_xy",
|
85 |
+
image_format="CHW", # "CHW" or "HWC"
|
86 |
+
normalize_mode="[0,1]",
|
87 |
+
cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
|
88 |
+
frames_per_row=4, # New parameter for grid layout
|
89 |
+
save_grid=True, # Flag to control whether to save the grid image
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Visualizes frames in a grid layout with specified frames per row.
|
93 |
+
Each track's color is determined by its (x,y) position
|
94 |
+
in the first visible frame (or frame 0 if always visible).
|
95 |
+
Finally convert the BGR result to RGB before saving.
|
96 |
+
Also saves each individual frame as a separate PNG file.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
|
100 |
+
tracks: torch.Tensor (S, N, 2), last dim = (x, y).
|
101 |
+
track_vis_mask: torch.Tensor (S, N) or None.
|
102 |
+
out_dir: folder to save visualizations.
|
103 |
+
image_format: "CHW" or "HWC".
|
104 |
+
normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
|
105 |
+
cmap_name: a matplotlib colormap name for color_from_xy.
|
106 |
+
frames_per_row: number of frames to display in each row of the grid.
|
107 |
+
save_grid: whether to save all frames in one grid image.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
None (saves images in out_dir).
|
111 |
+
"""
|
112 |
+
|
113 |
+
if len(tracks.shape) == 4:
|
114 |
+
tracks = tracks.squeeze(0)
|
115 |
+
images = images.squeeze(0)
|
116 |
+
if track_vis_mask is not None:
|
117 |
+
track_vis_mask = track_vis_mask.squeeze(0)
|
118 |
+
|
119 |
+
import matplotlib
|
120 |
+
|
121 |
+
matplotlib.use("Agg") # for non-interactive (optional)
|
122 |
+
|
123 |
+
os.makedirs(out_dir, exist_ok=True)
|
124 |
+
|
125 |
+
S = images.shape[0]
|
126 |
+
_, N, _ = tracks.shape # (S, N, 2)
|
127 |
+
|
128 |
+
# Move to CPU
|
129 |
+
images = images.cpu().clone()
|
130 |
+
tracks = tracks.cpu().clone()
|
131 |
+
if track_vis_mask is not None:
|
132 |
+
track_vis_mask = track_vis_mask.cpu().clone()
|
133 |
+
|
134 |
+
# Infer H, W from images shape
|
135 |
+
if image_format == "CHW":
|
136 |
+
# e.g. images[s].shape = (3, H, W)
|
137 |
+
H, W = images.shape[2], images.shape[3]
|
138 |
+
else:
|
139 |
+
# e.g. images[s].shape = (H, W, 3)
|
140 |
+
H, W = images.shape[1], images.shape[2]
|
141 |
+
|
142 |
+
# Pre-compute the color for each track i based on first visible position
|
143 |
+
track_colors_rgb = get_track_colors_by_position(
|
144 |
+
tracks, # shape (S, N, 2)
|
145 |
+
vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
|
146 |
+
image_width=W,
|
147 |
+
image_height=H,
|
148 |
+
cmap_name=cmap_name,
|
149 |
+
)
|
150 |
+
|
151 |
+
# We'll accumulate each frame's drawn image in a list
|
152 |
+
frame_images = []
|
153 |
+
|
154 |
+
for s in range(S):
|
155 |
+
# shape => either (3, H, W) or (H, W, 3)
|
156 |
+
img = images[s]
|
157 |
+
|
158 |
+
# Convert to (H, W, 3)
|
159 |
+
if image_format == "CHW":
|
160 |
+
img = img.permute(1, 2, 0) # (H, W, 3)
|
161 |
+
# else "HWC", do nothing
|
162 |
+
|
163 |
+
img = img.numpy().astype(np.float32)
|
164 |
+
|
165 |
+
# Scale to [0,255] if needed
|
166 |
+
if normalize_mode == "[0,1]":
|
167 |
+
img = np.clip(img, 0, 1) * 255.0
|
168 |
+
elif normalize_mode == "[-1,1]":
|
169 |
+
img = (img + 1.0) * 0.5 * 255.0
|
170 |
+
img = np.clip(img, 0, 255.0)
|
171 |
+
# else no normalization
|
172 |
+
|
173 |
+
# Convert to uint8
|
174 |
+
img = img.astype(np.uint8)
|
175 |
+
|
176 |
+
# For drawing in OpenCV, convert to BGR
|
177 |
+
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
178 |
+
|
179 |
+
# Draw each visible track
|
180 |
+
cur_tracks = tracks[s] # shape (N, 2)
|
181 |
+
if track_vis_mask is not None:
|
182 |
+
valid_indices = torch.where(track_vis_mask[s])[0]
|
183 |
+
else:
|
184 |
+
valid_indices = range(N)
|
185 |
+
|
186 |
+
cur_tracks_np = cur_tracks.numpy()
|
187 |
+
for i in valid_indices:
|
188 |
+
x, y = cur_tracks_np[i]
|
189 |
+
pt = (int(round(x)), int(round(y)))
|
190 |
+
|
191 |
+
# track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
|
192 |
+
R, G, B = track_colors_rgb[i]
|
193 |
+
color_bgr = (int(B), int(G), int(R))
|
194 |
+
cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
|
195 |
+
|
196 |
+
# Convert back to RGB for consistent final saving:
|
197 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
198 |
+
|
199 |
+
# Save individual frame
|
200 |
+
frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
|
201 |
+
# Convert to BGR for OpenCV imwrite
|
202 |
+
frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
203 |
+
cv2.imwrite(frame_path, frame_bgr)
|
204 |
+
|
205 |
+
frame_images.append(img_rgb)
|
206 |
+
|
207 |
+
# Only create and save the grid image if save_grid is True
|
208 |
+
if save_grid:
|
209 |
+
# Calculate grid dimensions
|
210 |
+
num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
|
211 |
+
|
212 |
+
# Create a grid of images
|
213 |
+
grid_img = None
|
214 |
+
for row in range(num_rows):
|
215 |
+
start_idx = row * frames_per_row
|
216 |
+
end_idx = min(start_idx + frames_per_row, S)
|
217 |
+
|
218 |
+
# Concatenate this row horizontally
|
219 |
+
row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
|
220 |
+
|
221 |
+
# If this row has fewer than frames_per_row images, pad with black
|
222 |
+
if end_idx - start_idx < frames_per_row:
|
223 |
+
padding_width = (frames_per_row - (end_idx - start_idx)) * W
|
224 |
+
padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
|
225 |
+
row_img = np.concatenate([row_img, padding], axis=1)
|
226 |
+
|
227 |
+
# Add this row to the grid
|
228 |
+
if grid_img is None:
|
229 |
+
grid_img = row_img
|
230 |
+
else:
|
231 |
+
grid_img = np.concatenate([grid_img, row_img], axis=0)
|
232 |
+
|
233 |
+
out_path = os.path.join(out_dir, "tracks_grid.png")
|
234 |
+
# Convert back to BGR for OpenCV imwrite
|
235 |
+
grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
|
236 |
+
cv2.imwrite(out_path, grid_img_bgr)
|
237 |
+
print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
|
238 |
+
|
239 |
+
print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
|
scripts/download.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Download the example data using gdown
|
4 |
+
mkdir -p ./assets/example1
|
5 |
+
gdown --id 1q6n2R5ihfMoD-dU_u5vfcMALZSihNgiq -O ./assets/example1/snowboard.npz
|
viz.html
ADDED
@@ -0,0 +1,2115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>3D Point Cloud Visualizer</title>
|
7 |
+
<style>
|
8 |
+
:root {
|
9 |
+
--primary: #9b59b6; /* Brighter purple for dark mode */
|
10 |
+
--primary-light: #3a2e4a;
|
11 |
+
--secondary: #a86add;
|
12 |
+
--accent: #ff6e6e;
|
13 |
+
--bg: #1a1a1a;
|
14 |
+
--surface: #2c2c2c;
|
15 |
+
--text: #e0e0e0;
|
16 |
+
--text-secondary: #a0a0a0;
|
17 |
+
--border: #444444;
|
18 |
+
--shadow: rgba(0, 0, 0, 0.2);
|
19 |
+
--shadow-hover: rgba(0, 0, 0, 0.3);
|
20 |
+
|
21 |
+
--space-sm: 16px;
|
22 |
+
--space-md: 24px;
|
23 |
+
--space-lg: 32px;
|
24 |
+
}
|
25 |
+
|
26 |
+
body {
|
27 |
+
margin: 0;
|
28 |
+
overflow: hidden;
|
29 |
+
background: var(--bg);
|
30 |
+
color: var(--text);
|
31 |
+
font-family: 'Inter', sans-serif;
|
32 |
+
-webkit-font-smoothing: antialiased;
|
33 |
+
}
|
34 |
+
|
35 |
+
#canvas-container {
|
36 |
+
position: absolute;
|
37 |
+
width: 100%;
|
38 |
+
height: 100%;
|
39 |
+
}
|
40 |
+
|
41 |
+
#ui-container {
|
42 |
+
position: absolute;
|
43 |
+
top: 0;
|
44 |
+
left: 0;
|
45 |
+
width: 100%;
|
46 |
+
height: 100%;
|
47 |
+
pointer-events: none;
|
48 |
+
z-index: 10;
|
49 |
+
}
|
50 |
+
|
51 |
+
#status-bar {
|
52 |
+
position: absolute;
|
53 |
+
top: 16px;
|
54 |
+
left: 16px;
|
55 |
+
background: rgba(30, 30, 30, 0.9);
|
56 |
+
padding: 8px 16px;
|
57 |
+
border-radius: 8px;
|
58 |
+
pointer-events: auto;
|
59 |
+
box-shadow: 0 4px 6px var(--shadow);
|
60 |
+
backdrop-filter: blur(4px);
|
61 |
+
border: 1px solid var(--border);
|
62 |
+
color: var(--text);
|
63 |
+
transition: opacity 0.5s ease, transform 0.5s ease;
|
64 |
+
font-weight: 500;
|
65 |
+
}
|
66 |
+
|
67 |
+
#status-bar.hidden {
|
68 |
+
opacity: 0;
|
69 |
+
transform: translateY(-20px);
|
70 |
+
pointer-events: none;
|
71 |
+
}
|
72 |
+
|
73 |
+
#control-panel {
|
74 |
+
position: absolute;
|
75 |
+
bottom: 16px;
|
76 |
+
left: 50%;
|
77 |
+
transform: translateX(-50%);
|
78 |
+
background: rgba(44, 44, 44, 0.95);
|
79 |
+
padding: 6px 8px;
|
80 |
+
border-radius: 6px;
|
81 |
+
display: flex;
|
82 |
+
gap: 8px;
|
83 |
+
align-items: center;
|
84 |
+
justify-content: space-between;
|
85 |
+
pointer-events: auto;
|
86 |
+
box-shadow: 0 4px 10px var(--shadow);
|
87 |
+
backdrop-filter: blur(4px);
|
88 |
+
border: 1px solid var(--border);
|
89 |
+
}
|
90 |
+
|
91 |
+
#timeline {
|
92 |
+
width: 150px;
|
93 |
+
height: 4px;
|
94 |
+
background: rgba(255, 255, 255, 0.1);
|
95 |
+
border-radius: 2px;
|
96 |
+
position: relative;
|
97 |
+
cursor: pointer;
|
98 |
+
}
|
99 |
+
|
100 |
+
#progress {
|
101 |
+
position: absolute;
|
102 |
+
height: 100%;
|
103 |
+
background: var(--primary);
|
104 |
+
border-radius: 2px;
|
105 |
+
width: 0%;
|
106 |
+
}
|
107 |
+
|
108 |
+
#playback-controls {
|
109 |
+
display: flex;
|
110 |
+
gap: 4px;
|
111 |
+
align-items: center;
|
112 |
+
}
|
113 |
+
|
114 |
+
button {
|
115 |
+
background: rgba(255, 255, 255, 0.08);
|
116 |
+
border: 1px solid var(--border);
|
117 |
+
color: var(--text);
|
118 |
+
padding: 4px 6px;
|
119 |
+
border-radius: 3px;
|
120 |
+
cursor: pointer;
|
121 |
+
display: flex;
|
122 |
+
align-items: center;
|
123 |
+
justify-content: center;
|
124 |
+
transition: background 0.2s, transform 0.2s;
|
125 |
+
font-family: 'Inter', sans-serif;
|
126 |
+
font-weight: 500;
|
127 |
+
font-size: 6px;
|
128 |
+
}
|
129 |
+
|
130 |
+
button:hover {
|
131 |
+
background: rgba(255, 255, 255, 0.15);
|
132 |
+
transform: translateY(-1px);
|
133 |
+
}
|
134 |
+
|
135 |
+
button.active {
|
136 |
+
background: var(--primary);
|
137 |
+
color: white;
|
138 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
139 |
+
}
|
140 |
+
|
141 |
+
select, input {
|
142 |
+
background: rgba(255, 255, 255, 0.08);
|
143 |
+
border: 1px solid var(--border);
|
144 |
+
color: var(--text);
|
145 |
+
padding: 4px 6px;
|
146 |
+
border-radius: 3px;
|
147 |
+
cursor: pointer;
|
148 |
+
font-family: 'Inter', sans-serif;
|
149 |
+
font-size: 6px;
|
150 |
+
}
|
151 |
+
|
152 |
+
.icon {
|
153 |
+
width: 10px;
|
154 |
+
height: 10px;
|
155 |
+
fill: currentColor;
|
156 |
+
}
|
157 |
+
|
158 |
+
.tooltip {
|
159 |
+
position: absolute;
|
160 |
+
bottom: 100%;
|
161 |
+
left: 50%;
|
162 |
+
transform: translateX(-50%);
|
163 |
+
background: var(--surface);
|
164 |
+
color: var(--text);
|
165 |
+
padding: 3px 6px;
|
166 |
+
border-radius: 3px;
|
167 |
+
font-size: 7px;
|
168 |
+
white-space: nowrap;
|
169 |
+
margin-bottom: 4px;
|
170 |
+
opacity: 0;
|
171 |
+
transition: opacity 0.2s;
|
172 |
+
pointer-events: none;
|
173 |
+
box-shadow: 0 2px 4px var(--shadow);
|
174 |
+
border: 1px solid var(--border);
|
175 |
+
}
|
176 |
+
|
177 |
+
button:hover .tooltip {
|
178 |
+
opacity: 1;
|
179 |
+
}
|
180 |
+
|
181 |
+
#settings-panel {
|
182 |
+
position: absolute;
|
183 |
+
top: 16px;
|
184 |
+
right: 16px;
|
185 |
+
background: rgba(44, 44, 44, 0.98);
|
186 |
+
padding: 10px;
|
187 |
+
border-radius: 6px;
|
188 |
+
width: 195px;
|
189 |
+
max-height: calc(100vh - 40px);
|
190 |
+
overflow-y: auto;
|
191 |
+
pointer-events: auto;
|
192 |
+
box-shadow: 0 4px 15px var(--shadow);
|
193 |
+
backdrop-filter: blur(4px);
|
194 |
+
border: 1px solid var(--border);
|
195 |
+
display: block;
|
196 |
+
opacity: 1;
|
197 |
+
scrollbar-width: thin;
|
198 |
+
scrollbar-color: var(--primary-light) transparent;
|
199 |
+
transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
|
200 |
+
}
|
201 |
+
|
202 |
+
#settings-panel.is-hidden {
|
203 |
+
transform: translateX(calc(100% + 20px));
|
204 |
+
opacity: 0;
|
205 |
+
pointer-events: none;
|
206 |
+
}
|
207 |
+
|
208 |
+
#settings-panel::-webkit-scrollbar {
|
209 |
+
width: 3px;
|
210 |
+
}
|
211 |
+
|
212 |
+
#settings-panel::-webkit-scrollbar-track {
|
213 |
+
background: transparent;
|
214 |
+
}
|
215 |
+
|
216 |
+
#settings-panel::-webkit-scrollbar-thumb {
|
217 |
+
background-color: var(--primary-light);
|
218 |
+
border-radius: 3px;
|
219 |
+
}
|
220 |
+
|
221 |
+
@media (max-height: 700px) {
|
222 |
+
#settings-panel {
|
223 |
+
max-height: calc(100vh - 40px);
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
@media (max-width: 768px) {
|
228 |
+
#control-panel {
|
229 |
+
width: 90%;
|
230 |
+
flex-wrap: wrap;
|
231 |
+
justify-content: center;
|
232 |
+
}
|
233 |
+
|
234 |
+
#timeline {
|
235 |
+
width: 100%;
|
236 |
+
order: 3;
|
237 |
+
margin-top: 10px;
|
238 |
+
}
|
239 |
+
|
240 |
+
#settings-panel {
|
241 |
+
width: 140px;
|
242 |
+
right: 10px;
|
243 |
+
top: 10px;
|
244 |
+
max-height: calc(100vh - 20px);
|
245 |
+
}
|
246 |
+
}
|
247 |
+
|
248 |
+
.settings-group {
|
249 |
+
margin-bottom: 8px;
|
250 |
+
}
|
251 |
+
|
252 |
+
.settings-group h3 {
|
253 |
+
margin: 0 0 6px 0;
|
254 |
+
font-size: 10px;
|
255 |
+
font-weight: 500;
|
256 |
+
color: var(--text-secondary);
|
257 |
+
}
|
258 |
+
|
259 |
+
.slider-container {
|
260 |
+
display: flex;
|
261 |
+
align-items: center;
|
262 |
+
gap: 6px;
|
263 |
+
width: 100%;
|
264 |
+
}
|
265 |
+
|
266 |
+
.slider-container label {
|
267 |
+
min-width: 60px;
|
268 |
+
font-size: 10px;
|
269 |
+
flex-shrink: 0;
|
270 |
+
}
|
271 |
+
|
272 |
+
input[type="range"] {
|
273 |
+
flex: 1;
|
274 |
+
height: 2px;
|
275 |
+
-webkit-appearance: none;
|
276 |
+
background: rgba(255, 255, 255, 0.1);
|
277 |
+
border-radius: 1px;
|
278 |
+
min-width: 0;
|
279 |
+
}
|
280 |
+
|
281 |
+
input[type="range"]::-webkit-slider-thumb {
|
282 |
+
-webkit-appearance: none;
|
283 |
+
width: 8px;
|
284 |
+
height: 8px;
|
285 |
+
border-radius: 50%;
|
286 |
+
background: var(--primary);
|
287 |
+
cursor: pointer;
|
288 |
+
}
|
289 |
+
|
290 |
+
.toggle-switch {
|
291 |
+
position: relative;
|
292 |
+
display: inline-block;
|
293 |
+
width: 20px;
|
294 |
+
height: 10px;
|
295 |
+
}
|
296 |
+
|
297 |
+
.toggle-switch input {
|
298 |
+
opacity: 0;
|
299 |
+
width: 0;
|
300 |
+
height: 0;
|
301 |
+
}
|
302 |
+
|
303 |
+
.toggle-slider {
|
304 |
+
position: absolute;
|
305 |
+
cursor: pointer;
|
306 |
+
top: 0;
|
307 |
+
left: 0;
|
308 |
+
right: 0;
|
309 |
+
bottom: 0;
|
310 |
+
background: rgba(255, 255, 255, 0.1);
|
311 |
+
transition: .4s;
|
312 |
+
border-radius: 10px;
|
313 |
+
}
|
314 |
+
|
315 |
+
.toggle-slider:before {
|
316 |
+
position: absolute;
|
317 |
+
content: "";
|
318 |
+
height: 8px;
|
319 |
+
width: 8px;
|
320 |
+
left: 1px;
|
321 |
+
bottom: 1px;
|
322 |
+
background: var(--surface);
|
323 |
+
border: 1px solid var(--border);
|
324 |
+
transition: .4s;
|
325 |
+
border-radius: 50%;
|
326 |
+
}
|
327 |
+
|
328 |
+
input:checked + .toggle-slider {
|
329 |
+
background: var(--primary);
|
330 |
+
}
|
331 |
+
|
332 |
+
input:checked + .toggle-slider:before {
|
333 |
+
transform: translateX(10px);
|
334 |
+
}
|
335 |
+
|
336 |
+
.checkbox-container {
|
337 |
+
display: flex;
|
338 |
+
align-items: center;
|
339 |
+
gap: 4px;
|
340 |
+
margin-bottom: 4px;
|
341 |
+
}
|
342 |
+
|
343 |
+
.checkbox-container label {
|
344 |
+
font-size: 10px;
|
345 |
+
cursor: pointer;
|
346 |
+
}
|
347 |
+
|
348 |
+
#loading-overlay {
|
349 |
+
position: absolute;
|
350 |
+
top: 0;
|
351 |
+
left: 0;
|
352 |
+
width: 100%;
|
353 |
+
height: 100%;
|
354 |
+
background: var(--bg);
|
355 |
+
display: flex;
|
356 |
+
flex-direction: column;
|
357 |
+
align-items: center;
|
358 |
+
justify-content: center;
|
359 |
+
z-index: 100;
|
360 |
+
transition: opacity 0.5s;
|
361 |
+
}
|
362 |
+
|
363 |
+
#loading-overlay.fade-out {
|
364 |
+
opacity: 0;
|
365 |
+
pointer-events: none;
|
366 |
+
}
|
367 |
+
|
368 |
+
.spinner {
|
369 |
+
width: 50px;
|
370 |
+
height: 50px;
|
371 |
+
border: 5px solid rgba(155, 89, 182, 0.2);
|
372 |
+
border-radius: 50%;
|
373 |
+
border-top-color: var(--primary);
|
374 |
+
animation: spin 1s ease-in-out infinite;
|
375 |
+
margin-bottom: 16px;
|
376 |
+
}
|
377 |
+
|
378 |
+
@keyframes spin {
|
379 |
+
to { transform: rotate(360deg); }
|
380 |
+
}
|
381 |
+
|
382 |
+
#loading-text {
|
383 |
+
margin-top: 16px;
|
384 |
+
font-size: 18px;
|
385 |
+
color: var(--text);
|
386 |
+
font-weight: 500;
|
387 |
+
}
|
388 |
+
|
389 |
+
#frame-counter {
|
390 |
+
color: var(--text-secondary);
|
391 |
+
font-size: 7px;
|
392 |
+
font-weight: 500;
|
393 |
+
min-width: 60px;
|
394 |
+
text-align: center;
|
395 |
+
padding: 0 4px;
|
396 |
+
}
|
397 |
+
|
398 |
+
.control-btn {
|
399 |
+
background: rgba(255, 255, 255, 0.08);
|
400 |
+
border: 1px solid var(--border);
|
401 |
+
padding: 4px 6px;
|
402 |
+
border-radius: 3px;
|
403 |
+
cursor: pointer;
|
404 |
+
display: flex;
|
405 |
+
align-items: center;
|
406 |
+
justify-content: center;
|
407 |
+
transition: all 0.2s ease;
|
408 |
+
font-size: 6px;
|
409 |
+
}
|
410 |
+
|
411 |
+
.control-btn:hover {
|
412 |
+
background: rgba(255, 255, 255, 0.15);
|
413 |
+
transform: translateY(-1px);
|
414 |
+
}
|
415 |
+
|
416 |
+
.control-btn.active {
|
417 |
+
background: var(--primary);
|
418 |
+
color: white;
|
419 |
+
}
|
420 |
+
|
421 |
+
.control-btn.active:hover {
|
422 |
+
background: var(--primary);
|
423 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
424 |
+
}
|
425 |
+
|
426 |
+
#settings-toggle-btn {
|
427 |
+
position: relative;
|
428 |
+
border-radius: 6px;
|
429 |
+
z-index: 20;
|
430 |
+
}
|
431 |
+
|
432 |
+
#settings-toggle-btn.active {
|
433 |
+
background: var(--primary);
|
434 |
+
color: white;
|
435 |
+
}
|
436 |
+
|
437 |
+
#status-bar,
|
438 |
+
#control-panel,
|
439 |
+
#settings-panel,
|
440 |
+
button,
|
441 |
+
input,
|
442 |
+
select,
|
443 |
+
.toggle-switch {
|
444 |
+
pointer-events: auto;
|
445 |
+
}
|
446 |
+
|
447 |
+
h2 {
|
448 |
+
font-size: 0.9rem;
|
449 |
+
font-weight: 600;
|
450 |
+
margin-top: 0;
|
451 |
+
margin-bottom: 12px;
|
452 |
+
color: var(--primary);
|
453 |
+
cursor: move;
|
454 |
+
user-select: none;
|
455 |
+
display: flex;
|
456 |
+
align-items: center;
|
457 |
+
}
|
458 |
+
|
459 |
+
.drag-handle {
|
460 |
+
font-size: 10px;
|
461 |
+
margin-right: 4px;
|
462 |
+
opacity: 0.6;
|
463 |
+
}
|
464 |
+
|
465 |
+
h2:hover .drag-handle {
|
466 |
+
opacity: 1;
|
467 |
+
}
|
468 |
+
|
469 |
+
.loading-subtitle {
|
470 |
+
font-size: 7px;
|
471 |
+
color: var(--text-secondary);
|
472 |
+
margin-top: 4px;
|
473 |
+
}
|
474 |
+
|
475 |
+
#reset-view-btn {
|
476 |
+
background: var(--primary-light);
|
477 |
+
color: var(--primary);
|
478 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
479 |
+
font-weight: 600;
|
480 |
+
font-size: 9px;
|
481 |
+
padding: 4px 6px;
|
482 |
+
transition: all 0.2s;
|
483 |
+
}
|
484 |
+
|
485 |
+
#reset-view-btn:hover {
|
486 |
+
background: var(--primary);
|
487 |
+
color: white;
|
488 |
+
transform: translateY(-2px);
|
489 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
490 |
+
}
|
491 |
+
|
492 |
+
#show-settings-btn {
|
493 |
+
position: absolute;
|
494 |
+
top: 16px;
|
495 |
+
right: 16px;
|
496 |
+
z-index: 15;
|
497 |
+
display: none;
|
498 |
+
}
|
499 |
+
|
500 |
+
#settings-panel.visible {
|
501 |
+
display: block;
|
502 |
+
opacity: 1;
|
503 |
+
animation: slideIn 0.3s ease forwards;
|
504 |
+
}
|
505 |
+
|
506 |
+
@keyframes slideIn {
|
507 |
+
from {
|
508 |
+
transform: translateY(20px);
|
509 |
+
opacity: 0;
|
510 |
+
}
|
511 |
+
to {
|
512 |
+
transform: translateY(0);
|
513 |
+
opacity: 1;
|
514 |
+
}
|
515 |
+
}
|
516 |
+
|
517 |
+
.dragging {
|
518 |
+
opacity: 0.9;
|
519 |
+
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
|
520 |
+
transition: none !important;
|
521 |
+
}
|
522 |
+
|
523 |
+
/* Tooltip for draggable element */
|
524 |
+
.tooltip-drag {
|
525 |
+
position: absolute;
|
526 |
+
left: 50%;
|
527 |
+
transform: translateX(-50%);
|
528 |
+
background: var(--primary);
|
529 |
+
color: white;
|
530 |
+
font-size: 9px;
|
531 |
+
padding: 2px 4px;
|
532 |
+
border-radius: 2px;
|
533 |
+
opacity: 0;
|
534 |
+
pointer-events: none;
|
535 |
+
transition: opacity 0.3s;
|
536 |
+
white-space: nowrap;
|
537 |
+
bottom: 100%;
|
538 |
+
margin-bottom: 4px;
|
539 |
+
}
|
540 |
+
|
541 |
+
h2:hover .tooltip-drag {
|
542 |
+
opacity: 1;
|
543 |
+
}
|
544 |
+
|
545 |
+
.btn-group {
|
546 |
+
display: flex;
|
547 |
+
margin-top: 8px;
|
548 |
+
}
|
549 |
+
|
550 |
+
#reset-settings-btn {
|
551 |
+
background: var(--primary-light);
|
552 |
+
color: var(--primary);
|
553 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
554 |
+
font-weight: 600;
|
555 |
+
font-size: 9px;
|
556 |
+
padding: 4px 6px;
|
557 |
+
transition: all 0.2s;
|
558 |
+
}
|
559 |
+
|
560 |
+
#reset-settings-btn:hover {
|
561 |
+
background: var(--primary);
|
562 |
+
color: white;
|
563 |
+
transform: translateY(-2px);
|
564 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
565 |
+
}
|
566 |
+
</style>
|
567 |
+
</head>
|
568 |
+
<body>
|
569 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
570 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
571 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
572 |
+
|
573 |
+
<div id="canvas-container"></div>
|
574 |
+
|
575 |
+
<div id="ui-container">
|
576 |
+
<div id="status-bar">Initializing...</div>
|
577 |
+
|
578 |
+
<div id="control-panel">
|
579 |
+
<button id="play-pause-btn" class="control-btn">
|
580 |
+
<svg class="icon" viewBox="0 0 24 24">
|
581 |
+
<path id="play-icon" d="M8 5v14l11-7z"/>
|
582 |
+
<path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
|
583 |
+
</svg>
|
584 |
+
<span class="tooltip">Play/Pause</span>
|
585 |
+
</button>
|
586 |
+
|
587 |
+
<div id="timeline">
|
588 |
+
<div id="progress"></div>
|
589 |
+
</div>
|
590 |
+
|
591 |
+
<div id="frame-counter">Frame 0 / 0</div>
|
592 |
+
|
593 |
+
<div id="playback-controls">
|
594 |
+
<button id="speed-btn" class="control-btn">1x</button>
|
595 |
+
</div>
|
596 |
+
</div>
|
597 |
+
|
598 |
+
<div id="settings-panel">
|
599 |
+
<h2>
|
600 |
+
<span class="drag-handle">☰</span>
|
601 |
+
Visualization Settings
|
602 |
+
<button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
|
603 |
+
<svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
|
604 |
+
<path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
|
605 |
+
</svg>
|
606 |
+
</button>
|
607 |
+
</h2>
|
608 |
+
|
609 |
+
<div class="settings-group">
|
610 |
+
<h3>Point Cloud</h3>
|
611 |
+
<div class="slider-container">
|
612 |
+
<label for="point-size">Size</label>
|
613 |
+
<input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
|
614 |
+
</div>
|
615 |
+
<div class="slider-container">
|
616 |
+
<label for="point-opacity">Opacity</label>
|
617 |
+
<input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
|
618 |
+
</div>
|
619 |
+
<div class="slider-container">
|
620 |
+
<label for="max-depth">Max Depth</label>
|
621 |
+
<input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
|
622 |
+
</div>
|
623 |
+
</div>
|
624 |
+
|
625 |
+
<div class="settings-group">
|
626 |
+
<h3>Trajectory</h3>
|
627 |
+
<div class="checkbox-container">
|
628 |
+
<label class="toggle-switch">
|
629 |
+
<input type="checkbox" id="show-trajectory" checked>
|
630 |
+
<span class="toggle-slider"></span>
|
631 |
+
</label>
|
632 |
+
<label for="show-trajectory">Show Trajectory</label>
|
633 |
+
</div>
|
634 |
+
<div class="checkbox-container">
|
635 |
+
<label class="toggle-switch">
|
636 |
+
<input type="checkbox" id="enable-rich-trail">
|
637 |
+
<span class="toggle-slider"></span>
|
638 |
+
</label>
|
639 |
+
<label for="enable-rich-trail">Visual-Rich Trail</label>
|
640 |
+
</div>
|
641 |
+
<div class="slider-container">
|
642 |
+
<label for="trajectory-line-width">Line Width</label>
|
643 |
+
<input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
|
644 |
+
</div>
|
645 |
+
<div class="slider-container">
|
646 |
+
<label for="trajectory-ball-size">Ball Size</label>
|
647 |
+
<input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
|
648 |
+
</div>
|
649 |
+
<div class="slider-container">
|
650 |
+
<label for="trajectory-history">History Frames</label>
|
651 |
+
<input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
|
652 |
+
</div>
|
653 |
+
<div class="slider-container" id="tail-opacity-container" style="display: none;">
|
654 |
+
<label for="trajectory-fade">Tail Opacity</label>
|
655 |
+
<input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
|
656 |
+
</div>
|
657 |
+
</div>
|
658 |
+
|
659 |
+
<div class="settings-group">
|
660 |
+
<h3>Camera</h3>
|
661 |
+
<div class="checkbox-container">
|
662 |
+
<label class="toggle-switch">
|
663 |
+
<input type="checkbox" id="show-camera-frustum" checked>
|
664 |
+
<span class="toggle-slider"></span>
|
665 |
+
</label>
|
666 |
+
<label for="show-camera-frustum">Show Camera Frustum</label>
|
667 |
+
</div>
|
668 |
+
<div class="slider-container">
|
669 |
+
<label for="frustum-size">Size</label>
|
670 |
+
<input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
|
671 |
+
</div>
|
672 |
+
</div>
|
673 |
+
|
674 |
+
<div class="settings-group">
|
675 |
+
<h3>Keep History</h3>
|
676 |
+
<div class="checkbox-container">
|
677 |
+
<label class="toggle-switch">
|
678 |
+
<input type="checkbox" id="enable-keep-history">
|
679 |
+
<span class="toggle-slider"></span>
|
680 |
+
</label>
|
681 |
+
<label for="enable-keep-history">Enable Keep History</label>
|
682 |
+
</div>
|
683 |
+
<div class="slider-container">
|
684 |
+
<label for="history-stride">Stride</label>
|
685 |
+
<select id="history-stride">
|
686 |
+
<option value="1">1</option>
|
687 |
+
<option value="2">2</option>
|
688 |
+
<option value="5" selected>5</option>
|
689 |
+
<option value="10">10</option>
|
690 |
+
<option value="20">20</option>
|
691 |
+
</select>
|
692 |
+
</div>
|
693 |
+
</div>
|
694 |
+
|
695 |
+
<div class="settings-group">
|
696 |
+
<h3>Background</h3>
|
697 |
+
<div class="checkbox-container">
|
698 |
+
<label class="toggle-switch">
|
699 |
+
<input type="checkbox" id="white-background">
|
700 |
+
<span class="toggle-slider"></span>
|
701 |
+
</label>
|
702 |
+
<label for="white-background">White Background</label>
|
703 |
+
</div>
|
704 |
+
</div>
|
705 |
+
|
706 |
+
<div class="settings-group">
|
707 |
+
<div class="btn-group">
|
708 |
+
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
709 |
+
<button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
|
710 |
+
</div>
|
711 |
+
</div>
|
712 |
+
</div>
|
713 |
+
|
714 |
+
<button id="show-settings-btn" class="control-btn" title="Show Settings">
|
715 |
+
<svg class="icon" viewBox="0 0 24 24">
|
716 |
+
<path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
|
717 |
+
</svg>
|
718 |
+
</button>
|
719 |
+
</div>
|
720 |
+
|
721 |
+
<div id="loading-overlay">
|
722 |
+
<!-- <div class="spinner"></div> -->
|
723 |
+
<div id="loading-text"></div>
|
724 |
+
<div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
|
725 |
+
</div>
|
726 |
+
|
727 |
+
<!-- Libraries -->
|
728 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
|
729 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/build/three.min.js"></script>
|
730 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/controls/OrbitControls.js"></script>
|
731 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/build/dat.gui.min.js"></script>
|
732 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegmentsGeometry.js"></script>
|
733 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineGeometry.js"></script>
|
734 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineMaterial.js"></script>
|
735 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegments2.js"></script>
|
736 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/Line2.js"></script>
|
737 |
+
|
738 |
+
<script>
|
739 |
+
class PointCloudVisualizer {
|
740 |
+
constructor() {
|
741 |
+
this.data = null;
|
742 |
+
this.config = {};
|
743 |
+
this.currentFrame = 0;
|
744 |
+
this.isPlaying = false;
|
745 |
+
this.playbackSpeed = 1;
|
746 |
+
this.lastFrameTime = 0;
|
747 |
+
this.defaultSettings = null;
|
748 |
+
|
749 |
+
this.ui = {
|
750 |
+
statusBar: document.getElementById('status-bar'),
|
751 |
+
playPauseBtn: document.getElementById('play-pause-btn'),
|
752 |
+
speedBtn: document.getElementById('speed-btn'),
|
753 |
+
timeline: document.getElementById('timeline'),
|
754 |
+
progress: document.getElementById('progress'),
|
755 |
+
settingsPanel: document.getElementById('settings-panel'),
|
756 |
+
loadingOverlay: document.getElementById('loading-overlay'),
|
757 |
+
loadingText: document.getElementById('loading-text'),
|
758 |
+
settingsToggleBtn: document.getElementById('settings-toggle-btn'),
|
759 |
+
frameCounter: document.getElementById('frame-counter'),
|
760 |
+
pointSize: document.getElementById('point-size'),
|
761 |
+
pointOpacity: document.getElementById('point-opacity'),
|
762 |
+
maxDepth: document.getElementById('max-depth'),
|
763 |
+
showTrajectory: document.getElementById('show-trajectory'),
|
764 |
+
enableRichTrail: document.getElementById('enable-rich-trail'),
|
765 |
+
trajectoryLineWidth: document.getElementById('trajectory-line-width'),
|
766 |
+
trajectoryBallSize: document.getElementById('trajectory-ball-size'),
|
767 |
+
trajectoryHistory: document.getElementById('trajectory-history'),
|
768 |
+
trajectoryFade: document.getElementById('trajectory-fade'),
|
769 |
+
tailOpacityContainer: document.getElementById('tail-opacity-container'),
|
770 |
+
resetViewBtn: document.getElementById('reset-view-btn'),
|
771 |
+
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
772 |
+
frustumSize: document.getElementById('frustum-size'),
|
773 |
+
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
774 |
+
showSettingsBtn: document.getElementById('show-settings-btn'),
|
775 |
+
enableKeepHistory: document.getElementById('enable-keep-history'),
|
776 |
+
historyStride: document.getElementById('history-stride'),
|
777 |
+
whiteBackground: document.getElementById('white-background')
|
778 |
+
};
|
779 |
+
|
780 |
+
this.scene = null;
|
781 |
+
this.camera = null;
|
782 |
+
this.renderer = null;
|
783 |
+
this.controls = null;
|
784 |
+
this.pointCloud = null;
|
785 |
+
this.trajectories = [];
|
786 |
+
this.cameraFrustum = null;
|
787 |
+
|
788 |
+
// Keep History functionality
|
789 |
+
this.historyPointClouds = [];
|
790 |
+
this.historyTrajectories = [];
|
791 |
+
this.historyFrames = [];
|
792 |
+
this.maxHistoryFrames = 20;
|
793 |
+
|
794 |
+
this.initThreeJS();
|
795 |
+
this.loadDefaultSettings().then(() => {
|
796 |
+
this.initEventListeners();
|
797 |
+
this.loadData();
|
798 |
+
});
|
799 |
+
}
|
800 |
+
|
801 |
+
async loadDefaultSettings() {
|
802 |
+
try {
|
803 |
+
const urlParams = new URLSearchParams(window.location.search);
|
804 |
+
const dataPath = urlParams.get('data') || '';
|
805 |
+
|
806 |
+
const defaultSettings = {
|
807 |
+
pointSize: 0.03,
|
808 |
+
pointOpacity: 1.0,
|
809 |
+
showTrajectory: true,
|
810 |
+
trajectoryLineWidth: 2.5,
|
811 |
+
trajectoryBallSize: 0.015,
|
812 |
+
trajectoryHistory: 0,
|
813 |
+
showCameraFrustum: true,
|
814 |
+
frustumSize: 0.2
|
815 |
+
};
|
816 |
+
|
817 |
+
if (!dataPath) {
|
818 |
+
this.defaultSettings = defaultSettings;
|
819 |
+
this.applyDefaultSettings();
|
820 |
+
return;
|
821 |
+
}
|
822 |
+
|
823 |
+
// Try to extract dataset and videoId from the data path
|
824 |
+
// Expected format: demos/datasetname/videoid.bin
|
825 |
+
const pathParts = dataPath.split('/');
|
826 |
+
if (pathParts.length < 3) {
|
827 |
+
this.defaultSettings = defaultSettings;
|
828 |
+
this.applyDefaultSettings();
|
829 |
+
return;
|
830 |
+
}
|
831 |
+
|
832 |
+
const datasetName = pathParts[pathParts.length - 2];
|
833 |
+
let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
|
834 |
+
|
835 |
+
// Load settings from data.json
|
836 |
+
const response = await fetch('./data.json');
|
837 |
+
if (!response.ok) {
|
838 |
+
this.defaultSettings = defaultSettings;
|
839 |
+
this.applyDefaultSettings();
|
840 |
+
return;
|
841 |
+
}
|
842 |
+
|
843 |
+
const settingsData = await response.json();
|
844 |
+
|
845 |
+
// Check if this dataset and video exist
|
846 |
+
if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
|
847 |
+
this.defaultSettings = settingsData[datasetName][videoId];
|
848 |
+
} else {
|
849 |
+
this.defaultSettings = defaultSettings;
|
850 |
+
}
|
851 |
+
|
852 |
+
this.applyDefaultSettings();
|
853 |
+
} catch (error) {
|
854 |
+
console.error("Error loading default settings:", error);
|
855 |
+
|
856 |
+
this.defaultSettings = {
|
857 |
+
pointSize: 0.03,
|
858 |
+
pointOpacity: 1.0,
|
859 |
+
showTrajectory: true,
|
860 |
+
trajectoryLineWidth: 2.5,
|
861 |
+
trajectoryBallSize: 0.015,
|
862 |
+
trajectoryHistory: 0,
|
863 |
+
showCameraFrustum: true,
|
864 |
+
frustumSize: 0.2
|
865 |
+
};
|
866 |
+
|
867 |
+
this.applyDefaultSettings();
|
868 |
+
}
|
869 |
+
}
|
870 |
+
|
871 |
+
applyDefaultSettings() {
|
872 |
+
if (!this.defaultSettings) return;
|
873 |
+
|
874 |
+
if (this.ui.pointSize) {
|
875 |
+
this.ui.pointSize.value = this.defaultSettings.pointSize;
|
876 |
+
}
|
877 |
+
|
878 |
+
if (this.ui.pointOpacity) {
|
879 |
+
this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
|
880 |
+
}
|
881 |
+
|
882 |
+
if (this.ui.maxDepth) {
|
883 |
+
this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
|
884 |
+
}
|
885 |
+
|
886 |
+
if (this.ui.showTrajectory) {
|
887 |
+
this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
|
888 |
+
}
|
889 |
+
|
890 |
+
if (this.ui.trajectoryLineWidth) {
|
891 |
+
this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
|
892 |
+
}
|
893 |
+
|
894 |
+
if (this.ui.trajectoryBallSize) {
|
895 |
+
this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
|
896 |
+
}
|
897 |
+
|
898 |
+
if (this.ui.trajectoryHistory) {
|
899 |
+
this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
|
900 |
+
}
|
901 |
+
|
902 |
+
if (this.ui.showCameraFrustum) {
|
903 |
+
this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
|
904 |
+
}
|
905 |
+
|
906 |
+
if (this.ui.frustumSize) {
|
907 |
+
this.ui.frustumSize.value = this.defaultSettings.frustumSize;
|
908 |
+
}
|
909 |
+
}
|
910 |
+
|
911 |
+
initThreeJS() {
|
912 |
+
this.scene = new THREE.Scene();
|
913 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
914 |
+
|
915 |
+
this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
|
916 |
+
this.camera.position.set(0, 0, 0);
|
917 |
+
|
918 |
+
this.renderer = new THREE.WebGLRenderer({ antialias: true });
|
919 |
+
this.renderer.setPixelRatio(window.devicePixelRatio);
|
920 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
921 |
+
document.getElementById('canvas-container').appendChild(this.renderer.domElement);
|
922 |
+
|
923 |
+
this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
|
924 |
+
this.controls.enableDamping = true;
|
925 |
+
this.controls.dampingFactor = 0.05;
|
926 |
+
this.controls.target.set(0, 0, 0);
|
927 |
+
this.controls.minDistance = 0.1;
|
928 |
+
this.controls.maxDistance = 1000;
|
929 |
+
this.controls.update();
|
930 |
+
|
931 |
+
const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
|
932 |
+
this.scene.add(ambientLight);
|
933 |
+
|
934 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
|
935 |
+
directionalLight.position.set(1, 1, 1);
|
936 |
+
this.scene.add(directionalLight);
|
937 |
+
}
|
938 |
+
|
939 |
+
initEventListeners() {
|
940 |
+
window.addEventListener('resize', () => this.onWindowResize());
|
941 |
+
|
942 |
+
this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
|
943 |
+
|
944 |
+
this.ui.timeline.addEventListener('click', (e) => {
|
945 |
+
const rect = this.ui.timeline.getBoundingClientRect();
|
946 |
+
const pos = (e.clientX - rect.left) / rect.width;
|
947 |
+
this.seekTo(pos);
|
948 |
+
});
|
949 |
+
|
950 |
+
this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
|
951 |
+
|
952 |
+
this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
|
953 |
+
this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
|
954 |
+
this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
|
955 |
+
this.ui.showTrajectory.addEventListener('change', () => {
|
956 |
+
this.trajectories.forEach(trajectory => {
|
957 |
+
trajectory.visible = this.ui.showTrajectory.checked;
|
958 |
+
});
|
959 |
+
});
|
960 |
+
|
961 |
+
this.ui.enableRichTrail.addEventListener('change', () => {
|
962 |
+
this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
|
963 |
+
this.updateTrajectories(this.currentFrame);
|
964 |
+
});
|
965 |
+
|
966 |
+
this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
|
967 |
+
this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
|
968 |
+
this.ui.trajectoryHistory.addEventListener('input', () => {
|
969 |
+
this.updateTrajectories(this.currentFrame);
|
970 |
+
});
|
971 |
+
this.ui.trajectoryFade.addEventListener('input', () => {
|
972 |
+
this.updateTrajectories(this.currentFrame);
|
973 |
+
});
|
974 |
+
|
975 |
+
this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
|
976 |
+
|
977 |
+
const resetSettingsBtn = document.getElementById('reset-settings-btn');
|
978 |
+
if (resetSettingsBtn) {
|
979 |
+
resetSettingsBtn.addEventListener('click', () => this.resetSettings());
|
980 |
+
}
|
981 |
+
|
982 |
+
document.addEventListener('keydown', (e) => {
|
983 |
+
if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
|
984 |
+
this.ui.settingsPanel.classList.remove('visible');
|
985 |
+
this.ui.settingsToggleBtn.classList.remove('active');
|
986 |
+
}
|
987 |
+
});
|
988 |
+
|
989 |
+
if (this.ui.settingsToggleBtn) {
|
990 |
+
this.ui.settingsToggleBtn.addEventListener('click', () => {
|
991 |
+
const isVisible = this.ui.settingsPanel.classList.toggle('visible');
|
992 |
+
this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
|
993 |
+
|
994 |
+
if (isVisible) {
|
995 |
+
const panelRect = this.ui.settingsPanel.getBoundingClientRect();
|
996 |
+
const viewportHeight = window.innerHeight;
|
997 |
+
|
998 |
+
if (panelRect.bottom > viewportHeight) {
|
999 |
+
this.ui.settingsPanel.style.bottom = 'auto';
|
1000 |
+
this.ui.settingsPanel.style.top = '80px';
|
1001 |
+
}
|
1002 |
+
}
|
1003 |
+
});
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
if (this.ui.frustumSize) {
|
1007 |
+
this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
|
1008 |
+
}
|
1009 |
+
|
1010 |
+
if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
|
1011 |
+
this.ui.hideSettingsBtn.addEventListener('click', () => {
|
1012 |
+
this.ui.settingsPanel.classList.add('is-hidden');
|
1013 |
+
this.ui.showSettingsBtn.style.display = 'flex';
|
1014 |
+
});
|
1015 |
+
|
1016 |
+
this.ui.showSettingsBtn.addEventListener('click', () => {
|
1017 |
+
this.ui.settingsPanel.classList.remove('is-hidden');
|
1018 |
+
this.ui.showSettingsBtn.style.display = 'none';
|
1019 |
+
});
|
1020 |
+
}
|
1021 |
+
|
1022 |
+
// Keep History event listeners
|
1023 |
+
if (this.ui.enableKeepHistory) {
|
1024 |
+
this.ui.enableKeepHistory.addEventListener('change', () => {
|
1025 |
+
if (!this.ui.enableKeepHistory.checked) {
|
1026 |
+
this.clearHistory();
|
1027 |
+
}
|
1028 |
+
});
|
1029 |
+
}
|
1030 |
+
|
1031 |
+
if (this.ui.historyStride) {
|
1032 |
+
this.ui.historyStride.addEventListener('change', () => {
|
1033 |
+
this.clearHistory();
|
1034 |
+
});
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
// Background toggle event listener
|
1038 |
+
if (this.ui.whiteBackground) {
|
1039 |
+
this.ui.whiteBackground.addEventListener('change', () => {
|
1040 |
+
this.toggleBackground();
|
1041 |
+
});
|
1042 |
+
}
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
makeElementDraggable(element) {
|
1046 |
+
let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
|
1047 |
+
|
1048 |
+
const dragHandle = element.querySelector('h2');
|
1049 |
+
|
1050 |
+
if (dragHandle) {
|
1051 |
+
dragHandle.onmousedown = dragMouseDown;
|
1052 |
+
dragHandle.title = "Drag to move panel";
|
1053 |
+
} else {
|
1054 |
+
element.onmousedown = dragMouseDown;
|
1055 |
+
}
|
1056 |
+
|
1057 |
+
function dragMouseDown(e) {
|
1058 |
+
e = e || window.event;
|
1059 |
+
e.preventDefault();
|
1060 |
+
pos3 = e.clientX;
|
1061 |
+
pos4 = e.clientY;
|
1062 |
+
document.onmouseup = closeDragElement;
|
1063 |
+
document.onmousemove = elementDrag;
|
1064 |
+
|
1065 |
+
element.classList.add('dragging');
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
function elementDrag(e) {
|
1069 |
+
e = e || window.event;
|
1070 |
+
e.preventDefault();
|
1071 |
+
pos1 = pos3 - e.clientX;
|
1072 |
+
pos2 = pos4 - e.clientY;
|
1073 |
+
pos3 = e.clientX;
|
1074 |
+
pos4 = e.clientY;
|
1075 |
+
|
1076 |
+
const newTop = element.offsetTop - pos2;
|
1077 |
+
const newLeft = element.offsetLeft - pos1;
|
1078 |
+
|
1079 |
+
const viewportWidth = window.innerWidth;
|
1080 |
+
const viewportHeight = window.innerHeight;
|
1081 |
+
|
1082 |
+
const panelRect = element.getBoundingClientRect();
|
1083 |
+
|
1084 |
+
const maxTop = viewportHeight - 50;
|
1085 |
+
const maxLeft = viewportWidth - 50;
|
1086 |
+
|
1087 |
+
element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
|
1088 |
+
element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
|
1089 |
+
|
1090 |
+
// Remove bottom/right settings when dragging
|
1091 |
+
element.style.bottom = 'auto';
|
1092 |
+
element.style.right = 'auto';
|
1093 |
+
}
|
1094 |
+
|
1095 |
+
function closeDragElement() {
|
1096 |
+
document.onmouseup = null;
|
1097 |
+
document.onmousemove = null;
|
1098 |
+
|
1099 |
+
element.classList.remove('dragging');
|
1100 |
+
}
|
1101 |
+
}
|
1102 |
+
|
1103 |
+
async loadData() {
|
1104 |
+
try {
|
1105 |
+
// this.ui.loadingText.textContent = "Loading binary data...";
|
1106 |
+
|
1107 |
+
let arrayBuffer;
|
1108 |
+
|
1109 |
+
if (window.embeddedBase64) {
|
1110 |
+
// Base64 embedded path
|
1111 |
+
const binaryString = atob(window.embeddedBase64);
|
1112 |
+
const len = binaryString.length;
|
1113 |
+
const bytes = new Uint8Array(len);
|
1114 |
+
for (let i = 0; i < len; i++) {
|
1115 |
+
bytes[i] = binaryString.charCodeAt(i);
|
1116 |
+
}
|
1117 |
+
arrayBuffer = bytes.buffer;
|
1118 |
+
} else {
|
1119 |
+
// Default fetch path (fallback)
|
1120 |
+
const urlParams = new URLSearchParams(window.location.search);
|
1121 |
+
const dataPath = urlParams.get('data') || 'data.bin';
|
1122 |
+
|
1123 |
+
const response = await fetch(dataPath);
|
1124 |
+
if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
|
1125 |
+
arrayBuffer = await response.arrayBuffer();
|
1126 |
+
}
|
1127 |
+
|
1128 |
+
const dataView = new DataView(arrayBuffer);
|
1129 |
+
const headerLen = dataView.getUint32(0, true);
|
1130 |
+
|
1131 |
+
const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
|
1132 |
+
const header = JSON.parse(headerText);
|
1133 |
+
|
1134 |
+
const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
|
1135 |
+
const decompressed = pako.inflate(compressedBlob).buffer;
|
1136 |
+
|
1137 |
+
const arrays = {};
|
1138 |
+
for (const key in header) {
|
1139 |
+
if (key === "meta") continue;
|
1140 |
+
|
1141 |
+
const meta = header[key];
|
1142 |
+
const { dtype, shape, offset, length } = meta;
|
1143 |
+
const slice = decompressed.slice(offset, offset + length);
|
1144 |
+
|
1145 |
+
let typedArray;
|
1146 |
+
switch (dtype) {
|
1147 |
+
case "uint8": typedArray = new Uint8Array(slice); break;
|
1148 |
+
case "uint16": typedArray = new Uint16Array(slice); break;
|
1149 |
+
case "float32": typedArray = new Float32Array(slice); break;
|
1150 |
+
case "float64": typedArray = new Float64Array(slice); break;
|
1151 |
+
default: throw new Error(`Unknown dtype: ${dtype}`);
|
1152 |
+
}
|
1153 |
+
|
1154 |
+
arrays[key] = { data: typedArray, shape: shape };
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
this.data = arrays;
|
1158 |
+
this.config = header.meta;
|
1159 |
+
|
1160 |
+
this.initCameraWithCorrectFOV();
|
1161 |
+
this.ui.loadingText.textContent = "Creating point cloud...";
|
1162 |
+
|
1163 |
+
this.initPointCloud();
|
1164 |
+
this.initTrajectories();
|
1165 |
+
|
1166 |
+
setTimeout(() => {
|
1167 |
+
this.ui.loadingOverlay.classList.add('fade-out');
|
1168 |
+
this.ui.statusBar.classList.add('hidden');
|
1169 |
+
this.startAnimation();
|
1170 |
+
}, 500);
|
1171 |
+
} catch (error) {
|
1172 |
+
console.error("Error loading data:", error);
|
1173 |
+
this.ui.statusBar.textContent = `Error: ${error.message}`;
|
1174 |
+
// this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
|
1175 |
+
}
|
1176 |
+
}
|
1177 |
+
|
1178 |
+
initPointCloud() {
|
1179 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
1180 |
+
const positions = new Float32Array(numPoints * 3);
|
1181 |
+
const colors = new Float32Array(numPoints * 3);
|
1182 |
+
|
1183 |
+
const geometry = new THREE.BufferGeometry();
|
1184 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
|
1185 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
|
1186 |
+
|
1187 |
+
const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
|
1188 |
+
const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
|
1189 |
+
|
1190 |
+
const material = new THREE.PointsMaterial({
|
1191 |
+
size: pointSize,
|
1192 |
+
vertexColors: true,
|
1193 |
+
transparent: true,
|
1194 |
+
opacity: pointOpacity,
|
1195 |
+
sizeAttenuation: true
|
1196 |
+
});
|
1197 |
+
|
1198 |
+
this.pointCloud = new THREE.Points(geometry, material);
|
1199 |
+
this.scene.add(this.pointCloud);
|
1200 |
+
}
|
1201 |
+
|
1202 |
+
initTrajectories() {
|
1203 |
+
if (!this.data.trajectories) return;
|
1204 |
+
|
1205 |
+
this.trajectories.forEach(trajectory => {
|
1206 |
+
if (trajectory.userData.lineSegments) {
|
1207 |
+
trajectory.userData.lineSegments.forEach(segment => {
|
1208 |
+
segment.geometry.dispose();
|
1209 |
+
segment.material.dispose();
|
1210 |
+
});
|
1211 |
+
}
|
1212 |
+
this.scene.remove(trajectory);
|
1213 |
+
});
|
1214 |
+
this.trajectories = [];
|
1215 |
+
|
1216 |
+
const shape = this.data.trajectories.shape;
|
1217 |
+
if (!shape || shape.length < 2) return;
|
1218 |
+
|
1219 |
+
const [totalFrames, numTrajectories] = shape;
|
1220 |
+
const palette = this.createColorPalette(numTrajectories);
|
1221 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1222 |
+
const maxHistory = 500; // Max value of the history slider, for the object pool
|
1223 |
+
|
1224 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1225 |
+
const trajectoryGroup = new THREE.Group();
|
1226 |
+
|
1227 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1228 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1229 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
|
1230 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
1231 |
+
trajectoryGroup.add(positionMarker);
|
1232 |
+
|
1233 |
+
// High-Performance Line (default)
|
1234 |
+
const simpleLineGeometry = new THREE.BufferGeometry();
|
1235 |
+
const simpleLinePositions = new Float32Array(maxHistory * 3);
|
1236 |
+
simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
|
1237 |
+
const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
|
1238 |
+
simpleLine.frustumCulled = false;
|
1239 |
+
trajectoryGroup.add(simpleLine);
|
1240 |
+
|
1241 |
+
// High-Quality Line Segments (for rich trail)
|
1242 |
+
const lineSegments = [];
|
1243 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
1244 |
+
|
1245 |
+
// Create a pool of line segment objects
|
1246 |
+
for (let j = 0; j < maxHistory - 1; j++) {
|
1247 |
+
const lineGeometry = new THREE.LineGeometry();
|
1248 |
+
lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
|
1249 |
+
const lineMaterial = new THREE.LineMaterial({
|
1250 |
+
color: palette[i],
|
1251 |
+
linewidth: lineWidth,
|
1252 |
+
resolution: resolution,
|
1253 |
+
transparent: true,
|
1254 |
+
depthWrite: false, // Correctly handle transparency
|
1255 |
+
opacity: 0
|
1256 |
+
});
|
1257 |
+
const segment = new THREE.Line2(lineGeometry, lineMaterial);
|
1258 |
+
segment.frustumCulled = false;
|
1259 |
+
segment.visible = false; // Start with all segments hidden
|
1260 |
+
trajectoryGroup.add(segment);
|
1261 |
+
lineSegments.push(segment);
|
1262 |
+
}
|
1263 |
+
|
1264 |
+
trajectoryGroup.userData = {
|
1265 |
+
marker: positionMarker,
|
1266 |
+
simpleLine: simpleLine,
|
1267 |
+
lineSegments: lineSegments,
|
1268 |
+
color: palette[i]
|
1269 |
+
};
|
1270 |
+
|
1271 |
+
this.scene.add(trajectoryGroup);
|
1272 |
+
this.trajectories.push(trajectoryGroup);
|
1273 |
+
}
|
1274 |
+
|
1275 |
+
const showTrajectory = this.ui.showTrajectory.checked;
|
1276 |
+
this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
|
1277 |
+
}
|
1278 |
+
|
1279 |
+
createColorPalette(count) {
|
1280 |
+
const colors = [];
|
1281 |
+
const hueStep = 360 / count;
|
1282 |
+
|
1283 |
+
for (let i = 0; i < count; i++) {
|
1284 |
+
const hue = (i * hueStep) % 360;
|
1285 |
+
const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
|
1286 |
+
colors.push(color);
|
1287 |
+
}
|
1288 |
+
|
1289 |
+
return colors;
|
1290 |
+
}
|
1291 |
+
|
1292 |
+
updatePointCloud(frameIndex) {
|
1293 |
+
if (!this.data || !this.pointCloud) return;
|
1294 |
+
|
1295 |
+
const positions = this.pointCloud.geometry.attributes.position.array;
|
1296 |
+
const colors = this.pointCloud.geometry.attributes.color.array;
|
1297 |
+
|
1298 |
+
const rgbVideo = this.data.rgb_video;
|
1299 |
+
const depthsRgb = this.data.depths_rgb;
|
1300 |
+
const intrinsics = this.data.intrinsics;
|
1301 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1302 |
+
|
1303 |
+
const width = this.config.resolution[0];
|
1304 |
+
const height = this.config.resolution[1];
|
1305 |
+
const numPoints = width * height;
|
1306 |
+
|
1307 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
1308 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
1309 |
+
|
1310 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1311 |
+
const transform = this.getTransformElements(invExtrMat);
|
1312 |
+
|
1313 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
1314 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
1315 |
+
|
1316 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
1317 |
+
|
1318 |
+
let validPointCount = 0;
|
1319 |
+
|
1320 |
+
for (let i = 0; i < numPoints; i++) {
|
1321 |
+
const xPix = i % width;
|
1322 |
+
const yPix = Math.floor(i / width);
|
1323 |
+
|
1324 |
+
const d0 = depthFrame[i * 3];
|
1325 |
+
const d1 = depthFrame[i * 3 + 1];
|
1326 |
+
const depthEncoded = d0 | (d1 << 8);
|
1327 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
1328 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
1329 |
+
this.config.depthRange[0];
|
1330 |
+
|
1331 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
1332 |
+
continue;
|
1333 |
+
}
|
1334 |
+
|
1335 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
1336 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
1337 |
+
const Z = depthValue;
|
1338 |
+
|
1339 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
1340 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
1341 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
1342 |
+
|
1343 |
+
const index = validPointCount * 3;
|
1344 |
+
positions[index] = tx;
|
1345 |
+
positions[index + 1] = -ty;
|
1346 |
+
positions[index + 2] = -tz;
|
1347 |
+
|
1348 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
1349 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
1350 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
1351 |
+
|
1352 |
+
validPointCount++;
|
1353 |
+
}
|
1354 |
+
|
1355 |
+
this.pointCloud.geometry.setDrawRange(0, validPointCount);
|
1356 |
+
this.pointCloud.geometry.attributes.position.needsUpdate = true;
|
1357 |
+
this.pointCloud.geometry.attributes.color.needsUpdate = true;
|
1358 |
+
this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
|
1359 |
+
|
1360 |
+
this.updateTrajectories(frameIndex);
|
1361 |
+
|
1362 |
+
// Keep History management
|
1363 |
+
this.updateHistory(frameIndex);
|
1364 |
+
|
1365 |
+
const progress = (frameIndex + 1) / this.config.totalFrames;
|
1366 |
+
this.ui.progress.style.width = `${progress * 100}%`;
|
1367 |
+
|
1368 |
+
if (this.ui.frameCounter && this.config.totalFrames) {
|
1369 |
+
this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
|
1370 |
+
}
|
1371 |
+
|
1372 |
+
this.updateCameraFrustum(frameIndex);
|
1373 |
+
}
|
1374 |
+
|
1375 |
+
updateTrajectories(frameIndex) {
|
1376 |
+
if (!this.data.trajectories || this.trajectories.length === 0) return;
|
1377 |
+
|
1378 |
+
const trajectoryData = this.data.trajectories.data;
|
1379 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
1380 |
+
const historyFrames = parseInt(this.ui.trajectoryHistory.value);
|
1381 |
+
const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
|
1382 |
+
|
1383 |
+
const isRichMode = this.ui.enableRichTrail.checked;
|
1384 |
+
|
1385 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1386 |
+
const trajectoryGroup = this.trajectories[i];
|
1387 |
+
const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
|
1388 |
+
|
1389 |
+
const currentPos = new THREE.Vector3();
|
1390 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
1391 |
+
|
1392 |
+
currentPos.x = trajectoryData[currentOffset];
|
1393 |
+
currentPos.y = -trajectoryData[currentOffset + 1];
|
1394 |
+
currentPos.z = -trajectoryData[currentOffset + 2];
|
1395 |
+
|
1396 |
+
marker.position.copy(currentPos);
|
1397 |
+
marker.material.opacity = 1.0;
|
1398 |
+
|
1399 |
+
const historyToShow = Math.min(historyFrames, frameIndex + 1);
|
1400 |
+
|
1401 |
+
if (isRichMode) {
|
1402 |
+
// --- High-Quality Mode ---
|
1403 |
+
simpleLine.visible = false;
|
1404 |
+
|
1405 |
+
for (let j = 0; j < lineSegments.length; j++) {
|
1406 |
+
const segment = lineSegments[j];
|
1407 |
+
if (j < historyToShow - 1) {
|
1408 |
+
const headFrame = frameIndex - j;
|
1409 |
+
const tailFrame = frameIndex - j - 1;
|
1410 |
+
const headOffset = (headFrame * numTrajectories + i) * 3;
|
1411 |
+
const tailOffset = (tailFrame * numTrajectories + i) * 3;
|
1412 |
+
const positions = [
|
1413 |
+
trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
|
1414 |
+
trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
|
1415 |
+
];
|
1416 |
+
segment.geometry.setPositions(positions);
|
1417 |
+
const headOpacity = 1.0;
|
1418 |
+
const normalizedAge = j / Math.max(1, historyToShow - 2);
|
1419 |
+
const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
|
1420 |
+
segment.material.opacity = Math.max(0, alpha);
|
1421 |
+
segment.visible = true;
|
1422 |
+
} else {
|
1423 |
+
segment.visible = false;
|
1424 |
+
}
|
1425 |
+
}
|
1426 |
+
} else {
|
1427 |
+
// --- Performance Mode ---
|
1428 |
+
lineSegments.forEach(s => s.visible = false);
|
1429 |
+
simpleLine.visible = true;
|
1430 |
+
|
1431 |
+
const positions = simpleLine.geometry.attributes.position.array;
|
1432 |
+
for (let j = 0; j < historyToShow; j++) {
|
1433 |
+
const historyFrame = Math.max(0, frameIndex - j);
|
1434 |
+
const offset = (historyFrame * numTrajectories + i) * 3;
|
1435 |
+
positions[j * 3] = trajectoryData[offset];
|
1436 |
+
positions[j * 3 + 1] = -trajectoryData[offset + 1];
|
1437 |
+
positions[j * 3 + 2] = -trajectoryData[offset + 2];
|
1438 |
+
}
|
1439 |
+
simpleLine.geometry.setDrawRange(0, historyToShow);
|
1440 |
+
simpleLine.geometry.attributes.position.needsUpdate = true;
|
1441 |
+
}
|
1442 |
+
}
|
1443 |
+
}
|
1444 |
+
|
1445 |
+
updateTrajectorySettings() {
|
1446 |
+
if (!this.trajectories || this.trajectories.length === 0) return;
|
1447 |
+
|
1448 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1449 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
1450 |
+
|
1451 |
+
this.trajectories.forEach(trajectoryGroup => {
|
1452 |
+
const { marker, lineSegments } = trajectoryGroup.userData;
|
1453 |
+
|
1454 |
+
marker.geometry.dispose();
|
1455 |
+
marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1456 |
+
|
1457 |
+
// Line width only affects rich mode
|
1458 |
+
lineSegments.forEach(segment => {
|
1459 |
+
if (segment.material) {
|
1460 |
+
segment.material.linewidth = lineWidth;
|
1461 |
+
}
|
1462 |
+
});
|
1463 |
+
});
|
1464 |
+
|
1465 |
+
this.updateTrajectories(this.currentFrame);
|
1466 |
+
}
|
1467 |
+
|
1468 |
+
getDepthColor(normalizedDepth) {
|
1469 |
+
const hue = (1 - normalizedDepth) * 240 / 360;
|
1470 |
+
const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
|
1471 |
+
return color;
|
1472 |
+
}
|
1473 |
+
|
1474 |
+
getFrame(typedArray, shape, frameIndex) {
|
1475 |
+
const [T, H, W, C] = shape;
|
1476 |
+
const frameSize = H * W * C;
|
1477 |
+
const offset = frameIndex * frameSize;
|
1478 |
+
return typedArray.subarray(offset, offset + frameSize);
|
1479 |
+
}
|
1480 |
+
|
1481 |
+
get3x3Matrix(typedArray, shape, frameIndex) {
|
1482 |
+
const frameSize = 9;
|
1483 |
+
const offset = frameIndex * frameSize;
|
1484 |
+
const K = [];
|
1485 |
+
for (let i = 0; i < 3; i++) {
|
1486 |
+
const row = [];
|
1487 |
+
for (let j = 0; j < 3; j++) {
|
1488 |
+
row.push(typedArray[offset + i * 3 + j]);
|
1489 |
+
}
|
1490 |
+
K.push(row);
|
1491 |
+
}
|
1492 |
+
return K;
|
1493 |
+
}
|
1494 |
+
|
1495 |
+
get4x4Matrix(typedArray, shape, frameIndex) {
|
1496 |
+
const frameSize = 16;
|
1497 |
+
const offset = frameIndex * frameSize;
|
1498 |
+
const M = [];
|
1499 |
+
for (let i = 0; i < 4; i++) {
|
1500 |
+
const row = [];
|
1501 |
+
for (let j = 0; j < 4; j++) {
|
1502 |
+
row.push(typedArray[offset + i * 4 + j]);
|
1503 |
+
}
|
1504 |
+
M.push(row);
|
1505 |
+
}
|
1506 |
+
return M;
|
1507 |
+
}
|
1508 |
+
|
1509 |
+
getTransformElements(matrix) {
|
1510 |
+
return {
|
1511 |
+
m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
|
1512 |
+
m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
|
1513 |
+
m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
|
1514 |
+
};
|
1515 |
+
}
|
1516 |
+
|
1517 |
+
togglePlayback() {
|
1518 |
+
this.isPlaying = !this.isPlaying;
|
1519 |
+
|
1520 |
+
const playIcon = document.getElementById('play-icon');
|
1521 |
+
const pauseIcon = document.getElementById('pause-icon');
|
1522 |
+
|
1523 |
+
if (this.isPlaying) {
|
1524 |
+
playIcon.style.display = 'none';
|
1525 |
+
pauseIcon.style.display = 'block';
|
1526 |
+
this.lastFrameTime = performance.now();
|
1527 |
+
} else {
|
1528 |
+
playIcon.style.display = 'block';
|
1529 |
+
pauseIcon.style.display = 'none';
|
1530 |
+
}
|
1531 |
+
}
|
1532 |
+
|
1533 |
+
cyclePlaybackSpeed() {
|
1534 |
+
const speeds = [0.5, 1, 2, 4, 8];
|
1535 |
+
const speedRates = speeds.map(s => s * this.config.baseFrameRate);
|
1536 |
+
|
1537 |
+
let currentIndex = 0;
|
1538 |
+
const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
|
1539 |
+
|
1540 |
+
for (let i = 0; i < speeds.length; i++) {
|
1541 |
+
if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
|
1542 |
+
currentIndex = i;
|
1543 |
+
}
|
1544 |
+
}
|
1545 |
+
|
1546 |
+
const nextIndex = (currentIndex + 1) % speeds.length;
|
1547 |
+
this.playbackSpeed = speedRates[nextIndex];
|
1548 |
+
this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
|
1549 |
+
|
1550 |
+
if (speeds[nextIndex] === 1) {
|
1551 |
+
this.ui.speedBtn.classList.remove('active');
|
1552 |
+
} else {
|
1553 |
+
this.ui.speedBtn.classList.add('active');
|
1554 |
+
}
|
1555 |
+
}
|
1556 |
+
|
1557 |
+
seekTo(position) {
|
1558 |
+
const frameIndex = Math.floor(position * this.config.totalFrames);
|
1559 |
+
this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
|
1560 |
+
this.updatePointCloud(this.currentFrame);
|
1561 |
+
}
|
1562 |
+
|
1563 |
+
updatePointCloudSettings() {
|
1564 |
+
if (!this.pointCloud) return;
|
1565 |
+
|
1566 |
+
const size = parseFloat(this.ui.pointSize.value);
|
1567 |
+
const opacity = parseFloat(this.ui.pointOpacity.value);
|
1568 |
+
|
1569 |
+
this.pointCloud.material.size = size;
|
1570 |
+
this.pointCloud.material.opacity = opacity;
|
1571 |
+
this.pointCloud.material.needsUpdate = true;
|
1572 |
+
|
1573 |
+
this.updatePointCloud(this.currentFrame);
|
1574 |
+
}
|
1575 |
+
|
1576 |
+
updateControls() {
|
1577 |
+
if (!this.controls) return;
|
1578 |
+
this.controls.update();
|
1579 |
+
}
|
1580 |
+
|
1581 |
+
resetView() {
|
1582 |
+
if (!this.camera || !this.controls) return;
|
1583 |
+
|
1584 |
+
// Reset camera position
|
1585 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
1586 |
+
|
1587 |
+
// Reset controls
|
1588 |
+
this.controls.reset();
|
1589 |
+
|
1590 |
+
// Set target slightly in front of camera
|
1591 |
+
this.controls.target.set(0, 0, -1);
|
1592 |
+
this.controls.update();
|
1593 |
+
|
1594 |
+
// Show status message
|
1595 |
+
this.ui.statusBar.textContent = "View reset";
|
1596 |
+
this.ui.statusBar.classList.remove('hidden');
|
1597 |
+
|
1598 |
+
// Hide status message after a few seconds
|
1599 |
+
setTimeout(() => {
|
1600 |
+
this.ui.statusBar.classList.add('hidden');
|
1601 |
+
}, 3000);
|
1602 |
+
}
|
1603 |
+
|
1604 |
+
onWindowResize() {
|
1605 |
+
if (!this.camera || !this.renderer) return;
|
1606 |
+
|
1607 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
1608 |
+
this.camera.aspect = windowAspect;
|
1609 |
+
this.camera.updateProjectionMatrix();
|
1610 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
1611 |
+
|
1612 |
+
if (this.trajectories && this.trajectories.length > 0) {
|
1613 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1614 |
+
this.trajectories.forEach(trajectory => {
|
1615 |
+
const { lineSegments } = trajectory.userData;
|
1616 |
+
if (lineSegments && lineSegments.length > 0) {
|
1617 |
+
lineSegments.forEach(segment => {
|
1618 |
+
if (segment.material && segment.material.resolution) {
|
1619 |
+
segment.material.resolution.copy(resolution);
|
1620 |
+
}
|
1621 |
+
});
|
1622 |
+
}
|
1623 |
+
});
|
1624 |
+
}
|
1625 |
+
|
1626 |
+
if (this.cameraFrustum) {
|
1627 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1628 |
+
this.cameraFrustum.children.forEach(line => {
|
1629 |
+
if (line.material && line.material.resolution) {
|
1630 |
+
line.material.resolution.copy(resolution);
|
1631 |
+
}
|
1632 |
+
});
|
1633 |
+
}
|
1634 |
+
}
|
1635 |
+
|
1636 |
+
startAnimation() {
|
1637 |
+
this.isPlaying = true;
|
1638 |
+
this.lastFrameTime = performance.now();
|
1639 |
+
|
1640 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
1641 |
+
this.controls.target.set(0, 0, -1);
|
1642 |
+
this.controls.update();
|
1643 |
+
|
1644 |
+
this.playbackSpeed = this.config.baseFrameRate;
|
1645 |
+
|
1646 |
+
document.getElementById('play-icon').style.display = 'none';
|
1647 |
+
document.getElementById('pause-icon').style.display = 'block';
|
1648 |
+
|
1649 |
+
this.animate();
|
1650 |
+
}
|
1651 |
+
|
1652 |
+
animate() {
|
1653 |
+
requestAnimationFrame(() => this.animate());
|
1654 |
+
|
1655 |
+
if (this.controls) {
|
1656 |
+
this.controls.update();
|
1657 |
+
}
|
1658 |
+
|
1659 |
+
if (this.isPlaying && this.data) {
|
1660 |
+
const now = performance.now();
|
1661 |
+
const delta = (now - this.lastFrameTime) / 1000;
|
1662 |
+
|
1663 |
+
const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
|
1664 |
+
if (framesToAdvance > 0) {
|
1665 |
+
this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
|
1666 |
+
this.lastFrameTime = now;
|
1667 |
+
this.updatePointCloud(this.currentFrame);
|
1668 |
+
}
|
1669 |
+
}
|
1670 |
+
|
1671 |
+
if (this.renderer && this.scene && this.camera) {
|
1672 |
+
this.renderer.render(this.scene, this.camera);
|
1673 |
+
}
|
1674 |
+
}
|
1675 |
+
|
1676 |
+
initCameraWithCorrectFOV() {
|
1677 |
+
const fov = this.config.fov || 60;
|
1678 |
+
|
1679 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
1680 |
+
|
1681 |
+
this.camera = new THREE.PerspectiveCamera(
|
1682 |
+
fov,
|
1683 |
+
windowAspect,
|
1684 |
+
0.1,
|
1685 |
+
10000
|
1686 |
+
);
|
1687 |
+
|
1688 |
+
this.controls.object = this.camera;
|
1689 |
+
this.controls.update();
|
1690 |
+
|
1691 |
+
this.initCameraFrustum();
|
1692 |
+
}
|
1693 |
+
|
1694 |
+
initCameraFrustum() {
|
1695 |
+
this.cameraFrustum = new THREE.Group();
|
1696 |
+
|
1697 |
+
this.scene.add(this.cameraFrustum);
|
1698 |
+
|
1699 |
+
this.initCameraFrustumGeometry();
|
1700 |
+
|
1701 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
|
1702 |
+
|
1703 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
1704 |
+
}
|
1705 |
+
|
1706 |
+
initCameraFrustumGeometry() {
|
1707 |
+
const fov = this.config.fov || 60;
|
1708 |
+
const originalAspect = this.config.original_aspect_ratio || 1.33;
|
1709 |
+
|
1710 |
+
const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
|
1711 |
+
|
1712 |
+
const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
|
1713 |
+
const halfWidth = halfHeight * originalAspect;
|
1714 |
+
|
1715 |
+
const vertices = [
|
1716 |
+
new THREE.Vector3(0, 0, 0),
|
1717 |
+
new THREE.Vector3(-halfWidth, -halfHeight, size),
|
1718 |
+
new THREE.Vector3(halfWidth, -halfHeight, size),
|
1719 |
+
new THREE.Vector3(halfWidth, halfHeight, size),
|
1720 |
+
new THREE.Vector3(-halfWidth, halfHeight, size)
|
1721 |
+
];
|
1722 |
+
|
1723 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1724 |
+
|
1725 |
+
const linePairs = [
|
1726 |
+
[1, 2], [2, 3], [3, 4], [4, 1],
|
1727 |
+
[0, 1], [0, 2], [0, 3], [0, 4]
|
1728 |
+
];
|
1729 |
+
|
1730 |
+
const colors = {
|
1731 |
+
edge: new THREE.Color(0x3366ff),
|
1732 |
+
ray: new THREE.Color(0x33cc66)
|
1733 |
+
};
|
1734 |
+
|
1735 |
+
linePairs.forEach((pair, index) => {
|
1736 |
+
const positions = [
|
1737 |
+
vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
|
1738 |
+
vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
|
1739 |
+
];
|
1740 |
+
|
1741 |
+
const lineGeometry = new THREE.LineGeometry();
|
1742 |
+
lineGeometry.setPositions(positions);
|
1743 |
+
|
1744 |
+
let color = index < 4 ? colors.edge : colors.ray;
|
1745 |
+
|
1746 |
+
const lineMaterial = new THREE.LineMaterial({
|
1747 |
+
color: color,
|
1748 |
+
linewidth: 2,
|
1749 |
+
resolution: resolution,
|
1750 |
+
dashed: false
|
1751 |
+
});
|
1752 |
+
|
1753 |
+
const line = new THREE.Line2(lineGeometry, lineMaterial);
|
1754 |
+
this.cameraFrustum.add(line);
|
1755 |
+
});
|
1756 |
+
}
|
1757 |
+
|
1758 |
+
updateCameraFrustum(frameIndex) {
|
1759 |
+
if (!this.cameraFrustum || !this.data) return;
|
1760 |
+
|
1761 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1762 |
+
if (!invExtrinsics) return;
|
1763 |
+
|
1764 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1765 |
+
|
1766 |
+
const matrix = new THREE.Matrix4();
|
1767 |
+
matrix.set(
|
1768 |
+
invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
|
1769 |
+
invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
|
1770 |
+
invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
|
1771 |
+
invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
|
1772 |
+
);
|
1773 |
+
|
1774 |
+
const position = new THREE.Vector3();
|
1775 |
+
position.setFromMatrixPosition(matrix);
|
1776 |
+
|
1777 |
+
const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
|
1778 |
+
|
1779 |
+
const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
|
1780 |
+
|
1781 |
+
const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
|
1782 |
+
|
1783 |
+
const quaternion = new THREE.Quaternion();
|
1784 |
+
quaternion.setFromRotationMatrix(finalRotation);
|
1785 |
+
|
1786 |
+
position.y = -position.y;
|
1787 |
+
position.z = -position.z;
|
1788 |
+
|
1789 |
+
this.cameraFrustum.position.copy(position);
|
1790 |
+
this.cameraFrustum.quaternion.copy(quaternion);
|
1791 |
+
|
1792 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
|
1793 |
+
|
1794 |
+
if (this.cameraFrustum.visible !== showCameraFrustum) {
|
1795 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
1796 |
+
}
|
1797 |
+
|
1798 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1799 |
+
this.cameraFrustum.children.forEach(line => {
|
1800 |
+
if (line.material && line.material.resolution) {
|
1801 |
+
line.material.resolution.copy(resolution);
|
1802 |
+
}
|
1803 |
+
});
|
1804 |
+
}
|
1805 |
+
|
1806 |
+
updateFrustumDimensions() {
|
1807 |
+
if (!this.cameraFrustum) return;
|
1808 |
+
|
1809 |
+
while(this.cameraFrustum.children.length > 0) {
|
1810 |
+
const child = this.cameraFrustum.children[0];
|
1811 |
+
if (child.geometry) child.geometry.dispose();
|
1812 |
+
if (child.material) child.material.dispose();
|
1813 |
+
this.cameraFrustum.remove(child);
|
1814 |
+
}
|
1815 |
+
|
1816 |
+
this.initCameraFrustumGeometry();
|
1817 |
+
|
1818 |
+
this.updateCameraFrustum(this.currentFrame);
|
1819 |
+
}
|
1820 |
+
|
1821 |
+
// Keep History methods
|
1822 |
+
updateHistory(frameIndex) {
|
1823 |
+
if (!this.ui.enableKeepHistory.checked || !this.data) return;
|
1824 |
+
|
1825 |
+
const stride = parseInt(this.ui.historyStride.value);
|
1826 |
+
const newHistoryFrames = this.calculateHistoryFrames(frameIndex, stride);
|
1827 |
+
|
1828 |
+
// Check if history frames changed
|
1829 |
+
if (this.arraysEqual(this.historyFrames, newHistoryFrames)) return;
|
1830 |
+
|
1831 |
+
this.clearHistory();
|
1832 |
+
this.historyFrames = newHistoryFrames;
|
1833 |
+
|
1834 |
+
// Create history point clouds and trajectories
|
1835 |
+
this.historyFrames.forEach(historyFrame => {
|
1836 |
+
if (historyFrame !== frameIndex) {
|
1837 |
+
this.createHistoryPointCloud(historyFrame);
|
1838 |
+
this.createHistoryTrajectories(historyFrame);
|
1839 |
+
}
|
1840 |
+
});
|
1841 |
+
}
|
1842 |
+
|
1843 |
+
calculateHistoryFrames(currentFrame, stride) {
|
1844 |
+
const frames = [];
|
1845 |
+
let frame = 1; // Start from frame 1
|
1846 |
+
|
1847 |
+
while (frame <= currentFrame && frames.length < this.maxHistoryFrames) {
|
1848 |
+
frames.push(frame);
|
1849 |
+
frame += stride;
|
1850 |
+
}
|
1851 |
+
|
1852 |
+
// Always include current frame
|
1853 |
+
if (!frames.includes(currentFrame)) {
|
1854 |
+
frames.push(currentFrame);
|
1855 |
+
}
|
1856 |
+
|
1857 |
+
return frames.sort((a, b) => a - b);
|
1858 |
+
}
|
1859 |
+
|
1860 |
+
createHistoryPointCloud(frameIndex) {
|
1861 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
1862 |
+
const positions = new Float32Array(numPoints * 3);
|
1863 |
+
const colors = new Float32Array(numPoints * 3);
|
1864 |
+
|
1865 |
+
const geometry = new THREE.BufferGeometry();
|
1866 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
1867 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
|
1868 |
+
|
1869 |
+
const material = new THREE.PointsMaterial({
|
1870 |
+
size: parseFloat(this.ui.pointSize.value),
|
1871 |
+
vertexColors: true,
|
1872 |
+
transparent: true,
|
1873 |
+
opacity: 0.5, // Transparent for history
|
1874 |
+
sizeAttenuation: true
|
1875 |
+
});
|
1876 |
+
|
1877 |
+
const historyPointCloud = new THREE.Points(geometry, material);
|
1878 |
+
this.scene.add(historyPointCloud);
|
1879 |
+
this.historyPointClouds.push(historyPointCloud);
|
1880 |
+
|
1881 |
+
// Update the history point cloud with data
|
1882 |
+
this.updateHistoryPointCloud(historyPointCloud, frameIndex);
|
1883 |
+
}
|
1884 |
+
|
1885 |
+
updateHistoryPointCloud(pointCloud, frameIndex) {
|
1886 |
+
const positions = pointCloud.geometry.attributes.position.array;
|
1887 |
+
const colors = pointCloud.geometry.attributes.color.array;
|
1888 |
+
|
1889 |
+
const rgbVideo = this.data.rgb_video;
|
1890 |
+
const depthsRgb = this.data.depths_rgb;
|
1891 |
+
const intrinsics = this.data.intrinsics;
|
1892 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1893 |
+
|
1894 |
+
const width = this.config.resolution[0];
|
1895 |
+
const height = this.config.resolution[1];
|
1896 |
+
const numPoints = width * height;
|
1897 |
+
|
1898 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
1899 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
1900 |
+
|
1901 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1902 |
+
const transform = this.getTransformElements(invExtrMat);
|
1903 |
+
|
1904 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
1905 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
1906 |
+
|
1907 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
1908 |
+
|
1909 |
+
let validPointCount = 0;
|
1910 |
+
|
1911 |
+
for (let i = 0; i < numPoints; i++) {
|
1912 |
+
const xPix = i % width;
|
1913 |
+
const yPix = Math.floor(i / width);
|
1914 |
+
|
1915 |
+
const d0 = depthFrame[i * 3];
|
1916 |
+
const d1 = depthFrame[i * 3 + 1];
|
1917 |
+
const depthEncoded = d0 | (d1 << 8);
|
1918 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
1919 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
1920 |
+
this.config.depthRange[0];
|
1921 |
+
|
1922 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
1923 |
+
continue;
|
1924 |
+
}
|
1925 |
+
|
1926 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
1927 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
1928 |
+
const Z = depthValue;
|
1929 |
+
|
1930 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
1931 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
1932 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
1933 |
+
|
1934 |
+
const index = validPointCount * 3;
|
1935 |
+
positions[index] = tx;
|
1936 |
+
positions[index + 1] = -ty;
|
1937 |
+
positions[index + 2] = -tz;
|
1938 |
+
|
1939 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
1940 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
1941 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
1942 |
+
|
1943 |
+
validPointCount++;
|
1944 |
+
}
|
1945 |
+
|
1946 |
+
pointCloud.geometry.setDrawRange(0, validPointCount);
|
1947 |
+
pointCloud.geometry.attributes.position.needsUpdate = true;
|
1948 |
+
pointCloud.geometry.attributes.color.needsUpdate = true;
|
1949 |
+
}
|
1950 |
+
|
1951 |
+
createHistoryTrajectories(frameIndex) {
|
1952 |
+
if (!this.data.trajectories) return;
|
1953 |
+
|
1954 |
+
const trajectoryData = this.data.trajectories.data;
|
1955 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
1956 |
+
const palette = this.createColorPalette(numTrajectories);
|
1957 |
+
|
1958 |
+
const historyTrajectoryGroup = new THREE.Group();
|
1959 |
+
|
1960 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1961 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1962 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1963 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({
|
1964 |
+
color: palette[i],
|
1965 |
+
transparent: true,
|
1966 |
+
opacity: 0.3 // Transparent for history
|
1967 |
+
});
|
1968 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
1969 |
+
|
1970 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
1971 |
+
positionMarker.position.set(
|
1972 |
+
trajectoryData[currentOffset],
|
1973 |
+
-trajectoryData[currentOffset + 1],
|
1974 |
+
-trajectoryData[currentOffset + 2]
|
1975 |
+
);
|
1976 |
+
|
1977 |
+
historyTrajectoryGroup.add(positionMarker);
|
1978 |
+
}
|
1979 |
+
|
1980 |
+
this.scene.add(historyTrajectoryGroup);
|
1981 |
+
this.historyTrajectories.push(historyTrajectoryGroup);
|
1982 |
+
}
|
1983 |
+
|
1984 |
+
clearHistory() {
|
1985 |
+
// Clear history point clouds
|
1986 |
+
this.historyPointClouds.forEach(pointCloud => {
|
1987 |
+
if (pointCloud.geometry) pointCloud.geometry.dispose();
|
1988 |
+
if (pointCloud.material) pointCloud.material.dispose();
|
1989 |
+
this.scene.remove(pointCloud);
|
1990 |
+
});
|
1991 |
+
this.historyPointClouds = [];
|
1992 |
+
|
1993 |
+
// Clear history trajectories
|
1994 |
+
this.historyTrajectories.forEach(trajectoryGroup => {
|
1995 |
+
trajectoryGroup.children.forEach(child => {
|
1996 |
+
if (child.geometry) child.geometry.dispose();
|
1997 |
+
if (child.material) child.material.dispose();
|
1998 |
+
});
|
1999 |
+
this.scene.remove(trajectoryGroup);
|
2000 |
+
});
|
2001 |
+
this.historyTrajectories = [];
|
2002 |
+
|
2003 |
+
this.historyFrames = [];
|
2004 |
+
}
|
2005 |
+
|
2006 |
+
arraysEqual(a, b) {
|
2007 |
+
if (a.length !== b.length) return false;
|
2008 |
+
for (let i = 0; i < a.length; i++) {
|
2009 |
+
if (a[i] !== b[i]) return false;
|
2010 |
+
}
|
2011 |
+
return true;
|
2012 |
+
}
|
2013 |
+
|
2014 |
+
toggleBackground() {
|
2015 |
+
const isWhiteBackground = this.ui.whiteBackground.checked;
|
2016 |
+
|
2017 |
+
if (isWhiteBackground) {
|
2018 |
+
// Switch to white background
|
2019 |
+
document.body.style.backgroundColor = '#ffffff';
|
2020 |
+
this.scene.background = new THREE.Color(0xffffff);
|
2021 |
+
|
2022 |
+
// Update UI elements for white background
|
2023 |
+
document.documentElement.style.setProperty('--bg', '#ffffff');
|
2024 |
+
document.documentElement.style.setProperty('--text', '#333333');
|
2025 |
+
document.documentElement.style.setProperty('--text-secondary', '#666666');
|
2026 |
+
document.documentElement.style.setProperty('--border', '#cccccc');
|
2027 |
+
document.documentElement.style.setProperty('--surface', '#f5f5f5');
|
2028 |
+
document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.1)');
|
2029 |
+
document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.2)');
|
2030 |
+
|
2031 |
+
// Update status bar and control panel backgrounds
|
2032 |
+
this.ui.statusBar.style.background = 'rgba(245, 245, 245, 0.9)';
|
2033 |
+
this.ui.statusBar.style.color = '#333333';
|
2034 |
+
|
2035 |
+
const controlPanel = document.getElementById('control-panel');
|
2036 |
+
if (controlPanel) {
|
2037 |
+
controlPanel.style.background = 'rgba(245, 245, 245, 0.95)';
|
2038 |
+
}
|
2039 |
+
|
2040 |
+
const settingsPanel = document.getElementById('settings-panel');
|
2041 |
+
if (settingsPanel) {
|
2042 |
+
settingsPanel.style.background = 'rgba(245, 245, 245, 0.98)';
|
2043 |
+
}
|
2044 |
+
|
2045 |
+
} else {
|
2046 |
+
// Switch back to dark background
|
2047 |
+
document.body.style.backgroundColor = '#1a1a1a';
|
2048 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
2049 |
+
|
2050 |
+
// Restore original dark theme variables
|
2051 |
+
document.documentElement.style.setProperty('--bg', '#1a1a1a');
|
2052 |
+
document.documentElement.style.setProperty('--text', '#e0e0e0');
|
2053 |
+
document.documentElement.style.setProperty('--text-secondary', '#a0a0a0');
|
2054 |
+
document.documentElement.style.setProperty('--border', '#444444');
|
2055 |
+
document.documentElement.style.setProperty('--surface', '#2c2c2c');
|
2056 |
+
document.documentElement.style.setProperty('--shadow', 'rgba(0, 0, 0, 0.2)');
|
2057 |
+
document.documentElement.style.setProperty('--shadow-hover', 'rgba(0, 0, 0, 0.3)');
|
2058 |
+
|
2059 |
+
// Restore original UI backgrounds
|
2060 |
+
this.ui.statusBar.style.background = 'rgba(30, 30, 30, 0.9)';
|
2061 |
+
this.ui.statusBar.style.color = '#e0e0e0';
|
2062 |
+
|
2063 |
+
const controlPanel = document.getElementById('control-panel');
|
2064 |
+
if (controlPanel) {
|
2065 |
+
controlPanel.style.background = 'rgba(44, 44, 44, 0.95)';
|
2066 |
+
}
|
2067 |
+
|
2068 |
+
const settingsPanel = document.getElementById('settings-panel');
|
2069 |
+
if (settingsPanel) {
|
2070 |
+
settingsPanel.style.background = 'rgba(44, 44, 44, 0.98)';
|
2071 |
+
}
|
2072 |
+
}
|
2073 |
+
|
2074 |
+
// Show status message
|
2075 |
+
this.ui.statusBar.textContent = isWhiteBackground ? "Switched to white background" : "Switched to dark background";
|
2076 |
+
this.ui.statusBar.classList.remove('hidden');
|
2077 |
+
|
2078 |
+
setTimeout(() => {
|
2079 |
+
this.ui.statusBar.classList.add('hidden');
|
2080 |
+
}, 2000);
|
2081 |
+
}
|
2082 |
+
|
2083 |
+
resetSettings() {
|
2084 |
+
if (!this.defaultSettings) return;
|
2085 |
+
|
2086 |
+
this.applyDefaultSettings();
|
2087 |
+
|
2088 |
+
// Reset background to dark theme
|
2089 |
+
if (this.ui.whiteBackground) {
|
2090 |
+
this.ui.whiteBackground.checked = false;
|
2091 |
+
this.toggleBackground();
|
2092 |
+
}
|
2093 |
+
|
2094 |
+
this.updatePointCloudSettings();
|
2095 |
+
this.updateTrajectorySettings();
|
2096 |
+
this.updateFrustumDimensions();
|
2097 |
+
|
2098 |
+
// Clear history when resetting settings
|
2099 |
+
this.clearHistory();
|
2100 |
+
|
2101 |
+
this.ui.statusBar.textContent = "Settings reset to defaults";
|
2102 |
+
this.ui.statusBar.classList.remove('hidden');
|
2103 |
+
|
2104 |
+
setTimeout(() => {
|
2105 |
+
this.ui.statusBar.classList.add('hidden');
|
2106 |
+
}, 3000);
|
2107 |
+
}
|
2108 |
+
}
|
2109 |
+
|
2110 |
+
window.addEventListener('DOMContentLoaded', () => {
|
2111 |
+
new PointCloudVisualizer();
|
2112 |
+
});
|
2113 |
+
</script>
|
2114 |
+
</body>
|
2115 |
+
</html>
|