import os import numpy as np import pandas as pd import torch import torchaudio import tqdm import torch.utils.data as data from utils import MIDITokenExtractor from config import voc_single_track from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC, SEGMENT_N_FRAMES """# Dataset Uses MAESTRO v3.0.0 dataset. """ class AMTDatasetBase(data.Dataset): def __init__( self, flist_audio, # オーディオファイルのパスをリスト形式で渡す flist_midi, # MIDIファイルのパスをリスト形式で渡す sample_rate, # オーディオファイルのサンプリングレートを指定。全てのオーディオがこれにリサンプリングされる。 voc_dict, # トークン定義を渡す apply_pedal=True, whole_song=False, ): super().__init__() self.midi_filelist = flist_midi self.audio_filelist = flist_audio self.audio_metalist = [torchaudio.info(f) for f in flist_audio] # 各オーディオファイルのメタ情報(サンプルレート、フレーム数など)を収集します。 self.voc_dict = voc_dict # 各MIDIファイルを MIDITokenExtractor を使ってトークン化し、その結果をリストとして保持します。 self.midi_list = [ MIDITokenExtractor(f, voc_dict, apply_pedal) for f in tqdm.tqdm(self.midi_filelist, desc="load dataset") ] self.sample_rate = sample_rate self.whole_song = whole_song def __len__(self): return len(self.audio_filelist) def __getitem__(self, index): """ Return a pair of (audio, tokens) for the given index. On the training stage, return a random segment from the song. On the test stage, return the audio and MIDI of the whole song. """ if not self.whole_song: return self.getitem_segment(index) else: return self.getitem_wholesong(index) def getitem_segment(self, index, start_pos=None): # 対象ファイルを指定するindexとセグメントの開始位置(フレーム単位)。Noneの場合はランダムに選択 metadata = self.audio_metalist[index] num_frames = metadata.num_frames # オーディオの全体の「サンプル数」。 sample_rate = metadata.sample_rate duration_y = round(num_frames / float(sample_rate) * FRAME_PER_SEC) # オーディオ全体の長さをフレーム単位に変換 midi_item = self.midi_list[index] # セグメントの開始位置と終了位置(フレーム単位)を決定。 if start_pos is None: # np.random.randint を使用して、オーディオ全体からランダムに開始位置を選択。 segment_start = np.random.randint(duration_y - SEGMENT_N_FRAMES) else: # start_pos が指定されている場合 segment_start = start_pos segment_end = segment_start + SEGMENT_N_FRAMES # オーディオセグメントのサンプル単位の開始位置 segment_start_sample = round( segment_start * FRAME_STEP_SIZE_SEC * sample_rate ) # セグメント範囲(segment_start ~ segment_end)に対応するMIDIトークン列を抽出。 segment_tokens = midi_item.get_segment_tokens(segment_start, segment_end) segment_tokens = torch.from_numpy(segment_tokens).long() # NumPy配列をPyTorchテンソルに変換。long()でテンソルのデータ型を64ビット整数(long)に設定。 # 指定されたセグメント範囲のオーディオデータを読み込む。 # frame_offset から始まる範囲を num_frames サンプル分読み込む。 y_segment, _ = torchaudio.load( self.audio_filelist[index], frame_offset=segment_start_sample, num_frames=round(AUDIO_SEGMENT_SEC * sample_rate), ) y_segment = y_segment.mean(0) # オーディオが複数チャンネルの場合(例: ステレオ)、チャンネルを平均してモノラルに変換。 # サンプルレートのリサンプリング # オーディオデータのサンプルレートが self.sample_rate と異なる場合、指定されたサンプルレートにリサンプリング。 if sample_rate != self.sample_rate: y_segment = torchaudio.functional.resample( y_segment, sample_rate, self.sample_rate, resampling_method="kaiser_window", # Kaiserウィンドウによるリサンプリングアルゴリズムを適用。 ) return y_segment, segment_tokens def getitem_wholesong(self, index): """ Return a pair of (audio, midi) for the given index. """ y, sr = torchaudio.load(self.audio_filelist[index]) # 読み込まれた波形データ(テンソル形式)。形状は (チャンネル数, サンプル数)。 y = y.mean(0) # モノラル化 # サンプルレートのリサンプリング if sr != self.sample_rate: y = torchaudio.functional.resample( y, sr, self.sample_rate, resampling_method="kaiser_window" ) midi = self.midi_list[index].pm return y, midi # collateはバッチにまとめる役割の関数 def collate_wholesong(self, batch): # batch: データセットから取り出された複数のデータ(オーディオとMIDIのペア)のリスト。 # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。 # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。 # 出力: テンソルの形状は (バッチサイズ, サンプル数)。 batch_audio = torch.stack([b[0] for b in batch], dim=0) midi = [b[1] for b in batch] # バッチ内の各曲のMIDIデータをリストとしてまとめる。 return batch_audio, midi # テンソル, リスト def collate_batch(self, batch): # データセットから取り出されたセグメント化されたオーディオテンソルとセグメント化されたMIDIトークン列のリスト。 # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。 # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。 # 出力: テンソルの形状は (バッチサイズ, サンプル数)。 batch_audio = torch.stack([b[0] for b in batch], dim=0) batch_tokens = [b[1] for b in batch] # バッチ内の各セグメントのトークン列をテンソル?リスト形式で取得。 # バッチ内のMIDIトークン列の長さを揃えるためにパディング # torch.nn.utils.rnn.pad_sequence は、異なる長さのシーケンス(テンソルリスト)をパディングして同じ長さに揃えるためのPyTorchユーティリティ(すべてのテンソルは同じ次元数である必要があります(長さ以外は一致)。) # batch_first = True: パディング後のテンソル形状を (バッチサイズ, 最大長さ) に設定 batch_tokens_pad = torch.nn.utils.rnn.pad_sequence( batch_tokens, batch_first=True, padding_value=self.voc_dict["pad"] ) return batch_audio, batch_tokens_pad # テンソル, テンソル (バッチサイズ, サンプル数), (バッチサイズ, 最大トークンの長さ) class CustomDataset(AMTDatasetBase): def __init__( self, midi_root: str = "/content/drive/MyDrive/B4/Humtrans/midi", wav_root: str = "/content/wav_rms", split: str = "train", sample_rate: int = 16000, apply_pedal: bool = True, whole_song: bool = False, ): """ MIDIとWAVのペアをロードするデータセットクラス Args: midi_root (str): MIDIファイルが保存されているルートフォルダ wav_root (str): WAVファイルが保存されているフォルダ split (str): 使用するデータセットの分割 ('train', 'valid', 'test') sample_rate (int): サンプルレート apply_pedal (bool): ペダルの適用 whole_song (bool): 曲全体をロードするか """ # MIDIフォルダのパスを設定 self.midi_root = f"/content/filtered_{split}_midi" self.wav_root = wav_root self.sample_rate = sample_rate self.split = split # MIDIとWAVのペアを見つける flist_midi, flist_audio = self._get_paired_files() # 親クラスのコンストラクタを呼び出し super().__init__( flist_audio, flist_midi, sample_rate, voc_dict=voc_single_track, apply_pedal=apply_pedal, whole_song=whole_song, ) def _get_paired_files(self): """ MIDIフォルダとWAVフォルダからペアとなるファイルリストを作成する Returns: flist_midi (list): 対応するMIDIファイルのリスト flist_audio (list): 対応するWAVファイルのリスト """ flist_midi = [] flist_audio = [] # MIDIフォルダからMIDIファイルを取得 midi_files = [f for f in os.listdir(self.midi_root) if f.endswith(".mid")] for midi_file in midi_files: # MIDIファイルのパスを構築 midi_path = os.path.join(self.midi_root, midi_file) # WAVファイルのパスを構築 (拡張子を変更) wav_file = os.path.splitext(midi_file)[0] + ".wav" wav_path = os.path.join(self.wav_root, wav_file) # WAVファイルが存在するか確認 if os.path.exists(wav_path): flist_midi.append(midi_path) flist_audio.append(wav_path) else: print(f"対応するWAVファイルが見つかりません: {midi_file}") print(f"{self.split}データセット: {len(flist_midi)} ペアのMIDI-WAVが見つかりました。") return flist_midi, flist_audio