| import torch.nn as nn | |
| from transformers import PretrainedConfig | |
| class cceVAEConfig(PretrainedConfig): | |
| model_type = "cceVAE" | |
| def __init__( | |
| self, | |
| d=2, | |
| input_size=(1, 256, 256), | |
| z_dim=1024, | |
| fmap_sizes=(16, 64, 256, 1024), | |
| to_1x1=True, | |
| conv_params=None, | |
| tconv_params=None, | |
| normalization_op=None, | |
| normalization_params=None, | |
| activation_op="prelu", | |
| activation_params=None, | |
| block_op=None, | |
| block_params=None, | |
| **kwargs): | |
| self.d = d | |
| self.input_size = input_size | |
| self.z_dim = z_dim | |
| self.fmap_sizes = fmap_sizes | |
| self.to_1x1 = to_1x1 | |
| self.conv_params = conv_params | |
| self.tconv_params = tconv_params | |
| self.normalization_op = normalization_op | |
| self.normalization_params = normalization_params | |
| self.activation_op = activation_op | |
| self.activation_params = activation_params | |
| self.block_op = block_op | |
| self.block_params = block_params | |
| super().__init__(**kwargs) |