alethanhson commited on
Commit
63c4f82
·
1 Parent(s): 4c46478
Files changed (4) hide show
  1. generator.py +190 -0
  2. models.py +203 -0
  3. requirements.txt +13 -0
  4. watermarking.py +79 -0
generator.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from huggingface_hub import hf_hub_download
7
+ from models import Model
8
+ from moshi.models import loaders
9
+ from tokenizers.processors import TemplateProcessing
10
+ from transformers import AutoTokenizer
11
+ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
12
+
13
+
14
+ @dataclass
15
+ class Segment:
16
+ speaker: int
17
+ text: str
18
+ # (num_samples,), sample_rate = 24_000
19
+ audio: torch.Tensor
20
+
21
+
22
+ def load_llama3_tokenizer():
23
+ """
24
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
25
+ """
26
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
27
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
28
+ bos = tokenizer.bos_token
29
+ eos = tokenizer.eos_token
30
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
31
+ single=f"{bos}:0 $A:0 {eos}:0",
32
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
33
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
34
+ )
35
+
36
+ return tokenizer
37
+
38
+
39
+ class Generator:
40
+ def __init__(
41
+ self,
42
+ model: Model,
43
+ ):
44
+ self._model = model
45
+ self._model.setup_caches(1)
46
+
47
+ self._text_tokenizer = load_llama3_tokenizer()
48
+
49
+ device = next(model.parameters()).device
50
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
51
+ mimi = loaders.get_mimi(mimi_weight, device=device)
52
+ mimi.set_num_codebooks(32)
53
+ self._audio_tokenizer = mimi
54
+
55
+ self._watermarker = load_watermarker(device=device)
56
+
57
+ self.sample_rate = mimi.sample_rate
58
+ self.device = device
59
+
60
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ frame_tokens = []
62
+ frame_masks = []
63
+
64
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
65
+ text_frame = torch.zeros(len(text_tokens), 33).long()
66
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
67
+ text_frame[:, -1] = torch.tensor(text_tokens)
68
+ text_frame_mask[:, -1] = True
69
+
70
+ frame_tokens.append(text_frame.to(self.device))
71
+ frame_masks.append(text_frame_mask.to(self.device))
72
+
73
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
74
+
75
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ frame_tokens = []
77
+ frame_masks = []
78
+
79
+ # (K, T)
80
+ audio = audio.to(self.device)
81
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
82
+ # add EOS frame
83
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
84
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
85
+
86
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
87
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
88
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
89
+ audio_frame_mask[:, :-1] = True
90
+
91
+ frame_tokens.append(audio_frame)
92
+ frame_masks.append(audio_frame_mask)
93
+
94
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
95
+
96
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ """
98
+ Returns:
99
+ (seq_len, 33), (seq_len, 33)
100
+ """
101
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
102
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
103
+
104
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
105
+
106
+ @torch.inference_mode()
107
+ def generate(
108
+ self,
109
+ text: str,
110
+ speaker: int,
111
+ context: List[Segment],
112
+ max_audio_length_ms: float = 90_000,
113
+ temperature: float = 0.9,
114
+ topk: int = 50,
115
+ ) -> torch.Tensor:
116
+ self._model.reset_caches()
117
+
118
+ max_audio_frames = int(max_audio_length_ms / 80)
119
+ tokens, tokens_mask = [], []
120
+ for segment in context:
121
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
122
+ tokens.append(segment_tokens)
123
+ tokens_mask.append(segment_tokens_mask)
124
+
125
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
126
+ tokens.append(gen_segment_tokens)
127
+ tokens_mask.append(gen_segment_tokens_mask)
128
+
129
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
130
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
131
+
132
+ samples = []
133
+ curr_tokens = prompt_tokens.unsqueeze(0)
134
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
135
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
136
+
137
+ max_seq_len = 2048 - max_audio_frames
138
+ if curr_tokens.size(1) >= max_seq_len:
139
+ raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
140
+
141
+ for _ in range(max_audio_frames):
142
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
143
+ if torch.all(sample == 0):
144
+ break # eos
145
+
146
+ samples.append(sample)
147
+
148
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
149
+ curr_tokens_mask = torch.cat(
150
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
151
+ ).unsqueeze(1)
152
+ curr_pos = curr_pos[:, -1:] + 1
153
+
154
+ audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
155
+
156
+ # This applies an imperceptible watermark to identify audio as AI-generated.
157
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
158
+ # Please be a responsible AI citizen and keep the watermarking in place.
159
+ # If using CSM 1B in another application, use your own private key and keep it secret.
160
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
161
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
162
+
163
+ return audio
164
+
165
+
166
+ # def load_csm_1b(device: str = "cuda") -> Generator:
167
+ # model = Model.from_pretrained("sesame/csm-1b")
168
+ # model.to(device=device, dtype=torch.bfloat16)
169
+
170
+ # generator = Generator(model)
171
+ # return generator
172
+
173
+ def load_csm_1b(device="cuda"):
174
+ """
175
+ Load the CSM-1B model with proper configuration
176
+ """
177
+ from silentcipher import Config # Import the proper Config class
178
+
179
+ # Create a default configuration or load it from the model
180
+ model_path = "sesame/csm-1b"
181
+ config = Config.from_pretrained(model_path)
182
+
183
+ # Pass the config to the Model constructor
184
+ model = Model.from_pretrained(model_path, config=config)
185
+ model = model.to(device)
186
+
187
+ # Rest of your loading code remains the same
188
+ # ...
189
+
190
+ return Generator(model, device=device)
models.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchtune
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from torchtune.models import llama3_2
8
+
9
+
10
+ def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
11
+ return llama3_2.llama3_2(
12
+ vocab_size=128_256,
13
+ num_layers=16,
14
+ num_heads=32,
15
+ num_kv_heads=8,
16
+ embed_dim=2048,
17
+ max_seq_len=2048,
18
+ intermediate_dim=8192,
19
+ attn_dropout=0.0,
20
+ norm_eps=1e-5,
21
+ rope_base=500_000,
22
+ scale_factor=32,
23
+ )
24
+
25
+
26
+ def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
27
+ return llama3_2.llama3_2(
28
+ vocab_size=128_256,
29
+ num_layers=4,
30
+ num_heads=8,
31
+ num_kv_heads=2,
32
+ embed_dim=1024,
33
+ max_seq_len=2048,
34
+ intermediate_dim=8192,
35
+ attn_dropout=0.0,
36
+ norm_eps=1e-5,
37
+ rope_base=500_000,
38
+ scale_factor=32,
39
+ )
40
+
41
+
42
+ FLAVORS = {
43
+ "llama-1B": llama3_2_1B,
44
+ "llama-100M": llama3_2_100M,
45
+ }
46
+
47
+
48
+ def _prepare_transformer(model):
49
+ embed_dim = model.tok_embeddings.embedding_dim
50
+ model.tok_embeddings = nn.Identity()
51
+ model.output = nn.Identity()
52
+ return model, embed_dim
53
+
54
+
55
+ def _create_causal_mask(seq_len: int, device: torch.device):
56
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
57
+
58
+
59
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
60
+ """
61
+ Args:
62
+ mask: (max_seq_len, max_seq_len)
63
+ input_pos: (batch_size, seq_len)
64
+
65
+ Returns:
66
+ (batch_size, seq_len, max_seq_len)
67
+ """
68
+ r = mask[input_pos, :]
69
+ return r
70
+
71
+
72
+ def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
73
+ q = torch.empty_like(probs).exponential_(1)
74
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
75
+
76
+
77
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
78
+ logits = logits / temperature
79
+
80
+ filter_value: float = -float("Inf")
81
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
82
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
83
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
84
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
85
+
86
+ sample_token = _multinomial_sample_one_no_sync(probs)
87
+ return sample_token
88
+
89
+
90
+ @dataclass
91
+ class ModelArgs:
92
+ backbone_flavor: str
93
+ decoder_flavor: str
94
+ text_vocab_size: int
95
+ audio_vocab_size: int
96
+ audio_num_codebooks: int
97
+
98
+
99
+ class Model(
100
+ nn.Module,
101
+ PyTorchModelHubMixin,
102
+ repo_url="https://github.com/SesameAILabs/csm",
103
+ pipeline_tag="text-to-speech",
104
+ license="apache-2.0",
105
+ ):
106
+ def __init__(self, config: ModelArgs):
107
+ super().__init__()
108
+ self.config = config
109
+
110
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
111
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
112
+
113
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
114
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
115
+
116
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
117
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
118
+ self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
119
+
120
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
121
+ """Setup KV caches and return a causal mask."""
122
+ dtype = next(self.parameters()).dtype
123
+ device = next(self.parameters()).device
124
+
125
+ with device:
126
+ self.backbone.setup_caches(max_batch_size, dtype)
127
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
128
+
129
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
130
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
131
+
132
+ def generate_frame(
133
+ self,
134
+ tokens: torch.Tensor,
135
+ tokens_mask: torch.Tensor,
136
+ input_pos: torch.Tensor,
137
+ temperature: float,
138
+ topk: int,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
143
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
144
+ input_pos: (batch_size, seq_len) positions for each token
145
+ mask: (batch_size, seq_len, max_seq_len
146
+
147
+ Returns:
148
+ (batch_size, audio_num_codebooks) sampled tokens
149
+ """
150
+ dtype = next(self.parameters()).dtype
151
+ b, s, _ = tokens.size()
152
+
153
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
154
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
155
+ embeds = self._embed_tokens(tokens)
156
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
157
+ h = masked_embeds.sum(dim=2)
158
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
159
+
160
+ last_h = h[:, -1, :]
161
+ c0_logits = self.codebook0_head(last_h)
162
+ c0_sample = sample_topk(c0_logits, topk, temperature)
163
+ c0_embed = self._embed_audio(0, c0_sample)
164
+
165
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
166
+ curr_sample = c0_sample.clone()
167
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
168
+
169
+ # Decoder caches must be reset every frame.
170
+ self.decoder.reset_caches()
171
+ for i in range(1, self.config.audio_num_codebooks):
172
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
173
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
174
+ dtype=dtype
175
+ )
176
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
177
+ ci_sample = sample_topk(ci_logits, topk, temperature)
178
+ ci_embed = self._embed_audio(i, ci_sample)
179
+
180
+ curr_h = ci_embed
181
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
182
+ curr_pos = curr_pos[:, -1:] + 1
183
+
184
+ return curr_sample
185
+
186
+ def reset_caches(self):
187
+ self.backbone.reset_caches()
188
+ self.decoder.reset_caches()
189
+
190
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
191
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
192
+
193
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
194
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
195
+
196
+ audio_tokens = tokens[:, :, :-1] + (
197
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
198
+ )
199
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
200
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
201
+ )
202
+
203
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchaudio==2.4.0
3
+ tokenizers==0.21.0
4
+ transformers==4.49.0
5
+ huggingface_hub==0.28.1
6
+ moshi==0.2.2
7
+ torchtune==0.4.0
8
+ torchao==0.9.0
9
+ silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
10
+ fastapi
11
+ uvicorn[standard]
12
+ python-multipart==0.0.9
13
+ pydantic==2.6.1
watermarking.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import silentcipher
4
+ import torch
5
+ import torchaudio
6
+
7
+ # This watermark key is public, it is not secure.
8
+ # If using CSM 1B in another application, use a new private key and keep it secret.
9
+ CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
10
+
11
+
12
+ def cli_check_audio() -> None:
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--audio_path", type=str, required=True)
15
+ args = parser.parse_args()
16
+
17
+ check_audio_from_file(args.audio_path)
18
+
19
+
20
+ def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
21
+ model = silentcipher.get_model(
22
+ model_type="44.1k",
23
+ device=device,
24
+ )
25
+ return model
26
+
27
+
28
+ @torch.inference_mode()
29
+ def watermark(
30
+ watermarker: silentcipher.server.Model,
31
+ audio_array: torch.Tensor,
32
+ sample_rate: int,
33
+ watermark_key: list[int],
34
+ ) -> tuple[torch.Tensor, int]:
35
+ audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
36
+ encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
37
+
38
+ output_sample_rate = min(44100, sample_rate)
39
+ encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
40
+ return encoded, output_sample_rate
41
+
42
+
43
+ @torch.inference_mode()
44
+ def verify(
45
+ watermarker: silentcipher.server.Model,
46
+ watermarked_audio: torch.Tensor,
47
+ sample_rate: int,
48
+ watermark_key: list[int],
49
+ ) -> bool:
50
+ watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
51
+ result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
52
+
53
+ is_watermarked = result["status"]
54
+ if is_watermarked:
55
+ is_csm_watermarked = result["messages"][0] == watermark_key
56
+ else:
57
+ is_csm_watermarked = False
58
+
59
+ return is_watermarked and is_csm_watermarked
60
+
61
+
62
+ def check_audio_from_file(audio_path: str) -> None:
63
+ watermarker = load_watermarker(device="cuda")
64
+
65
+ audio_array, sample_rate = load_audio(audio_path)
66
+ is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
67
+
68
+ outcome = "Watermarked" if is_watermarked else "Not watermarked"
69
+ print(f"{outcome}: {audio_path}")
70
+
71
+
72
+ def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
73
+ audio_array, sample_rate = torchaudio.load(audio_path)
74
+ audio_array = audio_array.mean(dim=0)
75
+ return audio_array, int(sample_rate)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ cli_check_audio()