JichenHu commited on
Commit
d021baf
·
verified ·
1 Parent(s): 1f93e83

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +450 -0
app.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
+ import functools
22
+ import os
23
+ import tempfile
24
+
25
+ import gradio as gr
26
+ import imageio as imageio
27
+ import numpy as np
28
+ import spaces
29
+ import torch as torch
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ from PIL import Image
32
+ from gradio_imageslider import ImageSlider
33
+ from tqdm import tqdm
34
+
35
+ from pathlib import Path
36
+ import gradio
37
+ from gradio.utils import get_cache_folder
38
+ from DAI.pipeline_all import DAIPipeline
39
+
40
+ from diffusers import (
41
+ AutoencoderKL,
42
+ UNet2DConditionModel,
43
+ )
44
+
45
+ from transformers import CLIPTextModel, AutoTokenizer
46
+
47
+ from DAI.controlnetvae import ControlNetVAEModel
48
+
49
+ from DAI.decoder import CustomAutoencoderKL
50
+
51
+
52
+ class Examples(gradio.helpers.Examples):
53
+ def __init__(self, *args, directory_name=None, **kwargs):
54
+ super().__init__(*args, **kwargs, _initiated_directly=False)
55
+ if directory_name is not None:
56
+ self.cached_folder = get_cache_folder() / directory_name
57
+ self.cached_file = Path(self.cached_folder) / "log.csv"
58
+ self.create()
59
+
60
+
61
+ default_seed = 2024
62
+ default_batch_size = 1
63
+
64
+ default_image_processing_resolution = 2048
65
+ default_video_out_max_frames = 60
66
+
67
+ def process_image_check(path_input):
68
+ if path_input is None:
69
+ raise gr.Error(
70
+ "Missing image in the first pane: upload a file or use one from the gallery below."
71
+ )
72
+
73
+ def resize_image(input_image, resolution):
74
+ # Ensure input_image is a PIL Image object
75
+ if not isinstance(input_image, Image.Image):
76
+ raise ValueError("input_image should be a PIL Image object")
77
+
78
+ # Convert image to numpy array
79
+ input_image_np = np.asarray(input_image)
80
+
81
+ # Get image dimensions
82
+ H, W, C = input_image_np.shape
83
+ H = float(H)
84
+ W = float(W)
85
+
86
+ # Calculate the scaling factor
87
+ k = float(resolution) / min(H, W)
88
+
89
+ # Determine new dimensions
90
+ H *= k
91
+ W *= k
92
+ H = int(np.round(H / 64.0)) * 64
93
+ W = int(np.round(W / 64.0)) * 64
94
+
95
+ # Resize the image using PIL's resize method
96
+ img = input_image.resize((W, H), Image.Resampling.LANCZOS)
97
+
98
+ return img
99
+
100
+ def process_image(
101
+ pipe,
102
+ vae_2,
103
+ path_input,
104
+ ):
105
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
106
+ print(f"Processing image {name_base}{name_ext}")
107
+
108
+ path_output_dir = tempfile.mkdtemp()
109
+ path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
110
+ input_image = Image.open(path_input)
111
+ # resolution = 0
112
+ # if max(input_image.size) < 768:
113
+ # resolution = None
114
+ resolution = None
115
+
116
+ pipe_out = pipe(
117
+ image=input_image,
118
+ prompt="remove glass reflection",
119
+ vae_2=vae_2,
120
+ processing_resolution=resolution,
121
+ )
122
+
123
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
124
+ processed_frame = (processed_frame[0] * 255).astype(np.uint8)
125
+ processed_frame = Image.fromarray(processed_frame)
126
+ processed_frame.save(path_out_png)
127
+ yield [input_image, path_out_png]
128
+
129
+ def process_video(
130
+ pipe,
131
+ vae_2,
132
+ path_input,
133
+ out_max_frames=default_video_out_max_frames,
134
+ target_fps=10,
135
+ progress=gr.Progress(),
136
+ ):
137
+ if path_input is None:
138
+ raise gr.Error(
139
+ "Missing video in the first pane: upload a file or use one from the gallery below."
140
+ )
141
+
142
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
143
+ print(f"Processing video {name_base}{name_ext}")
144
+
145
+ path_output_dir = tempfile.mkdtemp()
146
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_delight.mp4")
147
+
148
+ init_latents = None
149
+ reader, writer = None, None
150
+ try:
151
+ reader = imageio.get_reader(path_input)
152
+
153
+ meta_data = reader.get_meta_data()
154
+ fps = meta_data["fps"]
155
+ size = meta_data["size"]
156
+ duration_sec = meta_data["duration"]
157
+
158
+ writer = imageio.get_writer(path_out_vis, fps=target_fps)
159
+
160
+ out_frame_id = 0
161
+ pbar = tqdm(desc="Processing Video", total=duration_sec)
162
+
163
+ for frame_id, frame in enumerate(reader):
164
+ if frame_id % (fps // target_fps) != 0:
165
+ continue
166
+ else:
167
+ out_frame_id += 1
168
+ pbar.update(1)
169
+ if out_frame_id > out_max_frames:
170
+ break
171
+
172
+ frame_pil = Image.fromarray(frame)
173
+
174
+ resolution = None
175
+
176
+ pipe_out = pipe(
177
+ image=frame_pil,
178
+ prompt="remove glass reflection",
179
+ vae_2=vae_2,
180
+ processing_resolution=resolution,
181
+ )
182
+
183
+ if init_latents is None:
184
+ init_latents = pipe_out.gaus_noise
185
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
186
+ processed_frame = processed_frame[0]
187
+ _processed_frame = imageio.core.util.Array(processed_frame)
188
+ writer.append_data(_processed_frame)
189
+
190
+ yield (
191
+ [frame_pil, processed_frame],
192
+ None,
193
+ )
194
+ finally:
195
+
196
+ if writer is not None:
197
+ writer.close()
198
+
199
+ if reader is not None:
200
+ reader.close()
201
+
202
+ yield (
203
+ [frame_pil, processed_frame],
204
+ [path_out_vis,]
205
+ )
206
+
207
+
208
+ def run_demo_server(pipe, vae_2):
209
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2))
210
+ process_pipe_video = spaces.GPU(
211
+ functools.partial(process_video, pipe, vae_2), duration=120
212
+ )
213
+
214
+ gradio_theme = gr.themes.Default()
215
+
216
+ with gr.Blocks(
217
+ theme=gradio_theme,
218
+ title="Dereflection Any Image",
219
+ css="""
220
+ #download {
221
+ height: 118px;
222
+ }
223
+ .slider .inner {
224
+ width: 5px;
225
+ background: #FFF;
226
+ }
227
+ .viewport {
228
+ aspect-ratio: 4/3;
229
+ }
230
+ .tabs button.selected {
231
+ font-size: 20px !important;
232
+ color: crimson !important;
233
+ }
234
+ h1 {
235
+ text-align: center;
236
+ display: block;
237
+ }
238
+ h2 {
239
+ text-align: center;
240
+ display: block;
241
+ }
242
+ h3 {
243
+ text-align: center;
244
+ display: block;
245
+ }
246
+ .md_feedback li {
247
+ margin-bottom: 0px !important;
248
+ }
249
+ """,
250
+ head="""
251
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
252
+ <script>
253
+ window.dataLayer = window.dataLayer || [];
254
+ function gtag() {dataLayer.push(arguments);}
255
+ gtag('js', new Date());
256
+ gtag('config', 'G-1FWSVCGZTG');
257
+ </script>
258
+ """,
259
+ ) as demo:
260
+ gr.Markdown(
261
+ """
262
+ # Dereflection Any Image
263
+ <p align="center">
264
+ """
265
+ )
266
+
267
+ with gr.Tabs(elem_classes=["tabs"]):
268
+ with gr.Tab("Image"):
269
+ with gr.Row():
270
+ with gr.Column():
271
+ image_input = gr.Image(
272
+ label="Input Image",
273
+ type="filepath",
274
+ )
275
+ with gr.Row():
276
+ image_submit_btn = gr.Button(
277
+ value="remove reflection", variant="primary"
278
+ )
279
+ image_reset_btn = gr.Button(value="Reset")
280
+ with gr.Column():
281
+ image_output_slider = ImageSlider(
282
+ label="outputs",
283
+ type="filepath",
284
+ show_download_button=True,
285
+ show_share_button=True,
286
+ interactive=False,
287
+ elem_classes="slider",
288
+ # position=0.25,
289
+ )
290
+
291
+ Examples(
292
+ fn=process_pipe_image,
293
+ examples=sorted([
294
+ os.path.join("files", "image", name)
295
+ for name in os.listdir(os.path.join("files", "image"))
296
+ ]),
297
+ inputs=[image_input],
298
+ outputs=[image_output_slider],
299
+ cache_examples=False,
300
+ directory_name="examples_image",
301
+ )
302
+
303
+ # with gr.Tab("Video"):
304
+ # with gr.Row():
305
+ # with gr.Column():
306
+ # video_input = gr.Video(
307
+ # label="Input Video",
308
+ # sources=["upload", "webcam"],
309
+ # )
310
+ # with gr.Row():
311
+ # video_submit_btn = gr.Button(
312
+ # value="Remove reflection", variant="primary"
313
+ # )
314
+ # video_reset_btn = gr.Button(value="Reset")
315
+ # with gr.Column():
316
+ # processed_frames = ImageSlider(
317
+ # label="Realtime Visualization",
318
+ # type="filepath",
319
+ # show_download_button=True,
320
+ # show_share_button=True,
321
+ # interactive=False,
322
+ # elem_classes="slider",
323
+ # # position=0.25,
324
+ # )
325
+ # video_output_files = gr.Files(
326
+ # label="outputs",
327
+ # elem_id="download",
328
+ # interactive=False,
329
+ # )
330
+ # Examples(
331
+ # fn=process_pipe_video,
332
+ # examples=sorted([
333
+ # os.path.join("files", "video", name)
334
+ # for name in os.listdir(os.path.join("files", "video"))
335
+ # ]),
336
+ # inputs=[video_input],
337
+ # outputs=[processed_frames, video_output_files],
338
+ # directory_name="examples_video",
339
+ # cache_examples=False,
340
+ # )
341
+
342
+ ### Image tab
343
+ image_submit_btn.click(
344
+ fn=process_image_check,
345
+ inputs=image_input,
346
+ outputs=None,
347
+ preprocess=False,
348
+ queue=False,
349
+ ).success(
350
+ fn=process_pipe_image,
351
+ inputs=[
352
+ image_input,
353
+ ],
354
+ outputs=[image_output_slider],
355
+ concurrency_limit=1,
356
+ )
357
+
358
+ image_reset_btn.click(
359
+ fn=lambda: (
360
+ None,
361
+ None,
362
+ None,
363
+ ),
364
+ inputs=[],
365
+ outputs=[
366
+ image_input,
367
+ image_output_slider,
368
+ ],
369
+ queue=False,
370
+ )
371
+
372
+ ### Video tab
373
+
374
+ # video_submit_btn.click(
375
+ # fn=process_pipe_video,
376
+ # inputs=[video_input],
377
+ # outputs=[processed_frames, video_output_files],
378
+ # concurrency_limit=1,
379
+ # )
380
+
381
+ # video_reset_btn.click(
382
+ # fn=lambda: (None, None, None),
383
+ # inputs=[],
384
+ # outputs=[video_input, processed_frames, video_output_files],
385
+ # concurrency_limit=1,
386
+ # )
387
+
388
+ ### Server launch
389
+
390
+ demo.queue(
391
+ api_open=False,
392
+ ).launch(
393
+ server_name="0.0.0.0",
394
+ server_port=7860,
395
+ )
396
+
397
+
398
+ def main():
399
+ os.system("pip freeze")
400
+
401
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
402
+
403
+ weight_dtype = torch.float32
404
+ model_dir = "./weights"
405
+ pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
406
+ revision = None
407
+ variant = None
408
+ # Load the model
409
+ # normal
410
+ controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path + "/controlnet", torch_dtype=weight_dtype).to(device)
411
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path + "/unet", torch_dtype=weight_dtype).to(device)
412
+ vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path + "/vae_2", torch_dtype=weight_dtype).to(device)
413
+
414
+ # Load other components of the pipeline
415
+ vae = AutoencoderKL.from_pretrained(
416
+ pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant
417
+ ).to(device)
418
+
419
+ text_encoder = CLIPTextModel.from_pretrained(
420
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, variant=variant
421
+ ).to(device)
422
+ tokenizer = AutoTokenizer.from_pretrained(
423
+ pretrained_model_name_or_path,
424
+ subfolder="tokenizer",
425
+ revision=revision,
426
+ use_fast=False,
427
+ )
428
+ pipe = DAIPipeline(
429
+ vae=vae,
430
+ text_encoder=text_encoder,
431
+ tokenizer=tokenizer,
432
+ unet=unet,
433
+ controlnet=controlnet,
434
+ safety_checker=None,
435
+ scheduler=None,
436
+ feature_extractor=None,
437
+ t_start=0,
438
+ ).to(device)
439
+
440
+ try:
441
+ import xformers
442
+ pipe.enable_xformers_memory_efficient_attention()
443
+ except:
444
+ pass # run without xformers
445
+
446
+ run_demo_server(pipe, vae_2)
447
+
448
+
449
+ if __name__ == "__main__":
450
+ main()