preston-cell commited on
Commit
0286478
·
verified ·
1 Parent(s): 4c9d528

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +203 -0
models.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchtune
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from torchtune.models import llama3_2
8
+
9
+
10
+ def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
11
+ return llama3_2.llama3_2(
12
+ vocab_size=128_256,
13
+ num_layers=16,
14
+ num_heads=32,
15
+ num_kv_heads=8,
16
+ embed_dim=2048,
17
+ max_seq_len=2048,
18
+ intermediate_dim=8192,
19
+ attn_dropout=0.0,
20
+ norm_eps=1e-5,
21
+ rope_base=500_000,
22
+ scale_factor=32,
23
+ )
24
+
25
+
26
+ def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
27
+ return llama3_2.llama3_2(
28
+ vocab_size=128_256,
29
+ num_layers=4,
30
+ num_heads=8,
31
+ num_kv_heads=2,
32
+ embed_dim=1024,
33
+ max_seq_len=2048,
34
+ intermediate_dim=8192,
35
+ attn_dropout=0.0,
36
+ norm_eps=1e-5,
37
+ rope_base=500_000,
38
+ scale_factor=32,
39
+ )
40
+
41
+
42
+ FLAVORS = {
43
+ "llama-1B": llama3_2_1B,
44
+ "llama-100M": llama3_2_100M,
45
+ }
46
+
47
+
48
+ def _prepare_transformer(model):
49
+ embed_dim = model.tok_embeddings.embedding_dim
50
+ model.tok_embeddings = nn.Identity()
51
+ model.output = nn.Identity()
52
+ return model, embed_dim
53
+
54
+
55
+ def _create_causal_mask(seq_len: int, device: torch.device):
56
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
57
+
58
+
59
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
60
+ """
61
+ Args:
62
+ mask: (max_seq_len, max_seq_len)
63
+ input_pos: (batch_size, seq_len)
64
+
65
+ Returns:
66
+ (batch_size, seq_len, max_seq_len)
67
+ """
68
+ r = mask[input_pos, :]
69
+ return r
70
+
71
+
72
+ def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
73
+ q = torch.empty_like(probs).exponential_(1)
74
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
75
+
76
+
77
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
78
+ logits = logits / temperature
79
+
80
+ filter_value: float = -float("Inf")
81
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
82
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
83
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
84
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
85
+
86
+ sample_token = _multinomial_sample_one_no_sync(probs)
87
+ return sample_token
88
+
89
+
90
+ @dataclass
91
+ class ModelArgs:
92
+ backbone_flavor: str
93
+ decoder_flavor: str
94
+ text_vocab_size: int
95
+ audio_vocab_size: int
96
+ audio_num_codebooks: int
97
+
98
+
99
+ class Model(
100
+ nn.Module,
101
+ PyTorchModelHubMixin,
102
+ repo_url="https://github.com/SesameAILabs/csm",
103
+ pipeline_tag="text-to-speech",
104
+ license="apache-2.0",
105
+ ):
106
+ def __init__(self, config: ModelArgs):
107
+ super().__init__()
108
+ self.config = config
109
+
110
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
111
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
112
+
113
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
114
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
115
+
116
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
117
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
118
+ self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
119
+
120
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
121
+ """Setup KV caches and return a causal mask."""
122
+ dtype = next(self.parameters()).dtype
123
+ device = next(self.parameters()).device
124
+
125
+ with device:
126
+ self.backbone.setup_caches(max_batch_size, dtype)
127
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
128
+
129
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
130
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
131
+
132
+ def generate_frame(
133
+ self,
134
+ tokens: torch.Tensor,
135
+ tokens_mask: torch.Tensor,
136
+ input_pos: torch.Tensor,
137
+ temperature: float,
138
+ topk: int,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
143
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
144
+ input_pos: (batch_size, seq_len) positions for each token
145
+ mask: (batch_size, seq_len, max_seq_len
146
+
147
+ Returns:
148
+ (batch_size, audio_num_codebooks) sampled tokens
149
+ """
150
+ dtype = next(self.parameters()).dtype
151
+ b, s, _ = tokens.size()
152
+
153
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
154
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
155
+ embeds = self._embed_tokens(tokens)
156
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
157
+ h = masked_embeds.sum(dim=2)
158
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
159
+
160
+ last_h = h[:, -1, :]
161
+ c0_logits = self.codebook0_head(last_h)
162
+ c0_sample = sample_topk(c0_logits, topk, temperature)
163
+ c0_embed = self._embed_audio(0, c0_sample)
164
+
165
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
166
+ curr_sample = c0_sample.clone()
167
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
168
+
169
+ # Decoder caches must be reset every frame.
170
+ self.decoder.reset_caches()
171
+ for i in range(1, self.config.audio_num_codebooks):
172
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
173
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
174
+ dtype=dtype
175
+ )
176
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
177
+ ci_sample = sample_topk(ci_logits, topk, temperature)
178
+ ci_embed = self._embed_audio(i, ci_sample)
179
+
180
+ curr_h = ci_embed
181
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
182
+ curr_pos = curr_pos[:, -1:] + 1
183
+
184
+ return curr_sample
185
+
186
+ def reset_caches(self):
187
+ self.backbone.reset_caches()
188
+ self.decoder.reset_caches()
189
+
190
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
191
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
192
+
193
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
194
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
195
+
196
+ audio_tokens = tokens[:, :, :-1] + (
197
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
198
+ )
199
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
200
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
201
+ )
202
+
203
+ return torch.cat([audio_embeds, text_embeds], dim=-2)