toshas commited on
Commit
b837595
·
1 Parent(s): 10ef4da

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .idea
2
+ .DS_Store
3
+ __pycache__
4
+ gradio_cached_examples
5
+ Marigold
LICENSE.txt ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
- title: Marigold Normals
3
- emoji: 🌍
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.25.0
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Marigold Normals Estimation
3
+ emoji: 🏵️
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.21.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: cc-by-sa-4.0
11
+ hf_oauth: true
12
+ hf_oauth_expiration_minutes: 43200
13
  ---
14
 
15
+ This is a demo of Marigold, the state-of-the-art normals estimator for images in the wild.
16
+ Find out more in our CVPR 2024 paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
17
+
18
+ ```
19
+ @InProceedings{ke2023repurposing,
20
+ title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
21
+ author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
22
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
23
+ year={2024}
24
+ }
25
+ ```
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import functools
22
+ import os
23
+
24
+ import spaces
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch as torch
28
+ from PIL import Image
29
+ from diffusers import UNet2DConditionModel
30
+
31
+ from gradio_imageslider import ImageSlider
32
+ from huggingface_hub import login
33
+
34
+ from marigold_normals_estimation import MarigoldNormalsPipeline
35
+
36
+
37
+ def process(
38
+ pipe,
39
+ path_input,
40
+ ensemble_size,
41
+ denoise_steps,
42
+ processing_res,
43
+ ):
44
+ input_image = Image.open(path_input)
45
+
46
+ pipe_out = pipe(
47
+ input_image,
48
+ ensemble_size=ensemble_size,
49
+ denoising_steps=denoise_steps,
50
+ processing_res=processing_res,
51
+ batch_size=1 if processing_res == 0 else 0, # TODO: do we abuse "batch size" notation here?
52
+ show_progress_bar=True,
53
+ )
54
+
55
+ normals_pred = pipe_out.normals_np
56
+ normals_colored = pipe_out.normals_colored
57
+
58
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
59
+ os.makedirs(path_output_dir, exist_ok=True)
60
+
61
+ name_base = os.path.splitext(os.path.basename(path_input))[0]
62
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_normals_fp32.npy")
63
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_normals_colored.png")
64
+
65
+ np.save(path_out_fp32, normals_pred)
66
+ normals_colored.save(path_out_vis)
67
+
68
+ return (
69
+ [path_input, path_out_vis], # TODO: should we unify and output rgb here in depth too?
70
+ [path_out_fp32, path_out_vis], # TODO: reintroduce 16bit pngs if it supports 3 channels
71
+ )
72
+
73
+
74
+ def run_demo_server(pipe):
75
+ process_pipe = spaces.GPU(functools.partial(process, pipe), duration=120)
76
+ os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
77
+
78
+ with gr.Blocks(
79
+ analytics_enabled=False,
80
+ title="Marigold Normals Estimation",
81
+ css="""
82
+ #download {
83
+ height: 118px;
84
+ }
85
+ .slider .inner {
86
+ width: 5px;
87
+ background: #FFF;
88
+ }
89
+ .viewport {
90
+ aspect-ratio: 4/3;
91
+ }
92
+ h1 {
93
+ text-align: center;
94
+ display: block;
95
+ }
96
+ h2 {
97
+ text-align: center;
98
+ display: block;
99
+ }
100
+ h3 {
101
+ text-align: center;
102
+ display: block;
103
+ }
104
+ """,
105
+ ) as demo:
106
+ gr.Markdown(
107
+ """
108
+ # Marigold Normals Estimation
109
+
110
+ <p align="center">
111
+ <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
112
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
113
+ </a>
114
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
115
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
116
+ </a>
117
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
118
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
119
+ </a>
120
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
121
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
122
+ </a>
123
+ </p>
124
+ """
125
+ )
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_image = gr.Image(
130
+ label="Input Image",
131
+ type="filepath",
132
+ )
133
+ with gr.Accordion("Advanced options", open=True):
134
+ ensemble_size = gr.Slider(
135
+ label="Ensemble size",
136
+ minimum=1,
137
+ maximum=20,
138
+ step=1,
139
+ value=10,
140
+ )
141
+ denoise_steps = gr.Slider(
142
+ label="Number of denoising steps",
143
+ minimum=10,
144
+ maximum=20,
145
+ step=1,
146
+ value=10,
147
+ )
148
+ processing_res = gr.Radio(
149
+ [
150
+ ("Native", 0),
151
+ ("Recommended", 768),
152
+ ],
153
+ label="Processing resolution",
154
+ value=768,
155
+ )
156
+ with gr.Row():
157
+ submit_btn = gr.Button(value="Compute Normals", variant="primary")
158
+ clear_btn = gr.Button(value="Clear")
159
+ with gr.Column():
160
+ output_slider = ImageSlider(
161
+ label="Predicted normals",
162
+ type="filepath",
163
+ show_download_button=True,
164
+ show_share_button=True,
165
+ interactive=False,
166
+ elem_classes="slider",
167
+ position=0.25,
168
+ )
169
+ files = gr.Files(
170
+ label="Output files",
171
+ elem_id="download",
172
+ interactive=False,
173
+ )
174
+
175
+ blocks_settings = [ensemble_size, denoise_steps, processing_res]
176
+ map_id_to_default = {b._id: b.value for b in blocks_settings}
177
+
178
+ inputs = [
179
+ input_image,
180
+ ensemble_size,
181
+ denoise_steps,
182
+ processing_res,
183
+ ]
184
+ outputs = [
185
+ submit_btn,
186
+ input_image,
187
+ output_slider,
188
+ files,
189
+ ]
190
+
191
+ def submit_normals_fn(*args):
192
+ out = list(process_pipe(*args))
193
+ out = [gr.Button(interactive=False), gr.Image(interactive=False)] + out
194
+ return out
195
+
196
+ submit_btn.click(
197
+ fn=submit_normals_fn,
198
+ inputs=inputs,
199
+ outputs=outputs,
200
+ concurrency_limit=1,
201
+ )
202
+
203
+ gr.Examples(
204
+ fn=submit_normals_fn,
205
+ examples=[
206
+ [
207
+ "files/bee.jpg",
208
+ 10, # ensemble_size
209
+ 10, # denoise_steps
210
+ 768, # processing_res
211
+ ],
212
+ [
213
+ "files/cat.jpg",
214
+ 10, # ensemble_size
215
+ 10, # denoise_steps
216
+ 768, # processing_res
217
+ ],
218
+ [
219
+ "files/swings.jpg",
220
+ 10, # ensemble_size
221
+ 10, # denoise_steps
222
+ 768, # processing_res
223
+ ],
224
+ [
225
+ "files/einstein.jpg",
226
+ 10, # ensemble_size
227
+ 10, # denoise_steps
228
+ 768, # processing_res
229
+ ],
230
+ ],
231
+ inputs=inputs,
232
+ outputs=outputs,
233
+ cache_examples=False,
234
+ )
235
+
236
+ def clear_fn():
237
+ out = []
238
+ for b in blocks_settings:
239
+ out.append(map_id_to_default[b._id])
240
+ out += [
241
+ gr.Button(interactive=True),
242
+ gr.Image(value=None, interactive=True),
243
+ None, None, None,
244
+ ]
245
+ return out
246
+
247
+ clear_btn.click(
248
+ fn=clear_fn,
249
+ inputs=[],
250
+ outputs=blocks_settings + [
251
+ submit_btn,
252
+ input_image,
253
+ output_slider,
254
+ files,
255
+ ],
256
+ )
257
+
258
+ demo.queue(
259
+ api_open=False,
260
+ ).launch(
261
+ server_name="0.0.0.0",
262
+ server_port=7860,
263
+ )
264
+
265
+
266
+ def main():
267
+ CHECKPOINT_DEPTH = "prs-eth/marigold-v1-0"
268
+ CHECKPOINT_NORMALS = "KevinQu7/marigold_normals"
269
+
270
+ if "HF_TOKEN_LOGIN" in os.environ:
271
+ login(token=os.environ["HF_TOKEN_LOGIN"])
272
+
273
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274
+
275
+ pipe = MarigoldNormalsPipeline.from_pretrained(
276
+ CHECKPOINT_DEPTH,
277
+ unet=UNet2DConditionModel.from_pretrained(
278
+ CHECKPOINT_NORMALS,
279
+ subfolder='unet',
280
+ use_auth_token=True,
281
+ )
282
+ )
283
+ try:
284
+ import xformers
285
+
286
+ pipe.enable_xformers_memory_efficient_attention()
287
+ except:
288
+ pass # run without xformers
289
+
290
+ pipe = pipe.to(device)
291
+ run_demo_server(pipe)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ main()
gradio_patches/examples.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio
4
+ from gradio.utils import get_cache_folder
5
+
6
+
7
+ class Examples(gradio.helpers.Examples):
8
+ def __init__(self, *args, directory_name=None, **kwargs):
9
+ super().__init__(*args, **kwargs, _initiated_directly=False)
10
+ if directory_name is not None:
11
+ self.cached_folder = get_cache_folder() / directory_name
12
+ self.cached_file = Path(self.cached_folder) / "log.csv"
13
+ self.create()
gradio_patches/flagging.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import json
5
+ import time
6
+ import uuid
7
+ from collections import OrderedDict
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import gradio
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ from gradio import FlaggingCallback
16
+ from gradio_client import utils as client_utils
17
+
18
+
19
+ class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
20
+ def flag(
21
+ self,
22
+ flag_data: list[Any],
23
+ flag_option: str = "",
24
+ username: str | None = None,
25
+ ) -> int:
26
+ if self.separate_dirs:
27
+ # JSONL files to support dataset preview on the Hub
28
+ current_utc_time = datetime.now(timezone.utc)
29
+ iso_format_without_microseconds = current_utc_time.strftime(
30
+ "%Y-%m-%dT%H:%M:%S"
31
+ )
32
+ milliseconds = int(current_utc_time.microsecond / 1000)
33
+ unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
34
+ if username not in (None, ""):
35
+ unique_id += f"_U_{username}"
36
+ else:
37
+ unique_id += f"_{str(uuid.uuid4())[:8]}"
38
+ components_dir = self.dataset_dir / unique_id
39
+ data_file = components_dir / "metadata.jsonl"
40
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
41
+ else:
42
+ # Unique CSV file
43
+ components_dir = self.dataset_dir
44
+ data_file = components_dir / "data.csv"
45
+ path_in_repo = None # upload at root level
46
+
47
+ return self._flag_in_dir(
48
+ data_file=data_file,
49
+ components_dir=components_dir,
50
+ path_in_repo=path_in_repo,
51
+ flag_data=flag_data,
52
+ flag_option=flag_option,
53
+ username=username or "",
54
+ )
55
+
56
+ def _deserialize_components(
57
+ self,
58
+ data_dir: Path,
59
+ flag_data: list[Any],
60
+ flag_option: str = "",
61
+ username: str = "",
62
+ ) -> tuple[dict[Any, Any], list[Any]]:
63
+ """Deserialize components and return the corresponding row for the flagged sample.
64
+
65
+ Images/audio are saved to disk as individual files.
66
+ """
67
+ # Components that can have a preview on dataset repos
68
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
69
+
70
+ # Generate the row corresponding to the flagged sample
71
+ features = OrderedDict()
72
+ row = []
73
+ for component, sample in zip(self.components, flag_data):
74
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
75
+ label = component.label or ""
76
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
77
+ save_dir.mkdir(exist_ok=True, parents=True)
78
+ deserialized = component.flag(sample, save_dir)
79
+
80
+ # Base component .flag method returns JSON; extract path from it when it is FileData
81
+ if component.data_model:
82
+ data = component.data_model.from_json(json.loads(deserialized))
83
+ if component.data_model == gr.data_classes.FileData:
84
+ deserialized = data.path
85
+
86
+ # Add deserialized object to row
87
+ features[label] = {"dtype": "string", "_type": "Value"}
88
+ try:
89
+ deserialized_path = Path(deserialized)
90
+ if not deserialized_path.exists():
91
+ raise FileNotFoundError(f"File {deserialized} not found")
92
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
93
+ except (FileNotFoundError, TypeError, ValueError):
94
+ deserialized = "" if deserialized is None else str(deserialized)
95
+ row.append(deserialized)
96
+
97
+ # If component is eligible for a preview, add the URL of the file
98
+ # Be mindful that images and audio can be None
99
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
100
+ for _component, _type in file_preview_types.items():
101
+ if isinstance(component, _component):
102
+ features[label + " file"] = {"_type": _type}
103
+ break
104
+ if deserialized:
105
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
106
+ Path(deserialized).relative_to(self.dataset_dir)
107
+ ).replace(
108
+ "\\", "/"
109
+ )
110
+ row.append(
111
+ huggingface_hub.hf_hub_url(
112
+ repo_id=self.dataset_id,
113
+ filename=path_in_repo,
114
+ repo_type="dataset",
115
+ )
116
+ )
117
+ else:
118
+ row.append("")
119
+ features["flag"] = {"dtype": "string", "_type": "Value"}
120
+ features["username"] = {"dtype": "string", "_type": "Value"}
121
+ row.append(flag_option)
122
+ row.append(username)
123
+ return features, row
124
+
125
+
126
+ class FlagMethod:
127
+ """
128
+ Helper class that contains the flagging options and calls the flagging method. Also
129
+ provides visual feedback to the user when flag is clicked.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ flagging_callback: FlaggingCallback,
135
+ label: str,
136
+ value: str,
137
+ visual_feedback: bool = True,
138
+ ):
139
+ self.flagging_callback = flagging_callback
140
+ self.label = label
141
+ self.value = value
142
+ self.__name__ = "Flag"
143
+ self.visual_feedback = visual_feedback
144
+
145
+ def __call__(
146
+ self,
147
+ request: gr.Request,
148
+ profile: gr.OAuthProfile | None,
149
+ *flag_data,
150
+ ):
151
+ username = None
152
+ if profile is not None:
153
+ username = profile.username
154
+ try:
155
+ self.flagging_callback.flag(
156
+ list(flag_data), flag_option=self.value, username=username
157
+ )
158
+ except Exception as e:
159
+ print(f"Error while sharing: {e}")
160
+ if self.visual_feedback:
161
+ return gr.Button(value="Sharing error", interactive=False)
162
+ if not self.visual_feedback:
163
+ return
164
+ time.sleep(0.8) # to provide enough time for the user to observe button change
165
+ return gr.Button(value="Sharing complete", interactive=False)
marigold_normals_estimation.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bingxin Ke, Anton Obukhov, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import math
22
+ from typing import Dict, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from torch.utils.data import DataLoader, TensorDataset
28
+ from tqdm.auto import tqdm
29
+ from transformers import CLIPTextModel, CLIPTokenizer
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ DiffusionPipeline,
35
+ UNet2DConditionModel,
36
+ )
37
+ from diffusers.utils import BaseOutput, check_min_version
38
+
39
+
40
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
41
+ check_min_version("0.27.0.dev0")
42
+
43
+
44
+ class MarigoldNormalsOutput(BaseOutput):
45
+ """
46
+ Output class for Marigold monocular normals prediction pipeline.
47
+
48
+ Args:
49
+ normals_np (`np.ndarray`):
50
+ Predicted normals map, with normals values in the range of [0, 1].
51
+ normals_colored (`None` or `PIL.Image.Image`):
52
+ Colorized normals map, with the shape of [3, H, W] and values in [0, 1].
53
+ normals_uncertainty (`None` or `np.ndarray`):
54
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
55
+ """
56
+
57
+ normals_np: np.ndarray
58
+ normals_colored: Union[None, Image.Image]
59
+ normals_uncertainty: Union[None, np.ndarray]
60
+
61
+
62
+ class MarigoldNormalsPipeline(DiffusionPipeline):
63
+ """
64
+ Pipeline for monocular normals estimation using Marigold: https://marigoldmonodepth.github.io.
65
+
66
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
67
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
68
+
69
+ Args:
70
+ unet (`UNet2DConditionModel`):
71
+ Conditional U-Net to denoise the normals latent, conditioned on image latent.
72
+ vae (`AutoencoderKL`):
73
+ Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
74
+ to and from latent representations.
75
+ scheduler (`DDIMScheduler`):
76
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
77
+ text_encoder (`CLIPTextModel`):
78
+ Text-encoder, for empty text embedding.
79
+ tokenizer (`CLIPTokenizer`):
80
+ CLIP tokenizer.
81
+ """
82
+
83
+ latent_scale_factor = 0.18215
84
+
85
+ def __init__(
86
+ self,
87
+ unet: UNet2DConditionModel,
88
+ vae: AutoencoderKL,
89
+ scheduler: DDIMScheduler,
90
+ text_encoder: CLIPTextModel,
91
+ tokenizer: CLIPTokenizer,
92
+ ):
93
+ super().__init__()
94
+
95
+ self.register_modules(
96
+ unet=unet,
97
+ vae=vae,
98
+ scheduler=scheduler,
99
+ text_encoder=text_encoder,
100
+ tokenizer=tokenizer,
101
+ )
102
+
103
+ self.empty_text_embed = None
104
+
105
+ @torch.no_grad()
106
+ def __call__(
107
+ self,
108
+ input_image: Image,
109
+ denoising_steps: int = 10,
110
+ ensemble_size: int = 10,
111
+ processing_res: int = 768,
112
+ match_input_res: bool = True,
113
+ batch_size: int = 0,
114
+ save_memory: bool = False,
115
+ color_map: str = "Spectral", # TODO change colorization api based on modality
116
+ show_progress_bar: bool = True,
117
+ ensemble_kwargs: Dict = None,
118
+ ) -> MarigoldNormalsOutput:
119
+ """
120
+ Function invoked when calling the pipeline.
121
+
122
+ Args:
123
+ input_image (`Image`):
124
+ Input RGB (or gray-scale) image.
125
+ processing_res (`int`, *optional*, defaults to `768`):
126
+ Maximum resolution of processing.
127
+ If set to 0: will not resize at all.
128
+ match_input_res (`bool`, *optional*, defaults to `True`):
129
+ Resize normals prediction to match input resolution.
130
+ Only valid if `limit_input_res` is not None.
131
+ denoising_steps (`int`, *optional*, defaults to `10`):
132
+ Number of diffusion denoising steps (DDIM) during inference.
133
+ ensemble_size (`int`, *optional*, defaults to `10`):
134
+ Number of predictions to be ensembled.
135
+ batch_size (`int`, *optional*, defaults to `0`):
136
+ Inference batch size, no bigger than `num_ensemble`.
137
+ If set to 0, the script will automatically decide the proper batch size.
138
+ save_memory (`bool`, defaults to `False`):
139
+ Extra steps to save memory at the cost of perforance.
140
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
141
+ Display a progress bar of diffusion denoising.
142
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
143
+ Colormap used to colorize the normals map.
144
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
145
+ Arguments for detailed ensembling settings.
146
+ Returns:
147
+ `MarigoldNormalsOutput`: Output class for Marigold monocular normals prediction pipeline, including:
148
+ - **normals_np** (`np.ndarray`) Predicted normals map, with normals values in the range of [-1, 1]
149
+ - **normals_colored** (`None` or `PIL.Image.Image`) Colorized normals map, with the shape of [3, H, W] and
150
+ values in [0, 1]. None if `color_map` is `None`
151
+ - **normals_uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
152
+ coming from ensembling. None if `ensemble_size = 1`
153
+ """
154
+
155
+ if not match_input_res:
156
+ assert processing_res is not None
157
+ assert processing_res >= 0
158
+ assert denoising_steps >= 1
159
+ assert ensemble_size >= 1
160
+
161
+ W, H = input_image.size
162
+
163
+ if processing_res > 0:
164
+ input_image = self.resize_max_res(
165
+ input_image, max_edge_resolution=processing_res
166
+ )
167
+ input_image = input_image.convert("RGB")
168
+ image = np.asarray(input_image)
169
+
170
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
171
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
172
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
173
+ rgb_norm = rgb_norm.to(self.device)
174
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
175
+
176
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
177
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
178
+ if batch_size > 0:
179
+ _bs = batch_size
180
+ else:
181
+ _bs = self._find_batch_size(
182
+ ensemble_size=ensemble_size,
183
+ input_res=max(rgb_norm.shape[1:]),
184
+ dtype=self.dtype,
185
+ )
186
+
187
+ single_rgb_loader = DataLoader(
188
+ single_rgb_dataset, batch_size=_bs, shuffle=False
189
+ )
190
+
191
+ pred = []
192
+ if show_progress_bar:
193
+ iterable = tqdm(
194
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
195
+ )
196
+ else:
197
+ iterable = single_rgb_loader
198
+ for batch in iterable:
199
+ (batched_img,) = batch
200
+ pred_raw = self.single_infer(
201
+ rgb_in=batched_img,
202
+ num_inference_steps=denoising_steps,
203
+ show_pbar=show_progress_bar,
204
+ )
205
+ pred_raw = pred_raw.detach()
206
+ if save_memory:
207
+ pred_raw = pred_raw.cpu()
208
+ pred.append(pred_raw)
209
+
210
+ pred = torch.concat(pred, dim=0) # [B,3,H,W]
211
+ pred_uncert = None
212
+
213
+ if save_memory:
214
+ torch.cuda.empty_cache()
215
+
216
+ if ensemble_size > 1:
217
+ pred, pred_uncert = self.ensemble_normals(
218
+ pred, **(ensemble_kwargs or {})
219
+ ) # [1,3,H,W], [1,H,W]
220
+
221
+ if match_input_res:
222
+ pred = torch.nn.functional.interpolate(
223
+ pred, (H, W), mode="bilinear"
224
+ ) # [1,3,H,W]
225
+ norm = torch.norm(pred, dim=1, keepdim=True) # [1,1,H,W]
226
+ pred /= norm.clamp(min=1e-6)
227
+
228
+ if pred_uncert is not None:
229
+ pred_uncert = torch.nn.functional.interpolate(
230
+ pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
231
+ ).squeeze(
232
+ 1
233
+ ) # [1,H,W]
234
+
235
+ # TODO: make X-axis of normals configurable through abstraction
236
+ if color_map is not None:
237
+ colored = (pred.squeeze(0) + 1.0) * 0.5
238
+ colored = (colored * 255).to(torch.uint8)
239
+ colored = self.chw2hwc(colored).cpu().numpy()
240
+ colored_img = Image.fromarray(colored)
241
+ else:
242
+ colored_img = None
243
+
244
+ if pred_uncert is not None:
245
+ pred_uncert = pred_uncert.cpu().numpy()
246
+
247
+ pred = pred.cpu().numpy() # TODO: np or torch?
248
+
249
+ out = MarigoldNormalsOutput(
250
+ normals_np=pred,
251
+ normals_colored=colored_img,
252
+ normals_uncertainty=pred_uncert,
253
+ )
254
+
255
+ return out
256
+
257
+ def _encode_empty_text(self):
258
+ """
259
+ Encode text embedding for empty prompt.
260
+ """
261
+ prompt = ""
262
+ text_inputs = self.tokenizer(
263
+ prompt,
264
+ padding="do_not_pad",
265
+ max_length=self.tokenizer.model_max_length,
266
+ truncation=True,
267
+ return_tensors="pt",
268
+ )
269
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
270
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
271
+
272
+ @torch.no_grad()
273
+ def single_infer(
274
+ self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool
275
+ ) -> torch.Tensor:
276
+ """
277
+ Perform an individual normals prediction without ensembling.
278
+
279
+ Args:
280
+ rgb_in (`torch.Tensor`):
281
+ Input RGB image.
282
+ num_inference_steps (`int`):
283
+ Number of diffusion denoisign steps (DDIM) during inference.
284
+ show_pbar (`bool`):
285
+ Display a progress bar of diffusion denoising.
286
+ Returns:
287
+ `torch.Tensor`: Predicted normals map.
288
+ """
289
+ device = rgb_in.device
290
+
291
+ # Set timesteps
292
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
293
+ timesteps = self.scheduler.timesteps # [T]
294
+
295
+ # Encode image
296
+ rgb_latent = self._encode_rgb(rgb_in)
297
+
298
+ # Initialize prediction latent with noise
299
+ pred_latent = torch.randn(
300
+ rgb_latent.shape, device=device, dtype=self.dtype
301
+ ) # [B, 4, h, w]
302
+
303
+ # Batched empty text embedding
304
+ if self.empty_text_embed is None:
305
+ self._encode_empty_text()
306
+ batch_empty_text_embed = self.empty_text_embed.repeat(
307
+ (rgb_latent.shape[0], 1, 1)
308
+ ) # [B, 2, 1024]
309
+
310
+ # Denoising loop
311
+ if show_pbar:
312
+ iterable = tqdm(
313
+ enumerate(timesteps),
314
+ total=len(timesteps),
315
+ leave=False,
316
+ desc=" " * 4 + "Diffusion denoising",
317
+ )
318
+ else:
319
+ iterable = enumerate(timesteps)
320
+
321
+ for i, t in iterable:
322
+ unet_input = torch.cat(
323
+ [rgb_latent, pred_latent], dim=1
324
+ ) # this order is important
325
+
326
+ # predict the noise residual
327
+ noise_pred = self.unet(
328
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
329
+ ).sample # [B, 4, h, w]
330
+
331
+ # compute the previous noisy sample x_t -> x_t-1
332
+ pred_latent = self.scheduler.step(noise_pred, t, pred_latent).prev_sample
333
+
334
+ # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
335
+
336
+ pred_pixels = self._decode_pred(pred_latent) # [B, 3, H, W]
337
+
338
+ return pred_pixels
339
+
340
+ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
341
+ """
342
+ Encode RGB image into latent.
343
+
344
+ Args:
345
+ rgb_in (`torch.Tensor`):
346
+ Input RGB image to be encoded.
347
+
348
+ Returns:
349
+ `torch.Tensor`: Image latent.
350
+ """
351
+ # encode
352
+ h = self.vae.encoder(rgb_in)
353
+ moments = self.vae.quant_conv(h)
354
+ mean, logvar = torch.chunk(moments, 2, dim=1)
355
+ # scale latent
356
+ rgb_latent = mean * self.latent_scale_factor
357
+ return rgb_latent
358
+
359
+ def _decode_pred(self, latent: torch.Tensor) -> torch.Tensor:
360
+ """
361
+ Decode normals latent into normals map.
362
+
363
+ Args:
364
+ latent (`torch.Tensor`):
365
+ Prediction latent to be decoded [B, 4, h, w].
366
+
367
+ Returns:
368
+ `torch.Tensor`: Decoded prediction map [B, 3, H, W].
369
+ """
370
+ # decode latent
371
+ latent = latent / self.latent_scale_factor
372
+ latent = self.vae.post_quant_conv(latent)
373
+ pixels = self.vae.decoder(latent)
374
+
375
+ # clip prediction
376
+ pixels = torch.clip(pixels, -1.0, 1.0)
377
+
378
+ # renormalize prediction
379
+ norm = torch.norm(pixels, dim=1, keepdim=True)
380
+ pixels = pixels / norm.clamp(min=1e-6)
381
+
382
+ return pixels
383
+
384
+ @staticmethod
385
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
386
+ """
387
+ Resize image to limit maximum edge length while keeping aspect ratio.
388
+
389
+ Args:
390
+ img (`Image.Image`):
391
+ Image to be resized.
392
+ max_edge_resolution (`int`):
393
+ Maximum edge length (pixel).
394
+
395
+ Returns:
396
+ `Image.Image`: Resized image.
397
+ """
398
+ original_width, original_height = img.size
399
+ downscale_factor = min(
400
+ max_edge_resolution / original_width, max_edge_resolution / original_height
401
+ )
402
+
403
+ new_width = int(original_width * downscale_factor)
404
+ new_height = int(original_height * downscale_factor)
405
+
406
+ resized_img = img.resize((new_width, new_height))
407
+ return resized_img
408
+
409
+ @staticmethod
410
+ def chw2hwc(chw):
411
+ assert 3 == len(chw.shape)
412
+ if isinstance(chw, torch.Tensor):
413
+ hwc = torch.permute(chw, (1, 2, 0))
414
+ elif isinstance(chw, np.ndarray):
415
+ hwc = np.moveaxis(chw, 0, -1)
416
+ return hwc
417
+
418
+ @staticmethod
419
+ def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
420
+ """
421
+ Automatically search for suitable operating batch size.
422
+
423
+ Args:
424
+ ensemble_size (`int`):
425
+ Number of predictions to be ensembled.
426
+ input_res (`int`):
427
+ Operating resolution of the input image.
428
+
429
+ Returns:
430
+ `int`: Operating batch size.
431
+ """
432
+ # Search table for suggested max. inference batch size
433
+ bs_search_table = [
434
+ # tested on A100-PCIE-80GB
435
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
436
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
437
+ # tested on A100-PCIE-40GB
438
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
439
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
440
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
441
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
442
+ # tested on RTX3090, RTX4090
443
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
444
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
445
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
446
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
447
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
448
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
449
+ # tested on GTX1080Ti
450
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
451
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
452
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
453
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
454
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
455
+ ]
456
+
457
+ if not torch.cuda.is_available():
458
+ return 1
459
+
460
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
461
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
462
+ for settings in sorted(
463
+ filtered_bs_search_table,
464
+ key=lambda k: (k["res"], -k["total_vram"]),
465
+ ):
466
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
467
+ bs = settings["bs"]
468
+ if bs > ensemble_size:
469
+ bs = ensemble_size
470
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
471
+ bs = math.ceil(ensemble_size / 2)
472
+ return bs
473
+
474
+ return 1
475
+
476
+ @staticmethod
477
+ def ensemble_normals(pred_normals: torch.Tensor, reduction: str = "median"):
478
+ assert reduction in ("median", "mean")
479
+
480
+ B, C, H, W = pred_normals.shape
481
+ assert C == 3
482
+
483
+ mean_normals = pred_normals.mean(dim=0, keepdim=True) # [1,3,H,W]
484
+ mean_normals_norm = mean_normals.norm(dim=1, keepdim=True) # [1,1,H,W]
485
+ mean_normals /= mean_normals_norm.clip(min=1e-6) # [1,3,H,W]
486
+
487
+ sim_cos = (mean_normals * pred_normals).sum(dim=1) # [B,H,W]
488
+ sim_acos = sim_cos.arccos() # [B,H,W]
489
+ sim_acos = sim_acos.mean(dim=0, keepdim=True) / math.pi # [1,H,W]
490
+
491
+ if reduction == "mean":
492
+ return mean_normals, sim_acos # [1,3,H,W], [1,H,W]
493
+
494
+ # Find the index of the closest normal vector for each pixel
495
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,H,W]
496
+
497
+ closest_indices = closest_indices.unsqueeze(0).repeat(1, 3, 1, 1) # [1,3,H,W]
498
+ closest_normals = torch.gather(pred_normals, 0, closest_indices)
499
+
500
+ return closest_normals, sim_acos # [1,3,H,W], [1,H,W]
requirements.txt ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.6.0
7
+ anyio==4.3.0
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ Authlib==1.3.0
11
+ certifi==2024.2.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.0.4
15
+ cmake==3.29.0.1
16
+ contourpy==1.2.0
17
+ cryptography==42.0.5
18
+ cycler==0.12.1
19
+ dataclasses-json==0.6.4
20
+ datasets==2.18.0
21
+ Deprecated==1.2.14
22
+ diffusers==0.27.2
23
+ dill==0.3.8
24
+ exceptiongroup==1.2.0
25
+ fastapi==0.110.0
26
+ ffmpy==0.3.2
27
+ filelock==3.13.3
28
+ fonttools==4.50.0
29
+ frozenlist==1.4.1
30
+ fsspec==2024.2.0
31
+ gradio==4.21.0
32
+ gradio_client==0.12.0
33
+ gradio_imageslider==0.0.18
34
+ h11==0.14.0
35
+ httpcore==1.0.5
36
+ httpx==0.27.0
37
+ huggingface-hub==0.22.1
38
+ idna==3.6
39
+ imageio==2.34.0
40
+ imageio-ffmpeg==0.4.9
41
+ importlib_metadata==7.1.0
42
+ importlib_resources==6.4.0
43
+ itsdangerous==2.1.2
44
+ Jinja2==3.1.3
45
+ jsonschema==4.21.1
46
+ jsonschema-specifications==2023.12.1
47
+ kiwisolver==1.4.5
48
+ lit==18.1.2
49
+ markdown-it-py==3.0.0
50
+ MarkupSafe==2.1.5
51
+ marshmallow==3.21.1
52
+ matplotlib==3.8.2
53
+ mdurl==0.1.2
54
+ mpmath==1.3.0
55
+ multidict==6.0.5
56
+ multiprocess==0.70.16
57
+ mypy-extensions==1.0.0
58
+ networkx==3.2.1
59
+ numpy==1.26.4
60
+ nvidia-cublas-cu11==11.10.3.66
61
+ nvidia-cuda-cupti-cu11==11.7.101
62
+ nvidia-cuda-nvrtc-cu11==11.7.99
63
+ nvidia-cuda-runtime-cu11==11.7.99
64
+ nvidia-cudnn-cu11==8.5.0.96
65
+ nvidia-cufft-cu11==10.9.0.58
66
+ nvidia-curand-cu11==10.2.10.91
67
+ nvidia-cusolver-cu11==11.4.0.1
68
+ nvidia-cusparse-cu11==11.7.4.91
69
+ nvidia-nccl-cu11==2.14.3
70
+ nvidia-nvtx-cu11==11.7.91
71
+ orjson==3.10.0
72
+ packaging==24.0
73
+ pandas==2.2.1
74
+ pillow==10.2.0
75
+ protobuf==3.20.3
76
+ psutil==5.9.8
77
+ pyarrow==15.0.2
78
+ pyarrow-hotfix==0.6
79
+ pycparser==2.22
80
+ pydantic==2.6.4
81
+ pydantic_core==2.16.3
82
+ pydub==0.25.1
83
+ pygltflib==1.16.1
84
+ Pygments==2.17.2
85
+ pyparsing==3.1.2
86
+ python-dateutil==2.9.0.post0
87
+ python-multipart==0.0.9
88
+ pytz==2024.1
89
+ PyYAML==6.0.1
90
+ referencing==0.34.0
91
+ regex==2023.12.25
92
+ requests==2.31.0
93
+ rich==13.7.1
94
+ rpds-py==0.18.0
95
+ ruff==0.3.4
96
+ safetensors==0.4.2
97
+ scipy==1.11.4
98
+ semantic-version==2.10.0
99
+ shellingham==1.5.4
100
+ six==1.16.0
101
+ sniffio==1.3.1
102
+ spaces==0.25.0
103
+ starlette==0.36.3
104
+ sympy==1.12
105
+ tokenizers==0.15.2
106
+ tomlkit==0.12.0
107
+ toolz==0.12.1
108
+ torch==2.0.1
109
+ tqdm==4.66.2
110
+ transformers==4.36.1
111
+ trimesh==4.0.5
112
+ triton==2.0.0
113
+ typer==0.12.0
114
+ typer-cli==0.12.0
115
+ typer-slim==0.12.0
116
+ typing-inspect==0.9.0
117
+ typing_extensions==4.10.0
118
+ tzdata==2024.1
119
+ urllib3==2.2.1
120
+ uvicorn==0.29.0
121
+ websockets==11.0.3
122
+ wrapt==1.16.0
123
+ xformers==0.0.21
124
+ xxhash==3.4.1
125
+ yarl==1.9.4
126
+ zipp==3.18.1
requirements_min.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.21.0
2
+ gradio-imageslider==0.0.18
3
+ pygltflib==1.16.1
4
+ trimesh==4.0.5
5
+ imageio
6
+ imageio-ffmpeg
7
+ Pillow
8
+
9
+ spaces==0.25.0
10
+ accelerate==0.25.0
11
+ diffusers==0.27.2
12
+ matplotlib==3.8.2
13
+ scipy==1.11.4
14
+ torch==2.0.1
15
+ transformers==4.36.1
16
+ xformers==0.0.21