soiz1 commited on
Commit
538cd32
·
verified ·
1 Parent(s): a77e655

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -681
app.py CHANGED
@@ -1,683 +1,3 @@
1
- import os
2
- import spaces
3
- import gradio as gr
4
- import torch
5
- import torchaudio
6
- import librosa
7
- from modules.commons import build_model, load_checkpoint, recursive_munch
8
- import yaml
9
- from hf_utils import load_custom_model_from_hf
10
- import numpy as np
11
- from pydub import AudioSegment
12
-
13
- # Load model and configuration
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
17
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
18
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
19
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
20
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
21
- config = yaml.safe_load(open(dit_config_path, 'r'))
22
- model_params = recursive_munch(config['model_params'])
23
- model = build_model(model_params, stage='DiT')
24
- hop_length = config['preprocess_params']['spect_params']['hop_length']
25
- sr = config['preprocess_params']['sr']
26
-
27
- # Load checkpoints
28
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
29
- load_only_params=True, ignore_modules=[], is_distributed=False)
30
- for key in model:
31
- model[key].eval()
32
- model[key].to(device)
33
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
34
-
35
- # Load additional modules
36
- from modules.campplus.DTDNN import CAMPPlus
37
-
38
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
39
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
40
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
41
- campplus_model.eval()
42
- campplus_model.to(device)
43
-
44
- from modules.bigvgan import bigvgan
45
-
46
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
47
-
48
- # remove weight norm in the model and set to eval mode
49
- bigvgan_model.remove_weight_norm()
50
- bigvgan_model = bigvgan_model.eval().to(device)
51
-
52
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
53
-
54
- codec_config = yaml.safe_load(open(config_path))
55
- codec_model_params = recursive_munch(codec_config['model_params'])
56
- codec_encoder = build_model(codec_model_params, stage="codec")
57
-
58
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
59
-
60
- for key in codec_encoder:
61
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
62
- _ = [codec_encoder[key].eval() for key in codec_encoder]
63
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
64
-
65
- # whisper
66
- from transformers import AutoFeatureExtractor, WhisperModel
67
-
68
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
69
- 'whisper_name') else "openai/whisper-small"
70
- whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
71
- del whisper_model.decoder
72
- whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
73
-
74
- # Generate mel spectrograms
75
- mel_fn_args = {
76
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
77
- "win_size": config['preprocess_params']['spect_params']['win_length'],
78
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
79
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
80
- "sampling_rate": sr,
81
- "fmin": 0,
82
- "fmax": None,
83
- "center": False
84
- }
85
- from modules.audio import mel_spectrogram
86
-
87
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
88
-
89
- # f0 conditioned model
90
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
91
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
92
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
93
-
94
- config = yaml.safe_load(open(dit_config_path, 'r'))
95
- model_params = recursive_munch(config['model_params'])
96
- model_f0 = build_model(model_params, stage='DiT')
97
- hop_length = config['preprocess_params']['spect_params']['hop_length']
98
- sr = config['preprocess_params']['sr']
99
-
100
- # Load checkpoints
101
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
102
- load_only_params=True, ignore_modules=[], is_distributed=False)
103
- for key in model_f0:
104
- model_f0[key].eval()
105
- model_f0[key].to(device)
106
- model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
107
-
108
- # f0 extractor
109
- from modules.rmvpe import RMVPE
110
-
111
- model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
112
- rmvpe = RMVPE(model_path, is_half=False, device=device)
113
-
114
- mel_fn_args_f0 = {
115
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
116
- "win_size": config['preprocess_params']['spect_params']['win_length'],
117
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
118
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
119
- "sampling_rate": sr,
120
- "fmin": 0,
121
- "fmax": None,
122
- "center": False
123
- }
124
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
125
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
126
-
127
- # remove weight norm in the model and set to eval mode
128
- bigvgan_44k_model.remove_weight_norm()
129
- bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
130
-
131
- def adjust_f0_semitones(f0_sequence, n_semitones):
132
- factor = 2 ** (n_semitones / 12)
133
- return f0_sequence * factor
134
-
135
- def crossfade(chunk1, chunk2, overlap):
136
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
137
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
138
- chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
139
- return chunk2
140
-
141
- # streaming and chunk processing related params
142
- bitrate = "320k"
143
- overlap_frame_len = 16
144
- @spaces.GPU
145
- @torch.no_grad()
146
- @torch.inference_mode()
147
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
148
- inference_module = model if not f0_condition else model_f0
149
- mel_fn = to_mel if not f0_condition else to_mel_f0
150
- bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
151
- sr = 22050 if not f0_condition else 44100
152
- hop_length = 256 if not f0_condition else 512
153
- max_context_window = sr // hop_length * 30
154
- overlap_wave_len = overlap_frame_len * hop_length
155
- # Load audio
156
- source_audio = librosa.load(source, sr=sr)[0]
157
- ref_audio = librosa.load(target, sr=sr)[0]
158
-
159
- # Process audio
160
- source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
161
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
162
-
163
- # Resample
164
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
165
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
166
- # if source audio less than 30 seconds, whisper can handle in one forward
167
- if converted_waves_16k.size(-1) <= 16000 * 30:
168
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
169
- return_tensors="pt",
170
- return_attention_mask=True,
171
- sampling_rate=16000)
172
- alt_input_features = whisper_model._mask_input_features(
173
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
174
- alt_outputs = whisper_model.encoder(
175
- alt_input_features.to(whisper_model.encoder.dtype),
176
- head_mask=None,
177
- output_attentions=False,
178
- output_hidden_states=False,
179
- return_dict=True,
180
- )
181
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
182
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
183
- else:
184
- overlapping_time = 5 # 5 seconds
185
- S_alt_list = []
186
- buffer = None
187
- traversed_time = 0
188
- while traversed_time < converted_waves_16k.size(-1):
189
- if buffer is None: # first chunk
190
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
191
- else:
192
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
193
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
194
- return_tensors="pt",
195
- return_attention_mask=True,
196
- sampling_rate=16000)
197
- alt_input_features = whisper_model._mask_input_features(
198
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
199
- alt_outputs = whisper_model.encoder(
200
- alt_input_features.to(whisper_model.encoder.dtype),
201
- head_mask=None,
202
- output_attentions=False,
203
- output_hidden_states=False,
204
- return_dict=True,
205
- )
206
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
207
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
208
- if traversed_time == 0:
209
- S_alt_list.append(S_alt)
210
- else:
211
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
212
- buffer = chunk[:, -16000 * overlapping_time:]
213
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
214
- S_alt = torch.cat(S_alt_list, dim=1)
215
-
216
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
217
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
218
- return_tensors="pt",
219
- return_attention_mask=True)
220
- ori_input_features = whisper_model._mask_input_features(
221
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
222
- with torch.no_grad():
223
- ori_outputs = whisper_model.encoder(
224
- ori_input_features.to(whisper_model.encoder.dtype),
225
- head_mask=None,
226
- output_attentions=False,
227
- output_hidden_states=False,
228
- return_dict=True,
229
- )
230
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
231
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
232
-
233
- mel = mel_fn(source_audio.to(device).float())
234
- mel2 = mel_fn(ref_audio.to(device).float())
235
-
236
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
237
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
238
-
239
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
240
- num_mel_bins=80,
241
- dither=0,
242
- sample_frequency=16000)
243
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
244
- style2 = campplus_model(feat2.unsqueeze(0))
245
-
246
- if f0_condition:
247
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
248
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
249
-
250
- F0_ori = torch.from_numpy(F0_ori).to(device)[None]
251
- F0_alt = torch.from_numpy(F0_alt).to(device)[None]
252
-
253
- voiced_F0_ori = F0_ori[F0_ori > 1]
254
- voiced_F0_alt = F0_alt[F0_alt > 1]
255
-
256
- log_f0_alt = torch.log(F0_alt + 1e-5)
257
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
258
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
259
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
260
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
261
-
262
- # shift alt log f0 level to ori log f0 level
263
- shifted_log_f0_alt = log_f0_alt.clone()
264
- if auto_f0_adjust:
265
- shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
266
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
267
- if pitch_shift != 0:
268
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
269
- else:
270
- F0_ori = None
271
- F0_alt = None
272
- shifted_f0_alt = None
273
-
274
- # Length regulation
275
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
276
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
277
-
278
- max_source_window = max_context_window - mel2.size(2)
279
- # split source condition (cond) into chunks
280
- processed_frames = 0
281
- generated_wave_chunks = []
282
- # generate chunk by chunk and stream the output
283
- while processed_frames < cond.size(1):
284
- chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
285
- is_last_chunk = processed_frames + max_source_window >= cond.size(1)
286
- cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
287
- with torch.autocast(device_type='cuda', dtype=torch.float16):
288
- # Voice Conversion
289
- vc_target = inference_module.cfm.inference(cat_condition,
290
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
291
- mel2, style2, None, diffusion_steps,
292
- inference_cfg_rate=inference_cfg_rate)
293
- vc_target = vc_target[:, :, mel2.size(-1):]
294
- vc_wave = bigvgan_fn(vc_target.float())[0]
295
- if processed_frames == 0:
296
- if is_last_chunk:
297
- output_wave = vc_wave[0].cpu().numpy()
298
- generated_wave_chunks.append(output_wave)
299
- output_wave = (output_wave * 32768.0).astype(np.int16)
300
- mp3_bytes = AudioSegment(
301
- output_wave.tobytes(), frame_rate=sr,
302
- sample_width=output_wave.dtype.itemsize, channels=1
303
- ).export(format="mp3", bitrate=bitrate).read()
304
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
305
- break
306
- output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
307
- generated_wave_chunks.append(output_wave)
308
- previous_chunk = vc_wave[0, -overlap_wave_len:]
309
- processed_frames += vc_target.size(2) - overlap_frame_len
310
- output_wave = (output_wave * 32768.0).astype(np.int16)
311
- mp3_bytes = AudioSegment(
312
- output_wave.tobytes(), frame_rate=sr,
313
- sample_width=output_wave.dtype.itemsize, channels=1
314
- ).export(format="mp3", bitrate=bitrate).read()
315
- yield mp3_bytes, None
316
- elif is_last_chunk:
317
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
318
- generated_wave_chunks.append(output_wave)
319
- processed_frames += vc_target.size(2) - overlap_frame_len
320
- output_wave = (output_wave * 32768.0).astype(np.int16)
321
- mp3_bytes = AudioSegment(
322
- output_wave.tobytes(), frame_rate=sr,
323
- sample_width=output_wave.dtype.itemsize, channels=1
324
- ).export(format="mp3", bitrate=bitrate).read()
325
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
326
- break
327
- else:
328
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
329
- generated_wave_chunks.append(output_wave)
330
- previous_chunk = vc_wave[0, -overlap_wave_len:]
331
- processed_frames += vc_target.size(2) - overlap_frame_len
332
- output_wave = (output_wave * 32768.0).astype(np.int16)
333
- mp3_bytes = AudioSegment(
334
- output_wave.tobytes(), frame_rate=sr,
335
- sample_width=output_wave.dtype.itemsize, channels=1
336
- ).export(format="mp3", bitrate=bitrate).read()
337
- yield mp3_bytes, None
338
-
339
- import os
340
- import spaces
341
- import gradio as gr
342
- import torch
343
- import torchaudio
344
- import librosa
345
- from modules.commons import build_model, load_checkpoint, recursive_munch
346
- import yaml
347
- from hf_utils import load_custom_model_from_hf
348
- import numpy as np
349
- from pydub import AudioSegment
350
-
351
- # Load model and configuration
352
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
353
-
354
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
355
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
356
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
357
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
358
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
359
- config = yaml.safe_load(open(dit_config_path, 'r'))
360
- model_params = recursive_munch(config['model_params'])
361
- model = build_model(model_params, stage='DiT')
362
- hop_length = config['preprocess_params']['spect_params']['hop_length']
363
- sr = config['preprocess_params']['sr']
364
-
365
- # Load checkpoints
366
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
367
- load_only_params=True, ignore_modules=[], is_distributed=False)
368
- for key in model:
369
- model[key].eval()
370
- model[key].to(device)
371
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
372
-
373
- # Load additional modules
374
- from modules.campplus.DTDNN import CAMPPlus
375
-
376
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
377
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
378
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
379
- campplus_model.eval()
380
- campplus_model.to(device)
381
-
382
- from modules.bigvgan import bigvgan
383
-
384
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
385
-
386
- # remove weight norm in the model and set to eval mode
387
- bigvgan_model.remove_weight_norm()
388
- bigvgan_model = bigvgan_model.eval().to(device)
389
-
390
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
391
-
392
- codec_config = yaml.safe_load(open(config_path))
393
- codec_model_params = recursive_munch(codec_config['model_params'])
394
- codec_encoder = build_model(codec_model_params, stage="codec")
395
-
396
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
397
-
398
- for key in codec_encoder:
399
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
400
- _ = [codec_encoder[key].eval() for key in codec_encoder]
401
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
402
-
403
- # whisper
404
- from transformers import AutoFeatureExtractor, WhisperModel
405
-
406
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
407
- 'whisper_name') else "openai/whisper-small"
408
- whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
409
- del whisper_model.decoder
410
- whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
411
-
412
- # Generate mel spectrograms
413
- mel_fn_args = {
414
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
415
- "win_size": config['preprocess_params']['spect_params']['win_length'],
416
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
417
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
418
- "sampling_rate": sr,
419
- "fmin": 0,
420
- "fmax": None,
421
- "center": False
422
- }
423
- from modules.audio import mel_spectrogram
424
-
425
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
426
-
427
- # f0 conditioned model
428
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
429
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
430
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
431
-
432
- config = yaml.safe_load(open(dit_config_path, 'r'))
433
- model_params = recursive_munch(config['model_params'])
434
- model_f0 = build_model(model_params, stage='DiT')
435
- hop_length = config['preprocess_params']['spect_params']['hop_length']
436
- sr = config['preprocess_params']['sr']
437
-
438
- # Load checkpoints
439
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
440
- load_only_params=True, ignore_modules=[], is_distributed=False)
441
- for key in model_f0:
442
- model_f0[key].eval()
443
- model_f0[key].to(device)
444
- model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
445
-
446
- # f0 extractor
447
- from modules.rmvpe import RMVPE
448
-
449
- model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
450
- rmvpe = RMVPE(model_path, is_half=False, device=device)
451
-
452
- mel_fn_args_f0 = {
453
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
454
- "win_size": config['preprocess_params']['spect_params']['win_length'],
455
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
456
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
457
- "sampling_rate": sr,
458
- "fmin": 0,
459
- "fmax": None,
460
- "center": False
461
- }
462
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
463
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
464
-
465
- # remove weight norm in the model and set to eval mode
466
- bigvgan_44k_model.remove_weight_norm()
467
- bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
468
-
469
- def adjust_f0_semitones(f0_sequence, n_semitones):
470
- factor = 2 ** (n_semitones / 12)
471
- return f0_sequence * factor
472
-
473
- def crossfade(chunk1, chunk2, overlap):
474
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
475
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
476
- chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
477
- return chunk2
478
-
479
- # streaming and chunk processing related params
480
- bitrate = "320k"
481
- overlap_frame_len = 16
482
- @spaces.GPU
483
- @torch.no_grad()
484
- @torch.inference_mode()
485
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
486
- inference_module = model if not f0_condition else model_f0
487
- mel_fn = to_mel if not f0_condition else to_mel_f0
488
- bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
489
- sr = 22050 if not f0_condition else 44100
490
- hop_length = 256 if not f0_condition else 512
491
- max_context_window = sr // hop_length * 30
492
- overlap_wave_len = overlap_frame_len * hop_length
493
- # Load audio
494
- source_audio = librosa.load(source, sr=sr)[0]
495
- ref_audio = librosa.load(target, sr=sr)[0]
496
-
497
- # Process audio
498
- source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
499
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
500
-
501
- # Resample
502
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
503
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
504
- # if source audio less than 30 seconds, whisper can handle in one forward
505
- if converted_waves_16k.size(-1) <= 16000 * 30:
506
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
507
- return_tensors="pt",
508
- return_attention_mask=True,
509
- sampling_rate=16000)
510
- alt_input_features = whisper_model._mask_input_features(
511
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
512
- alt_outputs = whisper_model.encoder(
513
- alt_input_features.to(whisper_model.encoder.dtype),
514
- head_mask=None,
515
- output_attentions=False,
516
- output_hidden_states=False,
517
- return_dict=True,
518
- )
519
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
520
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
521
- else:
522
- overlapping_time = 5 # 5 seconds
523
- S_alt_list = []
524
- buffer = None
525
- traversed_time = 0
526
- while traversed_time < converted_waves_16k.size(-1):
527
- if buffer is None: # first chunk
528
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
529
- else:
530
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
531
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
532
- return_tensors="pt",
533
- return_attention_mask=True,
534
- sampling_rate=16000)
535
- alt_input_features = whisper_model._mask_input_features(
536
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
537
- alt_outputs = whisper_model.encoder(
538
- alt_input_features.to(whisper_model.encoder.dtype),
539
- head_mask=None,
540
- output_attentions=False,
541
- output_hidden_states=False,
542
- return_dict=True,
543
- )
544
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
545
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
546
- if traversed_time == 0:
547
- S_alt_list.append(S_alt)
548
- else:
549
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
550
- buffer = chunk[:, -16000 * overlapping_time:]
551
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
552
- S_alt = torch.cat(S_alt_list, dim=1)
553
-
554
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
555
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
556
- return_tensors="pt",
557
- return_attention_mask=True)
558
- ori_input_features = whisper_model._mask_input_features(
559
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
560
- with torch.no_grad():
561
- ori_outputs = whisper_model.encoder(
562
- ori_input_features.to(whisper_model.encoder.dtype),
563
- head_mask=None,
564
- output_attentions=False,
565
- output_hidden_states=False,
566
- return_dict=True,
567
- )
568
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
569
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
570
-
571
- mel = mel_fn(source_audio.to(device).float())
572
- mel2 = mel_fn(ref_audio.to(device).float())
573
-
574
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
575
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
576
-
577
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
578
- num_mel_bins=80,
579
- dither=0,
580
- sample_frequency=16000)
581
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
582
- style2 = campplus_model(feat2.unsqueeze(0))
583
-
584
- if f0_condition:
585
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
586
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
587
-
588
- F0_ori = torch.from_numpy(F0_ori).to(device)[None]
589
- F0_alt = torch.from_numpy(F0_alt).to(device)[None]
590
-
591
- voiced_F0_ori = F0_ori[F0_ori > 1]
592
- voiced_F0_alt = F0_alt[F0_alt > 1]
593
-
594
- log_f0_alt = torch.log(F0_alt + 1e-5)
595
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
596
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
597
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
598
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
599
-
600
- # shift alt log f0 level to ori log f0 level
601
- shifted_log_f0_alt = log_f0_alt.clone()
602
- if auto_f0_adjust:
603
- shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
604
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
605
- if pitch_shift != 0:
606
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
607
- else:
608
- F0_ori = None
609
- F0_alt = None
610
- shifted_f0_alt = None
611
-
612
- # Length regulation
613
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
614
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
615
-
616
- max_source_window = max_context_window - mel2.size(2)
617
- # split source condition (cond) into chunks
618
- processed_frames = 0
619
- generated_wave_chunks = []
620
- # generate chunk by chunk and stream the output
621
- while processed_frames < cond.size(1):
622
- chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
623
- is_last_chunk = processed_frames + max_source_window >= cond.size(1)
624
- cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
625
- with torch.autocast(device_type='cuda', dtype=torch.float16):
626
- # Voice Conversion
627
- vc_target = inference_module.cfm.inference(cat_condition,
628
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
629
- mel2, style2, None, diffusion_steps,
630
- inference_cfg_rate=inference_cfg_rate)
631
- vc_target = vc_target[:, :, mel2.size(-1):]
632
- vc_wave = bigvgan_fn(vc_target.float())[0]
633
- if processed_frames == 0:
634
- if is_last_chunk:
635
- output_wave = vc_wave[0].cpu().numpy()
636
- generated_wave_chunks.append(output_wave)
637
- output_wave = (output_wave * 32768.0).astype(np.int16)
638
- mp3_bytes = AudioSegment(
639
- output_wave.tobytes(), frame_rate=sr,
640
- sample_width=output_wave.dtype.itemsize, channels=1
641
- ).export(format="mp3", bitrate=bitrate).read()
642
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
643
- break
644
- output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
645
- generated_wave_chunks.append(output_wave)
646
- previous_chunk = vc_wave[0, -overlap_wave_len:]
647
- processed_frames += vc_target.size(2) - overlap_frame_len
648
- output_wave = (output_wave * 32768.0).astype(np.int16)
649
- mp3_bytes = AudioSegment(
650
- output_wave.tobytes(), frame_rate=sr,
651
- sample_width=output_wave.dtype.itemsize, channels=1
652
- ).export(format="mp3", bitrate=bitrate).read()
653
- yield mp3_bytes, None
654
- elif is_last_chunk:
655
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
656
- generated_wave_chunks.append(output_wave)
657
- processed_frames += vc_target.size(2) - overlap_frame_len
658
- output_wave = (output_wave * 32768.0).astype(np.int16)
659
- mp3_bytes = AudioSegment(
660
- output_wave.tobytes(), frame_rate=sr,
661
- sample_width=output_wave.dtype.itemsize, channels=1
662
- ).export(format="mp3", bitrate=bitrate).read()
663
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
664
- break
665
- else:
666
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
667
- generated_wave_chunks.append(output_wave)
668
- previous_chunk = vc_wave[0, -overlap_wave_len:]
669
- processed_frames += vc_target.size(2) - overlap_frame_len
670
- output_wave = (output_wave * 32768.0).astype(np.int16)
671
- mp3_bytes = AudioSegment(
672
- output_wave.tobytes(), frame_rate=sr,
673
- sample_width=output_wave.dtype.itemsize, channels=1
674
- ).export(format="mp3", bitrate=bitrate).read()
675
- yield mp3_bytes, None
676
-
677
-
678
-
679
-
680
-
681
  import gradio as gr
682
 
683
  gallery_data = {"sikokumetan": {"webp": "default/sikokumetan.webp", "mp3": "default/sikokumetan.mp3"}}
@@ -686,7 +6,9 @@ def update_reference(evt: gr.SelectData):
686
  selected_image = evt.value
687
  for key, value in gallery_data.items():
688
  if value["webp"] == selected_image:
 
689
  return value["mp3"]
 
690
  return ""
691
 
692
  if __name__ == "__main__":
@@ -724,4 +46,4 @@ if __name__ == "__main__":
724
 
725
  gallery.select(update_reference, outputs=inputs[1])
726
 
727
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  gallery_data = {"sikokumetan": {"webp": "default/sikokumetan.webp", "mp3": "default/sikokumetan.mp3"}}
 
6
  selected_image = evt.value
7
  for key, value in gallery_data.items():
8
  if value["webp"] == selected_image:
9
+ print(f"選択された画像: {selected_image}, 対応するMP3: {value['mp3']}")
10
  return value["mp3"]
11
+ print("対応するMP3が見つかりませんでした。")
12
  return ""
13
 
14
  if __name__ == "__main__":
 
46
 
47
  gallery.select(update_reference, outputs=inputs[1])
48
 
49
+ interface.launch()