|
"""Configuration management module for the Dia model. |
|
|
|
This module provides comprehensive configuration management for the Dia model, |
|
utilizing Pydantic for validation. It defines configurations for data processing, |
|
model architecture (encoder and decoder), and training settings. |
|
|
|
Key components: |
|
- DataConfig: Parameters for data loading and preprocessing. |
|
- EncoderConfig: Architecture details for the encoder module. |
|
- DecoderConfig: Architecture details for the decoder module. |
|
- ModelConfig: Combined model architecture settings. |
|
- TrainingConfig: Training hyperparameters and settings. |
|
- DiaConfig: Master configuration combining all components. |
|
""" |
|
|
|
import os |
|
from typing import Annotated |
|
|
|
from pydantic import BaseModel, BeforeValidator, Field |
|
|
|
|
|
class DataConfig(BaseModel, frozen=True): |
|
"""Configuration for data loading and preprocessing. |
|
|
|
Attributes: |
|
text_length: Maximum length of text sequences (must be multiple of 128). |
|
audio_length: Maximum length of audio sequences (must be multiple of 128). |
|
channels: Number of audio channels. |
|
text_pad_value: Value used for padding text sequences. |
|
audio_eos_value: Value representing the end of audio sequences. |
|
audio_bos_value: Value representing the beginning of audio sequences. |
|
audio_pad_value: Value used for padding audio sequences. |
|
delay_pattern: List of delay values for each audio channel. |
|
""" |
|
|
|
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( |
|
Field(gt=0, multiple_of=128) |
|
) |
|
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( |
|
Field(gt=0, multiple_of=128) |
|
) |
|
channels: int = Field(default=9, gt=0, multiple_of=1) |
|
text_pad_value: int = Field(default=0) |
|
audio_eos_value: int = Field(default=1024) |
|
audio_pad_value: int = Field(default=1025) |
|
audio_bos_value: int = Field(default=1026) |
|
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field( |
|
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15] |
|
) |
|
|
|
def __hash__(self) -> int: |
|
"""Generate a hash based on all fields of the config.""" |
|
return hash( |
|
( |
|
self.text_length, |
|
self.audio_length, |
|
self.channels, |
|
self.text_pad_value, |
|
self.audio_pad_value, |
|
self.audio_bos_value, |
|
self.audio_eos_value, |
|
tuple(self.delay_pattern), |
|
) |
|
) |
|
|
|
|
|
class EncoderConfig(BaseModel, frozen=True): |
|
"""Configuration for the encoder component of the Dia model. |
|
|
|
Attributes: |
|
n_layer: Number of transformer layers. |
|
n_embd: Embedding dimension. |
|
n_hidden: Hidden dimension size in the MLP layers. |
|
n_head: Number of attention heads. |
|
head_dim: Dimension per attention head. |
|
""" |
|
|
|
n_layer: int = Field(gt=0) |
|
n_embd: int = Field(gt=0) |
|
n_hidden: int = Field(gt=0) |
|
n_head: int = Field(gt=0) |
|
head_dim: int = Field(gt=0) |
|
|
|
|
|
class DecoderConfig(BaseModel, frozen=True): |
|
"""Configuration for the decoder component of the Dia model. |
|
|
|
Attributes: |
|
n_layer: Number of transformer layers. |
|
n_embd: Embedding dimension. |
|
n_hidden: Hidden dimension size in the MLP layers. |
|
gqa_query_heads: Number of query heads for grouped-query self-attention. |
|
kv_heads: Number of key/value heads for grouped-query self-attention. |
|
gqa_head_dim: Dimension per query head for grouped-query self-attention. |
|
cross_query_heads: Number of query heads for cross-attention. |
|
cross_head_dim: Dimension per cross-attention head. |
|
""" |
|
|
|
n_layer: int = Field(gt=0) |
|
n_embd: int = Field(gt=0) |
|
n_hidden: int = Field(gt=0) |
|
gqa_query_heads: int = Field(gt=0) |
|
kv_heads: int = Field(gt=0) |
|
gqa_head_dim: int = Field(gt=0) |
|
cross_query_heads: int = Field(gt=0) |
|
cross_head_dim: int = Field(gt=0) |
|
|
|
|
|
class ModelConfig(BaseModel, frozen=True): |
|
"""Main configuration container for the Dia model architecture. |
|
|
|
Attributes: |
|
encoder: Configuration for the encoder component. |
|
decoder: Configuration for the decoder component. |
|
src_vocab_size: Size of the source (text) vocabulary. |
|
tgt_vocab_size: Size of the target (audio code) vocabulary. |
|
dropout: Dropout probability applied within the model. |
|
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm). |
|
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16"). |
|
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE). |
|
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE). |
|
""" |
|
|
|
encoder: EncoderConfig |
|
decoder: DecoderConfig |
|
src_vocab_size: int = Field(default=128, gt=0) |
|
tgt_vocab_size: int = Field(default=1028, gt=0) |
|
dropout: float = Field(default=0.0, ge=0.0, lt=1.0) |
|
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0) |
|
weight_dtype: str = Field(default="float32", description="Weight precision") |
|
rope_min_timescale: int = Field( |
|
default=1, description="Timescale For global Attention" |
|
) |
|
rope_max_timescale: int = Field( |
|
default=10_000, description="Timescale For global Attention" |
|
) |
|
|
|
|
|
class TrainingConfig(BaseModel, frozen=True): |
|
pass |
|
|
|
|
|
class DiaConfig(BaseModel, frozen=True): |
|
"""Master configuration for the Dia model. |
|
|
|
Combines all sub-configurations into a single validated object. |
|
|
|
Attributes: |
|
version: Configuration version string. |
|
model: Model architecture configuration. |
|
training: Training process configuration (precision settings). |
|
data: Data loading and processing configuration. |
|
""" |
|
|
|
version: str = Field(default="1.0") |
|
model: ModelConfig |
|
|
|
training: TrainingConfig |
|
data: DataConfig |
|
|
|
def save(self, path: str) -> None: |
|
"""Save the current configuration instance to a JSON file. |
|
|
|
Ensures the parent directory exists and the file has a .json extension. |
|
|
|
Args: |
|
path: The target file path to save the configuration. |
|
|
|
Raises: |
|
ValueError: If the path is not a file with a .json extension. |
|
""" |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
config_json = self.model_dump_json(indent=2) |
|
with open(path, "w") as f: |
|
f.write(config_json) |
|
|
|
@classmethod |
|
def load(cls, path: str) -> "DiaConfig | None": |
|
"""Load and validate a Dia configuration from a JSON file. |
|
|
|
Args: |
|
path: The path to the configuration file. |
|
|
|
Returns: |
|
A validated DiaConfig instance if the file exists and is valid, |
|
otherwise None if the file is not found. |
|
|
|
Raises: |
|
ValueError: If the path does not point to an existing .json file. |
|
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. |
|
""" |
|
try: |
|
with open(path, "r") as f: |
|
content = f.read() |
|
return cls.model_validate_json(content) |
|
except FileNotFoundError: |
|
return None |
|
|