Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
323e5b5
1
Parent(s):
d7edbd1
app demo
Browse files
app.py
CHANGED
@@ -70,7 +70,7 @@ class Pipeline:
|
|
70 |
self.diffusion_sampler = diffusion_sampler
|
71 |
self.resolution = resolution
|
72 |
self.classlabels2ids = classlabels2ids
|
73 |
-
|
74 |
@torch.no_grad()
|
75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
|
|
70 |
self.diffusion_sampler = diffusion_sampler
|
71 |
self.resolution = resolution
|
72 |
self.classlabels2ids = classlabels2ids
|
73 |
+
@spaces.GPU
|
74 |
@torch.no_grad()
|
75 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
76 |
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
src/diffusion/stateful_flow_matching/sharing_sampling.py
CHANGED
@@ -56,7 +56,6 @@ class EulerSampler(BaseSampler):
|
|
56 |
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
57 |
|
58 |
# init recompute
|
59 |
-
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
60 |
self.recompute_timesteps = list(range(self.num_steps))
|
61 |
|
62 |
def sharing_dp(self, net, noise, condition, uncondition):
|
@@ -143,6 +142,7 @@ class EulerSampler(BaseSampler):
|
|
143 |
return x, pooled_state_list
|
144 |
|
145 |
def __call__(self, net, noise, condition, uncondition):
|
|
|
146 |
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
147 |
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
148 |
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
|
|
56 |
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
57 |
|
58 |
# init recompute
|
|
|
59 |
self.recompute_timesteps = list(range(self.num_steps))
|
60 |
|
61 |
def sharing_dp(self, net, noise, condition, uncondition):
|
|
|
142 |
return x, pooled_state_list
|
143 |
|
144 |
def __call__(self, net, noise, condition, uncondition):
|
145 |
+
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
146 |
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
147 |
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
148 |
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|