DeepLearning101 commited on
Commit
76875bc
·
verified ·
1 Parent(s): 35fea36

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +10 -2
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -49,13 +49,21 @@ def load_dpt_model():
49
  token=speech_sep_token
50
  )
51
 
52
- # 👇 原本邏輯完全不變
53
  conf_filterbank, conf_masknet = get_conf()
54
  model_class = getattr(asteroid_test, "DPTNet")
55
  model = model_class(**conf_filterbank, **conf_masknet)
56
  model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
57
 
58
- state_dict = torch.load(model_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
59
  model.load_state_dict(state_dict)
60
  model.eval()
61
  return model
 
49
  token=speech_sep_token
50
  )
51
 
 
52
  conf_filterbank, conf_masknet = get_conf()
53
  model_class = getattr(asteroid_test, "DPTNet")
54
  model = model_class(**conf_filterbank, **conf_masknet)
55
  model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
56
 
57
+ try:
58
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
59
+ except pickle.UnpicklingError as e:
60
+ raise RuntimeError(
61
+ "模型載入失敗!請確認:\n"
62
+ "1. 模型來源是否可信\n"
63
+ "2. 是否為舊版 PyTorch 儲存的模型\n"
64
+ "3. 嘗試鎖定 PyTorch 版本為 2.5.x"
65
+ ) from e
66
+
67
  model.load_state_dict(state_dict)
68
  model.eval()
69
  return model