JichenHu commited on
Commit
fb89072
·
verified ·
1 Parent(s): 4589910

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -79
app.py CHANGED
@@ -1,100 +1,110 @@
1
- import gradio as gr
2
- from PIL import Image
3
- from DAI.pipeline_all import DAIPipeline
4
  import os
5
- import tempfile
6
  import numpy as np
 
 
 
 
7
 
8
- from diffusers import (
9
- AutoencoderKL,
10
- UNet2DConditionModel,
11
- )
12
-
13
- from transformers import CLIPTextModel, AutoTokenizer
14
 
 
 
15
  from DAI.controlnetvae import ControlNetVAEModel
16
-
17
  from DAI.decoder import CustomAutoencoderKL
 
 
18
 
19
- def process_image(pipe, vae_2, image):
20
- # Save the input image to a temporary file
21
- temp_input_path = tempfile.mktemp(suffix=".png")
22
- image.save(temp_input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
25
- print(f"Processing image {name_base}{name_ext}")
26
 
27
- path_output_dir = tempfile.mkdtemp()
28
- path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
29
- resolution = None
30
 
 
31
  pipe_out = pipe(
32
- image=image,
33
  prompt="remove glass reflection",
34
  vae_2=vae_2,
35
- processing_resolution=resolution,
36
  )
37
 
 
38
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
39
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
40
  processed_frame = Image.fromarray(processed_frame)
41
- processed_frame.save(path_out_png)
42
 
43
- return processed_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  if __name__ == "__main__":
46
- pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
47
- pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
48
- revision = None
49
- variant = None
50
-
51
- # Load the model
52
- controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet")
53
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
54
- vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2")
55
-
56
- vae = AutoencoderKL.from_pretrained(
57
- pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
58
- )
59
-
60
- text_encoder = CLIPTextModel.from_pretrained(
61
- pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
62
- )
63
- tokenizer = AutoTokenizer.from_pretrained(
64
- pretrained_model_name_or_path2,
65
- subfolder="tokenizer",
66
- revision=revision,
67
- use_fast=False,
68
- )
69
- pipe = DAIPipeline(
70
- vae=vae,
71
- text_encoder=text_encoder,
72
- tokenizer=tokenizer,
73
- unet=unet,
74
- controlnet=controlnet,
75
- safety_checker=None,
76
- scheduler=None,
77
- feature_extractor=None,
78
- t_start=0,
79
- )
80
-
81
- # Cache example images in memory
82
- example_images_dir = "files/image"
83
- example_images = []
84
- for i in range(1, 9):
85
- image_path = os.path.join(example_images_dir, f"{i}.png")
86
- if os.path.exists(image_path):
87
- example_images.append([Image.open(image_path)])
88
-
89
- # Create a Gradio interface
90
- interface = gr.Interface(
91
- fn=lambda image: process_image(pipe, vae_2, image),
92
- inputs=gr.Image(type="pil"),
93
- outputs=gr.Image(type="pil"),
94
- title="Dereflection Any Image",
95
- description="Upload an image to remove glass reflections.",
96
- examples=example_images,
97
- )
98
-
99
- interface.launch()
100
-
 
1
+ import spaces # 必须放在最前面
 
 
2
  import os
 
3
  import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
 
9
+ # 延迟 CUDA 初始化
10
+ weight_dtype = torch.float32
 
 
 
 
11
 
12
+ # 加载模型组件
13
+ from DAI.pipeline_all import DAIPipeline
14
  from DAI.controlnetvae import ControlNetVAEModel
 
15
  from DAI.decoder import CustomAutoencoderKL
16
+ from diffusers import AutoencoderKL, UNet2DConditionModel
17
+ from transformers import CLIPTextModel, AutoTokenizer
18
 
19
+ pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
20
+ pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
21
+
22
+ # 加载模型
23
+ controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype)
24
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype)
25
+ vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype)
26
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae")
27
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder")
28
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
29
+
30
+ # 创建推理管道
31
+ pipe = DAIPipeline(
32
+ vae=vae,
33
+ text_encoder=text_encoder,
34
+ tokenizer=tokenizer,
35
+ unet=unet,
36
+ controlnet=controlnet,
37
+ safety_checker=None,
38
+ scheduler=None,
39
+ feature_extractor=None,
40
+ t_start=0,
41
+ )
42
 
 
 
43
 
44
+ def process_image(input_image):
45
+ # Gradio 输入转换为 PIL 图像
46
+ input_image = Image.fromarray(input_image)
47
 
48
+ # 处理图像
49
  pipe_out = pipe(
50
+ image=input_image,
51
  prompt="remove glass reflection",
52
  vae_2=vae_2,
53
+ processing_resolution=None,
54
  )
55
 
56
+ # 将输出转换为图像
57
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
 
60
 
61
+ # 返回输入图像和处理后的图像
62
+ return input_image, processed_frame
63
+
64
+ # 创建 Gradio 界面
65
+ def create_gradio_interface():
66
+ # 示例图像
67
+ example_images = [
68
+ os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
69
+ ]
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# Dereflection Any Image")
73
+ with gr.Row():
74
+ with gr.Column():
75
+ input_image = gr.Image(label="Input Image", type="numpy")
76
+ submit_btn = gr.Button("Remove Reflection", variant="primary")
77
+ with gr.Column():
78
+ # 使用 ImageSlider 显示前后对比
79
+ output_slider = ImageSlider(
80
+ label="Before & After",
81
+ show_download_button=True,
82
+ show_share_button=True,
83
+ )
84
+
85
+ # 添加示例
86
+ gr.Examples(
87
+ examples=example_images,
88
+ inputs=input_image,
89
+ outputs=output_slider,
90
+ fn=process_image,
91
+ cache_examples=False, # 缓存结果以加快加载速度
92
+ label="Example Images",
93
+ )
94
+
95
+ # 绑定按钮点击事件
96
+ submit_btn.click(
97
+ fn=process_image,
98
+ inputs=input_image,
99
+ outputs=output_slider,
100
+ )
101
+
102
+ return demo
103
+
104
+ # 主函数
105
+ def main():
106
+ demo = create_gradio_interface()
107
+ demo.launch(server_name="0.0.0.0", server_port=7860)
108
 
109
  if __name__ == "__main__":
110
+ main()