vedant-jumle commited on
Commit
62f01f5
Β·
1 Parent(s): 94ef592

Initial CosAE release

Browse files
Files changed (3) hide show
  1. README.md +105 -0
  2. config.json +40 -0
  3. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CosAE: Convolutional Harmonic Autoencoder
2
+ CosAE is a PyTorch implementation of the Convolutional Harmonic Autoencoder (CosAE). It encodes images into learnable harmonic representations (amplitudes and phases), constructs spatial cosine bases via a Harmonic Construction Module, and decodes back to RGB images. This repository provides the core model code, a Jupyter notebook for training and evaluation, and pretrained weights in SafeTensors format.
3
+
4
+ ## Features
5
+ - Convolutional encoder with residual blocks and optional global attention
6
+ - Harmonic Construction Module (HCM) learning per-channel frequencies
7
+ - Decoder with upsampling and optional global attention for reconstruction
8
+ - Supports raw RGB input or augmented FFT channels (RGB + FFT)
9
+ - Pretrained model weights available (SafeTensors)
10
+ - Training and tracking via Jupyter notebook and Weights & Biases
11
+
12
+ ## Installation
13
+ ### Requirements
14
+ ```bash
15
+ pip install -r requirements.txt
16
+ ```
17
+
18
+ ## Quickstart
19
+ ### Inference with Pretrained Model
20
+ ```python
21
+ import torch
22
+ from cosae import CosAEModel
23
+
24
+ # Load pretrained CosAE (expects 9-channel input: 3 RGB + 6 FFT)
25
+ model = CosAEModel.from_pretrained("model/model-final")
26
+ model.eval()
27
+
28
+ # Dummy input: batch of 1, 9 channels, 256Γ—256
29
+ x = torch.randn(1, 9, 256, 256)
30
+ with torch.no_grad():
31
+ recon = model(x) # recon.shape == [1, 3, 256, 256]
32
+ ```
33
+
34
+ ### Preparing Raw Images
35
+ ```python
36
+ from PIL import Image
37
+ from torchvision import transforms
38
+
39
+ transform = transforms.Compose([
40
+ transforms.Resize((256, 256)),
41
+ transforms.ToTensor(),
42
+ ])
43
+ img = Image.open("path/to/image.png").convert("RGB")
44
+ tensor = transform(img).unsqueeze(0) # [1, 3, 256, 256]
45
+ ```
46
+
47
+ ### Generating FFT Channels (Optional)
48
+ To use FFT-augmented input (in_channels=9), compute the 2D FFT per RGB channel and stack real/imaginary parts:
49
+ ```python
50
+ import torch
51
+ fft = torch.fft.rfft2(tensor, norm="ortho") # complex tensor [1,3,H,W/2+1]
52
+ real, imag = fft.real, fft.imag
53
+ # Optionally pad or reshape to match full spatial dims
54
+ x9 = torch.cat([tensor, real, imag], dim=1) # [1,9,H,W]
55
+ ```
56
+
57
+ ## Creating a Model from Scratch
58
+ Use the `CosAEConfig` and `CosAEModel` classes to instantiate a model with custom settings:
59
+ ```python
60
+ from cosae.config import CosAEConfig
61
+ from cosae.cosae import CosAEModel
62
+
63
+ # 1) Define a configuration (example uses RGB input only)
64
+ config = CosAEConfig(
65
+ in_channels=3,
66
+ hidden_dims=[64, 128, 256, 512],
67
+ downsample_strides=[2, 2, 2, 2],
68
+ use_encoder_attention=True,
69
+ encoder_attention_heads=8,
70
+ encoder_attention_layers=1,
71
+ bottleneck_channels=256,
72
+ basis_size=16,
73
+ decoder_hidden_dim=256,
74
+ decoder_upsample_strides=[2],
75
+ use_decoder_attention=False,
76
+ )
77
+
78
+ # 2) Instantiate the model
79
+ model = CosAEModel(config)
80
+
81
+ # 3) (Optional) Save the model and config for later reuse
82
+ model.save_pretrained("./my_cosae_model") # creates model-final.safetensors and config.json
83
+ ```
84
+
85
+ ## Training and Evaluation
86
+ A full training and evaluation pipeline is provided in `cosine-ae.ipynb`. Launch Jupyter to run experiments and track metrics with Weights & Biases:
87
+ ```bash
88
+ jupyter lab cosine-ae.ipynb
89
+ ```
90
+
91
+ ## Repository Structure
92
+ ```
93
+ .
94
+ β”œβ”€β”€ cosae/ # Core model implementation (config, encoder, HCM, decoder)
95
+ β”œβ”€β”€ model/ # Pretrained weights and config (SafeTensors)
96
+ β”œβ”€β”€ cosine-ae.ipynb # Notebook: training, evaluation, and demos
97
+ β”œβ”€β”€ LICENSE # MIT License
98
+ └── README.md # Project overview and usage
99
+ ```
100
+
101
+ ## Contributing
102
+ Contributions, issues, and feature requests are welcome. Feel free to fork the repository and submit pull requests.
103
+
104
+ ## License
105
+ This project is released under the MIT License. See [LICENSE](LICENSE) for details.
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "CosAEModel"
5
+ ],
6
+ "basis_size": 16,
7
+ "bottleneck_channels": 256,
8
+ "decoder_attention_heads": 8,
9
+ "decoder_attention_layers": 0,
10
+ "decoder_hidden_dim": 256,
11
+ "decoder_upsample_strides": [
12
+ 2
13
+ ],
14
+ "downsample_strides": [
15
+ 2,
16
+ 2,
17
+ 2,
18
+ 2
19
+ ],
20
+ "encoder_attention_heads": 8,
21
+ "encoder_attention_layers": 1,
22
+ "hidden_dims": [
23
+ 64,
24
+ 128,
25
+ 256,
26
+ 512
27
+ ],
28
+ "image_size": [
29
+ 256,
30
+ 256
31
+ ],
32
+ "in_channels": 9,
33
+ "model_type": "cosae",
34
+ "norm_type": "gn",
35
+ "num_res_blocks": 2,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.51.3",
38
+ "use_decoder_attention": false,
39
+ "use_encoder_attention": true
40
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52bfdf9130a8688360fad5fba4763aca935a358c6221e8e0b2f009fdc7930f2f
3
+ size 72900732