wangshuai6 commited on
Commit
d7edbd1
·
1 Parent(s): 52d009c
app.py CHANGED
@@ -63,14 +63,14 @@ def load_model(weight_dict, denosier):
63
 
64
 
65
  class Pipeline:
66
- def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
67
  self.vae = vae
68
  self.denoiser = denoiser
69
  self.conditioner = conditioner
70
  self.diffusion_sampler = diffusion_sampler
71
  self.resolution = resolution
 
72
 
73
- @spaces.GPU
74
  @torch.no_grad()
75
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
76
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
@@ -83,7 +83,7 @@ class Pipeline:
83
  generator = torch.Generator(device="cuda").manual_seed(seed)
84
  xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
85
  with torch.no_grad():
86
- condition, uncondition = conditioner([y,]*num_images)
87
  # Sample images:
88
  samples = diffusion_sampler(denoiser, xT, condition, uncondition)
89
  samples = vae.decode(samples)
@@ -136,7 +136,15 @@ if __name__ == "__main__":
136
  vae = vae.cuda()
137
  denoiser.eval()
138
 
139
- pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
 
 
 
 
 
 
 
 
140
 
141
  with gr.Blocks() as demo:
142
  gr.Markdown("DDT")
@@ -144,12 +152,14 @@ if __name__ == "__main__":
144
  with gr.Column(scale=1):
145
  num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
146
  guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
147
- num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
148
- label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
149
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
150
  state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
151
- guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
152
- guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
 
 
153
  timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
154
  with gr.Column(scale=2):
155
  btn = gr.Button("Generate")
@@ -167,4 +177,4 @@ if __name__ == "__main__":
167
  guidance_interval_max,
168
  timeshift
169
  ], outputs=[output])
170
- demo.launch(server_name="0.0.0.0", server_port=7861)
 
63
 
64
 
65
  class Pipeline:
66
+ def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
67
  self.vae = vae
68
  self.denoiser = denoiser
69
  self.conditioner = conditioner
70
  self.diffusion_sampler = diffusion_sampler
71
  self.resolution = resolution
72
+ self.classlabels2ids = classlabels2ids
73
 
 
74
  @torch.no_grad()
75
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
76
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
 
83
  generator = torch.Generator(device="cuda").manual_seed(seed)
84
  xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
85
  with torch.no_grad():
86
+ condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
87
  # Sample images:
88
  samples = diffusion_sampler(denoiser, xT, condition, uncondition)
89
  samples = vae.decode(samples)
 
136
  vae = vae.cuda()
137
  denoiser.eval()
138
 
139
+ # read imagenet classlabels
140
+ with open("imagenet_classlabels.txt", "r") as f:
141
+ classlabels = f.readlines()
142
+ classlabels = [label.strip() for label in classlabels]
143
+
144
+ classlabels2id = {label: i for i, label in enumerate(classlabels)}
145
+ id2classlabels = {i: label for i, label in enumerate(classlabels)}
146
+
147
+ pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
148
 
149
  with gr.Blocks() as demo:
150
  gr.Markdown("DDT")
 
152
  with gr.Column(scale=1):
153
  num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
154
  guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
155
+ num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4)
156
+ label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
157
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
158
  state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
159
+ guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
160
+ value=0.0)
161
+ guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max",
162
+ value=1.0)
163
  timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
164
  with gr.Column(scale=2):
165
  btn = gr.Button("Generate")
 
177
  guidance_interval_max,
178
  timeshift
179
  ], outputs=[output])
180
+ demo.launch()
imagenet_classlabels.txt ADDED
The diff for this file is too large to render. See raw diff
 
src/diffusion/stateful_flow_matching/sharing_sampling.py CHANGED
@@ -109,7 +109,7 @@ class EulerSampler(BaseSampler):
109
  timesteps.reverse()
110
 
111
  print("recompute timesteps solved by DP: ", timesteps)
112
- return timesteps[:-1]
113
 
114
  def _impl_sampling(self, net, noise, condition, uncondition):
115
  """
 
109
  timesteps.reverse()
110
 
111
  print("recompute timesteps solved by DP: ", timesteps)
112
+ return timesteps[:-1][:self.num_recompute_timesteps]
113
 
114
  def _impl_sampling(self, net, noise, condition, uncondition):
115
  """