DeepLearning101 commited on
Commit
67c1496
·
verified ·
1 Parent(s): d8be50a

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +17 -19
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -46,49 +46,47 @@ def get_conf():
46
  def load_dpt_model():
47
  print('Load Separation Model...')
48
 
49
- # 從環境變數取得 Secret Token
50
  speech_sep_token = os.getenv("SpeechSeparation")
51
  if not speech_sep_token:
52
  raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
53
 
54
- # 從 HF Hub 下載模型權重
55
  model_path = hf_hub_download(
56
  repo_id="DeepLearning101/speech-separation",
57
  filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
58
  token=speech_sep_token
59
  )
60
 
61
- # 取得模型參數
62
  conf_filterbank, conf_masknet = get_conf()
63
 
64
- # 建立模型架構(⚠️ 這邊要與訓練時完全一樣)
65
  try:
66
  model_class = getattr(asteroid_test, "DPTNet")
67
  model = model_class(**conf_filterbank, **conf_masknet)
68
  except Exception as e:
69
  raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
70
 
71
- # 套用量化設定
72
- try:
73
- model = torch.quantization.quantize_dynamic(
74
- model,
75
- {torch.nn.LSTM, torch.nn.Linear},
76
- dtype=torch.qint8
77
- )
78
- except Exception as e:
79
- print("量化設定失敗:", e)
80
 
81
- # 載入權重(忽略不匹配的 keys)
82
  state_dict = torch.load(model_path, map_location="cpu")
83
  own_state = model.state_dict()
84
- filtered_state_dict = {
85
- k: v for k, v in state_dict.items() if k in own_state and v.shape == own_state[k].shape
86
- }
87
 
88
- # 忽略找不到的 keys,也不強制要求全部 match
 
 
 
 
 
 
 
 
 
 
 
89
  missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
90
 
91
- # 印出警告訊息方便除錯
92
  if missing_keys:
93
  print("⚠️ Missing keys:", missing_keys)
94
  if unexpected_keys:
 
46
  def load_dpt_model():
47
  print('Load Separation Model...')
48
 
 
49
  speech_sep_token = os.getenv("SpeechSeparation")
50
  if not speech_sep_token:
51
  raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
52
 
 
53
  model_path = hf_hub_download(
54
  repo_id="DeepLearning101/speech-separation",
55
  filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
56
  token=speech_sep_token
57
  )
58
 
 
59
  conf_filterbank, conf_masknet = get_conf()
60
 
 
61
  try:
62
  model_class = getattr(asteroid_test, "DPTNet")
63
  model = model_class(**conf_filterbank, **conf_masknet)
64
  except Exception as e:
65
  raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
66
 
67
+ model = torch.quantization.quantize_dynamic(
68
+ model,
69
+ {torch.nn.LSTM, torch.nn.Linear},
70
+ dtype=torch.qint8
71
+ )
 
 
 
 
72
 
 
73
  state_dict = torch.load(model_path, map_location="cpu")
74
  own_state = model.state_dict()
 
 
 
75
 
76
+ # 只保留是 torch.Tensor 的 key-value pairs
77
+ filtered_state_dict = {}
78
+ for k, v in state_dict.items():
79
+ if k in own_state:
80
+ if isinstance(v, torch.Tensor) and isinstance(own_state[k], torch.Tensor):
81
+ if v.shape == own_state[k].shape:
82
+ filtered_state_dict[k] = v
83
+ else:
84
+ print(f"Skip '{k}': shape mismatch")
85
+ else:
86
+ print(f"Skip '{k}': not a tensor")
87
+
88
  missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
89
 
 
90
  if missing_keys:
91
  print("⚠️ Missing keys:", missing_keys)
92
  if unexpected_keys: