File size: 3,523 Bytes
4e65175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# coding=utf-8
"""Distilled Audio State-Space Model (DASS) configuration"""

from typing import Any, Dict

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)

class DASSConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`DASSModel`]. It is used to instantiate a DASS
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the
    [DASS-small](https://github.com/Saurabhbhati/DASS/) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        patch_size (`int`, *optional*, defaults to 4):
            The size (resolution) of each patch.
        embed_dim (`int`, *optional*, defaults to 96):
            Dimensionality of patch embedding.
        depths (`list(int)`, *optional*, defaults to `[2, 2, 8, 2]`):
            Depth of each layer in the DASS encoder.
        dims (`list(int)`, *optional*, defaults to `[96, 192, 384, 768]`):
            Dimensionality of each layer in the DASS encoder.
        drop_path_rate (`float`, *optional*, defaults to 0.2):
            Stochastic depth rate.
        num_classes (`int`, *optional*, defaults to 527):
            Number of classes for classification.
        max_length (`int`, *optional*, defaults to 1024):
            Temporal dimension of the spectrograms.
        num_mel_bins (`int`, *optional*, defaults to 128):
            Frequency dimension of the spectrograms (number of Mel-frequency bins).
        use_checkpoint (`bool`, *optional*, defaults to `False`):
            Whether to use checkpointing to save memory.

    Example:

    ```python
    >>> from transformers import DASSConfig, DASSModel

    >>> # Initializing a DASS small style configuration
    >>> configuration = DASSConfig()

    >>> # Initializing a model (with random weights) from the DASS small style configuration
    >>> model = DASSModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "dass"

    def __init__(
        self,
        patch_size: int = 4,
        embed_dim: int = 96,
        depths: list = [2, 2, 8, 2],
        dims: list =[96, 192, 384, 768],
        drop_path_rate: float = 0.2,
        num_classes: int = 527,
        max_length: int = 1024,
        num_mel_bins: int = 128,
        use_checkpoint: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.depths = depths
        self.dims = dims
        self.drop_path_rate = drop_path_rate
        self.num_classes = num_classes
        self.max_length = max_length
        self.num_mel_bins = num_mel_bins
        self.use_checkpoint = use_checkpoint

    # Overwritten from the parent class: DASS is not compatible with `generate`, but has a config parameter sharing the
    # same name (`max_length`). Sharing the same name triggers checks regarding the config -> generation_config
    # generative parameters deprecation cycle, overwriting this function prevents this from happening.
    def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
        return {}


__all__ = ["DASSConfig"]