hayaton0005 commited on
Commit
c094356
·
verified ·
1 Parent(s): 93cea17

Upload 11 files

Browse files
Files changed (11) hide show
  1. app.py +25 -0
  2. config.py +27 -0
  3. dataset.py +208 -0
  4. eval.py +34 -0
  5. infer.py +65 -0
  6. main.py +6 -0
  7. model.pth +3 -0
  8. model.py +65 -0
  9. requirements.txt +29 -0
  10. train.py +188 -0
  11. utils.py +345 -0
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import os
5
+ from infer import infer_midi_from_wav
6
+
7
+ def predict_midi(wav_file):
8
+ # 一時ファイルに保存
9
+ input_path = "input.wav"
10
+ wav_file.save(input_path)
11
+
12
+ midi_path = infer_midi_from_wav(input_path)
13
+
14
+ return midi_path # Gradioはファイルのパスを返すとダウンロードリンクとして表示される
15
+
16
+ iface = gr.Interface(
17
+ fn=predict_midi,
18
+ inputs=gr.Audio(source="upload", type="file"),
19
+ outputs=gr.File(label="Download MIDI"),
20
+ title="🎵 My Melody: 鼻歌→MIDI変換",
21
+ description="鼻歌をアップロードすると、MIDIファイルがダウンロードできます。"
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ iface.launch()
config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning.pytorch as pl
2
+
3
+ SR = 16000
4
+ AUDIO_SEGMENT_SEC = 2.0
5
+ SEGMENT_N_FRAMES = 200
6
+ FRAME_STEP_SIZE_SEC = 0.01
7
+ FRAME_PER_SEC = 100
8
+
9
+ pl.seed_everything(1234)
10
+
11
+ """# MIDI-like token"""
12
+
13
+ # MIDI-like token definitions
14
+
15
+ N_NOTE = 128
16
+ N_TIME = 205
17
+ N_SPECIAL = 3
18
+
19
+ voc_single_track = {
20
+ "pad": 0,
21
+ "eos": 1,
22
+ "endtie": 2,
23
+ "note": N_SPECIAL, # 3から始まるという意味
24
+ "onset": N_SPECIAL+N_NOTE,
25
+ "time": N_SPECIAL+N_NOTE+2,
26
+ "n_voc": N_SPECIAL+N_NOTE+2+N_TIME+3,
27
+ "keylist": ["pad", "eos", "endtie", "note", "onset", "time"]}
dataset.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torchaudio
6
+ import tqdm
7
+
8
+ import torch.utils.data as data
9
+
10
+ from utils import MIDITokenExtractor
11
+ from config import voc_single_track
12
+ from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC, SEGMENT_N_FRAMES
13
+
14
+ """# Dataset
15
+ Uses MAESTRO v3.0.0 dataset.
16
+ """
17
+
18
+ class AMTDatasetBase(data.Dataset):
19
+ def __init__(
20
+ self,
21
+ flist_audio, # オーディオファイルのパスをリスト形式で渡す
22
+ flist_midi, # MIDIファイルのパスをリスト形式で渡す
23
+ sample_rate, # オーディオファイルのサンプリングレートを指定。全てのオーディオがこれにリサンプリングされる。
24
+ voc_dict, # トークン定義を渡す
25
+ apply_pedal=True,
26
+ whole_song=False,
27
+ ):
28
+ super().__init__()
29
+ self.midi_filelist = flist_midi
30
+ self.audio_filelist = flist_audio
31
+ self.audio_metalist = [torchaudio.info(f) for f in flist_audio] # 各オーディオファイルのメタ情報(サンプルレート、フレーム数など)を収集します。
32
+ self.voc_dict = voc_dict
33
+ # 各MIDIファイルを MIDITokenExtractor を使ってトークン化し、その結果をリストとして保持します。
34
+ self.midi_list = [
35
+ MIDITokenExtractor(f, voc_dict, apply_pedal)
36
+ for f in tqdm.tqdm(self.midi_filelist, desc="load dataset")
37
+ ]
38
+ self.sample_rate = sample_rate
39
+ self.whole_song = whole_song
40
+
41
+ def __len__(self):
42
+ return len(self.audio_filelist)
43
+
44
+ def __getitem__(self, index):
45
+ """
46
+ Return a pair of (audio, tokens) for the given index.
47
+ On the training stage, return a random segment from the song.
48
+ On the test stage, return the audio and MIDI of the whole song.
49
+ """
50
+ if not self.whole_song:
51
+ return self.getitem_segment(index)
52
+ else:
53
+ return self.getitem_wholesong(index)
54
+
55
+ def getitem_segment(self, index, start_pos=None): # 対象ファイルを指定するindexとセグメントの開始位置(フレーム単位)。Noneの場合はランダムに選択
56
+ metadata = self.audio_metalist[index]
57
+ num_frames = metadata.num_frames # オーディオの全体の「サンプル数」。
58
+ sample_rate = metadata.sample_rate
59
+ duration_y = round(num_frames / float(sample_rate) * FRAME_PER_SEC) # オーディオ全体の長さをフレーム単位に変換
60
+ midi_item = self.midi_list[index]
61
+
62
+ # セグメントの開始位置と終了位置(フレーム単位)を決定。
63
+ if start_pos is None: # np.random.randint を使用して、オーディオ全体からランダムに開始位置を選択。
64
+ segment_start = np.random.randint(duration_y - SEGMENT_N_FRAMES)
65
+ else: # start_pos が指定されている場合
66
+ segment_start = start_pos
67
+ segment_end = segment_start + SEGMENT_N_FRAMES
68
+ # オーディオセグメントのサンプル単位の開始位置
69
+ segment_start_sample = round(
70
+ segment_start * FRAME_STEP_SIZE_SEC * sample_rate
71
+ )
72
+
73
+ # セグメント範囲(segment_start ~ segment_end)に対応するMIDIトークン列を抽出。
74
+ segment_tokens = midi_item.get_segment_tokens(segment_start, segment_end)
75
+ segment_tokens = torch.from_numpy(segment_tokens).long() # NumPy配列をPyTorchテンソルに変換。long()でテンソルのデータ型を64ビット整数(long)に設定。
76
+
77
+ # 指定されたセグメント範囲のオーディオデータを読み込む。
78
+ # frame_offset から始まる範囲を num_frames サンプル分読み込む。
79
+ y_segment, _ = torchaudio.load(
80
+ self.audio_filelist[index],
81
+ frame_offset=segment_start_sample,
82
+ num_frames=round(AUDIO_SEGMENT_SEC * sample_rate),
83
+ )
84
+ y_segment = y_segment.mean(0) # オーディオが複数チャンネルの場合(例: ステレオ)、チャンネルを平均してモノラルに変換。
85
+
86
+ # サンプルレートのリサンプリング
87
+ # オーディオデータのサンプルレートが self.sample_rate と異なる場合、指定されたサンプルレートにリサンプリング。
88
+ if sample_rate != self.sample_rate:
89
+ y_segment = torchaudio.functional.resample(
90
+ y_segment,
91
+ sample_rate,
92
+ self.sample_rate,
93
+ resampling_method="kaiser_window", # Kaiserウィンドウによるリサンプリングアルゴリズムを適用。
94
+ )
95
+ return y_segment, segment_tokens
96
+
97
+ def getitem_wholesong(self, index):
98
+ """
99
+ Return a pair of (audio, midi) for the given index.
100
+ """
101
+ y, sr = torchaudio.load(self.audio_filelist[index]) # 読み込まれた波形データ(テンソル形���)。形状は (チャンネル数, サンプル数)。
102
+ y = y.mean(0) # モノラル化
103
+ # サンプルレートのリサンプリング
104
+ if sr != self.sample_rate:
105
+ y = torchaudio.functional.resample(
106
+ y, sr, self.sample_rate,
107
+ resampling_method="kaiser_window"
108
+ )
109
+ midi = self.midi_list[index].pm
110
+ return y, midi
111
+
112
+ # collateはバッチにまとめる役割の関数
113
+ def collate_wholesong(self, batch): # batch: データセットから取り出された複数のデータ(オーディオとMIDIのペア)のリスト。
114
+ # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
115
+ # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
116
+ # 出力: テンソルの形状は (バッチサイズ, サンプル数)。
117
+ batch_audio = torch.stack([b[0] for b in batch], dim=0)
118
+ midi = [b[1] for b in batch] # バッチ内の各曲のMIDIデータをリストとしてまとめる。
119
+ return batch_audio, midi # テンソル, リスト
120
+
121
+ def collate_batch(self, batch): # データセットから取り出されたセグメント化されたオーディオテンソルとセグメント化されたMIDIトークン列のリスト。
122
+ # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
123
+ # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
124
+ # 出力: テンソルの形状は (バッチサイズ, サンプル数)。
125
+ batch_audio = torch.stack([b[0] for b in batch], dim=0)
126
+ batch_tokens = [b[1] for b in batch] # バッチ内の各セグメントのトークン列をテンソル?リスト形式で取得。
127
+
128
+ # バッチ内のMIDIトークン列の長さを揃えるためにパディング
129
+ # torch.nn.utils.rnn.pad_sequence は、異なる長さのシーケンス(テンソルリスト)をパディングして同じ長さに揃えるためのPyTorchユーティリティ(すべてのテンソルは同じ次元数である必要があります(長さ以外は一致)。)
130
+ # batch_first = True: パディング後のテンソル形状を (バッチサイズ, 最大長さ) に設定
131
+ batch_tokens_pad = torch.nn.utils.rnn.pad_sequence(
132
+ batch_tokens, batch_first=True, padding_value=self.voc_dict["pad"]
133
+ )
134
+ return batch_audio, batch_tokens_pad # テンソル, テンソル (バッチサイズ, サンプル数), (バッチサイズ, 最大トークンの長さ)
135
+
136
+
137
+ class CustomDataset(AMTDatasetBase):
138
+ def __init__(
139
+ self,
140
+ midi_root: str = "/content/drive/MyDrive/B4/Humtrans/midi",
141
+ wav_root: str = "/content/wav_rms",
142
+ split: str = "train",
143
+ sample_rate: int = 16000,
144
+ apply_pedal: bool = True,
145
+ whole_song: bool = False,
146
+ ):
147
+ """
148
+ MIDIとWAVのペアをロードするデータセットクラス
149
+
150
+ Args:
151
+ midi_root (str): MIDIファイルが保存されているルートフォルダ
152
+ wav_root (str): WAVファイルが保存されているフォルダ
153
+ split (str): 使用するデータセットの分割 ('train', 'valid', 'test')
154
+ sample_rate (int): サンプルレート
155
+ apply_pedal (bool): ペダルの適用
156
+ whole_song (bool): 曲全体をロードするか
157
+ """
158
+ # MIDIフォルダのパスを設定
159
+ self.midi_root = f"/content/filtered_{split}_midi"
160
+ self.wav_root = wav_root
161
+ self.sample_rate = sample_rate
162
+ self.split = split
163
+
164
+ # MIDIとWAVのペアを見つける
165
+ flist_midi, flist_audio = self._get_paired_files()
166
+
167
+ # 親クラスのコンストラクタを呼び出し
168
+ super().__init__(
169
+ flist_audio,
170
+ flist_midi,
171
+ sample_rate,
172
+ voc_dict=voc_single_track,
173
+ apply_pedal=apply_pedal,
174
+ whole_song=whole_song,
175
+ )
176
+
177
+
178
+ def _get_paired_files(self):
179
+ """
180
+ MIDIフォルダとWAVフォルダからペアとなるファイルリストを作成する
181
+
182
+ Returns:
183
+ flist_midi (list): 対応するMIDIファイルのリスト
184
+ flist_audio (list): 対応するWAVファイルのリスト
185
+ """
186
+ flist_midi = []
187
+ flist_audio = []
188
+
189
+ # MIDIフォルダからMIDIファイルを取得
190
+ midi_files = [f for f in os.listdir(self.midi_root) if f.endswith(".mid")]
191
+
192
+ for midi_file in midi_files:
193
+ # MIDIファイルのパスを構築
194
+ midi_path = os.path.join(self.midi_root, midi_file)
195
+
196
+ # WAVファイルのパスを構築 (拡張子を変更)
197
+ wav_file = os.path.splitext(midi_file)[0] + ".wav"
198
+ wav_path = os.path.join(self.wav_root, wav_file)
199
+
200
+ # WAVファイルが存在するか確認
201
+ if os.path.exists(wav_path):
202
+ flist_midi.append(midi_path)
203
+ flist_audio.append(wav_path)
204
+ else:
205
+ print(f"対応するWAVファイルが見つかりません: {midi_file}")
206
+
207
+ print(f"{self.split}データセット: {len(flist_midi)} ペアのMIDI-WAVが見つかりました。")
208
+ return flist_midi, flist_audio
eval.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pretty_midi as pm
3
+ import mir_eval
4
+
5
+ """# Evaluation function"""
6
+
7
+ def extract_midi(midi: pm.PrettyMIDI, program=0): # MIDIデータを読み込んだ PrettyMIDI オブジェクト, MIDIチャンネル(楽器番号)を指定。
8
+ intervals = [] # 音符ごとの開始時間と終了時間のペアを格納したNumPy配列
9
+ pitches = [] # 音符ごとの音高(MIDIノート番号)のNumPy配列。
10
+ pm_notes = midi.instruments[program].notes # programで指定された対象楽器に含まれる全てのノート情報を取得。
11
+ """
12
+ 例;
13
+ instruments = [
14
+ Instrument 0 (Piano): [Note(start=0.5, end=1.0, pitch=60), ...],
15
+ Instrument 1 (Violin): [Note(start=1.0, end=1.5, pitch=62), ...]
16
+ ]
17
+ """
18
+ # ノートを順番に処理
19
+ for note in pm_notes:
20
+ intervals.append((note.start, note.end)) # 音符の開始・終了時間のペアを intervals に追加。
21
+ pitches.append(note.pitch) # 音符の音高を pitches に追加。
22
+
23
+ return np.array(intervals), np.array(pitches) # intervals: 2D配列(各行が1つの音符の開始・終了時間を表す。), pitches: 1D配列(各要素が1つの音符の音高(ピッチ)を表す。)
24
+
25
+
26
+ def evaluate_midi(est_midi: pm.PrettyMIDI, ref_midi: pm.PrettyMIDI, program=0):
27
+ est_intervals, est_pitches = extract_midi(est_midi, program)
28
+ ref_intervals, ref_pitches = extract_midi(ref_midi, program)
29
+
30
+ # mir_eval ライブラリの transcription モジュールを使って、音符の一致度を評価します。
31
+ dict_eval = mir_eval.transcription.evaluate(
32
+ ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05)
33
+
34
+ return dict_eval # dict_eval: 評価結果の辞書。
infer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import os
4
+ import soundfile as sf
5
+ import librosa
6
+
7
+ from utils import unpack_sequence, token_seg_list_to_midi
8
+ from train import LitTranscriber
9
+ from utils import rms_normalize_wav
10
+
11
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # backend/src を指す
12
+ PTH_PATH = os.path.join(BASE_DIR, "model.pth") # ✅ .pth に変更
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ def load_model():
17
+ args = {
18
+ "n_mels": 128,
19
+ "sample_rate": 16000,
20
+ "n_fft": 1024,
21
+ "hop_length": 128,
22
+ }
23
+ model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99)
24
+ state_dict = torch.load(PTH_PATH, map_location=device) # ✅ .pthをロード
25
+ model.load_state_dict(state_dict)
26
+ #model.to(device) # ✅ デバイスに転送
27
+ model.eval()
28
+ return model
29
+
30
+
31
+
32
+ def convert_to_pcm_wav(input_path, output_path):
33
+ # librosaで読み込み(自動的にPCM形式に変換される)
34
+ y, sr = librosa.load(input_path, sr=16000, mono=True)
35
+ sf.write(output_path, y, sr)
36
+
37
+
38
+ def infer_midi_from_wav(input_wav_path: str) -> str:
39
+ model = load_model()
40
+
41
+ converted_path = os.path.join(BASE_DIR, "converted_input.wav")
42
+ convert_to_pcm_wav(input_wav_path, converted_path)
43
+
44
+ normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav")
45
+ rms_normalize_wav(converted_path, normalized_path, target_rms=0.1)
46
+
47
+ waveform, sr = torchaudio.load(normalized_path)
48
+ waveform = waveform.mean(0).to(device)
49
+
50
+ if sr != model.transcriber.sr:
51
+ waveform = torchaudio.functional.resample(
52
+ waveform, sr, model.transcriber.sr
53
+ ).to(device)
54
+
55
+ with torch.no_grad():
56
+ output_tokens = model(waveform)
57
+
58
+ unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy())
59
+ unpadded_tokens = [t[1:] for t in unpadded_tokens]
60
+ est_midi = token_seg_list_to_midi(unpadded_tokens)
61
+
62
+ midi_path = os.path.join(BASE_DIR, "output.mid")
63
+ est_midi.write(midi_path)
64
+ print(f"MIDI saved at: {midi_path}")
65
+ return midi_path
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # main.py
2
+ from infer import infer_midi_from_wav
3
+
4
+ input_wav = "/Users/yagihayato/Downloads/my-melody-app/backend/data/humming.wav"
5
+ midi_path = infer_midi_from_wav(input_wav)
6
+ print("推論完了: ", midi_path)
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02c7d22a827f8ae20f7271c9d8ae4d07c082a73a6663ea1379e7361ac2cea759
3
+ size 45979110
model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import split_audio_into_segments
4
+ from utils import LogMelspec
5
+ from utils import LogCQT
6
+ from transformers import T5Config, T5ForConditionalGeneration
7
+
8
+
9
+ class Seq2SeqTranscriber(nn.Module):
10
+ def __init__(
11
+ self, n_mels: int, sample_rate: int, n_fft: int, hop_length: int, voc_dict: dict
12
+ ):
13
+ super().__init__()
14
+ self.infer_max_len = 200 # 推論時の最大シーケンス長。
15
+ self.voc_dict = voc_dict # トークン辞書と
16
+ self.n_voc_token = voc_dict["n_voc"] # トークンの数を保持。
17
+ self.t5config = T5Config.from_pretrained("google/t5-v1_1-small") # Googleの事前学習済み T5 モデル(小型バージョン)の設定をロード。
18
+ # カスタム設定を T5 の設定に追加:
19
+ custom_configs = {
20
+ "vocab_size": self.n_voc_token, # トークン辞書のサイズ。
21
+ "pad_token_id": voc_dict["pad"], # パディングトークンID。
22
+ "d_model": 96, # モデルの隠れ次元数(ここではメルバンド数に設定)。
23
+ }
24
+
25
+ for k, v in custom_configs.items():
26
+ self.t5config.__setattr__(k, v)
27
+
28
+ self.transformer = T5ForConditionalGeneration(self.t5config) # カスタム設定を適用した T5 モデルをロード。
29
+ # self.melspec = LogMelspec(sample_rate, n_fft, n_mels, hop_length) # LogMelspec クラスを使用して、音声波形を対数メルスペクトログラムに変換するモジュールを作成。
30
+ # CQT モデルインスタンス作成
31
+ self.log_cqt = LogCQT(16000, 84, 128, 12)
32
+ self.sr = sample_rate # サンプルレートをインスタンス変数として保存。
33
+
34
+ # モデルの学習時に呼び出され、損失値を計算します。
35
+ def forward(self, wav, labels):
36
+ # spec = self.melspec(wav).transpose(-1, -2) # 音声波形(wav)をメルスペクトログラム(spec)に変換。LogMelspec クラスのforwardを実行
37
+ spec = self.log_cqt(wav).transpose(-1, -2)
38
+ # ※ .transpose(-1, -2): T5 モデルは通常 [バッチ, 時間ステップ, 次元] の形状を期待するため、周波数軸(メルバンド)と時間軸を入れ替えます。
39
+ # T5 モデルのフォワードパス
40
+ print("sepc.shape: ", spec.shape) # (1, n_bins, time)
41
+ outs = self.transformer.forward(
42
+ inputs_embeds=spec, return_dict=True, labels=labels
43
+ )
44
+ return outs # outs は辞書形式で損失値や出力トークン列を含む。
45
+
46
+ # 入力音声波形(wav)から推定トークン列を生成する関数
47
+ def infer(self, wav):
48
+ """
49
+ Infer the transcription of a single audio file.
50
+ The input audio file is split into segments of 2 seconds
51
+ before passing to the transformer.
52
+ """
53
+ wav_segs = split_audio_into_segments(wav, self.sr) # 音声波形を固定長(例: 2秒)に分割。
54
+ #spec = self.melspec(wav_segs).transpose(-1, -2) # 各セグメントをメルスペクトログラムに変換。
55
+ spec = self.log_cqt(wav_segs).transpose(-1, -2)
56
+ # generate: T5 モデルの推論モードを使用して、トークン列を生成。
57
+ outs = self.transformer.generate(
58
+ inputs_embeds=spec,
59
+ max_length=self.infer_max_len, # 推論時の最大出力長。
60
+ num_beams=5, # ビームサーチを無効化し、単純なグリーディーサーチ。
61
+ do_sample=False, # サンプリングを無効化。
62
+ return_dict_in_generate=False,
63
+ )
64
+ return outs #推論結果として生成されたトークン列を返します。 #形状: (セグメント数, 最大トークン長)
65
+
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flask関連
2
+ flask
3
+ flask-cors
4
+
5
+ # 推論・学習用ライブラリ
6
+ torch
7
+ torchaudio
8
+ pytorch-lightning
9
+ lightning
10
+
11
+
12
+ # データ前処理
13
+ numpy
14
+ pandas
15
+ soundfile
16
+ tqdm
17
+ librosa
18
+ soundfile
19
+ gradio
20
+
21
+
22
+ # MIDI・音楽処理
23
+ pretty_midi
24
+ mir_eval
25
+
26
+ # モデル構成・特徴抽出
27
+ nnAudio
28
+ transformers
29
+
train.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ import torch.utils.data as data
5
+
6
+ from utils import voc_single_track
7
+ from model import Seq2SeqTranscriber
8
+ from dataset import CustomDataset
9
+ from utils import token_seg_list_to_midi, unpack_sequence
10
+ from eval import evaluate_midi
11
+
12
+ class LitTranscriber(pl.LightningModule):
13
+ def __init__(
14
+ self,
15
+ transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど
16
+ lr: float, # 学習率(Learning Rate)。
17
+ lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。
18
+ lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。
19
+ ):
20
+ super().__init__()
21
+ self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。
22
+ self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。
23
+ self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。
24
+
25
+ # 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。
26
+ self.transcriber = Seq2SeqTranscriber(
27
+ **transcriber_args, voc_dict=self.voc_dict
28
+ )
29
+ # 引数から渡された学習率や減衰関連の設定を保存。
30
+ self.lr = lr
31
+ self.lr_decay = lr_decay
32
+ self.lr_decay_interval = lr_decay_interval
33
+
34
+ # 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数
35
+ def forward(self, y: torch.Tensor):
36
+ transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。
37
+ return transcriber_infer #推論結果として生成されたトークン列を返します。
38
+
39
+ # モデルのトレーニング中に1バッチの処理を実行する関数
40
+ def training_step(self, batch, batch_idx):
41
+ y, t = batch # 入力波形(y)と正解トークン列(t)を分割。
42
+ tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。
43
+ loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。
44
+ t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。
45
+ mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。
46
+
47
+ # モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。
48
+ # 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。
49
+ accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
50
+ #損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。
51
+ self.log("train_loss", loss)
52
+ self.log("train_accr", accr)
53
+ return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。
54
+
55
+ # モデルの検証時に1バッチの処理を実行する関数
56
+ def validation_step(self, batch, batch_idx):
57
+ assert not self.transcriber.training # モデルがトレーニングモードでないことを確認
58
+ y, t = batch
59
+ tf_out = self.transcriber(y, t)
60
+ loss = tf_out.loss
61
+ t = t.detach()
62
+ mask = t != self.voc_dict["pad"]
63
+ accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
64
+ self.log("vali_loss", loss)
65
+ self.log("vali_accr", accr)
66
+ return loss
67
+
68
+ def test_step(self, batch, batch_idx):
69
+ y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割
70
+ # テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。
71
+ y = y[0]
72
+ ref_midi = ref_midi[0]
73
+ with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。
74
+ est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長)
75
+ unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換
76
+ unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。
77
+ est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。
78
+ dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。
79
+ # 評価結果をPyTorch Lightningのログ機能で記録。
80
+ dict_log = {}
81
+ for key in dict_eval:
82
+ dict_log["test/" + key] = dict_eval[key]
83
+ self.log_dict(dict_log, batch_size=1)
84
+
85
+ def train_dataloader(self):
86
+ # 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
87
+ dset = CustomDataset(
88
+ split="train", # データセットの分割
89
+ sample_rate=16000,
90
+ apply_pedal=True,
91
+ whole_song=False
92
+ )
93
+ # データローダーの作成
94
+ return data.DataLoader(
95
+ dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
96
+ collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
97
+ batch_size=64, # 1バッチに含めるサンプル数。
98
+ shuffle=True, # データセットをランダムにシャッフル。
99
+ pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
100
+ num_workers=32, # データローディングに使用するプロセス数。
101
+ )
102
+
103
+ def val_dataloader(self):
104
+ # 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
105
+ dset = CustomDataset(
106
+ split="valid", # データセットの分割
107
+ sample_rate=16000,
108
+ apply_pedal=True,
109
+ whole_song=False
110
+ )
111
+ # データローダーの作成
112
+ return data.DataLoader(
113
+ dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
114
+ collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
115
+ batch_size=32, # 1バッチに含めるサンプル数。
116
+ shuffle=True, # データセットをランダムにシャッフル。
117
+ pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
118
+ num_workers=16, # データローディングに使用するプロセス数。
119
+ )
120
+ def test_dataloader(self):
121
+ dset = CustomDataset(
122
+ split="test", # データセットの分割
123
+ sample_rate=16000,
124
+ apply_pedal=True,
125
+ whole_song=True
126
+ )
127
+ return data.DataLoader(
128
+ dataset=dset,
129
+ collate_fn=dset.collate_wholesong,
130
+ batch_size=1,
131
+ shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。
132
+ pin_memory=True,
133
+ )
134
+
135
+ # モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数
136
+ def configure_optimizers(self):
137
+ return torch.optim.AdamW(self.parameters(), lr=self.lr)
138
+ # Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。
139
+ # self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。
140
+
141
+ from lightning.pytorch.callbacks import ModelCheckpoint
142
+ from lightning.pytorch.loggers import TensorBoardLogger
143
+
144
+ # チェックポイント保存の設定
145
+ if __name__ == "__main__":
146
+ checkpoint_callback = ModelCheckpoint(
147
+ dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints", # 保存先
148
+ filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}", # ファイル名フォーマット
149
+ save_top_k=3, # 上位3つのチェックポイントを保存
150
+ monitor="vali_loss", # 訓練データの損失を監視
151
+ mode="min", # 損失が小さいほど良いモデル
152
+ every_n_epochs=1, # 1エポックごとに保存
153
+ )
154
+
155
+ # ロガー設定
156
+ logger = TensorBoardLogger(
157
+ save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs",
158
+ name="transcription"
159
+ )
160
+
161
+ # トレーナーの設定
162
+ trainer = pl.Trainer(
163
+ logger=logger,
164
+ callbacks=[checkpoint_callback], # チェックポイントコールバックを追加
165
+ enable_checkpointing=True,
166
+ accelerator="gpu",
167
+ devices=1,
168
+ max_epochs=400,
169
+ )
170
+
171
+ # モデルの引数
172
+ args = {
173
+ "n_mels": 128,
174
+ "sample_rate": 16000,
175
+ "n_fft": 1024,
176
+ "hop_length": 128,
177
+ }
178
+
179
+ lightning_module = LitTranscriber(
180
+ transcriber_args=args,
181
+ lr=1e-4, # 0.0001
182
+ lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる
183
+ )
184
+
185
+ # 学習の開始
186
+ trainer.fit(lightning_module)
187
+
188
+
utils.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import attrgetter
2
+ import dataclasses
3
+ import numpy as np
4
+ import pretty_midi as pm
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from nnAudio.features import CQT
10
+ import soundfile as sf
11
+
12
+ from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC
13
+ from config import voc_single_track
14
+
15
+ @dataclasses.dataclass
16
+ class Event:
17
+ prog: int
18
+ onset: bool
19
+ pitch: int
20
+
21
+ class MIDITokenExtractor:
22
+ """
23
+ ・MIDIデータ(音符、タイミング、ペダル情報など)を抽出し、トークン列に変換する。
24
+ ・セグメント単位でMIDIデータを分割し、各セグメントをトークンとして表現。
25
+ """
26
+ def __init__(self, midi_path, voc_dict, apply_pedal=True):
27
+ """
28
+ ・MIDIデータを読み込み、必要に応じてサステインペダルを適用する初期化処理を行う。
29
+ """
30
+ self.pm = pm.PrettyMIDI(midi_path) # MIDIファイルをPrettyMIDIで読み込む
31
+ if apply_pedal:
32
+ self.pm_apply_pedal(self.pm) # サステインペダル処理を適用
33
+ self.voc_dict = voc_dict # トークンの定義辞書
34
+ self.multi_track = "instrument" in voc_dict # マルチトラック対応のフラグ
35
+
36
+ def pm_apply_pedal(self, pm: pm.PrettyMIDI, program=0):
37
+ """
38
+ Apply sustain pedal by stretching the notes in the pm object.
39
+ """
40
+ # 1: Record the onset positions of each notes as a dictionary
41
+ onset_dict = dict()
42
+ for note in pm.instruments[program].notes:
43
+ if note.pitch in onset_dict:
44
+ onset_dict[note.pitch].append(note.start)
45
+ else:
46
+ onset_dict[note.pitch] = [note.start]
47
+ for k in onset_dict.keys():
48
+ onset_dict[k] = np.sort(onset_dict[k])
49
+
50
+ # 2: Record the pedal on/off state of each time frame
51
+ arr_pedal = np.zeros(
52
+ round(pm.get_end_time()*FRAME_PER_SEC)+100, dtype=bool)
53
+ pedal_on_time = -1
54
+ list_pedaloff_time = []
55
+ for cc in pm.instruments[program].control_changes:
56
+ if cc.number == 64:
57
+ if (cc.value > 0) and (pedal_on_time < 0):
58
+ pedal_on_time = round(cc.time*FRAME_PER_SEC)
59
+ elif (cc.value == 0) and (pedal_on_time >= 0):
60
+ pedal_off_time = round(cc.time*FRAME_PER_SEC)
61
+ arr_pedal[pedal_on_time:pedal_off_time] = True
62
+ list_pedaloff_time.append(cc.time)
63
+ pedal_on_time = -1
64
+ list_pedaloff_time = np.sort(list_pedaloff_time)
65
+
66
+ # 3: Stretch the notes (modify note.end)
67
+ for note in pm.instruments[program].notes:
68
+ # 3-1: Determine whether sustain pedal is on at note.end. If not, do nothing.
69
+ # 3-2: Find the next note onset time and next pedal off time after note.end.
70
+ # 3-3: Extend note.end till the minimum of next_onset and next_pedaloff.
71
+ note_off_frame = round(note.end*FRAME_PER_SEC)
72
+ pitch = note.pitch
73
+ if arr_pedal[note_off_frame]:
74
+ next_onset = np.argwhere(onset_dict[pitch] > note.end)
75
+ next_onset = np.inf if len(
76
+ next_onset) == 0 else onset_dict[pitch][next_onset[0, 0]]
77
+ next_pedaloff = np.argwhere(list_pedaloff_time > note.end)
78
+ next_pedaloff = np.inf if len(
79
+ next_pedaloff) == 0 else list_pedaloff_time[next_pedaloff[0, 0]]
80
+ new_noteoff_time = max(note.end, min(next_onset, next_pedaloff))
81
+ new_noteoff_time = min(new_noteoff_time, pm.get_end_time())
82
+ note.end = new_noteoff_time
83
+
84
+ def get_segment_tokens(self, start, end):
85
+ """
86
+ Transform a segment of the MIDI file into a sequence of tokens.
87
+ """
88
+ dict_event = dict() # a dictionary that maps time to a list of events.
89
+
90
+ def append_to_dict_event(time, item):
91
+ if time in dict_event:
92
+ dict_event[time].append(item)
93
+ else:
94
+ dict_event[time] = [item]
95
+
96
+ list_events = [] # events section
97
+ list_tie_section = [] # tie section
98
+
99
+ for instrument in self.pm.instruments:
100
+ prog = instrument.program
101
+ for note in instrument.notes:
102
+ note_end = round(note.end * FRAME_PER_SEC) # 音符終了時刻(フレーム単位)
103
+ note_start = round(note.start * FRAME_PER_SEC) # 音符開始時刻(フレーム単位)
104
+ if (note_end < start) or (note_start >= end):
105
+ # セグメント外の音符は無視
106
+ continue
107
+ if (note_start < start) and (note_end >= start):
108
+ # If the note starts before the segment, but ends in the segment
109
+ # it is added to the tie section.
110
+ # セグメント開始時より前に始まり、セグメント内で終了する音符(ピッチ)をタイセクションに追加
111
+ list_tie_section.append(self.voc_dict["note"] + note.pitch)
112
+ if note_end < end:
113
+ # セグメント内で終了する場合、イベントに終了時刻を記録
114
+ append_to_dict_event(
115
+ note_end - start, Event(prog, False, note.pitch)
116
+ )
117
+ continue
118
+ assert note_start >= start
119
+ # セグメント内で開始
120
+ append_to_dict_event(note_start - start, Event(prog, True, note.pitch))
121
+ if note_end < end:
122
+ # セグメント内で終了
123
+ append_to_dict_event(
124
+ note_end - start, Event(prog, False, note.pitch)
125
+ )
126
+
127
+ cur_onset = None
128
+ cur_prog = -1
129
+
130
+ for time in sorted(dict_event.keys()): # 現在の相対時間(time)をトークン化し、イベント列に追加
131
+ list_events.append(self.voc_dict["time"] + time) # self.voc_dict["time"]はtimeトークンの開始ID(133)。これに相対時間を足す
132
+ for event in sorted(dict_event[time], key=attrgetter("pitch", "onset")): # 同一時間内でピッチを昇順に、onset→offsetの順になるようにソートする。
133
+ if cur_onset != event.onset:
134
+ cur_onset = event.onset # オンセットオフセットが変わった場合に更新
135
+ list_events.append(self.voc_dict["onset"] + int(event.onset)) # オンセットオフセットトークンを追加
136
+ list_events.append(self.voc_dict["note"] + event.pitch) # 音符トークンを追加
137
+
138
+ # Concatenate tie section, endtie token, and event section
139
+ list_tie_section.append(self.voc_dict["endtie"]) # ID:2を追加
140
+ list_events.append(self.voc_dict["eos"]) # ID: 1を追加
141
+ tokens = np.concatenate((list_tie_section, list_events)).astype(int)
142
+ return tokens
143
+
144
+
145
+ """## Detokenizer
146
+ Transforms a list of MIDI-like token sequences into a MIDI file.
147
+ """
148
+
149
+ def parse_id(voc_dict: dict, id: int):
150
+ """
151
+ トークンIDを解析し、トークンの種類名とその相対IDを返す関数。
152
+ """
153
+ keys = voc_dict["keylist"] # トークンの種類リストを取得
154
+ token_name = keys[0] # デフォルトの種類を "pad" に設定
155
+
156
+ # トークンの種類を特定する
157
+ for k in keys:
158
+ if id < voc_dict[k]: # 現在の種類の開始位置より小さい場合、前の種類が該当
159
+ break
160
+ token_name = k # 現在の種類名を更新
161
+
162
+ # 該当種類内での相対IDを計算
163
+ token_id = id - voc_dict[token_name]
164
+
165
+ return token_name, token_id # 種類名と相対IDを返す
166
+
167
+
168
+
169
+ def to_second(n):
170
+ """
171
+ フレーム数を秒単位に変換する関数。
172
+ """
173
+ return n * FRAME_STEP_SIZE_SEC
174
+
175
+
176
+
177
+ def find_note(list, n):
178
+ """
179
+ タプルリストの最初の要素から指定された値を検索し、そのインデックスを返す関数。
180
+ """
181
+ li_elem = [a for a, _ in list] # 最初の要素だけを抽出したリスト
182
+ try:
183
+ idx = li_elem.index(n) # n が存在する場合、そのインデックスを取得
184
+ except ValueError:
185
+ return -1 # 存在しない場合は -1 を返す
186
+ return idx # 見つかった場合、そのインデックスを返す
187
+
188
+
189
+
190
+ def token_seg_list_to_midi(token_seg_list: list):
191
+ """
192
+ トークン列リストをMIDIファイルに変換する関数。
193
+ """
194
+ # MIDIデータと楽器の初期化
195
+ midi_data = pm.PrettyMIDI()
196
+ piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
197
+ piano = pm.Instrument(program=piano_program)
198
+
199
+ list_onset = [] # 開始時刻を記録するリスト ※次のセグメント処理を行うときにlist_onsetには、最終音とタイの可能性がある音が記録されている
200
+ cur_time = 0 # 前回のセグメントの終了時間
201
+
202
+ # トークンセグメントごとの処理
203
+ for token_seg in token_seg_list:
204
+ list_tie = [] # タイ結合された音符
205
+ cur_relative_time = -1 # セグメント内の相対時間
206
+ cur_onset = -1 # 現在のオンセット状態
207
+ tie_end = False # タイ結合が終了したかどうか
208
+
209
+ for token in token_seg:
210
+ # トークンを解析
211
+ token_name, token_id = parse_id(voc_single_track, token)
212
+
213
+ if token_name == "note":
214
+ # 音符処理
215
+ if not tie_end: # タイ結合の場合
216
+ list_tie.append(token_id) # タイのnote番号(相対ID)を追加
217
+ elif cur_onset == 1:
218
+ list_onset.append((token_id, cur_time + cur_relative_time)) # 開始時刻を記録
219
+ elif cur_onset == 0:
220
+ # 終了処理
221
+ i = find_note(list_onset, token_id)
222
+ if i >= 0:
223
+ start = list_onset[i][1]
224
+ end = cur_time + cur_relative_time
225
+ if start < end: # 開始時刻 < 終了時刻の場合のみ追加
226
+ new_note = pm.Note(100, token_id, start, end)
227
+ piano.notes.append(new_note)
228
+ list_onset.pop(i)
229
+
230
+ elif token_name == "onset":
231
+ # オンセット/オフセットの更新
232
+ if tie_end:
233
+ if token_id == 1:
234
+ cur_onset = 1 # 開始
235
+ elif token_id == 0:
236
+ cur_onset = 0 # 終了
237
+
238
+ elif token_name == "time":
239
+ # 相対時間の更新
240
+ if tie_end:
241
+ cur_relative_time = to_second(token_id)
242
+
243
+ elif token_name == "endtie":
244
+ # タイ結合終了処理
245
+ tie_end = True
246
+ for note, start in list_onset: # list_onsetには前回のセグメントで未処理の最終音が含まれる
247
+ if note not in list_tie: # list_onsetにあるにも関わらず、list_tieにない場合
248
+ if start < cur_time:
249
+ new_note = pm.Note(100, note, start, cur_time) # 前回のセグメントの終了時をendとする(end=cur_time)
250
+ piano.notes.append(new_note)
251
+ list_onset.remove((note, start))
252
+
253
+ # 現在の時間を更新
254
+ cur_time += AUDIO_SEGMENT_SEC
255
+
256
+ # 楽器をMIDIデータに追加
257
+ midi_data.instruments.append(piano)
258
+ return midi_data
259
+
260
+
261
+ # 固定長のセグメントに分割する関数
262
+ def split_audio_into_segments(y: torch.Tensor, sr: int): # オーディオデータ(テンソル), オーディオのサンプルレート
263
+ audio_segment_samples = round(AUDIO_SEGMENT_SEC * sr) # 1セグメントの長さをサンプル数で計算
264
+ pad_size = audio_segment_samples - (y.shape[-1] % audio_segment_samples) # セグメントサイズできっちり分割できるようにpadするサイズを計算
265
+
266
+ y = F.pad(y, (0, pad_size)) # padを追加
267
+ assert (y.shape[-1] % audio_segment_samples) == 0 # 割り切れない場合assertをする
268
+ n_chunks = y.shape[-1] // audio_segment_samples # セグメント数を計算
269
+ # 固定長のセグメントに分割
270
+ y_segments = torch.chunk(y, chunks=n_chunks, dim=-1) # torch.chunk: テンソルを指定した数(n_chunks)に分割、dim=-1: サンプル次元(最後の次元)で分割
271
+ return torch.stack(y_segments, dim=0) # 分割したセグメントを1つのテンソルにまとめる。dim=0: セグメント数をバッチ次元として結合。
272
+ # 形状: (セグメント数, セグメント長)。
273
+
274
+ # 推論時に作られたpadされたseqから有効な部分を抽出する関数
275
+ def unpack_sequence(x: torch.Tensor, eos_id: int=1):
276
+ seqs = [] # 各シーケンスを切り出して保存するリストを初期化。
277
+ max_length = x.shape[-1] # シーケンスの最大長を取得。whileループの範囲をチェックするために使用。
278
+ for seq in x: # テンソル x の各シーケンス(行)を処理
279
+ # eosトークンを探す
280
+ start_pos = 0
281
+ pos = 0
282
+ while (pos < max_length) and (seq[pos] != eos_id): # 現在地が最大長を超えない&現在地がeosではない場合
283
+ pos += 1
284
+ # ループ終了後:pos には、終了トークン(eos_id)の位置(またはシーケンスの末尾)が格納されます。
285
+ end_pos = pos+1
286
+ seqs.append(seq[start_pos:end_pos]) #開始位置(start_pos)から終了位置(end_pos)までの部分を切り出し、リスト seqs に追加。
287
+ return seqs
288
+
289
+ class LogMelspec(nn.Module):
290
+ def __init__(self, sample_rate, n_fft, n_mels, hop_length):
291
+ super().__init__()
292
+ self.melspec = torchaudio.transforms.MelSpectrogram(
293
+ sample_rate=sample_rate,
294
+ n_fft=n_fft, # FFT(高速フーリエ変換)に使用するポイント数
295
+ hop_length=hop_length, # ストライド(フレームのシフト幅(サンプル単位))。通常、n_fft // 4 などの値を設定
296
+ f_min=20.0, # メルフィルタバンクの最小周波数。20Hz 以上が推奨(人間の聴覚範囲)。
297
+ n_mels=n_mels, # メルスペクトログラムの周波数軸方向の次元数(メルバンド数)。
298
+ mel_scale="slaney", # メル尺度を計算する方法."slaney": より音響的に意味のあるスケール
299
+ norm="slaney", # メルフィルタバンクの正規化方法. "slaney" を指定するとフィルタバンクがエネルギーで正規化
300
+ power=1, # 出力スペクトログラムのエネルギース���ール, 1: 振幅スペクトログラム。
301
+ )
302
+ self.eps = 1e-5 # 対数計算時のゼロ除算エラーを防ぐための閾値。メルスペクトログラムの値が 1e-5 未満の場合、この値に置き換える。
303
+
304
+ def forward(self, x): # 入力: モノラル: (バッチサイズ, サンプル数) or ステレオ: (バッチサイズ, チャンネル数, サンプル数)
305
+ spec = self.melspec(x) # 入力波形 x からメルスペクトログラムを計算, 出力:メルスペクトログラム: (バッチサイズ, メルバンド数, フレーム数)
306
+ safe_spec = torch.clamp(spec, min=self.eps) # メルスペクトログラムの最小値を self.eps に制限。値が非常に小さい場合でも、対数計算が可能。
307
+ log_spec = torch.log(safe_spec) #メルスペクトログラムを対数スケールに変換。
308
+ return log_spec # (バッチサイズ, メルバンド数, フレーム数) のテンソル。各値は対数スケールのメルスペクトログラム。
309
+
310
+
311
+ class LogCQT(nn.Module):
312
+ def __init__(self, sample_rate, n_bins, hop_length, bins_per_octave):
313
+ super().__init__()
314
+ self.cqt = CQT(
315
+ sr=sample_rate,
316
+ hop_length=hop_length,
317
+ fmin=32.7, # 低周波数の最小値 (通常は32.7Hz, C1)
318
+ fmax=8000, # 最高周波数
319
+ n_bins=n_bins,
320
+ bins_per_octave=bins_per_octave,
321
+ verbose=False
322
+ ).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに載せる
323
+ self.eps = 1e-5 # ゼロ除算を防ぐ閾値
324
+
325
+ def forward(self, x):
326
+ x = x.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに送る
327
+ cqt_spec = self.cqt(x) # (B, n_bins, time)
328
+ safe_spec = torch.clamp(cqt_spec, min=self.eps) # 小さな値をカット
329
+ log_cqt = torch.log(safe_spec) # 対数変換
330
+ return log_cqt # (B, n_bins, time)
331
+
332
+ def normalize_rms_torch(audio_tensor, target_rms=0.1):
333
+ rms = torch.sqrt(torch.mean(audio_tensor**2)).item()
334
+ if rms < 1e-6:
335
+ print("音が非常に小さいため、最小値でスケーリングします")
336
+ rms = 1e-6
337
+ scaling_factor = target_rms / rms
338
+ return audio_tensor * scaling_factor
339
+
340
+ def rms_normalize_wav(input_path, output_path, target_rms=0.1):
341
+ waveform, sr = torchaudio.load(input_path)
342
+ waveform = waveform.mean(0, keepdim=True) # モノラル化
343
+ normalized = normalize_rms_torch(waveform, target_rms)
344
+ torchaudio.save(output_path, normalized, sample_rate=sr)
345
+