Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
d7edbd1
1
Parent(s):
52d009c
app demo
Browse files
app.py
CHANGED
@@ -63,14 +63,14 @@ def load_model(weight_dict, denosier):
|
|
63 |
|
64 |
|
65 |
class Pipeline:
|
66 |
-
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
67 |
self.vae = vae
|
68 |
self.denoiser = denoiser
|
69 |
self.conditioner = conditioner
|
70 |
self.diffusion_sampler = diffusion_sampler
|
71 |
self.resolution = resolution
|
|
|
72 |
|
73 |
-
@spaces.GPU
|
74 |
@torch.no_grad()
|
75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
@@ -83,7 +83,7 @@ class Pipeline:
|
|
83 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
84 |
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
85 |
with torch.no_grad():
|
86 |
-
condition, uncondition = conditioner([y,]*num_images)
|
87 |
# Sample images:
|
88 |
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
89 |
samples = vae.decode(samples)
|
@@ -136,7 +136,15 @@ if __name__ == "__main__":
|
|
136 |
vae = vae.cuda()
|
137 |
denoiser.eval()
|
138 |
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
with gr.Blocks() as demo:
|
142 |
gr.Markdown("DDT")
|
@@ -144,12 +152,14 @@ if __name__ == "__main__":
|
|
144 |
with gr.Column(scale=1):
|
145 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
146 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
147 |
-
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=
|
148 |
-
label = gr.
|
149 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
150 |
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
151 |
-
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
|
152 |
-
|
|
|
|
|
153 |
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
154 |
with gr.Column(scale=2):
|
155 |
btn = gr.Button("Generate")
|
@@ -167,4 +177,4 @@ if __name__ == "__main__":
|
|
167 |
guidance_interval_max,
|
168 |
timeshift
|
169 |
], outputs=[output])
|
170 |
-
demo.launch(
|
|
|
63 |
|
64 |
|
65 |
class Pipeline:
|
66 |
+
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
|
67 |
self.vae = vae
|
68 |
self.denoiser = denoiser
|
69 |
self.conditioner = conditioner
|
70 |
self.diffusion_sampler = diffusion_sampler
|
71 |
self.resolution = resolution
|
72 |
+
self.classlabels2ids = classlabels2ids
|
73 |
|
|
|
74 |
@torch.no_grad()
|
75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
|
|
83 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
84 |
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
85 |
with torch.no_grad():
|
86 |
+
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
87 |
# Sample images:
|
88 |
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
89 |
samples = vae.decode(samples)
|
|
|
136 |
vae = vae.cuda()
|
137 |
denoiser.eval()
|
138 |
|
139 |
+
# read imagenet classlabels
|
140 |
+
with open("imagenet_classlabels.txt", "r") as f:
|
141 |
+
classlabels = f.readlines()
|
142 |
+
classlabels = [label.strip() for label in classlabels]
|
143 |
+
|
144 |
+
classlabels2id = {label: i for i, label in enumerate(classlabels)}
|
145 |
+
id2classlabels = {i: label for i, label in enumerate(classlabels)}
|
146 |
+
|
147 |
+
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
|
148 |
|
149 |
with gr.Blocks() as demo:
|
150 |
gr.Markdown("DDT")
|
|
|
152 |
with gr.Column(scale=1):
|
153 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
154 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
155 |
+
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4)
|
156 |
+
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
157 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
158 |
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
159 |
+
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
|
160 |
+
value=0.0)
|
161 |
+
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max",
|
162 |
+
value=1.0)
|
163 |
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
164 |
with gr.Column(scale=2):
|
165 |
btn = gr.Button("Generate")
|
|
|
177 |
guidance_interval_max,
|
178 |
timeshift
|
179 |
], outputs=[output])
|
180 |
+
demo.launch()
|
imagenet_classlabels.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/diffusion/stateful_flow_matching/sharing_sampling.py
CHANGED
@@ -109,7 +109,7 @@ class EulerSampler(BaseSampler):
|
|
109 |
timesteps.reverse()
|
110 |
|
111 |
print("recompute timesteps solved by DP: ", timesteps)
|
112 |
-
return timesteps[:-1]
|
113 |
|
114 |
def _impl_sampling(self, net, noise, condition, uncondition):
|
115 |
"""
|
|
|
109 |
timesteps.reverse()
|
110 |
|
111 |
print("recompute timesteps solved by DP: ", timesteps)
|
112 |
+
return timesteps[:-1][:self.num_recompute_timesteps]
|
113 |
|
114 |
def _impl_sampling(self, net, noise, condition, uncondition):
|
115 |
"""
|