Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import os
|
5 |
+
from infer import infer_midi_from_wav
|
6 |
+
|
7 |
+
def predict_midi(wav_file):
|
8 |
+
# 一時ファイルに保存
|
9 |
+
input_path = "input.wav"
|
10 |
+
wav_file.save(input_path)
|
11 |
+
|
12 |
+
midi_path = infer_midi_from_wav(input_path)
|
13 |
+
|
14 |
+
return midi_path # Gradioはファイルのパスを返すとダウンロードリンクとして表示される
|
15 |
+
|
16 |
+
iface = gr.Interface(
|
17 |
+
fn=predict_midi,
|
18 |
+
inputs=gr.Audio(source="upload", type="file"),
|
19 |
+
outputs=gr.File(label="Download MIDI"),
|
20 |
+
title="🎵 My Melody: 鼻歌→MIDI変換",
|
21 |
+
description="鼻歌をアップロードすると、MIDIファイルがダウンロードできます。"
|
22 |
+
)
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
iface.launch()
|
config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning.pytorch as pl
|
2 |
+
|
3 |
+
SR = 16000
|
4 |
+
AUDIO_SEGMENT_SEC = 2.0
|
5 |
+
SEGMENT_N_FRAMES = 200
|
6 |
+
FRAME_STEP_SIZE_SEC = 0.01
|
7 |
+
FRAME_PER_SEC = 100
|
8 |
+
|
9 |
+
pl.seed_everything(1234)
|
10 |
+
|
11 |
+
"""# MIDI-like token"""
|
12 |
+
|
13 |
+
# MIDI-like token definitions
|
14 |
+
|
15 |
+
N_NOTE = 128
|
16 |
+
N_TIME = 205
|
17 |
+
N_SPECIAL = 3
|
18 |
+
|
19 |
+
voc_single_track = {
|
20 |
+
"pad": 0,
|
21 |
+
"eos": 1,
|
22 |
+
"endtie": 2,
|
23 |
+
"note": N_SPECIAL, # 3から始まるという意味
|
24 |
+
"onset": N_SPECIAL+N_NOTE,
|
25 |
+
"time": N_SPECIAL+N_NOTE+2,
|
26 |
+
"n_voc": N_SPECIAL+N_NOTE+2+N_TIME+3,
|
27 |
+
"keylist": ["pad", "eos", "endtie", "note", "onset", "time"]}
|
dataset.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import tqdm
|
7 |
+
|
8 |
+
import torch.utils.data as data
|
9 |
+
|
10 |
+
from utils import MIDITokenExtractor
|
11 |
+
from config import voc_single_track
|
12 |
+
from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC, SEGMENT_N_FRAMES
|
13 |
+
|
14 |
+
"""# Dataset
|
15 |
+
Uses MAESTRO v3.0.0 dataset.
|
16 |
+
"""
|
17 |
+
|
18 |
+
class AMTDatasetBase(data.Dataset):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
flist_audio, # オーディオファイルのパスをリスト形式で渡す
|
22 |
+
flist_midi, # MIDIファイルのパスをリスト形式で渡す
|
23 |
+
sample_rate, # オーディオファイルのサンプリングレートを指定。全てのオーディオがこれにリサンプリングされる。
|
24 |
+
voc_dict, # トークン定義を渡す
|
25 |
+
apply_pedal=True,
|
26 |
+
whole_song=False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.midi_filelist = flist_midi
|
30 |
+
self.audio_filelist = flist_audio
|
31 |
+
self.audio_metalist = [torchaudio.info(f) for f in flist_audio] # 各オーディオファイルのメタ情報(サンプルレート、フレーム数など)を収集します。
|
32 |
+
self.voc_dict = voc_dict
|
33 |
+
# 各MIDIファイルを MIDITokenExtractor を使ってトークン化し、その結果をリストとして保持します。
|
34 |
+
self.midi_list = [
|
35 |
+
MIDITokenExtractor(f, voc_dict, apply_pedal)
|
36 |
+
for f in tqdm.tqdm(self.midi_filelist, desc="load dataset")
|
37 |
+
]
|
38 |
+
self.sample_rate = sample_rate
|
39 |
+
self.whole_song = whole_song
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.audio_filelist)
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
"""
|
46 |
+
Return a pair of (audio, tokens) for the given index.
|
47 |
+
On the training stage, return a random segment from the song.
|
48 |
+
On the test stage, return the audio and MIDI of the whole song.
|
49 |
+
"""
|
50 |
+
if not self.whole_song:
|
51 |
+
return self.getitem_segment(index)
|
52 |
+
else:
|
53 |
+
return self.getitem_wholesong(index)
|
54 |
+
|
55 |
+
def getitem_segment(self, index, start_pos=None): # 対象ファイルを指定するindexとセグメントの開始位置(フレーム単位)。Noneの場合はランダムに選択
|
56 |
+
metadata = self.audio_metalist[index]
|
57 |
+
num_frames = metadata.num_frames # オーディオの全体の「サンプル数」。
|
58 |
+
sample_rate = metadata.sample_rate
|
59 |
+
duration_y = round(num_frames / float(sample_rate) * FRAME_PER_SEC) # オーディオ全体の長さをフレーム単位に変換
|
60 |
+
midi_item = self.midi_list[index]
|
61 |
+
|
62 |
+
# セグメントの開始位置と終了位置(フレーム単位)を決定。
|
63 |
+
if start_pos is None: # np.random.randint を使用して、オーディオ全体からランダムに開始位置を選択。
|
64 |
+
segment_start = np.random.randint(duration_y - SEGMENT_N_FRAMES)
|
65 |
+
else: # start_pos が指定されている場合
|
66 |
+
segment_start = start_pos
|
67 |
+
segment_end = segment_start + SEGMENT_N_FRAMES
|
68 |
+
# オーディオセグメントのサンプル単位の開始位置
|
69 |
+
segment_start_sample = round(
|
70 |
+
segment_start * FRAME_STEP_SIZE_SEC * sample_rate
|
71 |
+
)
|
72 |
+
|
73 |
+
# セグメント範囲(segment_start ~ segment_end)に対応するMIDIトークン列を抽出。
|
74 |
+
segment_tokens = midi_item.get_segment_tokens(segment_start, segment_end)
|
75 |
+
segment_tokens = torch.from_numpy(segment_tokens).long() # NumPy配列をPyTorchテンソルに変換。long()でテンソルのデータ型を64ビット整数(long)に設定。
|
76 |
+
|
77 |
+
# 指定されたセグメント範囲のオーディオデータを読み込む。
|
78 |
+
# frame_offset から始まる範囲を num_frames サンプル分読み込む。
|
79 |
+
y_segment, _ = torchaudio.load(
|
80 |
+
self.audio_filelist[index],
|
81 |
+
frame_offset=segment_start_sample,
|
82 |
+
num_frames=round(AUDIO_SEGMENT_SEC * sample_rate),
|
83 |
+
)
|
84 |
+
y_segment = y_segment.mean(0) # オーディオが複数チャンネルの場合(例: ステレオ)、チャンネルを平均してモノラルに変換。
|
85 |
+
|
86 |
+
# サンプルレートのリサンプリング
|
87 |
+
# オーディオデータのサンプルレートが self.sample_rate と異なる場合、指定されたサンプルレートにリサンプリング。
|
88 |
+
if sample_rate != self.sample_rate:
|
89 |
+
y_segment = torchaudio.functional.resample(
|
90 |
+
y_segment,
|
91 |
+
sample_rate,
|
92 |
+
self.sample_rate,
|
93 |
+
resampling_method="kaiser_window", # Kaiserウィンドウによるリサンプリングアルゴリズムを適用。
|
94 |
+
)
|
95 |
+
return y_segment, segment_tokens
|
96 |
+
|
97 |
+
def getitem_wholesong(self, index):
|
98 |
+
"""
|
99 |
+
Return a pair of (audio, midi) for the given index.
|
100 |
+
"""
|
101 |
+
y, sr = torchaudio.load(self.audio_filelist[index]) # 読み込まれた波形データ(テンソル形���)。形状は (チャンネル数, サンプル数)。
|
102 |
+
y = y.mean(0) # モノラル化
|
103 |
+
# サンプルレートのリサンプリング
|
104 |
+
if sr != self.sample_rate:
|
105 |
+
y = torchaudio.functional.resample(
|
106 |
+
y, sr, self.sample_rate,
|
107 |
+
resampling_method="kaiser_window"
|
108 |
+
)
|
109 |
+
midi = self.midi_list[index].pm
|
110 |
+
return y, midi
|
111 |
+
|
112 |
+
# collateはバッチにまとめる役割の関数
|
113 |
+
def collate_wholesong(self, batch): # batch: データセットから取り出された複数のデータ(オーディオとMIDIのペア)のリスト。
|
114 |
+
# b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
|
115 |
+
# torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
|
116 |
+
# 出力: テンソルの形状は (バッチサイズ, サンプル数)。
|
117 |
+
batch_audio = torch.stack([b[0] for b in batch], dim=0)
|
118 |
+
midi = [b[1] for b in batch] # バッチ内の各曲のMIDIデータをリストとしてまとめる。
|
119 |
+
return batch_audio, midi # テンソル, リスト
|
120 |
+
|
121 |
+
def collate_batch(self, batch): # データセットから取り出されたセグメント化されたオーディオテンソルとセグメント化されたMIDIトークン列のリスト。
|
122 |
+
# b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
|
123 |
+
# torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
|
124 |
+
# 出力: テンソルの形状は (バッチサイズ, サンプル数)。
|
125 |
+
batch_audio = torch.stack([b[0] for b in batch], dim=0)
|
126 |
+
batch_tokens = [b[1] for b in batch] # バッチ内の各セグメントのトークン列をテンソル?リスト形式で取得。
|
127 |
+
|
128 |
+
# バッチ内のMIDIトークン列の長さを揃えるためにパディング
|
129 |
+
# torch.nn.utils.rnn.pad_sequence は、異なる長さのシーケンス(テンソルリスト)をパディングして同じ長さに揃えるためのPyTorchユーティリティ(すべてのテンソルは同じ次元数である必要があります(長さ以外は一致)。)
|
130 |
+
# batch_first = True: パディング後のテンソル形状を (バッチサイズ, 最大長さ) に設定
|
131 |
+
batch_tokens_pad = torch.nn.utils.rnn.pad_sequence(
|
132 |
+
batch_tokens, batch_first=True, padding_value=self.voc_dict["pad"]
|
133 |
+
)
|
134 |
+
return batch_audio, batch_tokens_pad # テンソル, テンソル (バッチサイズ, サンプル数), (バッチサイズ, 最大トークンの長さ)
|
135 |
+
|
136 |
+
|
137 |
+
class CustomDataset(AMTDatasetBase):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
midi_root: str = "/content/drive/MyDrive/B4/Humtrans/midi",
|
141 |
+
wav_root: str = "/content/wav_rms",
|
142 |
+
split: str = "train",
|
143 |
+
sample_rate: int = 16000,
|
144 |
+
apply_pedal: bool = True,
|
145 |
+
whole_song: bool = False,
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
MIDIとWAVのペアをロードするデータセットクラス
|
149 |
+
|
150 |
+
Args:
|
151 |
+
midi_root (str): MIDIファイルが保存されているルートフォルダ
|
152 |
+
wav_root (str): WAVファイルが保存されているフォルダ
|
153 |
+
split (str): 使用するデータセットの分割 ('train', 'valid', 'test')
|
154 |
+
sample_rate (int): サンプルレート
|
155 |
+
apply_pedal (bool): ペダルの適用
|
156 |
+
whole_song (bool): 曲全体をロードするか
|
157 |
+
"""
|
158 |
+
# MIDIフォルダのパスを設定
|
159 |
+
self.midi_root = f"/content/filtered_{split}_midi"
|
160 |
+
self.wav_root = wav_root
|
161 |
+
self.sample_rate = sample_rate
|
162 |
+
self.split = split
|
163 |
+
|
164 |
+
# MIDIとWAVのペアを見つける
|
165 |
+
flist_midi, flist_audio = self._get_paired_files()
|
166 |
+
|
167 |
+
# 親クラスのコンストラクタを呼び出し
|
168 |
+
super().__init__(
|
169 |
+
flist_audio,
|
170 |
+
flist_midi,
|
171 |
+
sample_rate,
|
172 |
+
voc_dict=voc_single_track,
|
173 |
+
apply_pedal=apply_pedal,
|
174 |
+
whole_song=whole_song,
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def _get_paired_files(self):
|
179 |
+
"""
|
180 |
+
MIDIフォルダとWAVフォルダからペアとなるファイルリストを作成する
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
flist_midi (list): 対応するMIDIファイルのリスト
|
184 |
+
flist_audio (list): 対応するWAVファイルのリスト
|
185 |
+
"""
|
186 |
+
flist_midi = []
|
187 |
+
flist_audio = []
|
188 |
+
|
189 |
+
# MIDIフォルダからMIDIファイルを取得
|
190 |
+
midi_files = [f for f in os.listdir(self.midi_root) if f.endswith(".mid")]
|
191 |
+
|
192 |
+
for midi_file in midi_files:
|
193 |
+
# MIDIファイルのパスを構築
|
194 |
+
midi_path = os.path.join(self.midi_root, midi_file)
|
195 |
+
|
196 |
+
# WAVファイルのパスを構築 (拡張子を変更)
|
197 |
+
wav_file = os.path.splitext(midi_file)[0] + ".wav"
|
198 |
+
wav_path = os.path.join(self.wav_root, wav_file)
|
199 |
+
|
200 |
+
# WAVファイルが存在するか確認
|
201 |
+
if os.path.exists(wav_path):
|
202 |
+
flist_midi.append(midi_path)
|
203 |
+
flist_audio.append(wav_path)
|
204 |
+
else:
|
205 |
+
print(f"対応するWAVファイルが見つかりません: {midi_file}")
|
206 |
+
|
207 |
+
print(f"{self.split}データセット: {len(flist_midi)} ペアのMIDI-WAVが見つかりました。")
|
208 |
+
return flist_midi, flist_audio
|
eval.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pretty_midi as pm
|
3 |
+
import mir_eval
|
4 |
+
|
5 |
+
"""# Evaluation function"""
|
6 |
+
|
7 |
+
def extract_midi(midi: pm.PrettyMIDI, program=0): # MIDIデータを読み込んだ PrettyMIDI オブジェクト, MIDIチャンネル(楽器番号)を指定。
|
8 |
+
intervals = [] # 音符ごとの開始時間と終了時間のペアを格納したNumPy配列
|
9 |
+
pitches = [] # 音符ごとの音高(MIDIノート番号)のNumPy配列。
|
10 |
+
pm_notes = midi.instruments[program].notes # programで指定された対象楽器に含まれる全てのノート情報を取得。
|
11 |
+
"""
|
12 |
+
例;
|
13 |
+
instruments = [
|
14 |
+
Instrument 0 (Piano): [Note(start=0.5, end=1.0, pitch=60), ...],
|
15 |
+
Instrument 1 (Violin): [Note(start=1.0, end=1.5, pitch=62), ...]
|
16 |
+
]
|
17 |
+
"""
|
18 |
+
# ノートを順番に処理
|
19 |
+
for note in pm_notes:
|
20 |
+
intervals.append((note.start, note.end)) # 音符の開始・終了時間のペアを intervals に追加。
|
21 |
+
pitches.append(note.pitch) # 音符の音高を pitches に追加。
|
22 |
+
|
23 |
+
return np.array(intervals), np.array(pitches) # intervals: 2D配列(各行が1つの音符の開始・終了時間を表す。), pitches: 1D配列(各要素が1つの音符の音高(ピッチ)を表す。)
|
24 |
+
|
25 |
+
|
26 |
+
def evaluate_midi(est_midi: pm.PrettyMIDI, ref_midi: pm.PrettyMIDI, program=0):
|
27 |
+
est_intervals, est_pitches = extract_midi(est_midi, program)
|
28 |
+
ref_intervals, ref_pitches = extract_midi(ref_midi, program)
|
29 |
+
|
30 |
+
# mir_eval ライブラリの transcription モジュールを使って、音符の一致度を評価します。
|
31 |
+
dict_eval = mir_eval.transcription.evaluate(
|
32 |
+
ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05)
|
33 |
+
|
34 |
+
return dict_eval # dict_eval: 評価結果の辞書。
|
infer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import os
|
4 |
+
import soundfile as sf
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
from utils import unpack_sequence, token_seg_list_to_midi
|
8 |
+
from train import LitTranscriber
|
9 |
+
from utils import rms_normalize_wav
|
10 |
+
|
11 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # backend/src を指す
|
12 |
+
PTH_PATH = os.path.join(BASE_DIR, "model.pth") # ✅ .pth に変更
|
13 |
+
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
def load_model():
|
17 |
+
args = {
|
18 |
+
"n_mels": 128,
|
19 |
+
"sample_rate": 16000,
|
20 |
+
"n_fft": 1024,
|
21 |
+
"hop_length": 128,
|
22 |
+
}
|
23 |
+
model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99)
|
24 |
+
state_dict = torch.load(PTH_PATH, map_location=device) # ✅ .pthをロード
|
25 |
+
model.load_state_dict(state_dict)
|
26 |
+
#model.to(device) # ✅ デバイスに転送
|
27 |
+
model.eval()
|
28 |
+
return model
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def convert_to_pcm_wav(input_path, output_path):
|
33 |
+
# librosaで読み込み(自動的にPCM形式に変換される)
|
34 |
+
y, sr = librosa.load(input_path, sr=16000, mono=True)
|
35 |
+
sf.write(output_path, y, sr)
|
36 |
+
|
37 |
+
|
38 |
+
def infer_midi_from_wav(input_wav_path: str) -> str:
|
39 |
+
model = load_model()
|
40 |
+
|
41 |
+
converted_path = os.path.join(BASE_DIR, "converted_input.wav")
|
42 |
+
convert_to_pcm_wav(input_wav_path, converted_path)
|
43 |
+
|
44 |
+
normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav")
|
45 |
+
rms_normalize_wav(converted_path, normalized_path, target_rms=0.1)
|
46 |
+
|
47 |
+
waveform, sr = torchaudio.load(normalized_path)
|
48 |
+
waveform = waveform.mean(0).to(device)
|
49 |
+
|
50 |
+
if sr != model.transcriber.sr:
|
51 |
+
waveform = torchaudio.functional.resample(
|
52 |
+
waveform, sr, model.transcriber.sr
|
53 |
+
).to(device)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
output_tokens = model(waveform)
|
57 |
+
|
58 |
+
unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy())
|
59 |
+
unpadded_tokens = [t[1:] for t in unpadded_tokens]
|
60 |
+
est_midi = token_seg_list_to_midi(unpadded_tokens)
|
61 |
+
|
62 |
+
midi_path = os.path.join(BASE_DIR, "output.mid")
|
63 |
+
est_midi.write(midi_path)
|
64 |
+
print(f"MIDI saved at: {midi_path}")
|
65 |
+
return midi_path
|
main.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
from infer import infer_midi_from_wav
|
3 |
+
|
4 |
+
input_wav = "/Users/yagihayato/Downloads/my-melody-app/backend/data/humming.wav"
|
5 |
+
midi_path = infer_midi_from_wav(input_wav)
|
6 |
+
print("推論完了: ", midi_path)
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02c7d22a827f8ae20f7271c9d8ae4d07c082a73a6663ea1379e7361ac2cea759
|
3 |
+
size 45979110
|
model.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from utils import split_audio_into_segments
|
4 |
+
from utils import LogMelspec
|
5 |
+
from utils import LogCQT
|
6 |
+
from transformers import T5Config, T5ForConditionalGeneration
|
7 |
+
|
8 |
+
|
9 |
+
class Seq2SeqTranscriber(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, n_mels: int, sample_rate: int, n_fft: int, hop_length: int, voc_dict: dict
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.infer_max_len = 200 # 推論時の最大シーケンス長。
|
15 |
+
self.voc_dict = voc_dict # トークン辞書と
|
16 |
+
self.n_voc_token = voc_dict["n_voc"] # トークンの数を保持。
|
17 |
+
self.t5config = T5Config.from_pretrained("google/t5-v1_1-small") # Googleの事前学習済み T5 モデル(小型バージョン)の設定をロード。
|
18 |
+
# カスタム設定を T5 の設定に追加:
|
19 |
+
custom_configs = {
|
20 |
+
"vocab_size": self.n_voc_token, # トークン辞書のサイズ。
|
21 |
+
"pad_token_id": voc_dict["pad"], # パディングトークンID。
|
22 |
+
"d_model": 96, # モデルの隠れ次元数(ここではメルバンド数に設定)。
|
23 |
+
}
|
24 |
+
|
25 |
+
for k, v in custom_configs.items():
|
26 |
+
self.t5config.__setattr__(k, v)
|
27 |
+
|
28 |
+
self.transformer = T5ForConditionalGeneration(self.t5config) # カスタム設定を適用した T5 モデルをロード。
|
29 |
+
# self.melspec = LogMelspec(sample_rate, n_fft, n_mels, hop_length) # LogMelspec クラスを使用して、音声波形を対数メルスペクトログラムに変換するモジュールを作成。
|
30 |
+
# CQT モデルインスタンス作成
|
31 |
+
self.log_cqt = LogCQT(16000, 84, 128, 12)
|
32 |
+
self.sr = sample_rate # サンプルレートをインスタンス変数として保存。
|
33 |
+
|
34 |
+
# モデルの学習時に呼び出され、損失値を計算します。
|
35 |
+
def forward(self, wav, labels):
|
36 |
+
# spec = self.melspec(wav).transpose(-1, -2) # 音声波形(wav)をメルスペクトログラム(spec)に変換。LogMelspec クラスのforwardを実行
|
37 |
+
spec = self.log_cqt(wav).transpose(-1, -2)
|
38 |
+
# ※ .transpose(-1, -2): T5 モデルは通常 [バッチ, 時間ステップ, 次元] の形状を期待するため、周波数軸(メルバンド)と時間軸を入れ替えます。
|
39 |
+
# T5 モデルのフォワードパス
|
40 |
+
print("sepc.shape: ", spec.shape) # (1, n_bins, time)
|
41 |
+
outs = self.transformer.forward(
|
42 |
+
inputs_embeds=spec, return_dict=True, labels=labels
|
43 |
+
)
|
44 |
+
return outs # outs は辞書形式で損失値や出力トークン列を含む。
|
45 |
+
|
46 |
+
# 入力音声波形(wav)から推定トークン列を生成する関数
|
47 |
+
def infer(self, wav):
|
48 |
+
"""
|
49 |
+
Infer the transcription of a single audio file.
|
50 |
+
The input audio file is split into segments of 2 seconds
|
51 |
+
before passing to the transformer.
|
52 |
+
"""
|
53 |
+
wav_segs = split_audio_into_segments(wav, self.sr) # 音声波形を固定長(例: 2秒)に分割。
|
54 |
+
#spec = self.melspec(wav_segs).transpose(-1, -2) # 各セグメントをメルスペクトログラムに変換。
|
55 |
+
spec = self.log_cqt(wav_segs).transpose(-1, -2)
|
56 |
+
# generate: T5 モデルの推論モードを使用して、トークン列を生成。
|
57 |
+
outs = self.transformer.generate(
|
58 |
+
inputs_embeds=spec,
|
59 |
+
max_length=self.infer_max_len, # 推論時の最大出力長。
|
60 |
+
num_beams=5, # ビームサーチを無効化し、単純なグリーディーサーチ。
|
61 |
+
do_sample=False, # サンプリングを無効化。
|
62 |
+
return_dict_in_generate=False,
|
63 |
+
)
|
64 |
+
return outs #推論結果として生成されたトークン列を返します。 #形状: (セグメント数, 最大トークン長)
|
65 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flask関連
|
2 |
+
flask
|
3 |
+
flask-cors
|
4 |
+
|
5 |
+
# 推論・学習用ライブラリ
|
6 |
+
torch
|
7 |
+
torchaudio
|
8 |
+
pytorch-lightning
|
9 |
+
lightning
|
10 |
+
|
11 |
+
|
12 |
+
# データ前処理
|
13 |
+
numpy
|
14 |
+
pandas
|
15 |
+
soundfile
|
16 |
+
tqdm
|
17 |
+
librosa
|
18 |
+
soundfile
|
19 |
+
gradio
|
20 |
+
|
21 |
+
|
22 |
+
# MIDI・音楽処理
|
23 |
+
pretty_midi
|
24 |
+
mir_eval
|
25 |
+
|
26 |
+
# モデル構成・特徴抽出
|
27 |
+
nnAudio
|
28 |
+
transformers
|
29 |
+
|
train.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch.utils.data as data
|
5 |
+
|
6 |
+
from utils import voc_single_track
|
7 |
+
from model import Seq2SeqTranscriber
|
8 |
+
from dataset import CustomDataset
|
9 |
+
from utils import token_seg_list_to_midi, unpack_sequence
|
10 |
+
from eval import evaluate_midi
|
11 |
+
|
12 |
+
class LitTranscriber(pl.LightningModule):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
transcriber_args: dict, # Seq2SeqTranscriber モジュールの初期化に必要な引数を含む辞書. 例えばn_melsやn_fft, hop_lengthなど
|
16 |
+
lr: float, # 学習率(Learning Rate)。
|
17 |
+
lr_decay: float = 1.0, # 学習率の減衰率。デフォルト値は 1.0(減衰なし)。
|
18 |
+
lr_decay_interval: int = 1, # 学習率の減衰間隔(エポック単位)。
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.save_hyperparameters() # 渡された引数をハイパーパラメータとして保存(PyTorch Lightning機能)。
|
22 |
+
self.voc_dict = voc_single_track # voc_dict をインスタンス変数に保存。
|
23 |
+
self.n_voc = self.voc_dict["n_voc"] # トークンの総数(n_voc)も保存。
|
24 |
+
|
25 |
+
# 渡された引数 transcriber_args を展開して Seq2SeqTranscriber モジュールを初期化。トークン辞書(voc_dict)を追加で渡す。
|
26 |
+
self.transcriber = Seq2SeqTranscriber(
|
27 |
+
**transcriber_args, voc_dict=self.voc_dict
|
28 |
+
)
|
29 |
+
# 引数から渡された学習率や減衰関連の設定を保存。
|
30 |
+
self.lr = lr
|
31 |
+
self.lr_decay = lr_decay
|
32 |
+
self.lr_decay_interval = lr_decay_interval
|
33 |
+
|
34 |
+
# 推論時に Seq2SeqTranscriber の推論機能(infer メソッド)を呼び出す関数
|
35 |
+
def forward(self, y: torch.Tensor):
|
36 |
+
transcriber_infer = self.transcriber.infer(y) # 入力波形(y)を Seq2SeqTranscriber の infer メソッドに渡して推論を実行。
|
37 |
+
return transcriber_infer #推論結果として生成されたトークン列を返します。
|
38 |
+
|
39 |
+
# モデルのトレーニング中に1バッチの処理を実行する関数
|
40 |
+
def training_step(self, batch, batch_idx):
|
41 |
+
y, t = batch # 入力波形(y)と正解トークン列(t)を分割。
|
42 |
+
tf_out = self.transcriber(y, t) # 入力データを Seq2SeqTranscriber に渡してtranscriberのforwardで損失値(loss)を取得。出力は、損失値やロジット(各トークンのスコア)を含むオブジェクト。
|
43 |
+
loss = tf_out.loss # T5モデルの出力に含まれる損失値(loss)。学習時の損失計算は、CrossEntropyLoss に基づいています。
|
44 |
+
t = t.detach() # 正解トークン(t)を計算グラフから切り離して、後続の計算が逆伝播に影響しないようにする。
|
45 |
+
mask = t != self.voc_dict["pad"] # 正解トークン列のうち、パディングトークン(pad)を無視するためのマスクを作成。
|
46 |
+
|
47 |
+
# モデルの出力(logits)から最も高いスコアのトークン(argmax(-1))を取得し、マスクされた正解トークン(t[mask])と比較。
|
48 |
+
# 正しく予測されたトークンの数を、マスクされたトークン全体の数で割って精度を計算。
|
49 |
+
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
|
50 |
+
#損失値(loss)と精度(accr)を PyTorch Lightning のログ機能を使って記録。
|
51 |
+
self.log("train_loss", loss)
|
52 |
+
self.log("train_accr", accr)
|
53 |
+
return loss # 計算された損失を返します。この値はPyTorch Lightningがバックプロパゲーションを行うために使用します。
|
54 |
+
|
55 |
+
# モデルの検証時に1バッチの処理を実行する関数
|
56 |
+
def validation_step(self, batch, batch_idx):
|
57 |
+
assert not self.transcriber.training # モデルがトレーニングモードでないことを確認
|
58 |
+
y, t = batch
|
59 |
+
tf_out = self.transcriber(y, t)
|
60 |
+
loss = tf_out.loss
|
61 |
+
t = t.detach()
|
62 |
+
mask = t != self.voc_dict["pad"]
|
63 |
+
accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
|
64 |
+
self.log("vali_loss", loss)
|
65 |
+
self.log("vali_accr", accr)
|
66 |
+
return loss
|
67 |
+
|
68 |
+
def test_step(self, batch, batch_idx):
|
69 |
+
y, ref_midi = batch # 音声データ(y)と参照MIDI(ref_midi)を分割
|
70 |
+
# テストデータセットでは、通常1つのサンプルを1バッチとして処理するため、インデックス0の要素を取り出します。
|
71 |
+
y = y[0]
|
72 |
+
ref_midi = ref_midi[0]
|
73 |
+
with torch.no_grad(): # 推論中は勾配計算が不要なため、この文脈を使ってメモリ消費を削減。
|
74 |
+
est_tokens = self.forward(y) # 入力音声 y を用いてモデルに推論を実行し、MIDIトークン列(est_tokens)を生成。モデルの forward メソッドを呼び出して推論を実行。この時forward内では分割されて実行される, 形状: (セグメント数, 最大トークン長)
|
75 |
+
unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy()) # 推論結果(est_tokens)からパディングを取り除き、有効なトークン列だけを取得。GPU上のテンソルをCPUに移動してNumPy配列に変換
|
76 |
+
unpadded_tokens = [t[1:] for t in unpadded_tokens] # 各トークン列から開始トークン(<sos>など)を除外。
|
77 |
+
est_midi = token_seg_list_to_midi(unpadded_tokens) # パディングを除去したトークン列(unpadded_tokens)をMIDI形式に変換。
|
78 |
+
dict_eval = evaluate_midi(est_midi, ref_midi) # 参照MIDI(ref_midi)と推論MIDI(est_midi)を比較し、性能を評価。
|
79 |
+
# 評価結果をPyTorch Lightningのログ機能で記録。
|
80 |
+
dict_log = {}
|
81 |
+
for key in dict_eval:
|
82 |
+
dict_log["test/" + key] = dict_eval[key]
|
83 |
+
self.log_dict(dict_log, batch_size=1)
|
84 |
+
|
85 |
+
def train_dataloader(self):
|
86 |
+
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
|
87 |
+
dset = CustomDataset(
|
88 |
+
split="train", # データセットの分割
|
89 |
+
sample_rate=16000,
|
90 |
+
apply_pedal=True,
|
91 |
+
whole_song=False
|
92 |
+
)
|
93 |
+
# データローダーの作成
|
94 |
+
return data.DataLoader(
|
95 |
+
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
|
96 |
+
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
|
97 |
+
batch_size=64, # 1バッチに含めるサンプル数。
|
98 |
+
shuffle=True, # データセットをランダムにシャッフル。
|
99 |
+
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
|
100 |
+
num_workers=32, # データローディングに使用するプロセス数。
|
101 |
+
)
|
102 |
+
|
103 |
+
def val_dataloader(self):
|
104 |
+
# 事前に定義されたデータセットクラス。音声(波形データ)とMIDIファイルをペアとしてロードします。
|
105 |
+
dset = CustomDataset(
|
106 |
+
split="valid", # データセットの分割
|
107 |
+
sample_rate=16000,
|
108 |
+
apply_pedal=True,
|
109 |
+
whole_song=False
|
110 |
+
)
|
111 |
+
# データローダーの作成
|
112 |
+
return data.DataLoader(
|
113 |
+
dataset=dset, # 使用するデータセットとして、初期化済みの Maestro を指定。
|
114 |
+
collate_fn=dset.collate_batch, # バッチ処理時のデータ整形関数。
|
115 |
+
batch_size=32, # 1バッチに含めるサンプル数。
|
116 |
+
shuffle=True, # データセットをランダムにシャッフル。
|
117 |
+
pin_memory=True, # データをピン留め(ページロック)してメモリに保存。
|
118 |
+
num_workers=16, # データローディングに使用するプロセス数。
|
119 |
+
)
|
120 |
+
def test_dataloader(self):
|
121 |
+
dset = CustomDataset(
|
122 |
+
split="test", # データセットの分割
|
123 |
+
sample_rate=16000,
|
124 |
+
apply_pedal=True,
|
125 |
+
whole_song=True
|
126 |
+
)
|
127 |
+
return data.DataLoader(
|
128 |
+
dataset=dset,
|
129 |
+
collate_fn=dset.collate_wholesong,
|
130 |
+
batch_size=1,
|
131 |
+
shuffle=False, # テストデータの順序を固定。理由: 再現性を確保するため。
|
132 |
+
pin_memory=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
# モデルの学習時に、損失関数の値を最小化するようにパラメータを更新するオプティマイザ(最適化手法)を指定する関数
|
136 |
+
def configure_optimizers(self):
|
137 |
+
return torch.optim.AdamW(self.parameters(), lr=self.lr)
|
138 |
+
# Adamオプティマイザ の改良版であり、L2正則化の代わりに ウェイトデカイ(Weight Decay) を導入したオプティマイザ。
|
139 |
+
# self.parameters(): モデルの学習可能なすべてのパラメータ(重みとバイアス)を取得。
|
140 |
+
|
141 |
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
142 |
+
from lightning.pytorch.loggers import TensorBoardLogger
|
143 |
+
|
144 |
+
# チェックポイント保存の設定
|
145 |
+
if __name__ == "__main__":
|
146 |
+
checkpoint_callback = ModelCheckpoint(
|
147 |
+
dirpath="/content/drive/MyDrive/B4/Humtrans/m2m100/checkpoints", # 保存先
|
148 |
+
filename="epoch{epoch:02d}-vali_loss{vali_loss:.2f}", # ファイル名フォーマット
|
149 |
+
save_top_k=3, # 上位3つのチェックポイントを保存
|
150 |
+
monitor="vali_loss", # 訓練データの損失を監視
|
151 |
+
mode="min", # 損失が小さいほど良いモデル
|
152 |
+
every_n_epochs=1, # 1エポックごとに保存
|
153 |
+
)
|
154 |
+
|
155 |
+
# ロガー設定
|
156 |
+
logger = TensorBoardLogger(
|
157 |
+
save_dir="/content/drive/MyDrive/B4/Humtrans/m2m100/logs",
|
158 |
+
name="transcription"
|
159 |
+
)
|
160 |
+
|
161 |
+
# トレーナーの設定
|
162 |
+
trainer = pl.Trainer(
|
163 |
+
logger=logger,
|
164 |
+
callbacks=[checkpoint_callback], # チェックポイントコールバックを追加
|
165 |
+
enable_checkpointing=True,
|
166 |
+
accelerator="gpu",
|
167 |
+
devices=1,
|
168 |
+
max_epochs=400,
|
169 |
+
)
|
170 |
+
|
171 |
+
# モデルの引数
|
172 |
+
args = {
|
173 |
+
"n_mels": 128,
|
174 |
+
"sample_rate": 16000,
|
175 |
+
"n_fft": 1024,
|
176 |
+
"hop_length": 128,
|
177 |
+
}
|
178 |
+
|
179 |
+
lightning_module = LitTranscriber(
|
180 |
+
transcriber_args=args,
|
181 |
+
lr=1e-4, # 0.0001
|
182 |
+
lr_decay=0.99, # 各エポックの度に0.99倍して減衰させる
|
183 |
+
)
|
184 |
+
|
185 |
+
# 学習の開始
|
186 |
+
trainer.fit(lightning_module)
|
187 |
+
|
188 |
+
|
utils.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import attrgetter
|
2 |
+
import dataclasses
|
3 |
+
import numpy as np
|
4 |
+
import pretty_midi as pm
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from nnAudio.features import CQT
|
10 |
+
import soundfile as sf
|
11 |
+
|
12 |
+
from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC
|
13 |
+
from config import voc_single_track
|
14 |
+
|
15 |
+
@dataclasses.dataclass
|
16 |
+
class Event:
|
17 |
+
prog: int
|
18 |
+
onset: bool
|
19 |
+
pitch: int
|
20 |
+
|
21 |
+
class MIDITokenExtractor:
|
22 |
+
"""
|
23 |
+
・MIDIデータ(音符、タイミング、ペダル情報など)を抽出し、トークン列に変換する。
|
24 |
+
・セグメント単位でMIDIデータを分割し、各セグメントをトークンとして表現。
|
25 |
+
"""
|
26 |
+
def __init__(self, midi_path, voc_dict, apply_pedal=True):
|
27 |
+
"""
|
28 |
+
・MIDIデータを読み込み、必要に応じてサステインペダルを適用する初期化処理を行う。
|
29 |
+
"""
|
30 |
+
self.pm = pm.PrettyMIDI(midi_path) # MIDIファイルをPrettyMIDIで読み込む
|
31 |
+
if apply_pedal:
|
32 |
+
self.pm_apply_pedal(self.pm) # サステインペダル処理を適用
|
33 |
+
self.voc_dict = voc_dict # トークンの定義辞書
|
34 |
+
self.multi_track = "instrument" in voc_dict # マルチトラック対応のフラグ
|
35 |
+
|
36 |
+
def pm_apply_pedal(self, pm: pm.PrettyMIDI, program=0):
|
37 |
+
"""
|
38 |
+
Apply sustain pedal by stretching the notes in the pm object.
|
39 |
+
"""
|
40 |
+
# 1: Record the onset positions of each notes as a dictionary
|
41 |
+
onset_dict = dict()
|
42 |
+
for note in pm.instruments[program].notes:
|
43 |
+
if note.pitch in onset_dict:
|
44 |
+
onset_dict[note.pitch].append(note.start)
|
45 |
+
else:
|
46 |
+
onset_dict[note.pitch] = [note.start]
|
47 |
+
for k in onset_dict.keys():
|
48 |
+
onset_dict[k] = np.sort(onset_dict[k])
|
49 |
+
|
50 |
+
# 2: Record the pedal on/off state of each time frame
|
51 |
+
arr_pedal = np.zeros(
|
52 |
+
round(pm.get_end_time()*FRAME_PER_SEC)+100, dtype=bool)
|
53 |
+
pedal_on_time = -1
|
54 |
+
list_pedaloff_time = []
|
55 |
+
for cc in pm.instruments[program].control_changes:
|
56 |
+
if cc.number == 64:
|
57 |
+
if (cc.value > 0) and (pedal_on_time < 0):
|
58 |
+
pedal_on_time = round(cc.time*FRAME_PER_SEC)
|
59 |
+
elif (cc.value == 0) and (pedal_on_time >= 0):
|
60 |
+
pedal_off_time = round(cc.time*FRAME_PER_SEC)
|
61 |
+
arr_pedal[pedal_on_time:pedal_off_time] = True
|
62 |
+
list_pedaloff_time.append(cc.time)
|
63 |
+
pedal_on_time = -1
|
64 |
+
list_pedaloff_time = np.sort(list_pedaloff_time)
|
65 |
+
|
66 |
+
# 3: Stretch the notes (modify note.end)
|
67 |
+
for note in pm.instruments[program].notes:
|
68 |
+
# 3-1: Determine whether sustain pedal is on at note.end. If not, do nothing.
|
69 |
+
# 3-2: Find the next note onset time and next pedal off time after note.end.
|
70 |
+
# 3-3: Extend note.end till the minimum of next_onset and next_pedaloff.
|
71 |
+
note_off_frame = round(note.end*FRAME_PER_SEC)
|
72 |
+
pitch = note.pitch
|
73 |
+
if arr_pedal[note_off_frame]:
|
74 |
+
next_onset = np.argwhere(onset_dict[pitch] > note.end)
|
75 |
+
next_onset = np.inf if len(
|
76 |
+
next_onset) == 0 else onset_dict[pitch][next_onset[0, 0]]
|
77 |
+
next_pedaloff = np.argwhere(list_pedaloff_time > note.end)
|
78 |
+
next_pedaloff = np.inf if len(
|
79 |
+
next_pedaloff) == 0 else list_pedaloff_time[next_pedaloff[0, 0]]
|
80 |
+
new_noteoff_time = max(note.end, min(next_onset, next_pedaloff))
|
81 |
+
new_noteoff_time = min(new_noteoff_time, pm.get_end_time())
|
82 |
+
note.end = new_noteoff_time
|
83 |
+
|
84 |
+
def get_segment_tokens(self, start, end):
|
85 |
+
"""
|
86 |
+
Transform a segment of the MIDI file into a sequence of tokens.
|
87 |
+
"""
|
88 |
+
dict_event = dict() # a dictionary that maps time to a list of events.
|
89 |
+
|
90 |
+
def append_to_dict_event(time, item):
|
91 |
+
if time in dict_event:
|
92 |
+
dict_event[time].append(item)
|
93 |
+
else:
|
94 |
+
dict_event[time] = [item]
|
95 |
+
|
96 |
+
list_events = [] # events section
|
97 |
+
list_tie_section = [] # tie section
|
98 |
+
|
99 |
+
for instrument in self.pm.instruments:
|
100 |
+
prog = instrument.program
|
101 |
+
for note in instrument.notes:
|
102 |
+
note_end = round(note.end * FRAME_PER_SEC) # 音符終了時刻(フレーム単位)
|
103 |
+
note_start = round(note.start * FRAME_PER_SEC) # 音符開始時刻(フレーム単位)
|
104 |
+
if (note_end < start) or (note_start >= end):
|
105 |
+
# セグメント外の音符は無視
|
106 |
+
continue
|
107 |
+
if (note_start < start) and (note_end >= start):
|
108 |
+
# If the note starts before the segment, but ends in the segment
|
109 |
+
# it is added to the tie section.
|
110 |
+
# セグメント開始時より前に始まり、セグメント内で終了する音符(ピッチ)をタイセクションに追加
|
111 |
+
list_tie_section.append(self.voc_dict["note"] + note.pitch)
|
112 |
+
if note_end < end:
|
113 |
+
# セグメント内で終了する場合、イベントに終了時刻を記録
|
114 |
+
append_to_dict_event(
|
115 |
+
note_end - start, Event(prog, False, note.pitch)
|
116 |
+
)
|
117 |
+
continue
|
118 |
+
assert note_start >= start
|
119 |
+
# セグメント内で開始
|
120 |
+
append_to_dict_event(note_start - start, Event(prog, True, note.pitch))
|
121 |
+
if note_end < end:
|
122 |
+
# セグメント内で終了
|
123 |
+
append_to_dict_event(
|
124 |
+
note_end - start, Event(prog, False, note.pitch)
|
125 |
+
)
|
126 |
+
|
127 |
+
cur_onset = None
|
128 |
+
cur_prog = -1
|
129 |
+
|
130 |
+
for time in sorted(dict_event.keys()): # 現在の相対時間(time)をトークン化し、イベント列に追加
|
131 |
+
list_events.append(self.voc_dict["time"] + time) # self.voc_dict["time"]はtimeトークンの開始ID(133)。これに相対時間を足す
|
132 |
+
for event in sorted(dict_event[time], key=attrgetter("pitch", "onset")): # 同一時間内でピッチを昇順に、onset→offsetの順になるようにソートする。
|
133 |
+
if cur_onset != event.onset:
|
134 |
+
cur_onset = event.onset # オンセットオフセットが変わった場合に更新
|
135 |
+
list_events.append(self.voc_dict["onset"] + int(event.onset)) # オンセットオフセットトークンを追加
|
136 |
+
list_events.append(self.voc_dict["note"] + event.pitch) # 音符トークンを追加
|
137 |
+
|
138 |
+
# Concatenate tie section, endtie token, and event section
|
139 |
+
list_tie_section.append(self.voc_dict["endtie"]) # ID:2を追加
|
140 |
+
list_events.append(self.voc_dict["eos"]) # ID: 1を追加
|
141 |
+
tokens = np.concatenate((list_tie_section, list_events)).astype(int)
|
142 |
+
return tokens
|
143 |
+
|
144 |
+
|
145 |
+
"""## Detokenizer
|
146 |
+
Transforms a list of MIDI-like token sequences into a MIDI file.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def parse_id(voc_dict: dict, id: int):
|
150 |
+
"""
|
151 |
+
トークンIDを解析し、トークンの種類名とその相対IDを返す関数。
|
152 |
+
"""
|
153 |
+
keys = voc_dict["keylist"] # トークンの種類リストを取得
|
154 |
+
token_name = keys[0] # デフォルトの種類を "pad" に設定
|
155 |
+
|
156 |
+
# トークンの種類を特定する
|
157 |
+
for k in keys:
|
158 |
+
if id < voc_dict[k]: # 現在の種類の開始位置より小さい場合、前の種類が該当
|
159 |
+
break
|
160 |
+
token_name = k # 現在の種類名を更新
|
161 |
+
|
162 |
+
# 該当種類内での相対IDを計算
|
163 |
+
token_id = id - voc_dict[token_name]
|
164 |
+
|
165 |
+
return token_name, token_id # 種類名と相対IDを返す
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
def to_second(n):
|
170 |
+
"""
|
171 |
+
フレーム数を秒単位に変換する関数。
|
172 |
+
"""
|
173 |
+
return n * FRAME_STEP_SIZE_SEC
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
def find_note(list, n):
|
178 |
+
"""
|
179 |
+
タプルリストの最初の要素から指定された値を検索し、そのインデックスを返す関数。
|
180 |
+
"""
|
181 |
+
li_elem = [a for a, _ in list] # 最初の要素だけを抽出したリスト
|
182 |
+
try:
|
183 |
+
idx = li_elem.index(n) # n が存在する場合、そのインデックスを取得
|
184 |
+
except ValueError:
|
185 |
+
return -1 # 存在しない場合は -1 を返す
|
186 |
+
return idx # 見つかった場合、そのインデックスを返す
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
def token_seg_list_to_midi(token_seg_list: list):
|
191 |
+
"""
|
192 |
+
トークン列リストをMIDIファイルに変換する関数。
|
193 |
+
"""
|
194 |
+
# MIDIデータと楽器の初期化
|
195 |
+
midi_data = pm.PrettyMIDI()
|
196 |
+
piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
|
197 |
+
piano = pm.Instrument(program=piano_program)
|
198 |
+
|
199 |
+
list_onset = [] # 開始時刻を記録するリスト ※次のセグメント処理を行うときにlist_onsetには、最終音とタイの可能性がある音が記録されている
|
200 |
+
cur_time = 0 # 前回のセグメントの終了時間
|
201 |
+
|
202 |
+
# トークンセグメントごとの処理
|
203 |
+
for token_seg in token_seg_list:
|
204 |
+
list_tie = [] # タイ結合された音符
|
205 |
+
cur_relative_time = -1 # セグメント内の相対時間
|
206 |
+
cur_onset = -1 # 現在のオンセット状態
|
207 |
+
tie_end = False # タイ結合が終了したかどうか
|
208 |
+
|
209 |
+
for token in token_seg:
|
210 |
+
# トークンを解析
|
211 |
+
token_name, token_id = parse_id(voc_single_track, token)
|
212 |
+
|
213 |
+
if token_name == "note":
|
214 |
+
# 音符処理
|
215 |
+
if not tie_end: # タイ結合の場合
|
216 |
+
list_tie.append(token_id) # タイのnote番号(相対ID)を追加
|
217 |
+
elif cur_onset == 1:
|
218 |
+
list_onset.append((token_id, cur_time + cur_relative_time)) # 開始時刻を記録
|
219 |
+
elif cur_onset == 0:
|
220 |
+
# 終了処理
|
221 |
+
i = find_note(list_onset, token_id)
|
222 |
+
if i >= 0:
|
223 |
+
start = list_onset[i][1]
|
224 |
+
end = cur_time + cur_relative_time
|
225 |
+
if start < end: # 開始時刻 < 終了時刻の場合のみ追加
|
226 |
+
new_note = pm.Note(100, token_id, start, end)
|
227 |
+
piano.notes.append(new_note)
|
228 |
+
list_onset.pop(i)
|
229 |
+
|
230 |
+
elif token_name == "onset":
|
231 |
+
# オンセット/オフセットの更新
|
232 |
+
if tie_end:
|
233 |
+
if token_id == 1:
|
234 |
+
cur_onset = 1 # 開始
|
235 |
+
elif token_id == 0:
|
236 |
+
cur_onset = 0 # 終了
|
237 |
+
|
238 |
+
elif token_name == "time":
|
239 |
+
# 相対時間の更新
|
240 |
+
if tie_end:
|
241 |
+
cur_relative_time = to_second(token_id)
|
242 |
+
|
243 |
+
elif token_name == "endtie":
|
244 |
+
# タイ結合終了処理
|
245 |
+
tie_end = True
|
246 |
+
for note, start in list_onset: # list_onsetには前回のセグメントで未処理の最終音が含まれる
|
247 |
+
if note not in list_tie: # list_onsetにあるにも関わらず、list_tieにない場合
|
248 |
+
if start < cur_time:
|
249 |
+
new_note = pm.Note(100, note, start, cur_time) # 前回のセグメントの終了時をendとする(end=cur_time)
|
250 |
+
piano.notes.append(new_note)
|
251 |
+
list_onset.remove((note, start))
|
252 |
+
|
253 |
+
# 現在の時間を更新
|
254 |
+
cur_time += AUDIO_SEGMENT_SEC
|
255 |
+
|
256 |
+
# 楽器をMIDIデータに追加
|
257 |
+
midi_data.instruments.append(piano)
|
258 |
+
return midi_data
|
259 |
+
|
260 |
+
|
261 |
+
# 固定長のセグメントに分割する関数
|
262 |
+
def split_audio_into_segments(y: torch.Tensor, sr: int): # オーディオデータ(テンソル), オーディオのサンプルレート
|
263 |
+
audio_segment_samples = round(AUDIO_SEGMENT_SEC * sr) # 1セグメントの長さをサンプル数で計算
|
264 |
+
pad_size = audio_segment_samples - (y.shape[-1] % audio_segment_samples) # セグメントサイズできっちり分割できるようにpadするサイズを計算
|
265 |
+
|
266 |
+
y = F.pad(y, (0, pad_size)) # padを追加
|
267 |
+
assert (y.shape[-1] % audio_segment_samples) == 0 # 割り切れない場合assertをする
|
268 |
+
n_chunks = y.shape[-1] // audio_segment_samples # セグメント数を計算
|
269 |
+
# 固定長のセグメントに分割
|
270 |
+
y_segments = torch.chunk(y, chunks=n_chunks, dim=-1) # torch.chunk: テンソルを指定した数(n_chunks)に分割、dim=-1: サンプル次元(最後の次元)で分割
|
271 |
+
return torch.stack(y_segments, dim=0) # 分割したセグメントを1つのテンソルにまとめる。dim=0: セグメント数をバッチ次元として結合。
|
272 |
+
# 形状: (セグメント数, セグメント長)。
|
273 |
+
|
274 |
+
# 推論時に作られたpadされたseqから有効な部分を抽出する関数
|
275 |
+
def unpack_sequence(x: torch.Tensor, eos_id: int=1):
|
276 |
+
seqs = [] # 各シーケンスを切り出して保存するリストを初期化。
|
277 |
+
max_length = x.shape[-1] # シーケンスの最大長を取得。whileループの範囲をチェックするために使用。
|
278 |
+
for seq in x: # テンソル x の各シーケンス(行)を処理
|
279 |
+
# eosトークンを探す
|
280 |
+
start_pos = 0
|
281 |
+
pos = 0
|
282 |
+
while (pos < max_length) and (seq[pos] != eos_id): # 現在地が最大長を超えない&現在地がeosではない場合
|
283 |
+
pos += 1
|
284 |
+
# ループ終了後:pos には、終了トークン(eos_id)の位置(またはシーケンスの末尾)が格納されます。
|
285 |
+
end_pos = pos+1
|
286 |
+
seqs.append(seq[start_pos:end_pos]) #開始位置(start_pos)から終了位置(end_pos)までの部分を切り出し、リスト seqs に追加。
|
287 |
+
return seqs
|
288 |
+
|
289 |
+
class LogMelspec(nn.Module):
|
290 |
+
def __init__(self, sample_rate, n_fft, n_mels, hop_length):
|
291 |
+
super().__init__()
|
292 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
293 |
+
sample_rate=sample_rate,
|
294 |
+
n_fft=n_fft, # FFT(高速フーリエ変換)に使用するポイント数
|
295 |
+
hop_length=hop_length, # ストライド(フレームのシフト幅(サンプル単位))。通常、n_fft // 4 などの値を設定
|
296 |
+
f_min=20.0, # メルフィルタバンクの最小周波数。20Hz 以上が推奨(人間の聴覚範囲)。
|
297 |
+
n_mels=n_mels, # メルスペクトログラムの周波数軸方向の次元数(メルバンド数)。
|
298 |
+
mel_scale="slaney", # メル尺度を計算する方法."slaney": より音響的に意味のあるスケール
|
299 |
+
norm="slaney", # メルフィルタバンクの正規化方法. "slaney" を指定するとフィルタバンクがエネルギーで正規化
|
300 |
+
power=1, # 出力スペクトログラムのエネルギース���ール, 1: 振幅スペクトログラム。
|
301 |
+
)
|
302 |
+
self.eps = 1e-5 # 対数計算時のゼロ除算エラーを防ぐための閾値。メルスペクトログラムの値が 1e-5 未満の場合、この値に置き換える。
|
303 |
+
|
304 |
+
def forward(self, x): # 入力: モノラル: (バッチサイズ, サンプル数) or ステレオ: (バッチサイズ, チャンネル数, サンプル数)
|
305 |
+
spec = self.melspec(x) # 入力波形 x からメルスペクトログラムを計算, 出力:メルスペクトログラム: (バッチサイズ, メルバンド数, フレーム数)
|
306 |
+
safe_spec = torch.clamp(spec, min=self.eps) # メルスペクトログラムの最小値を self.eps に制限。値が非常に小さい場合でも、対数計算が可能。
|
307 |
+
log_spec = torch.log(safe_spec) #メルスペクトログラムを対数スケールに変換。
|
308 |
+
return log_spec # (バッチサイズ, メルバンド数, フレーム数) のテンソル。各値は対数スケールのメルスペクトログラム。
|
309 |
+
|
310 |
+
|
311 |
+
class LogCQT(nn.Module):
|
312 |
+
def __init__(self, sample_rate, n_bins, hop_length, bins_per_octave):
|
313 |
+
super().__init__()
|
314 |
+
self.cqt = CQT(
|
315 |
+
sr=sample_rate,
|
316 |
+
hop_length=hop_length,
|
317 |
+
fmin=32.7, # 低周波数の最小値 (通常は32.7Hz, C1)
|
318 |
+
fmax=8000, # 最高周波数
|
319 |
+
n_bins=n_bins,
|
320 |
+
bins_per_octave=bins_per_octave,
|
321 |
+
verbose=False
|
322 |
+
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに載せる
|
323 |
+
self.eps = 1e-5 # ゼロ除算を防ぐ閾値
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
x = x.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに送る
|
327 |
+
cqt_spec = self.cqt(x) # (B, n_bins, time)
|
328 |
+
safe_spec = torch.clamp(cqt_spec, min=self.eps) # 小さな値をカット
|
329 |
+
log_cqt = torch.log(safe_spec) # 対数変換
|
330 |
+
return log_cqt # (B, n_bins, time)
|
331 |
+
|
332 |
+
def normalize_rms_torch(audio_tensor, target_rms=0.1):
|
333 |
+
rms = torch.sqrt(torch.mean(audio_tensor**2)).item()
|
334 |
+
if rms < 1e-6:
|
335 |
+
print("音が非常に小さいため、最小値でスケーリングします")
|
336 |
+
rms = 1e-6
|
337 |
+
scaling_factor = target_rms / rms
|
338 |
+
return audio_tensor * scaling_factor
|
339 |
+
|
340 |
+
def rms_normalize_wav(input_path, output_path, target_rms=0.1):
|
341 |
+
waveform, sr = torchaudio.load(input_path)
|
342 |
+
waveform = waveform.mean(0, keepdim=True) # モノラル化
|
343 |
+
normalized = normalize_rms_torch(waveform, target_rms)
|
344 |
+
torchaudio.save(output_path, normalized, sample_rate=sr)
|
345 |
+
|