File size: 10,547 Bytes
c094356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.utils.data as data

from utils import voc_single_track
from model import Seq2SeqTranscriber
from dataset import CustomDataset
from utils import token_seg_list_to_midi, unpack_sequence    
from eval import evaluate_midi

class LitTranscriber(pl.LightningModule):
    def __init__(
        self,
        transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど
        lr: float, # 学習率(Learning Rate)。
        lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。
        lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。
    ):
        super().__init__()
        self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。
        self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。
        self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。

        # 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。
        self.transcriber = Seq2SeqTranscriber(
            **transcriber_args, voc_dict=self.voc_dict
        )
        # 引数から渡された学習率や減衰関連の設定を保存。
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_interval = lr_decay_interval

    # 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数
    def forward(self, y: torch.Tensor):
        transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。
        return transcriber_infer #推論結果として生成されたトークン列を返します。

    # モデルのトレーニング中に1バッチの処理を実行する関数
    def training_step(self, batch, batch_idx):
        y, t = batch # 入力波形(y)と正解トークン列(t)を分割。
        tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。
        loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。
        t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。
        mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。

        # モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。
        # 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。
        accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
        #損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。
        self.log("train_loss", loss)
        self.log("train_accr", accr)
        return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。

    # モデルの検証時に1バッチの処理を実行する関数
    def validation_step(self, batch, batch_idx):
        assert not self.transcriber.training # モデルがトレーニングモードでないことを確認
        y, t = batch
        tf_out = self.transcriber(y, t)
        loss = tf_out.loss
        t = t.detach()
        mask = t != self.voc_dict["pad"]
        accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
        self.log("vali_loss", loss)
        self.log("vali_accr", accr)
        return loss

    def test_step(self, batch, batch_idx):
        y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割
        # テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。
        y = y[0]
        ref_midi = ref_midi[0]
        with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。
            est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長)
            unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換
            unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。
            est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。
        dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。
        # 評価結果をPyTorch Lightningのログ機能で記録。
        dict_log = {}
        for key in dict_eval:
            dict_log["test/" + key] = dict_eval[key]
        self.log_dict(dict_log, batch_size=1)

    def train_dataloader(self):
        # 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
        dset = CustomDataset(
            split="train",                      # データセットの分割
            sample_rate=16000,
            apply_pedal=True,
            whole_song=False
        )
        # データローダーの作成
        return data.DataLoader(
            dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
            collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
            batch_size=64, # 1バッチに含めるサンプル数。
            shuffle=True, # データセットをランダムにシャッフル。
            pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
            num_workers=32, # データローディングに使用するプロセス数。
        )

    def val_dataloader(self):
        # 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
        dset = CustomDataset(
            split="valid",                      # データセットの分割
            sample_rate=16000,
            apply_pedal=True,
            whole_song=False
        )
        # データローダーの作成
        return data.DataLoader(
            dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
            collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
            batch_size=32, # 1バッチに含めるサンプル数。
            shuffle=True, # データセットをランダムにシャッフル。
            pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
            num_workers=16, # データローディングに使用するプロセス数。
        )
    def test_dataloader(self):
        dset = CustomDataset(
            split="test",                      # データセットの分割
            sample_rate=16000,
            apply_pedal=True,
            whole_song=True
        )
        return data.DataLoader(
            dataset=dset,
            collate_fn=dset.collate_wholesong,
            batch_size=1,
            shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。
            pin_memory=True,
        )

    # モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)
    # Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。
    # self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。

from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

# チェックポイント保存の設定
if __name__ == "__main__":
    checkpoint_callback = ModelCheckpoint(
        dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints",  # 保存先
        filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}",  # ファイル名フォーマット
        save_top_k=3,  # 上位3つのチェックポイントを保存
        monitor="vali_loss",  # 訓練データの損失を監視
        mode="min",  # 損失が小さいほど良いモデル
        every_n_epochs=1,  # 1エポックごとに保存
    )

    # ロガー設定
    logger = TensorBoardLogger(
        save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs",
        name="transcription"
    )

    # トレーナーの設定
    trainer = pl.Trainer(
        logger=logger,
        callbacks=[checkpoint_callback],  # チェックポイントコールバックを追加
        enable_checkpointing=True,
        accelerator="gpu",
        devices=1,
        max_epochs=400,
    )

    # モデルの引数
    args = {
        "n_mels": 128,
        "sample_rate": 16000,
        "n_fft": 1024,
        "hop_length": 128,
    }

    lightning_module = LitTranscriber(
        transcriber_args=args,
        lr=1e-4, # 0.0001
        lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる
    )

    # 学習の開始
    trainer.fit(lightning_module)