Spaces:
Build error
Build error
积极的屁孩
commited on
Commit
·
e48a9d8
1
Parent(s):
6efd082
adjust frequency
Browse files
app.py
CHANGED
@@ -234,23 +234,34 @@ def vevo_style(content_wav, style_wav):
|
|
234 |
temp_style_path = "wav/temp_style.wav"
|
235 |
output_path = "wav/output_vevostyle.wav"
|
236 |
|
237 |
-
#
|
238 |
if content_wav is None or style_wav is None:
|
239 |
raise ValueError("请上传音频文件")
|
240 |
|
241 |
-
#
|
242 |
if isinstance(content_wav, tuple) and len(content_wav) == 2:
|
243 |
-
# 确保正确的顺序 (data, sample_rate)
|
244 |
if isinstance(content_wav[0], np.ndarray):
|
245 |
content_data, content_sr = content_wav
|
246 |
else:
|
247 |
content_sr, content_data = content_wav
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
else:
|
252 |
raise ValueError("内容音频格式不正确")
|
253 |
-
|
254 |
if isinstance(style_wav, tuple) and len(style_wav) == 2:
|
255 |
# 确保正确的顺序 (data, sample_rate)
|
256 |
if isinstance(style_wav[0], np.ndarray):
|
@@ -263,25 +274,42 @@ def vevo_style(content_wav, style_wav):
|
|
263 |
else:
|
264 |
raise ValueError("风格音频格式不正确")
|
265 |
|
266 |
-
#
|
|
|
|
|
|
|
|
|
267 |
torchaudio.save(temp_content_path, content_tensor, content_sr)
|
268 |
torchaudio.save(temp_style_path, style_tensor, style_sr)
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
def vevo_timbre(content_wav, reference_wav):
|
287 |
temp_content_path = "wav/temp_content.wav"
|
|
|
234 |
temp_style_path = "wav/temp_style.wav"
|
235 |
output_path = "wav/output_vevostyle.wav"
|
236 |
|
237 |
+
# 检查并处理音频数据
|
238 |
if content_wav is None or style_wav is None:
|
239 |
raise ValueError("请上传音频文件")
|
240 |
|
241 |
+
# 处理音频格式
|
242 |
if isinstance(content_wav, tuple) and len(content_wav) == 2:
|
|
|
243 |
if isinstance(content_wav[0], np.ndarray):
|
244 |
content_data, content_sr = content_wav
|
245 |
else:
|
246 |
content_sr, content_data = content_wav
|
247 |
+
|
248 |
+
# 确保是单声道
|
249 |
+
if len(content_data.shape) > 1 and content_data.shape[1] > 1:
|
250 |
+
content_data = np.mean(content_data, axis=1)
|
251 |
+
|
252 |
+
# 重采样到24kHz
|
253 |
+
if content_sr != 24000:
|
254 |
+
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
|
255 |
+
content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
|
256 |
+
content_sr = 24000
|
257 |
+
else:
|
258 |
+
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
|
259 |
+
|
260 |
+
# 归一化音量
|
261 |
+
content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
|
262 |
else:
|
263 |
raise ValueError("内容音频格式不正确")
|
264 |
+
|
265 |
if isinstance(style_wav, tuple) and len(style_wav) == 2:
|
266 |
# 确保正确的顺序 (data, sample_rate)
|
267 |
if isinstance(style_wav[0], np.ndarray):
|
|
|
274 |
else:
|
275 |
raise ValueError("风格音频格式不正确")
|
276 |
|
277 |
+
# 打印debug信息
|
278 |
+
print(f"内容音频形状: {content_tensor.shape}, 采样率: {content_sr}")
|
279 |
+
print(f"风格音频形状: {style_tensor.shape}, 采样率: {style_sr}")
|
280 |
+
|
281 |
+
# 保存音频
|
282 |
torchaudio.save(temp_content_path, content_tensor, content_sr)
|
283 |
torchaudio.save(temp_style_path, style_tensor, style_sr)
|
284 |
|
285 |
+
try:
|
286 |
+
# 获取管道
|
287 |
+
pipeline = get_pipeline("style")
|
288 |
+
|
289 |
+
# 推理
|
290 |
+
gen_audio = pipeline.inference_ar_and_fm(
|
291 |
+
src_wav_path=temp_content_path,
|
292 |
+
src_text=None,
|
293 |
+
style_ref_wav_path=temp_style_path,
|
294 |
+
timbre_ref_wav_path=temp_content_path,
|
295 |
+
)
|
296 |
+
|
297 |
+
# 检查生成音频是否为数值异常
|
298 |
+
if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
|
299 |
+
print("警告:生成的音频包含NaN或Inf值")
|
300 |
+
gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
|
301 |
+
|
302 |
+
print(f"生成音频形状: {gen_audio.shape}, 最大值: {torch.max(gen_audio)}, 最小值: {torch.min(gen_audio)}")
|
303 |
+
|
304 |
+
# 保存生成的音频
|
305 |
+
save_audio(gen_audio, output_path=output_path)
|
306 |
+
|
307 |
+
return output_path
|
308 |
+
except Exception as e:
|
309 |
+
print(f"处理过程中出错: {e}")
|
310 |
+
import traceback
|
311 |
+
traceback.print_exc()
|
312 |
+
raise e
|
313 |
|
314 |
def vevo_timbre(content_wav, reference_wav):
|
315 |
temp_content_path = "wav/temp_content.wav"
|