Spaces:
Running
Running
File size: 17,652 Bytes
c094356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
from operator import attrgetter
import dataclasses
import numpy as np
import pretty_midi as pm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from nnAudio.features import CQT
import soundfile as sf
from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC
from config import voc_single_track
@dataclasses.dataclass
class Event:
prog: int
onset: bool
pitch: int
class MIDITokenExtractor:
"""
・MIDIデータ(音符、タイミング、ペダル情報など)を抽出し、トークン列に変換する。
・セグメント単位でMIDIデータを分割し、各セグメントをトークンとして表現。
"""
def __init__(self, midi_path, voc_dict, apply_pedal=True):
"""
・MIDIデータを読み込み、必要に応じてサステインペダルを適用する初期化処理を行う。
"""
self.pm = pm.PrettyMIDI(midi_path) # MIDIファイルをPrettyMIDIで読み込む
if apply_pedal:
self.pm_apply_pedal(self.pm) # サステインペダル処理を適用
self.voc_dict = voc_dict # トークンの定義辞書
self.multi_track = "instrument" in voc_dict # マルチトラック対応のフラグ
def pm_apply_pedal(self, pm: pm.PrettyMIDI, program=0):
"""
Apply sustain pedal by stretching the notes in the pm object.
"""
# 1: Record the onset positions of each notes as a dictionary
onset_dict = dict()
for note in pm.instruments[program].notes:
if note.pitch in onset_dict:
onset_dict[note.pitch].append(note.start)
else:
onset_dict[note.pitch] = [note.start]
for k in onset_dict.keys():
onset_dict[k] = np.sort(onset_dict[k])
# 2: Record the pedal on/off state of each time frame
arr_pedal = np.zeros(
round(pm.get_end_time()*FRAME_PER_SEC)+100, dtype=bool)
pedal_on_time = -1
list_pedaloff_time = []
for cc in pm.instruments[program].control_changes:
if cc.number == 64:
if (cc.value > 0) and (pedal_on_time < 0):
pedal_on_time = round(cc.time*FRAME_PER_SEC)
elif (cc.value == 0) and (pedal_on_time >= 0):
pedal_off_time = round(cc.time*FRAME_PER_SEC)
arr_pedal[pedal_on_time:pedal_off_time] = True
list_pedaloff_time.append(cc.time)
pedal_on_time = -1
list_pedaloff_time = np.sort(list_pedaloff_time)
# 3: Stretch the notes (modify note.end)
for note in pm.instruments[program].notes:
# 3-1: Determine whether sustain pedal is on at note.end. If not, do nothing.
# 3-2: Find the next note onset time and next pedal off time after note.end.
# 3-3: Extend note.end till the minimum of next_onset and next_pedaloff.
note_off_frame = round(note.end*FRAME_PER_SEC)
pitch = note.pitch
if arr_pedal[note_off_frame]:
next_onset = np.argwhere(onset_dict[pitch] > note.end)
next_onset = np.inf if len(
next_onset) == 0 else onset_dict[pitch][next_onset[0, 0]]
next_pedaloff = np.argwhere(list_pedaloff_time > note.end)
next_pedaloff = np.inf if len(
next_pedaloff) == 0 else list_pedaloff_time[next_pedaloff[0, 0]]
new_noteoff_time = max(note.end, min(next_onset, next_pedaloff))
new_noteoff_time = min(new_noteoff_time, pm.get_end_time())
note.end = new_noteoff_time
def get_segment_tokens(self, start, end):
"""
Transform a segment of the MIDI file into a sequence of tokens.
"""
dict_event = dict() # a dictionary that maps time to a list of events.
def append_to_dict_event(time, item):
if time in dict_event:
dict_event[time].append(item)
else:
dict_event[time] = [item]
list_events = [] # events section
list_tie_section = [] # tie section
for instrument in self.pm.instruments:
prog = instrument.program
for note in instrument.notes:
note_end = round(note.end * FRAME_PER_SEC) # 音符終了時刻(フレーム単位)
note_start = round(note.start * FRAME_PER_SEC) # 音符開始時刻(フレーム単位)
if (note_end < start) or (note_start >= end):
# セグメント外の音符は無視
continue
if (note_start < start) and (note_end >= start):
# If the note starts before the segment, but ends in the segment
# it is added to the tie section.
# セグメント開始時より前に始まり、セグメント内で終了する音符(ピッチ)をタイセクションに追加
list_tie_section.append(self.voc_dict["note"] + note.pitch)
if note_end < end:
# セグメント内で終了する場合、イベントに終了時刻を記録
append_to_dict_event(
note_end - start, Event(prog, False, note.pitch)
)
continue
assert note_start >= start
# セグメント内で開始
append_to_dict_event(note_start - start, Event(prog, True, note.pitch))
if note_end < end:
# セグメント内で終了
append_to_dict_event(
note_end - start, Event(prog, False, note.pitch)
)
cur_onset = None
cur_prog = -1
for time in sorted(dict_event.keys()): # 現在の相対時間(time)をトークン化し、イベント列に追加
list_events.append(self.voc_dict["time"] + time) # self.voc_dict["time"]はtimeトークンの開始ID(133)。これに相対時間を足す
for event in sorted(dict_event[time], key=attrgetter("pitch", "onset")): # 同一時間内でピッチを昇順に、onset→offsetの順になるようにソートする。
if cur_onset != event.onset:
cur_onset = event.onset # オンセットオフセットが変わった場合に更新
list_events.append(self.voc_dict["onset"] + int(event.onset)) # オンセットオフセットトークンを追加
list_events.append(self.voc_dict["note"] + event.pitch) # 音符トークンを追加
# Concatenate tie section, endtie token, and event section
list_tie_section.append(self.voc_dict["endtie"]) # ID:2を追加
list_events.append(self.voc_dict["eos"]) # ID: 1を追加
tokens = np.concatenate((list_tie_section, list_events)).astype(int)
return tokens
"""## Detokenizer
Transforms a list of MIDI-like token sequences into a MIDI file.
"""
def parse_id(voc_dict: dict, id: int):
"""
トークンIDを解析し、トークンの種類名とその相対IDを返す関数。
"""
keys = voc_dict["keylist"] # トークンの種類リストを取得
token_name = keys[0] # デフォルトの種類を "pad" に設定
# トークンの種類を特定する
for k in keys:
if id < voc_dict[k]: # 現在の種類の開始位置より小さい場合、前の種類が該当
break
token_name = k # 現在の種類名を更新
# 該当種類内での相対IDを計算
token_id = id - voc_dict[token_name]
return token_name, token_id # 種類名と相対IDを返す
def to_second(n):
"""
フレーム数を秒単位に変換する関数。
"""
return n * FRAME_STEP_SIZE_SEC
def find_note(list, n):
"""
タプルリストの最初の要素から指定された値を検索し、そのインデックスを返す関数。
"""
li_elem = [a for a, _ in list] # 最初の要素だけを抽出したリスト
try:
idx = li_elem.index(n) # n が存在する場合、そのインデックスを取得
except ValueError:
return -1 # 存在しない場合は -1 を返す
return idx # 見つかった場合、そのインデックスを返す
def token_seg_list_to_midi(token_seg_list: list):
"""
トークン列リストをMIDIファイルに変換する関数。
"""
# MIDIデータと楽器の初期化
midi_data = pm.PrettyMIDI()
piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
piano = pm.Instrument(program=piano_program)
list_onset = [] # 開始時刻を記録するリスト ※次のセグメント処理を行うときにlist_onsetには、最終音とタイの可能性がある音が記録されている
cur_time = 0 # 前回のセグメントの終了時間
# トークンセグメントごとの処理
for token_seg in token_seg_list:
list_tie = [] # タイ結合された音符
cur_relative_time = -1 # セグメント内の相対時間
cur_onset = -1 # 現在のオンセット状態
tie_end = False # タイ結合が終了したかどうか
for token in token_seg:
# トークンを解析
token_name, token_id = parse_id(voc_single_track, token)
if token_name == "note":
# 音符処理
if not tie_end: # タイ結合の場合
list_tie.append(token_id) # タイのnote番号(相対ID)を追加
elif cur_onset == 1:
list_onset.append((token_id, cur_time + cur_relative_time)) # 開始時刻を記録
elif cur_onset == 0:
# 終了処理
i = find_note(list_onset, token_id)
if i >= 0:
start = list_onset[i][1]
end = cur_time + cur_relative_time
if start < end: # 開始時刻 < 終了時刻の場合のみ追加
new_note = pm.Note(100, token_id, start, end)
piano.notes.append(new_note)
list_onset.pop(i)
elif token_name == "onset":
# オンセット/オフセットの更新
if tie_end:
if token_id == 1:
cur_onset = 1 # 開始
elif token_id == 0:
cur_onset = 0 # 終了
elif token_name == "time":
# 相対時間の更新
if tie_end:
cur_relative_time = to_second(token_id)
elif token_name == "endtie":
# タイ結合終了処理
tie_end = True
for note, start in list_onset: # list_onsetには前回のセグメントで未処理の最終音が含まれる
if note not in list_tie: # list_onsetにあるにも関わらず、list_tieにない場合
if start < cur_time:
new_note = pm.Note(100, note, start, cur_time) # 前回のセグメントの終了時をendとする(end=cur_time)
piano.notes.append(new_note)
list_onset.remove((note, start))
# 現在の時間を更新
cur_time += AUDIO_SEGMENT_SEC
# 楽器をMIDIデータに追加
midi_data.instruments.append(piano)
return midi_data
# 固定長のセグメントに分割する関数
def split_audio_into_segments(y: torch.Tensor, sr: int): # オーディオデータ(テンソル), オーディオのサンプルレート
audio_segment_samples = round(AUDIO_SEGMENT_SEC * sr) # 1セグメントの長さをサンプル数で計算
pad_size = audio_segment_samples - (y.shape[-1] % audio_segment_samples) # セグメントサイズできっちり分割できるようにpadするサイズを計算
y = F.pad(y, (0, pad_size)) # padを追加
assert (y.shape[-1] % audio_segment_samples) == 0 # 割り切れない場合assertをする
n_chunks = y.shape[-1] // audio_segment_samples # セグメント数を計算
# 固定長のセグメントに分割
y_segments = torch.chunk(y, chunks=n_chunks, dim=-1) # torch.chunk: テンソルを指定した数(n_chunks)に分割、dim=-1: サンプル次元(最後の次元)で分割
return torch.stack(y_segments, dim=0) # 分割したセグメントを1つのテンソルにまとめる。dim=0: セグメント数をバッチ次元として結合。
# 形状: (セグメント数, セグメント長)。
# 推論時に作られたpadされたseqから有効な部分を抽出する関数
def unpack_sequence(x: torch.Tensor, eos_id: int=1):
seqs = [] # 各シーケンスを切り出して保存するリストを初期化。
max_length = x.shape[-1] # シーケンスの最大長を取得。whileループの範囲をチェックするために使用。
for seq in x: # テンソル x の各シーケンス(行)を処理
# eosトークンを探す
start_pos = 0
pos = 0
while (pos < max_length) and (seq[pos] != eos_id): # 現在地が最大長を超えない&現在地がeosではない場合
pos += 1
# ループ終了後:pos には、終了トークン(eos_id)の位置(またはシーケンスの末尾)が格納されます。
end_pos = pos+1
seqs.append(seq[start_pos:end_pos]) #開始位置(start_pos)から終了位置(end_pos)までの部分を切り出し、リスト seqs に追加。
return seqs
class LogMelspec(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, hop_length):
super().__init__()
self.melspec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft, # FFT(高速フーリエ変換)に使用するポイント数
hop_length=hop_length, # ストライド(フレームのシフト幅(サンプル単位))。通常、n_fft // 4 などの値を設定
f_min=20.0, # メルフィルタバンクの最小周波数。20Hz 以上が推奨(人間の聴覚範囲)。
n_mels=n_mels, # メルスペクトログラムの周波数軸方向の次元数(メルバンド数)。
mel_scale="slaney", # メル尺度を計算する方法."slaney": より音響的に意味のあるスケール
norm="slaney", # メルフィルタバンクの正規化方法. "slaney" を指定するとフィルタバンクがエネルギーで正規化
power=1, # 出力スペクトログラムのエネルギースケール, 1: 振幅スペクトログラム。
)
self.eps = 1e-5 # 対数計算時のゼロ除算エラーを防ぐための閾値。メルスペクトログラムの値が 1e-5 未満の場合、この値に置き換える。
def forward(self, x): # 入力: モノラル: (バッチサイズ, サンプル数) or ステレオ: (バッチサイズ, チャンネル数, サンプル数)
spec = self.melspec(x) # 入力波形 x からメルスペクトログラムを計算, 出力:メルスペクトログラム: (バッチサイズ, メルバンド数, フレーム数)
safe_spec = torch.clamp(spec, min=self.eps) # メルスペクトログラムの最小値を self.eps に制限。値が非常に小さい場合でも、対数計算が可能。
log_spec = torch.log(safe_spec) #メルスペクトログラムを対数スケールに変換。
return log_spec # (バッチサイズ, メルバンド数, フレーム数) のテンソル。各値は対数スケールのメルスペクトログラム。
class LogCQT(nn.Module):
def __init__(self, sample_rate, n_bins, hop_length, bins_per_octave):
super().__init__()
self.cqt = CQT(
sr=sample_rate,
hop_length=hop_length,
fmin=32.7, # 低周波数の最小値 (通常は32.7Hz, C1)
fmax=8000, # 最高周波数
n_bins=n_bins,
bins_per_octave=bins_per_octave,
verbose=False
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに載せる
self.eps = 1e-5 # ゼロ除算を防ぐ閾値
def forward(self, x):
x = x.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # GPUに送る
cqt_spec = self.cqt(x) # (B, n_bins, time)
safe_spec = torch.clamp(cqt_spec, min=self.eps) # 小さな値をカット
log_cqt = torch.log(safe_spec) # 対数変換
return log_cqt # (B, n_bins, time)
def normalize_rms_torch(audio_tensor, target_rms=0.1):
rms = torch.sqrt(torch.mean(audio_tensor**2)).item()
if rms < 1e-6:
print("音が非常に小さいため、最小値でスケーリングします")
rms = 1e-6
scaling_factor = target_rms / rms
return audio_tensor * scaling_factor
def rms_normalize_wav(input_path, output_path, target_rms=0.1):
waveform, sr = torchaudio.load(input_path)
waveform = waveform.mean(0, keepdim=True) # モノラル化
normalized = normalize_rms_torch(waveform, target_rms)
torchaudio.save(output_path, normalized, sample_rate=sr)
|