Alex Ergasti commited on
Commit
67d1f09
·
1 Parent(s): 3a3fb7b
Files changed (1) hide show
  1. diffusion/rectified_flow.py +322 -0
diffusion/rectified_flow.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import math
5
+ class RectifiedFlow():
6
+ def __init__(self, num_timesteps, warmup_timesteps = 10, noise_scale=1.0, init_type='gaussian', eps=1, sampling='logit', window_size=8):
7
+ """
8
+ eps: A `float` number. The smallest time step to sample from.
9
+ """
10
+ self.num_timesteps = num_timesteps
11
+
12
+ self.warmup_timesteps = warmup_timesteps*num_timesteps
13
+ self.T = 1000.
14
+ self.noise_scale = noise_scale
15
+ self.init_type = init_type
16
+ self.eps = eps
17
+
18
+ self.window_size = window_size
19
+
20
+ self.sampling = sampling
21
+
22
+ def logit(self, x):
23
+ return torch.log(x / (1 - x))
24
+
25
+ def logit_normal(self, x, mu=0, sigma=1):
26
+ return 1 / (sigma * math.sqrt(2 * torch.pi) * x * (1 - x)) * torch.exp(-(self.logit(x) - mu) ** 2 / (2 * sigma ** 2))
27
+
28
+ def training_loss(self, model, v, a, model_kwargs):
29
+ """
30
+ v: [B, T, C, H, W]
31
+ a: [B, T, N, F]
32
+ """
33
+
34
+ B,T = v.shape[:2]
35
+
36
+ tw = torch.rand((v.shape[0],1), device=v.device)
37
+
38
+ window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v.device).unsqueeze(0).repeat(B,1)
39
+
40
+ rollout = torch.bernoulli(torch.tensor(0.8).repeat(B).to(v.device)).bool()
41
+ t_rollout = (window_indexes+tw)/self.window_size
42
+
43
+ t_pre_rollout = window_indexes/self.window_size + tw
44
+
45
+ t = torch.where(rollout.unsqueeze(1).repeat(1,self.window_size), t_rollout, t_pre_rollout)
46
+ t = 1 - t # swap 0 and 1, since 1 is full image and 0 is full noise
47
+
48
+ t = torch.clamp(t, 0+1e-6, 1-1e-6)
49
+
50
+ if self.sampling == 'logit':
51
+ weigths = self.logit_normal(t, mu=0, sigma=1)
52
+ else:
53
+ weigths = torch.ones_like(t)
54
+
55
+ B, T = t.shape
56
+
57
+ v_z0 = self.get_z0(v).to(v.device)
58
+ a_z0 = self.get_z0(a).to(a.device)
59
+
60
+ t_video = t.view(B,T,1,1,1).repeat(1,1,v.shape[2], v.shape[3], v.shape[4])
61
+ t_audio = t.view(B,T,1,1,1).repeat(1,1,a.shape[2], a.shape[3], a.shape[4])
62
+
63
+ perturbed_video = t_video*v + (1-t_video)*v_z0
64
+ perturbed_audio = t_audio*a + (1-t_audio)*a_z0
65
+
66
+ t_rf = t*(self.T-self.eps) + self.eps
67
+ score_v, score_a = model(perturbed_video, perturbed_audio, t_rf, **model_kwargs)
68
+
69
+ # score_v = [B, T, C, H, W]
70
+ # score_a = [B, T, N, F]
71
+ target_video = v - v_z0 # direction of the flow
72
+ target_audio = a - a_z0 # direction of the flow
73
+
74
+ loss_video = torch.square(score_v-target_video)
75
+ loss_audio = torch.square(score_a-target_audio)
76
+
77
+ loss_video = torch.mean(loss_video, dim=[2,3,4])
78
+ loss_audio = torch.mean(loss_audio, dim=[2,3,4])
79
+
80
+ #mask out the loss for the time steps that are greater than T
81
+
82
+ loss_video = loss_video * (weigths)
83
+ loss_video = torch.mean(loss_video)
84
+
85
+ loss_audio = loss_audio * (weigths)
86
+ loss_audio = torch.mean(loss_audio)
87
+
88
+ return {"loss": (loss_video + loss_audio)}
89
+
90
+ def sample(self, model, v_z, a_z, model_kwargs, progress=True):
91
+ B = v_z.shape[0]
92
+
93
+ window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1)
94
+
95
+
96
+ # warm up with different number of warmup timestep to be more precise
97
+ for i in tqdm(range(self.warmup_timesteps), disable=not progress):
98
+ dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
99
+
100
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
101
+
102
+ v_z = v_z.detach().clone() + dt*score_v
103
+ a_z = a_z.detach().clone() + dt*score_a
104
+
105
+ v_f = v_z[:,0]
106
+ a_f = a_z[:,0]
107
+
108
+ v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
109
+ a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
110
+
111
+ def yield_frame():
112
+ nonlocal v_z, a_z, window_indexes
113
+ yield (v_f, a_f)
114
+
115
+ dt = 1/(self.num_timesteps*self.window_size)
116
+
117
+ while True:
118
+ for i in range(self.num_timesteps):
119
+ tw = (self.num_timesteps - i)/self.num_timesteps
120
+ t = (window_indexes + tw)/self.window_size
121
+ t = 1-t
122
+
123
+ t_rf = t*(self.T-self.eps) + self.eps
124
+
125
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
126
+
127
+ v_z = v_z.detach().clone() + dt*score_v
128
+ a_z = a_z.detach().clone() + dt*score_a
129
+
130
+ v = v_z[:,0]
131
+ a = a_z[:,0]
132
+
133
+ #remove the first element
134
+ v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale
135
+ a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale
136
+
137
+ v_z = torch.cat([v_z[:,1:],v_noise], dim=1)
138
+ a_z = torch.cat([a_z[:,1:],a_noise], dim=1)
139
+
140
+ yield (v, a)
141
+
142
+ return yield_frame
143
+
144
+ def sample_a2v(self, model, v_z, a, model_kwargs, scale=1, progress=True):
145
+ B = v_z.shape[0]
146
+ window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1)
147
+
148
+ a_partial = a[:, :self.window_size]
149
+
150
+ a_noise = torch.randn_like(a, device=v_z.device)*self.noise_scale
151
+ a_noise_partial = a_noise[:, :self.window_size]
152
+
153
+
154
+ with torch.enable_grad():
155
+ # warm up with different number of warmup timestep to be more precise
156
+ for i in tqdm(range(self.warmup_timesteps), disable=not progress):
157
+ v_z = v_z.detach().requires_grad_(True)
158
+
159
+ dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
160
+
161
+ a_z = a_partial*t_partial + a_noise_partial*(1-t_partial)
162
+
163
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
164
+
165
+ loss = torch.square((a_partial-a_noise_partial)-score_a)
166
+ grad = torch.autograd.grad(loss.mean(), v_z)[0]
167
+
168
+ v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale
169
+
170
+ v_f = v_z[:,0].detach()
171
+ v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
172
+
173
+
174
+ def yield_frame():
175
+ nonlocal v_z, a, a_noise, window_indexes
176
+ yield v_f
177
+
178
+ dt = 1/(self.num_timesteps*self.window_size)
179
+
180
+ while True:
181
+ torch.cuda.empty_cache()
182
+
183
+ a = a[:,1:]
184
+ a_noise = a_noise[:,1:]
185
+ if a.shape[1] < self.window_size:
186
+ a = torch.cat([a, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
187
+ a_noise = torch.cat([a_noise, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
188
+
189
+ a_partial = a[:, :self.window_size]
190
+ a_noise_partial = a_noise[:, :self.window_size]
191
+
192
+ with torch.enable_grad():
193
+ for i in range(self.num_timesteps):
194
+ v_z = v_z.detach().requires_grad_(True)
195
+
196
+ tw = (self.num_timesteps - i)/self.num_timesteps
197
+ t = (window_indexes + tw)/self.window_size
198
+ t = 1-t
199
+
200
+ t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
201
+ t_rf = t*(self.T-self.eps) + self.eps
202
+
203
+ a_z = a_partial*t_partial + torch.randn_like(a_partial, device=v_z.device)*self.noise_scale*(1-t_partial)
204
+
205
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
206
+
207
+ loss = torch.square((a_partial-a_noise_partial)-score_a)
208
+ grad = torch.autograd.grad(loss.mean(), v_z)[0]
209
+
210
+ v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale
211
+
212
+ v = v_z[:,0].detach()
213
+
214
+ v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale
215
+ v_z = torch.cat([v_z[:,1:],v_noise], dim=1)
216
+ yield v
217
+
218
+ return yield_frame
219
+
220
+ def sample_v2a(self, model, v, a_z, model_kwargs, scale=2, progress=True):
221
+ B = a_z.shape[0]
222
+ window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=a_z.device).unsqueeze(0).repeat(B,1)
223
+
224
+ v_partial = v[:, :self.window_size]
225
+ v_noise = torch.randn_like(v, device=a_z.device)*self.noise_scale
226
+ v_noise_partial = v_noise[:, :self.window_size]
227
+
228
+ with torch.enable_grad():
229
+ # warm up with different number of warmup timestep to be more precise
230
+ for i in tqdm(range(self.warmup_timesteps), disable=not progress):
231
+ a_z = a_z.detach().requires_grad_(True)
232
+
233
+ dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
234
+
235
+ v_z = v_partial*t_partial + v_noise_partial*(1-t_partial)
236
+
237
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
238
+
239
+ loss = torch.square((v_partial-v_noise_partial)-score_v)
240
+ grad = torch.autograd.grad(loss.mean(), a_z)[0]
241
+
242
+ a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale
243
+
244
+ a_f = a_z[:,0].detach()
245
+ a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
246
+
247
+ def yield_frame():
248
+ nonlocal v, v_noise, a_z, window_indexes
249
+ yield a_f
250
+
251
+ dt = 1/(self.num_timesteps*self.window_size)
252
+ while True:
253
+ torch.cuda.empty_cache()
254
+ v = v[:,1:]
255
+ v_noise = v_noise[:,1:]
256
+
257
+ if v.shape[1] < self.window_size:
258
+ v = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
259
+ v_noise = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
260
+
261
+ v_partial = v[:, :self.window_size]
262
+ v_noise_partial = v_noise[:, :self.window_size]
263
+
264
+ with torch.enable_grad():
265
+ for i in range(self.num_timesteps):
266
+ a_z = a_z.detach().requires_grad_(True)
267
+
268
+ tw = (self.num_timesteps - i)/self.num_timesteps
269
+ t = (window_indexes + tw)/self.window_size
270
+ t = 1-t
271
+
272
+ t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
273
+ t_rf = t*(self.T-self.eps) + self.eps
274
+
275
+ v_z = v_partial*t_partial + v_noise_partial*(1-t_partial)
276
+
277
+ score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
278
+
279
+ loss = torch.square((v_partial-v_noise_partial)-score_v)
280
+ grad = torch.autograd.grad(loss.mean(), a_z)[0]
281
+
282
+ a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale
283
+
284
+ a = a_z[:,0].detach()
285
+
286
+ a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale
287
+ a_z = torch.cat([a_z[:,1:],a_noise], dim=1)
288
+
289
+
290
+ yield a
291
+
292
+ return yield_frame
293
+
294
+ def calculate_prerolling_timestep(self, window_indexes, i):
295
+ tw = (self.warmup_timesteps - i)/self.warmup_timesteps
296
+ tw_future = (self.warmup_timesteps - (i+1))/self.warmup_timesteps
297
+
298
+ t = window_indexes/self.window_size + tw
299
+
300
+ #timestep for the next iteration, to calculate dt
301
+ t_future = window_indexes/self.window_size + tw_future
302
+
303
+ #Swap 0 with 1, 1 is full image, 0 is full noise
304
+ t = 1-t
305
+ t_future = 1 - t_future
306
+
307
+ t = torch.clamp(t, 0, 1)
308
+ t_future = torch.clamp(t_future, 0, 1)
309
+
310
+ dt = torch.abs(t_future-t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [B, window_size, 1, 1, 1]
311
+
312
+ t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
313
+ t_rf= t*(self.T-self.eps) + self.eps
314
+ return dt,t_partial,t_rf
315
+
316
+ def get_z0(self, batch, train=True):
317
+
318
+ if self.init_type == 'gaussian':
319
+ ### standard gaussian #+ 0.5
320
+ return torch.randn(batch.shape)*self.noise_scale
321
+ else:
322
+ raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED")