Spaces:
Running
Running
File size: 10,547 Bytes
c094356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.utils.data as data
from utils import voc_single_track
from model import Seq2SeqTranscriber
from dataset import CustomDataset
from utils import token_seg_list_to_midi, unpack_sequence
from eval import evaluate_midi
class LitTranscriber(pl.LightningModule):
def __init__(
self,
transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど
lr: float, # 学習率(Learning Rate)。
lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。
lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。
):
super().__init__()
self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。
self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。
self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。
# 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。
self.transcriber = Seq2SeqTranscriber(
**transcriber_args, voc_dict=self.voc_dict
)
# 引数から渡された学習率や減衰関連の設定を保存。
self.lr = lr
self.lr_decay = lr_decay
self.lr_decay_interval = lr_decay_interval
# 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数
def forward(self, y: torch.Tensor):
transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。
return transcriber_infer #推論結果として生成されたトークン列を返します。
# モデルのトレーニング中に1バッチの処理を実行する関数
def training_step(self, batch, batch_idx):
y, t = batch # 入力波形(y)と正解トークン列(t)を分割。
tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。
loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。
t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。
mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。
# モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。
# 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
#損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。
self.log("train_loss", loss)
self.log("train_accr", accr)
return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。
# モデルの検証時に1バッチの処理を実行する関数
def validation_step(self, batch, batch_idx):
assert not self.transcriber.training # モデルがトレーニングモードでないことを確認
y, t = batch
tf_out = self.transcriber(y, t)
loss = tf_out.loss
t = t.detach()
mask = t != self.voc_dict["pad"]
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
self.log("vali_loss", loss)
self.log("vali_accr", accr)
return loss
def test_step(self, batch, batch_idx):
y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割
# テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。
y = y[0]
ref_midi = ref_midi[0]
with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。
est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長)
unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換
unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。
est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。
dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。
# 評価結果をPyTorch Lightningのログ機能で記録。
dict_log = {}
for key in dict_eval:
dict_log["test/" + key] = dict_eval[key]
self.log_dict(dict_log, batch_size=1)
def train_dataloader(self):
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
dset = CustomDataset(
split="train", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=False
)
# データローダーの作成
return data.DataLoader(
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
batch_size=64, # 1バッチに含めるサンプル数。
shuffle=True, # データセットをランダムにシャッフル。
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
num_workers=32, # データローディングに使用するプロセス数。
)
def val_dataloader(self):
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
dset = CustomDataset(
split="valid", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=False
)
# データローダーの作成
return data.DataLoader(
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
batch_size=32, # 1バッチに含めるサンプル数。
shuffle=True, # データセットをランダムにシャッフル。
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
num_workers=16, # データローディングに使用するプロセス数。
)
def test_dataloader(self):
dset = CustomDataset(
split="test", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=True
)
return data.DataLoader(
dataset=dset,
collate_fn=dset.collate_wholesong,
batch_size=1,
shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。
pin_memory=True,
)
# モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)
# Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。
# self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
# チェックポイント保存の設定
if __name__ == "__main__":
checkpoint_callback = ModelCheckpoint(
dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints", # 保存先
filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}", # ファイル名フォーマット
save_top_k=3, # 上位3つのチェックポイントを保存
monitor="vali_loss", # 訓練データの損失を監視
mode="min", # 損失が小さいほど良いモデル
every_n_epochs=1, # 1エポックごとに保存
)
# ロガー設定
logger = TensorBoardLogger(
save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs",
name="transcription"
)
# トレーナーの設定
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback], # チェックポイントコールバックを追加
enable_checkpointing=True,
accelerator="gpu",
devices=1,
max_epochs=400,
)
# モデルの引数
args = {
"n_mels": 128,
"sample_rate": 16000,
"n_fft": 1024,
"hop_length": 128,
}
lightning_module = LitTranscriber(
transcriber_args=args,
lr=1e-4, # 0.0001
lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる
)
# 学習の開始
trainer.fit(lightning_module)
|