import torch import torchaudio import os import soundfile as sf import librosa from utils import unpack_sequence, token_seg_list_to_midi from train import LitTranscriber from utils import rms_normalize_wav BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # backend/src を指す PTH_PATH = os.path.join(BASE_DIR, "model.pth") # ✅ .pth に変更 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): args = { "n_mels": 128, "sample_rate": 16000, "n_fft": 1024, "hop_length": 128, } model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99) state_dict = torch.load(PTH_PATH, map_location=device) # ✅ .pthをロード model.load_state_dict(state_dict) #model.to(device) # ✅ デバイスに転送 model.eval() return model def convert_to_pcm_wav(input_path, output_path): # librosaで読み込み(自動的にPCM形式に変換される) y, sr = librosa.load(input_path, sr=16000, mono=True) sf.write(output_path, y, sr) def infer_midi_from_wav(input_wav_path: str) -> str: model = load_model() converted_path = os.path.join(BASE_DIR, "converted_input.wav") convert_to_pcm_wav(input_wav_path, converted_path) normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav") rms_normalize_wav(converted_path, normalized_path, target_rms=0.1) waveform, sr = torchaudio.load(normalized_path) waveform = waveform.mean(0).to(device) if sr != model.transcriber.sr: waveform = torchaudio.functional.resample( waveform, sr, model.transcriber.sr ).to(device) with torch.no_grad(): output_tokens = model(waveform) unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy()) unpadded_tokens = [t[1:] for t in unpadded_tokens] est_midi = token_seg_list_to_midi(unpadded_tokens) midi_path = os.path.join(BASE_DIR, "output.mid") est_midi.write(midi_path) print(f"MIDI saved at: {midi_path}") return midi_path