yonishafir commited on
Commit
eda03a7
·
verified ·
1 Parent(s): af794f4

Delete ip_adapter

Browse files
ip_adapter/attention_processor.py DELETED
@@ -1,447 +0,0 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- try:
7
- import xformers
8
- import xformers.ops
9
- xformers_available = True
10
- except Exception as e:
11
- xformers_available = False
12
-
13
- class RegionControler(object):
14
- def __init__(self) -> None:
15
- self.prompt_image_conditioning = []
16
- region_control = RegionControler()
17
-
18
- class AttnProcessor(nn.Module):
19
- r"""
20
- Default processor for performing attention-related computations.
21
- """
22
- def __init__(
23
- self,
24
- hidden_size=None,
25
- cross_attention_dim=None,
26
- ):
27
- super().__init__()
28
-
29
- def forward(
30
- self,
31
- attn,
32
- hidden_states,
33
- encoder_hidden_states=None,
34
- attention_mask=None,
35
- temb=None,
36
- ):
37
- residual = hidden_states
38
-
39
- if attn.spatial_norm is not None:
40
- hidden_states = attn.spatial_norm(hidden_states, temb)
41
-
42
- input_ndim = hidden_states.ndim
43
-
44
- if input_ndim == 4:
45
- batch_size, channel, height, width = hidden_states.shape
46
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
47
-
48
- batch_size, sequence_length, _ = (
49
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
50
- )
51
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
52
-
53
- if attn.group_norm is not None:
54
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
55
-
56
- query = attn.to_q(hidden_states)
57
-
58
- if encoder_hidden_states is None:
59
- encoder_hidden_states = hidden_states
60
- elif attn.norm_cross:
61
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
-
63
- key = attn.to_k(encoder_hidden_states)
64
- value = attn.to_v(encoder_hidden_states)
65
-
66
- query = attn.head_to_batch_dim(query)
67
- key = attn.head_to_batch_dim(key)
68
- value = attn.head_to_batch_dim(value)
69
-
70
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
71
- hidden_states = torch.bmm(attention_probs, value)
72
- hidden_states = attn.batch_to_head_dim(hidden_states)
73
-
74
- # linear proj
75
- hidden_states = attn.to_out[0](hidden_states)
76
- # dropout
77
- hidden_states = attn.to_out[1](hidden_states)
78
-
79
- if input_ndim == 4:
80
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
81
-
82
- if attn.residual_connection:
83
- hidden_states = hidden_states + residual
84
-
85
- hidden_states = hidden_states / attn.rescale_output_factor
86
-
87
- return hidden_states
88
-
89
-
90
- class IPAttnProcessor(nn.Module):
91
- r"""
92
- Attention processor for IP-Adapater.
93
- Args:
94
- hidden_size (`int`):
95
- The hidden size of the attention layer.
96
- cross_attention_dim (`int`):
97
- The number of channels in the `encoder_hidden_states`.
98
- scale (`float`, defaults to 1.0):
99
- the weight scale of image prompt.
100
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
101
- The context length of the image features.
102
- """
103
-
104
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
105
- super().__init__()
106
-
107
- self.hidden_size = hidden_size
108
- self.cross_attention_dim = cross_attention_dim
109
- self.scale = scale
110
- self.num_tokens = num_tokens
111
-
112
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
113
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
-
115
- def forward(
116
- self,
117
- attn,
118
- hidden_states,
119
- encoder_hidden_states=None,
120
- attention_mask=None,
121
- temb=None,
122
- ):
123
- residual = hidden_states
124
-
125
- if attn.spatial_norm is not None:
126
- hidden_states = attn.spatial_norm(hidden_states, temb)
127
-
128
- input_ndim = hidden_states.ndim
129
-
130
- if input_ndim == 4:
131
- batch_size, channel, height, width = hidden_states.shape
132
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
133
-
134
- batch_size, sequence_length, _ = (
135
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
136
- )
137
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
138
-
139
- if attn.group_norm is not None:
140
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
141
-
142
- query = attn.to_q(hidden_states)
143
-
144
- if encoder_hidden_states is None:
145
- encoder_hidden_states = hidden_states
146
- else:
147
- # get encoder_hidden_states, ip_hidden_states
148
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
149
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
150
- if attn.norm_cross:
151
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
152
-
153
- key = attn.to_k(encoder_hidden_states)
154
- value = attn.to_v(encoder_hidden_states)
155
-
156
- query = attn.head_to_batch_dim(query)
157
- key = attn.head_to_batch_dim(key)
158
- value = attn.head_to_batch_dim(value)
159
-
160
- if xformers_available:
161
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
162
- else:
163
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
164
- hidden_states = torch.bmm(attention_probs, value)
165
- hidden_states = attn.batch_to_head_dim(hidden_states)
166
-
167
- # for ip-adapter
168
- ip_key = self.to_k_ip(ip_hidden_states)
169
- ip_value = self.to_v_ip(ip_hidden_states)
170
-
171
- ip_key = attn.head_to_batch_dim(ip_key)
172
- ip_value = attn.head_to_batch_dim(ip_value)
173
-
174
- if xformers_available:
175
- ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
176
- else:
177
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
178
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
-
181
- # region control
182
- if len(region_control.prompt_image_conditioning) == 1:
183
- region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
184
- if region_mask is not None:
185
- h, w = region_mask.shape[:2]
186
- ratio = (h * w / query.shape[1]) ** 0.5
187
- mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
188
- else:
189
- mask = torch.ones_like(ip_hidden_states)
190
- ip_hidden_states = ip_hidden_states * mask
191
-
192
- hidden_states = hidden_states + self.scale * ip_hidden_states
193
-
194
- # linear proj
195
- hidden_states = attn.to_out[0](hidden_states)
196
- # dropout
197
- hidden_states = attn.to_out[1](hidden_states)
198
-
199
- if input_ndim == 4:
200
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
201
-
202
- if attn.residual_connection:
203
- hidden_states = hidden_states + residual
204
-
205
- hidden_states = hidden_states / attn.rescale_output_factor
206
-
207
- return hidden_states
208
-
209
-
210
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
211
- # TODO attention_mask
212
- query = query.contiguous()
213
- key = key.contiguous()
214
- value = value.contiguous()
215
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
216
- # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
- return hidden_states
218
-
219
-
220
- class AttnProcessor2_0(torch.nn.Module):
221
- r"""
222
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
223
- """
224
- def __init__(
225
- self,
226
- hidden_size=None,
227
- cross_attention_dim=None,
228
- ):
229
- super().__init__()
230
- if not hasattr(F, "scaled_dot_product_attention"):
231
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
232
-
233
- def forward(
234
- self,
235
- attn,
236
- hidden_states,
237
- encoder_hidden_states=None,
238
- attention_mask=None,
239
- temb=None,
240
- ):
241
- residual = hidden_states
242
-
243
- if attn.spatial_norm is not None:
244
- hidden_states = attn.spatial_norm(hidden_states, temb)
245
-
246
- input_ndim = hidden_states.ndim
247
-
248
- if input_ndim == 4:
249
- batch_size, channel, height, width = hidden_states.shape
250
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
251
-
252
- batch_size, sequence_length, _ = (
253
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
254
- )
255
-
256
- if attention_mask is not None:
257
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
258
- # scaled_dot_product_attention expects attention_mask shape to be
259
- # (batch, heads, source_length, target_length)
260
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
261
-
262
- if attn.group_norm is not None:
263
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264
-
265
- query = attn.to_q(hidden_states)
266
-
267
- if encoder_hidden_states is None:
268
- encoder_hidden_states = hidden_states
269
- elif attn.norm_cross:
270
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
271
-
272
- key = attn.to_k(encoder_hidden_states)
273
- value = attn.to_v(encoder_hidden_states)
274
-
275
- inner_dim = key.shape[-1]
276
- head_dim = inner_dim // attn.heads
277
-
278
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279
-
280
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
-
283
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
284
- # TODO: add support for attn.scale when we move to Torch 2.1
285
- hidden_states = F.scaled_dot_product_attention(
286
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
287
- )
288
-
289
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
290
- hidden_states = hidden_states.to(query.dtype)
291
-
292
- # linear proj
293
- hidden_states = attn.to_out[0](hidden_states)
294
- # dropout
295
- hidden_states = attn.to_out[1](hidden_states)
296
-
297
- if input_ndim == 4:
298
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
299
-
300
- if attn.residual_connection:
301
- hidden_states = hidden_states + residual
302
-
303
- hidden_states = hidden_states / attn.rescale_output_factor
304
-
305
- return hidden_states
306
-
307
- class IPAttnProcessor2_0(torch.nn.Module):
308
- r"""
309
- Attention processor for IP-Adapater for PyTorch 2.0.
310
- Args:
311
- hidden_size (`int`):
312
- The hidden size of the attention layer.
313
- cross_attention_dim (`int`):
314
- The number of channels in the `encoder_hidden_states`.
315
- scale (`float`, defaults to 1.0):
316
- the weight scale of image prompt.
317
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
318
- The context length of the image features.
319
- """
320
-
321
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
322
- super().__init__()
323
-
324
- if not hasattr(F, "scaled_dot_product_attention"):
325
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
326
-
327
- self.hidden_size = hidden_size
328
- self.cross_attention_dim = cross_attention_dim
329
- self.scale = scale
330
- self.num_tokens = num_tokens
331
-
332
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
-
335
- def forward(
336
- self,
337
- attn,
338
- hidden_states,
339
- encoder_hidden_states=None,
340
- attention_mask=None,
341
- temb=None,
342
- ):
343
- residual = hidden_states
344
-
345
- if attn.spatial_norm is not None:
346
- hidden_states = attn.spatial_norm(hidden_states, temb)
347
-
348
- input_ndim = hidden_states.ndim
349
-
350
- if input_ndim == 4:
351
- batch_size, channel, height, width = hidden_states.shape
352
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
353
-
354
- batch_size, sequence_length, _ = (
355
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
356
- )
357
-
358
- if attention_mask is not None:
359
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
360
- # scaled_dot_product_attention expects attention_mask shape to be
361
- # (batch, heads, source_length, target_length)
362
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
363
-
364
- if attn.group_norm is not None:
365
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
366
-
367
- query = attn.to_q(hidden_states)
368
-
369
- if encoder_hidden_states is None:
370
- encoder_hidden_states = hidden_states
371
- else:
372
- # get encoder_hidden_states, ip_hidden_states
373
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
374
- encoder_hidden_states, ip_hidden_states = (
375
- encoder_hidden_states[:, :end_pos, :],
376
- encoder_hidden_states[:, end_pos:, :],
377
- )
378
- if attn.norm_cross:
379
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
380
-
381
- key = attn.to_k(encoder_hidden_states)
382
- value = attn.to_v(encoder_hidden_states)
383
-
384
- inner_dim = key.shape[-1]
385
- head_dim = inner_dim // attn.heads
386
-
387
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
-
389
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
-
392
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
- # TODO: add support for attn.scale when we move to Torch 2.1
394
- hidden_states = F.scaled_dot_product_attention(
395
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
- )
397
-
398
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
- hidden_states = hidden_states.to(query.dtype)
400
-
401
- # for ip-adapter
402
- ip_key = self.to_k_ip(ip_hidden_states)
403
- ip_value = self.to_v_ip(ip_hidden_states)
404
-
405
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
-
408
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
409
- # TODO: add support for attn.scale when we move to Torch 2.1
410
- ip_hidden_states = F.scaled_dot_product_attention(
411
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
412
- )
413
- with torch.no_grad():
414
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
415
- #print(self.attn_map.shape)
416
-
417
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
418
- ip_hidden_states = ip_hidden_states.to(query.dtype)
419
-
420
- # region control
421
- if len(region_control.prompt_image_conditioning) == 1:
422
- region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
423
- if region_mask is not None:
424
- query = query.reshape([-1, query.shape[-2], query.shape[-1]])
425
- h, w = region_mask.shape[:2]
426
- ratio = (h * w / query.shape[1]) ** 0.5
427
- mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
428
- else:
429
- mask = torch.ones_like(ip_hidden_states)
430
- ip_hidden_states = ip_hidden_states * mask
431
-
432
- hidden_states = hidden_states + self.scale * ip_hidden_states
433
-
434
- # linear proj
435
- hidden_states = attn.to_out[0](hidden_states)
436
- # dropout
437
- hidden_states = attn.to_out[1](hidden_states)
438
-
439
- if input_ndim == 4:
440
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
441
-
442
- if attn.residual_connection:
443
- hidden_states = hidden_states + residual
444
-
445
- hidden_states = hidden_states / attn.rescale_output_factor
446
-
447
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/resampler.py DELETED
@@ -1,121 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
- # FFN
9
- def FeedForward(dim, mult=4):
10
- inner_dim = int(dim * mult)
11
- return nn.Sequential(
12
- nn.LayerNorm(dim),
13
- nn.Linear(dim, inner_dim, bias=False),
14
- nn.GELU(),
15
- nn.Linear(inner_dim, dim, bias=False),
16
- )
17
-
18
-
19
- def reshape_tensor(x, heads):
20
- bs, length, width = x.shape
21
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
- x = x.view(bs, length, heads, -1)
23
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
- x = x.transpose(1, 2)
25
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
- x = x.reshape(bs, heads, length, -1)
27
- return x
28
-
29
-
30
- class PerceiverAttention(nn.Module):
31
- def __init__(self, *, dim, dim_head=64, heads=8):
32
- super().__init__()
33
- self.scale = dim_head**-0.5
34
- self.dim_head = dim_head
35
- self.heads = heads
36
- inner_dim = dim_head * heads
37
-
38
- self.norm1 = nn.LayerNorm(dim)
39
- self.norm2 = nn.LayerNorm(dim)
40
-
41
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
-
45
-
46
- def forward(self, x, latents):
47
- """
48
- Args:
49
- x (torch.Tensor): image features
50
- shape (b, n1, D)
51
- latent (torch.Tensor): latent features
52
- shape (b, n2, D)
53
- """
54
- x = self.norm1(x)
55
- latents = self.norm2(latents)
56
-
57
- b, l, _ = latents.shape
58
-
59
- q = self.to_q(latents)
60
- kv_input = torch.cat((x, latents), dim=-2)
61
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
-
63
- q = reshape_tensor(q, self.heads)
64
- k = reshape_tensor(k, self.heads)
65
- v = reshape_tensor(v, self.heads)
66
-
67
- # attention
68
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
- out = weight @ v
72
-
73
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
-
75
- return self.to_out(out)
76
-
77
-
78
- class Resampler(nn.Module):
79
- def __init__(
80
- self,
81
- dim=1024,
82
- depth=8,
83
- dim_head=64,
84
- heads=16,
85
- num_queries=8,
86
- embedding_dim=768,
87
- output_dim=1024,
88
- ff_mult=4,
89
- ):
90
- super().__init__()
91
-
92
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
-
94
- self.proj_in = nn.Linear(embedding_dim, dim)
95
-
96
- self.proj_out = nn.Linear(dim, output_dim)
97
- self.norm_out = nn.LayerNorm(output_dim)
98
-
99
- self.layers = nn.ModuleList([])
100
- for _ in range(depth):
101
- self.layers.append(
102
- nn.ModuleList(
103
- [
104
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
- FeedForward(dim=dim, mult=ff_mult),
106
- ]
107
- )
108
- )
109
-
110
- def forward(self, x):
111
-
112
- latents = self.latents.repeat(x.size(0), 1, 1)
113
-
114
- x = self.proj_in(x)
115
-
116
- for attn, ff in self.layers:
117
- latents = attn(x, latents) + latents
118
- latents = ff(latents) + latents
119
-
120
- latents = self.proj_out(latents)
121
- return self.norm_out(latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/utils.py DELETED
@@ -1,5 +0,0 @@
1
- import torch.nn.functional as F
2
-
3
-
4
- def is_torch2_available():
5
- return hasattr(F, "scaled_dot_product_attention")