bokyeong1015 commited on
Commit
5c762ce
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ checkpoints/BK-SDM-Small_iter50000
3
+ checkpoints/BK-SDM-Small_iter50000.tar.gz
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from demo import SdmCompressionDemo
4
+
5
+ if __name__ == "__main__":
6
+ servicer = SdmCompressionDemo()
7
+ example_list = servicer.get_example_list()
8
+
9
+ with gr.Blocks(theme='nota-ai/theme') as demo:
10
+ gr.Markdown(Path('docs/header.md').read_text())
11
+ gr.Markdown(Path('docs/description.md').read_text())
12
+ with gr.Row():
13
+ with gr.Column(variant='panel',scale=30):
14
+
15
+ text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
16
+
17
+ with gr.Row().style(equal_height=True):
18
+ generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
19
+ generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
20
+
21
+ with gr.Accordion("Advanced Settings", open=False):
22
+ negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
23
+ with gr.Row():
24
+ guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
25
+ steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
26
+ seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
27
+
28
+ with gr.Tab("Example Prompts"):
29
+ examples = gr.Examples(examples=example_list, inputs=[text])
30
+
31
+ with gr.Column(variant='panel',scale=35):
32
+ # Define original model output components
33
+ gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
34
+ original_model_output = gr.Image(label="Original Model")
35
+ with gr.Row().style(equal_height=True):
36
+ original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
37
+ original_model_error = gr.Markdown()
38
+
39
+ with gr.Column(variant='panel',scale=35):
40
+ # Define compressed model output components
41
+ gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
42
+ compressed_model_output = gr.Image(label="Compressed Model")
43
+ with gr.Row().style(equal_height=True):
44
+ compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
45
+ compressed_model_error = gr.Markdown()
46
+
47
+ inputs = [text, negative, guidance_scale, steps, seed]
48
+
49
+ # Click the generate button for original model
50
+ original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
51
+ text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
52
+ generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
53
+
54
+ # Click the generate button for compressed model
55
+ compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
56
+ text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
57
+ generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
58
+
59
+ gr.Markdown(Path('docs/footer.md').read_text())
60
+
61
+ demo.queue(concurrency_count=1)
62
+ # demo.launch()
63
+ demo.launch(share=True, auth=("test", "testasdf@@19"))
64
+
checkpoints/.gitkeep ADDED
File without changes
demo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
2
+ import torch
3
+ import copy
4
+
5
+ import time
6
+
7
+ ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4"
8
+ COMPRESSED_UNET_PATH = "checkpoints/BK-SDM-Small_iter50000"
9
+
10
+ DEVICE='cuda'
11
+ # DEVICE='cpu'
12
+
13
+ class SdmCompressionDemo:
14
+ def __init__(self) -> None:
15
+ self.device = DEVICE
16
+ self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
17
+
18
+ self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID,
19
+ torch_dtype=self.torch_dtype)
20
+ self.pipe_compressed = copy.deepcopy(self.pipe_original)
21
+ self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_PATH,
22
+ subfolder="unet",
23
+ torch_dtype=self.torch_dtype)
24
+ if 'cuda' in self.device:
25
+ self.pipe_original = self.pipe_original.to(self.device)
26
+ self.pipe_compressed = self.pipe_compressed.to(self.device)
27
+ self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
28
+
29
+ def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
30
+ generator = torch.Generator(self.device).manual_seed(seed)
31
+ start = time.time()
32
+ result = pipe(text, negative_prompt = negative, generator = generator,
33
+ guidance_scale = guidance_scale, num_inference_steps = steps)
34
+ test_time = time.time() - start
35
+
36
+ image = result.images[0]
37
+ nsfw_detected = result.nsfw_content_detected[0]
38
+ print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}")
39
+ print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ")
40
+ print("===========")
41
+
42
+ return image, nsfw_detected, format(test_time, ".2f")
43
+
44
+ def error_msg(self, nsfw_detected):
45
+ if nsfw_detected:
46
+ return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds."
47
+ else:
48
+ return self.device_msg
49
+
50
+ def check_invalid_input(self, text):
51
+ if text == '':
52
+ return True
53
+
54
+ def infer_original_model(self, text, negative, guidance_scale, steps, seed):
55
+ print(f"=== ORIG model --- seed {seed}")
56
+ if self.check_invalid_input(text):
57
+ return None, "Please enter the input prompt.", None
58
+ output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original,
59
+ text, negative, guidance_scale, steps, seed)
60
+
61
+ return output_image, self.error_msg(nsfw_detected), test_time
62
+
63
+ def infer_compressed_model(self, text, negative, guidance_scale, steps, seed):
64
+ print(f"=== COMPRESSED model --- seed {seed}")
65
+ if self.check_invalid_input(text):
66
+ return None, "Please enter the input prompt.", None
67
+ output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed,
68
+ text, negative, guidance_scale, steps, seed)
69
+
70
+ return output_image, self.error_msg(nsfw_detected), test_time
71
+
72
+
73
+ def get_example_list(self):
74
+ return [
75
+ 'a tropical bird sitting on a branch of a tree',
76
+ 'many decorative umbrellas hanging up',
77
+ 'an orange cat staring off with pretty eyes',
78
+ 'beautiful woman face with fancy makeup',
79
+ 'a decorated living room with a stylish feel',
80
+ 'a black vase holding a bouquet of roses',
81
+ 'very elegant bedroom featuring natural wood',
82
+ 'buffet-style food including cake and cheese',
83
+ 'a tall castle sitting under a cloudy sky',
84
+ 'closeup of a brown bear sitting in a grassy area',
85
+ 'a large basket with many fresh vegetables',
86
+ 'house being built with lots of wood',
87
+ 'a close up of a pizza with several toppings',
88
+ 'a golden vase with many different flows',
89
+ 'a statue of a lion face attached to brick wall',
90
+ 'something that looks particularly interesting',
91
+ 'table filled with a variety of different dishes',
92
+ 'a cinematic view of a large snowy peak',
93
+ 'a grand city in the year 2100, hyper realistic',
94
+ 'a blue eyed baby girl looking at the camera',
95
+ ]
96
+
97
+
98
+
99
+
100
+
docs/description.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ This demo showcases a compressed Stable Diffusion model (SDM) for general-purpose text-to-image synthesis. Our lightest model (**BK-SDM-Small**) achieves **36% reduced** parameters and latency. This model is bulit with (i) removing several residual and attention blocks from the U-Net of SDM and (ii) distillation pretraining on only 0.22M LAION pairs (fewer than 0.1% of the full training set). Despite very limited training resources, our model can imitate the original SDM by benefiting from transferred knowledge.
2
+
3
+ <!-- <center>
4
+ <img src="docs/fig_model.png" width="70%">
5
+ </center> -->
6
+
docs/fig_model.png ADDED
docs/footer.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <p align="center">
2
+ <a href="https://netspresso.ai/"><img src="https://huggingface.co/spaces/nota-ai/theme/resolve/main/docs/logo/nota_favicon_800x800.png" width="96px" height="96px"></a>
3
+ </p>
4
+
5
+ <br/>
docs/header.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # <center>Lightweight Text-to-Image Generation Demo</center>
2
+
3
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.1
2
+ gradio==3.31.0
3
+ diffusers==0.15.0.dev0