humtrans / train.py
hayaton0005's picture
Upload 11 files
c094356 verified
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.utils.data as data
from utils import voc_single_track
from model import Seq2SeqTranscriber
from dataset import CustomDataset
from utils import token_seg_list_to_midi, unpack_sequence
from eval import evaluate_midi
class LitTranscriber(pl.LightningModule):
def __init__(
self,
transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど
lr: float, # 学習率(Learning Rate)。
lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。
lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。
):
super().__init__()
self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。
self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。
self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。
# 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。
self.transcriber = Seq2SeqTranscriber(
**transcriber_args, voc_dict=self.voc_dict
)
# 引数から渡された学習率や減衰関連の設定を保存。
self.lr = lr
self.lr_decay = lr_decay
self.lr_decay_interval = lr_decay_interval
# 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数
def forward(self, y: torch.Tensor):
transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。
return transcriber_infer #推論結果として生成されたトークン列を返します。
# モデルのトレーニング中に1バッチの処理を実行する関数
def training_step(self, batch, batch_idx):
y, t = batch # 入力波形(y)と正解トークン列(t)を分割。
tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。
loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。
t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。
mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。
# モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。
# 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
#損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。
self.log("train_loss", loss)
self.log("train_accr", accr)
return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。
# モデルの検証時に1バッチの処理を実行する関数
def validation_step(self, batch, batch_idx):
assert not self.transcriber.training # モデルがトレーニングモードでないことを確認
y, t = batch
tf_out = self.transcriber(y, t)
loss = tf_out.loss
t = t.detach()
mask = t != self.voc_dict["pad"]
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
self.log("vali_loss", loss)
self.log("vali_accr", accr)
return loss
def test_step(self, batch, batch_idx):
y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割
# テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。
y = y[0]
ref_midi = ref_midi[0]
with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。
est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長)
unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換
unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。
est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。
dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。
# 評価結果をPyTorch Lightningのログ機能で記録。
dict_log = {}
for key in dict_eval:
dict_log["test/" + key] = dict_eval[key]
self.log_dict(dict_log, batch_size=1)
def train_dataloader(self):
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
dset = CustomDataset(
split="train", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=False
)
# データローダーの作成
return data.DataLoader(
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
batch_size=64, # 1バッチに含めるサンプル数。
shuffle=True, # データセットをランダムにシャッフル。
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
num_workers=32, # データローディングに使用するプロセス数。
)
def val_dataloader(self):
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
dset = CustomDataset(
split="valid", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=False
)
# データローダーの作成
return data.DataLoader(
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
batch_size=32, # 1バッチに含めるサンプル数。
shuffle=True, # データセットをランダムにシャッフル。
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
num_workers=16, # データローディングに使用するプロセス数。
)
def test_dataloader(self):
dset = CustomDataset(
split="test", # データセットの分割
sample_rate=16000,
apply_pedal=True,
whole_song=True
)
return data.DataLoader(
dataset=dset,
collate_fn=dset.collate_wholesong,
batch_size=1,
shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。
pin_memory=True,
)
# モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)
# Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。
# self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
# チェックポイント保存の設定
if __name__ == "__main__":
checkpoint_callback = ModelCheckpoint(
dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints", # 保存先
filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}", # ファイル名フォーマット
save_top_k=3, # 上位3つのチェックポイントを保存
monitor="vali_loss", # 訓練データの損失を監視
mode="min", # 損失が小さいほど良いモデル
every_n_epochs=1, # 1エポックごとに保存
)
# ロガー設定
logger = TensorBoardLogger(
save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs",
name="transcription"
)
# トレーナーの設定
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback], # チェックポイントコールバックを追加
enable_checkpointing=True,
accelerator="gpu",
devices=1,
max_epochs=400,
)
# モデルの引数
args = {
"n_mels": 128,
"sample_rate": 16000,
"n_fft": 1024,
"hop_length": 128,
}
lightning_module = LitTranscriber(
transcriber_args=args,
lr=1e-4, # 0.0001
lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる
)
# 学習の開始
trainer.fit(lightning_module)