Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
import torch.utils.data as data | |
from utils import voc_single_track | |
from model import Seq2SeqTranscriber | |
from dataset import CustomDataset | |
from utils import token_seg_list_to_midi, unpack_sequence | |
from eval import evaluate_midi | |
class LitTranscriber(pl.LightningModule): | |
def __init__( | |
self, | |
transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど | |
lr: float, # 学習率(Learning Rate)。 | |
lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。 | |
lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。 | |
): | |
super().__init__() | |
self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。 | |
self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。 | |
self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。 | |
# 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。 | |
self.transcriber = Seq2SeqTranscriber( | |
**transcriber_args, voc_dict=self.voc_dict | |
) | |
# 引数から渡された学習率や減衰関連の設定を保存。 | |
self.lr = lr | |
self.lr_decay = lr_decay | |
self.lr_decay_interval = lr_decay_interval | |
# 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数 | |
def forward(self, y: torch.Tensor): | |
transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。 | |
return transcriber_infer #推論結果として生成されたトークン列を返します。 | |
# モデルのトレーニング中に1バッチの処理を実行する関数 | |
def training_step(self, batch, batch_idx): | |
y, t = batch # 入力波形(y)と正解トークン列(t)を分割。 | |
tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。 | |
loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。 | |
t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。 | |
mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。 | |
# モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。 | |
# 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。 | |
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum() | |
#損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。 | |
self.log("train_loss", loss) | |
self.log("train_accr", accr) | |
return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。 | |
# モデルの検証時に1バッチの処理を実行する関数 | |
def validation_step(self, batch, batch_idx): | |
assert not self.transcriber.training # モデルがトレーニングモードでないことを確認 | |
y, t = batch | |
tf_out = self.transcriber(y, t) | |
loss = tf_out.loss | |
t = t.detach() | |
mask = t != self.voc_dict["pad"] | |
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum() | |
self.log("vali_loss", loss) | |
self.log("vali_accr", accr) | |
return loss | |
def test_step(self, batch, batch_idx): | |
y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割 | |
# テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。 | |
y = y[0] | |
ref_midi = ref_midi[0] | |
with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。 | |
est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長) | |
unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換 | |
unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。 | |
est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。 | |
dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。 | |
# 評価結果をPyTorch Lightningのログ機能で記録。 | |
dict_log = {} | |
for key in dict_eval: | |
dict_log["test/" + key] = dict_eval[key] | |
self.log_dict(dict_log, batch_size=1) | |
def train_dataloader(self): | |
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。 | |
dset = CustomDataset( | |
split="train", # データセットの分割 | |
sample_rate=16000, | |
apply_pedal=True, | |
whole_song=False | |
) | |
# データローダーの作成 | |
return data.DataLoader( | |
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。 | |
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。 | |
batch_size=64, # 1バッチに含めるサンプル数。 | |
shuffle=True, # データセットをランダムにシャッフル。 | |
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。 | |
num_workers=32, # データローディングに使用するプロセス数。 | |
) | |
def val_dataloader(self): | |
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。 | |
dset = CustomDataset( | |
split="valid", # データセットの分割 | |
sample_rate=16000, | |
apply_pedal=True, | |
whole_song=False | |
) | |
# データローダーの作成 | |
return data.DataLoader( | |
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。 | |
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。 | |
batch_size=32, # 1バッチに含めるサンプル数。 | |
shuffle=True, # データセットをランダムにシャッフル。 | |
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。 | |
num_workers=16, # データローディングに使用するプロセス数。 | |
) | |
def test_dataloader(self): | |
dset = CustomDataset( | |
split="test", # データセットの分割 | |
sample_rate=16000, | |
apply_pedal=True, | |
whole_song=True | |
) | |
return data.DataLoader( | |
dataset=dset, | |
collate_fn=dset.collate_wholesong, | |
batch_size=1, | |
shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。 | |
pin_memory=True, | |
) | |
# モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数 | |
def configure_optimizers(self): | |
return torch.optim.AdamW(self.parameters(), lr=self.lr) | |
# Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。 | |
# self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。 | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.loggers import TensorBoardLogger | |
# チェックポイント保存の設定 | |
if __name__ == "__main__": | |
checkpoint_callback = ModelCheckpoint( | |
dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints", # 保存先 | |
filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}", # ファイル名フォーマット | |
save_top_k=3, # 上位3つのチェックポイントを保存 | |
monitor="vali_loss", # 訓練データの損失を監視 | |
mode="min", # 損失が小さいほど良いモデル | |
every_n_epochs=1, # 1エポックごとに保存 | |
) | |
# ロガー設定 | |
logger = TensorBoardLogger( | |
save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs", | |
name="transcription" | |
) | |
# トレーナーの設定 | |
trainer = pl.Trainer( | |
logger=logger, | |
callbacks=[checkpoint_callback], # チェックポイントコールバックを追加 | |
enable_checkpointing=True, | |
accelerator="gpu", | |
devices=1, | |
max_epochs=400, | |
) | |
# モデルの引数 | |
args = { | |
"n_mels": 128, | |
"sample_rate": 16000, | |
"n_fft": 1024, | |
"hop_length": 128, | |
} | |
lightning_module = LitTranscriber( | |
transcriber_args=args, | |
lr=1e-4, # 0.0001 | |
lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる | |
) | |
# 学習の開始 | |
trainer.fit(lightning_module) | |