Alex Ergasti commited on
Commit
0e0633b
·
1 Parent(s): 9837429

Update model

Browse files
Files changed (1) hide show
  1. models.py +3 -13
models.py CHANGED
@@ -1,14 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # GLIDE: https://github.com/openai/glide-text2im
9
- # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
- # --------------------------------------------------------
11
-
12
  import torch
13
  import torch.nn as nn
14
  import numpy as np
@@ -16,6 +5,8 @@ import math
16
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
  import einops
18
 
 
 
19
  import torch.utils.checkpoint as checkpoint
20
 
21
  from transformers import PreTrainedModel
@@ -371,7 +362,7 @@ class FinalLayer(nn.Module):
371
  return x
372
 
373
 
374
- class FLAV(nn.Module):
375
  """
376
  Diffusion model with a Transformer backbone.
377
  """
@@ -748,4 +739,3 @@ FLAV_models = {
748
  'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
749
  'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
750
  }
751
-
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import numpy as np
 
5
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
6
  import einops
7
 
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
  import torch.utils.checkpoint as checkpoint
11
 
12
  from transformers import PreTrainedModel
 
362
  return x
363
 
364
 
365
+ class FLAV(nn.Module, PyTorchModelHubMixin):
366
  """
367
  Diffusion model with a Transformer backbone.
368
  """
 
739
  'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
740
  'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
741
  }