Spaces:
Running
on
Zero
Running
on
Zero
"""Configuration management module for the Dia model. | |
This module provides comprehensive configuration management for the Dia model, | |
utilizing Pydantic for validation. It defines configurations for data processing, | |
model architecture (encoder and decoder), and training settings. | |
Key components: | |
- DataConfig: Parameters for data loading and preprocessing. | |
- EncoderConfig: Architecture details for the encoder module. | |
- DecoderConfig: Architecture details for the decoder module. | |
- ModelConfig: Combined model architecture settings. | |
- TrainingConfig: Training hyperparameters and settings. | |
- DiaConfig: Master configuration combining all components. | |
""" | |
import os | |
from typing import Annotated | |
from pydantic import BaseModel, BeforeValidator, Field | |
class DataConfig(BaseModel, frozen=True): | |
"""Configuration for data loading and preprocessing. | |
Attributes: | |
text_length: Maximum length of text sequences (must be multiple of 128). | |
audio_length: Maximum length of audio sequences (must be multiple of 128). | |
channels: Number of audio channels. | |
text_pad_value: Value used for padding text sequences. | |
audio_eos_value: Value representing the end of audio sequences. | |
audio_bos_value: Value representing the beginning of audio sequences. | |
audio_pad_value: Value used for padding audio sequences. | |
delay_pattern: List of delay values for each audio channel. | |
""" | |
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128) | |
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128) | |
channels: int = Field(default=9, gt=0, multiple_of=1) | |
text_pad_value: int = Field(default=0) | |
audio_eos_value: int = Field(default=1024) | |
audio_pad_value: int = Field(default=1025) | |
audio_bos_value: int = Field(default=1026) | |
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]) | |
def __hash__(self) -> int: | |
"""Generate a hash based on all fields of the config.""" | |
return hash( | |
( | |
self.text_length, | |
self.audio_length, | |
self.channels, | |
self.text_pad_value, | |
self.audio_pad_value, | |
self.audio_bos_value, | |
self.audio_eos_value, | |
tuple(self.delay_pattern), | |
) | |
) | |
class EncoderConfig(BaseModel, frozen=True): | |
"""Configuration for the encoder component of the Dia model. | |
Attributes: | |
n_layer: Number of transformer layers. | |
n_embd: Embedding dimension. | |
n_hidden: Hidden dimension size in the MLP layers. | |
n_head: Number of attention heads. | |
head_dim: Dimension per attention head. | |
mlp_activations: List of activation functions for the MLP layers. | |
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP). | |
""" | |
n_layer: int = Field(gt=0) | |
n_embd: int = Field(gt=0) | |
n_hidden: int = Field(gt=0) | |
n_head: int = Field(gt=0) | |
head_dim: int = Field(gt=0) | |
mlp_activations: list[str] = Field(default=["silu", "linear"]) | |
use_pre_norm: bool = Field(default=False) | |
class DecoderConfig(BaseModel, frozen=True): | |
"""Configuration for the decoder component of the Dia model. | |
Attributes: | |
n_layer: Number of transformer layers. | |
n_embd: Embedding dimension. | |
n_hidden: Hidden dimension size in the MLP layers. | |
gqa_query_heads: Number of query heads for grouped-query self-attention. | |
kv_heads: Number of key/value heads for grouped-query self-attention. | |
gqa_head_dim: Dimension per query head for grouped-query self-attention. | |
cross_query_heads: Number of query heads for cross-attention. | |
cross_head_dim: Dimension per cross-attention head. | |
mlp_activations: List of activation functions for the MLP layers. | |
use_pre_norm: Whether to use pre-normalization. | |
""" | |
n_layer: int = Field(gt=0) | |
n_embd: int = Field(gt=0) | |
n_hidden: int = Field(gt=0) | |
gqa_query_heads: int = Field(gt=0) | |
kv_heads: int = Field(gt=0) | |
gqa_head_dim: int = Field(gt=0) | |
cross_query_heads: int = Field(gt=0) | |
cross_head_dim: int = Field(gt=0) | |
mlp_activations: list[str] = Field(default=["silu", "linear"]) | |
use_pre_norm: bool = Field(default=False) | |
class ModelConfig(BaseModel, frozen=True): | |
"""Main configuration container for the Dia model architecture. | |
Attributes: | |
encoder: Configuration for the encoder component. | |
decoder: Configuration for the decoder component. | |
src_vocab_size: Size of the source (text) vocabulary. | |
tgt_vocab_size: Size of the target (audio code) vocabulary. | |
dropout: Dropout probability applied within the model. | |
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm). | |
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16"). | |
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE). | |
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE). | |
""" | |
encoder: EncoderConfig | |
decoder: DecoderConfig | |
src_vocab_size: int = Field(default=128, gt=0) | |
tgt_vocab_size: int = Field(default=1028, gt=0) | |
dropout: float = Field(default=0.0, ge=0.0, lt=1.0) | |
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0) | |
weight_dtype: str = Field(default="float32", description="Weight precision") | |
rope_min_timescale: int = Field(default=1, description="Timescale For global Attention") | |
rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention") | |
class TrainingConfig(BaseModel, frozen=True): | |
"""Training process configuration and hyperparameters. | |
Note: This configuration currently only includes precision settings. | |
Other training parameters (like batch size, learning rate, optimizer settings) | |
are assumed to be handled externally. | |
Attributes: | |
dtype: Data type for activations during training (e.g., "bfloat16", "float32"). | |
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability. | |
""" | |
dtype: str = Field(default="bfloat16", description="Activation precision") | |
logits_dot_in_fp32: bool = Field(default=False) | |
class DiaConfig(BaseModel, frozen=True): | |
"""Master configuration for the Dia model. | |
Combines all sub-configurations into a single validated object. | |
Attributes: | |
version: Configuration version string. | |
model: Model architecture configuration. | |
training: Training process configuration (precision settings). | |
data: Data loading and processing configuration. | |
""" | |
version: str = Field(default="1.0") | |
model: ModelConfig | |
training: TrainingConfig | |
data: DataConfig | |
def save(self, path: str) -> None: | |
"""Save the current configuration instance to a JSON file. | |
Ensures the parent directory exists and the file has a .json extension. | |
Args: | |
path: The target file path to save the configuration. | |
Raises: | |
ValueError: If the path is not a file with a .json extension. | |
""" | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
config_json = self.model_dump_json(indent=2) | |
with open(path, "w") as f: | |
f.write(config_json) | |
def load(cls, path: str) -> "DiaConfig | None": | |
"""Load and validate a Dia configuration from a JSON file. | |
Args: | |
path: The path to the configuration file. | |
Returns: | |
A validated DiaConfig instance if the file exists and is valid, | |
otherwise None if the file is not found. | |
Raises: | |
ValueError: If the path does not point to an existing .json file. | |
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. | |
""" | |
try: | |
with open(path, "r") as f: | |
content = f.read() | |
return cls.model_validate_json(content) | |
except FileNotFoundError: | |
return None | |