积极的屁孩 commited on
Commit
4a1664c
·
1 Parent(s): abbd236

debug audio saving format

Browse files
Files changed (1) hide show
  1. app.py +123 -8
app.py CHANGED
@@ -234,9 +234,38 @@ def vevo_style(content_wav, style_wav):
234
  temp_style_path = "wav/temp_style.wav"
235
  output_path = "wav/output_vevostyle.wav"
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  # 保存上传的音频
238
- torchaudio.save(temp_content_path, content_wav[0], content_wav[1])
239
- torchaudio.save(temp_style_path, style_wav[0], style_wav[1])
240
 
241
  # 获取管道
242
  pipeline = get_pipeline("style")
@@ -259,9 +288,38 @@ def vevo_timbre(content_wav, reference_wav):
259
  temp_reference_path = "wav/temp_reference.wav"
260
  output_path = "wav/output_vevotimbre.wav"
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  # 保存上传的音频
263
- torchaudio.save(temp_content_path, content_wav[0], content_wav[1])
264
- torchaudio.save(temp_reference_path, reference_wav[0], reference_wav[1])
265
 
266
  # 获取管道
267
  pipeline = get_pipeline("timbre")
@@ -283,9 +341,38 @@ def vevo_voice(content_wav, reference_wav):
283
  temp_reference_path = "wav/temp_reference.wav"
284
  output_path = "wav/output_vevovoice.wav"
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  # 保存上传的音频
287
- torchaudio.save(temp_content_path, content_wav[0], content_wav[1])
288
- torchaudio.save(temp_reference_path, reference_wav[0], reference_wav[1])
289
 
290
  # 获取管道
291
  pipeline = get_pipeline("voice")
@@ -308,11 +395,39 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, src_language="en", ref_language
308
  temp_timbre_path = "wav/temp_timbre.wav"
309
  output_path = "wav/output_vevotts.wav"
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # 保存上传的音频
312
- torchaudio.save(temp_ref_path, ref_wav[0], ref_wav[1])
313
 
314
  if timbre_ref_wav is not None:
315
- torchaudio.save(temp_timbre_path, timbre_ref_wav[0], timbre_ref_wav[1])
 
 
 
 
 
 
 
 
 
 
 
316
  else:
317
  temp_timbre_path = temp_ref_path
318
 
 
234
  temp_style_path = "wav/temp_style.wav"
235
  output_path = "wav/output_vevostyle.wav"
236
 
237
+ # 检查并正确处理音频数据
238
+ if content_wav is None or style_wav is None:
239
+ raise ValueError("请上传音频文件")
240
+
241
+ # Gradio音频组件返回(sample_rate, data)元组或(data, sample_rate)元组
242
+ if isinstance(content_wav, tuple) and len(content_wav) == 2:
243
+ # 确保正确的顺序 (data, sample_rate)
244
+ if isinstance(content_wav[0], np.ndarray):
245
+ content_data, content_sr = content_wav
246
+ else:
247
+ content_sr, content_data = content_wav
248
+ content_tensor = torch.FloatTensor(content_data)
249
+ if content_tensor.ndim == 1:
250
+ content_tensor = content_tensor.unsqueeze(0) # 添加通道维度
251
+ else:
252
+ raise ValueError("内容音频格式不正确")
253
+
254
+ if isinstance(style_wav, tuple) and len(style_wav) == 2:
255
+ # 确保正确的顺序 (data, sample_rate)
256
+ if isinstance(style_wav[0], np.ndarray):
257
+ style_data, style_sr = style_wav
258
+ else:
259
+ style_sr, style_data = style_wav
260
+ style_tensor = torch.FloatTensor(style_data)
261
+ if style_tensor.ndim == 1:
262
+ style_tensor = style_tensor.unsqueeze(0) # 添加通道维度
263
+ else:
264
+ raise ValueError("风格音频格式不正确")
265
+
266
  # 保存上传的音频
267
+ torchaudio.save(temp_content_path, content_tensor, content_sr)
268
+ torchaudio.save(temp_style_path, style_tensor, style_sr)
269
 
270
  # 获取管道
271
  pipeline = get_pipeline("style")
 
288
  temp_reference_path = "wav/temp_reference.wav"
289
  output_path = "wav/output_vevotimbre.wav"
290
 
