Update DPTNet_eval/DPTNet_quant_sep.py
Browse files
DPTNet_eval/DPTNet_quant_sep.py
CHANGED
@@ -49,13 +49,21 @@ def load_dpt_model():
|
|
49 |
token=speech_sep_token
|
50 |
)
|
51 |
|
52 |
-
# 👇 原本邏輯完全不變
|
53 |
conf_filterbank, conf_masknet = get_conf()
|
54 |
model_class = getattr(asteroid_test, "DPTNet")
|
55 |
model = model_class(**conf_filterbank, **conf_masknet)
|
56 |
model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
model.load_state_dict(state_dict)
|
60 |
model.eval()
|
61 |
return model
|
|
|
49 |
token=speech_sep_token
|
50 |
)
|
51 |
|
|
|
52 |
conf_filterbank, conf_masknet = get_conf()
|
53 |
model_class = getattr(asteroid_test, "DPTNet")
|
54 |
model = model_class(**conf_filterbank, **conf_masknet)
|
55 |
model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
|
56 |
|
57 |
+
try:
|
58 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
59 |
+
except pickle.UnpicklingError as e:
|
60 |
+
raise RuntimeError(
|
61 |
+
"模型載入失敗!請確認:\n"
|
62 |
+
"1. 模型來源是否可信\n"
|
63 |
+
"2. 是否為舊版 PyTorch 儲存的模型\n"
|
64 |
+
"3. 嘗試鎖定 PyTorch 版本為 2.5.x"
|
65 |
+
) from e
|
66 |
+
|
67 |
model.load_state_dict(state_dict)
|
68 |
model.eval()
|
69 |
return model
|