Commit
·
5c762ce
0
Parent(s):
first commit
Browse files- .gitignore +3 -0
- app.py +64 -0
- checkpoints/.gitkeep +0 -0
- demo.py +100 -0
- docs/description.md +6 -0
- docs/fig_model.png +0 -0
- docs/footer.md +5 -0
- docs/header.md +3 -0
- requirements.txt +3 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
checkpoints/BK-SDM-Small_iter50000
|
3 |
+
checkpoints/BK-SDM-Small_iter50000.tar.gz
|
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
from demo import SdmCompressionDemo
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
servicer = SdmCompressionDemo()
|
7 |
+
example_list = servicer.get_example_list()
|
8 |
+
|
9 |
+
with gr.Blocks(theme='nota-ai/theme') as demo:
|
10 |
+
gr.Markdown(Path('docs/header.md').read_text())
|
11 |
+
gr.Markdown(Path('docs/description.md').read_text())
|
12 |
+
with gr.Row():
|
13 |
+
with gr.Column(variant='panel',scale=30):
|
14 |
+
|
15 |
+
text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
|
16 |
+
|
17 |
+
with gr.Row().style(equal_height=True):
|
18 |
+
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
|
19 |
+
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
|
20 |
+
|
21 |
+
with gr.Accordion("Advanced Settings", open=False):
|
22 |
+
negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
|
23 |
+
with gr.Row():
|
24 |
+
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
|
25 |
+
steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
|
26 |
+
seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
|
27 |
+
|
28 |
+
with gr.Tab("Example Prompts"):
|
29 |
+
examples = gr.Examples(examples=example_list, inputs=[text])
|
30 |
+
|
31 |
+
with gr.Column(variant='panel',scale=35):
|
32 |
+
# Define original model output components
|
33 |
+
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
|
34 |
+
original_model_output = gr.Image(label="Original Model")
|
35 |
+
with gr.Row().style(equal_height=True):
|
36 |
+
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
37 |
+
original_model_error = gr.Markdown()
|
38 |
+
|
39 |
+
with gr.Column(variant='panel',scale=35):
|
40 |
+
# Define compressed model output components
|
41 |
+
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
|
42 |
+
compressed_model_output = gr.Image(label="Compressed Model")
|
43 |
+
with gr.Row().style(equal_height=True):
|
44 |
+
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
|
45 |
+
compressed_model_error = gr.Markdown()
|
46 |
+
|
47 |
+
inputs = [text, negative, guidance_scale, steps, seed]
|
48 |
+
|
49 |
+
# Click the generate button for original model
|
50 |
+
original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
|
51 |
+
text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
|
52 |
+
generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
|
53 |
+
|
54 |
+
# Click the generate button for compressed model
|
55 |
+
compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
|
56 |
+
text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
|
57 |
+
generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
|
58 |
+
|
59 |
+
gr.Markdown(Path('docs/footer.md').read_text())
|
60 |
+
|
61 |
+
demo.queue(concurrency_count=1)
|
62 |
+
# demo.launch()
|
63 |
+
demo.launch(share=True, auth=("test", "testasdf@@19"))
|
64 |
+
|
checkpoints/.gitkeep
ADDED
File without changes
|
demo.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
|
5 |
+
import time
|
6 |
+
|
7 |
+
ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4"
|
8 |
+
COMPRESSED_UNET_PATH = "checkpoints/BK-SDM-Small_iter50000"
|
9 |
+
|
10 |
+
DEVICE='cuda'
|
11 |
+
# DEVICE='cpu'
|
12 |
+
|
13 |
+
class SdmCompressionDemo:
|
14 |
+
def __init__(self) -> None:
|
15 |
+
self.device = DEVICE
|
16 |
+
self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
|
17 |
+
|
18 |
+
self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID,
|
19 |
+
torch_dtype=self.torch_dtype)
|
20 |
+
self.pipe_compressed = copy.deepcopy(self.pipe_original)
|
21 |
+
self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_PATH,
|
22 |
+
subfolder="unet",
|
23 |
+
torch_dtype=self.torch_dtype)
|
24 |
+
if 'cuda' in self.device:
|
25 |
+
self.pipe_original = self.pipe_original.to(self.device)
|
26 |
+
self.pipe_compressed = self.pipe_compressed.to(self.device)
|
27 |
+
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
|
28 |
+
|
29 |
+
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
|
30 |
+
generator = torch.Generator(self.device).manual_seed(seed)
|
31 |
+
start = time.time()
|
32 |
+
result = pipe(text, negative_prompt = negative, generator = generator,
|
33 |
+
guidance_scale = guidance_scale, num_inference_steps = steps)
|
34 |
+
test_time = time.time() - start
|
35 |
+
|
36 |
+
image = result.images[0]
|
37 |
+
nsfw_detected = result.nsfw_content_detected[0]
|
38 |
+
print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}")
|
39 |
+
print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ")
|
40 |
+
print("===========")
|
41 |
+
|
42 |
+
return image, nsfw_detected, format(test_time, ".2f")
|
43 |
+
|
44 |
+
def error_msg(self, nsfw_detected):
|
45 |
+
if nsfw_detected:
|
46 |
+
return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds."
|
47 |
+
else:
|
48 |
+
return self.device_msg
|
49 |
+
|
50 |
+
def check_invalid_input(self, text):
|
51 |
+
if text == '':
|
52 |
+
return True
|
53 |
+
|
54 |
+
def infer_original_model(self, text, negative, guidance_scale, steps, seed):
|
55 |
+
print(f"=== ORIG model --- seed {seed}")
|
56 |
+
if self.check_invalid_input(text):
|
57 |
+
return None, "Please enter the input prompt.", None
|
58 |
+
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original,
|
59 |
+
text, negative, guidance_scale, steps, seed)
|
60 |
+
|
61 |
+
return output_image, self.error_msg(nsfw_detected), test_time
|
62 |
+
|
63 |
+
def infer_compressed_model(self, text, negative, guidance_scale, steps, seed):
|
64 |
+
print(f"=== COMPRESSED model --- seed {seed}")
|
65 |
+
if self.check_invalid_input(text):
|
66 |
+
return None, "Please enter the input prompt.", None
|
67 |
+
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed,
|
68 |
+
text, negative, guidance_scale, steps, seed)
|
69 |
+
|
70 |
+
return output_image, self.error_msg(nsfw_detected), test_time
|
71 |
+
|
72 |
+
|
73 |
+
def get_example_list(self):
|
74 |
+
return [
|
75 |
+
'a tropical bird sitting on a branch of a tree',
|
76 |
+
'many decorative umbrellas hanging up',
|
77 |
+
'an orange cat staring off with pretty eyes',
|
78 |
+
'beautiful woman face with fancy makeup',
|
79 |
+
'a decorated living room with a stylish feel',
|
80 |
+
'a black vase holding a bouquet of roses',
|
81 |
+
'very elegant bedroom featuring natural wood',
|
82 |
+
'buffet-style food including cake and cheese',
|
83 |
+
'a tall castle sitting under a cloudy sky',
|
84 |
+
'closeup of a brown bear sitting in a grassy area',
|
85 |
+
'a large basket with many fresh vegetables',
|
86 |
+
'house being built with lots of wood',
|
87 |
+
'a close up of a pizza with several toppings',
|
88 |
+
'a golden vase with many different flows',
|
89 |
+
'a statue of a lion face attached to brick wall',
|
90 |
+
'something that looks particularly interesting',
|
91 |
+
'table filled with a variety of different dishes',
|
92 |
+
'a cinematic view of a large snowy peak',
|
93 |
+
'a grand city in the year 2100, hyper realistic',
|
94 |
+
'a blue eyed baby girl looking at the camera',
|
95 |
+
]
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
docs/description.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This demo showcases a compressed Stable Diffusion model (SDM) for general-purpose text-to-image synthesis. Our lightest model (**BK-SDM-Small**) achieves **36% reduced** parameters and latency. This model is bulit with (i) removing several residual and attention blocks from the U-Net of SDM and (ii) distillation pretraining on only 0.22M LAION pairs (fewer than 0.1% of the full training set). Despite very limited training resources, our model can imitate the original SDM by benefiting from transferred knowledge.
|
2 |
+
|
3 |
+
<!-- <center>
|
4 |
+
<img src="docs/fig_model.png" width="70%">
|
5 |
+
</center> -->
|
6 |
+
|
docs/fig_model.png
ADDED
![]() |
docs/footer.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<a href="https://netspresso.ai/"><img src="https://huggingface.co/spaces/nota-ai/theme/resolve/main/docs/logo/nota_favicon_800x800.png" width="96px" height="96px"></a>
|
3 |
+
</p>
|
4 |
+
|
5 |
+
<br/>
|
docs/header.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# <center>Lightweight Text-to-Image Generation Demo</center>
|
2 |
+
|
3 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
gradio==3.31.0
|
3 |
+
diffusers==0.15.0.dev0
|