Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,322 Bytes
1034391 4aa0f34 1034391 4aa0f34 1034391 4aa0f34 1034391 4aa0f34 1034391 4aa0f34 1034391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
"""Configuration management module for the Dia model.
This module provides comprehensive configuration management for the Dia model,
utilizing Pydantic for validation. It defines configurations for data processing,
model architecture (encoder and decoder), and training settings.
Key components:
- DataConfig: Parameters for data loading and preprocessing.
- EncoderConfig: Architecture details for the encoder module.
- DecoderConfig: Architecture details for the decoder module.
- ModelConfig: Combined model architecture settings.
- TrainingConfig: Training hyperparameters and settings.
- DiaConfig: Master configuration combining all components.
"""
import os
from typing import Annotated
from pydantic import BaseModel, BeforeValidator, Field
class DataConfig(BaseModel, frozen=True):
"""Configuration for data loading and preprocessing.
Attributes:
text_length: Maximum length of text sequences (must be multiple of 128).
audio_length: Maximum length of audio sequences (must be multiple of 128).
channels: Number of audio channels.
text_pad_value: Value used for padding text sequences.
audio_eos_value: Value representing the end of audio sequences.
audio_bos_value: Value representing the beginning of audio sequences.
audio_pad_value: Value used for padding audio sequences.
delay_pattern: List of delay values for each audio channel.
"""
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
Field(gt=0, multiple_of=128)
)
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
Field(gt=0, multiple_of=128)
)
channels: int = Field(default=9, gt=0, multiple_of=1)
text_pad_value: int = Field(default=0)
audio_eos_value: int = Field(default=1024)
audio_pad_value: int = Field(default=1025)
audio_bos_value: int = Field(default=1026)
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
)
def __hash__(self) -> int:
"""Generate a hash based on all fields of the config."""
return hash(
(
self.text_length,
self.audio_length,
self.channels,
self.text_pad_value,
self.audio_pad_value,
self.audio_bos_value,
self.audio_eos_value,
tuple(self.delay_pattern),
)
)
class EncoderConfig(BaseModel, frozen=True):
"""Configuration for the encoder component of the Dia model.
Attributes:
n_layer: Number of transformer layers.
n_embd: Embedding dimension.
n_hidden: Hidden dimension size in the MLP layers.
n_head: Number of attention heads.
head_dim: Dimension per attention head.
"""
n_layer: int = Field(gt=0)
n_embd: int = Field(gt=0)
n_hidden: int = Field(gt=0)
n_head: int = Field(gt=0)
head_dim: int = Field(gt=0)
class DecoderConfig(BaseModel, frozen=True):
"""Configuration for the decoder component of the Dia model.
Attributes:
n_layer: Number of transformer layers.
n_embd: Embedding dimension.
n_hidden: Hidden dimension size in the MLP layers.
gqa_query_heads: Number of query heads for grouped-query self-attention.
kv_heads: Number of key/value heads for grouped-query self-attention.
gqa_head_dim: Dimension per query head for grouped-query self-attention.
cross_query_heads: Number of query heads for cross-attention.
cross_head_dim: Dimension per cross-attention head.
"""
n_layer: int = Field(gt=0)
n_embd: int = Field(gt=0)
n_hidden: int = Field(gt=0)
gqa_query_heads: int = Field(gt=0)
kv_heads: int = Field(gt=0)
gqa_head_dim: int = Field(gt=0)
cross_query_heads: int = Field(gt=0)
cross_head_dim: int = Field(gt=0)
class ModelConfig(BaseModel, frozen=True):
"""Main configuration container for the Dia model architecture.
Attributes:
encoder: Configuration for the encoder component.
decoder: Configuration for the decoder component.
src_vocab_size: Size of the source (text) vocabulary.
tgt_vocab_size: Size of the target (audio code) vocabulary.
dropout: Dropout probability applied within the model.
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
"""
encoder: EncoderConfig
decoder: DecoderConfig
src_vocab_size: int = Field(default=128, gt=0)
tgt_vocab_size: int = Field(default=1028, gt=0)
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
weight_dtype: str = Field(default="float32", description="Weight precision")
rope_min_timescale: int = Field(
default=1, description="Timescale For global Attention"
)
rope_max_timescale: int = Field(
default=10_000, description="Timescale For global Attention"
)
class TrainingConfig(BaseModel, frozen=True):
pass
class DiaConfig(BaseModel, frozen=True):
"""Master configuration for the Dia model.
Combines all sub-configurations into a single validated object.
Attributes:
version: Configuration version string.
model: Model architecture configuration.
training: Training process configuration (precision settings).
data: Data loading and processing configuration.
"""
version: str = Field(default="1.0")
model: ModelConfig
# TODO: remove training. this is just for backwards-compatability
training: TrainingConfig
data: DataConfig
def save(self, path: str) -> None:
"""Save the current configuration instance to a JSON file.
Ensures the parent directory exists and the file has a .json extension.
Args:
path: The target file path to save the configuration.
Raises:
ValueError: If the path is not a file with a .json extension.
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
config_json = self.model_dump_json(indent=2)
with open(path, "w") as f:
f.write(config_json)
@classmethod
def load(cls, path: str) -> "DiaConfig | None":
"""Load and validate a Dia configuration from a JSON file.
Args:
path: The path to the configuration file.
Returns:
A validated DiaConfig instance if the file exists and is valid,
otherwise None if the file is not found.
Raises:
ValueError: If the path does not point to an existing .json file.
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
"""
try:
with open(path, "r") as f:
content = f.read()
return cls.model_validate_json(content)
except FileNotFoundError:
return None
|