积极的屁孩 commited on
Commit
e48a9d8
·
1 Parent(s): 6efd082

adjust frequency

Browse files
Files changed (1) hide show
  1. app.py +51 -23
app.py CHANGED
@@ -234,23 +234,34 @@ def vevo_style(content_wav, style_wav):
234
  temp_style_path = "wav/temp_style.wav"
235
  output_path = "wav/output_vevostyle.wav"
236
 
237
- # 检查并正确处理音频数据
238
  if content_wav is None or style_wav is None:
239
  raise ValueError("请上传音频文件")
240
 
241
- # Gradio音频组件返回(sample_rate, data)元组或(data, sample_rate)元组
242
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
243
- # 确保正确的顺序 (data, sample_rate)
244
  if isinstance(content_wav[0], np.ndarray):
245
  content_data, content_sr = content_wav
246
  else:
247
  content_sr, content_data = content_wav
248
- content_tensor = torch.FloatTensor(content_data)
249
- if content_tensor.ndim == 1:
250
- content_tensor = content_tensor.unsqueeze(0) # 添加通道维度
 
 
 
 
 
 
 
 
 
 
 
 
251
  else:
252
  raise ValueError("内容音频格式不正确")
253
-
254
  if isinstance(style_wav, tuple) and len(style_wav) == 2:
255
  # 确保正确的顺序 (data, sample_rate)
256
  if isinstance(style_wav[0], np.ndarray):
@@ -263,25 +274,42 @@ def vevo_style(content_wav, style_wav):
263
  else:
264
  raise ValueError("风格音频格式不正确")
265
 
266
- # 保存上传的音频
 
 
 
 
267
  torchaudio.save(temp_content_path, content_tensor, content_sr)
268
  torchaudio.save(temp_style_path, style_tensor, style_sr)
269
 
270
- # 获取管道
271
- pipeline = get_pipeline("style")
272
-
273
- # 推理
274
- gen_audio = pipeline.inference_ar_and_fm(
275
- src_wav_path=temp_content_path,
276
- src_text=None,
277
- style_ref_wav_path=temp_style_path,
278
- timbre_ref_wav_path=temp_content_path,
279
- )
280
-
281
- # 保存生成的音频
282
- save_audio(gen_audio, output_path=output_path)
283
-
284
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  def vevo_timbre(content_wav, reference_wav):
287
  temp_content_path = "wav/temp_content.wav"
 
234
  temp_style_path = "wav/temp_style.wav"
235
  output_path = "wav/output_vevostyle.wav"
236
 
237
+ # 检查并处理音频数据
238
  if content_wav is None or style_wav is None:
239
  raise ValueError("请上传音频文件")
240
 
241
+ # 处理音频格式
242
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
 
243
  if isinstance(content_wav[0], np.ndarray):
244
  content_data, content_sr = content_wav
245
  else:
246
  content_sr, content_data = content_wav
247
+
248
+ # 确保是单声道
249
+ if len(content_data.shape) > 1 and content_data.shape[1] > 1:
250
+ content_data = np.mean(content_data, axis=1)
251
+
252
+ # 重采样到24kHz
253
+ if content_sr != 24000:
254
+ content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
255
+ content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
256
+ content_sr = 24000
257
+ else:
258
+ content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
259
+
260
+ # 归一化音量
261
+ content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
262
  else:
263
  raise ValueError("内容音频格式不正确")
264
+
265
  if isinstance(style_wav, tuple) and len(style_wav) == 2:
266
  # 确保正确的顺序 (data, sample_rate)
267
  if isinstance(style_wav[0], np.ndarray):
 
274
  else:
275
  raise ValueError("风格音频格式不正确")
276
 
277
+ # 打印debug信息
278
+ print(f"内容音频形状: {content_tensor.shape}, 采样率: {content_sr}")
279
+ print(f"风格音频形状: {style_tensor.shape}, 采样率: {style_sr}")
280
+
281
+ # 保存音频
282
  torchaudio.save(temp_content_path, content_tensor, content_sr)
283
  torchaudio.save(temp_style_path, style_tensor, style_sr)
284
 
285
+ try:
286
+ # 获取管道
287
+ pipeline = get_pipeline("style")
288
+
289
+ # 推理
290
+ gen_audio = pipeline.inference_ar_and_fm(
291
+ src_wav_path=temp_content_path,
292
+ src_text=None,
293
+ style_ref_wav_path=temp_style_path,
294
+ timbre_ref_wav_path=temp_content_path,
295
+ )
296
+
297
+ # 检查生成音频是否为数值异常
298
+ if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
299
+ print("警告:生成的音频包含NaN或Inf值")
300
+ gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
301
+
302
+ print(f"生成音频形状: {gen_audio.shape}, 最大值: {torch.max(gen_audio)}, 最小值: {torch.min(gen_audio)}")
303
+
304
+ # 保存生成的音频
305
+ save_audio(gen_audio, output_path=output_path)
306
+
307
+ return output_path
308
+ except Exception as e:
309
+ print(f"处理过程中出错: {e}")
310
+ import traceback
311
+ traceback.print_exc()
312
+ raise e
313
 
314
  def vevo_timbre(content_wav, reference_wav):
315
  temp_content_path = "wav/temp_content.wav"