DASS_small_AudioSet_47.2 / configuration_dass.py
saurabhati's picture
Upload DASSForAudioClassification
4e65175 verified
# coding=utf-8
"""Distilled Audio State-Space Model (DASS) configuration"""
from typing import Any, Dict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DASSConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DASSModel`]. It is used to instantiate a DASS
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[DASS-small](https://github.com/Saurabhbhati/DASS/) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
patch_size (`int`, *optional*, defaults to 4):
The size (resolution) of each patch.
embed_dim (`int`, *optional*, defaults to 96):
Dimensionality of patch embedding.
depths (`list(int)`, *optional*, defaults to `[2, 2, 8, 2]`):
Depth of each layer in the DASS encoder.
dims (`list(int)`, *optional*, defaults to `[96, 192, 384, 768]`):
Dimensionality of each layer in the DASS encoder.
drop_path_rate (`float`, *optional*, defaults to 0.2):
Stochastic depth rate.
num_classes (`int`, *optional*, defaults to 527):
Number of classes for classification.
max_length (`int`, *optional*, defaults to 1024):
Temporal dimension of the spectrograms.
num_mel_bins (`int`, *optional*, defaults to 128):
Frequency dimension of the spectrograms (number of Mel-frequency bins).
use_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to use checkpointing to save memory.
Example:
```python
>>> from transformers import DASSConfig, DASSModel
>>> # Initializing a DASS small style configuration
>>> configuration = DASSConfig()
>>> # Initializing a model (with random weights) from the DASS small style configuration
>>> model = DASSModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "dass"
def __init__(
self,
patch_size: int = 4,
embed_dim: int = 96,
depths: list = [2, 2, 8, 2],
dims: list =[96, 192, 384, 768],
drop_path_rate: float = 0.2,
num_classes: int = 527,
max_length: int = 1024,
num_mel_bins: int = 128,
use_checkpoint: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depths = depths
self.dims = dims
self.drop_path_rate = drop_path_rate
self.num_classes = num_classes
self.max_length = max_length
self.num_mel_bins = num_mel_bins
self.use_checkpoint = use_checkpoint
# Overwritten from the parent class: DASS is not compatible with `generate`, but has a config parameter sharing the
# same name (`max_length`). Sharing the same name triggers checks regarding the config -> generation_config
# generative parameters deprecation cycle, overwriting this function prevents this from happening.
def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
return {}
__all__ = ["DASSConfig"]