3v324v23 commited on
Commit
7e6ee0b
·
1 Parent(s): c2bb5f2

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. customs/customsf +0 -0
  2. data/__init__.py +0 -3
  3. data/__pycache__/__init__.cpython-311.pyc +0 -0
  4. data/__pycache__/collation.cpython-311.pyc +0 -0
  5. data/__pycache__/input_strategies.cpython-311.pyc +0 -0
  6. data/__pycache__/tokenizer.cpython-311.pyc +0 -0
  7. data/collation.py +0 -120
  8. data/datamodule.py +0 -419
  9. data/dataset.py +0 -242
  10. data/fbank.py +0 -212
  11. data/input_strategies.py +0 -159
  12. data/tokenizer.py +0 -126
  13. macros.py +0 -44
  14. models/__init__.py +0 -136
  15. models/__pycache__/__init__.cpython-311.pyc +0 -0
  16. models/__pycache__/macros.cpython-311.pyc +0 -0
  17. models/__pycache__/transformer.cpython-311.pyc +0 -0
  18. models/__pycache__/vallex.cpython-311.pyc +0 -0
  19. models/__pycache__/visualizer.cpython-311.pyc +0 -0
  20. models/macros.py +0 -11
  21. models/transformer.py +0 -394
  22. models/vallex.py +0 -853
  23. models/visualizer.py +0 -106
  24. modules/__init__.py +0 -0
  25. modules/__pycache__/__init__.cpython-311.pyc +0 -0
  26. modules/__pycache__/activation.cpython-311.pyc +0 -0
  27. modules/__pycache__/embedding.cpython-311.pyc +0 -0
  28. modules/__pycache__/scaling.cpython-311.pyc +0 -0
  29. modules/__pycache__/transformer.cpython-311.pyc +0 -0
  30. modules/activation.py +0 -612
  31. modules/embedding.py +0 -97
  32. modules/optim.py +0 -1105
  33. modules/scaling.py +0 -1401
  34. modules/scheduler.py +0 -78
  35. modules/transformer.py +0 -683
  36. prompts/promptsf +0 -0
  37. utils/__init__.py +0 -15
  38. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  39. utils/__pycache__/generation.cpython-311.pyc +0 -0
  40. utils/__pycache__/prompt_making.cpython-311.pyc +0 -0
  41. utils/__pycache__/sentence_cutter.cpython-311.pyc +0 -0
  42. utils/__pycache__/symbol_table.cpython-311.pyc +0 -0
  43. utils/download.py +0 -49
  44. utils/g2p/__init__.py +0 -72
  45. utils/g2p/__pycache__/__init__.cpython-311.pyc +0 -0
  46. utils/g2p/__pycache__/cleaners.cpython-311.pyc +0 -0
  47. utils/g2p/__pycache__/english.cpython-311.pyc +0 -0
  48. utils/g2p/__pycache__/japanese.cpython-311.pyc +0 -0
  49. utils/g2p/__pycache__/mandarin.cpython-311.pyc +0 -0
  50. utils/g2p/__pycache__/symbols.cpython-311.pyc +0 -0
customs/customsf DELETED
File without changes
data/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- # from .datamodule import *
2
- # from .tokenizer import *
3
- from .collation import *
 
 
 
 
data/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (208 Bytes)
 
data/__pycache__/collation.cpython-311.pyc DELETED
Binary file (7.2 kB)
 
data/__pycache__/input_strategies.cpython-311.pyc DELETED
Binary file (1.8 kB)
 
data/__pycache__/tokenizer.cpython-311.pyc DELETED
Binary file (6.77 kB)
 
