File size: 3,253 Bytes
fb89072
b1afd66
 
fb89072
 
 
d021baf
fb89072
 
b1afd66
fb89072
 
b1afd66
 
fb89072
 
b1afd66
fb89072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1afd66
 
fb89072
 
 
b1afd66
fb89072
b1afd66
fb89072
b1afd66
 
fb89072
b1afd66
 
fb89072
b1afd66
 
 
 
1dc7cc3
fb89072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dc7cc3
fb89072
 
 
 
 
1dc7cc3
fb89072
 
 
 
 
 
 
 
 
1dc7cc3
fb89072
 
 
 
 
 
 
 
b1afd66
d021baf
fb89072
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces  # 必须放在最前面
import os
import numpy as np
import torch
from PIL import Image
import gradio as gr

# 延迟 CUDA 初始化
weight_dtype = torch.float32

# 加载模型组件
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, AutoTokenizer

pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"

# 加载模型
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)

# 创建推理管道
pipe = DAIPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    controlnet=controlnet,
    safety_checker=None,
    scheduler=None,
    feature_extractor=None,
    t_start=0,
)


def process_image(input_image):
    # 将 Gradio 输入转换为 PIL 图像
    input_image = Image.fromarray(input_image)

    # 处理图像
    pipe_out = pipe(
        image=input_image,
        prompt="remove glass reflection",
        vae_2=vae_2,
        processing_resolution=None,
    )

    # 将输出转换为图像
    processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
    processed_frame = (processed_frame[0] * 255).astype(np.uint8)
    processed_frame = Image.fromarray(processed_frame)

    return processed_frame

# 创建 Gradio 界面
def create_gradio_interface():
    # 示例图像
    example_images = [
        os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
    ]

    with gr.Blocks() as demo:
        gr.Markdown("# Dereflection Any Image")
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="numpy")
                submit_btn = gr.Button("Remove Reflection", variant="primary")
            with gr.Column():
                output_image = gr.Image(label="Processed Image")

        # 添加示例
        gr.Examples(
            examples=example_images,
            inputs=input_image,
            outputs=output_image,
            fn=process_image,
            cache_examples=False,  # 缓存结果以加快加载速度
            label="Example Images",
        )

        # 绑定按钮点击事件
        submit_btn.click(
            fn=process_image,
            inputs=input_image,
            outputs=output_image,
        )

    return demo

# 主函数
def main():
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    main()