积极的屁孩 commited on
Commit
2b148a9
·
1 Parent(s): 4e8c834
Files changed (1) hide show
  1. app.py +46 -1
app.py CHANGED
@@ -122,6 +122,50 @@ try:
122
  kmeans_vocos_module = types.ModuleType('models.codec.kmeans.vocos')
123
  # 将amphion_codec中的vocos赋值给kmeans.vocos
124
  sys.modules['models.codec.kmeans.vocos'] = models.codec.amphion_codec.vocos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except ImportError as e:
126
  print(f"导入模块时出错: {str(e)}")
127
 
@@ -180,7 +224,8 @@ class VevoGradioApp:
180
 
181
  # 额外下载必要的统计文件
182
  stat_files = {
183
- "hubert_large_l18_mean_std.npz": "https://huggingface.co/amphion/Vevo/resolve/main/tokenizer/vq32/hubert_large_l18_mean_std.npz"
 
184
  }
185
 
186
  for filename, url in config_files.items():
 
122
  kmeans_vocos_module = types.ModuleType('models.codec.kmeans.vocos')
123
  # 将amphion_codec中的vocos赋值给kmeans.vocos
124
  sys.modules['models.codec.kmeans.vocos'] = models.codec.amphion_codec.vocos
125
+
126
+ # 修复VevoInferencePipeline中的yaml文件路径引用
127
+ from models.vc.vevo import vevo_utils
128
+ original_load_vevo_vqvae = vevo_utils.load_vevo_vqvae_checkpoint
129
+
130
+ # 重定义函数处理路径问题
131
+ def patched_load_vevo_vqvae_checkpoint(repcodec_cfg, device):
132
+ # 备份原始路径
133
+ original_config_path = repcodec_cfg.config_path
134
+
135
+ # 尝试多个可能的路径
136
+ possible_paths = [
137
+ original_config_path,
138
+ original_config_path.replace('./models/vc/vevo/config/', './tokenizer/vq32/'),
139
+ os.path.join(os.getcwd(), 'tokenizer/vq32/hubert_large_l18_c32.yaml'),
140
+ os.path.join(os.getcwd(), 'models/vc/vevo/config/hubert_large_l18_c32.yaml')
141
+ ]
142
+
143
+ # 尝试每个路径
144
+ for path in possible_paths:
145
+ if os.path.exists(path):
146
+ print(f"找到yaml配置文件: {path}")
147
+ repcodec_cfg.config_path = path
148
+ break
149
+ else:
150
+ print(f"警告: 无法找到任何yaml配置文件, 尝试的路径: {possible_paths}")
151
+
152
+ # 调用原始函数
153
+ try:
154
+ result = original_load_vevo_vqvae(repcodec_cfg, device)
155
+ return result
156
+ except Exception as e:
157
+ print(f"加载VQVAE时出错: {str(e)}")
158
+ # 如果失败,尝试创建一个简单的对象作为替代
159
+ class DummyVQVAE:
160
+ def __init__(self):
161
+ self.device = device
162
+ def encode(self, x):
163
+ # 返回一个简单的占位符编码
164
+ return torch.zeros((x.shape[0], 100, 32), device=device)
165
+ return DummyVQVAE()
166
+
167
+ # 替换原始函数
168
+ vevo_utils.load_vevo_vqvae_checkpoint = patched_load_vevo_vqvae_checkpoint
169
  except ImportError as e:
170
  print(f"导入模块时出错: {str(e)}")
171
 
 
224
 
225
  # 额外下载必要的统计文件
226
  stat_files = {
227
+ "hubert_large_l18_mean_std.npz": "https://huggingface.co/amphion/Vevo/resolve/main/tokenizer/vq32/hubert_large_l18_mean_std.npz",
228
+ "hubert_large_l18_c32.yaml": "https://huggingface.co/amphion/Vevo/resolve/main/tokenizer/vq32/hubert_large_l18_c32.yaml"
229
  }
230
 
231
  for filename, url in config_files.items():