data/collation.py DELETED
@@ -1,120 +0,0 @@
1
- from pathlib import Path
2
- from typing import List, Tuple
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from utils import SymbolTable
8
-
9
-
10
- class TextTokenCollater:
11
- """Collate list of text tokens
12
-
13
- Map sentences to integers. Sentences are padded to equal length.
14
- Beginning and end-of-sequence symbols can be added.
15
-
16
- Example:
17
- >>> token_collater = TextTokenCollater(text_tokens)
18
- >>> tokens_batch, tokens_lens = token_collater(text)
19
-
20
- Returns:
21
- tokens_batch: IntTensor of shape (B, L)
22
- B: batch dimension, number of input sentences
23
- L: length of the longest sentence
24
- tokens_lens: IntTensor of shape (B,)
25
- Length of each sentence after adding <eos> and <bos>
26
- but before padding.
27
- """
28
-
29
- def __init__(
30
- self,
31
- text_tokens: List[str],
32
- add_eos: bool = True,
33
- add_bos: bool = True,
34
- pad_symbol: str = "<pad>",
35
- bos_symbol: str = "<bos>",
36
- eos_symbol: str = "<eos>",
37
- ):
38
- self.pad_symbol = pad_symbol
39
-
40
- self.add_eos = add_eos
41
- self.add_bos = add_bos
42
-
43
- self.bos_symbol = bos_symbol
44
- self.eos_symbol = eos_symbol
45
-
46
- unique_tokens = (
47
- [pad_symbol]
48
- + ([bos_symbol] if add_bos else [])
49
- + ([eos_symbol] if add_eos else [])
50
- + sorted(text_tokens)
51
- )
52
-
53
- self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54
- self.idx2token = [token for token in unique_tokens]
55
-
56
- def index(
57
- self, tokens_list: List[str]
58
- ) -> Tuple[torch.Tensor, torch.Tensor]:
59
- seqs, seq_lens = [], []
60
- for tokens in tokens_list:
61
- assert (
62
- all([True if s in self.token2idx else False for s in tokens])
63
- is True
64
- )
65
- seq = (
66
- ([self.bos_symbol] if self.add_bos else [])
67
- + list(tokens)
68
- + ([self.eos_symbol] if self.add_eos else [])
69
- )
70
- seqs.append(seq)
71
- seq_lens.append(len(seq))
72
-
73
- max_len = max(seq_lens)
74
- for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75
- seq.extend([self.pad_symbol] * (max_len - seq_len))
76
-
77
- tokens = torch.from_numpy(
78
- np.array(
79
- [[self.token2idx[token] for token in seq] for seq in seqs],
80
- dtype=np.int64,
81
- )
82
- )
83
- tokens_lens = torch.IntTensor(seq_lens)
84
-
85
- return tokens, tokens_lens
86
-
87
- def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88
- tokens_seqs = [[p for p in text] for text in texts]
89
- max_len = len(max(tokens_seqs, key=len))
90
-
91
- seqs = [
92
- ([self.bos_symbol] if self.add_bos else [])
93
- + list(seq)
94
- + ([self.eos_symbol] if self.add_eos else [])
95
- + [self.pad_symbol] * (max_len - len(seq))
96
- for seq in tokens_seqs
97
- ]
98
-
99
- tokens_batch = torch.from_numpy(
100
- np.array(
101
- [seq for seq in seqs],
102
- dtype=np.int64,
103
- )
104
- )
105
-
106
- tokens_lens = torch.IntTensor(
107
- [
108
- len(seq) + int(self.add_eos) + int(self.add_bos)
109
- for seq in tokens_seqs
110
- ]
111
- )
112
-
113
- return tokens_batch, tokens_lens
114
-
115
-
116
- def get_text_token_collater() -> TextTokenCollater:
117
- collater = TextTokenCollater(
118
- ['0'], add_bos=False, add_eos=False
119
- )
120
- return collater
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/datamodule.py DELETED
@@ -1,419 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # See ../../../../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
-
18
- import argparse
19
- import inspect
20
- import logging
21
- from functools import lru_cache
22
- from pathlib import Path
23
- from typing import Any, Dict, Optional
24
-
25
- import torch
26
- # from icefall.utils import str2bool
27
- # from lhotse import CutSet, load_manifest_lazy
28
- # from lhotse.dataset import (
29
- # CutConcatenate,
30
- # DynamicBucketingSampler,
31
- # PrecomputedFeatures,
32
- # SingleCutSampler,
33
- # SpecAugment,
34
- # )
35
- # from lhotse.dataset.input_strategies import OnTheFlyFeatures
36
- # from lhotse.utils import fix_random_seed
37
- from torch.utils.data import DataLoader
38
-
39
- from data.collation import get_text_token_collater
40
- # from data.dataset import SpeechSynthesisDataset
41
- from data.fbank import get_fbank_extractor
42
- from data.input_strategies import PromptedPrecomputedFeatures
43
-
44
- # PrecomputedFeatures = PrecomputedFeatures
45
-
46
-
47
- class _SeedWorkers:
48
- def __init__(self, seed: int):
49
- self.seed = seed
50
-
51
- def __call__(self, worker_id: int):
52
- fix_random_seed(self.seed + worker_id)
53
-
54
-
55
- def _get_input_strategy(input_strategy, dataset, cuts):
56
- if input_strategy == "PromptedPrecomputedFeatures":
57
- return PromptedPrecomputedFeatures(dataset, cuts)
58
-
59
- return eval(input_strategy)()
60
-
61
-
62
- class TtsDataModule:
63
- """
64
- DataModule for VALL-E TTS experiments.
65
- It assumes there is always one train and valid dataloader.
66
-
67
- It contains all the common data pipeline modules used in TTS
68
- experiments, e.g.:
69
- - dynamic batch size,
70
- - bucketing samplers,
71
- - cut concatenation[not used & tested yet],
72
- - augmentation[not used & tested yet],
73
- - on-the-fly feature extraction[not used & tested yet]
74
-
75
- This class should be derived for specific corpora used in TTS tasks.
76
- """
77
-
78
- def __init__(self, args: argparse.Namespace):
79
- self.args = args
80
-
81
- @classmethod
82
- def add_arguments(cls, parser: argparse.ArgumentParser):
83
- group = parser.add_argument_group(
84
- title="TTS data related options",
85
- description="These options are used for the preparation of "
86
- "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87
- "effective batch sizes, sampling strategies, applied data "
88
- "augmentations, etc.",
89
- )
90
- group.add_argument(
91
- "--manifest-dir",
92
- type=Path,
93
- default=Path("data/tokenized"),
94
- help="Path to directory with train/valid/test cuts.",
95
- )
96
- group.add_argument(
97
- "--max-duration",
98
- type=int,
99
- default=40.0,
100
- help="Maximum pooled recordings duration (seconds) in a "
101
- "single batch. You can reduce it if it causes CUDA OOM.",
102
- )
103
- group.add_argument(
104
- "--bucketing-sampler",
105
- type=str2bool,
106
- default=True,
107
- help="When enabled, the batches will come from buckets of "
108
- "similar duration (saves padding frames).",
109
- )
110
- group.add_argument(
111
- "--num-buckets",
112
- type=int,
113
- default=10,
114
- help="The number of buckets for the DynamicBucketingSampler"
115
- "(you might want to increase it for larger datasets).",
116
- )
117
- group.add_argument(
118
- "--concatenate-cuts",
119
- type=str2bool,
120
- default=False,
121
- help="When enabled, utterances (cuts) will be concatenated "
122
- "to minimize the amount of padding.",
123
- )
124
- group.add_argument(
125
- "--duration-factor",
126
- type=float,
127
- default=1.0,
128
- help="Determines the maximum duration of a concatenated cut "
129
- "relative to the duration of the longest cut in a batch.",
130
- )
131
- group.add_argument(
132
- "--gap",
133
- type=float,
134
- default=0.1,
135
- help="The amount of padding (in seconds) inserted between "
136
- "concatenated cuts. This padding is filled with noise when "
137
- "noise augmentation is used.",
138
- )
139
- group.add_argument(
140
- "--on-the-fly-feats",
141
- type=str2bool,
142
- default=False,
143
- help="When enabled, use on-the-fly cut mixing and feature "
144
- "extraction. Will drop existing precomputed feature manifests "
145
- "if available.",
146
- )
147
- group.add_argument(
148
- "--shuffle",
149
- type=str2bool,
150
- default=True,
151
- help="When enabled (=default), the examples will be "
152
- "shuffled for each epoch.",
153
- )
154
- group.add_argument(
155
- "--drop-last",
156
- type=str2bool,
157
- default=False,
158
- help="Whether to drop last batch. Used by sampler.",
159
- )
160
- group.add_argument(
161
- "--return-cuts",
162
- type=str2bool,
163
- default=True,
164
- help="When enabled, each batch will have the "
165
- "field: batch['supervisions']['cut'] with the cuts that "
166
- "were used to construct it.",
167
- )
168
-
169
- group.add_argument(
170
- "--num-workers",
171
- type=int,
172
- default=8,
173
- help="The number of training dataloader workers that "
174
- "collect the batches.",
175
- )
176
-
177
- group.add_argument(
178
- "--enable-spec-aug",
179
- type=str2bool,
180
- default=False,
181
- help="When enabled, use SpecAugment for training dataset.",
182
- )
183
-
184
- group.add_argument(
185
- "--spec-aug-time-warp-factor",
186
- type=int,
187
- default=80,
188
- help="Used only when --enable-spec-aug is True. "
189
- "It specifies the factor for time warping in SpecAugment. "
190
- "Larger values mean more warping. "
191
- "A value less than 1 means to disable time warp.",
192
- )
193
-
194
- group.add_argument(
195
- "--input-strategy",
196
- type=str,
197
- default="PrecomputedFeatures",
198
- help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
199
- )
200
-
201
- group.add_argument(
202
- "--dataset",
203
- type=str,
204
- default="ljspeech",
205
- help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
206
- )
207
-
208
- parser.add_argument(
209
- "--text-tokens",
210
- type=str,
211
- default="data/tokenized/unique_text_tokens.k2symbols",
212
- help="Path to the unique text tokens file",
213
- )
214
-
215
- parser.add_argument(
216
- "--sampling-rate",
217
- type=int,
218
- default=24000,
219
- help="""Audio sampling rate.""",
220
- )
221
-
222
- def train_dataloaders(
223
- self,
224
- cuts_train: CutSet,
225
- sampler_state_dict: Optional[Dict[str, Any]] = None,
226
- ) -> DataLoader:
227
- """
228
- Args:
229
- cuts_train:
230
- CutSet for training.
231
- sampler_state_dict:
232
- The state dict for the training sampler.
233
- """
234
- transforms = []
235
-
236
- if self.args.concatenate_cuts:
237
- logging.info(
238
- f"Using cut concatenation with duration factor "
239
- f"{self.args.duration_factor} and gap {self.args.gap}."
240
- )
241
- # Cut concatenation should be the first transform in the list,
242
- # so that if we e.g. mix noise in, it will fill the gaps between
243
- # different utterances.
244
- transforms = [
245
- CutConcatenate(
246
- duration_factor=self.args.duration_factor, gap=self.args.gap
247
- )
248
- ] + transforms
249
-
250
- input_transforms = []
251
- if self.args.enable_spec_aug:
252
- logging.info("Enable SpecAugment")
253
- logging.info(
254
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
255
- )
256
- # Set the value of num_frame_masks according to Lhotse's version.
257
- # In different Lhotse's versions, the default of num_frame_masks is
258
- # different.
259
- num_frame_masks = 10
260
- num_frame_masks_parameter = inspect.signature(
261
- SpecAugment.__init__
262
- ).parameters["num_frame_masks"]
263
- if num_frame_masks_parameter.default == 1:
264
- num_frame_masks = 2
265
- logging.info(f"Num frame mask: {num_frame_masks}")
266
- input_transforms.append(
267
- SpecAugment(
268
- time_warp_factor=self.args.spec_aug_time_warp_factor,
269
- num_frame_masks=num_frame_masks,
270
- features_mask_size=27,
271
- num_feature_masks=2,
272
- frames_mask_size=100,
273
- )
274
- )
275
- else:
276
- logging.info("Disable SpecAugment")
277
-
278
- logging.info("About to create train dataset")
279
- if self.args.on_the_fly_feats:
280
- # NOTE: the PerturbSpeed transform should be added only if we
281
- # remove it from data prep stage.
282
- # Add on-the-fly speed perturbation; since originally it would
283
- # have increased epoch size by 3, we will apply prob 2/3 and use
284
- # 3x more epochs.
285
- # Speed perturbation probably should come first before
286
- # concatenation, but in principle the transforms order doesn't have
287
- # to be strict (e.g. could be randomized)
288
- # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
289
- # Drop feats to be on the safe side.
290
- train = SpeechSynthesisDataset(
291
- get_text_token_collater(self.args.text_tokens),
292
- cut_transforms=transforms,
293
- feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
294
- feature_transforms=input_transforms,
295
- )
296
- else:
297
- train = SpeechSynthesisDataset(
298
- get_text_token_collater(self.args.text_tokens),
299
- feature_input_strategy=_get_input_strategy(
300
- self.args.input_strategy, self.args.dataset, cuts_train
301
- ),
302
- cut_transforms=transforms,
303
- feature_transforms=input_transforms,
304
- )
305
-
306
- if self.args.bucketing_sampler:
307
- logging.info("Using DynamicBucketingSampler")
308
- train_sampler = DynamicBucketingSampler(
309
- cuts_train,
310
- max_duration=self.args.max_duration,
311
- shuffle=self.args.shuffle,
312
- num_buckets=self.args.num_buckets,
313
- drop_last=self.args.drop_last,
314
- )
315
- else:
316
- logging.info(
317
- "Using SingleCutSampler and sort by duraton(ascending=True)."
318
- )
319
- cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
320
- train_sampler = SingleCutSampler(
321
- cuts_train,
322
- max_duration=self.args.max_duration,
323
- shuffle=self.args.shuffle,
324
- )
325
- logging.info("About to create train dataloader")
326
-
327
- if sampler_state_dict is not None:
328
- logging.info("Loading sampler state dict")
329
- train_sampler.load_state_dict(sampler_state_dict)
330
-
331
- # 'seed' is derived from the current random state, which will have
332
- # previously been set in the main process.
333
- seed = torch.randint(0, 100000, ()).item()
334
- worker_init_fn = _SeedWorkers(seed)
335
-
336
- train_dl = DataLoader(
337
- train,
338
- sampler=train_sampler,
339
- batch_size=None,
340
- num_workers=self.args.num_workers,
341
- persistent_workers=False,
342
- worker_init_fn=worker_init_fn,
343
- )
344
-
345
- return train_dl
346
-
347
- def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
348
- logging.info("About to create dev dataset")
349
- if self.args.on_the_fly_feats:
350
- validate = SpeechSynthesisDataset(
351
- get_text_token_collater(self.args.text_tokens),
352
- feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
353
- cut_transforms=[],
354
- )
355
- else:
356
- validate = SpeechSynthesisDataset(
357
- get_text_token_collater(self.args.text_tokens),
358
- feature_input_strategy=_get_input_strategy(
359
- self.args.input_strategy, self.args.dataset, cuts_valid
360
- ),
361
- cut_transforms=[],
362
- )
363
- valid_sampler = DynamicBucketingSampler(
364
- cuts_valid,
365
- max_duration=self.args.max_duration,
366
- shuffle=False,
367
- )
368
- logging.info("About to create dev dataloader")
369
- valid_dl = DataLoader(
370
- validate,
371
- sampler=valid_sampler,
372
- batch_size=None,
373
- num_workers=4,
374
- persistent_workers=False,
375
- )
376
-
377
- return valid_dl
378
-
379
- def test_dataloaders(self, cuts: CutSet) -> DataLoader:
380
- logging.debug("About to create test dataset")
381
- test = SpeechSynthesisDataset(
382
- get_text_token_collater(self.args.text_tokens),
383
- feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
384
- if self.args.on_the_fly_feats
385
- else _get_input_strategy(
386
- self.args.input_strategy, self.args.dataset, cuts
387
- ),
388
- cut_transforms=[],
389
- )
390
- sampler = DynamicBucketingSampler(
391
- cuts,
392
- max_duration=self.args.max_duration,
393
- shuffle=False,
394
- )
395
- logging.debug("About to create test dataloader")
396
- test_dl = DataLoader(
397
- test,
398
- batch_size=None,
399
- sampler=sampler,
400
- num_workers=self.args.num_workers,
401
- )
402
- return test_dl
403
-
404
- @lru_cache()
405
- def train_cuts(self) -> CutSet:
406
- logging.info("About to get train cuts")
407
- return load_manifest_lazy(
408
- self.args.manifest_dir / "cuts_train.jsonl.gz"
409
- )
410
-
411
- @lru_cache()
412
- def dev_cuts(self) -> CutSet:
413
- logging.info("About to get dev cuts")
414
- return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
415
-
416
- @lru_cache()
417
- def test_cuts(self) -> CutSet:
418
- logging.info("About to get test cuts")
419
- return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/dataset.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # See ../../../../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- """
18
- modified from lhoste.dataset.speech_synthesis.py
19
- """
20
-
21
- import torch
22
- import math
23
- import h5py
24
- from tokenizers import Tokenizer
25
- from typing import Union, List
26
- import numpy as np
27
- from tqdm import tqdm
28
-
29
- _pad = '_'
30
- _punctuation = ',.!?-~…'
31
- _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
32
- symbols = [_pad] + list(_punctuation) + list(_letters)
33
-
34
- language_dict = {
35
- 'en': 0,
36
- 'zh': 1,
37
- 'ja': 2,
38
- }
39
- def seq2phone(tokens: Union[List, np.ndarray]):
40
- """
41
- Convert tokenized phoneme ID sequence back to phoneme string
42
- :param tokens: phoneme tokens
43
- :return: recovered phoneme sequence
44
- """
45
- phones = "".join([symbols[i] for i in tokens])
46
- return phones
47
-
48
- class DynamicBatchSampler(torch.utils.data.Sampler):
49
- def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
50
- max_tokens=None, max_sentences=None, drop_last=False):
51
- """
52
-
53
- :param sampler:
54
- :param num_tokens_fn: 根据idx返回样本的长度的函数
55
- :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
56
- :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
57
- :param max_size: 最大长度的样本
58
- :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
59
- """
60
- super(DynamicBatchSampler, self).__init__(sampler)
61
- self.sampler = sampler
62
- self.num_tokens_fn = num_tokens_fn
63
- self.num_buckets = num_buckets
64
-
65
- self.min_size = min_size
66
- self.max_size = max_size
67
-
68
- assert max_size <= max_tokens, "max_size should be smaller than max tokens"
69
- assert max_tokens is not None or max_sentences is not None, \
70
- "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
71
- self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
72
- self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
73
- self.drop_last = drop_last
74
-
75
- def set_epoch(self, epoch):
76
- self.sampler.set_epoch(epoch)
77
- def is_batch_full(self, num_tokens, batch):
78
- if len(batch) == 0:
79
- return False
80
- if len(batch) == self.max_sentences:
81
- return True
82
- if num_tokens > self.max_tokens:
83
- return True
84
- return False
85
-
86
- def __iter__(self):
87
- buckets = [[] for _ in range(self.num_buckets)]
88
- sample_len = [0] * self.num_buckets
89
-
90
- for idx in self.sampler:
91
- idx_length = self.num_tokens_fn(idx)
92
- if not (self.min_size <= idx_length <= self.max_size):
93
- print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
94
- continue
95
-
96
- index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
97
- * self.num_buckets)
98
- sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
99
-
100
- num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
101
- if self.is_batch_full(num_tokens, buckets[index_buckets]):
102
- # yield this batch
103
- yield buckets[index_buckets]
104
- buckets[index_buckets] = []
105
- sample_len[index_buckets] = 0
106
-
107
- buckets[index_buckets].append(idx)
108
-
109
- # process left-over
110
- leftover_batch = []
111
- leftover_sample_len = 0
112
- leftover = [idx for bucket in buckets for idx in bucket]
113
- for idx in leftover:
114
- idx_length = self.num_tokens_fn(idx)
115
- leftover_sample_len = max(leftover_sample_len, idx_length)
116
- num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
117
- if self.is_batch_full(num_tokens, leftover_batch):
118
- yield leftover_batch
119
- leftover_batch = []
120
- leftover_sample_len = 0
121
- leftover_batch.append(idx)
122
-
123
- if len(leftover_batch) > 0 and not self.drop_last:
124
- yield leftover_batch
125
-
126
- def __len__(self):
127
- # we do not know the exactly batch size, so do not call len(dataloader)
128
- pass
129
-
130
-
131
- class AudioDataset(torch.utils.data.Dataset):
132
- def __init__(self, h5_path, ann_path, tokenizer_path):
133
- self.h5_path = h5_path
134
- with open(ann_path, 'r', encoding='utf-8') as f:
135
- lines = f.readlines()
136
- ls = [l.split("|") for l in lines]
137
- ls_T = list(zip(*ls))
138
- del ls_T[-1]
139
- self.h5_paths, self.durations, self.langs, self.texts = \
140
- list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
141
- self.durations = [float(dur) for dur in self.durations]
142
- self.tokenizer = Tokenizer.from_file(tokenizer_path)
143
-
144
- self._archive = None
145
-
146
- def __len__(self):
147
- return len(self.h5_paths)
148
-
149
- def get_dur(self, idx):
150
- return self.durations[idx]
151
-
152
- @property
153
- def archive(self):
154
- if self._archive is None: # lazy loading here!
155
- self._archive = h5py.File(self.h5_path, "r")
156
- return self._archive
157
- def __getitem__(self, idx):
158
- archive = self.archive
159
- h5_path = self.h5_paths[idx]
160
- sub = archive[h5_path]
161
- audio_tokens = sub['audio'][()]
162
- phone_tokens = sub['text'][()]
163
- dur = self.durations[idx]
164
- lang = self.langs[idx]
165
- text = self.texts[idx]
166
- # tokenization should be done within dataloader
167
- phones = seq2phone(phone_tokens)
168
- phones = phones.replace(" ", "_")
169
- if not len(phones):
170
- cptpho_tokens = self.tokenizer.encode(text).ids
171
- else:
172
- cptpho_tokens = self.tokenizer.encode(phones).ids
173
- assert len(cptpho_tokens)
174
- return {
175
- 'utt_id': h5_path,
176
- 'text': text,
177
- 'audio': None,
178
- 'audio_lens': None,
179
- 'audio_features': audio_tokens,
180
- 'audio_features_lens': len(audio_tokens.T),
181
- 'text_tokens': np.array(cptpho_tokens),
182
- 'text_tokens_lens': len(cptpho_tokens),
183
- 'language': language_dict[lang],
184
- }
185
-
186
- def collate(batch):
187
- utt_id_s = [b['utt_id'] for b in batch]
188
- text_s = [b['text'] for b in batch]
189
-
190
- audio_s = [b['audio'] for b in batch]
191
- audio_lens_s = [b['audio_lens'] for b in batch]
192
-
193
- audio_features_lens_s = [b['audio_features_lens'] for b in batch]
194
- # create an empty tensor with maximum audio feature length
195
- audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
196
-
197
- text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
198
- # create an empty tensor with maximum text tokens length
199
- text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
200
-
201
- language_s = [b['language'] for b in batch]
202
-
203
- for i, b in enumerate(batch):
204
- audio_features = b['audio_features']
205
- audio_features_lens = b['audio_features_lens']
206
- audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
207
-
208
- text_tokens = b['text_tokens']
209
- text_tokens_lens = b['text_tokens_lens']
210
- text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
211
-
212
- batch = {
213
- 'utt_id': utt_id_s,
214
- 'text': text_s,
215
- 'audio': audio_s,
216
- 'audio_lens': audio_lens_s,
217
- 'audio_features': audio_features_s,
218
- 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
219
- 'text_tokens': text_tokens_s,
220
- 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
221
- 'languages': torch.LongTensor(np.array(language_s)),
222
- }
223
- return batch
224
-
225
- def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
226
- train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
227
- ann_path=f"{data_dir}/audio_ann_sum.txt",
228
- tokenizer_path=f"{data_dir}/bpe_69.json")
229
- ran_sampler = torch.utils.data.distributed.DistributedSampler(
230
- train_dataset,
231
- num_replicas=n_gpus,
232
- rank=rank,
233
- shuffle=True,
234
- )
235
- dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
236
- max_tokens=max_duration)
237
-
238
-
239
- train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
240
- batch_sampler=dynamic_sampler)
241
-
242
- return train_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/fbank.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # See ../../../../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
-
18
- from dataclasses import asdict, dataclass
19
- from typing import Any, Dict, Optional, Union
20
-
21
- import numpy as np
22
- import torch
23
- # from lhotse.features.base import FeatureExtractor
24
- # from lhotse.utils import EPSILON, Seconds, compute_num_frames
25
- from librosa.filters import mel as librosa_mel_fn
26
-
27
-
28
- @dataclass
29
- class BigVGANFbankConfig:
30
- # Spectogram-related part
31
- # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32
- frame_length: Seconds = 1024 / 24000.0
33
- frame_shift: Seconds = 256 / 24000.0
34
- remove_dc_offset: bool = True
35
- round_to_power_of_two: bool = True
36
-
37
- # Fbank-related part
38
- low_freq: float = 0.0
39
- high_freq: float = 12000.0
40
- num_mel_bins: int = 100
41
- use_energy: bool = False
42
-
43
- def to_dict(self) -> Dict[str, Any]:
44
- return asdict(self)
45
-
46
- @staticmethod
47
- def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48
- return BigVGANFbankConfig(**data)
49
-
50
-
51
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52
- return torch.log(torch.clamp(x, min=clip_val) * C)
53
-
54
-
55
- def spectral_normalize_torch(magnitudes):
56
- output = dynamic_range_compression_torch(magnitudes)
57
- return output
58
-
59
-
60
- # https://github.com/NVIDIA/BigVGAN
61
- # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62
- class BigVGANFbank(FeatureExtractor):
63
- name = "fbank"
64
- config_type = BigVGANFbankConfig
65
-
66
- def __init__(self, config: Optional[Any] = None):
67
- super(BigVGANFbank, self).__init__(config)
68
- sampling_rate = 24000
69
- self.mel_basis = torch.from_numpy(
70
- librosa_mel_fn(
71
- sampling_rate,
72
- 1024,
73
- self.config.num_mel_bins,
74
- self.config.low_freq,
75
- self.config.high_freq,
76
- ).astype(np.float32)
77
- )
78
- self.hann_window = torch.hann_window(1024)
79
-
80
- def _feature_fn(self, samples, **kwargs):
81
- win_length, n_fft = 1024, 1024
82
- hop_size = 256
83
- if True:
84
- sampling_rate = 24000
85
- duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86
- expected_num_frames = compute_num_frames(
87
- duration=duration,
88
- frame_shift=self.frame_shift,
89
- sampling_rate=sampling_rate,
90
- )
91
- pad_size = (
92
- (expected_num_frames - 1) * hop_size
93
- + win_length
94
- - samples.shape[-1]
95
- )
96
- assert pad_size >= 0
97
-
98
- y = torch.nn.functional.pad(
99
- samples,
100
- (0, pad_size),
101
- mode="constant",
102
- )
103
- else:
104
- y = torch.nn.functional.pad(
105
- samples,
106
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107
- mode="reflect",
108
- )
109
-
110
- y = y.squeeze(1)
111
-
112
- # complex tensor as default, then use view_as_real for future pytorch compatibility
113
- spec = torch.stft(
114
- y,
115
- n_fft,
116
- hop_length=hop_size,
117
- win_length=win_length,
118
- window=self.hann_window,
119
- center=False,
120
- pad_mode="reflect",
121
- normalized=False,
122
- onesided=True,
123
- return_complex=True,
124
- )
125
- spec = torch.view_as_real(spec)
126
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127
-
128
- spec = torch.matmul(self.mel_basis, spec)
129
- spec = spectral_normalize_torch(spec)
130
-
131
- return spec.transpose(2, 1).squeeze(0)
132
-
133
- def extract(
134
- self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135
- ) -> np.ndarray:
136
- assert sampling_rate == 24000
137
- params = asdict(self.config)
138
- params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139
- params["frame_shift"] *= 1000.0
140
- params["frame_length"] *= 1000.0
141
- if not isinstance(samples, torch.Tensor):
142
- samples = torch.from_numpy(samples)
143
- # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144
- if len(samples.shape) == 1:
145
- samples = samples.unsqueeze(0)
146
- features = self._feature_fn(samples, **params).to(torch.float32)
147
- return features.numpy()
148
-
149
- @property
150
- def frame_shift(self) -> Seconds:
151
- return self.config.frame_shift
152
-
153
- def feature_dim(self, sampling_rate: int) -> int:
154
- return self.config.num_mel_bins
155
-
156
- @staticmethod
157
- def mix(
158
- features_a: np.ndarray,
159
- features_b: np.ndarray,
160
- energy_scaling_factor_b: float,
161
- ) -> np.ndarray:
162
- return np.log(
163
- np.maximum(
164
- # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165
- EPSILON,
166
- np.exp(features_a)
167
- + energy_scaling_factor_b * np.exp(features_b),
168
- )
169
- )
170
-
171
- @staticmethod
172
- def compute_energy(features: np.ndarray) -> float:
173
- return float(np.sum(np.exp(features)))
174
-
175
-
176
- def get_fbank_extractor() -> BigVGANFbank:
177
- return BigVGANFbank(BigVGANFbankConfig())
178
-
179
-
180
- if __name__ == "__main__":
181
- extractor = BigVGANFbank(BigVGANFbankConfig())
182
-
183
- samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184
- samples = torch.clip(samples, -1.0, 1.0)
185
- fbank = extractor.extract(samples, 24000.0)
186
- print(f"fbank {fbank.shape}")
187
-
188
- from scipy.io.wavfile import read
189
-
190
- MAX_WAV_VALUE = 32768.0
191
-
192
- sampling_rate, samples = read(
193
- "egs/libritts/prompts/5639_40744_000000_000002.wav"
194
- )
195
- print(f"samples: [{samples.min()}, {samples.max()}]")
196
- fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197
- print(f"fbank {fbank.shape}")
198
-
199
- import matplotlib.pyplot as plt
200
-
201
- _ = plt.figure(figsize=(18, 10))
202
- plt.imshow(
203
- X=fbank.transpose(1, 0),
204
- cmap=plt.get_cmap("jet"),
205
- aspect="auto",
206
- interpolation="nearest",
207
- )
208
- plt.gca().invert_yaxis()
209
- plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210
- plt.close()
211
-
212
- print("fbank test PASS!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/input_strategies.py DELETED
@@ -1,159 +0,0 @@
1
- import random
2
- from collections import defaultdict
3
- from concurrent.futures import ThreadPoolExecutor
4
- from typing import Tuple, Type
5
-
6
- # from lhotse import CutSet
7
- # from lhotse.dataset.collation import collate_features
8
- # from lhotse.dataset.input_strategies import (
9
- # ExecutorType,
10
- # PrecomputedFeatures,
11
- # _get_executor,
12
- # )
13
- # from lhotse.utils import fastcopy
14
-
15
-
16
- class PromptedFeatures:
17
- def __init__(self, prompts, features):
18
- self.prompts = prompts
19
- self.features = features
20
-
21
- def to(self, device):
22
- return PromptedFeatures(
23
- self.prompts.to(device), self.features.to(device)
24
- )
25
-
26
- def sum(self):
27
- return self.features.sum()
28
-
29
- @property
30
- def ndim(self):
31
- return self.features.ndim
32
-
33
- @property
34
- def data(self):
35
- return (self.prompts, self.features)
36
-
37
-
38
- # class PromptedPrecomputedFeatures(PrecomputedFeatures):
39
- # """
40
- # :class:`InputStrategy` that reads pre-computed features, whose manifests
41
- # are attached to cuts, from disk.
42
- #
43
- # It automatically pads the feature matrices with pre or post feature.
44
- #
45
- # .. automethod:: __call__
46
- # """
47
- #
48
- # def __init__(
49
- # self,
50
- # dataset: str,
51
- # cuts: CutSet,
52
- # num_workers: int = 0,
53
- # executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54
- # ) -> None:
55
- # super(PromptedPrecomputedFeatures, self).__init__(
56
- # num_workers, executor_type
57
- # )
58
- #
59
- # self.utt2neighbors = defaultdict(lambda: [])
60
- #
61
- # if dataset.lower() == "libritts":
62
- # # 909_131041_000013_000002
63
- # # 909_131041_000013_000003
64
- # speaker2utts = defaultdict(lambda: [])
65
- #
66
- # utt2cut = {}
67
- # for cut in cuts:
68
- # speaker = cut.supervisions[0].speaker
69
- # speaker2utts[speaker].append(cut.id)
70
- # utt2cut[cut.id] = cut
71
- #
72
- # for spk in speaker2utts:
73
- # uttids = sorted(speaker2utts[spk])
74
- # # Using the property of sorted keys to find previous utterance
75
- # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76
- # if len(uttids) == 1:
77
- # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78
- # continue
79
- #
80
- # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81
- # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82
- #
83
- # for utt in utt2prevutt:
84
- # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85
- #
86
- # for utt in utt2postutt:
87
- # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88
- # elif dataset.lower() == "ljspeech":
89
- # utt2cut = {}
90
- # uttids = []
91
- # for cut in cuts:
92
- # uttids.append(cut.id)
93
- # utt2cut[cut.id] = cut
94
- #
95
- # if len(uttids) == 1:
96
- # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97
- # else:
98
- # # Using the property of sorted keys to find previous utterance
99
- # # The keys has structure: LJ001-0010
100
- # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101
- # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102
- #
103
- # for utt in utt2postutt:
104
- # postutt = utt2postutt[utt]
105
- # if utt[:5] == postutt[:5]:
106
- # self.utt2neighbors[utt].append(utt2cut[postutt])
107
- #
108
- # for utt in utt2prevutt:
109
- # prevutt = utt2prevutt[utt]
110
- # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111
- # self.utt2neighbors[utt].append(utt2cut[prevutt])
112
- # else:
113
- # raise ValueError
114
- #
115
- # def __call__(
116
- # self, cuts: CutSet
117
- # ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118
- # """
119
- # Reads the pre-computed features from disk/other storage.
120
- # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121
- #
122
- # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123
- # """
124
- # features, features_lens = collate_features(
125
- # cuts,
126
- # executor=_get_executor(
127
- # self.num_workers, executor_type=self._executor_type
128
- # ),
129
- # )
130
- #
131
- # prompts_cuts = []
132
- # for k, cut in enumerate(cuts):
133
- # prompts_cut = random.choice(self.utt2neighbors[cut.id])
134
- # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135
- #
136
- # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137
- # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138
- # # max_duration=mini_duration,
139
- # # offset_type="random",
140
- # # preserve_id=True,
141
- # # )
142
- # prompts_cuts = CutSet(
143
- # cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144
- # ).truncate(
145
- # max_duration=mini_duration,
146
- # offset_type="random",
147
- # preserve_id=False,
148
- # )
149
- #
150
- # prompts, prompts_lens = collate_features(
151
- # prompts_cuts,
152
- # executor=_get_executor(
153
- # self.num_workers, executor_type=self._executor_type
154
- # ),
155
- # )
156
- #
157
- # return PromptedFeatures(prompts, features), PromptedFeatures(
158
- # prompts_lens, features_lens
159
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/tokenizer.py DELETED
@@ -1,126 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright 2023 (authors: Feiteng Li)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import re
17
- from dataclasses import asdict, dataclass
18
- from typing import Any, Dict, List, Optional, Pattern, Union
19
-
20
- import numpy as np
21
- import torch
22
- import torchaudio
23
- from encodec import EncodecModel
24
- from encodec.utils import convert_audio
25
-
26
- try:
27
- from pypinyin import Style, pinyin
28
- from pypinyin.style._utils import get_finals, get_initials
29
- except Exception:
30
- pass
31
-
32
-
33
- def remove_encodec_weight_norm(model):
34
- from encodec.modules import SConv1d
35
- from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
36
- from torch.nn.utils import remove_weight_norm
37
-
38
- encoder = model.encoder.model
39
- for key in encoder._modules:
40
- if isinstance(encoder._modules[key], SEANetResnetBlock):
41
- remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
42
- block_modules = encoder._modules[key].block._modules
43
- for skey in block_modules:
44
- if isinstance(block_modules[skey], SConv1d):
45
- remove_weight_norm(block_modules[skey].conv.conv)
46
- elif isinstance(encoder._modules[key], SConv1d):
47
- remove_weight_norm(encoder._modules[key].conv.conv)
48
-
49
- decoder = model.decoder.model
50
- for key in decoder._modules:
51
- if isinstance(decoder._modules[key], SEANetResnetBlock):
52
- remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
53
- block_modules = decoder._modules[key].block._modules
54
- for skey in block_modules:
55
- if isinstance(block_modules[skey], SConv1d):
56
- remove_weight_norm(block_modules[skey].conv.conv)
57
- elif isinstance(decoder._modules[key], SConvTranspose1d):
58
- remove_weight_norm(decoder._modules[key].convtr.convtr)
59
- elif isinstance(decoder._modules[key], SConv1d):
60
- remove_weight_norm(decoder._modules[key].conv.conv)
61
-
62
-
63
- class AudioTokenizer:
64
- """EnCodec audio."""
65
-
66
- def __init__(
67
- self,
68
- device: Any = None,
69
- ) -> None:
70
- # Instantiate a pretrained EnCodec model
71
- model = EncodecModel.encodec_model_24khz()
72
- model.set_target_bandwidth(6.0)
73
- remove_encodec_weight_norm(model)
74
-
75
- if not device:
76
- device = torch.device("cpu")
77
- if torch.cuda.is_available():
78
- device = torch.device("cuda:0")
79
- if torch.backends.mps.is_available():
80
- device = torch.device("mps")
81
-
82
- self._device = device
83
-
84
- self.codec = model.to(device)
85
- self.sample_rate = model.sample_rate
86
- self.channels = model.channels
87
-
88
- @property
89
- def device(self):
90
- return self._device
91
-
92
- def encode(self, wav: torch.Tensor) -> torch.Tensor:
93
- return self.codec.encode(wav.to(self.device))
94
-
95
- def decode(self, frames: torch.Tensor) -> torch.Tensor:
96
- return self.codec.decode(frames)
97
-
98
-
99
- def tokenize_audio(tokenizer: AudioTokenizer, audio):
100
- # Load and pre-process the audio waveform
101
- if isinstance(audio, str):
102
- wav, sr = torchaudio.load(audio)
103
- else:
104
- wav, sr = audio
105
- wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
106
- wav = wav.unsqueeze(0)
107
-
108
- # Extract discrete codes from EnCodec
109
- with torch.no_grad():
110
- encoded_frames = tokenizer.encode(wav)
111
- return encoded_frames
112
-
113
-
114
- if __name__ == "__main__":
115
- model = EncodecModel.encodec_model_24khz()
116
- model.set_target_bandwidth(6.0)
117
-
118
- samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
119
- torch.float32
120
- )
121
- codes_raw = model.encode(samples)
122
-
123
- remove_encodec_weight_norm(model)
124
- codes_norm = model.encode(samples)
125
-
126
- assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
macros.py DELETED
@@ -1,44 +0,0 @@
1
- NUM_LAYERS = 12
2
- NUM_HEAD = 16
3
- N_DIM = 1024
4
- PREFIX_MODE = 1
5
- NUM_QUANTIZERS = 8
6
- SAMPLE_RATE = 24000
7
-
8
- lang2token = {
9
- 'zh': "[ZH]",
10
- 'ja': "[JA]",
11
- "en": "[EN]",
12
- "AR": "[AR]",
13
- 'mix': "",
14
- }
15
-
16
- lang2code = {
17
- 'zh': 0,
18
- 'ja': 1,
19
- "en": 2,
20
- "ar": 3,
21
- }
22
-
23
- token2lang = {
24
- '[ZH]': "zh",
25
- '[JA]': "ja",
26
- "[EN]": "en",
27
- "[AR]": "ar",
28
- "": "mix"
29
- }
30
-
31
- code2lang = {
32
- 0: 'zh',
33
- 1: 'ja',
34
- 2: "en",
35
- 3: "ar",
36
- }
37
-
38
- langdropdown2token = {
39
- 'English': "[EN]",
40
- '中文': "[ZH]",
41
- '日本語': "[JA]",
42
- 'عربي':"[AR]",
43
- 'Mix': "",
44
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py DELETED
@@ -1,136 +0,0 @@
1
- import argparse
2
-
3
- import torch.nn as nn
4
- # from icefall.utils import AttributeDict, str2bool
5
-
6
- from .macros import (
7
- NUM_AUDIO_TOKENS,
8
- NUM_MEL_BINS,
9
- NUM_SPEAKER_CLASSES,
10
- NUM_TEXT_TOKENS,
11
- SPEAKER_EMBEDDING_DIM,
12
- )
13
- from .transformer import Transformer
14
- from .vallex import VALLE, VALLF
15
- from .visualizer import visualize
16
-
17
-
18
- def add_model_arguments(parser: argparse.ArgumentParser):
19
- parser.add_argument(
20
- "--model-name",
21
- type=str,
22
- default="VALL-E",
23
- help="VALL-E, VALL-F, Transformer.",
24
- )
25
- parser.add_argument(
26
- "--decoder-dim",
27
- type=int,
28
- default=1024,
29
- help="Embedding dimension in the decoder model.",
30
- )
31
- parser.add_argument(
32
- "--nhead",
33
- type=int,
34
- default=16,
35
- help="Number of attention heads in the Decoder layers.",
36
- )
37
- parser.add_argument(
38
- "--num-decoder-layers",
39
- type=int,
40
- default=12,
41
- help="Number of Decoder layers.",
42
- )
43
- parser.add_argument(
44
- "--scale-factor",
45
- type=float,
46
- default=1.0,
47
- help="Model scale factor which will be assigned different meanings in different models.",
48
- )
49
- parser.add_argument(
50
- "--norm-first",
51
- type=bool,
52
- default=True,
53
- help="Pre or Post Normalization.",
54
- )
55
- parser.add_argument(
56
- "--add-prenet",
57
- type=bool,
58
- default=False,
59
- help="Whether add PreNet after Inputs.",
60
- )
61
-
62
- # VALL-E & F
63
- parser.add_argument(
64
- "--prefix-mode",
65
- type=int,
66
- default=1,
67
- help="The mode for how to prefix VALL-E NAR Decoder, "
68
- "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
69
- )
70
- parser.add_argument(
71
- "--share-embedding",
72
- type=bool,
73
- default=True,
74
- help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
75
- )
76
- parser.add_argument(
77
- "--prepend-bos",
78
- type=bool,
79
- default=False,
80
- help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
81
- )
82
- parser.add_argument(
83
- "--num-quantizers",
84
- type=int,
85
- default=8,
86
- help="Number of Audio/Semantic quantization layers.",
87
- )
88
-
89
- # Transformer
90
- parser.add_argument(
91
- "--scaling-xformers",
92
- type=bool,
93
- default=False,
94
- help="Apply Reworked Conformer scaling on Transformers.",
95
- )
96
-
97
-
98
- def get_model(params) -> nn.Module:
99
- if params.model_name.lower() in ["vall-f", "vallf"]:
100
- model = VALLF(
101
- params.decoder_dim,
102
- params.nhead,
103
- params.num_decoder_layers,
104
- norm_first=params.norm_first,
105
- add_prenet=params.add_prenet,
106
- prefix_mode=params.prefix_mode,
107
- share_embedding=params.share_embedding,
108
- nar_scale_factor=params.scale_factor,
109
- prepend_bos=params.prepend_bos,
110
- num_quantizers=params.num_quantizers,
111
- )
112
- elif params.model_name.lower() in ["vall-e", "valle"]:
113
- model = VALLE(
114
- params.decoder_dim,
115
- params.nhead,
116
- params.num_decoder_layers,
117
- norm_first=params.norm_first,
118
- add_prenet=params.add_prenet,
119
- prefix_mode=params.prefix_mode,
120
- share_embedding=params.share_embedding,
121
- nar_scale_factor=params.scale_factor,
122
- prepend_bos=params.prepend_bos,
123
- num_quantizers=params.num_quantizers,
124
- )
125
- else:
126
- assert params.model_name in ["Transformer"]
127
- model = Transformer(
128
- params.decoder_dim,
129
- params.nhead,
130
- params.num_decoder_layers,
131
- norm_first=params.norm_first,
132
- add_prenet=params.add_prenet,
133
- scaling_xformers=params.scaling_xformers,
134
- )
135
-
136
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (4.4 kB)
 
models/__pycache__/macros.cpython-311.pyc DELETED
Binary file (335 Bytes)
 
models/__pycache__/transformer.cpython-311.pyc DELETED
Binary file (15.1 kB)
 
models/__pycache__/vallex.cpython-311.pyc DELETED
Binary file (37.6 kB)
 
models/__pycache__/visualizer.cpython-311.pyc DELETED
Binary file (5.17 kB)
 
models/macros.py DELETED
@@ -1,11 +0,0 @@
1
- # Text
2
- NUM_TEXT_TOKENS = 2048
3
-
4
- # Audio
5
- NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6
- NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7
-
8
-
9
- # Speaker
10
- NUM_SPEAKER_CLASSES = 4096
11
- SPEAKER_EMBEDDING_DIM = 64
 
 
 
 
 
 
 
 
 
 
 
 
models/transformer.py DELETED
@@ -1,394 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from functools import partial
16
- from typing import Any, Dict, List, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- # from icefall.utils import make_pad_mask
22
- # from torchmetrics.classification import BinaryAccuracy
23
-
24
- from models.vallex import Transpose
25
- from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26
- from modules.scaling import BalancedDoubleSwish, ScaledLinear
27
- from modules.transformer import (
28
- BalancedBasicNorm,
29
- IdentityNorm,
30
- TransformerDecoderLayer,
31
- TransformerEncoder,
32
- TransformerEncoderLayer,
33
- )
34
-
35
- from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36
- from .visualizer import visualize
37
-
38
- IdentityNorm = IdentityNorm
39
-
40
-
41
- class Transformer(nn.Module):
42
- """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43
- Neural Speech Synthesis with Transformer Network
44
- https://arxiv.org/abs/1809.08895
45
- """
46
-
47
- def __init__(
48
- self,
49
- d_model: int,
50
- nhead: int,
51
- num_layers: int,
52
- norm_first: bool = True,
53
- add_prenet: bool = False,
54
- scaling_xformers: bool = False,
55
- ):
56
- """
57
- Args:
58
- d_model:
59
- The number of expected features in the input (required).
60
- nhead:
61
- The number of heads in the multiheadattention models (required).
62
- num_layers:
63
- The number of sub-decoder-layers in the decoder (required).
64
- """
65
- super().__init__()
66
- self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67
-
68
- if add_prenet:
69
- self.encoder_prenet = nn.Sequential(
70
- Transpose(),
71
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72
- nn.BatchNorm1d(d_model),
73
- nn.ReLU(),
74
- nn.Dropout(0.5),
75
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76
- nn.BatchNorm1d(d_model),
77
- nn.ReLU(),
78
- nn.Dropout(0.5),
79
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80
- nn.BatchNorm1d(d_model),
81
- nn.ReLU(),
82
- nn.Dropout(0.5),
83
- Transpose(),
84
- nn.Linear(d_model, d_model),
85
- )
86
-
87
- self.decoder_prenet = nn.Sequential(
88
- nn.Linear(NUM_MEL_BINS, 256),
89
- nn.ReLU(),
90
- nn.Dropout(0.5),
91
- nn.Linear(256, 256),
92
- nn.ReLU(),
93
- nn.Dropout(0.5),
94
- nn.Linear(256, d_model),
95
- )
96
-
97
- assert scaling_xformers is False # TODO: update this block
98
- else:
99
- self.encoder_prenet = nn.Identity()
100
- if scaling_xformers:
101
- self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102
- else:
103
- self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104
-
105
- self.encoder_position = SinePositionalEmbedding(
106
- d_model,
107
- dropout=0.1,
108
- scale=False,
109
- )
110
- self.decoder_position = SinePositionalEmbedding(
111
- d_model, dropout=0.1, scale=False
112
- )
113
-
114
- if scaling_xformers:
115
- self.encoder = TransformerEncoder(
116
- TransformerEncoderLayer(
117
- d_model,
118
- nhead,
119
- dim_feedforward=d_model * 4,
120
- dropout=0.1,
121
- batch_first=True,
122
- norm_first=norm_first,
123
- linear1_self_attention_cls=ScaledLinear,
124
- linear2_self_attention_cls=partial(
125
- ScaledLinear, initial_scale=0.01
126
- ),
127
- linear1_feedforward_cls=ScaledLinear,
128
- linear2_feedforward_cls=partial(
129
- ScaledLinear, initial_scale=0.01
130
- ),
131
- activation=partial(
132
- BalancedDoubleSwish,
133
- channel_dim=-1,
134
- max_abs=10.0,
135
- min_prob=0.25,
136
- ),
137
- layer_norm_cls=IdentityNorm,
138
- ),
139
- num_layers=num_layers,
140
- norm=BalancedBasicNorm(d_model) if norm_first else None,
141
- )
142
-
143
- self.decoder = nn.TransformerDecoder(
144
- TransformerDecoderLayer(
145
- d_model,
146
- nhead,
147
- dim_feedforward=d_model * 4,
148
- dropout=0.1,
149
- batch_first=True,
150
- norm_first=norm_first,
151
- linear1_self_attention_cls=ScaledLinear,
152
- linear2_self_attention_cls=partial(
153
- ScaledLinear, initial_scale=0.01
154
- ),
155
- linear1_feedforward_cls=ScaledLinear,
156
- linear2_feedforward_cls=partial(
157
- ScaledLinear, initial_scale=0.01
158
- ),
159
- activation=partial(
160
- BalancedDoubleSwish,
161
- channel_dim=-1,
162
- max_abs=10.0,
163
- min_prob=0.25,
164
- ),
165
- layer_norm_cls=IdentityNorm,
166
- ),
167
- num_layers=num_layers,
168
- norm=BalancedBasicNorm(d_model) if norm_first else None,
169
- )
170
-
171
- self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172
- self.stop_layer = nn.Linear(d_model, 1)
173
- else:
174
- self.encoder = nn.TransformerEncoder(
175
- nn.TransformerEncoderLayer(
176
- d_model,
177
- nhead,
178
- dim_feedforward=d_model * 4,
179
- activation=F.relu,
180
- dropout=0.1,
181
- batch_first=True,
182
- norm_first=norm_first,
183
- ),
184
- num_layers=num_layers,
185
- norm=nn.LayerNorm(d_model) if norm_first else None,
186
- )
187
-
188
- self.decoder = nn.TransformerDecoder(
189
- nn.TransformerDecoderLayer(
190
- d_model,
191
- nhead,
192
- dim_feedforward=d_model * 4,
193
- activation=F.relu,
194
- dropout=0.1,
195
- batch_first=True,
196
- norm_first=norm_first,
197
- ),
198
- num_layers=num_layers,
199
- norm=nn.LayerNorm(d_model) if norm_first else None,
200
- )
201
-
202
- self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203
- self.stop_layer = nn.Linear(d_model, 1)
204
-
205
- self.stop_accuracy_metric = BinaryAccuracy(
206
- threshold=0.5, multidim_average="global"
207
- )
208
-
209
- # self.apply(self._init_weights)
210
-
211
- # def _init_weights(self, module):
212
- # if isinstance(module, (nn.Linear)):
213
- # module.weight.data.normal_(mean=0.0, std=0.02)
214
- # if isinstance(module, nn.Linear) and module.bias is not None:
215
- # module.bias.data.zero_()
216
- # elif isinstance(module, nn.LayerNorm):
217
- # module.bias.data.zero_()
218
- # module.weight.data.fill_(1.0)
219
- # elif isinstance(module, nn.Embedding):
220
- # module.weight.data.normal_(mean=0.0, std=0.02)
221
-
222
- def forward(
223
- self,
224
- x: torch.Tensor,
225
- x_lens: torch.Tensor,
226
- y: torch.Tensor,
227
- y_lens: torch.Tensor,
228
- reduction: str = "sum",
229
- train_stage: int = 0,
230
- **kwargs,
231
- ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232
- """
233
- Args:
234
- x:
235
- A 2-D tensor of shape (N, S).
236
- x_lens:
237
- A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238
- before padding.
239
- y:
240
- A 3-D tensor of shape (N, T, 8).
241
- y_lens:
242
- A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243
- before padding.
244
- train_stage:
245
- Not used in this model.
246
- Returns:
247
- Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248
- """
249
- del train_stage
250
-
251
- assert x.ndim == 2, x.shape
252
- assert x_lens.ndim == 1, x_lens.shape
253
- assert y.ndim == 3, y.shape
254
- assert y_lens.ndim == 1, y_lens.shape
255
-
256
- assert torch.all(x_lens > 0)
257
-
258
- # NOTE: x has been padded in TextTokenCollater
259
- x_mask = make_pad_mask(x_lens).to(x.device)
260
-
261
- x = self.text_embedding(x)
262
- x = self.encoder_prenet(x)
263
- x = self.encoder_position(x)
264
- x = self.encoder(x, src_key_padding_mask=x_mask)
265
-
266
- total_loss, metrics = 0.0, {}
267
-
268
- y_mask = make_pad_mask(y_lens).to(y.device)
269
- y_mask_float = y_mask.type(torch.float32)
270
- data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271
-
272
- # Training
273
- # AR Decoder
274
- def pad_y(y):
275
- y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276
- # inputs, targets
277
- return y[:, :-1], y[:, 1:]
278
-
279
- y, targets = pad_y(y * data_mask) # mask padding as zeros
280
-
281
- y_emb = self.decoder_prenet(y)
282
- y_pos = self.decoder_position(y_emb)
283
-
284
- y_len = y_lens.max()
285
- tgt_mask = torch.triu(
286
- torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287
- diagonal=1,
288
- )
289
- y_dec = self.decoder(
290
- y_pos,
291
- x,
292
- tgt_mask=tgt_mask,
293
- memory_key_padding_mask=x_mask,
294
- )
295
-
296
- predict = self.predict_layer(y_dec)
297
- # loss
298
- total_loss = F.mse_loss(predict, targets, reduction=reduction)
299
-
300
- logits = self.stop_layer(y_dec).squeeze(-1)
301
- stop_loss = F.binary_cross_entropy_with_logits(
302
- logits,
303
- y_mask_float.detach(),
304
- weight=1.0 + y_mask_float.detach() * 4.0,
305
- reduction=reduction,
306
- )
307
- metrics["stop_loss"] = stop_loss.detach()
308
-
309
- stop_accuracy = self.stop_accuracy_metric(
310
- (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311
- y_mask.type(torch.int64),
312
- )
313
- # icefall MetricsTracker.norm_items()
314
- metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315
- torch.float32
316
- )
317
-
318
- return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319
-
320
- def inference(
321
- self,
322
- x: torch.Tensor,
323
- x_lens: torch.Tensor,
324
- y: Any = None,
325
- **kwargs,
326
- ) -> torch.Tensor:
327
- """
328
- Args:
329
- x:
330
- A 2-D tensor of shape (1, S).
331
- x_lens:
332
- A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333
- before padding.
334
- Returns:
335
- Return the predicted audio code matrix and cross-entropy loss.
336
- """
337
- assert x.ndim == 2, x.shape
338
- assert x_lens.ndim == 1, x_lens.shape
339
-
340
- assert torch.all(x_lens > 0)
341
-
342
- x_mask = make_pad_mask(x_lens).to(x.device)
343
-
344
- x = self.text_embedding(x)
345
- x = self.encoder_prenet(x)
346
- x = self.encoder_position(x)
347
- x = self.encoder(x, src_key_padding_mask=x_mask)
348
-
349
- x_mask = make_pad_mask(x_lens).to(x.device)
350
-
351
- # AR Decoder
352
- # TODO: Managing decoder steps avoid repetitive computation
353
- y = torch.zeros(
354
- [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355
- )
356
- while True:
357
- y_emb = self.decoder_prenet(y)
358
- y_pos = self.decoder_position(y_emb)
359
-
360
- tgt_mask = torch.triu(
361
- torch.ones(
362
- y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363
- ),
364
- diagonal=1,
365
- )
366
-
367
- y_dec = self.decoder(
368
- y_pos,
369
- x,
370
- tgt_mask=tgt_mask,
371
- memory_mask=None,
372
- memory_key_padding_mask=x_mask,
373
- )
374
- predict = self.predict_layer(y_dec[:, -1:])
375
-
376
- logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377
- if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378
- print(
379
- f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380
- )
381
- break
382
-
383
- y = torch.concat([y, predict], dim=1)
384
-
385
- return y[:, 1:]
386
-
387
- def visualize(
388
- self,
389
- predicts: Tuple[torch.Tensor],
390
- batch: Dict[str, Union[List, torch.Tensor]],
391
- output_dir: str,
392
- limit: int = 4,
393
- ) -> None:
394
- visualize(predicts, batch, output_dir, limit=limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vallex.py DELETED
@@ -1,853 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import random
16
- from typing import Dict, Iterator, List, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
- # from icefall.utils import make_pad_mask
23
- # from torchmetrics.classification import MulticlassAccuracy
24
-
25
- from data.input_strategies import PromptedFeatures
26
- from modules.embedding import SinePositionalEmbedding, TokenEmbedding
27
- from modules.transformer import (
28
- AdaptiveLayerNorm,
29
- LayerNorm,
30
- TransformerDecoderLayer,
31
- TransformerEncoder,
32
- TransformerEncoderLayer,
33
- )
34
-
35
- from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
36
- from .visualizer import visualize
37
-
38
-
39
- class Transpose(nn.Identity):
40
- """(N, T, D) -> (N, D, T)"""
41
-
42
- def forward(self, input: torch.Tensor) -> torch.Tensor:
43
- return input.transpose(1, 2)
44
-
45
-
46
- # NOTE: There are two ways to implement the model
47
- # 1) [VALL-F] standard TransformerDecoder, use x as memory
48
- # 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
49
- # use x as the prefix of decoder inputs
50
- class VALLF(nn.Module):
51
- """It implements https://arxiv.org/abs/2301.02111
52
- "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
53
- """
54
-
55
- def __init__(
56
- self,
57
- d_model: int,
58
- nhead: int,
59
- num_layers: int,
60
- norm_first: bool = True,
61
- add_prenet: bool = False,
62
- decoder_cls: Union[
63
- nn.TransformerDecoder, nn.TransformerEncoder
64
- ] = nn.TransformerDecoder,
65
- decoder_layer_cls: Union[
66
- TransformerDecoderLayer, TransformerEncoderLayer
67
- ] = TransformerDecoderLayer,
68
- prefix_mode: int = 0,
69
- share_embedding: bool = True,
70
- nar_scale_factor: float = 1.0,
71
- prepend_bos: bool = True,
72
- num_quantizers: int = 8,
73
- ):
74
- """
75
- Args:
76
- d_model:
77
- The number of expected features in the input (required).
78
- nhead:
79
- The number of heads in the multiheadattention models (required).
80
- num_layers:
81
- The number of sub-decoder-layers in the decoder (required).
82
- """
83
- super().__init__()
84
- nar_d_model = int(d_model * nar_scale_factor)
85
-
86
- self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
87
- self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
88
-
89
- # ID NUM_AUDIO_TOKENS -> PAD
90
- # ID NUM_AUDIO_TOKENS + 1 -> BOS
91
- self.ar_audio_prepend_bos = prepend_bos
92
- self.ar_audio_embedding = TokenEmbedding(
93
- d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
94
- )
95
-
96
- # PreNet
97
- if add_prenet:
98
- self.ar_text_prenet = nn.Sequential(
99
- Transpose(),
100
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
101
- nn.BatchNorm1d(d_model),
102
- nn.ReLU(),
103
- nn.Dropout(0.5),
104
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
105
- nn.BatchNorm1d(d_model),
106
- nn.ReLU(),
107
- nn.Dropout(0.5),
108
- nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
109
- nn.BatchNorm1d(d_model),
110
- nn.ReLU(),
111
- nn.Dropout(0.5),
112
- Transpose(),
113
- nn.Linear(d_model, d_model),
114
- )
115
-
116
- self.ar_audio_prenet = nn.Sequential(
117
- nn.Linear(d_model, 256),
118
- nn.ReLU(),
119
- nn.Dropout(0.25),
120
- nn.Linear(256, 256),
121
- nn.ReLU(),
122
- nn.Dropout(0.25),
123
- nn.Linear(256, d_model),
124
- )
125
- else:
126
- self.ar_text_prenet = nn.Identity()
127
- self.ar_audio_prenet = nn.Identity()
128
-
129
- self.ar_text_position = SinePositionalEmbedding(
130
- d_model,
131
- dropout=0.1,
132
- scale=False,
133
- alpha=True,
134
- )
135
- self.ar_audio_position = SinePositionalEmbedding(
136
- d_model,
137
- dropout=0.1,
138
- scale=False,
139
- alpha=True,
140
- )
141
-
142
- self.ar_decoder = decoder_cls(
143
- decoder_layer_cls(
144
- d_model,
145
- nhead,
146
- dim_feedforward=d_model * 4,
147
- dropout=0.1,
148
- batch_first=True,
149
- norm_first=norm_first,
150
- ),
151
- num_layers=num_layers,
152
- norm=LayerNorm(d_model) if norm_first else None,
153
- )
154
- self.ar_predict_layer = nn.Linear(
155
- d_model, NUM_AUDIO_TOKENS + 1, bias=False
156
- )
157
-
158
- self.rng = random.Random(0)
159
- self.num_heads = nhead
160
- self.prefix_mode = prefix_mode
161
- self.num_quantizers = num_quantizers
162
-
163
- assert num_quantizers >= 1
164
- if num_quantizers > 1:
165
- self.nar_audio_embeddings = nn.ModuleList(
166
- [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
167
- + [
168
- TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
169
- for i in range(num_quantizers - 1)
170
- ]
171
- ) # W_a
172
-
173
- # PreNet
174
- if add_prenet:
175
- self.nar_text_prenet = nn.Sequential(
176
- Transpose(),
177
- nn.Conv1d(
178
- nar_d_model, nar_d_model, kernel_size=5, padding="same"
179
- ),
180
- nn.BatchNorm1d(nar_d_model),
181
- nn.ReLU(),
182
- nn.Dropout(0.5),
183
- nn.Conv1d(
184
- nar_d_model, nar_d_model, kernel_size=5, padding="same"
185
- ),
186
- nn.BatchNorm1d(nar_d_model),
187
- nn.ReLU(),
188
- nn.Dropout(0.5),
189
- nn.Conv1d(
190
- nar_d_model, nar_d_model, kernel_size=5, padding="same"
191
- ),
192
- nn.BatchNorm1d(nar_d_model),
193
- nn.ReLU(),
194
- nn.Dropout(0.5),
195
- Transpose(),
196
- nn.Linear(nar_d_model, nar_d_model),
197
- )
198
- self.nar_audio_prenet = nn.Sequential(
199
- nn.Linear(nar_d_model, 256),
200
- nn.ReLU(),
201
- nn.Dropout(0.25),
202
- nn.Linear(256, 256),
203
- nn.ReLU(),
204
- nn.Dropout(0.25),
205
- nn.Linear(256, nar_d_model),
206
- )
207
- else:
208
- self.nar_text_prenet = nn.Identity()
209
- self.nar_audio_prenet = nn.Identity()
210
-
211
- self.nar_text_position = SinePositionalEmbedding(
212
- nar_d_model,
213
- dropout=0.0,
214
- scale=False,
215
- alpha=False,
216
- )
217
- self.nar_audio_position = SinePositionalEmbedding(
218
- nar_d_model,
219
- dropout=0.1,
220
- scale=False,
221
- alpha=False,
222
- )
223
-
224
- self.nar_decoder = decoder_cls(
225
- decoder_layer_cls(
226
- nar_d_model,
227
- int(nhead * nar_scale_factor),
228
- dim_feedforward=nar_d_model * 4,
229
- dropout=0.1,
230
- batch_first=True,
231
- norm_first=norm_first,
232
- adaptive_layer_norm=True,
233
- ),
234
- num_layers=int(num_layers * nar_scale_factor),
235
- norm=AdaptiveLayerNorm(
236
- nar_d_model, norm=nn.LayerNorm(nar_d_model)
237
- )
238
- if norm_first
239
- else None,
240
- )
241
- self.nar_predict_layers = nn.ModuleList(
242
- [
243
- nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
244
- for i in range(num_quantizers - 1)
245
- ]
246
- )
247
- self.nar_stage_embeddings = nn.ModuleList(
248
- [
249
- TokenEmbedding(nar_d_model, 1)
250
- for i in range(num_quantizers - 1)
251
- ]
252
- )
253
-
254
- if share_embedding:
255
- # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
256
- # NOTE(Feiteng): In the experiment, this undermines accuracy
257
- # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
258
-
259
- # We also share the parameters of the acoustic embedding layer and the output prediction layer,
260
- # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
261
- for j in range(0, num_quantizers - 2):
262
- self.nar_predict_layers[
263
- j
264
- ].weight = self.nar_audio_embeddings[j + 2].weight
265
-
266
- def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
267
- assert stage > 0
268
- if stage == 1:
269
- for name, param in self.named_parameters():
270
- if name.startswith("ar_"):
271
- print(f" AR parameter: {name}")
272
- yield param
273
-
274
- if stage == 2:
275
- for name, param in self.named_parameters():
276
- if name.startswith("nar_"):
277
- print(f"NAR parameter: {name}")
278
- yield param
279
-
280
- def stage_named_parameters(
281
- self, stage: int = 1
282
- ) -> Iterator[Tuple[str, nn.Parameter]]:
283
- assert stage > 0
284
- if stage == 1:
285
- for pair in self.named_parameters():
286
- if pair[0].startswith("ar_"):
287
- yield pair
288
-
289
- if stage == 2:
290
- for pair in self.named_parameters():
291
- if pair[0].startswith("nar_"):
292
- yield pair
293
-
294
- def pad_y_eos(self, y, y_mask_int, eos_id):
295
- targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
296
- y_mask_int, (0, 1), value=1
297
- )
298
- # inputs, targets
299
- if self.ar_audio_prepend_bos:
300
- return (
301
- F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
302
- targets,
303
- )
304
-
305
- return targets[:, :-1], targets[:, 1:]
306
-
307
- def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
308
- # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
309
- # from the same utterance.
310
- # We implement this differently.
311
- if prefix_mode == 0:
312
- # no prefix
313
- prefix_len = 0
314
- y_emb = self.nar_audio_embeddings[0](y)
315
- for j in range(1, nar_stage):
316
- # Formula (4) (5)
317
- y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
318
- elif prefix_mode == 1:
319
- # prefix at begining
320
- int_low = (0.25 * y_lens.min()).type(torch.int64).item()
321
- prefix_len = torch.randint(0, int_low * 2, size=()).item()
322
- prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
323
-
324
- y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
325
- y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
326
- for j in range(1, self.num_quantizers):
327
- y_prompts += self.nar_audio_embeddings[j](
328
- codes[:, :prefix_len, j]
329
- )
330
- if j < nar_stage:
331
- y_emb += self.nar_audio_embeddings[j](
332
- codes[:, prefix_len:, j]
333
- )
334
- y_emb = torch.concat([y_prompts, y_emb], axis=1)
335
- elif prefix_mode in [2, 4]:
336
- if prefix_mode == 2:
337
- # random prefix
338
- prefix_len = min(225, int(0.25 * y_lens.min().item()))
339
-
340
- y_prompts_codes = []
341
- for b in range(codes.shape[0]):
342
- start = self.rng.randint(0, y_lens[b].item() - prefix_len)
343
- y_prompts_codes.append(
344
- torch.clone(codes[b, start : start + prefix_len])
345
- )
346
- codes[
347
- b, start : start + prefix_len, nar_stage
348
- ] = NUM_AUDIO_TOKENS
349
- y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
350
- else:
351
- prefix_len = y_prompts_codes.shape[1]
352
-
353
- y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
354
- y_emb = self.nar_audio_embeddings[0](y)
355
- for j in range(1, self.num_quantizers):
356
- y_prompts += self.nar_audio_embeddings[j](
357
- y_prompts_codes[..., j]
358
- )
359
- if j < nar_stage:
360
- y_emb += self.nar_audio_embeddings[j](codes[..., j])
361
- y_emb = torch.concat([y_prompts, y_emb], axis=1)
362
- else:
363
- raise ValueError
364
-
365
- return y_emb, prefix_len
366
-
367
- def forward(
368
- self,
369
- x: torch.Tensor,
370
- x_lens: torch.Tensor,
371
- y: Union[torch.Tensor, PromptedFeatures],
372
- y_lens: Union[torch.Tensor, PromptedFeatures],
373
- reduction: str = "sum",
374
- train_stage: int = 0,
375
- **kwargs,
376
- ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
377
- raise NotImplementedError
378
-
379
- def inference(
380
- self,
381
- x: torch.Tensor,
382
- x_lens: torch.Tensor,
383
- y: torch.Tensor,
384
- enroll_x_lens: Union[torch.Tensor, None] = None,
385
- top_k: int = -100,
386
- temperature: float = 1.0,
387
- ) -> torch.Tensor:
388
- raise NotImplementedError
389
-
390
- def visualize(
391
- self,
392
- predicts: Tuple[torch.Tensor],
393
- batch: Dict[str, Union[List, torch.Tensor]],
394
- output_dir: str,
395
- limit: int = 4,
396
- ) -> None:
397
- raise NotImplementedError
398
-
399
-
400
- class VALLE(VALLF):
401
- """It implements https://arxiv.org/abs/2301.02111
402
- "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
403
- """
404
-
405
- def __init__(
406
- self,
407
- d_model: int,
408
- nhead: int,
409
- num_layers: int,
410
- norm_first: bool = True,
411
- add_prenet: bool = False,
412
- prefix_mode: int = 0,
413
- share_embedding: bool = True,
414
- nar_scale_factor: float = 1.0,
415
- **kwargs,
416
- ):
417
- """
418
- Args:
419
- d_model:
420
- The number of expected features in the input (required).
421
- nhead:
422
- The number of heads in the multiheadattention models (required).
423
- num_layers:
424
- The number of sub-decoder-layers in the decoder (required).
425
- """
426
- super(VALLE, self).__init__(
427
- d_model,
428
- nhead,
429
- num_layers,
430
- norm_first=norm_first,
431
- add_prenet=add_prenet,
432
- decoder_cls=TransformerEncoder,
433
- decoder_layer_cls=TransformerEncoderLayer,
434
- prefix_mode=prefix_mode,
435
- share_embedding=share_embedding,
436
- nar_scale_factor=nar_scale_factor,
437
- **kwargs,
438
- )
439
- self.language_ID = {
440
- 'en': 0,
441
- 'zh': 1,
442
- 'ja': 2,
443
- }
444
- self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
445
- self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
446
-
447
- def forward(
448
- self,
449
- x: torch.Tensor,
450
- x_lens: torch.Tensor,
451
- y: Union[torch.Tensor, PromptedFeatures],
452
- y_lens: Union[torch.Tensor, PromptedFeatures],
453
- reduction: str = "sum",
454
- train_stage: int = 0,
455
- **kwargs,
456
- ):
457
- raise NotImplementedError
458
- def inference(
459
- self,
460
- x: torch.Tensor,
461
- x_lens: torch.Tensor,
462
- y: torch.Tensor,
463
- enroll_x_lens: torch.Tensor,
464
- top_k: int = -100,
465
- temperature: float = 1.0,
466
- prompt_language: str = None,
467
- text_language: str = None,
468
- best_of: int = 1,
469
- length_penalty: float = 1.0,
470
- return_worst: bool = False,
471
- ) -> torch.Tensor:
472
- """
473
- Args:
474
- x:
475
- A 2-D tensor of shape (1, S).
476
- x_lens:
477
- A 1-D tensor of shape (1,). It contains the number of tokens in `x`
478
- before padding.
479
- y:
480
- A 3-D tensor of shape (1, T, 8).
481
- top_k: (`optional`) int
482
- The number of highest probability tokens to keep for top-k-filtering. Default to -100.
483
- temperature: (`optional`) float
484
- The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
485
- Returns:
486
- Return the predicted audio code matrix.
487
- """
488
- assert x.ndim == 2, x.shape
489
- assert x_lens.ndim == 1, x_lens.shape
490
- assert y.ndim == 3, y.shape
491
- assert y.shape[0] == 1, y.shape
492
-
493
- assert torch.all(x_lens > 0)
494
-
495
- # NOTE: x has been padded in TextTokenCollater
496
- text = x
497
- x = self.ar_text_embedding(text)
498
- # Add language embedding
499
- prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
500
- if isinstance(text_language, str):
501
- text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
502
- elif isinstance(text_language, List):
503
- text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
504
- x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
505
- x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
506
- x = self.ar_text_prenet(x)
507
- x = self.ar_text_position(x)
508
-
509
- text_len = x_lens.max()
510
- prompts = y
511
- prefix_len = y.shape[1]
512
-
513
- # AR Decoder
514
- # TODO: Managing decoder steps avoid repetitive computation
515
- y = prompts[..., 0]
516
- if self.ar_audio_prepend_bos:
517
- y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
518
-
519
- x_len = x_lens.max()
520
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
521
-
522
- kv_cache = None
523
- use_kv_caching = True
524
-
525
- sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
526
- x = x.repeat(best_of, 1, 1)
527
- y = y.repeat(best_of, 1)
528
- while True:
529
- y_emb = self.ar_audio_embedding(y)
530
- y_emb = self.ar_audio_prenet(y_emb)
531
- y_pos = self.ar_audio_position(y_emb)
532
- xy_pos = torch.concat([x, y_pos], dim=1)
533
-
534
- y_len = y.shape[1]
535
- x_attn_mask_pad = F.pad(
536
- x_attn_mask,
537
- (0, y_len),
538
- value=True,
539
- )
540
- y_attn_mask = F.pad(
541
- torch.triu(
542
- torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
543
- ),
544
- (x_len, 0),
545
- value=False,
546
- )
547
- xy_attn_mask = torch.concat(
548
- [x_attn_mask_pad, y_attn_mask], dim=0
549
- ).to(y.device)
550
-
551
-
552
- if use_kv_caching and kv_cache is not None:
553
- xy_pos = xy_pos[:, [-1]]
554
- else:
555
- pass
556
-
557
- xy_dec, kv_cache = self.ar_decoder.infer(
558
- xy_pos,
559
- mask=xy_attn_mask,
560
- past_kv=kv_cache,
561
- use_cache=use_kv_caching,
562
- )
563
- # xy_dec, _ = self.ar_decoder(
564
- # (xy_pos, None),
565
- # mask=xy_attn_mask,
566
- # )
567
-
568
- logits = self.ar_predict_layer(xy_dec[:, -1])
569
- samples, current_logprobs = topk_sampling(
570
- logits, top_k=top_k, top_p=1, temperature=temperature
571
- )
572
- sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
573
- samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
574
- completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
575
- if (
576
- completed
577
- or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
578
- ):
579
- if prompts.shape[1] == y.shape[1]:
580
- raise SyntaxError(
581
- "well trained model shouldn't reach here."
582
- )
583
- lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
584
- avg_logprobs = sum_logprobs / lengths ** length_penalty
585
- # choose the best beam according to sum_logprobs
586
- best_beam = y[torch.argmax(avg_logprobs), :]
587
- worst_beam = y[torch.argmin(avg_logprobs), :]
588
- # strip all eos tokens
589
- best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
590
- worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
591
- if return_worst:
592
- y = worst_beam.unsqueeze(0)
593
- else:
594
- y = best_beam.unsqueeze(0)
595
- print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
596
- break
597
-
598
- y = torch.concat([y, samples], dim=1)
599
-
600
- codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
601
- if self.num_quantizers == 1:
602
- return torch.stack(codes, dim=-1)
603
-
604
- # Non-AR Decoders
605
- y_emb = self.nar_audio_embeddings[0](
606
- y[:, int(self.ar_audio_prepend_bos) :]
607
- )
608
-
609
- if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
610
- enrolled_len = enroll_x_lens.max().item()
611
- # SOS + Synthesis Text + EOS
612
- text = torch.concat(
613
- [
614
- text[:, :1],
615
- text[:, enrolled_len - 1 :],
616
- ],
617
- dim=1,
618
- )
619
- text_len = text_len - (enrolled_len - 2)
620
- assert text.shape[0] == 1
621
-
622
- x = self.nar_text_embedding(text)
623
- # Add language embedding
624
- prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
625
- if isinstance(text_language, str):
626
- text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
627
- elif isinstance(text_language, List):
628
- text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
629
- x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
630
- x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
631
- x = self.nar_text_prenet(x)
632
- x = self.nar_text_position(x)
633
-
634
- if self.prefix_mode == 0:
635
- for i, (predict_layer, embedding_layer) in enumerate(
636
- zip(
637
- self.nar_predict_layers,
638
- self.nar_audio_embeddings[1:],
639
- )
640
- ):
641
- y_pos = self.nar_audio_prenet(y_emb)
642
- y_pos = self.nar_audio_position(y_pos)
643
- xy_pos = torch.concat([x, y_pos], dim=1)
644
-
645
- xy_dec, _ = self.nar_decoder(
646
- (xy_pos, self.nar_stage_embeddings[i].weight)
647
- )
648
- logits = predict_layer(xy_dec[:, text_len + prefix_len :])
649
-
650
- samples = torch.argmax(logits, dim=-1)
651
- codes.append(samples)
652
-
653
- if i < self.num_quantizers - 2:
654
- y_emb[:, :prefix_len] += embedding_layer(
655
- prompts[..., i + 1]
656
- )
657
- y_emb[:, prefix_len:] += embedding_layer(samples)
658
- else:
659
- for j in range(1, self.num_quantizers):
660
- y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
661
- prompts[..., j]
662
- )
663
-
664
- for i, (predict_layer, embedding_layer) in enumerate(
665
- zip(
666
- self.nar_predict_layers,
667
- self.nar_audio_embeddings[1:],
668
- )
669
- ):
670
- y_pos = self.nar_audio_prenet(y_emb)
671
- y_pos = self.nar_audio_position(y_pos)
672
- xy_pos = torch.concat([x, y_pos], dim=1)
673
-
674
- xy_dec, _ = self.nar_decoder(
675
- (xy_pos, self.nar_stage_embeddings[i].weight)
676
- )
677
- logits = predict_layer(xy_dec[:, text_len + prefix_len :])
678
-
679
- samples = torch.argmax(logits, dim=-1)
680
- codes.append(samples)
681
-
682
- if i < self.num_quantizers - 2:
683
- y_emb[:, prefix_len:] += embedding_layer(samples)
684
-
685
- assert len(codes) == self.num_quantizers
686
- return torch.stack(codes, dim=-1)
687
-
688
- def continual(
689
- self,
690
- x: torch.Tensor,
691
- x_lens: torch.Tensor,
692
- y: torch.Tensor,
693
- ) -> torch.Tensor:
694
- """
695
- Args:
696
- x:
697
- A 2-D tensor of shape (1, S).
698
- x_lens:
699
- A 1-D tensor of shape (1,). It contains the number of tokens in `x`
700
- before padding.
701
- y:
702
- A 3-D tensor of shape (1, T, 8).
703
- Returns:
704
- Return the predicted audio code matrix.
705
- """
706
- assert x.ndim == 2, x.shape
707
- assert x_lens.ndim == 1, x_lens.shape
708
- assert y.ndim == 3, y.shape
709
- assert y.shape[0] == 1, y.shape
710
-
711
- assert torch.all(x_lens > 0)
712
- assert self.num_quantizers == 8
713
-
714
- # NOTE: x has been padded in TextTokenCollater
715
- text = x
716
- x = self.ar_text_embedding(text)
717
- x = self.ar_text_prenet(x)
718
- x = self.ar_text_position(x)
719
-
720
- text_len = x_lens.max()
721
-
722
- prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
723
-
724
- # AR Decoder
725
- prompts = y[:, :prefix_len]
726
-
727
- codes = [y[:, prefix_len:, 0]]
728
- # Non-AR Decoders
729
- x = self.nar_text_embedding(text)
730
- x = self.nar_text_prenet(x)
731
- x = self.nar_text_position(x)
732
-
733
- y_emb = self.nar_audio_embeddings[0](y[..., 0])
734
-
735
- if self.prefix_mode == 0:
736
- for i, (predict_layer, embedding_layer) in enumerate(
737
- zip(
738
- self.nar_predict_layers,
739
- self.nar_audio_embeddings[1:],
740
- )
741
- ):
742
- y_pos = self.nar_audio_position(y_emb)
743
- y_pos = self.nar_audio_prenet(y_pos)
744
- xy_pos = torch.concat([x, y_pos], dim=1)
745
-
746
- xy_dec, _ = self.nar_decoder(
747
- (xy_pos, self.nar_stage_embeddings[i].weight)
748
- )
749
- logits = predict_layer(xy_dec[:, text_len + prefix_len :])
750
-
751
- samples = torch.argmax(logits, dim=-1)
752
- codes.append(samples)
753
-
754
- if i < 6:
755
- y_emb[:, :prefix_len] += embedding_layer(
756
- prompts[..., i + 1]
757
- )
758
- y_emb[:, prefix_len:] += embedding_layer(samples)
759
- else:
760
- for j in range(1, 8):
761
- y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
762
- prompts[..., j]
763
- )
764
-
765
- for i, (predict_layer, embedding_layer) in enumerate(
766
- zip(
767
- self.nar_predict_layers,
768
- self.nar_audio_embeddings[1:],
769
- )
770
- ):
771
- y_pos = self.nar_audio_prenet(y_emb)
772
- y_pos = self.nar_audio_position(y_pos)
773
- xy_pos = torch.concat([x, y_pos], dim=1)
774
-
775
- xy_dec, _ = self.nar_decoder(
776
- (xy_pos, self.nar_stage_embeddings[i].weight)
777
- )
778
- logits = predict_layer(xy_dec[:, text_len + prefix_len :])
779
-
780
- samples = torch.argmax(logits, dim=-1)
781
- codes.append(samples)
782
-
783
- if i < 6:
784
- y_emb[:, prefix_len:] += embedding_layer(samples)
785
-
786
- assert len(codes) == 8
787
- return torch.stack(codes, dim=-1)
788
-
789
-
790
- # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
791
- def top_k_top_p_filtering(
792
- logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
793
- ):
794
- """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
795
- Args:
796
- logits: logits distribution shape (batch size, vocabulary size)
797
- if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
798
- if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
799
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
800
- Make sure we keep at least min_tokens_to_keep per batch example in the output
801
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
802
- """
803
- if top_k > 0:
804
- top_k = min(
805
- max(top_k, min_tokens_to_keep), logits.size(-1)
806
- ) # Safety check
807
- # Remove all tokens with a probability less than the last token of the top-k
808
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
809
- logits[indices_to_remove] = filter_value
810
-
811
- if top_p < 1.0:
812
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
813
- cumulative_probs = torch.cumsum(
814
- F.softmax(sorted_logits, dim=-1), dim=-1
815
- )
816
-
817
- # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
818
- sorted_indices_to_remove = cumulative_probs > top_p
819
- if min_tokens_to_keep > 1:
820
- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
821
- sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
822
- # Shift the indices to the right to keep also the first token above the threshold
823
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
824
- ..., :-1
825
- ].clone()
826
- sorted_indices_to_remove[..., 0] = 0
827
-
828
- # scatter sorted tensors to original indexing
829
- indices_to_remove = sorted_indices_to_remove.scatter(
830
- 1, sorted_indices, sorted_indices_to_remove
831
- )
832
- logits[indices_to_remove] = filter_value
833
- return logits
834
-
835
-
836
- def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
837
- # temperature: (`optional`) float
838
- # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
839
- # top_k: (`optional`) int
840
- # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
841
- # top_p: (`optional`) float
842
- # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
843
-
844
- # Temperature (higher temperature => more likely to sample low probability tokens)
845
- if temperature != 1.0:
846
- logits = logits / temperature
847
- # Top-p/top-k filtering
848
- logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
849
- # Sample
850
- token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
851
- logprobs = F.log_softmax(logits.float(), dim=-1)
852
- current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
853
- return token, current_logprobs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/visualizer.py DELETED
@@ -1,106 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright 2023 (authors: Feiteng Li)
3
- #
4
- # See ../../../../LICENSE for clarification regarding multiple authors
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
-
19
- from typing import Dict, List, Tuple, Union
20
-
21
- import matplotlib.pyplot as plt
22
- import numpy as np
23
- import torch
24
-
25
-
26
- def visualize(
27
- predicts: Tuple[torch.Tensor],
28
- batch: Dict[str, Union[List, torch.Tensor]],
29
- output_dir: str,
30
- limit: int = 4,
31
- ) -> None:
32
- text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33
- text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34
- audio_features = batch["audio_features"].to("cpu").detach().numpy()
35
- audio_features_lens = (
36
- batch["audio_features_lens"].to("cpu").detach().numpy()
37
- )
38
- assert text_tokens.ndim == 2
39
-
40
- utt_ids, texts = batch["utt_id"], batch["text"]
41
-
42
- encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43
- decoder_outputs = predicts[1]
44
- if isinstance(decoder_outputs, list):
45
- decoder_outputs = decoder_outputs[-1]
46
- decoder_outputs = (
47
- decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48
- )
49
-
50
- vmin, vmax = 0, 1024 # Encodec
51
- if decoder_outputs.dtype == np.float32:
52
- vmin, vmax = -6, 0 # Fbank
53
-
54
- num_figures = 3
55
- for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56
- _ = plt.figure(figsize=(14, 8 * num_figures))
57
-
58
- S = text_tokens_lens[b]
59
- T = audio_features_lens[b]
60
-
61
- # encoder
62
- plt.subplot(num_figures, 1, 1)
63
- plt.title(f"Text: {text}")
64
- plt.imshow(
65
- X=np.transpose(encoder_outputs[b]),
66
- cmap=plt.get_cmap("jet"),
67
- aspect="auto",
68
- interpolation="nearest",
69
- )
70
- plt.gca().invert_yaxis()
71
- plt.axvline(x=S - 0.4, linewidth=2, color="r")
72
- plt.xlabel("Encoder Output")
73
- plt.colorbar()
74
-
75
- # decoder
76
- plt.subplot(num_figures, 1, 2)
77
- plt.imshow(
78
- X=np.transpose(decoder_outputs[b]),
79
- cmap=plt.get_cmap("jet"),
80
- aspect="auto",
81
- interpolation="nearest",
82
- vmin=vmin,
83
- vmax=vmax,
84
- )
85
- plt.gca().invert_yaxis()
86
- plt.axvline(x=T - 0.4, linewidth=2, color="r")
87
- plt.xlabel("Decoder Output")
88
- plt.colorbar()
89
-
90
- # target
91
- plt.subplot(num_figures, 1, 3)
92
- plt.imshow(
93
- X=np.transpose(audio_features[b]),
94
- cmap=plt.get_cmap("jet"),
95
- aspect="auto",
96
- interpolation="nearest",
97
- vmin=vmin,
98
- vmax=vmax,
99
- )
100
- plt.gca().invert_yaxis()
101
- plt.axvline(x=T - 0.4, linewidth=2, color="r")
102
- plt.xlabel("Decoder Target")
103
- plt.colorbar()
104
-
105
- plt.savefig(f"{output_dir}/{utt_id}.png")
106
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/__init__.py DELETED
File without changes
modules/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (175 Bytes)
 
modules/__pycache__/activation.cpython-311.pyc DELETED
Binary file (27.5 kB)
 
modules/__pycache__/embedding.cpython-311.pyc DELETED
Binary file (6.15 kB)
 
modules/__pycache__/scaling.cpython-311.pyc DELETED
Binary file (69 kB)
 
modules/__pycache__/transformer.cpython-311.pyc DELETED
Binary file (28.2 kB)
 
modules/activation.py DELETED
@@ -1,612 +0,0 @@
1
- from typing import Optional, Tuple, List
2
- import math
3
-
4
- import torch
5
- from torch import Tensor
6
- from torch.nn import Linear, Module
7
- from torch.nn import functional as F
8
- from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
- from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
- from torch.nn.parameter import Parameter
11
-
12
- def _in_projection_packed(
13
- q: Tensor,
14
- k: Tensor,
15
- v: Tensor,
16
- w: Tensor,
17
- b: Optional[Tensor] = None,
18
- ) -> List[Tensor]:
19
- r"""
20
- Performs the in-projection step of the attention operation, using packed weights.
21
- Output is a triple containing projection tensors for query, key and value.
22
-
23
- Args:
24
- q, k, v: query, key and value tensors to be projected. For self-attention,
25
- these are typically the same tensor; for encoder-decoder attention,
26
- k and v are typically the same tensor. (We take advantage of these
27
- identities for performance if they are present.) Regardless, q, k and v
28
- must share a common embedding dimension; otherwise their shapes may vary.
29
- w: projection weights for q, k and v, packed into a single tensor. Weights
30
- are packed along dimension 0, in q, k, v order.
31
- b: optional projection biases for q, k and v, packed into a single tensor
32
- in q, k, v order.
33
-
34
- Shape:
35
- Inputs:
36
- - q: :math:`(..., E)` where E is the embedding dimension
37
- - k: :math:`(..., E)` where E is the embedding dimension
38
- - v: :math:`(..., E)` where E is the embedding dimension
39
- - w: :math:`(E * 3, E)` where E is the embedding dimension
40
- - b: :math:`E * 3` where E is the embedding dimension
41
-
42
- Output:
43
- - in output list :math:`[q', k', v']`, each output tensor will have the
44
- same shape as the corresponding input tensor.
45
- """
46
- E = q.size(-1)
47
- if k is v:
48
- if q is k:
49
- # self-attention
50
- return F.linear(q, w, b).chunk(3, dim=-1)
51
- else:
52
- # encoder-decoder attention
53
- w_q, w_kv = w.split([E, E * 2])
54
- if b is None:
55
- b_q = b_kv = None
56
- else:
57
- b_q, b_kv = b.split([E, E * 2])
58
- return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
59
- else:
60
- w_q, w_k, w_v = w.chunk(3)
61
- if b is None:
62
- b_q = b_k = b_v = None
63
- else:
64
- b_q, b_k, b_v = b.chunk(3)
65
- return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
66
-
67
- def _scaled_dot_product_attention(
68
- q: Tensor,
69
- k: Tensor,
70
- v: Tensor,
71
- attn_mask: Optional[Tensor] = None,
72
- dropout_p: float = 0.0,
73
- ) -> Tuple[Tensor, Tensor]:
74
- r"""
75
- Computes scaled dot product attention on query, key and value tensors, using
76
- an optional attention mask if passed, and applying dropout if a probability
77
- greater than 0.0 is specified.
78
- Returns a tensor pair containing attended values and attention weights.
79
-
80
- Args:
81
- q, k, v: query, key and value tensors. See Shape section for shape details.
82
- attn_mask: optional tensor containing mask values to be added to calculated
83
- attention. May be 2D or 3D; see Shape section for details.
84
- dropout_p: dropout probability. If greater than 0.0, dropout is applied.
85
-
86
- Shape:
87
- - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
88
- and E is embedding dimension.
89
- - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
90
- and E is embedding dimension.
91
- - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
92
- and E is embedding dimension.
93
- - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
94
- shape :math:`(Nt, Ns)`.
95
-
96
- - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
97
- have shape :math:`(B, Nt, Ns)`
98
- """
99
- B, Nt, E = q.shape
100
- q = q / math.sqrt(E)
101
- # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
102
- if attn_mask is not None:
103
- attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
104
- else:
105
- attn = torch.bmm(q, k.transpose(-2, -1))
106
-
107
- attn = F.softmax(attn, dim=-1)
108
- if dropout_p > 0.0:
109
- attn = F.dropout(attn, p=dropout_p)
110
- # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
111
- output = torch.bmm(attn, v)
112
- return output, attn
113
-
114
- def multi_head_attention_forward(
115
- x,
116
- ipw,
117
- ipb,
118
- opw,
119
- opb,
120
- n_head,
121
- attn_mask,
122
- past_kv=None,
123
- use_cache=False,
124
- ):
125
- # x = x.transpose(1, 0)
126
- # tgt_len, bsz, embed_dim = x.shape
127
- # head_dim = embed_dim // n_head
128
- # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
129
- # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
130
- # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
131
- # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
132
-
133
- # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
134
- # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
135
- # attn_mask = new_attn_mask
136
- #
137
- # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
138
- # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
139
- # attn_output = torch._C._nn.linear(attn_output, opw, opb)
140
- # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
141
-
142
- B, T, C = x.size()
143
-
144
- q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
145
- k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
146
- q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
147
- v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
148
- if past_kv is not None:
149
- past_key = past_kv[0]
150
- past_value = past_kv[1]
151
- k = torch.cat((past_key, k), dim=-2)
152
- v = torch.cat((past_value, v), dim=-2)
153
-
154
- FULL_T = k.shape[-2]
155
-
156
- if use_cache is True:
157
- present = (k, v)
158
- else:
159
- present = None
160
-
161
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
162
- att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
163
- att = F.softmax(att, dim=-1)
164
- y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
165
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
166
- y = torch._C._nn.linear(y, opw, opb)
167
- return (y, present)
168
-
169
-
170
- class MultiheadAttention(Module):
171
- r"""Allows the model to jointly attend to information
172
- from different representation subspaces as described in the paper:
173
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
174
-
175
- Multi-Head Attention is defined as:
176
-
177
- .. math::
178
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
179
-
180
- where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
181
-
182
- ``forward()`` will use a special optimized implementation if all of the following
183
- conditions are met:
184
-
185
- - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
186
- restriction will be loosened in the future.)
187
- - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
188
- - training is disabled (using ``.eval()``)
189
- - dropout is 0
190
- - ``add_bias_kv`` is ``False``
191
- - ``add_zero_attn`` is ``False``
192
- - ``batch_first`` is ``True`` and the input is batched
193
- - ``kdim`` and ``vdim`` are equal to ``embed_dim``
194
- - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
195
- - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
196
- nor ``attn_mask`` is passed
197
-
198
- If the optimized implementation is in use, a
199
- `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
200
- ``query``/``key``/``value`` to represent padding more efficiently than using a
201
- padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
202
- will be returned, and an additional speedup proportional to the fraction of the input
203
- that is padding can be expected.
204
-
205
- Args:
206
- embed_dim: Total dimension of the model.
207
- num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
208
- across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
209
- dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
210
- bias: If specified, adds bias to input / output projection layers. Default: ``True``.
211
- add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
212
- add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
213
- Default: ``False``.
214
- kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
215
- vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
216
- batch_first: If ``True``, then the input and output tensors are provided
217
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
218
-
219
- Examples::
220
-
221
- >>> # xdoctest: +SKIP
222
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
223
- >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
224
-
225
- """
226
- __constants__ = ["batch_first"]
227
- bias_k: Optional[torch.Tensor]
228
- bias_v: Optional[torch.Tensor]
229
-
230
- def __init__(
231
- self,
232
- embed_dim,
233
- num_heads,
234
- dropout=0.0,
235
- bias=True,
236
- add_bias_kv=False,
237
- add_zero_attn=False,
238
- kdim=None,
239
- vdim=None,
240
- batch_first=False,
241
- linear1_cls=Linear,
242
- linear2_cls=Linear,
243
- device=None,
244
- dtype=None,
245
- ) -> None:
246
- factory_kwargs = {"device": device, "dtype": dtype}
247
- super(MultiheadAttention, self).__init__()
248
- self.embed_dim = embed_dim
249
- self.kdim = kdim if kdim is not None else embed_dim
250
- self.vdim = vdim if vdim is not None else embed_dim
251
- self._qkv_same_embed_dim = (
252
- self.kdim == embed_dim and self.vdim == embed_dim
253
- )
254
-
255
- self.num_heads = num_heads
256
- self.dropout = dropout
257
- self.batch_first = batch_first
258
- self.head_dim = embed_dim // num_heads
259
- assert (
260
- self.head_dim * num_heads == self.embed_dim
261
- ), "embed_dim must be divisible by num_heads"
262
-
263
- if add_bias_kv:
264
- self.bias_k = Parameter(
265
- torch.empty((1, 1, embed_dim), **factory_kwargs)
266
- )
267
- self.bias_v = Parameter(
268
- torch.empty((1, 1, embed_dim), **factory_kwargs)
269
- )
270
- else:
271
- self.bias_k = self.bias_v = None
272
-
273
- if linear1_cls == Linear:
274
- if not self._qkv_same_embed_dim:
275
- self.q_proj_weight = Parameter(
276
- torch.empty((embed_dim, embed_dim), **factory_kwargs)
277
- )
278
- self.k_proj_weight = Parameter(
279
- torch.empty((embed_dim, self.kdim), **factory_kwargs)
280
- )
281
- self.v_proj_weight = Parameter(
282
- torch.empty((embed_dim, self.vdim), **factory_kwargs)
283
- )
284
- self.register_parameter("in_proj_weight", None)
285
- else:
286
- self.in_proj_weight = Parameter(
287
- torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
288
- )
289
- self.register_parameter("q_proj_weight", None)
290
- self.register_parameter("k_proj_weight", None)
291
- self.register_parameter("v_proj_weight", None)
292
-
293
- if bias:
294
- self.in_proj_bias = Parameter(
295
- torch.empty(3 * embed_dim, **factory_kwargs)
296
- )
297
- else:
298
- self.register_parameter("in_proj_bias", None)
299
- self.out_proj = NonDynamicallyQuantizableLinear(
300
- embed_dim, embed_dim, bias=bias, **factory_kwargs
301
- )
302
-
303
- self._reset_parameters()
304
- else:
305
- if not self._qkv_same_embed_dim:
306
- raise NotImplementedError
307
- else:
308
- self.in_proj_linear = linear1_cls(
309
- embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
310
- )
311
- self.in_proj_weight = self.in_proj_linear.weight
312
-
313
- self.register_parameter("q_proj_weight", None)
314
- self.register_parameter("k_proj_weight", None)
315
- self.register_parameter("v_proj_weight", None)
316
-
317
- if bias:
318
- self.in_proj_bias = self.in_proj_linear.bias
319
- else:
320
- self.register_parameter("in_proj_bias", None)
321
-
322
- self.out_proj = linear2_cls(
323
- embed_dim, embed_dim, bias=bias, **factory_kwargs
324
- )
325
-
326
- if self.bias_k is not None:
327
- xavier_normal_(self.bias_k)
328
- if self.bias_v is not None:
329
- xavier_normal_(self.bias_v)
330
-
331
- self.add_zero_attn = add_zero_attn
332
-
333
- def _reset_parameters(self):
334
- if self._qkv_same_embed_dim:
335
- xavier_uniform_(self.in_proj_weight)
336
- else:
337
- xavier_uniform_(self.q_proj_weight)
338
- xavier_uniform_(self.k_proj_weight)
339
- xavier_uniform_(self.v_proj_weight)
340
-
341
- if self.in_proj_bias is not None:
342
- constant_(self.in_proj_bias, 0.0)
343
- constant_(self.out_proj.bias, 0.0)
344
-
345
- if self.bias_k is not None:
346
- xavier_normal_(self.bias_k)
347
- if self.bias_v is not None:
348
- xavier_normal_(self.bias_v)
349
-
350
- def __setstate__(self, state):
351
- # Support loading old MultiheadAttention checkpoints generated by v1.1.0
352
- if "_qkv_same_embed_dim" not in state:
353
- state["_qkv_same_embed_dim"] = True
354
-
355
- super(MultiheadAttention, self).__setstate__(state)
356
-
357
- def forward(
358
- self,
359
- query: Tensor,
360
- key: Tensor,
361
- value: Tensor,
362
- key_padding_mask: Optional[Tensor] = None,
363
- need_weights: bool = True,
364
- attn_mask: Optional[Tensor] = None,
365
- average_attn_weights: bool = True,
366
- ) -> Tuple[Tensor, Optional[Tensor]]:
367
- r"""
368
- Args:
369
- query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
370
- or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
371
- :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
372
- Queries are compared against key-value pairs to produce the output.
373
- See "Attention Is All You Need" for more details.
374
- key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
375
- or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
376
- :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
377
- See "Attention Is All You Need" for more details.
378
- value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
379
- ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
380
- sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
381
- See "Attention Is All You Need" for more details.
382
- key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
383
- to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
384
- Binary and byte masks are supported.
385
- For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
386
- the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
387
- need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
388
- Default: ``True``.
389
- attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
390
- :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
391
- :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
392
- broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
393
- Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
394
- corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
395
- corresponding position is not allowed to attend. For a float mask, the mask values will be added to
396
- the attention weight.
397
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
398
- heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
399
- effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
400
-
401
- Outputs:
402
- - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
403
- :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
404
- where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
405
- embedding dimension ``embed_dim``.
406
- - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
407
- returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
408
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
409
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
410
- head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
411
-
412
- .. note::
413
- `batch_first` argument is ignored for unbatched inputs.
414
- """
415
- is_batched = query.dim() == 3
416
- if key_padding_mask is not None:
417
- _kpm_dtype = key_padding_mask.dtype
418
- if _kpm_dtype != torch.bool and not torch.is_floating_point(
419
- key_padding_mask
420
- ):
421
- raise AssertionError(
422
- "only bool and floating types of key_padding_mask are supported"
423
- )
424
- why_not_fast_path = ""
425
- if not is_batched:
426
- why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
427
- elif query is not key or key is not value:
428
- # When lifting this restriction, don't forget to either
429
- # enforce that the dtypes all match or test cases where
430
- # they don't!
431
- why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
432
- elif (
433
- self.in_proj_bias is not None
434
- and query.dtype != self.in_proj_bias.dtype
435
- ):
436
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
437
- elif (
438
- self.in_proj_weight is not None
439
- and query.dtype != self.in_proj_weight.dtype
440
- ):
441
- # this case will fail anyway, but at least they'll get a useful error message.
442
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
443
- elif self.training:
444
- why_not_fast_path = "training is enabled"
445
- elif not self.batch_first:
446
- why_not_fast_path = "batch_first was not True"
447
- elif self.bias_k is not None:
448
- why_not_fast_path = "self.bias_k was not None"
449
- elif self.bias_v is not None:
450
- why_not_fast_path = "self.bias_v was not None"
451
- elif self.dropout:
452
- why_not_fast_path = f"dropout was {self.dropout}, required zero"
453
- elif self.add_zero_attn:
454
- why_not_fast_path = "add_zero_attn was enabled"
455
- elif not self._qkv_same_embed_dim:
456
- why_not_fast_path = "_qkv_same_embed_dim was not True"
457
- elif attn_mask is not None:
458
- why_not_fast_path = "attn_mask was not None"
459
- elif query.is_nested and key_padding_mask is not None:
460
- why_not_fast_path = (
461
- "key_padding_mask is not supported with NestedTensor input"
462
- )
463
- elif self.num_heads % 2 == 1:
464
- why_not_fast_path = "num_heads is odd"
465
- elif torch.is_autocast_enabled():
466
- why_not_fast_path = "autocast is enabled"
467
-
468
- if not why_not_fast_path:
469
- tensor_args = (
470
- query,
471
- key,
472
- value,
473
- self.in_proj_weight,
474
- self.in_proj_bias,
475
- self.out_proj.weight,
476
- self.out_proj.bias,
477
- )
478
- # We have to use list comprehensions below because TorchScript does not support
479
- # generator expressions.
480
- if torch.overrides.has_torch_function(tensor_args):
481
- why_not_fast_path = "some Tensor argument has_torch_function"
482
- elif not all(
483
- [
484
- (x is None or x.is_cuda or "cpu" in str(x.device))
485
- for x in tensor_args
486
- ]
487
- ):
488
- why_not_fast_path = (
489
- "some Tensor argument is neither CUDA nor CPU"
490
- )
491
- elif torch.is_grad_enabled() and any(
492
- [x is not None and x.requires_grad for x in tensor_args]
493
- ):
494
- why_not_fast_path = (
495
- "grad is enabled and at least one of query or the "
496
- "input/output projection weights or biases requires_grad"
497
- )
498
- if not why_not_fast_path:
499
- return torch._native_multi_head_attention(
500
- query,
501
- key,
502
- value,
503
- self.embed_dim,
504
- self.num_heads,
505
- self.in_proj_weight,
506
- self.in_proj_bias,
507
- self.out_proj.weight,
508
- self.out_proj.bias,
509
- key_padding_mask
510
- if key_padding_mask is not None
511
- else attn_mask,
512
- need_weights,
513
- average_attn_weights,
514
- 1
515
- if key_padding_mask is not None
516
- else 0
517
- if attn_mask is not None
518
- else None,
519
- )
520
-
521
- any_nested = query.is_nested or key.is_nested or value.is_nested
522
- assert not any_nested, (
523
- "MultiheadAttention does not support NestedTensor outside of its fast path. "
524
- + f"The fast path was not hit because {why_not_fast_path}"
525
- )
526
-
527
- if self.batch_first and is_batched:
528
- # make sure that the transpose op does not affect the "is" property
529
- if key is value:
530
- if query is key:
531
- query = key = value = query.transpose(1, 0)
532
- else:
533
- query, key = [x.transpose(1, 0) for x in (query, key)]
534
- value = key
535
- else:
536
- query, key, value = [
537
- x.transpose(1, 0) for x in (query, key, value)
538
- ]
539
-
540
- if not self._qkv_same_embed_dim:
541
- attn_output, attn_output_weights = F.multi_head_attention_forward(
542
- query,
543
- key,
544
- value,
545
- self.embed_dim,
546
- self.num_heads,
547
- self.in_proj_weight,
548
- self.in_proj_bias,
549
- self.bias_k,
550
- self.bias_v,
551
- self.add_zero_attn,
552
- self.dropout,
553
- self.out_proj.weight,
554
- self.out_proj.bias,
555
- training=self.training,
556
- key_padding_mask=key_padding_mask,
557
- need_weights=need_weights,
558
- attn_mask=attn_mask,
559
- use_separate_proj_weight=True,
560
- q_proj_weight=self.q_proj_weight,
561
- k_proj_weight=self.k_proj_weight,
562
- v_proj_weight=self.v_proj_weight,
563
- average_attn_weights=average_attn_weights,
564
- )
565
- else:
566
- attn_output, attn_output_weights = F.multi_head_attention_forward(
567
- query,
568
- key,
569
- value,
570
- self.embed_dim,
571
- self.num_heads,
572
- self.in_proj_weight,
573
- self.in_proj_bias,
574
- self.bias_k,
575
- self.bias_v,
576
- self.add_zero_attn,
577
- self.dropout,
578
- self.out_proj.weight,
579
- self.out_proj.bias,
580
- training=self.training,
581
- key_padding_mask=key_padding_mask,
582
- need_weights=need_weights,
583
- attn_mask=attn_mask,
584
- average_attn_weights=average_attn_weights,
585
- )
586
- if self.batch_first and is_batched:
587
- return attn_output.transpose(1, 0), attn_output_weights
588
- else:
589
- return attn_output, attn_output_weights
590
-
591
- def infer(self,
592
- x: Tensor,
593
- key_padding_mask: Optional[Tensor] = None,
594
- need_weights: bool = True,
595
- attn_mask: Optional[Tensor] = None,
596
- average_attn_weights: bool = True,
597
- past_kv = None,
598
- use_cache = False
599
- ):
600
- # x = x.transpose(1, 0)
601
- y, kv = multi_head_attention_forward(
602
- x=x,
603
- ipw=self.in_proj_weight,
604
- ipb=self.in_proj_bias,
605
- opw=self.out_proj.weight,
606
- opb=self.out_proj.bias,
607
- n_head=self.num_heads,
608
- attn_mask=attn_mask,
609
- past_kv=past_kv,
610
- use_cache=use_cache,
611
- )
612
- return (y, kv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/embedding.py DELETED
@@ -1,97 +0,0 @@
1
- # Copyright 2023 (authors: Feiteng Li)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
-
21
- class TokenEmbedding(nn.Module):
22
- def __init__(
23
- self,
24
- dim_model: int,
25
- vocab_size: int,
26
- dropout: float = 0.0,
27
- ):
28
- super().__init__()
29
-
30
- self.vocab_size = vocab_size
31
- self.dim_model = dim_model
32
-
33
- self.dropout = torch.nn.Dropout(p=dropout)
34
- self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35
-
36
- @property
37
- def weight(self) -> torch.Tensor:
38
- return self.word_embeddings.weight
39
-
40
- def embedding(self, index: int) -> torch.Tensor:
41
- return self.word_embeddings.weight[index : index + 1]
42
-
43
- def forward(self, x: torch.Tensor):
44
- X = self.word_embeddings(x)
45
- X = self.dropout(X)
46
-
47
- return X
48
-
49
-
50
- class SinePositionalEmbedding(nn.Module):
51
- def __init__(
52
- self,
53
- dim_model: int,
54
- dropout: float = 0.0,
55
- scale: bool = False,
56
- alpha: bool = False,
57
- ):
58
- super().__init__()
59
- self.dim_model = dim_model
60
- self.x_scale = math.sqrt(dim_model) if scale else 1.0
61
- self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62
- self.dropout = torch.nn.Dropout(p=dropout)
63
-
64
- self.reverse = False
65
- self.pe = None
66
- self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67
-
68
- def extend_pe(self, x):
69
- """Reset the positional encodings."""
70
- if self.pe is not None:
71
- if self.pe.size(1) >= x.size(1):
72
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
73
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74
- return
75
- pe = torch.zeros(x.size(1), self.dim_model)
76
- if self.reverse:
77
- position = torch.arange(
78
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
79
- ).unsqueeze(1)
80
- else:
81
- position = torch.arange(
82
- 0, x.size(1), dtype=torch.float32
83
- ).unsqueeze(1)
84
- div_term = torch.exp(
85
- torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86
- * -(math.log(10000.0) / self.dim_model)
87
- )
88
- pe[:, 0::2] = torch.sin(position * div_term)
89
- pe[:, 1::2] = torch.cos(position * div_term)
90
- pe = pe.unsqueeze(0)
91
- self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92
-
93
- def forward(self, x: torch.Tensor) -> torch.Tensor:
94
- self.extend_pe(x)
95
- output = x.unsqueeze(-1) if x.ndim == 2 else x
96
- output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97
- return self.dropout(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/optim.py DELETED
@@ -1,1105 +0,0 @@
1
- # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
- #
3
- # See ../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import contextlib
18
- import logging
19
- import random
20
- from collections import defaultdict
21
- from typing import List, Optional, Tuple, Union
22
-
23
- import torch
24
- from lhotse.utils import fix_random_seed
25
- from torch import Tensor
26
- from torch.optim import Optimizer
27
-
28
-
29
- class BatchedOptimizer(Optimizer):
30
- """
31
- This class adds to class Optimizer the capability to optimize parameters in batches:
32
- it will stack the parameters and their grads for you so the optimizer can work
33
- on tensors with an extra leading dimension. This is intended for speed with GPUs,
34
- as it reduces the number of kernels launched in the optimizer.
35
-
36
- Args:
37
- params:
38
- """
39
-
40
- def __init__(self, params, defaults):
41
- super(BatchedOptimizer, self).__init__(params, defaults)
42
-
43
- @contextlib.contextmanager
44
- def batched_params(self, param_group, group_params_names):
45
- """
46
- This function returns (technically, yields) a list of
47
- of tuples (p, state), where
48
- p is a `fake` parameter that is stacked (over axis 0) from real parameters
49
- that share the same shape, and its gradient is also stacked;
50
- `state` is the state corresponding to this batch of parameters
51
- (it will be physically located in the "state" for one of the real
52
- parameters, the last one that has any particular shape and dtype).
53
-
54
- This function is decorated as a context manager so that it can
55
- write parameters back to their "real" locations.
56
-
57
- The idea is, instead of doing:
58
- <code>
59
- for p in group["params"]:
60
- state = self.state[p]
61
- ...
62
- </code>
63
- you can do:
64
- <code>
65
- with self.batched_params(group["params"]) as batches:
66
- for p, state, p_names in batches:
67
- ...
68
- </code>
69
-
70
- Args:
71
- group: a parameter group, which is a list of parameters; should be
72
- one of self.param_groups.
73
- group_params_names: name for each parameter in group,
74
- which is List[str].
75
- """
76
- batches = defaultdict(
77
- list
78
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
79
- batches_names = defaultdict(
80
- list
81
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
82
-
83
- assert len(param_group) == len(group_params_names)
84
- for p, named_p in zip(param_group, group_params_names):
85
- key = (str(p.dtype), *p.shape)
86
- batches[key].append(p)
87
- batches_names[key].append(named_p)
88
-
89
- batches_names_keys = list(batches_names.keys())
90
- sorted_idx = sorted(
91
- range(len(batches_names)), key=lambda i: batches_names_keys[i]
92
- )
93
- batches_names = [
94
- batches_names[batches_names_keys[idx]] for idx in sorted_idx
95
- ]
96
- batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
97
-
98
- stacked_params_dict = dict()
99
-
100
- # turn batches into a list, in deterministic order.
101
- # tuples will contain tuples of (stacked_param, state, stacked_params_names),
102
- # one for each batch in `batches`.
103
- tuples = []
104
-
105
- for batch, batch_names in zip(batches, batches_names):
106
- p = batch[0]
107
- # we arbitrarily store the state in the
108
- # state corresponding to the 1st parameter in the
109
- # group. class Optimizer will take care of saving/loading state.
110
- state = self.state[p]
111
- p_stacked = torch.stack(batch)
112
- grad = torch.stack(
113
- [
114
- torch.zeros_like(p) if p.grad is None else p.grad
115
- for p in batch
116
- ]
117
- )
118
- p_stacked.grad = grad
119
- stacked_params_dict[key] = p_stacked
120
- tuples.append((p_stacked, state, batch_names))
121
-
122
- yield tuples # <-- calling code will do the actual optimization here!
123
-
124
- for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
125
- for i, p in enumerate(batch): # batch is list of Parameter
126
- p.copy_(stacked_params[i])
127
-
128
-
129
- class ScaledAdam(BatchedOptimizer):
130
- """
131
- Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
132
- proportional to the norm of that parameter; and also learn the scale of the parameter,
133
- in log space, subject to upper and lower limits (as if we had factored each parameter as
134
- param = underlying_param * log_scale.exp())
135
-
136
-
137
- Args:
138
- params: The parameters or param_groups to optimize (like other Optimizer subclasses)
139
- lr: The learning rate. We will typically use a learning rate schedule that starts
140
- at 0.03 and decreases over time, i.e. much higher than other common
141
- optimizers.
142
- clipping_scale: (e.g. 2.0)
143
- A scale for gradient-clipping: if specified, the normalized gradients
144
- over the whole model will be clipped to have 2-norm equal to
145
- `clipping_scale` times the median 2-norm over the most recent period
146
- of `clipping_update_period` minibatches. By "normalized gradients",
147
- we mean after multiplying by the rms parameter value for this tensor
148
- [for non-scalars]; this is appropriate because our update is scaled
149
- by this quantity.
150
- betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
151
- Must satisfy 0 < beta <= beta2 < 1.
152
- scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
153
- scale of each parameter tensor and scalar parameters of the mode..
154
- If each parameter were decomposed
155
- as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
156
- would be a the scaling factor on the learning rate of p_scale.
157
- eps: A general-purpose epsilon to prevent division by zero
158
- param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
159
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
160
- parameter tensor to be >= this value)
161
- param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
162
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
163
- parameter tensor to be <= this value)
164
- scalar_max: Maximum absolute value for scalar parameters (applicable if your
165
- model has any parameters with numel() == 1).
166
- size_update_period: The periodicity, in steps, with which we update the size (scale)
167
- of the parameter tensor. This is provided to save a little time
168
- in the update.
169
- clipping_update_period: if clipping_scale is specified, this is the period
170
- """
171
-
172
- def __init__(
173
- self,
174
- params,
175
- lr=3e-02,
176
- clipping_scale=None,
177
- betas=(0.9, 0.98),
178
- scalar_lr_scale=0.1,
179
- eps=1.0e-08,
180
- param_min_rms=1.0e-05,
181
- param_max_rms=3.0,
182
- scalar_max=10.0,
183
- size_update_period=4,
184
- clipping_update_period=100,
185
- parameters_names=None,
186
- show_dominant_parameters=True,
187
- ):
188
-
189
- assert parameters_names is not None, (
190
- "Please prepare parameters_names,"
191
- "which is a List[List[str]]. Each List[str] is for a group"
192
- "and each str is for a parameter"
193
- )
194
- defaults = dict(
195
- lr=lr,
196
- clipping_scale=clipping_scale,
197
- betas=betas,
198
- scalar_lr_scale=scalar_lr_scale,
199
- eps=eps,
200
- param_min_rms=param_min_rms,
201
- param_max_rms=param_max_rms,
202
- scalar_max=scalar_max,
203
- size_update_period=size_update_period,
204
- clipping_update_period=clipping_update_period,
205
- )
206
-
207
- super(ScaledAdam, self).__init__(params, defaults)
208
- assert len(self.param_groups) == len(parameters_names)
209
- self.parameters_names = parameters_names
210
- self.show_dominant_parameters = show_dominant_parameters
211
-
212
- def __setstate__(self, state):
213
- super(ScaledAdam, self).__setstate__(state)
214
-
215
- @torch.no_grad()
216
- def step(self, closure=None):
217
- """Performs a single optimization step.
218
-
219
- Arguments:
220
- closure (callable, optional): A closure that reevaluates the model
221
- and returns the loss.
222
- """
223
- loss = None
224
- if closure is not None:
225
- with torch.enable_grad():
226
- loss = closure()
227
-
228
- batch = True
229
-
230
- for group, group_params_names in zip(
231
- self.param_groups, self.parameters_names
232
- ):
233
-
234
- with self.batched_params(
235
- group["params"], group_params_names
236
- ) as batches:
237
-
238
- # batches is list of pairs (stacked_param, state). stacked_param is like
239
- # a regular parameter, and will have a .grad, but the 1st dim corresponds to
240
- # a stacking dim, it is not a real dim.
241
-
242
- if (
243
- len(batches[0][1]) == 0
244
- ): # if len(first state) == 0: not yet initialized
245
- clipping_scale = 1
246
- else:
247
- clipping_scale = self._get_clipping_scale(group, batches)
248
-
249
- for p, state, _ in batches:
250
- # Perform optimization step.
251
- # grad is not going to be None, we handled that when creating the batches.
252
- grad = p.grad
253
- if grad.is_sparse:
254
- raise RuntimeError(
255
- "ScaledAdam optimizer does not support sparse gradients"
256
- )
257
- # State initialization
258
- if len(state) == 0:
259
- self._init_state(group, p, state)
260
-
261
- self._step_one_batch(group, p, state, clipping_scale)
262
-
263
- return loss
264
-
265
- def _init_state(self, group: dict, p: Tensor, state: dict):
266
- """
267
- Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
268
- is actually the batch dimension, corresponding to batched-together
269
- parameters of a given shape.
270
-
271
-
272
- Args:
273
- group: Dict to look up configuration values.
274
- p: The parameter that we are initializing the state for
275
- state: Dict from string to whatever state we are initializing
276
- """
277
- size_update_period = group["size_update_period"]
278
-
279
- state["step"] = 0
280
-
281
- kwargs = {"device": p.device, "dtype": p.dtype}
282
-
283
- # 'delta' implements conventional momentum. There are
284
- # several different kinds of update going on, so rather than
285
- # compute "exp_avg" like in Adam, we store and decay a
286
- # parameter-change "delta", which combines all forms of
287
- # update. this is equivalent to how it's done in Adam,
288
- # except for the first few steps.
289
- state["delta"] = torch.zeros_like(
290
- p, memory_format=torch.preserve_format
291
- )
292
-
293
- batch_size = p.shape[0]
294
- numel = p.numel() // batch_size
295
- numel = p.numel()
296
-
297
- if numel > 1:
298
- # "param_rms" just periodically records the scalar root-mean-square value of
299
- # the parameter tensor.
300
- # it has a shape like (batch_size, 1, 1, 1, 1)
301
- param_rms = (
302
- (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
303
- )
304
- state["param_rms"] = param_rms
305
-
306
- state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
307
- state["scale_grads"] = torch.zeros(
308
- size_update_period, *param_rms.shape, **kwargs
309
- )
310
-
311
- # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
312
- state["exp_avg_sq"] = torch.zeros_like(
313
- p, memory_format=torch.preserve_format
314
- )
315
-
316
- def _get_clipping_scale(
317
- self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
318
- ) -> float:
319
- """
320
- Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
321
- by this amount before applying the rest of the update.
322
-
323
- Args:
324
- group: the parameter group, an item in self.param_groups
325
- tuples: a list of tuples of (param, state, param_names)
326
- where param is a batched set of parameters,
327
- with a .grad (1st dim is batch dim)
328
- and state is the state-dict where optimization parameters are kept.
329
- param_names is a List[str] while each str is name for a parameter
330
- in batched set of parameters "param".
331
- """
332
- assert len(tuples) >= 1
333
- clipping_scale = group["clipping_scale"]
334
- (first_p, first_state, _) = tuples[0]
335
- step = first_state["step"]
336
- if clipping_scale is None or step == 0:
337
- # no clipping. return early on step == 0 because the other
338
- # parameters' state won't have been initialized yet.
339
- return 1.0
340
- clipping_update_period = group["clipping_update_period"]
341
-
342
- tot_sumsq = torch.tensor(0.0, device=first_p.device)
343
- for (p, state, param_names) in tuples:
344
- grad = p.grad
345
- if grad.is_sparse:
346
- raise RuntimeError(
347
- "ScaledAdam optimizer does not support sparse gradients"
348
- )
349
- if p.numel() == p.shape[0]: # a batch of scalars
350
- tot_sumsq += (
351
- grad ** 2
352
- ).sum() # sum() to change shape [1] to []
353
- else:
354
- tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
355
-
356
- tot_norm = tot_sumsq.sqrt()
357
- if "model_norms" not in first_state:
358
- first_state["model_norms"] = torch.zeros(
359
- clipping_update_period, device=p.device
360
- )
361
- first_state["model_norms"][step % clipping_update_period] = tot_norm
362
-
363
- if step % clipping_update_period == 0:
364
- # Print some stats.
365
- # We don't reach here if step == 0 because we would have returned
366
- # above.
367
- sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
368
- quartiles = []
369
- for n in range(0, 5):
370
- index = min(
371
- clipping_update_period - 1,
372
- (clipping_update_period // 4) * n,
373
- )
374
- quartiles.append(sorted_norms[index].item())
375
-
376
- median = quartiles[2]
377
- threshold = clipping_scale * median
378
- first_state["model_norm_threshold"] = threshold
379
- percent_clipped = (
380
- first_state["num_clipped"] * 100.0 / clipping_update_period
381
- if "num_clipped" in first_state
382
- else 0.0
383
- )
384
- first_state["num_clipped"] = 0
385
- quartiles = " ".join(["%.3e" % x for x in quartiles])
386
- logging.info(
387
- f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
388
- f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
389
- )
390
-
391
- if step < clipping_update_period:
392
- return 1.0 # We have not yet estimated a norm to clip to.
393
- else:
394
- try:
395
- model_norm_threshold = first_state["model_norm_threshold"]
396
- except KeyError:
397
- logging.info(
398
- "Warning: model_norm_threshold not in state: possibly "
399
- "you changed config when restarting, adding clipping_scale option?"
400
- )
401
- return 1.0
402
- ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
403
- if ans < 1.0:
404
- first_state["num_clipped"] += 1
405
- if ans < 0.1:
406
- logging.warn(
407
- f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
408
- )
409
- if self.show_dominant_parameters:
410
- assert p.shape[0] == len(param_names)
411
- self._show_gradient_dominating_parameter(tuples, tot_sumsq)
412
- return ans
413
-
414
- def _show_gradient_dominating_parameter(
415
- self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
416
- ):
417
- """
418
- Show information of parameter wihch dominanting tot_sumsq.
419
-
420
- Args:
421
- tuples: a list of tuples of (param, state, param_names)
422
- where param is a batched set of parameters,
423
- with a .grad (1st dim is batch dim)
424
- and state is the state-dict where optimization parameters are kept.
425
- param_names is a List[str] while each str is name for a parameter
426
- in batched set of parameters "param".
427
- tot_sumsq: sumsq of all parameters. Though it's could be calculated
428
- from tuples, we still pass it to save some time.
429
- """
430
- all_sumsq_orig = {}
431
- for (p, state, batch_param_names) in tuples:
432
- # p is a stacked batch parameters.
433
- batch_grad = p.grad
434
- if p.numel() == p.shape[0]: # a batch of scalars
435
- batch_sumsq_orig = batch_grad ** 2
436
- # Dummpy values used by following `zip` statement.
437
- batch_rms_orig = torch.ones(p.shape[0])
438
- else:
439
- batch_rms_orig = state["param_rms"]
440
- batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
441
- dim=list(range(1, batch_grad.ndim))
442
- )
443
-
444
- for name, sumsq_orig, rms, grad in zip(
445
- batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
446
- ):
447
-
448
- proportion_orig = sumsq_orig / tot_sumsq
449
- all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
450
-
451
- assert torch.isclose(
452
- sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
453
- torch.tensor(1.0),
454
- )
455
- sorted_by_proportion = {
456
- k: v
457
- for k, v in sorted(
458
- all_sumsq_orig.items(),
459
- key=lambda item: item[1][0],
460
- reverse=True,
461
- )
462
- }
463
- dominant_param_name = next(iter(sorted_by_proportion))
464
- (
465
- dominant_proportion,
466
- dominant_sumsq,
467
- dominant_rms,
468
- dominant_grad,
469
- ) = sorted_by_proportion[dominant_param_name]
470
- logging.info(
471
- f"Parameter Dominanting tot_sumsq {dominant_param_name}"
472
- f" with proportion {dominant_proportion:.2f},"
473
- f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
474
- f"={dominant_sumsq:.3e},"
475
- f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
476
- f" orig_rms_sq={(dominant_rms**2).item():.3e}"
477
- )
478
-
479
- def _step_one_batch(
480
- self, group: dict, p: Tensor, state: dict, clipping_scale: float
481
- ):
482
- """
483
- Do the step for one parameter, which is actually going to be a batch of
484
- `real` parameters, with dim 0 as the batch dim.
485
- Args:
486
- group: dict to look up configuration values
487
- p: parameter to update (actually multiple parameters stacked together
488
- as a batch)
489
- state: state-dict for p, to look up the optimizer state
490
- """
491
- lr = group["lr"]
492
- size_update_period = group["size_update_period"]
493
- beta1 = group["betas"][0]
494
-
495
- grad = p.grad
496
- if clipping_scale != 1.0:
497
- grad = grad * clipping_scale
498
- step = state["step"]
499
- delta = state["delta"]
500
-
501
- delta.mul_(beta1)
502
- batch_size = p.shape[0]
503
- numel = p.numel() // batch_size
504
- if numel > 1:
505
- # Update the size/scale of p, and set param_rms
506
- scale_grads = state["scale_grads"]
507
- scale_grads[step % size_update_period] = (p * grad).sum(
508
- dim=list(range(1, p.ndim)), keepdim=True
509
- )
510
- if step % size_update_period == size_update_period - 1:
511
- param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
512
- param_rms.copy_(
513
- (p ** 2)
514
- .mean(dim=list(range(1, p.ndim)), keepdim=True)
515
- .sqrt()
516
- )
517
- if step > 0:
518
- # self._size_update() learns the overall scale on the
519
- # parameter, by shrinking or expanding it.
520
- self._size_update(group, scale_grads, p, state)
521
-
522
- if numel == 1:
523
- # For parameters with 1 element we just use regular Adam.
524
- # Updates delta.
525
- self._step_scalar(group, p, state)
526
- else:
527
- self._step(group, p, state)
528
-
529
- state["step"] = step + 1
530
-
531
- def _size_update(
532
- self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
533
- ) -> None:
534
- """
535
- Called only where p.numel() > 1, this updates the scale of the parameter.
536
- If we imagine: p = underlying_param * scale.exp(), and we are doing
537
- gradient descent on underlying param and on scale, this function does the update
538
- on `scale`.
539
-
540
- Args:
541
- group: dict to look up configuration values
542
- scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
543
- grads w.r.t. the scales.
544
- p: The parameter to update
545
- state: The state-dict of p
546
- """
547
-
548
- param_rms = state["param_rms"]
549
- beta1, beta2 = group["betas"]
550
- size_lr = group["lr"] * group["scalar_lr_scale"]
551
- param_min_rms = group["param_min_rms"]
552
- param_max_rms = group["param_max_rms"]
553
- eps = group["eps"]
554
- step = state["step"]
555
- batch_size = p.shape[0]
556
-
557
- size_update_period = scale_grads.shape[0]
558
- # correct beta2 for the size update period: we will have
559
- # faster decay at this level.
560
- beta2_corr = beta2 ** size_update_period
561
-
562
- scale_exp_avg_sq = state[
563
- "scale_exp_avg_sq"
564
- ] # shape: (batch_size, 1, 1, ..)
565
- scale_exp_avg_sq.mul_(beta2_corr).add_(
566
- (scale_grads ** 2).mean(
567
- dim=0
568
- ), # mean over dim `size_update_period`
569
- alpha=1 - beta2_corr,
570
- ) # shape is (batch_size, 1, 1, ...)
571
-
572
- # The 1st time we reach here is when size_step == 1.
573
- size_step = (step + 1) // size_update_period
574
- bias_correction2 = 1 - beta2_corr ** size_step
575
- # we don't bother with bias_correction1; this will help prevent divergence
576
- # at the start of training.
577
-
578
- denom = scale_exp_avg_sq.sqrt() + eps
579
-
580
- scale_step = (
581
- -size_lr
582
- * (bias_correction2 ** 0.5)
583
- * scale_grads.sum(dim=0)
584
- / denom
585
- )
586
-
587
- is_too_small = param_rms < param_min_rms
588
- is_too_large = param_rms > param_max_rms
589
-
590
- # when the param gets too small, just don't shrink it any further.
591
- scale_step.masked_fill_(is_too_small, 0.0)
592
- # when it gets too large, stop it from getting any larger.
593
- scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
594
- delta = state["delta"]
595
- # the factor of (1-beta1) relates to momentum.
596
- delta.add_(p * scale_step, alpha=(1 - beta1))
597
-
598
- def _step(self, group: dict, p: Tensor, state: dict):
599
- """
600
- This function does the core update of self.step(), in the case where the members of
601
- the batch have more than 1 element.
602
-
603
- Args:
604
- group: A dict which will be used to look up configuration values
605
- p: The parameter to be updated
606
- grad: The grad of p
607
- state: The state-dict corresponding to parameter p
608
-
609
- This function modifies p.
610
- """
611
- grad = p.grad
612
- lr = group["lr"]
613
- beta1, beta2 = group["betas"]
614
- eps = group["eps"]
615
- param_min_rms = group["param_min_rms"]
616
- step = state["step"]
617
-
618
- exp_avg_sq = state["exp_avg_sq"]
619
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
620
-
621
- this_step = state["step"] - (
622
- state["zero_step"] if "zero_step" in state else 0
623
- )
624
- bias_correction2 = 1 - beta2 ** (this_step + 1)
625
- if bias_correction2 < 0.99:
626
- # note: not in-place.
627
- exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
628
-
629
- denom = exp_avg_sq.sqrt()
630
- denom += eps
631
- grad = grad / denom
632
-
633
- alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
634
-
635
- delta = state["delta"]
636
- delta.add_(grad * alpha)
637
- p.add_(delta)
638
-
639
- def _step_scalar(self, group: dict, p: Tensor, state: dict):
640
- """
641
- A simplified form of the core update for scalar tensors, where we cannot get a good
642
- estimate of the parameter rms.
643
- """
644
- beta1, beta2 = group["betas"]
645
- scalar_max = group["scalar_max"]
646
- eps = group["eps"]
647
- lr = group["lr"] * group["scalar_lr_scale"]
648
- grad = p.grad
649
-
650
- exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
651
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
652
-
653
- # bias_correction2 is like in Adam. Don't bother with bias_correction1;
654
- # slower update at the start will help stability anyway.
655
- bias_correction2 = 1 - beta2 ** (state["step"] + 1)
656
- denom = (exp_avg_sq / bias_correction2).sqrt() + eps
657
-
658
- delta = state["delta"]
659
- delta.add_(grad / denom, alpha=-lr * (1 - beta1))
660
- p.clamp_(min=-scalar_max, max=scalar_max)
661
- p.add_(delta)
662
-
663
-
664
- class LRScheduler(object):
665
- """
666
- Base-class for learning rate schedulers where the learning-rate depends on both the
667
- batch and the epoch.
668
- """
669
-
670
- def __init__(self, optimizer: Optimizer, verbose: bool = False):
671
- # Attach optimizer
672
- if not isinstance(optimizer, Optimizer):
673
- raise TypeError(
674
- "{} is not an Optimizer".format(type(optimizer).__name__)
675
- )
676
- self.optimizer = optimizer
677
- self.verbose = verbose
678
-
679
- for group in optimizer.param_groups:
680
- group.setdefault("base_lr", group["lr"])
681
-
682
- self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
683
-
684
- self.epoch = 0
685
- self.batch = 0
686
-
687
- def state_dict(self):
688
- """Returns the state of the scheduler as a :class:`dict`.
689
-
690
- It contains an entry for every variable in self.__dict__ which
691
- is not the optimizer.
692
- """
693
- return {
694
- "base_lrs": self.base_lrs,
695
- "epoch": self.epoch,
696
- "batch": self.batch,
697
- }
698
-
699
- def load_state_dict(self, state_dict):
700
- """Loads the schedulers state.
701
-
702
- Args:
703
- state_dict (dict): scheduler state. Should be an object returned
704
- from a call to :meth:`state_dict`.
705
- """
706
- self.__dict__.update(state_dict)
707
-
708
- def get_last_lr(self) -> List[float]:
709
- """Return last computed learning rate by current scheduler. Will be a list of float."""
710
- return self._last_lr
711
-
712
- def get_lr(self):
713
- # Compute list of learning rates from self.epoch and self.batch and
714
- # self.base_lrs; this must be overloaded by the user.
715
- # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
716
- raise NotImplementedError
717
-
718
- def step_batch(self, batch: Optional[int] = None) -> None:
719
- # Step the batch index, or just set it. If `batch` is specified, it
720
- # must be the batch index from the start of training, i.e. summed over
721
- # all epochs.
722
- # You can call this in any order; if you don't provide 'batch', it should
723
- # of course be called once per batch.
724
- if batch is not None:
725
- self.batch = batch
726
- else:
727
- self.batch = self.batch + 1
728
- self._set_lrs()
729
-
730
- def step_epoch(self, epoch: Optional[int] = None):
731
- # Step the epoch index, or just set it. If you provide the 'epoch' arg,
732
- # you should call this at the start of the epoch; if you don't provide the 'epoch'
733
- # arg, you should call it at the end of the epoch.
734
- if epoch is not None:
735
- self.epoch = epoch
736
- else:
737
- self.epoch = self.epoch + 1
738
- self._set_lrs()
739
-
740
- def _set_lrs(self):
741
- values = self.get_lr()
742
- assert len(values) == len(self.optimizer.param_groups)
743
-
744
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
745
- param_group, lr = data
746
- param_group["lr"] = lr
747
- self.print_lr(self.verbose, i, lr)
748
- self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
749
-
750
- def print_lr(self, is_verbose, group, lr):
751
- """Display the current learning rate."""
752
- if is_verbose:
753
- logging.info(
754
- f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
755
- f" of group {group} to {lr:.4e}."
756
- )
757
-
758
-
759
- class Eden(LRScheduler):
760
- """
761
- Eden scheduler.
762
- The basic formula (before warmup) is:
763
- lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
764
- (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
765
- where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
766
- and then stays constant at 1.
767
-
768
-
769
- E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
770
-
771
- Args:
772
- optimizer: the optimizer to change the learning rates on
773
- lr_batches: the number of batches after which we start significantly
774
- decreasing the learning rate, suggest 5000.
775
- lr_epochs: the number of epochs after which we start significantly
776
- decreasing the learning rate, suggest 6 if you plan to do e.g.
777
- 20 to 40 epochs, but may need smaller number if dataset is huge
778
- and you will do few epochs.
779
- """
780
-
781
- def __init__(
782
- self,
783
- optimizer: Optimizer,
784
- lr_batches: Union[int, float],
785
- lr_epochs: Union[int, float],
786
- warmup_batches: Union[int, float] = 500.0,
787
- verbose: bool = False,
788
- ):
789
- super(Eden, self).__init__(optimizer, verbose)
790
- self.lr_batches = lr_batches
791
- self.lr_epochs = lr_epochs
792
- self.warmup_batches = warmup_batches
793
-
794
- def get_lr(self):
795
- factor = (
796
- (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
797
- ) ** -0.25 * (
798
- ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
799
- ** -0.25
800
- )
801
- warmup_factor = (
802
- 1.0
803
- if self.batch >= self.warmup_batches
804
- else 0.5 + 0.5 * (self.batch / self.warmup_batches)
805
- )
806
-
807
- return [x * factor * warmup_factor for x in self.base_lrs]
808
-
809
-
810
- def _test_eden():
811
- m = torch.nn.Linear(100, 100)
812
- optim = ScaledAdam(m.parameters(), lr=0.03)
813
-
814
- scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
815
-
816
- for epoch in range(10):
817
- scheduler.step_epoch(epoch) # sets epoch to `epoch`
818
-
819
- for step in range(20):
820
- x = torch.randn(200, 100).detach()
821
- x.requires_grad = True
822
- y = m(x)
823
- dy = torch.randn(200, 100).detach()
824
- f = (y * dy).sum()
825
- f.backward()
826
-
827
- optim.step()
828
- scheduler.step_batch()
829
- optim.zero_grad()
830
-
831
- logging.info(f"last lr = {scheduler.get_last_lr()}")
832
- logging.info(f"state dict = {scheduler.state_dict()}")
833
-
834
-
835
- # This is included mostly as a baseline for ScaledAdam.
836
- class Eve(Optimizer):
837
- """
838
- Implements Eve algorithm. This is a modified version of AdamW with a special
839
- way of setting the weight-decay / shrinkage-factor, which is designed to make the
840
- rms of the parameters approach a particular target_rms (default: 0.1). This is
841
- for use with networks with 'scaled' versions of modules (see scaling.py), which
842
- will be close to invariant to the absolute scale on the parameter matrix.
843
-
844
- The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
845
- The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
846
- Eve is unpublished so far.
847
-
848
- Arguments:
849
- params (iterable): iterable of parameters to optimize or dicts defining
850
- parameter groups
851
- lr (float, optional): learning rate (default: 1e-3)
852
- betas (Tuple[float, float], optional): coefficients used for computing
853
- running averages of gradient and its square (default: (0.9, 0.999))
854
- eps (float, optional): term added to the denominator to improve
855
- numerical stability (default: 1e-8)
856
- weight_decay (float, optional): weight decay coefficient (default: 3e-4;
857
- this value means that the weight would decay significantly after
858
- about 3k minibatches. Is not multiplied by learning rate, but
859
- is conditional on RMS-value of parameter being > target_rms.
860
- target_rms (float, optional): target root-mean-square value of
861
- parameters, if they fall below this we will stop applying weight decay.
862
-
863
-
864
- .. _Adam: A Method for Stochastic Optimization:
865
- https://arxiv.org/abs/1412.6980
866
- .. _Decoupled Weight Decay Regularization:
867
- https://arxiv.org/abs/1711.05101
868
- .. _On the Convergence of Adam and Beyond:
869
- https://openreview.net/forum?id=ryQu7f-RZ
870
- """
871
-
872
- def __init__(
873
- self,
874
- params,
875
- lr=1e-3,
876
- betas=(0.9, 0.98),
877
- eps=1e-8,
878
- weight_decay=1e-3,
879
- target_rms=0.1,
880
- ):
881
- if not 0.0 <= lr:
882
- raise ValueError("Invalid learning rate: {}".format(lr))
883
- if not 0.0 <= eps:
884
- raise ValueError("Invalid epsilon value: {}".format(eps))
885
- if not 0.0 <= betas[0] < 1.0:
886
- raise ValueError(
887
- "Invalid beta parameter at index 0: {}".format(betas[0])
888
- )
889
- if not 0.0 <= betas[1] < 1.0:
890
- raise ValueError(
891
- "Invalid beta parameter at index 1: {}".format(betas[1])
892
- )
893
- if not 0 <= weight_decay <= 0.1:
894
- raise ValueError(
895
- "Invalid weight_decay value: {}".format(weight_decay)
896
- )
897
- if not 0 < target_rms <= 10.0:
898
- raise ValueError("Invalid target_rms value: {}".format(target_rms))
899
- defaults = dict(
900
- lr=lr,
901
- betas=betas,
902
- eps=eps,
903
- weight_decay=weight_decay,
904
- target_rms=target_rms,
905
- )
906
- super(Eve, self).__init__(params, defaults)
907
-
908
- def __setstate__(self, state):
909
- super(Eve, self).__setstate__(state)
910
-
911
- @torch.no_grad()
912
- def step(self, closure=None):
913
- """Performs a single optimization step.
914
-
915
- Arguments:
916
- closure (callable, optional): A closure that reevaluates the model
917
- and returns the loss.
918
- """
919
- loss = None
920
- if closure is not None:
921
- with torch.enable_grad():
922
- loss = closure()
923
-
924
- for group in self.param_groups:
925
- for p in group["params"]:
926
- if p.grad is None:
927
- continue
928
-
929
- # Perform optimization step
930
- grad = p.grad
931
- if grad.is_sparse:
932
- raise RuntimeError(
933
- "AdamW does not support sparse gradients"
934
- )
935
-
936
- state = self.state[p]
937
-
938
- # State initialization
939
- if len(state) == 0:
940
- state["step"] = 0
941
- # Exponential moving average of gradient values
942
- state["exp_avg"] = torch.zeros_like(
943
- p, memory_format=torch.preserve_format
944
- )
945
- # Exponential moving average of squared gradient values
946
- state["exp_avg_sq"] = torch.zeros_like(
947
- p, memory_format=torch.preserve_format
948
- )
949
-
950
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
951
-
952
- beta1, beta2 = group["betas"]
953
-
954
- state["step"] += 1
955
- bias_correction1 = 1 - beta1 ** state["step"]
956
- bias_correction2 = 1 - beta2 ** state["step"]
957
-
958
- # Decay the first and second moment running average coefficient
959
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
960
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
961
- denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
962
- group["eps"]
963
- )
964
-
965
- step_size = group["lr"] / bias_correction1
966
- target_rms = group["target_rms"]
967
- weight_decay = group["weight_decay"]
968
-
969
- if p.numel() > 1:
970
- # avoid applying this weight-decay on "scaling factors"
971
- # (which are scalar).
972
- is_above_target_rms = p.norm() > (
973
- target_rms * (p.numel() ** 0.5)
974
- )
975
- p.mul_(1 - (weight_decay * is_above_target_rms))
976
-
977
- p.addcdiv_(exp_avg, denom, value=-step_size)
978
-
979
- # if random.random() < 0.0005:
980
- # step = (exp_avg / denom) * step_size
981
- # logging.info(
982
- # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
983
- # )
984
-
985
- return loss
986
-
987
-
988
- def _test_scaled_adam(hidden_dim: int):
989
- import timeit
990
-
991
- from scaling import ScaledLinear
992
-
993
- E = 100
994
- B = 4
995
- T = 2
996
- logging.info("in test_eve_cain")
997
- # device = torch.device('cuda')
998
- device = torch.device("cpu")
999
- dtype = torch.float32
1000
-
1001
- fix_random_seed(42)
1002
- # these input_magnitudes and output_magnitudes are to test that
1003
- # Abel is working as we expect and is able to adjust scales of
1004
- # different dims differently.
1005
- input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1006
- output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1007
-
1008
- for iter in [1, 0]:
1009
- fix_random_seed(42)
1010
- Linear = torch.nn.Linear if iter == 0 else ScaledLinear
1011
-
1012
- m = torch.nn.Sequential(
1013
- Linear(E, hidden_dim),
1014
- torch.nn.PReLU(),
1015
- Linear(hidden_dim, hidden_dim),
1016
- torch.nn.PReLU(),
1017
- Linear(hidden_dim, E),
1018
- ).to(device)
1019
-
1020
- train_pairs = [
1021
- (
1022
- 100.0
1023
- * torch.randn(B, T, E, device=device, dtype=dtype)
1024
- * input_magnitudes,
1025
- torch.randn(B, T, E, device=device, dtype=dtype)
1026
- * output_magnitudes,
1027
- )
1028
- for _ in range(20)
1029
- ]
1030
-
1031
- if iter == 0:
1032
- optim = Eve(m.parameters(), lr=0.003)
1033
- elif iter == 1:
1034
- optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
1035
- scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
1036
-
1037
- start = timeit.default_timer()
1038
- avg_loss = 0.0
1039
- for epoch in range(180):
1040
- scheduler.step_epoch()
1041
- # if epoch == 100 and iter in [2,3]:
1042
- # optim.reset_speedup() # check it doesn't crash.
1043
-
1044
- # if epoch == 130:
1045
- # opts = diagnostics.TensorDiagnosticOptions(
1046
- # 2 ** 22
1047
- # ) # allow 4 megabytes per sub-module
1048
- # diagnostic = diagnostics.attach_diagnostics(m, opts)
1049
-
1050
- for n, (x, y) in enumerate(train_pairs):
1051
- y_out = m(x)
1052
- loss = ((y_out - y) ** 2).mean() * 100.0
1053
- if epoch == 0 and n == 0:
1054
- avg_loss = loss.item()
1055
- else:
1056
- avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
1057
- if n == 0 and epoch % 5 == 0:
1058
- # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
1059
- # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
1060
- # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
1061
- # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
1062
- # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
1063
- # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
1064
- # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
1065
- # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
1066
- lr = scheduler.get_last_lr()[0]
1067
- logging.info(
1068
- f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
1069
- ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
1070
- loss.log().backward()
1071
- optim.step()
1072
- optim.zero_grad()
1073
- scheduler.step_batch()
1074
-
1075
- # diagnostic.print_diagnostics()
1076
-
1077
- stop = timeit.default_timer()
1078
- logging.info(f"Iter={iter}, Time taken: {stop - start}")
1079
-
1080
- logging.info(f"last lr = {scheduler.get_last_lr()}")
1081
- # logging.info("state dict = ", scheduler.state_dict())
1082
- # logging.info("optim state_dict = ", optim.state_dict())
1083
- logging.info(f"input_magnitudes = {input_magnitudes}")
1084
- logging.info(f"output_magnitudes = {output_magnitudes}")
1085
-
1086
-
1087
- if __name__ == "__main__":
1088
- torch.set_num_threads(1)
1089
- torch.set_num_interop_threads(1)
1090
- logging.getLogger().setLevel(logging.INFO)
1091
- import subprocess
1092
-
1093
- s = subprocess.check_output(
1094
- "git status -uno .; git log -1; git diff HEAD .", shell=True
1095
- )
1096
- logging.info(s)
1097
- import sys
1098
-
1099
- if len(sys.argv) > 1:
1100
- hidden_dim = int(sys.argv[1])
1101
- else:
1102
- hidden_dim = 200
1103
-
1104
- _test_scaled_adam(hidden_dim)
1105
- _test_eden()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/scaling.py DELETED
@@ -1,1401 +0,0 @@
1
- # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
- #
3
- # See ../../../../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
-
18
- import collections
19
- import logging
20
- import random
21
- import math
22
- from functools import reduce
23
- from itertools import repeat
24
- from typing import Optional, Tuple, Union
25
-
26
- import torch
27
- import torch.nn as nn
28
- import torch.nn.functional as F
29
- from torch import Tensor
30
- from torch.nn import Embedding as ScaledEmbedding
31
-
32
- from utils import Transpose
33
-
34
-
35
- class ActivationBalancerFunction(torch.autograd.Function):
36
- @staticmethod
37
- def forward(
38
- ctx,
39
- x: Tensor,
40
- scale_factor: Tensor,
41
- sign_factor: Optional[Tensor],
42
- channel_dim: int,
43
- ) -> Tensor:
44
- if channel_dim < 0:
45
- channel_dim += x.ndim
46
- ctx.channel_dim = channel_dim
47
- xgt0 = x > 0
48
- if sign_factor is None:
49
- ctx.save_for_backward(xgt0, scale_factor)
50
- else:
51
- ctx.save_for_backward(xgt0, scale_factor, sign_factor)
52
- return x
53
-
54
- @staticmethod
55
- def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
56
- if len(ctx.saved_tensors) == 3:
57
- xgt0, scale_factor, sign_factor = ctx.saved_tensors
58
- for _ in range(ctx.channel_dim, x_grad.ndim - 1):
59
- scale_factor = scale_factor.unsqueeze(-1)
60
- sign_factor = sign_factor.unsqueeze(-1)
61
- factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
62
- else:
63
- xgt0, scale_factor = ctx.saved_tensors
64
- for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
- scale_factor = scale_factor.unsqueeze(-1)
66
- factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
67
- neg_delta_grad = x_grad.abs() * factor
68
- return (
69
- x_grad - neg_delta_grad,
70
- None,
71
- None,
72
- None,
73
- )
74
-
75
-
76
- def _compute_scale_factor(
77
- x: Tensor,
78
- channel_dim: int,
79
- min_abs: float,
80
- max_abs: float,
81
- gain_factor: float,
82
- max_factor: float,
83
- ) -> Tensor:
84
- if channel_dim < 0:
85
- channel_dim += x.ndim
86
- sum_dims = [d for d in range(x.ndim) if d != channel_dim]
87
- x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
88
-
89
- if min_abs == 0.0:
90
- below_threshold = 0.0
91
- else:
92
- # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
93
- # x_abs)_mean , min_abs.
94
- below_threshold = (
95
- (min_abs - x_abs_mean) * (gain_factor / min_abs)
96
- ).clamp(min=0, max=max_factor)
97
-
98
- above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
99
- min=0, max=max_factor
100
- )
101
-
102
- return below_threshold - above_threshold
103
-
104
-
105
- def _compute_sign_factor(
106
- x: Tensor,
107
- channel_dim: int,
108
- min_positive: float,
109
- max_positive: float,
110
- gain_factor: float,
111
- max_factor: float,
112
- ) -> Tensor:
113
- if channel_dim < 0:
114
- channel_dim += x.ndim
115
- sum_dims = [d for d in range(x.ndim) if d != channel_dim]
116
- proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
117
- if min_positive == 0.0:
118
- factor1 = 0.0
119
- else:
120
- # 0 if proportion_positive >= min_positive, else can be
121
- # as large as max_factor.
122
- factor1 = (
123
- (min_positive - proportion_positive) * (gain_factor / min_positive)
124
- ).clamp_(min=0, max=max_factor)
125
-
126
- if max_positive == 1.0:
127
- factor2 = 0.0
128
- else:
129
- # 0 if self.proportion_positive <= max_positive, else can be
130
- # as large as -max_factor.
131
- factor2 = (
132
- (proportion_positive - max_positive)
133
- * (gain_factor / (1.0 - max_positive))
134
- ).clamp_(min=0, max=max_factor)
135
- sign_factor = factor1 - factor2
136
- # require min_positive != 0 or max_positive != 1:
137
- assert not isinstance(sign_factor, float)
138
- return sign_factor
139
-
140
-
141
- class ActivationScaleBalancerFunction(torch.autograd.Function):
142
- """
143
- This object is used in class ActivationBalancer when the user specified
144
- min_positive=0, max_positive=1, so there are no constraints on the signs
145
- of the activations and only the absolute value has a constraint.
146
- """
147
-
148
- @staticmethod
149
- def forward(
150
- ctx,
151
- x: Tensor,
152
- sign_factor: Tensor,
153
- scale_factor: Tensor,
154
- channel_dim: int,
155
- ) -> Tensor:
156
- if channel_dim < 0:
157
- channel_dim += x.ndim
158
- ctx.channel_dim = channel_dim
159
- xgt0 = x > 0
160
- ctx.save_for_backward(xgt0, sign_factor, scale_factor)
161
- return x
162
-
163
- @staticmethod
164
- def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
165
- xgt0, sign_factor, scale_factor = ctx.saved_tensors
166
- for _ in range(ctx.channel_dim, x_grad.ndim - 1):
167
- sign_factor = sign_factor.unsqueeze(-1)
168
- scale_factor = scale_factor.unsqueeze(-1)
169
-
170
- factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
171
- neg_delta_grad = x_grad.abs() * factor
172
- return (
173
- x_grad - neg_delta_grad,
174
- None,
175
- None,
176
- None,
177
- )
178
-
179
-
180
- class RandomClampFunction(torch.autograd.Function):
181
- @staticmethod
182
- def forward(
183
- ctx,
184
- x: Tensor,
185
- min: Optional[float],
186
- max: Optional[float],
187
- prob: float,
188
- reflect: float,
189
- ) -> Tensor:
190
- x_clamped = torch.clamp(x, min=min, max=max)
191
- mask = torch.rand_like(x) < prob
192
- ans = torch.where(mask, x_clamped, x)
193
- if x.requires_grad:
194
- ctx.save_for_backward(ans == x)
195
- ctx.reflect = reflect
196
- if reflect != 0.0:
197
- ans = ans * (1.0 + reflect) - (x * reflect)
198
- return ans
199
-
200
- @staticmethod
201
- def backward(
202
- ctx, ans_grad: Tensor
203
- ) -> Tuple[Tensor, None, None, None, None]:
204
- (is_same,) = ctx.saved_tensors
205
- x_grad = ans_grad * is_same.to(ans_grad.dtype)
206
- reflect = ctx.reflect
207
- if reflect != 0.0:
208
- x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
209
- return x_grad, None, None, None, None
210
-
211
-
212
- def random_clamp(
213
- x: Tensor,
214
- min: Optional[float] = None,
215
- max: Optional[float] = None,
216
- prob: float = 0.5,
217
- reflect: float = 0.0,
218
- ):
219
- return RandomClampFunction.apply(x, min, max, prob, reflect)
220
-
221
-
222
- def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
223
- """
224
- A randomized way of casting a floating point value to half precision.
225
- """
226
- if x.dtype == torch.float16:
227
- return x
228
- x_abs = x.abs()
229
- is_too_small = x_abs < min_abs
230
- # for elements where is_too_small is true, random_val will contain +-min_abs with
231
- # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
232
- # for those elements].
233
- random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
234
- return torch.where(is_too_small, random_val, x).to(torch.float16)
235
-
236
-
237
- class RandomGradFunction(torch.autograd.Function):
238
- """
239
- Does nothing in forward pass; in backward pass, gets rid of very small grads using
240
- randomized approach that preserves expectations (intended to reduce roundoff).
241
- """
242
-
243
- @staticmethod
244
- def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
245
- ctx.min_abs = min_abs
246
- return x
247
-
248
- @staticmethod
249
- def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
250
- if ans_grad.dtype == torch.float16:
251
- return (
252
- random_cast_to_half(
253
- ans_grad.to(torch.float32), min_abs=ctx.min_abs
254
- ),
255
- None,
256
- )
257
- else:
258
- return ans_grad, None
259
-
260
-
261
- class RandomGrad(torch.nn.Module):
262
- """
263
- Gets rid of very small gradients using an expectation-preserving method, intended to increase
264
- accuracy of training when using amp (automatic mixed precision)
265
- """
266
-
267
- def __init__(self, min_abs: float = 5.0e-06):
268
- super(RandomGrad, self).__init__()
269
- self.min_abs = min_abs
270
-
271
- def forward(self, x: Tensor):
272
- if (
273
- torch.jit.is_scripting()
274
- or not self.training
275
- or torch.jit.is_tracing()
276
- ):
277
- return x
278
- else:
279
- return RandomGradFunction.apply(x, self.min_abs)
280
-
281
-
282
- class SoftmaxFunction(torch.autograd.Function):
283
- """
284
- Tries to handle half-precision derivatives in a randomized way that should
285
- be more accurate for training than the default behavior.
286
- """
287
-
288
- @staticmethod
289
- def forward(ctx, x: Tensor, dim: int):
290
- ans = x.softmax(dim=dim)
291
- # if x dtype is float16, x.softmax() returns a float32 because
292
- # (presumably) that op does not support float16, and autocast
293
- # is enabled.
294
- if torch.is_autocast_enabled():
295
- ans = ans.to(torch.float16)
296
- ctx.save_for_backward(ans)
297
- ctx.x_dtype = x.dtype
298
- ctx.dim = dim
299
- return ans
300
-
301
- @staticmethod
302
- def backward(ctx, ans_grad: Tensor):
303
- (ans,) = ctx.saved_tensors
304
- with torch.cuda.amp.autocast(enabled=False):
305
- ans_grad = ans_grad.to(torch.float32)
306
- ans = ans.to(torch.float32)
307
- x_grad = ans_grad * ans
308
- x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
309
- return x_grad, None
310
-
311
-
312
- def softmax(x: Tensor, dim: int):
313
- if torch.jit.is_scripting() or torch.jit.is_tracing():
314
- return x.softmax(dim)
315
-
316
- return SoftmaxFunction.apply(x, dim)
317
-
318
-
319
- class MaxEigLimiterFunction(torch.autograd.Function):
320
- @staticmethod
321
- def forward(
322
- ctx,
323
- x: Tensor,
324
- coeffs: Tensor,
325
- direction: Tensor,
326
- channel_dim: int,
327
- grad_scale: float,
328
- ) -> Tensor:
329
- ctx.channel_dim = channel_dim
330
- ctx.grad_scale = grad_scale
331
- ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
332
- return x
333
-
334
- @staticmethod
335
- def backward(ctx, x_grad, *args):
336
- with torch.enable_grad():
337
- (x_orig, coeffs, new_direction) = ctx.saved_tensors
338
- x_orig.requires_grad = True
339
- num_channels = x_orig.shape[ctx.channel_dim]
340
- x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
341
- new_direction.requires_grad = False
342
- x = x - x.mean(dim=0)
343
- x_var = (x ** 2).mean()
344
- x_residual = x - coeffs * new_direction
345
- x_residual_var = (x_residual ** 2).mean()
346
- # `variance_proportion` is the proportion of the variance accounted for
347
- # by the top eigen-direction. This is to be minimized.
348
- variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
349
- variance_proportion.backward()
350
- x_orig_grad = x_orig.grad
351
- x_extra_grad = (
352
- x_orig.grad
353
- * ctx.grad_scale
354
- * x_grad.norm()
355
- / (x_orig_grad.norm() + 1.0e-20)
356
- )
357
- return x_grad + x_extra_grad.detach(), None, None, None, None
358
-
359
-
360
- class BasicNorm(torch.nn.Module):
361
- """
362
- This is intended to be a simpler, and hopefully cheaper, replacement for
363
- LayerNorm. The observation this is based on, is that Transformer-type
364
- networks, especially with pre-norm, sometimes seem to set one of the
365
- feature dimensions to a large constant value (e.g. 50), which "defeats"
366
- the LayerNorm because the output magnitude is then not strongly dependent
367
- on the other (useful) features. Presumably the weight and bias of the
368
- LayerNorm are required to allow it to do this.
369
-
370
- So the idea is to introduce this large constant value as an explicit
371
- parameter, that takes the role of the "eps" in LayerNorm, so the network
372
- doesn't have to do this trick. We make the "eps" learnable.
373
-
374
- Args:
375
- num_channels: the number of channels, e.g. 512.
376
- channel_dim: the axis/dimension corresponding to the channel,
377
- interprted as an offset from the input's ndim if negative.
378
- shis is NOT the num_channels; it should typically be one of
379
- {-2, -1, 0, 1, 2, 3}.
380
- eps: the initial "epsilon" that we add as ballast in:
381
- scale = ((input_vec**2).mean() + epsilon)**-0.5
382
- Note: our epsilon is actually large, but we keep the name
383
- to indicate the connection with conventional LayerNorm.
384
- learn_eps: if true, we learn epsilon; if false, we keep it
385
- at the initial value.
386
- eps_min: float
387
- eps_max: float
388
- """
389
-
390
- def __init__(
391
- self,
392
- num_channels: int,
393
- channel_dim: int = -1, # CAUTION: see documentation.
394
- eps: float = 0.25,
395
- learn_eps: bool = True,
396
- eps_min: float = -3.0,
397
- eps_max: float = 3.0,
398
- ) -> None:
399
- super(BasicNorm, self).__init__()
400
- self.num_channels = num_channels
401
- self.channel_dim = channel_dim
402
- if learn_eps:
403
- self.eps = nn.Parameter(torch.tensor(eps).log().detach())
404
- else:
405
- self.register_buffer("eps", torch.tensor(eps).log().detach())
406
- self.eps_min = eps_min
407
- self.eps_max = eps_max
408
-
409
- def forward(self, x: Tensor) -> Tensor:
410
- assert x.shape[self.channel_dim] == self.num_channels
411
- eps = self.eps
412
- if self.training and random.random() < 0.25:
413
- # with probability 0.25, in training mode, clamp eps between the min
414
- # and max; this will encourage it to learn parameters within the
415
- # allowed range by making parameters that are outside the allowed
416
- # range noisy.
417
-
418
- # gradients to allow the parameter to get back into the allowed
419
- # region if it happens to exit it.
420
- eps = eps.clamp(min=self.eps_min, max=self.eps_max)
421
- scales = (
422
- torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
423
- ) ** -0.5
424
- return x * scales
425
-
426
-
427
- def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
428
- """
429
- Behaves like a constructor of a modified version of nn.Linear
430
- that gives an easy way to set the default initial parameter scale.
431
-
432
- Args:
433
- Accepts the standard args and kwargs that nn.Linear accepts
434
- e.g. in_features, out_features, bias=False.
435
-
436
- initial_scale: you can override this if you want to increase
437
- or decrease the initial magnitude of the module's output
438
- (affects the initialization of weight_scale and bias_scale).
439
- Another option, if you want to do something like this, is
440
- to re-initialize the parameters.
441
- """
442
- ans = nn.Linear(*args, **kwargs)
443
- with torch.no_grad():
444
- ans.weight[:] *= initial_scale
445
- if ans.bias is not None:
446
- torch.nn.init.uniform_(
447
- ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
448
- )
449
- return ans
450
-
451
-
452
- def ScaledConv1d(
453
- *args,
454
- initial_scale: float = 1.0,
455
- kernel_size: int = 3,
456
- padding: str = "same",
457
- **kwargs,
458
- ) -> nn.Conv1d:
459
- """
460
- Behaves like a constructor of a modified version of nn.Conv1d
461
- that gives an easy way to set the default initial parameter scale.
462
-
463
- Args:
464
- Accepts the standard args and kwargs that nn.Linear accepts
465
- e.g. in_features, out_features, bias=False.
466
-
467
- initial_scale: you can override this if you want to increase
468
- or decrease the initial magnitude of the module's output
469
- (affects the initialization of weight_scale and bias_scale).
470
- Another option, if you want to do something like this, is
471
- to re-initialize the parameters.
472
- """
473
- ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
474
- with torch.no_grad():
475
- ans.weight[:] *= initial_scale
476
- if ans.bias is not None:
477
- torch.nn.init.uniform_(
478
- ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
479
- )
480
- return ans
481
-
482
-
483
- def TransposeScaledConv1d(
484
- *args,
485
- initial_scale: float = 1.0,
486
- kernel_size: int = 3,
487
- padding: str = "same",
488
- **kwargs,
489
- ) -> nn.Sequential:
490
- """
491
- Transpose -> ScaledConv1d
492
- """
493
- return nn.Sequential(
494
- Transpose(),
495
- ScaledConv1d(
496
- *args,
497
- initial_scale=initial_scale,
498
- kernel_size=kernel_size,
499
- padding=padding,
500
- **kwargs,
501
- ),
502
- )
503
-
504
-
505
- def ScaledConv1dTranspose(
506
- *args,
507
- initial_scale: float = 1.0,
508
- kernel_size: int = 3,
509
- padding: str = "same",
510
- **kwargs,
511
- ) -> nn.Sequential:
512
- """
513
- Transpose -> ScaledConv1d
514
- """
515
- return nn.Sequential(
516
- ScaledConv1d(
517
- *args,
518
- initial_scale=initial_scale,
519
- kernel_size=kernel_size,
520
- padding=padding,
521
- **kwargs,
522
- ),
523
- Transpose(),
524
- )
525
-
526
-
527
- def TransposeConv1d(
528
- *args, kernel_size: int = 3, padding: str = "same", **kwargs
529
- ) -> nn.Sequential:
530
- """
531
- Transpose -> Conv1d
532
- """
533
- return nn.Sequential(
534
- Transpose(),
535
- nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
536
- )
537
-
538
-
539
- def Conv1dTranspose(
540
- *args, kernel_size: int = 3, padding: str = "same", **kwargs
541
- ) -> nn.Sequential:
542
- """
543
- ScaledConv1d -> Transpose
544
- """
545
- return nn.Sequential(
546
- nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
547
- Transpose(),
548
- )
549
-
550
-
551
- class SRLinear(nn.Linear):
552
- """https://arxiv.org/abs/2303.06296
553
- Stabilizing Transformer Training by Preventing Attention Entropy Collapse
554
- """
555
-
556
- def __init__(self, in_features, out_features, bias=True, **kwargs):
557
- super().__init__(in_features, out_features, bias=bias, **kwargs)
558
- self.register_buffer(
559
- "u", nn.functional.normalize(torch.randn(in_features), dim=0)
560
- )
561
- with torch.no_grad():
562
- sigma = self.get_sigma()
563
- self.register_buffer("spectral_norm", sigma)
564
- self.sigma = nn.Parameter(torch.ones(1))
565
-
566
- def get_sigma(self):
567
- with torch.no_grad():
568
- u = self.u
569
- v = self.weight.mv(u)
570
- v = nn.functional.normalize(v, dim=0)
571
- u = self.weight.T.mv(v)
572
- u = nn.functional.normalize(u, dim=0)
573
- self.u.data.copy_(u)
574
- return torch.einsum("c,cd,d->", v, self.weight, u)
575
-
576
- def get_weight(self):
577
- sigma = self.get_sigma()
578
- if self.training:
579
- self.spectral_norm.data.copy_(sigma)
580
- weight = (self.sigma / sigma) * self.weight
581
- return weight
582
-
583
- def forward(self, x):
584
- return nn.functional.linear(x, self.get_weight(), self.bias)
585
-
586
-
587
- class SRConv1d(SRLinear):
588
- def __init__(
589
- self,
590
- in_features,
591
- out_features,
592
- kernel_size,
593
- stride: int = 1,
594
- padding: str = "same",
595
- bias: bool = True,
596
- **kwargs,
597
- ):
598
- in_features = in_features * kernel_size
599
- super().__init__(in_features, out_features, bias=bias, **kwargs)
600
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
601
- self.kernel_size = kernel_size
602
- self.stride = stride
603
- self.padding = padding
604
-
605
- def forward(self, x):
606
- in_features = self.in_features // self.kernel_size
607
- weight = self.get_weight().view(
608
- self.out_features, in_features, self.kernel_size
609
- )
610
- return nn.functional.conv1d(
611
- x, weight, bias=self.bias, stride=self.stride, padding=self.padding
612
- )
613
-
614
-
615
- def TransposeSRConv1d(
616
- *args, kernel_size: int = 3, padding: str = "same", **kwargs
617
- ) -> nn.Sequential:
618
- """
619
- Transpose -> SRConv1d
620
- """
621
- return nn.Sequential(
622
- Transpose(),
623
- SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
624
- )
625
-
626
-
627
- def SRConv1dTranspose(
628
- *args, kernel_size: int = 3, padding: str = "same", **kwargs
629
- ) -> nn.Sequential:
630
- """
631
- SRConv1d -> Transpose
632
- """
633
- return nn.Sequential(
634
- SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
635
- Transpose(),
636
- )
637
-
638
-
639
- class ActivationBalancer(torch.nn.Module):
640
- """
641
- Modifies the backpropped derivatives of a function to try to encourage, for
642
- each channel, that it is positive at least a proportion `threshold` of the
643
- time. It does this by multiplying negative derivative values by up to
644
- (1+max_factor), and positive derivative values by up to (1-max_factor),
645
- interpolated from 1 at the threshold to those extremal values when none
646
- of the inputs are positive.
647
-
648
- Args:
649
- num_channels: the number of channels
650
- channel_dim: the dimension/axis corresponding to the channel, e.g.
651
- -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
652
- min_positive: the minimum, per channel, of the proportion of the time
653
- that (x > 0), below which we start to modify the derivatives.
654
- max_positive: the maximum, per channel, of the proportion of the time
655
- that (x > 0), above which we start to modify the derivatives.
656
- max_factor: the maximum factor by which we modify the derivatives for
657
- either the sign constraint or the magnitude constraint;
658
- e.g. with max_factor=0.02, the the derivatives would be multiplied by
659
- values in the range [0.98..1.02].
660
- sign_gain_factor: determines the 'gain' with which we increase the
661
- change in gradient once the constraints on min_positive and max_positive
662
- are violated.
663
- scale_gain_factor: determines the 'gain' with which we increase the
664
- change in gradient once the constraints on min_abs and max_abs
665
- are violated.
666
- min_abs: the minimum average-absolute-value difference from the mean
667
- value per channel, which we allow, before we start to modify
668
- the derivatives to prevent this.
669
- max_abs: the maximum average-absolute-value difference from the mean
670
- value per channel, which we allow, before we start to modify
671
- the derivatives to prevent this.
672
- min_prob: determines the minimum probability with which we modify the
673
- gradients for the {min,max}_positive and {min,max}_abs constraints,
674
- on each forward(). This is done randomly to prevent all layers
675
- from doing it at the same time. Early in training we may use
676
- higher probabilities than this; it will decay to this value.
677
- """
678
-
679
- def __init__(
680
- self,
681
- num_channels: int,
682
- channel_dim: int,
683
- min_positive: float = 0.05,
684
- max_positive: float = 0.95,
685
- max_factor: float = 0.04,
686
- sign_gain_factor: float = 0.01,
687
- scale_gain_factor: float = 0.02,
688
- min_abs: float = 0.2,
689
- max_abs: float = 100.0,
690
- min_prob: float = 0.1,
691
- ):
692
- super(ActivationBalancer, self).__init__()
693
- self.num_channels = num_channels
694
- self.channel_dim = channel_dim
695
- self.min_positive = min_positive
696
- self.max_positive = max_positive
697
- self.max_factor = max_factor
698
- self.min_abs = min_abs
699
- self.max_abs = max_abs
700
- self.min_prob = min_prob
701
- self.sign_gain_factor = sign_gain_factor
702
- self.scale_gain_factor = scale_gain_factor
703
-
704
- # count measures how many times the forward() function has been called.
705
- # We occasionally sync this to a tensor called `count`, that exists to
706
- # make sure it is synced to disk when we load and save the model.
707
- self.cpu_count = 0
708
- self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
709
-
710
- def forward(self, x: Tensor) -> Tensor:
711
- if (
712
- torch.jit.is_scripting()
713
- or not x.requires_grad
714
- or torch.jit.is_tracing()
715
- ):
716
- return _no_op(x)
717
-
718
- count = self.cpu_count
719
- self.cpu_count += 1
720
-
721
- if random.random() < 0.01:
722
- # Occasionally sync self.cpu_count with self.count.
723
- # count affects the decay of 'prob'. don't do this on every iter,
724
- # because syncing with the GPU is slow.
725
- self.cpu_count = max(self.cpu_count, self.count.item())
726
- self.count.fill_(self.cpu_count)
727
-
728
- # the prob of doing some work exponentially decreases from 0.5 till it hits
729
- # a floor at min_prob (==0.1, by default)
730
- prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
731
-
732
- if random.random() < prob:
733
- sign_gain_factor = 0.5
734
- if self.min_positive != 0.0 or self.max_positive != 1.0:
735
- sign_factor = _compute_sign_factor(
736
- x,
737
- self.channel_dim,
738
- self.min_positive,
739
- self.max_positive,
740
- gain_factor=self.sign_gain_factor / prob,
741
- max_factor=self.max_factor,
742
- )
743
- else:
744
- sign_factor = None
745
-
746
- scale_factor = _compute_scale_factor(
747
- x.detach(),
748
- self.channel_dim,
749
- min_abs=self.min_abs,
750
- max_abs=self.max_abs,
751
- gain_factor=self.scale_gain_factor / prob,
752
- max_factor=self.max_factor,
753
- )
754
- return ActivationBalancerFunction.apply(
755
- x,
756
- scale_factor,
757
- sign_factor,
758
- self.channel_dim,
759
- )
760
- else:
761
- return _no_op(x)
762
-
763
-
764
- def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
765
- """
766
- Returns x unmodified, but in backprop will put a penalty for the excess of
767
- the absolute values of elements of x over the limit "limit". E.g. if
768
- limit == 10.0, then if x has any values over 10 it will get a penalty.
769
-
770
- Caution: the value of this penalty will be affected by grad scaling used
771
- in automatic mixed precision training. For this reasons we use this,
772
- it shouldn't really matter, or may even be helpful; we just use this
773
- to disallow really implausible values of scores to be given to softmax.
774
- """
775
- x_sign = x.sign()
776
- over_limit = (x.abs() - limit) > 0
777
- # The following is a memory efficient way to penalize the absolute values of
778
- # x that's over the limit. (The memory efficiency comes when you think
779
- # about which items torch needs to cache for the autograd, and which ones it
780
- # can throw away). The numerical value of aux_loss as computed here will
781
- # actually be larger than it should be, by limit * over_limit.sum(), but it
782
- # has the same derivative as the real aux_loss which is penalty * (x.abs() -
783
- # limit).relu().
784
- aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
785
- # note: we don't do sum() here on aux)_loss, but it's as if we had done
786
- # sum() due to how with_loss() works.
787
- x = with_loss(x, aux_loss)
788
- # you must use x for something, or this will be ineffective.
789
- return x
790
-
791
-
792
- def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
793
- if x.ndim == 2:
794
- return x.diag()
795
- else:
796
- (batch, dim, dim) = x.shape
797
- x = x.reshape(batch, dim * dim)
798
- x = x[:, :: dim + 1]
799
- assert x.shape == (batch, dim)
800
- return x
801
-
802
-
803
- def _whitening_metric(x: Tensor, num_groups: int):
804
- """
805
- Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
806
- of the centered feature covariance are the same within each group's covariance matrix
807
- and also between groups.
808
- Args:
809
- x: a Tensor of shape (*, num_channels)
810
- num_groups: the number of groups of channels, a number >=1 that divides num_channels
811
- Returns:
812
- Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
813
- greater than 1.0 otherwise.
814
- """
815
- assert x.dtype != torch.float16
816
- x = x.reshape(-1, x.shape[-1])
817
- (num_frames, num_channels) = x.shape
818
- assert num_channels % num_groups == 0
819
- channels_per_group = num_channels // num_groups
820
- x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
821
- # x now has shape (num_groups, num_frames, channels_per_group)
822
- # subtract the mean so we use the centered, not uncentered, covariance.
823
- # My experience has been that when we "mess with the gradients" like this,
824
- # it's better not do anything that tries to move the mean around, because
825
- # that can easily cause instability.
826
- x = x - x.mean(dim=1, keepdim=True)
827
- # x_covar: (num_groups, channels_per_group, channels_per_group)
828
- x_covar = torch.matmul(x.transpose(1, 2), x)
829
- x_covar_mean_diag = _diag(x_covar).mean()
830
- # the following expression is what we'd get if we took the matrix product
831
- # of each covariance and measured the mean of its trace, i.e.
832
- # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
833
- x_covarsq_mean_diag = (x_covar ** 2).sum() / (
834
- num_groups * channels_per_group
835
- )
836
- # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
837
- metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
838
- return metric
839
-
840
-
841
- class WhiteningPenaltyFunction(torch.autograd.Function):
842
- @staticmethod
843
- def forward(
844
- ctx,
845
- x: Tensor,
846
- num_groups: int,
847
- whitening_limit: float,
848
- grad_scale: float,
849
- ) -> Tensor:
850
- ctx.save_for_backward(x)
851
- ctx.num_groups = num_groups
852
- ctx.whitening_limit = whitening_limit
853
- ctx.grad_scale = grad_scale
854
- return x
855
-
856
- @staticmethod
857
- def backward(ctx, x_grad: Tensor):
858
- (x_orig,) = ctx.saved_tensors
859
- with torch.enable_grad():
860
- with torch.cuda.amp.autocast(enabled=False):
861
- x_detached = x_orig.to(torch.float32).detach()
862
- x_detached.requires_grad = True
863
-
864
- metric = _whitening_metric(x_detached, ctx.num_groups)
865
-
866
- if random.random() < 0.005 or __name__ == "__main__":
867
- logging.info(
868
- f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
869
- f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
870
- )
871
-
872
- (metric - ctx.whitening_limit).relu().backward()
873
- penalty_grad = x_detached.grad
874
- scale = ctx.grad_scale * (
875
- x_grad.to(torch.float32).norm()
876
- / (penalty_grad.norm() + 1.0e-20)
877
- )
878
- penalty_grad = penalty_grad * scale
879
- return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
880
-
881
-
882
- class Whiten(nn.Module):
883
- def __init__(
884
- self,
885
- num_groups: int,
886
- whitening_limit: float,
887
- prob: Union[float, Tuple[float, float]],
888
- grad_scale: float,
889
- ):
890
- """
891
- Args:
892
- num_groups: the number of groups to divide the channel dim into before
893
- whitening. We will attempt to make the feature covariance
894
- within each group, after mean subtraction, as "white" as possible,
895
- while having the same trace across all groups.
896
- whitening_limit: a value greater than 1.0, that dictates how much
897
- freedom we have to violate the constraints. 1.0 would mean perfectly
898
- white, with exactly the same trace across groups; larger values
899
- give more freedom. E.g. 2.0.
900
- prob: the probability with which we apply the gradient modification
901
- (also affects the grad scale). May be supplied as a float,
902
- or as a pair (min_prob, max_prob)
903
-
904
- grad_scale: determines the scale on the gradient term from this object,
905
- relative to the rest of the gradient on the attention weights.
906
- E.g. 0.02 (you may want to use smaller values than this if prob is large)
907
- """
908
- super(Whiten, self).__init__()
909
- assert num_groups >= 1
910
- assert whitening_limit >= 1
911
- assert grad_scale >= 0
912
- self.num_groups = num_groups
913
- self.whitening_limit = whitening_limit
914
- if isinstance(prob, float):
915
- assert 0 < prob <= 1
916
- self.prob = prob
917
- else:
918
- (self.min_prob, self.max_prob) = prob
919
- assert 0 < self.min_prob < self.max_prob <= 1
920
- self.prob = self.max_prob
921
-
922
- self.grad_scale = grad_scale
923
-
924
- def forward(self, x: Tensor) -> Tensor:
925
- """
926
- In the forward pass, this function just returns the input unmodified.
927
- In the backward pass, it will modify the gradients to ensure that the
928
- distribution in each group has close to (lambda times I) as the covariance
929
- after mean subtraction, with the same lambda across groups.
930
- For whitening_limit > 1, there will be more freedom to violate this
931
- constraint.
932
-
933
- Args:
934
- x: the input of shape (*, num_channels)
935
-
936
- Returns:
937
- x, unmodified. You should make sure
938
- you use the returned value, or the graph will be freed
939
- and nothing will happen in backprop.
940
- """
941
- if (
942
- not x.requires_grad
943
- or random.random() > self.prob
944
- or self.grad_scale == 0
945
- ):
946
- return _no_op(x)
947
- else:
948
- if hasattr(self, "min_prob") and random.random() < 0.25:
949
- # occasionally switch between min_prob and max_prob, based on whether
950
- # we are above or below the threshold.
951
- if (
952
- _whitening_metric(x.to(torch.float32), self.num_groups)
953
- > self.whitening_limit
954
- ):
955
- # there would be a change to the grad.
956
- self.prob = self.max_prob
957
- else:
958
- self.prob = self.min_prob
959
-
960
- return WhiteningPenaltyFunction.apply(
961
- x, self.num_groups, self.whitening_limit, self.grad_scale
962
- )
963
-
964
-
965
- class WithLoss(torch.autograd.Function):
966
- @staticmethod
967
- def forward(ctx, x: Tensor, y: Tensor):
968
- ctx.y_shape = y.shape
969
- return x
970
-
971
- @staticmethod
972
- def backward(ctx, ans_grad: Tensor):
973
- return ans_grad, torch.ones(
974
- ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
975
- )
976
-
977
-
978
- def with_loss(x, y):
979
- if torch.jit.is_scripting() or torch.jit.is_tracing():
980
- return x
981
- # returns x but adds y.sum() to the loss function.
982
- return WithLoss.apply(x, y)
983
-
984
-
985
- def _no_op(x: Tensor) -> Tensor:
986
- if torch.jit.is_scripting() or torch.jit.is_tracing():
987
- return x
988
- else:
989
- # a no-op function that will have a node in the autograd graph,
990
- # to avoid certain bugs relating to backward hooks
991
- return x.chunk(1, dim=-1)[0]
992
-
993
-
994
- class Identity(torch.nn.Module):
995
- def __init__(self):
996
- super(Identity, self).__init__()
997
-
998
- def forward(self, x):
999
- return _no_op(x)
1000
-
1001
-
1002
- class MaxEig(torch.nn.Module):
1003
- """
1004
- Modifies the backpropped derivatives of a function to try to discourage
1005
- that any given direction in activation space accounts for more than
1006
- a specified proportion of the covariance (e.g. 0.2).
1007
-
1008
-
1009
- Args:
1010
- num_channels: the number of channels
1011
- channel_dim: the dimension/axis corresponding to the channel, e.g.
1012
- -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1013
- max_var_per_eig: the maximum proportion of the variance of the
1014
- features/channels, after mean subtraction, that can come from
1015
- any given eigenvalue.
1016
- min_prob: the minimum probability with which we apply this during any invocation
1017
- of forward(), assuming last time we applied the constraint it was
1018
- not active; supplied for speed.
1019
- scale: determines the scale with which we modify the gradients, relative
1020
- to the existing / unmodified gradients
1021
- """
1022
-
1023
- def __init__(
1024
- self,
1025
- num_channels: int,
1026
- channel_dim: int,
1027
- max_var_per_eig: float = 0.2,
1028
- min_prob: float = 0.01,
1029
- scale: float = 0.01,
1030
- ):
1031
- super(MaxEig, self).__init__()
1032
- self.num_channels = num_channels
1033
- self.channel_dim = channel_dim
1034
- self.scale = scale
1035
- assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1036
- self.max_var_per_eig = max_var_per_eig
1037
-
1038
- # we figure out the dominant direction using the power method: starting with
1039
- # a random vector, keep multiplying by the covariance and renormalizing.
1040
- with torch.no_grad():
1041
- # arbitrary.. would use randn() but want to leave the rest of the model's
1042
- # random parameters unchanged for comparison
1043
- direction = torch.arange(num_channels).to(torch.float)
1044
- direction = direction / direction.norm()
1045
- self.register_buffer("max_eig_direction", direction)
1046
-
1047
- self.min_prob = min_prob
1048
- # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1049
- # We'll regress this towards prob, each time we try to apply it and it is not
1050
- # active.
1051
- self.cur_prob = 1.0
1052
-
1053
- def forward(self, x: Tensor) -> Tensor:
1054
- if (
1055
- torch.jit.is_scripting()
1056
- or self.max_var_per_eig <= 0
1057
- or random.random() > self.cur_prob
1058
- or torch.jit.is_tracing()
1059
- ):
1060
- return _no_op(x)
1061
-
1062
- with torch.cuda.amp.autocast(enabled=False):
1063
- eps = 1.0e-20
1064
- orig_x = x
1065
- x = x.to(torch.float32)
1066
- with torch.no_grad():
1067
- x = x.transpose(self.channel_dim, -1).reshape(
1068
- -1, self.num_channels
1069
- )
1070
- x = x - x.mean(dim=0)
1071
- new_direction, coeffs = self._find_direction_coeffs(
1072
- x, self.max_eig_direction
1073
- )
1074
- x_var = (x ** 2).mean()
1075
- x_residual = x - coeffs * new_direction
1076
- x_residual_var = (x_residual ** 2).mean()
1077
-
1078
- # `variance_proportion` is the proportion of the variance accounted for
1079
- # by the top eigen-direction.
1080
- variance_proportion = (x_var - x_residual_var) / (
1081
- x_var + 1.0e-20
1082
- )
1083
-
1084
- # ensure new direction is nonzero even if x == 0, by including `direction`.
1085
- self._set_direction(
1086
- 0.1 * self.max_eig_direction + new_direction
1087
- )
1088
-
1089
- if random.random() < 0.01 or __name__ == "__main__":
1090
- logging.info(
1091
- f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1092
- )
1093
-
1094
- if variance_proportion >= self.max_var_per_eig:
1095
- # The constraint is active. Note, we should quite rarely
1096
- # reach here, only near the beginning of training if we are
1097
- # starting to diverge, should this constraint be active.
1098
- cur_prob = self.cur_prob
1099
- self.cur_prob = (
1100
- 1.0 # next time, do the update with probability 1.0.
1101
- )
1102
- return MaxEigLimiterFunction.apply(
1103
- orig_x, coeffs, new_direction, self.channel_dim, self.scale
1104
- )
1105
- else:
1106
- # let self.cur_prob exponentially approach self.min_prob, as
1107
- # long as the constraint is inactive.
1108
- self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1109
- return orig_x
1110
-
1111
- def _set_direction(self, direction: Tensor):
1112
- """
1113
- Sets self.max_eig_direction to a normalized version of `direction`
1114
- """
1115
- direction = direction.detach()
1116
- direction = direction / direction.norm()
1117
- direction_sum = direction.sum().item()
1118
- if direction_sum - direction_sum == 0: # no inf/nan
1119
- self.max_eig_direction[:] = direction
1120
- else:
1121
- logging.info(
1122
- f"Warning: sum of direction in MaxEig is {direction_sum}, "
1123
- "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1124
- )
1125
-
1126
- def _find_direction_coeffs(
1127
- self, x: Tensor, prev_direction: Tensor
1128
- ) -> Tuple[Tensor, Tensor, Tensor]:
1129
- """
1130
- Figure out (an approximation to) the proportion of the variance of a set of
1131
- feature vectors that can be attributed to the top eigen-direction.
1132
- Args:
1133
- x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1134
- prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1135
- of the top eigen-direction, or a random direction if this is the first
1136
- iteration. Does not have to be normalized, but should be nonzero.
1137
-
1138
- Returns: (cur_direction, coeffs), where:
1139
- cur_direction: a Tensor of shape (num_channels,) that is the current
1140
- estimate of the top eigen-direction.
1141
- coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1142
- approximately minimizes, (x - coeffs * cur_direction).norm()
1143
- """
1144
- (num_frames, num_channels) = x.shape
1145
- assert num_channels > 1 and num_frames > 1
1146
- assert prev_direction.shape == (num_channels,)
1147
- # `coeffs` are the coefficients of `prev_direction` in x.
1148
- # actually represent the coeffs up to a constant positive factor.
1149
- coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1150
- cur_direction = (x * coeffs).sum(dim=0) / (
1151
- (coeffs ** 2).sum() + 1.0e-20
1152
- )
1153
- return cur_direction, coeffs
1154
-
1155
-
1156
- class DoubleSwishFunction(torch.autograd.Function):
1157
- """
1158
- double_swish(x) = x * torch.sigmoid(x-1)
1159
- This is a definition, originally motivated by its close numerical
1160
- similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1161
-
1162
- Memory-efficient derivative computation:
1163
- double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1164
- double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1165
- Now, s'(x) = s(x) * (1-s(x)).
1166
- double_swish'(x) = x * s'(x) + s(x).
1167
- = x * s(x) * (1-s(x)) + s(x).
1168
- = double_swish(x) * (1-s(x)) + s(x)
1169
- ... so we just need to remember s(x) but not x itself.
1170
- """
1171
-
1172
- @staticmethod
1173
- def forward(ctx, x: Tensor) -> Tensor:
1174
- requires_grad = x.requires_grad
1175
- x_dtype = x.dtype
1176
- if x.dtype == torch.float16:
1177
- x = x.to(torch.float32)
1178
-
1179
- s = torch.sigmoid(x - 1.0)
1180
- y = x * s
1181
-
1182
- if requires_grad:
1183
- deriv = y * (1 - s) + s
1184
- # notes on derivative of x * sigmoid(x - 1):
1185
- # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1186
- # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1187
- # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1188
- # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1189
- # floors), should be expectation-preserving.
1190
- floor = -0.043637
1191
- ceil = 1.2
1192
- d_scaled = (deriv - floor) * (
1193
- 255.0 / (ceil - floor)
1194
- ) + torch.rand_like(deriv)
1195
- if __name__ == "__main__":
1196
- # for self-testing only.
1197
- assert d_scaled.min() >= 0.0
1198
- assert d_scaled.max() < 256.0
1199
- d_int = d_scaled.to(torch.uint8)
1200
- ctx.save_for_backward(d_int)
1201
- if x.dtype == torch.float16 or torch.is_autocast_enabled():
1202
- y = y.to(torch.float16)
1203
- return y
1204
-
1205
- @staticmethod
1206
- def backward(ctx, y_grad: Tensor) -> Tensor:
1207
- (d,) = ctx.saved_tensors
1208
- # the same constants as used in forward pass.
1209
- floor = -0.043637
1210
- ceil = 1.2
1211
- d = d * ((ceil - floor) / 255.0) + floor
1212
- return y_grad * d
1213
-
1214
-
1215
- class DoubleSwish(torch.nn.Module):
1216
- def forward(self, x: Tensor) -> Tensor:
1217
- """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1218
- that we approximate closely with x * sigmoid(x-1).
1219
- """
1220
- if torch.jit.is_scripting() or torch.jit.is_tracing():
1221
- return x * torch.sigmoid(x - 1.0)
1222
- return DoubleSwishFunction.apply(x)
1223
-
1224
-
1225
- def BalancedDoubleSwish(
1226
- d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1227
- ) -> nn.Sequential:
1228
- """
1229
- ActivationBalancer -> DoubleSwish
1230
- """
1231
- balancer = ActivationBalancer(
1232
- d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1233
- )
1234
- return nn.Sequential(
1235
- balancer,
1236
- DoubleSwish(),
1237
- )
1238
-
1239
-
1240
- def _test_max_eig():
1241
- for proportion in [0.1, 0.5, 10.0]:
1242
- logging.info(f"proportion = {proportion}")
1243
- x = torch.randn(100, 128)
1244
- direction = torch.randn(128)
1245
- coeffs = torch.randn(100, 1)
1246
- x += proportion * direction * coeffs
1247
-
1248
- x.requires_grad = True
1249
-
1250
- num_channels = 128
1251
- m = MaxEig(
1252
- num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1253
- ) # grad_scale
1254
-
1255
- for _ in range(4):
1256
- y = m(x)
1257
-
1258
- y_grad = torch.randn_like(x)
1259
- y.backward(gradient=y_grad)
1260
-
1261
- if proportion < 0.2:
1262
- assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1263
- elif proportion > 1.0:
1264
- assert not torch.allclose(x.grad, y_grad)
1265
-
1266
-
1267
- def _test_whiten():
1268
- for proportion in [0.1, 0.5, 10.0]:
1269
- logging.info(f"_test_whiten(): proportion = {proportion}")
1270
- x = torch.randn(100, 128)
1271
- direction = torch.randn(128)
1272
- coeffs = torch.randn(100, 1)
1273
- x += proportion * direction * coeffs
1274
-
1275
- x.requires_grad = True
1276
-
1277
- num_channels = 128
1278
- m = Whiten(
1279
- 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1280
- ) # grad_scale
1281
-
1282
- for _ in range(4):
1283
- y = m(x)
1284
-
1285
- y_grad = torch.randn_like(x)
1286
- y.backward(gradient=y_grad)
1287
-
1288
- if proportion < 0.2:
1289
- assert torch.allclose(x.grad, y_grad)
1290
- elif proportion > 1.0:
1291
- assert not torch.allclose(x.grad, y_grad)
1292
-
1293
-
1294
- def _test_activation_balancer_sign():
1295
- probs = torch.arange(0, 1, 0.01)
1296
- N = 1000
1297
- x = 1.0 * (
1298
- (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1299
- )
1300
- x = x.detach()
1301
- x.requires_grad = True
1302
- m = ActivationBalancer(
1303
- probs.numel(),
1304
- channel_dim=0,
1305
- min_positive=0.05,
1306
- max_positive=0.95,
1307
- max_factor=0.2,
1308
- min_abs=0.0,
1309
- )
1310
-
1311
- y_grad = torch.sign(torch.randn(probs.numel(), N))
1312
-
1313
- y = m(x)
1314
- y.backward(gradient=y_grad)
1315
- print("_test_activation_balancer_sign: x = ", x)
1316
- print("_test_activation_balancer_sign: y grad = ", y_grad)
1317
- print("_test_activation_balancer_sign: x grad = ", x.grad)
1318
-
1319
-
1320
- def _test_activation_balancer_magnitude():
1321
- magnitudes = torch.arange(0, 1, 0.01)
1322
- N = 1000
1323
- x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1324
- -1
1325
- )
1326
- x = x.detach()
1327
- x.requires_grad = True
1328
- m = ActivationBalancer(
1329
- magnitudes.numel(),
1330
- channel_dim=0,
1331
- min_positive=0.0,
1332
- max_positive=1.0,
1333
- max_factor=0.2,
1334
- min_abs=0.2,
1335
- max_abs=0.8,
1336
- min_prob=1.0,
1337
- )
1338
-
1339
- y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1340
-
1341
- y = m(x)
1342
- y.backward(gradient=y_grad)
1343
- print("_test_activation_balancer_magnitude: x = ", x)
1344
- print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1345
- print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1346
-
1347
-
1348
- def _test_basic_norm():
1349
- num_channels = 128
1350
- m = BasicNorm(num_channels=num_channels, channel_dim=1)
1351
-
1352
- x = torch.randn(500, num_channels)
1353
-
1354
- y = m(x)
1355
-
1356
- assert y.shape == x.shape
1357
- x_rms = (x ** 2).mean().sqrt()
1358
- y_rms = (y ** 2).mean().sqrt()
1359
- print("x rms = ", x_rms)
1360
- print("y rms = ", y_rms)
1361
- assert y_rms < x_rms
1362
- assert y_rms > 0.5 * x_rms
1363
-
1364
-
1365
- def _test_double_swish_deriv():
1366
- x = torch.randn(10, 12, dtype=torch.double) * 3.0
1367
- x.requires_grad = True
1368
- m = DoubleSwish()
1369
-
1370
- tol = (1.2 - (-0.043637)) / 255.0
1371
- torch.autograd.gradcheck(m, x, atol=tol)
1372
-
1373
- # for self-test.
1374
- x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1375
- x.requires_grad = True
1376
- y = m(x)
1377
-
1378
-
1379
- def _test_softmax():
1380
- a = torch.randn(2, 10, dtype=torch.float64)
1381
- b = a.clone()
1382
- a.requires_grad = True
1383
- b.requires_grad = True
1384
- a.softmax(dim=1)[:, 0].sum().backward()
1385
- print("a grad = ", a.grad)
1386
- softmax(b, dim=1)[:, 0].sum().backward()
1387
- print("b grad = ", b.grad)
1388
- assert torch.allclose(a.grad, b.grad)
1389
-
1390
-
1391
- if __name__ == "__main__":
1392
- logging.getLogger().setLevel(logging.INFO)
1393
- torch.set_num_threads(1)
1394
- torch.set_num_interop_threads(1)
1395
- _test_softmax()
1396
- _test_whiten()
1397
- _test_max_eig()
1398
- _test_activation_balancer_sign()
1399
- _test_activation_balancer_magnitude()
1400
- _test_basic_norm()
1401
- _test_double_swish_deriv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/scheduler.py DELETED
@@ -1,78 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright 2023 (authors: Feiteng Li)
3
- #
4
- # See ../../../../LICENSE for clarification regarding multiple authors
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
-
19
- import torch
20
-
21
- from modules.optim import Eden
22
-
23
-
24
- def calc_lr(step, dim_embed, warmup_steps):
25
- return dim_embed ** (-0.5) * min(
26
- step ** (-0.5), step * warmup_steps ** (-1.5)
27
- )
28
-
29
-
30
- class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31
- def __init__(
32
- self,
33
- base_lr: float,
34
- optimizer: torch.optim.Optimizer,
35
- dim_embed: int,
36
- warmup_steps: int,
37
- last_epoch: int = -1,
38
- verbose: bool = False,
39
- ) -> None:
40
-
41
- self.dim_embed = dim_embed
42
- self.base_lr = base_lr
43
- self.warmup_steps = warmup_steps
44
- self.num_param_groups = len(optimizer.param_groups)
45
-
46
- super().__init__(optimizer, last_epoch, verbose)
47
-
48
- def get_lr(self) -> float:
49
- lr = self.base_lr * calc_lr(
50
- self._step_count, self.dim_embed, self.warmup_steps
51
- )
52
- return [lr] * self.num_param_groups
53
-
54
- def set_step(self, step: int):
55
- self._step_count = step
56
-
57
-
58
- def get_scheduler(params, optimizer):
59
- if params.scheduler_name.lower() == "eden":
60
- scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61
- elif params.scheduler_name.lower() == "noam":
62
- scheduler = NoamScheduler(
63
- params.base_lr,
64
- optimizer,
65
- params.decoder_dim,
66
- warmup_steps=params.warmup_steps,
67
- )
68
- # scheduler.set_step(params.start_batch or params.batch_idx_train)
69
- elif params.scheduler_name.lower() == "cosine":
70
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71
- params.warmup_steps,
72
- optimizer,
73
- eta_min=params.base_lr,
74
- )
75
- else:
76
- raise NotImplementedError(f"{params.scheduler_name}")
77
-
78
- return scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/transformer.py DELETED
@@ -1,683 +0,0 @@
1
- import copy
2
- import numbers
3
- from functools import partial
4
- from typing import Any, Callable, List, Optional, Tuple, Union
5
-
6
- import torch
7
- from torch import Tensor, nn
8
- from torch.nn import functional as F
9
-
10
- from .activation import MultiheadAttention
11
- from .scaling import ActivationBalancer, BalancedDoubleSwish
12
- from .scaling import BasicNorm as _BasicNorm
13
-
14
- _shape_t = Union[int, List[int], torch.Size]
15
-
16
-
17
- class LayerNorm(nn.Module):
18
- __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19
- normalized_shape: Tuple[int, ...]
20
- eps: float
21
- elementwise_affine: bool
22
-
23
- def __init__(
24
- self,
25
- normalized_shape: _shape_t,
26
- eps: float = 1e-5,
27
- elementwise_affine: bool = True,
28
- device=None,
29
- dtype=None,
30
- ) -> None:
31
- factory_kwargs = {"device": device, "dtype": dtype}
32
- super(LayerNorm, self).__init__()
33
- if isinstance(normalized_shape, numbers.Integral):
34
- # mypy error: incompatible types in assignment
35
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
36
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37
- self.eps = eps
38
- self.elementwise_affine = elementwise_affine
39
- if self.elementwise_affine:
40
- self.weight = nn.Parameter(
41
- torch.empty(self.normalized_shape, **factory_kwargs)
42
- )
43
- self.bias = nn.Parameter(
44
- torch.empty(self.normalized_shape, **factory_kwargs)
45
- )
46
- else:
47
- self.register_parameter("weight", None)
48
- self.register_parameter("bias", None)
49
-
50
- self.reset_parameters()
51
-
52
- def reset_parameters(self) -> None:
53
- if self.elementwise_affine:
54
- nn.init.ones_(self.weight)
55
- nn.init.zeros_(self.bias)
56
-
57
- def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58
- if isinstance(input, tuple):
59
- input, embedding = input
60
- return (
61
- F.layer_norm(
62
- input,
63
- self.normalized_shape,
64
- self.weight,
65
- self.bias,
66
- self.eps,
67
- ),
68
- embedding,
69
- )
70
-
71
- assert embedding is None
72
- return F.layer_norm(
73
- input, self.normalized_shape, self.weight, self.bias, self.eps
74
- )
75
-
76
- def extra_repr(self) -> str:
77
- return (
78
- "{normalized_shape}, eps={eps}, "
79
- "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80
- )
81
-
82
-
83
- class AdaptiveLayerNorm(nn.Module):
84
- r"""Adaptive Layer Normalization"""
85
-
86
- def __init__(self, d_model, norm) -> None:
87
- super(AdaptiveLayerNorm, self).__init__()
88
- self.project_layer = nn.Linear(d_model, 2 * d_model)
89
- self.norm = norm
90
- self.d_model = d_model
91
- self.eps = self.norm.eps
92
-
93
- def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94
- if isinstance(input, tuple):
95
- input, embedding = input
96
- weight, bias = torch.split(
97
- self.project_layer(embedding),
98
- split_size_or_sections=self.d_model,
99
- dim=-1,
100
- )
101
- return (weight * self.norm(input) + bias, embedding)
102
-
103
- weight, bias = torch.split(
104
- self.project_layer(embedding),
105
- split_size_or_sections=self.d_model,
106
- dim=-1,
107
- )
108
- return weight * self.norm(input) + bias
109
-
110
-
111
- class BasicNorm(_BasicNorm):
112
- def __init__(
113
- self,
114
- d_model: int,
115
- eps: float = 1e-5,
116
- device=None,
117
- dtype=None,
118
- ):
119
- super(BasicNorm, self).__init__(d_model, eps=eps)
120
-
121
- def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122
- if isinstance(input, tuple):
123
- input, embedding = input
124
- return (
125
- super(BasicNorm, self).forward(input),
126
- embedding,
127
- )
128
-
129
- assert embedding is None
130
- return super(BasicNorm, self).forward(input)
131
-
132
-
133
- class BalancedBasicNorm(nn.Module):
134
- def __init__(
135
- self,
136
- d_model: int,
137
- eps: float = 1e-5,
138
- device=None,
139
- dtype=None,
140
- ):
141
- super(BalancedBasicNorm, self).__init__()
142
- self.balancer = ActivationBalancer(
143
- d_model,
144
- channel_dim=-1,
145
- min_positive=0.45,
146
- max_positive=0.55,
147
- max_abs=6.0,
148
- )
149
- self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150
-
151
- def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152
- if isinstance(input, tuple):
153
- input, embedding = input
154
- return self.norm((self.balancer(input), embedding))
155
-
156
- assert embedding is None
157
- return self.norm(self.balancer(input))
158
-
159
-
160
- class IdentityNorm(nn.Module):
161
- def __init__(
162
- self,
163
- d_model: int,
164
- eps: float = 1e-5,
165
- device=None,
166
- dtype=None,
167
- ) -> None:
168
- super(IdentityNorm, self).__init__()
169
-
170
- def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171
- if isinstance(input, tuple):
172
- return input
173
-
174
- assert embedding is None
175
- return input
176
-
177
-
178
- class TransformerEncoderLayer(nn.Module):
179
- __constants__ = ["batch_first", "norm_first"]
180
-
181
- def __init__(
182
- self,
183
- d_model: int,
184
- nhead: int,
185
- dim_feedforward: int = 2048,
186
- dropout: float = 0.1,
187
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188
- batch_first: bool = False,
189
- norm_first: bool = False,
190
- device=None,
191
- dtype=None,
192
- linear1_self_attention_cls: nn.Module = nn.Linear,
193
- linear2_self_attention_cls: nn.Module = nn.Linear,
194
- linear1_feedforward_cls: nn.Module = nn.Linear,
195
- linear2_feedforward_cls: nn.Module = nn.Linear,
196
- layer_norm_cls: nn.Module = LayerNorm,
197
- layer_norm_eps: float = 1e-5,
198
- adaptive_layer_norm=False,
199
- ) -> None:
200
- factory_kwargs = {"device": device, "dtype": dtype}
201
- super(TransformerEncoderLayer, self).__init__()
202
- self.self_attn = MultiheadAttention(
203
- d_model,
204
- nhead,
205
- dropout=dropout,
206
- batch_first=batch_first,
207
- linear1_cls=linear1_self_attention_cls,
208
- linear2_cls=linear2_self_attention_cls,
209
- **factory_kwargs,
210
- )
211
-
212
- # Implementation of Feedforward model
213
- self.linear1 = linear1_feedforward_cls(
214
- d_model, dim_feedforward, **factory_kwargs
215
- )
216
- self.dropout = nn.Dropout(dropout)
217
- self.linear2 = linear2_feedforward_cls(
218
- dim_feedforward, d_model, **factory_kwargs
219
- )
220
-
221
- self.norm_first = norm_first
222
- self.dropout1 = nn.Dropout(dropout)
223
- self.dropout2 = nn.Dropout(dropout)
224
-
225
- # Legacy string support for activation function.
226
- if isinstance(activation, str):
227
- activation = _get_activation_fn(activation)
228
- elif isinstance(activation, partial):
229
- activation = activation(d_model)
230
- elif activation == BalancedDoubleSwish:
231
- activation = BalancedDoubleSwish(d_model)
232
-
233
- # # We can't test self.activation in forward() in TorchScript,
234
- # # so stash some information about it instead.
235
- # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236
- # self.activation_relu_or_gelu = 1
237
- # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238
- # self.activation_relu_or_gelu = 2
239
- # else:
240
- # self.activation_relu_or_gelu = 0
241
- self.activation = activation
242
-
243
- norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
- if layer_norm_cls == IdentityNorm:
245
- norm2 = BalancedBasicNorm(
246
- d_model, eps=layer_norm_eps, **factory_kwargs
247
- )
248
- else:
249
- norm2 = layer_norm_cls(
250
- d_model, eps=layer_norm_eps, **factory_kwargs
251
- )
252
-
253
- if adaptive_layer_norm:
254
- self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255
- self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256
- else:
257
- self.norm1 = norm1
258
- self.norm2 = norm2
259
-
260
- def __setstate__(self, state):
261
- super(TransformerEncoderLayer, self).__setstate__(state)
262
- if not hasattr(self, "activation"):
263
- self.activation = F.relu
264
-
265
- def forward(
266
- self,
267
- src: Tensor,
268
- src_mask: Optional[Tensor] = None,
269
- src_key_padding_mask: Optional[Tensor] = None,
270
- ) -> Tensor:
271
- r"""Pass the input through the encoder layer.
272
-
273
- Args:
274
- src: the sequence to the encoder layer (required).
275
- src_mask: the mask for the src sequence (optional).
276
- src_key_padding_mask: the mask for the src keys per batch (optional).
277
-
278
- Shape:
279
- see the docs in Transformer class.
280
- """
281
- x, stage_embedding = src, None
282
- is_src_tuple = False
283
- if isinstance(src, tuple):
284
- x, stage_embedding = src
285
- is_src_tuple = True
286
-
287
- if src_key_padding_mask is not None:
288
- _skpm_dtype = src_key_padding_mask.dtype
289
- if _skpm_dtype != torch.bool and not torch.is_floating_point(
290
- src_key_padding_mask
291
- ):
292
- raise AssertionError(
293
- "only bool and floating types of key_padding_mask are supported"
294
- )
295
-
296
- if self.norm_first:
297
- x = x + self._sa_block(
298
- self.norm1(x, stage_embedding),
299
- src_mask,
300
- src_key_padding_mask,
301
- )
302
- x = x + self._ff_block(self.norm2(x, stage_embedding))
303
- else:
304
- x = self.norm1(
305
- x + self._sa_block(x, src_mask, src_key_padding_mask),
306
- stage_embedding,
307
- )
308
- x = self.norm2(x + self._ff_block(x), stage_embedding)
309
-
310
- if is_src_tuple:
311
- return (x, stage_embedding)
312
- return x
313
-
314
- def infer(
315
- self,
316
- src: Tensor,
317
- src_mask: Optional[Tensor] = None,
318
- src_key_padding_mask: Optional[Tensor] = None,
319
- past_kv: Optional[Tensor] = None,
320
- use_cache: bool = False,
321
- ):
322
- x, stage_embedding = src, None
323
- is_src_tuple = False
324
- if isinstance(src, tuple):
325
- x, stage_embedding = src
326
- is_src_tuple = True
327
-
328
- if src_key_padding_mask is not None:
329
- _skpm_dtype = src_key_padding_mask.dtype
330
- if _skpm_dtype != torch.bool and not torch.is_floating_point(
331
- src_key_padding_mask
332
- ):
333
- raise AssertionError(
334
- "only bool and floating types of key_padding_mask are supported"
335
- )
336
-
337
- if self.norm_first:
338
- x_attn_out, kv = self.self_attn.infer(
339
- self.norm1(x, stage_embedding),
340
- attn_mask=src_mask,
341
- key_padding_mask=src_key_padding_mask,
342
- need_weights=False,
343
- past_kv=past_kv,
344
- use_cache=use_cache,
345
- )
346
- x = x + x_attn_out
347
- x = x + self._ff_block(self.norm2(x, stage_embedding))
348
-
349
- if is_src_tuple:
350
- return (x, stage_embedding)
351
- return (x, kv)
352
-
353
- # self-attention block
354
- def _sa_block(
355
- self,
356
- x: Tensor,
357
- attn_mask: Optional[Tensor],
358
- key_padding_mask: Optional[Tensor],
359
- ) -> Tensor:
360
- x = self.self_attn(
361
- x,
362
- x,
363
- x,
364
- attn_mask=attn_mask,
365
- key_padding_mask=key_padding_mask,
366
- need_weights=False,
367
- )[0]
368
- return self.dropout1(x)
369
-
370
- # feed forward block
371
- def _ff_block(self, x: Tensor) -> Tensor:
372
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
373
- return self.dropout2(x)
374
-
375
-
376
- class TransformerEncoder(nn.Module):
377
- r"""TransformerEncoder is a stack of N encoder layers. Users can build the
378
- BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
379
-
380
- Args:
381
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
382
- num_layers: the number of sub-encoder-layers in the encoder (required).
383
- norm: the layer normalization component (optional).
384
- enable_nested_tensor: if True, input will automatically convert to nested tensor
385
- (and convert back on output). This will improve the overall performance of
386
- TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
387
-
388
- Examples::
389
- >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
390
- >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
391
- >>> src = torch.rand(10, 32, 512)
392
- >>> out = transformer_encoder(src)
393
- """
394
- __constants__ = ["norm"]
395
-
396
- def __init__(self, encoder_layer, num_layers, norm=None):
397
- super(TransformerEncoder, self).__init__()
398
- self.layers = _get_clones(encoder_layer, num_layers)
399
- self.num_layers = num_layers
400
- self.norm = norm
401
-
402
- def forward(
403
- self,
404
- src: Tensor,
405
- mask: Optional[Tensor] = None,
406
- src_key_padding_mask: Optional[Tensor] = None,
407
- return_layer_states: bool = False,
408
- ) -> Tensor:
409
- r"""Pass the input through the encoder layers in turn.
410
-
411
- Args:
412
- src: the sequence to the encoder (required).
413
- mask: the mask for the src sequence (optional).
414
- src_key_padding_mask: the mask for the src keys per batch (optional).
415
- return_layer_states: return layers' state (optional).
416
-
417
- Shape:
418
- see the docs in Transformer class.
419
- """
420
- if return_layer_states:
421
- layer_states = [] # layers' output
422
- output = src
423
- for mod in self.layers:
424
- output = mod(
425
- output,
426
- src_mask=mask,
427
- src_key_padding_mask=src_key_padding_mask,
428
- )
429
- layer_states.append(output[0])
430
-
431
- if self.norm is not None:
432
- output = self.norm(output)
433
-
434
- return layer_states, output
435
-
436
- output = src
437
- for mod in self.layers:
438
- output = mod(
439
- output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
440
- )
441
-
442
- if self.norm is not None:
443
- output = self.norm(output)
444
-
445
- return output
446
-
447
- def infer(
448
- self,
449
- src: Tensor,
450
- mask: Optional[Tensor] = None,
451
- src_key_padding_mask: Optional[Tensor] = None,
452
- return_layer_states: bool = False,
453
- past_kv: Optional[Tensor] = None,
454
- use_cache: bool = False,
455
- ):
456
- if past_kv is None:
457
- past_length = 0
458
- past_kv = tuple([None] * self.num_layers)
459
- else:
460
- past_length = past_kv[0][0].size(-2)
461
- new_kv = () if use_cache else None
462
- output = src
463
- for mod, past_layer_kv in zip(self.layers, past_kv):
464
- output, kv = mod.infer(
465
- output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
466
- )
467
- if use_cache:
468
- new_kv = new_kv + (kv,)
469
-
470
- if self.norm is not None:
471
- output = self.norm(output)
472
-
473
- return output, new_kv
474
-
475
-
476
- class TransformerDecoderLayer(nn.Module):
477
- __constants__ = ["batch_first", "norm_first"]
478
-
479
- def __init__(
480
- self,
481
- d_model: int,
482
- nhead: int,
483
- dim_feedforward: int = 2048,
484
- dropout: float = 0.1,
485
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
486
- linear1_self_attention_cls: nn.Module = nn.Linear,
487
- linear2_self_attention_cls: nn.Module = nn.Linear,
488
- linear1_feedforward_cls: nn.Module = nn.Linear,
489
- linear2_feedforward_cls: nn.Module = nn.Linear,
490
- batch_first: bool = False,
491
- norm_first: bool = False,
492
- device=None,
493
- dtype=None,
494
- layer_norm_cls: nn.Module = LayerNorm,
495
- layer_norm_eps: float = 1e-5,
496
- adaptive_layer_norm=False,
497
- ) -> None:
498
- factory_kwargs = {"device": device, "dtype": dtype}
499
- super(TransformerDecoderLayer, self).__init__()
500
- self.self_attn = MultiheadAttention(
501
- d_model,
502
- nhead,
503
- dropout=dropout,
504
- batch_first=batch_first,
505
- linear1_cls=linear1_self_attention_cls,
506
- linear2_cls=linear2_self_attention_cls,
507
- **factory_kwargs,
508
- )
509
- self.multihead_attn = MultiheadAttention(
510
- d_model,
511
- nhead,
512
- dropout=dropout,
513
- batch_first=batch_first,
514
- linear1_cls=linear1_self_attention_cls,
515
- linear2_cls=linear2_self_attention_cls,
516
- **factory_kwargs,
517
- )
518
- # Implementation of Feedforward model
519
- self.linear1 = linear1_feedforward_cls(
520
- d_model, dim_feedforward, **factory_kwargs
521
- )
522
- self.dropout = nn.Dropout(dropout)
523
- self.linear2 = linear2_feedforward_cls(
524
- dim_feedforward, d_model, **factory_kwargs
525
- )
526
-
527
- self.norm_first = norm_first
528
- self.dropout1 = nn.Dropout(dropout)
529
- self.dropout2 = nn.Dropout(dropout)
530
- self.dropout3 = nn.Dropout(dropout)
531
-
532
- # Legacy string support for activation function.
533
- if isinstance(activation, str):
534
- self.activation = _get_activation_fn(activation)
535
- elif isinstance(activation, partial):
536
- self.activation = activation(d_model)
537
- elif activation == BalancedDoubleSwish:
538
- self.activation = BalancedDoubleSwish(d_model)
539
- else:
540
- self.activation = activation
541
-
542
- if adaptive_layer_norm:
543
- norm1 = layer_norm_cls(
544
- d_model, eps=layer_norm_eps, **factory_kwargs
545
- )
546
- norm2 = layer_norm_cls(
547
- d_model, eps=layer_norm_eps, **factory_kwargs
548
- )
549
- norm3 = layer_norm_cls(
550
- d_model, eps=layer_norm_eps, **factory_kwargs
551
- )
552
-
553
- self.norm1 = AdaptiveLayerNorm(d_model, norm1)
554
- self.norm2 = AdaptiveLayerNorm(d_model, norm2)
555
- self.norm3 = AdaptiveLayerNorm(d_model, norm3)
556
- else:
557
- self.norm1 = layer_norm_cls(
558
- d_model, eps=layer_norm_eps, **factory_kwargs
559
- )
560
- self.norm2 = layer_norm_cls(
561
- d_model, eps=layer_norm_eps, **factory_kwargs
562
- )
563
- if layer_norm_cls == IdentityNorm:
564
- self.norm3 = BalancedBasicNorm(
565
- d_model, eps=layer_norm_eps, **factory_kwargs
566
- )
567
- else:
568
- self.norm3 = layer_norm_cls(
569
- d_model, eps=layer_norm_eps, **factory_kwargs
570
- )
571
-
572
- def forward(
573
- self,
574
- tgt: Tensor,
575
- memory: Tensor,
576
- tgt_mask: Optional[Tensor] = None,
577
- memory_mask: Optional[Tensor] = None,
578
- tgt_key_padding_mask: Optional[Tensor] = None,
579
- memory_key_padding_mask: Optional[Tensor] = None,
580
- ) -> Tensor:
581
- r"""Pass the inputs (and mask) through the decoder layer.
582
-
583
- Args:
584
- tgt: the sequence to the decoder layer (required).
585
- memory: the sequence from the last layer of the encoder (required).
586
- tgt_mask: the mask for the tgt sequence (optional).
587
- memory_mask: the mask for the memory sequence (optional).
588
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
589
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
590
-
591
- Shape:
592
- see the docs in Transformer class.
593
- """
594
- tgt_is_tuple = False
595
- if isinstance(tgt, tuple):
596
- x, stage_embedding = tgt
597
- tgt_is_tuple = True
598
- else:
599
- x, stage_embedding = tgt, None
600
-
601
- if self.norm_first:
602
- x = x + self._sa_block(
603
- self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
604
- )
605
- x = x + self._mha_block(
606
- self.norm2(x, stage_embedding),
607
- memory,
608
- memory_mask,
609
- memory_key_padding_mask,
610
- )
611
- x = x + self._ff_block(self.norm3(x, stage_embedding))
612
- else:
613
- x = self.norm1(
614
- x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
615
- stage_embedding,
616
- )
617
- x = self.norm2(
618
- x
619
- + self._mha_block(
620
- x, memory, memory_mask, memory_key_padding_mask
621
- ),
622
- stage_embedding,
623
- )
624
- x = self.norm3(x + self._ff_block(x), stage_embedding)
625
-
626
- if tgt_is_tuple:
627
- return (x, stage_embedding)
628
- return x
629
-
630
- # self-attention block
631
- def _sa_block(
632
- self,
633
- x: Tensor,
634
- attn_mask: Optional[Tensor],
635
- key_padding_mask: Optional[Tensor],
636
- ) -> Tensor:
637
- x = self.self_attn(
638
- x,
639
- x,
640
- x,
641
- attn_mask=attn_mask,
642
- key_padding_mask=key_padding_mask,
643
- need_weights=False,
644
- )[0]
645
- return self.dropout1(x)
646
-
647
- # multihead attention block
648
- def _mha_block(
649
- self,
650
- x: Tensor,
651
- mem: Tensor,
652
- attn_mask: Optional[Tensor],
653
- key_padding_mask: Optional[Tensor],
654
- ) -> Tensor:
655
- x = self.multihead_attn(
656
- x,
657
- mem,
658
- mem,
659
- attn_mask=attn_mask,
660
- key_padding_mask=key_padding_mask,
661
- need_weights=False,
662
- )[0]
663
- return self.dropout2(x)
664
-
665
- # feed forward block
666
- def _ff_block(self, x: Tensor) -> Tensor:
667
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
668
- return self.dropout3(x)
669
-
670
-
671
- def _get_clones(module, N):
672
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
673
-
674
-
675
- def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
676
- if activation == "relu":
677
- return F.relu
678
- elif activation == "gelu":
679
- return F.gelu
680
-
681
- raise RuntimeError(
682
- "activation should be relu/gelu, not {}".format(activation)
683
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompts/promptsf DELETED
File without changes
utils/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- # from icefall.utils import make_pad_mask
4
-
5
- from .symbol_table import SymbolTable
6
-
7
- # make_pad_mask = make_pad_mask
8
- SymbolTable = SymbolTable
9
-
10
-
11
- class Transpose(nn.Identity):
12
- """(N, T, D) -> (N, D, T)"""
13
-
14
- def forward(self, input: torch.Tensor) -> torch.Tensor:
15
- return input.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (915 Bytes)
 
utils/__pycache__/generation.cpython-311.pyc DELETED
Binary file (15.1 kB)
 
utils/__pycache__/prompt_making.cpython-311.pyc DELETED
Binary file (7 kB)
 
utils/__pycache__/sentence_cutter.cpython-311.pyc DELETED
Binary file (3.5 kB)
 
utils/__pycache__/symbol_table.cpython-311.pyc DELETED
Binary file (12.8 kB)
 
utils/download.py DELETED
@@ -1,49 +0,0 @@
1
- import sys
2
- import requests
3
-
4
-
5
- def download_file_from_google_drive(id, destination):
6
- URL = "https://docs.google.com/uc?export=download&confirm=1"
7
-
8
- session = requests.Session()
9
-
10
- response = session.get(URL, params={"id": id}, stream=True)
11
- token = get_confirm_token(response)
12
-
13
- if token:
14
- params = {"id": id, "confirm": token}
15
- response = session.get(URL, params=params, stream=True)
16
-
17
- save_response_content(response, destination)
18
-
19
-
20
- def get_confirm_token(response):
21
- for key, value in response.cookies.items():
22
- if key.startswith("download_warning"):
23
- return value
24
-
25
- return None
26
-
27
-
28
- def save_response_content(response, destination):
29
- CHUNK_SIZE = 32768
30
-
31
- with open(destination, "wb", encoding='utf-8') as f:
32
- for chunk in response.iter_content(CHUNK_SIZE):
33
- if chunk: # filter out keep-alive new chunks
34
- f.write(chunk)
35
-
36
-
37
- def main():
38
- if len(sys.argv) >= 3:
39
- file_id = sys.argv[1]
40
- destination = sys.argv[2]
41
- else:
42
- file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
43
- destination = "DESTINATION_FILE_ON_YOUR_DISK"
44
- print(f"dowload {file_id} to {destination}")
45
- download_file_from_google_drive(file_id, destination)
46
-
47
-
48
- if __name__ == "__main__":
49
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/g2p/__init__.py DELETED
@@ -1,72 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
- import utils.g2p.cleaners
3
- from utils.g2p.symbols import symbols
4
- from tokenizers import Tokenizer
5
-
6
- # Mappings from symbol to numeric ID and vice versa:
7
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
- _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
-
10
-
11
- class PhonemeBpeTokenizer:
12
- def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
13
- self.tokenizer = Tokenizer.from_file(tokenizer_path)
14
-
15
- def tokenize(self, text):
16
- # 1. convert text to phoneme
17
- phonemes, langs = _clean_text(text, ['cje_cleaners'])
18
- # 2. replace blank space " " with "_"
19
- phonemes = phonemes.replace(" ", "_")
20
- # 3. tokenize phonemes
21
- phoneme_tokens = self.tokenizer.encode(phonemes).ids
22
- assert(len(phoneme_tokens) == len(langs))
23
- if not len(phoneme_tokens):
24
- raise ValueError("Empty text is given")
25
- return phoneme_tokens, langs
26
-
27
- def text_to_sequence(text, cleaner_names):
28
- '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29
- Args:
30
- text: string to convert to a sequence
31
- cleaner_names: names of the cleaner functions to run the text through
32
- Returns:
33
- List of integers corresponding to the symbols in the text
34
- '''
35
- sequence = []
36
- symbol_to_id = {s: i for i, s in enumerate(symbols)}
37
- clean_text = _clean_text(text, cleaner_names)
38
- for symbol in clean_text:
39
- if symbol not in symbol_to_id.keys():
40
- continue
41
- symbol_id = symbol_to_id[symbol]
42
- sequence += [symbol_id]
43
- return sequence
44
-
45
-
46
- def cleaned_text_to_sequence(cleaned_text):
47
- '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
48
- Args:
49
- text: string to convert to a sequence
50
- Returns:
51
- List of integers corresponding to the symbols in the text
52
- '''
53
- sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
54
- return sequence
55
-
56
-
57
- def sequence_to_text(sequence):
58
- '''Converts a sequence of IDs back to a string'''
59
- result = ''
60
- for symbol_id in sequence:
61
- s = _id_to_symbol[symbol_id]
62
- result += s
63
- return result
64
-
65
-
66
- def _clean_text(text, cleaner_names):
67
- for name in cleaner_names:
68
- cleaner = getattr(utils.g2p.cleaners, name)
69
- if not cleaner:
70
- raise Exception('Unknown cleaner: %s' % name)
71
- text, langs = cleaner(text)
72
- return text, langs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/g2p/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (4.49 kB)
 
utils/g2p/__pycache__/cleaners.cpython-311.pyc DELETED
Binary file (4.66 kB)
 
utils/g2p/__pycache__/english.cpython-311.pyc DELETED
Binary file (8.53 kB)
 
utils/g2p/__pycache__/japanese.cpython-311.pyc DELETED
Binary file (8.34 kB)
 
utils/g2p/__pycache__/mandarin.cpython-311.pyc DELETED
Binary file (9.61 kB)
 
utils/g2p/__pycache__/symbols.cpython-311.pyc DELETED
Binary file (1.5 kB)