JichenHu commited on
Commit
b1afd66
·
verified ·
1 Parent(s): d476f89

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -12
app.py CHANGED
@@ -1,18 +1,90 @@
1
  import gradio as gr
2
  from PIL import Image
 
 
 
 
3
 
4
- def process_image(image):
5
- # Simply return the input image without any processing
6
- return image
7
-
8
- # Create a Gradio interface
9
- interface = gr.Interface(
10
- fn=process_image,
11
- inputs=gr.Image(type="pil"),
12
- outputs=gr.Image(type="pil"),
13
- title="Simple Image Echo App",
14
- description="Upload an image and get the same image as output."
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if __name__ == "__main__":
18
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from DAI.pipeline_all import DAIPipeline
4
+ import os
5
+ import tempfile
6
+ import numpy as np
7
 
8
+ from diffusers import (
9
+ AutoencoderKL,
10
+ UNet2DConditionModel,
 
 
 
 
 
 
 
 
11
  )
12
 
13
+ from transformers import CLIPTextModel, AutoTokenizer
14
+
15
+ from DAI.controlnetvae import ControlNetVAEModel
16
+
17
+ from DAI.decoder import CustomAutoencoderKL
18
+
19
+ def process_image(pipe, vae_2, image):
20
+ # Save the input image to a temporary file
21
+ temp_input_path = tempfile.mktemp(suffix=".png")
22
+ image.save(temp_input_path)
23
+
24
+ name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
25
+ print(f"Processing image {name_base}{name_ext}")
26
+
27
+ path_output_dir = tempfile.mkdtemp()
28
+ path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
29
+ resolution = None
30
+
31
+ pipe_out = pipe(
32
+ image=image,
33
+ prompt="remove glass reflection",
34
+ vae_2=vae_2,
35
+ processing_resolution=resolution,
36
+ )
37
+
38
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
39
+ processed_frame = (processed_frame[0] * 255).astype(np.uint8)
40
+ processed_frame = Image.fromarray(processed_frame)
41
+ processed_frame.save(path_out_png)
42
+
43
+ return processed_frame
44
+
45
  if __name__ == "__main__":
46
+ pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
47
+ pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
48
+ revision = None
49
+ variant = None
50
+
51
+ # Load the model
52
+ controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet")
53
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
54
+ vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2")
55
+
56
+ vae = AutoencoderKL.from_pretrained(
57
+ pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
58
+ )
59
+
60
+ text_encoder = CLIPTextModel.from_pretrained(
61
+ pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
62
+ )
63
+ tokenizer = AutoTokenizer.from_pretrained(
64
+ pretrained_model_name_or_path2,
65
+ subfolder="tokenizer",
66
+ revision=revision,
67
+ use_fast=False,
68
+ )
69
+ pipe = DAIPipeline(
70
+ vae=vae,
71
+ text_encoder=text_encoder,
72
+ tokenizer=tokenizer,
73
+ unet=unet,
74
+ controlnet=controlnet,
75
+ safety_checker=None,
76
+ scheduler=None,
77
+ feature_extractor=None,
78
+ t_start=0,
79
+ )
80
+
81
+ # Create a Gradio interface
82
+ interface = gr.Interface(
83
+ fn=lambda image: process_image(pipe, vae_2, image),
84
+ inputs=gr.Image(type="pil"),
85
+ outputs=gr.Image(type="pil"),
86
+ title="Image Dereflection App",
87
+ description="Upload an image to remove glass reflections."
88
+ )
89
+
90
+ interface.launch()