jbilcke-hf commited on
Commit
2ff7eb5
·
verified ·
1 Parent(s): c5adc30

Update networks/generator.py

Browse files
Files changed (1) hide show
  1. networks/generator.py +31 -21
networks/generator.py CHANGED
@@ -17,6 +17,16 @@ class Generator(nn.Module):
17
  # encoder
18
  self.enc = Encoder(style_dim, motion_dim, scale)
19
  self.dec = Decoder(style_dim, motion_dim, scale)
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_alpha(self, x):
22
  return self.enc.enc_motion(x)
@@ -38,16 +48,11 @@ class Generator(nn.Module):
38
  enc_r2t_end = time.time()
39
  print(f"[Generator.edit_img] enc_r2t encoding took: {(enc_r2t_end - enc_r2t_start) * 1000:.2f} ms")
40
 
41
- # Alpha modification timing
42
  alpha_mod_start = time.time()
43
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
44
-
45
- # Current (creates tensor on CPU then moves to GPU)
46
- #alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
47
-
48
- # Optimized (create directly on GPU)
49
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.tensor(v_l, device='cuda', dtype=torch.float32).unsqueeze(0)
50
-
51
  alpha_mod_end = time.time()
52
  print(f"[Generator.edit_img] Alpha modification took: {(alpha_mod_end - alpha_mod_start) * 1000:.2f} ms")
53
 
@@ -66,13 +71,15 @@ class Generator(nn.Module):
66
  return img_recon
67
 
68
  def animate(self, img_source, vid_target, d_l, v_l):
69
-
70
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
71
 
72
  vid_target_recon = []
73
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
74
  alpha_r2s = self.enc.enc_r2t(z_s2r)
75
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
76
 
77
  for i in tqdm(range(vid_target.size(1))):
78
  img_target = vid_target[:, i, :, :, :]
@@ -84,14 +91,16 @@ class Generator(nn.Module):
84
  return vid_target_recon
85
 
86
  def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size):
87
-
88
  b,t,c,h,w = vid_target.size()
89
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40
90
 
91
  vid_target_recon = []
92
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
93
  alpha_r2s = self.enc.enc_r2t(z_s2r)
94
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
95
 
96
  bs = chunk_size
97
  chunks = t//bs
@@ -121,14 +130,16 @@ class Generator(nn.Module):
121
  return vid_target_recon # BCTHW
122
 
123
  def edit_vid(self, vid_target, d_l, v_l):
124
-
125
  img_source = vid_target[:, 0, :, :, :]
126
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
127
 
128
  vid_target_recon = []
129
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
130
  alpha_r2s = self.enc.enc_r2t(z_s2r)
131
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
132
 
133
  for i in tqdm(range(vid_target.size(1))):
134
  img_target = vid_target[:, i, :, :, :]
@@ -140,7 +151,6 @@ class Generator(nn.Module):
140
  return vid_target_recon
141
 
142
  def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size):
143
-
144
  b,t,c,h,w = vid_target.size()
145
  img_source = vid_target[:, 0, :, :, :]
146
  alpha_start = self.get_alpha(img_source) # 1x40
@@ -148,7 +158,10 @@ class Generator(nn.Module):
148
  vid_target_recon = []
149
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
150
  alpha_r2s = self.enc.enc_r2t(z_s2r)
151
- alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
 
 
 
152
 
153
  bs = chunk_size
154
  chunks = t//bs
@@ -177,9 +190,7 @@ class Generator(nn.Module):
177
 
178
  return vid_target_recon # BCTHW
179
 
180
-
181
  def interpolate_img(self, img_source, d_l, v_l):
182
-
183
  vid_target_recon = []
184
 
185
  step = 16
@@ -229,5 +240,4 @@ class Generator(nn.Module):
229
 
230
  vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
231
 
232
- return vid_target_recon
233
-
 
17
  # encoder
18
  self.enc = Encoder(style_dim, motion_dim, scale)
19
  self.dec = Decoder(style_dim, motion_dim, scale)
20
+
21
+ # Pre-allocate commonly used tensors to avoid repeated allocations
22
+ self._device = None
23
+ self._cached_tensors = {}
24
+
25
+ @property
26
+ def device(self):
27
+ if self._device is None:
28
+ self._device = next(self.parameters()).device
29
+ return self._device
30
 
31
  def get_alpha(self, x):
32
  return self.enc.enc_motion(x)
 
48
  enc_r2t_end = time.time()
49
  print(f"[Generator.edit_img] enc_r2t encoding took: {(enc_r2t_end - enc_r2t_start) * 1000:.2f} ms")
50
 
51
+ # Alpha modification timing - OPTIMIZED
52
  alpha_mod_start = time.time()
53
+ # Create tensor directly on the same device as alpha_r2s
54
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
55
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
 
 
 
 
 
56
  alpha_mod_end = time.time()
57
  print(f"[Generator.edit_img] Alpha modification took: {(alpha_mod_end - alpha_mod_start) * 1000:.2f} ms")
58
 
 
71
  return img_recon
72
 
73
  def animate(self, img_source, vid_target, d_l, v_l):
 
74
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
75
 
76
  vid_target_recon = []
77
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
78
  alpha_r2s = self.enc.enc_r2t(z_s2r)
79
+
80
+ # Optimized alpha modification
81
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
82
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
83
 
84
  for i in tqdm(range(vid_target.size(1))):
85
  img_target = vid_target[:, i, :, :, :]
 
91
  return vid_target_recon
92
 
93
  def animate_batch(self, img_source, vid_target, d_l, v_l, chunk_size):
 
94
  b,t,c,h,w = vid_target.size()
95
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :]) # 1x40
96
 
97
  vid_target_recon = []
98
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
99
  alpha_r2s = self.enc.enc_r2t(z_s2r)
100
+
101
+ # Optimized alpha modification
102
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
103
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
104
 
105
  bs = chunk_size
106
  chunks = t//bs
 
130
  return vid_target_recon # BCTHW
131
 
132
  def edit_vid(self, vid_target, d_l, v_l):
 
133
  img_source = vid_target[:, 0, :, :, :]
134
  alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
135
 
136
  vid_target_recon = []
137
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
138
  alpha_r2s = self.enc.enc_r2t(z_s2r)
139
+
140
+ # Optimized alpha modification
141
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
142
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
143
 
144
  for i in tqdm(range(vid_target.size(1))):
145
  img_target = vid_target[:, i, :, :, :]
 
151
  return vid_target_recon
152
 
153
  def edit_vid_batch(self, vid_target, d_l, v_l, chunk_size):
 
154
  b,t,c,h,w = vid_target.size()
155
  img_source = vid_target[:, 0, :, :, :]
156
  alpha_start = self.get_alpha(img_source) # 1x40
 
158
  vid_target_recon = []
159
  z_s2r, feat_rgb = self.enc.enc_2r(img_source)
160
  alpha_r2s = self.enc.enc_r2t(z_s2r)
161
+
162
+ # Optimized alpha modification
163
+ v_l_tensor = torch.tensor(v_l, device=alpha_r2s.device, dtype=alpha_r2s.dtype).unsqueeze(0)
164
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + v_l_tensor
165
 
166
  bs = chunk_size
167
  chunks = t//bs
 
190
 
191
  return vid_target_recon # BCTHW
192
 
 
193
  def interpolate_img(self, img_source, d_l, v_l):
 
194
  vid_target_recon = []
195
 
196
  step = 16
 
240
 
241
  vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
242
 
243
+ return vid_target_recon