291
+ # 检查并正确处理音频数据
292
+ if content_wav is None or reference_wav is None:
293
+ raise ValueError("请上传音频文件")
294
+
295
+ # Gradio音频组件返回(sample_rate, data)元组或(data, sample_rate)元组
296
+ if isinstance(content_wav, tuple) and len(content_wav) == 2:
297
+ # 确保正确的顺序 (data, sample_rate)
298
+ if isinstance(content_wav[0], np.ndarray):
299
+ content_data, content_sr = content_wav
300
+ else:
301
+ content_sr, content_data = content_wav
302
+ content_tensor = torch.FloatTensor(content_data)
303
+ if content_tensor.ndim == 1:
304
+ content_tensor = content_tensor.unsqueeze(0) # 添加通道维度
305
+ else:
306
+ raise ValueError("内容音频格式不正确")
307
+
308
+ if isinstance(reference_wav, tuple) and len(reference_wav) == 2:
309
+ # 确保正确的顺序 (data, sample_rate)
310
+ if isinstance(reference_wav[0], np.ndarray):
311
+ reference_data, reference_sr = reference_wav
312
+ else:
313
+ reference_sr, reference_data = reference_wav
314
+ reference_tensor = torch.FloatTensor(reference_data)
315
+ if reference_tensor.ndim == 1:
316
+ reference_tensor = reference_tensor.unsqueeze(0) # 添加通道维度
317
+ else:
318
+ raise ValueError("参考音频格式不正确")
319
+
320
  # 保存上传的音频
321
+ torchaudio.save(temp_content_path, content_tensor, content_sr)
322
+ torchaudio.save(temp_reference_path, reference_tensor, reference_sr)
323
 
324
  # 获取管道
325
  pipeline = get_pipeline("timbre")
 
341
  temp_reference_path = "wav/temp_reference.wav"
342
  output_path = "wav/output_vevovoice.wav"
343
 
344
+ # 检查并正确处理音频数据
345
+ if content_wav is None or reference_wav is None:
346
+ raise ValueError("请上传音频文件")
347
+
348
+ # Gradio音频组件返回(sample_rate, data)元组或(data, sample_rate)元组
349
+ if isinstance(content_wav, tuple) and len(content_wav) == 2:
350
+ # 确保正确的顺序 (data, sample_rate)
351
+ if isinstance(content_wav[0], np.ndarray):
352
+ content_data, content_sr = content_wav
353
+ else:
354
+ content_sr, content_data = content_wav
355
+ content_tensor = torch.FloatTensor(content_data)
356
+ if content_tensor.ndim == 1:
357
+ content_tensor = content_tensor.unsqueeze(0) # 添加通道维度
358
+ else:
359
+ raise ValueError("内容音频格式不正确")
360
+
361
+ if isinstance(reference_wav, tuple) and len(reference_wav) == 2:
362
+ # 确保正确的顺序 (data, sample_rate)
363
+ if isinstance(reference_wav[0], np.ndarray):
364
+ reference_data, reference_sr = reference_wav
365
+ else:
366
+ reference_sr, reference_data = reference_wav
367
+ reference_tensor = torch.FloatTensor(reference_data)
368
+ if reference_tensor.ndim == 1:
369
+ reference_tensor = reference_tensor.unsqueeze(0) # 添加通道维度
370
+ else:
371
+ raise ValueError("参考音频格式不正确")
372
+
373
  # 保存上传的音频
374
+ torchaudio.save(temp_content_path, content_tensor, content_sr)
375
+ torchaudio.save(temp_reference_path, reference_tensor, reference_sr)
376
 
377
  # 获取管道
378
  pipeline = get_pipeline("voice")
 
395
  temp_timbre_path = "wav/temp_timbre.wav"
396
  output_path = "wav/output_vevotts.wav"
397
 
398
+ # 检查并正确处理音频数据
399
+ if ref_wav is None:
400
+ raise ValueError("请上传参考音频文件")
401
+
402
+ # Gradio音频组件返回(sample_rate, data)元组或(data, sample_rate)元组
403
+ if isinstance(ref_wav, tuple) and len(ref_wav) == 2:
404
+ # 确保正确的顺序 (data, sample_rate)
405
+ if isinstance(ref_wav[0], np.ndarray):
406
+ ref_data, ref_sr = ref_wav
407
+ else:
408
+ ref_sr, ref_data = ref_wav
409
+ ref_tensor = torch.FloatTensor(ref_data)
410
+ if ref_tensor.ndim == 1:
411
+ ref_tensor = ref_tensor.unsqueeze(0) # 添加通道维度
412
+ else:
413
+ raise ValueError("参考音频格式不正确")
414
+
415
  # 保存上传的音频
416
+ torchaudio.save(temp_ref_path, ref_tensor, ref_sr)
417
 
418
  if timbre_ref_wav is not None:
419
+ if isinstance(timbre_ref_wav, tuple) and len(timbre_ref_wav) == 2:
420
+ # 确保正确的顺序 (data, sample_rate)
421
+ if isinstance(timbre_ref_wav[0], np.ndarray):
422
+ timbre_data, timbre_sr = timbre_ref_wav
423
+ else:
424
+ timbre_sr, timbre_data = timbre_ref_wav
425
+ timbre_tensor = torch.FloatTensor(timbre_data)
426
+ if timbre_tensor.ndim == 1:
427
+ timbre_tensor = timbre_tensor.unsqueeze(0) # 添加通道维度
428
+ torchaudio.save(temp_timbre_path, timbre_tensor, timbre_sr)
429
+ else:
430
+ raise ValueError("音色参考音频格式不正确")
431
  else:
432
  temp_timbre_path = temp_ref_path
433