wangshuai6 commited on
Commit
323e5b5
·
1 Parent(s): d7edbd1
app.py CHANGED
@@ -70,7 +70,7 @@ class Pipeline:
70
  self.diffusion_sampler = diffusion_sampler
71
  self.resolution = resolution
72
  self.classlabels2ids = classlabels2ids
73
-
74
  @torch.no_grad()
75
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
76
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
 
70
  self.diffusion_sampler = diffusion_sampler
71
  self.resolution = resolution
72
  self.classlabels2ids = classlabels2ids
73
+ @spaces.GPU
74
  @torch.no_grad()
75
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
76
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
src/diffusion/stateful_flow_matching/sharing_sampling.py CHANGED
@@ -56,7 +56,6 @@ class EulerSampler(BaseSampler):
56
  logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
57
 
58
  # init recompute
59
- self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
60
  self.recompute_timesteps = list(range(self.num_steps))
61
 
62
  def sharing_dp(self, net, noise, condition, uncondition):
@@ -143,6 +142,7 @@ class EulerSampler(BaseSampler):
143
  return x, pooled_state_list
144
 
145
  def __call__(self, net, noise, condition, uncondition):
 
146
  if len(self.recompute_timesteps) != self.num_recompute_timesteps:
147
  self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
148
  denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
 
56
  logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
57
 
58
  # init recompute
 
59
  self.recompute_timesteps = list(range(self.num_steps))
60
 
61
  def sharing_dp(self, net, noise, condition, uncondition):
 
142
  return x, pooled_state_list
143
 
144
  def __call__(self, net, noise, condition, uncondition):
145
+ self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
146
  if len(self.recompute_timesteps) != self.num_recompute_timesteps:
147
  self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
148
  denoised, _ = self._impl_sampling(net, noise, condition, uncondition)