File size: 62,624 Bytes
8a94acf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 |
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import einsum, nn
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
ModelOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModelForCausalLM
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
replace_return_docstrings,
)
from .configuration_granite_speech import (
GraniteSpeechConfig,
GraniteSpeechEncoderConfig,
GraniteSpeechProjectorConfig,
)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GraniteSpeechConfig"
@dataclass
class GraniteSpeechCausalLMOutputWithPast(ModelOutput):
"""
Base class for LlavaNext causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
### Projector
# Currently, we copy the Qformer code directly to avoid depending on Blip2;
# it would be better to create the model from config, similar to the LLM,
# but to do this, we will need to register the QFormer model into an automodel,
# which will should involve pulling it out into its own dir so that it is accessible
# under transformers.models.X.
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->GraniteSpeech
class GraniteSpeechQFormerMultiHeadAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
% (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->GraniteSpeechQFormer
class GraniteSpeechQFormerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->GraniteSpeech
class GraniteSpeechQFormerAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.attention = GraniteSpeechQFormerMultiHeadAttention(config, is_cross_attention)
self.output = GraniteSpeechQFormerSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->GraniteSpeechQFormer
class GraniteSpeechQFormerIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->GraniteSpeechQFormer
class GraniteSpeechQFormerOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerLayer with Blip2->GraniteSpeech
class GraniteSpeechQFormerLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = GraniteSpeechQFormerAttention(config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = GraniteSpeechQFormerAttention(config, is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
if config.use_qformer_text_input:
self.intermediate = GraniteSpeechQFormerIntermediate(config)
self.output = GraniteSpeechQFormerOutput(config)
self.intermediate_query = GraniteSpeechQFormerIntermediate(config)
self.output_query = GraniteSpeechQFormerOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->GraniteSpeech
class GraniteSpeechQFormerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[GraniteSpeechQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if layer_module.has_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class GraniteSpeechEncoderProjectorPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GraniteSpeechProjectorConfig
base_model_prefix = "qformer"
supports_gradient_checkpointing = True
_no_split_modules = [
"GraniteSpeechQFormerMultiHeadAttention",
"T5Block",
"OPTDecoderLayer",
]
_skip_keys_device_placement = "past_key_values"
#_keep_in_fp32_modules = ["query_tokens"]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class GraniteSpeechQFormerModel(GraniteSpeechEncoderProjectorPreTrainedModel):
"""
Querying Transformer (Q-Former), used in GraniteSpeech.
"""
def __init__(self, config: GraniteSpeechProjectorConfig):
super().__init__(config)
self.config = config
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = GraniteSpeechQFormerEncoder(config)
self.post_init()
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings.word_embeddings
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel._prune_heads
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel.get_extended_attention_mask
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int],
device: torch.device,
has_query: bool = False,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel.forward
def forward(
self,
query_embeds: torch.FloatTensor,
query_length: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
)
query_length = (
query_length if query_length is not None else query_embeds.shape[1] if query_embeds is not None else 0
)
embedding_output = self.layernorm(query_embeds)
embedding_output = self.dropout(embedding_output)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if isinstance(encoder_attention_mask, list):
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = sequence_output[:, 0, :]
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# TODO (alex) - refactor GraniteSpeechQformer to be available under
# transformers.models.X, delete all of the code above, and
# create the model through AutoModel.
class GraniteSpeechEncoderProjectorQFormer(nn.Module):
def __init__(self, config: GraniteSpeechProjectorConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.ds_rate = config.downsample_rate
self.window_size = config.window_size
self.num_queries = self.window_size // self.ds_rate
self.query = nn.Parameter(torch.zeros(1, self.num_queries, config.hidden_size))
self.query.data.normal_(mean=0.0, std=1.0)
# NOTE: It would be better to create this from config, similar to the LLM.
# To do this, we need to register the QFormer model into an automodel, which
# will require pulling it out into its own dir so that it's accessible under
# transformers.models.X
self.qformer = GraniteSpeechQFormerModel(config)
self.linear = nn.Linear(config.hidden_size, config.llm_dim)
def forward(self, x, atts):
batch_size, seq_len, dim = x.size()
nblocks = math.ceil(seq_len / self.window_size)
pad = nblocks * self.window_size - seq_len
x = nn.functional.pad(x, (0, 0, 0, pad), "constant", 0)
x = x.view(batch_size * nblocks, self.window_size, dim)
query_output = self.qformer(
query_embeds=self.query.data,
encoder_hidden_states=x,
encoder_attention_mask=atts,
return_dict=True,
)
query_proj = self.linear(
query_output.last_hidden_state.view(batch_size, nblocks * self.window_size // self.ds_rate, -1)
)
return query_proj
### Encoder
class GraniteSpeechCTCModel(nn.Module):
def __init__(self, config: GraniteSpeechEncoderConfig):
super(GraniteSpeechCTCModel, self).__init__()
self.rnn_tr = nn.ModuleList(
[nn.Linear(config.input_dim, config.hidden_dim, bias=True)]
+ [
GraniteSpeechConformerBlock(
dim=config.hidden_dim,
dim_head=config.dim_head,
heads=config.num_heads,
ff_mult=config.feedforward_mult,
conv_expansion_factor=config.conv_expansion_factor,
conv_kernel_size=config.conv_kernel_size,
context_size=config.context_size, # attention context size
attn_dropout=config.dropout,
ff_dropout=config.dropout,
conv_dropout=config.dropout,
)
for layer_idx in range(config.num_layers)
]
)
self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True)
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True)
self.context_size = config.context_size
self.input_dim = config.input_dim
self.num_layers = config.num_layers
self.hidden_dim = config.hidden_dim
self.output_dim = config.output_dim
def forward(self, x: torch.Tensor):
x = self.rnn_tr[0](x)
for idx, layer in enumerate(self.rnn_tr[1:], start=1):
x = layer(x, self.context_size)
if idx == self.num_layers // 2:
x_mid = x.clone()
x_mid = self.out(x_mid)
x += self.out_mid(nn.Softmax(dim=-1)(x_mid))
return x
# NOTE: Conformer adapated from: https://github.com/lucidrains/conformer.git
class GraniteSpeechConformerPermute(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def forward(self, x):
x = x.permute(self.dims)
return x
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False)
def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)
class GraniteSpeechConformerScale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class GraniteSpeechConformerPreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class GraniteSpeechConformerPreNormAttn(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, context_size, **kwargs):
x = self.norm(x)
return self.fn(x, context_size, **kwargs)
class GraniteSpeechConformerAttention(nn.Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.0,
context_size=200,
max_pos_emb=512,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.scale = dim_head**-0.5
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
self.max_pos_emb = max_pos_emb
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context_size):
device, h, max_pos_emb = x.device, self.heads, self.max_pos_emb
bs, n, d = x.shape
assert context_size > 0 and context_size <= max_pos_emb
nb = math.ceil(n / context_size)
nr = n % context_size
if nr > 0:
# right padding to reach block size
x = torch.nn.functional.pad(x, (0, 0, 0, context_size - nr))
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
q, k, v = [t.reshape(bs, nb, context_size, h, -1).transpose(2, 3) for t in (q, k, v)]
dots = einsum("b m h i d, b m h j d -> b m h i j", q, k) * self.scale
# shaw's relative positional embedding
seq = torch.arange(context_size, device=device)
dist = seq.view(-1, 1) - seq.view(1, -1)
dist = torch.clamp(dist, -context_size, context_size) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)
pos_attn = einsum("b m h c d, c r d -> b m h c r", q, rel_pos_emb) * self.scale
dots = dots + pos_attn
if nr > 0:
# masked attention in the extended block
mask = torch.ones(context_size, context_size, dtype=bool, device=device)
mask[:nr, :nr] = 0
mask_value = -torch.finfo(dots.dtype).max
dots[:, -1, :].masked_fill_(mask, mask_value)
attn = dots.softmax(dim=-1)
out = einsum("b m h i j, b m h j d -> b m h i d", attn, v)
out = out.transpose(2, 3).reshape(bs, x.shape[1], -1)
out = self.to_out(out[:, :n, :])
return self.dropout(out)
class GraniteSpeechConformerFeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult), nn.SiLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class GraniteSpeechConformerConvModule(nn.Module):
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
super().__init__()
inner_dim = dim * expansion_factor
padding = self.calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
GraniteSpeechConformerPermute(dims=(0, 2, 1)),
nn.Conv1d(dim, inner_dim * 2, 1),
nn.GLU(dim=1),
GraniteSpeechConformerDepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
nn.SiLU(),
nn.Conv1d(inner_dim, dim, 1),
GraniteSpeechConformerPermute(dims=(0, 2, 1)),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
@staticmethod
def calc_same_padding(kernel_size: int):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
class GraniteSpeechConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=2,
conv_expansion_factor=2,
conv_kernel_size=31,
context_size=-1,
attn_dropout=0.0,
ff_dropout=0.0,
conv_dropout=0.0,
):
super().__init__()
self.ff1 = GraniteSpeechConformerFeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.attn = GraniteSpeechConformerAttention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
context_size=context_size,
)
self.conv = GraniteSpeechConformerConvModule(
dim=dim,
causal=False,
expansion_factor=conv_expansion_factor,
kernel_size=conv_kernel_size,
dropout=conv_dropout,
)
self.ff2 = GraniteSpeechConformerFeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.attn = GraniteSpeechConformerPreNormAttn(dim, self.attn)
self.ff1 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(dim, self.ff1))
self.ff2 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(dim, self.ff2))
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, context_size):
x = self.ff1(x) + x
x = self.attn(x, context_size) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
GRANITE_SPEECH_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config (`GraniteSpeechConfig`):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Granite Speech Model outputting raw hidden-states without any specific head on top.",
GRANITE_SPEECH_START_DOCSTRING,
)
class GraniteSpeechPreTrainedModel(PreTrainedModel):
config_class = GraniteSpeechConfig
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
GRANITE_SPEECH_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
input_features (`torch.FloatTensor` of shape `(batch_size, audio seq len, mel feat dim)):
The tensors corresponding to the input audios. input features can be obtained using
[`AutoFeatureExtractor`]. See [`GraniteSpeechFeatureExtractor.__call__`] for details.
[`GraniteSpeechProcessor`] uses [`GraniteSpeechFeatureExtractor`] for processing audio.
input_mask (`torch.Tensor`, *optional*)
Mask for extracted audio features that should should be ignored when creating the merged
multimodal representation (i.e., due to padding).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"""The Granite Speech model, which consists of an audio encoder, projector, and language model.""",
GRANITE_SPEECH_START_DOCSTRING,
)
class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin):
def __init__(self, config: GraniteSpeechConfig):
super().__init__(config)
# NOTE: It doesn't matter when we initialize from config, but we should be careful
# to make sure this does not pick up the adapter_config if in the future we use
# from_pretrained or something similar, since that should be set by the composite
# model; don't need to consider it twice
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.encoder = GraniteSpeechCTCModel(config.encoder_config)
self.projector = GraniteSpeechEncoderProjectorQFormer(config.projector_config)
if config.has_lora_adapter and not is_peft_available():
logger.warning(
"Config indicates that a lora adapter should be present, but "
"peft is not installed; this will cause the model to perform "
"incorrectly when audio inputs are provided. Please install "
"peft and reload the model!"
)
self.post_init()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def get_audio_features(self, input_features):
encoder_embeds = self.encoder(input_features)
projected_embeds = self.projector(encoder_embeds, None)
return projected_embeds
@add_start_docstrings_to_model_forward(GRANITE_SPEECH_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GraniteSpeechCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
input_features: torch.FloatTensor = None,
input_features_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple[torch.Tensor], GraniteSpeechCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
TODO - add example for usage.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_features is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_features and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
# Get the base embeddings; set all audio tokens to 0 index
# to avoid out of vocabulary issues with the LLM embedding.
# Audio features will be masked into is_audio_idx indices later.
is_audio_idx = input_ids == self.config.audio_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[is_audio_idx] = 0
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if input_features is not None:
if input_features.dtype != self.dtype:
logger.warning(f"input features are casted to {self.dtype}")
input_features = input_features.to(self.dtype)
# Get the audio features from the encoder / projector
audio_features = self.get_audio_features(input_features)
# Merge the audio features into the LLM embeddings
inputs_embeds = self.get_merged_audio_embeddings(
input_ids=input_ids, audio_features=audio_features, input_features_mask=input_features_mask
)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return GraniteSpeechCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
input_features=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
# If we're in cached decoding stage, input_features should be None because
# input ids do not contain special audio token anymore Otherwise we need
# input feature values to be passed to the model
if cache_position[0] == 0:
model_inputs["input_features"] = input_features
return model_inputs
def get_merged_audio_embeddings(self, input_ids, audio_features, input_features_mask):
"""
Adds the audio token to the model's LLM vocabulary so that we can pass it
through the tokenizer; it's assumed that the embeddings corresponding to the
<|audio|> token will be clobbered with speech features.
TODO - This needs to be adapted to handle batches of variable length sequences
and potentially labels.
"""
is_audio_index = input_ids == self.config.audio_token_index
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
# Mask the audio features into the text embeddings
special_audio_mask = is_audio_index.unsqueeze(-1)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)[input_features_mask]
inputs_embeds = inputs_embeds.masked_scatter(
special_audio_mask,
audio_features,
)
return inputs_embeds
def generate(self, *args, **kwargs):
"""This model is expected to have a lora adapater, which is only
enabled when considering audio inputs. As such, we override generate
to conditionally enable / disable the lora adapter based on whether
or not any input features were provided.
"""
input_features = kwargs.pop("input_features", None)
if is_peft_available and self._hf_peft_config_loaded:
if input_features is not None:
self.enable_adapters()
else:
self.disable_adapters()
return super().generate(*args, input_features=input_features, **kwargs)
def save_pretrained(self, *args, **kwargs):
# overwrite save_pretrained to first save the adapter if we have one
# NOTE - this will use the base model path we are exporting in the lora
# adapter, which may not necessarily be the best behavior, but for now
# we keep this for portability, since using the local dir causes problems
# if the model is loaded from outside of the current working dir.
if is_peft_available and self._hf_peft_config_loaded:
super().save_pretrained(*args, **kwargs)
# Then save the base model afterwards
self._hf_peft_config_loaded = False
super().save_pretrained(*args, **kwargs)
__all__ = [
"GraniteSpeechForConditionalGeneration",
"GraniteSpeechPreTrainedModel",
"GraniteSpeechEncoderProjectorPreTrainedModel",
"GraniteSpeechQFormerModel",
]
|