File size: 4,539 Bytes
8a94acf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING, AutoConfig
class GraniteSpeechEncoderConfig(PretrainedConfig):
model_type = "granite_speech_encoder"
def __init__(
self,
input_dim=160,
num_layers=10,
hidden_dim=1024,
feedforward_mult=4,
num_heads=8,
dim_head=128,
output_dim=42,
context_size=200,
dropout=0.1,
conv_kernel_size=15,
conv_expansion_factor=2,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.feedforward_mult = feedforward_mult
self.num_heads = num_heads
self.dim_head = dim_head
self.output_dim = output_dim
self.context_size = context_size
self.dropout = dropout
self.conv_kernel_size = conv_kernel_size
self.conv_expansion_factor = conv_expansion_factor
## adapted from transformers.models.blip.configuration_blip_2.Blip2VisionConfig
class GraniteSpeechProjectorConfig(PretrainedConfig):
model_type = "granite_speech_qformer"
def __init__(
self,
llm_dim=4096,
downsample_rate=5,
window_size=15,
hidden_size=1024,
num_attention_heads=16,
intermediate_size=4096,
num_hidden_layers=2,
encoder_hidden_size=1024,
cross_attention_frequency=1,
max_position_embeddings=2048,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_qformer_text_input=False,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.cross_attention_frequency = cross_attention_frequency
self.encoder_hidden_size = encoder_hidden_size
self.use_qformer_text_input = use_qformer_text_input
self.downsample_rate = downsample_rate
self.window_size = window_size
self.llm_dim = llm_dim
class GraniteSpeechConfig(PretrainedConfig):
model_type = "granite_speech"
sub_configs = {
"text_config": AutoConfig,
"encoder_config": GraniteSpeechEncoderConfig,
"projector_config": GraniteSpeechProjectorConfig,
}
def __init__(
self,
encoder_config=None,
text_config=None,
projector_config=None,
audio_token_index=49155,
initializer_range=0.02,
has_lora_adapter=True,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "granite"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["granite"]()
if isinstance(projector_config, dict):
# TODO - In the future, we should make this generic.
projector_config = GraniteSpeechProjectorConfig(**projector_config)
elif projector_config is None:
projector_config = GraniteSpeechProjectorConfig()
if not isinstance(encoder_config, GraniteSpeechEncoderConfig):
encoder_config = {} if encoder_config is None else encoder_config
encoder_config = GraniteSpeechEncoderConfig(**encoder_config)
self.text_config = text_config
self.encoder_config = encoder_config
self.projector_config = projector_config
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
self.has_lora_adapter = has_lora_adapter
super().__init__(**kwargs)
__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechProjectorConfig", "GraniteSpeechConfig"]
|