Hemant0000 commited on
Commit
695dc3e
·
verified ·
1 Parent(s): e815215

Upload 12 files

Browse files
src/f5_tts/eval/src_f5_tts_eval_README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Evaluation
3
+
4
+ Install packages for evaluation:
5
+
6
+ ```bash
7
+ pip install -e .[eval]
8
+ ```
9
+
10
+ ## Generating Samples for Evaluation
11
+
12
+ ### Prepare Test Datasets
13
+
14
+ 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
15
+ 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
16
+ 3. Unzip the downloaded datasets and place them in the `data/` directory.
17
+ 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
18
+ 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
19
+
20
+ ### Batch Inference for Test Set
21
+
22
+ To run batch inference for evaluations, execute the following commands:
23
+
24
+ ```bash
25
+ # batch inference for evaluations
26
+ accelerate config # if not set before
27
+ bash src/f5_tts/eval/eval_infer_batch.sh
28
+ ```
29
+
30
+ ## Objective Evaluation on Generated Results
31
+
32
+ ### Download Evaluation Model Checkpoints
33
+
34
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
35
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
36
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
37
+
38
+ Then update in the following scripts with the paths you put evaluation model ckpts to.
39
+
40
+ ### Objective Evaluation
41
+
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
+ ```bash
44
+ # Evaluation for Seed-TTS test set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py
46
+
47
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py
49
+ ```
src/f5_tts/eval/src_f5_tts_eval_ecapa_tdnn.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """ Res2Conv1d + BatchNorm1d + ReLU
13
+ """
14
+
15
+
16
+ class Res2Conv1dReluBn(nn.Module):
17
+ """
18
+ in_channels == out_channels == channels
19
+ """
20
+
21
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
+ super().__init__()
23
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
24
+ self.scale = scale
25
+ self.width = channels // scale
26
+ self.nums = scale if scale == 1 else scale - 1
27
+
28
+ self.convs = []
29
+ self.bns = []
30
+ for i in range(self.nums):
31
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
32
+ self.bns.append(nn.BatchNorm1d(self.width))
33
+ self.convs = nn.ModuleList(self.convs)
34
+ self.bns = nn.ModuleList(self.bns)
35
+
36
+ def forward(self, x):
37
+ out = []
38
+ spx = torch.split(x, self.width, 1)
39
+ for i in range(self.nums):
40
+ if i == 0:
41
+ sp = spx[i]
42
+ else:
43
+ sp = sp + spx[i]
44
+ # Order: conv -> relu -> bn
45
+ sp = self.convs[i](sp)
46
+ sp = self.bns[i](F.relu(sp))
47
+ out.append(sp)
48
+ if self.scale != 1:
49
+ out.append(spx[self.nums])
50
+ out = torch.cat(out, dim=1)
51
+
52
+ return out
53
+
54
+
55
+ """ Conv1d + BatchNorm1d + ReLU
56
+ """
57
+
58
+
59
+ class Conv1dReluBn(nn.Module):
60
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
61
+ super().__init__()
62
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
63
+ self.bn = nn.BatchNorm1d(out_channels)
64
+
65
+ def forward(self, x):
66
+ return self.bn(F.relu(self.conv(x)))
67
+
68
+
69
+ """ The SE connection of 1D case.
70
+ """
71
+
72
+
73
+ class SE_Connect(nn.Module):
74
+ def __init__(self, channels, se_bottleneck_dim=128):
75
+ super().__init__()
76
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
77
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
78
+
79
+ def forward(self, x):
80
+ out = x.mean(dim=2)
81
+ out = F.relu(self.linear1(out))
82
+ out = torch.sigmoid(self.linear2(out))
83
+ out = x * out.unsqueeze(2)
84
+
85
+ return out
86
+
87
+
88
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
89
+ """
90
+
91
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
+ # return nn.Sequential(
93
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
94
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
95
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
96
+ # SE_Connect(channels)
97
+ # )
98
+
99
+
100
+ class SE_Res2Block(nn.Module):
101
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
+ super().__init__()
103
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
104
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
105
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
106
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
107
+
108
+ self.shortcut = None
109
+ if in_channels != out_channels:
110
+ self.shortcut = nn.Conv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=1,
114
+ )
115
+
116
+ def forward(self, x):
117
+ residual = x
118
+ if self.shortcut:
119
+ residual = self.shortcut(x)
120
+
121
+ x = self.Conv1dReluBn1(x)
122
+ x = self.Res2Conv1dReluBn(x)
123
+ x = self.Conv1dReluBn2(x)
124
+ x = self.SE_Connect(x)
125
+
126
+ return x + residual
127
+
128
+
129
+ """ Attentive weighted mean and standard deviation pooling.
130
+ """
131
+
132
+
133
+ class AttentiveStatsPool(nn.Module):
134
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
135
+ super().__init__()
136
+ self.global_context_att = global_context_att
137
+
138
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
139
+ if global_context_att:
140
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
141
+ else:
142
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
143
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
+
145
+ def forward(self, x):
146
+ if self.global_context_att:
147
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
150
+ else:
151
+ x_in = x
152
+
153
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154
+ alpha = torch.tanh(self.linear1(x_in))
155
+ # alpha = F.relu(self.linear1(x_in))
156
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
157
+ mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
+ std = torch.sqrt(residuals.clamp(min=1e-9))
160
+ return torch.cat([mean, std], dim=1)
161
+
162
+
163
+ class ECAPA_TDNN(nn.Module):
164
+ def __init__(
165
+ self,
166
+ feat_dim=80,
167
+ channels=512,
168
+ emb_dim=192,
169
+ global_context_att=False,
170
+ feat_type="wavlm_large",
171
+ sr=16000,
172
+ feature_selection="hidden_states",
173
+ update_extract=False,
174
+ config_path=None,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.feat_type = feat_type
179
+ self.feature_selection = feature_selection
180
+ self.update_extract = update_extract
181
+ self.sr = sr
182
+
183
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
+ try:
185
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
+ except: # noqa: E722
188
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
+
190
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
+ self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
+ ):
193
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
+ self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
+ ):
197
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
+
199
+ self.feat_num = self.get_feat_num()
200
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
+
202
+ if feat_type != "fbank" and feat_type != "mfcc":
203
+ freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
+ for name, param in self.feature_extract.named_parameters():
205
+ for freeze_val in freeze_list:
206
+ if freeze_val in name:
207
+ param.requires_grad = False
208
+ break
209
+
210
+ if not self.update_extract:
211
+ for param in self.feature_extract.parameters():
212
+ param.requires_grad = False
213
+
214
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
215
+ # self.channels = [channels] * 4 + [channels * 3]
216
+ self.channels = [channels] * 4 + [1536]
217
+
218
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
+ self.layer2 = SE_Res2Block(
220
+ self.channels[0],
221
+ self.channels[1],
222
+ kernel_size=3,
223
+ stride=1,
224
+ padding=2,
225
+ dilation=2,
226
+ scale=8,
227
+ se_bottleneck_dim=128,
228
+ )
229
+ self.layer3 = SE_Res2Block(
230
+ self.channels[1],
231
+ self.channels[2],
232
+ kernel_size=3,
233
+ stride=1,
234
+ padding=3,
235
+ dilation=3,
236
+ scale=8,
237
+ se_bottleneck_dim=128,
238
+ )
239
+ self.layer4 = SE_Res2Block(
240
+ self.channels[2],
241
+ self.channels[3],
242
+ kernel_size=3,
243
+ stride=1,
244
+ padding=4,
245
+ dilation=4,
246
+ scale=8,
247
+ se_bottleneck_dim=128,
248
+ )
249
+
250
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
+ cat_channels = channels * 3
252
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
+ self.pooling = AttentiveStatsPool(
254
+ self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
+ )
256
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
+
259
+ def get_feat_num(self):
260
+ self.feature_extract.eval()
261
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
262
+ with torch.no_grad():
263
+ features = self.feature_extract(wav)
264
+ select_feature = features[self.feature_selection]
265
+ if isinstance(select_feature, (list, tuple)):
266
+ return len(select_feature)
267
+ else:
268
+ return 1
269
+
270
+ def get_feat(self, x):
271
+ if self.update_extract:
272
+ x = self.feature_extract([sample for sample in x])
273
+ else:
274
+ with torch.no_grad():
275
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
+ else:
278
+ x = self.feature_extract([sample for sample in x])
279
+
280
+ if self.feat_type == "fbank":
281
+ x = x.log()
282
+
283
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
284
+ x = x[self.feature_selection]
285
+ if isinstance(x, (list, tuple)):
286
+ x = torch.stack(x, dim=0)
287
+ else:
288
+ x = x.unsqueeze(0)
289
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
290
+ x = (norm_weights * x).sum(dim=0)
291
+ x = torch.transpose(x, 1, 2) + 1e-6
292
+
293
+ x = self.instance_norm(x)
294
+ return x
295
+
296
+ def forward(self, x):
297
+ x = self.get_feat(x)
298
+
299
+ out1 = self.layer1(x)
300
+ out2 = self.layer2(out1)
301
+ out3 = self.layer3(out2)
302
+ out4 = self.layer4(out3)
303
+
304
+ out = torch.cat([out2, out3, out4], dim=1)
305
+ out = F.relu(self.conv(out))
306
+ out = self.bn(self.pooling(out))
307
+ out = self.linear(out)
308
+
309
+ return out
310
+
311
+
312
+ def ECAPA_TDNN_SMALL(
313
+ feat_dim,
314
+ emb_dim=256,
315
+ feat_type="wavlm_large",
316
+ sr=16000,
317
+ feature_selection="hidden_states",
318
+ update_extract=False,
319
+ config_path=None,
320
+ ):
321
+ return ECAPA_TDNN(
322
+ feat_dim=feat_dim,
323
+ channels=512,
324
+ emb_dim=emb_dim,
325
+ feat_type=feat_type,
326
+ sr=sr,
327
+ feature_selection=feature_selection,
328
+ update_extract=update_extract,
329
+ config_path=config_path,
330
+ )
src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import argparse
7
+ import time
8
+ from importlib.resources import files
9
+
10
+ import torch
11
+ import torchaudio
12
+ from accelerate import Accelerator
13
+ from tqdm import tqdm
14
+
15
+ from f5_tts.eval.utils_eval import (
16
+ get_inference_prompt,
17
+ get_librispeech_test_clean_metainfo,
18
+ get_seedtts_testset_metainfo,
19
+ )
20
+ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
+ from f5_tts.model import CFM, DiT, UNetT
22
+ from f5_tts.model.utils import get_tokenizer
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ win_length = 1024
34
+ n_fft = 1024
35
+ target_rms = 0.1
36
+
37
+
38
+ tokenizer = "pinyin"
39
+ rel_path = str(files("f5_tts").joinpath("../../"))
40
+
41
+
42
+ def main():
43
+ # ---------------------- infer setting ---------------------- #
44
+
45
+ parser = argparse.ArgumentParser(description="batch inference")
46
+
47
+ parser.add_argument("-s", "--seed", default=None, type=int)
48
+ parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
49
+ parser.add_argument("-n", "--expname", required=True)
50
+ parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
51
+ parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
52
+
53
+ parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54
+ parser.add_argument("-o", "--odemethod", default="euler")
55
+ parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
56
+
57
+ parser.add_argument("-t", "--testset", required=True)
58
+
59
+ args = parser.parse_args()
60
+
61
+ seed = args.seed
62
+ dataset_name = args.dataset
63
+ exp_name = args.expname
64
+ ckpt_step = args.ckptstep
65
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
66
+ mel_spec_type = args.mel_spec_type
67
+
68
+ nfe_step = args.nfestep
69
+ ode_method = args.odemethod
70
+ sway_sampling_coef = args.swaysampling
71
+
72
+ testset = args.testset
73
+
74
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
75
+ cfg_strength = 2.0
76
+ speed = 1.0
77
+ use_truth_duration = False
78
+ no_ref_audio = False
79
+
80
+ if exp_name == "F5TTS_Base":
81
+ model_cls = DiT
82
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83
+
84
+ elif exp_name == "E2TTS_Base":
85
+ model_cls = UNetT
86
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
87
+
88
+ if testset == "ls_pc_test_clean":
89
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
90
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
91
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
92
+
93
+ elif testset == "seedtts_test_zh":
94
+ metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
95
+ metainfo = get_seedtts_testset_metainfo(metalst)
96
+
97
+ elif testset == "seedtts_test_en":
98
+ metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
99
+ metainfo = get_seedtts_testset_metainfo(metalst)
100
+
101
+ # path to save genereted wavs
102
+ output_dir = (
103
+ f"{rel_path}/"
104
+ f"results/{exp_name}_{ckpt_step}/{testset}/"
105
+ f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
106
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
+ f"_cfg{cfg_strength}_speed{speed}"
108
+ f"{'_gt-dur' if use_truth_duration else ''}"
109
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
110
+ )
111
+
112
+ # -------------------------------------------------#
113
+
114
+ use_ema = True
115
+
116
+ prompts_all = get_inference_prompt(
117
+ metainfo,
118
+ speed=speed,
119
+ tokenizer=tokenizer,
120
+ target_sample_rate=target_sample_rate,
121
+ n_mel_channels=n_mel_channels,
122
+ hop_length=hop_length,
123
+ mel_spec_type=mel_spec_type,
124
+ target_rms=target_rms,
125
+ use_truth_duration=use_truth_duration,
126
+ infer_batch_size=infer_batch_size,
127
+ )
128
+
129
+ # Vocoder model
130
+ local = False
131
+ if mel_spec_type == "vocos":
132
+ vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
133
+ elif mel_spec_type == "bigvgan":
134
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
135
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
136
+
137
+ # Tokenizer
138
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
139
+
140
+ # Model
141
+ model = CFM(
142
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143
+ mel_spec_kwargs=dict(
144
+ n_fft=n_fft,
145
+ hop_length=hop_length,
146
+ win_length=win_length,
147
+ n_mel_channels=n_mel_channels,
148
+ target_sample_rate=target_sample_rate,
149
+ mel_spec_type=mel_spec_type,
150
+ ),
151
+ odeint_kwargs=dict(
152
+ method=ode_method,
153
+ ),
154
+ vocab_char_map=vocab_char_map,
155
+ ).to(device)
156
+
157
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159
+
160
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
161
+ os.makedirs(output_dir)
162
+
163
+ # start batch inference
164
+ accelerator.wait_for_everyone()
165
+ start = time.time()
166
+
167
+ with accelerator.split_between_processes(prompts_all) as prompts:
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond=ref_mels,
178
+ text=final_text_list,
179
+ duration=total_mel_lens,
180
+ lens=ref_mel_lens,
181
+ steps=nfe_step,
182
+ cfg_strength=cfg_strength,
183
+ sway_sampling_coef=sway_sampling_coef,
184
+ no_ref_audio=no_ref_audio,
185
+ seed=seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = gen.permute(0, 2, 1)
191
+ if mel_spec_type == "vocos":
192
+ generated_wave = vocoder.decode(gen_mel_spec)
193
+ elif mel_spec_type == "bigvgan":
194
+ generated_wave = vocoder(gen_mel_spec)
195
+
196
+ if ref_rms_list[i] < target_rms:
197
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
199
+
200
+ accelerator.wait_for_everyone()
201
+ if accelerator.is_main_process:
202
+ timediff = time.time() - start
203
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()
src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
src/f5_tts/eval/src_f5_tts_eval_eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys
4
+ import os
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ import multiprocessing as mp
9
+ from importlib.resources import files
10
+
11
+ import numpy as np
12
+
13
+ from f5_tts.eval.utils_eval import (
14
+ get_librispeech_test,
15
+ run_asr_wer,
16
+ run_sim,
17
+ )
18
+
19
+ rel_path = str(files("f5_tts").joinpath("../../"))
20
+
21
+
22
+ eval_task = "wer" # sim | wer
23
+ lang = "en"
24
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
25
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
26
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
+
28
+ gpus = [0, 1, 2, 3, 4, 5, 6, 7]
29
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
30
+
31
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
32
+ ## leading to a low similarity for the ground truth in some cases.
33
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
34
+
35
+ local = False
36
+ if local: # use local custom checkpoint dir
37
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
38
+ else:
39
+ asr_ckpt_dir = "" # auto download to cache dir
40
+
41
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
42
+
43
+
44
+ # --------------------------- WER ---------------------------
45
+
46
+ if eval_task == "wer":
47
+ wers = []
48
+
49
+ with mp.Pool(processes=len(gpus)) as pool:
50
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
51
+ results = pool.map(run_asr_wer, args)
52
+ for wers_ in results:
53
+ wers.extend(wers_)
54
+
55
+ wer = round(np.mean(wers) * 100, 3)
56
+ print(f"\nTotal {len(wers)} samples")
57
+ print(f"WER : {wer}%")
58
+
59
+
60
+ # --------------------------- SIM ---------------------------
61
+
62
+ if eval_task == "sim":
63
+ sim_list = []
64
+
65
+ with mp.Pool(processes=len(gpus)) as pool:
66
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
67
+ results = pool.map(run_sim, args)
68
+ for sim_ in results:
69
+ sim_list.extend(sim_)
70
+
71
+ sim = round(sum(sim_list) / len(sim_list), 3)
72
+ print(f"\nTotal {len(sim_list)} samples")
73
+ print(f"SIM : {sim}")
src/f5_tts/eval/src_f5_tts_eval_eval_seedtts_testset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys
4
+ import os
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ import multiprocessing as mp
9
+ from importlib.resources import files
10
+
11
+ import numpy as np
12
+
13
+ from f5_tts.eval.utils_eval import (
14
+ get_seed_tts_test,
15
+ run_asr_wer,
16
+ run_sim,
17
+ )
18
+
19
+ rel_path = str(files("f5_tts").joinpath("../../"))
20
+
21
+
22
+ eval_task = "wer" # sim | wer
23
+ lang = "zh" # zh | en
24
+ metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
25
+ # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
26
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
+
28
+
29
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
30
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
31
+ gpus = [0, 1, 2, 3, 4, 5, 6, 7]
32
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
33
+
34
+ local = False
35
+ if local: # use local custom checkpoint dir
36
+ if lang == "zh":
37
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
38
+ elif lang == "en":
39
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
40
+ else:
41
+ asr_ckpt_dir = "" # auto download to cache dir
42
+
43
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
44
+
45
+
46
+ # --------------------------- WER ---------------------------
47
+
48
+ if eval_task == "wer":
49
+ wers = []
50
+
51
+ with mp.Pool(processes=len(gpus)) as pool:
52
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
53
+ results = pool.map(run_asr_wer, args)
54
+ for wers_ in results:
55
+ wers.extend(wers_)
56
+
57
+ wer = round(np.mean(wers) * 100, 3)
58
+ print(f"\nTotal {len(wers)} samples")
59
+ print(f"WER : {wer}%")
60
+
61
+
62
+ # --------------------------- SIM ---------------------------
63
+
64
+ if eval_task == "sim":
65
+ sim_list = []
66
+
67
+ with mp.Pool(processes=len(gpus)) as pool:
68
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
+ results = pool.map(run_sim, args)
70
+ for sim_ in results:
71
+ sim_list.extend(sim_)
72
+
73
+ sim = round(sum(sim_list) / len(sim_list), 3)
74
+ print(f"\nTotal {len(sim_list)} samples")
75
+ print(f"SIM : {sim}")
src/f5_tts/eval/src_f5_tts_eval_utils_eval.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import string
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tqdm import tqdm
10
+
11
+ from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
12
+ from f5_tts.model.modules import MelSpec
13
+ from f5_tts.model.utils import convert_char_to_pinyin
14
+
15
+
16
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
17
+ def get_seedtts_testset_metainfo(metalst):
18
+ f = open(metalst)
19
+ lines = f.readlines()
20
+ f.close()
21
+ metainfo = []
22
+ for line in lines:
23
+ if len(line.strip().split("|")) == 5:
24
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
25
+ elif len(line.strip().split("|")) == 4:
26
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
27
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
28
+ if not os.path.isabs(prompt_wav):
29
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
30
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
31
+ return metainfo
32
+
33
+
34
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
35
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
36
+ f = open(metalst)
37
+ lines = f.readlines()
38
+ f.close()
39
+ metainfo = []
40
+ for line in lines:
41
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
42
+
43
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
44
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
45
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
46
+
47
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
48
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
49
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
50
+
51
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
52
+
53
+ return metainfo
54
+
55
+
56
+ # padded to max length mel batch
57
+ def padded_mel_batch(ref_mels):
58
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
59
+ padded_ref_mels = []
60
+ for mel in ref_mels:
61
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
62
+ padded_ref_mels.append(padded_ref_mel)
63
+ padded_ref_mels = torch.stack(padded_ref_mels)
64
+ padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
65
+ return padded_ref_mels
66
+
67
+
68
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
69
+
70
+
71
+ def get_inference_prompt(
72
+ metainfo,
73
+ speed=1.0,
74
+ tokenizer="pinyin",
75
+ polyphone=True,
76
+ target_sample_rate=24000,
77
+ n_fft=1024,
78
+ win_length=1024,
79
+ n_mel_channels=100,
80
+ hop_length=256,
81
+ mel_spec_type="vocos",
82
+ target_rms=0.1,
83
+ use_truth_duration=False,
84
+ infer_batch_size=1,
85
+ num_buckets=200,
86
+ min_secs=3,
87
+ max_secs=40,
88
+ ):
89
+ prompts_all = []
90
+
91
+ min_tokens = min_secs * target_sample_rate // hop_length
92
+ max_tokens = max_secs * target_sample_rate // hop_length
93
+
94
+ batch_accum = [0] * num_buckets
95
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
96
+ [[] for _ in range(num_buckets)] for _ in range(6)
97
+ )
98
+
99
+ mel_spectrogram = MelSpec(
100
+ n_fft=n_fft,
101
+ hop_length=hop_length,
102
+ win_length=win_length,
103
+ n_mel_channels=n_mel_channels,
104
+ target_sample_rate=target_sample_rate,
105
+ mel_spec_type=mel_spec_type,
106
+ )
107
+
108
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
109
+ # Audio
110
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
111
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
112
+ if ref_rms < target_rms:
113
+ ref_audio = ref_audio * target_rms / ref_rms
114
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
115
+ if ref_sr != target_sample_rate:
116
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
117
+ ref_audio = resampler(ref_audio)
118
+
119
+ # Text
120
+ if len(prompt_text[-1].encode("utf-8")) == 1:
121
+ prompt_text = prompt_text + " "
122
+ text = [prompt_text + gt_text]
123
+ if tokenizer == "pinyin":
124
+ text_list = convert_char_to_pinyin(text, polyphone=polyphone)
125
+ else:
126
+ text_list = text
127
+
128
+ # Duration, mel frame length
129
+ ref_mel_len = ref_audio.shape[-1] // hop_length
130
+ if use_truth_duration:
131
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
132
+ if gt_sr != target_sample_rate:
133
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
134
+ gt_audio = resampler(gt_audio)
135
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
136
+
137
+ # # test vocoder resynthesis
138
+ # ref_audio = gt_audio
139
+ else:
140
+ ref_text_len = len(prompt_text.encode("utf-8"))
141
+ gen_text_len = len(gt_text.encode("utf-8"))
142
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
143
+
144
+ # to mel spectrogram
145
+ ref_mel = mel_spectrogram(ref_audio)
146
+ ref_mel = ref_mel.squeeze(0)
147
+
148
+ # deal with batch
149
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
150
+ assert (
151
+ min_tokens <= total_mel_len <= max_tokens
152
+ ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
153
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
154
+
155
+ utts[bucket_i].append(utt)
156
+ ref_rms_list[bucket_i].append(ref_rms)
157
+ ref_mels[bucket_i].append(ref_mel)
158
+ ref_mel_lens[bucket_i].append(ref_mel_len)
159
+ total_mel_lens[bucket_i].append(total_mel_len)
160
+ final_text_list[bucket_i].extend(text_list)
161
+
162
+ batch_accum[bucket_i] += total_mel_len
163
+
164
+ if batch_accum[bucket_i] >= infer_batch_size:
165
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
166
+ prompts_all.append(
167
+ (
168
+ utts[bucket_i],
169
+ ref_rms_list[bucket_i],
170
+ padded_mel_batch(ref_mels[bucket_i]),
171
+ ref_mel_lens[bucket_i],
172
+ total_mel_lens[bucket_i],
173
+ final_text_list[bucket_i],
174
+ )
175
+ )
176
+ batch_accum[bucket_i] = 0
177
+ (
178
+ utts[bucket_i],
179
+ ref_rms_list[bucket_i],
180
+ ref_mels[bucket_i],
181
+ ref_mel_lens[bucket_i],
182
+ total_mel_lens[bucket_i],
183
+ final_text_list[bucket_i],
184
+ ) = [], [], [], [], [], []
185
+
186
+ # add residual
187
+ for bucket_i, bucket_frames in enumerate(batch_accum):
188
+ if bucket_frames > 0:
189
+ prompts_all.append(
190
+ (
191
+ utts[bucket_i],
192
+ ref_rms_list[bucket_i],
193
+ padded_mel_batch(ref_mels[bucket_i]),
194
+ ref_mel_lens[bucket_i],
195
+ total_mel_lens[bucket_i],
196
+ final_text_list[bucket_i],
197
+ )
198
+ )
199
+ # not only leave easy work for last workers
200
+ random.seed(666)
201
+ random.shuffle(prompts_all)
202
+
203
+ return prompts_all
204
+
205
+
206
+ # get wav_res_ref_text of seed-tts test metalst
207
+ # https://github.com/BytedanceSpeech/seed-tts-eval
208
+
209
+
210
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
211
+ f = open(metalst)
212
+ lines = f.readlines()
213
+ f.close()
214
+
215
+ test_set_ = []
216
+ for line in tqdm(lines):
217
+ if len(line.strip().split("|")) == 5:
218
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
219
+ elif len(line.strip().split("|")) == 4:
220
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
221
+
222
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
223
+ continue
224
+ gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
225
+ if not os.path.isabs(prompt_wav):
226
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
227
+
228
+ test_set_.append((gen_wav, prompt_wav, gt_text))
229
+
230
+ num_jobs = len(gpus)
231
+ if num_jobs == 1:
232
+ return [(gpus[0], test_set_)]
233
+
234
+ wav_per_job = len(test_set_) // num_jobs + 1
235
+ test_set = []
236
+ for i in range(num_jobs):
237
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
238
+
239
+ return test_set
240
+
241
+
242
+ # get librispeech test-clean cross sentence test
243
+
244
+
245
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
246
+ f = open(metalst)
247
+ lines = f.readlines()
248
+ f.close()
249
+
250
+ test_set_ = []
251
+ for line in tqdm(lines):
252
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
253
+
254
+ if eval_ground_truth:
255
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
256
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
257
+ else:
258
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
259
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
260
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
261
+
262
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
263
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
264
+
265
+ test_set_.append((gen_wav, ref_wav, gen_txt))
266
+
267
+ num_jobs = len(gpus)
268
+ if num_jobs == 1:
269
+ return [(gpus[0], test_set_)]
270
+
271
+ wav_per_job = len(test_set_) // num_jobs + 1
272
+ test_set = []
273
+ for i in range(num_jobs):
274
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
275
+
276
+ return test_set
277
+
278
+
279
+ # load asr model
280
+
281
+
282
+ def load_asr_model(lang, ckpt_dir=""):
283
+ if lang == "zh":
284
+ from funasr import AutoModel
285
+
286
+ model = AutoModel(
287
+ model=os.path.join(ckpt_dir, "paraformer-zh"),
288
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
289
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
290
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
291
+ disable_update=True,
292
+ ) # following seed-tts setting
293
+ elif lang == "en":
294
+ from faster_whisper import WhisperModel
295
+
296
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
297
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
298
+ return model
299
+
300
+
301
+ # WER Evaluation, the way Seed-TTS does
302
+
303
+
304
+ def run_asr_wer(args):
305
+ rank, lang, test_set, ckpt_dir = args
306
+
307
+ if lang == "zh":
308
+ import zhconv
309
+
310
+ torch.cuda.set_device(rank)
311
+ elif lang == "en":
312
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
313
+ else:
314
+ raise NotImplementedError(
315
+ "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
316
+ )
317
+
318
+ asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
319
+
320
+ from zhon.hanzi import punctuation
321
+
322
+ punctuation_all = punctuation + string.punctuation
323
+ wers = []
324
+
325
+ from jiwer import compute_measures
326
+
327
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
328
+ if lang == "zh":
329
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
330
+ hypo = res[0]["text"]
331
+ hypo = zhconv.convert(hypo, "zh-cn")
332
+ elif lang == "en":
333
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
334
+ hypo = ""
335
+ for segment in segments:
336
+ hypo = hypo + " " + segment.text
337
+
338
+ # raw_truth = truth
339
+ # raw_hypo = hypo
340
+
341
+ for x in punctuation_all:
342
+ truth = truth.replace(x, "")
343
+ hypo = hypo.replace(x, "")
344
+
345
+ truth = truth.replace(" ", " ")
346
+ hypo = hypo.replace(" ", " ")
347
+
348
+ if lang == "zh":
349
+ truth = " ".join([x for x in truth])
350
+ hypo = " ".join([x for x in hypo])
351
+ elif lang == "en":
352
+ truth = truth.lower()
353
+ hypo = hypo.lower()
354
+
355
+ measures = compute_measures(truth, hypo)
356
+ wer = measures["wer"]
357
+
358
+ # ref_list = truth.split(" ")
359
+ # subs = measures["substitutions"] / len(ref_list)
360
+ # dele = measures["deletions"] / len(ref_list)
361
+ # inse = measures["insertions"] / len(ref_list)
362
+
363
+ wers.append(wer)
364
+
365
+ return wers
366
+
367
+
368
+ # SIM Evaluation
369
+
370
+
371
+ def run_sim(args):
372
+ rank, test_set, ckpt_dir = args
373
+ device = f"cuda:{rank}"
374
+
375
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
376
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
377
+ model.load_state_dict(state_dict["model"], strict=False)
378
+
379
+ use_gpu = True if torch.cuda.is_available() else False
380
+ if use_gpu:
381
+ model = model.cuda(device)
382
+ model.eval()
383
+
384
+ sim_list = []
385
+ for wav1, wav2, truth in tqdm(test_set):
386
+ wav1, sr1 = torchaudio.load(wav1)
387
+ wav2, sr2 = torchaudio.load(wav2)
388
+
389
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
390
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
391
+ wav1 = resample1(wav1)
392
+ wav2 = resample2(wav2)
393
+
394
+ if use_gpu:
395
+ wav1 = wav1.cuda(device)
396
+ wav2 = wav2.cuda(device)
397
+ with torch.no_grad():
398
+ emb1 = model(wav1)
399
+ emb2 = model(wav2)
400
+
401
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
402
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
403
+ sim_list.append(sim)
404
+
405
+ return sim_list
src/f5_tts/infer/src_f5_tts_infer_README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
4
+
5
+ Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
6
+
7
+ To avoid possible inference failures, make sure you have seen through the following instructions.
8
+
9
+ - Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
10
+ - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
11
+ - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
12
+ - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
13
+
14
+
15
+ ## Gradio App
16
+
17
+ Currently supported features:
18
+
19
+ - Basic TTS with Chunk Inference
20
+ - Multi-Style / Multi-Speaker Generation
21
+ - Voice Chat powered by Qwen2.5-3B-Instruct
22
+
23
+ The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
24
+
25
+ The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
26
+
27
+ Could also be used as a component for larger application.
28
+ ```python
29
+ import gradio as gr
30
+ from f5_tts.infer.infer_gradio import app
31
+
32
+ with gr.Blocks() as main_app:
33
+ gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
34
+
35
+ # ... other Gradio components
36
+
37
+ app.render()
38
+
39
+ main_app.launch()
40
+ ```
41
+
42
+
43
+ ## CLI Inference
44
+
45
+ The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
46
+
47
+ The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
48
+
49
+ For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
50
+
51
+ Basically you can inference with flags:
52
+ ```bash
53
+ # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
54
+ f5-tts_infer-cli \
55
+ --model "F5-TTS" \
56
+ --ref_audio "ref_audio.wav" \
57
+ --ref_text "The content, subtitle or transcription of reference audio." \
58
+ --gen_text "Some text you want TTS model generate for you."
59
+
60
+ # Choose Vocoder
61
+ f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
62
+ f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
63
+ ```
64
+
65
+ And a `.toml` file would help with more flexible usage.
66
+
67
+ ```bash
68
+ f5-tts_infer-cli -c custom.toml
69
+ ```
70
+
71
+ For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
72
+
73
+ ```toml
74
+ # F5-TTS | E2-TTS
75
+ model = "F5-TTS"
76
+ ref_audio = "infer/examples/basic/basic_ref_en.wav"
77
+ # If an empty "", transcribes the reference audio automatically.
78
+ ref_text = "Some call me nature, others call me mother nature."
79
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
80
+ # File with text to generate. Ignores the text above.
81
+ gen_file = ""
82
+ remove_silence = false
83
+ output_dir = "tests"
84
+ ```
85
+
86
+ You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
87
+
88
+ ```toml
89
+ # F5-TTS | E2-TTS
90
+ model = "F5-TTS"
91
+ ref_audio = "infer/examples/multi/main.flac"
92
+ # If an empty "", transcribes the reference audio automatically.
93
+ ref_text = ""
94
+ gen_text = ""
95
+ # File with text to generate. Ignores the text above.
96
+ gen_file = "infer/examples/multi/story.txt"
97
+ remove_silence = true
98
+ output_dir = "tests"
99
+
100
+ [voices.town]
101
+ ref_audio = "infer/examples/multi/town.flac"
102
+ ref_text = ""
103
+
104
+ [voices.country]
105
+ ref_audio = "infer/examples/multi/country.flac"
106
+ ref_text = ""
107
+ ```
108
+ You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
109
+
110
+ ## Speech Editing
111
+
112
+ To test speech editing capabilities, use the following command:
113
+
114
+ ```bash
115
+ python src/f5_tts/infer/speech_edit.py
116
+ ```
src/f5_tts/infer/src_f5_tts_infer_infer_cli.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
+ import os
4
+ import re
5
+ from importlib.resources import files
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import tomli
11
+ from cached_path import cached_path
12
+
13
+ from f5_tts.infer.utils_infer import (
14
+ infer_process,
15
+ load_model,
16
+ load_vocoder,
17
+ preprocess_ref_audio_text,
18
+ remove_silence_for_generated_wav,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+
22
+ parser = argparse.ArgumentParser(
23
+ prog="python3 infer-cli.py",
24
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
25
+ epilog="Specify options above to override one or more settings from config.",
26
+ )
27
+ parser.add_argument(
28
+ "-c",
29
+ "--config",
30
+ help="Configuration file. Default=infer/examples/basic/basic.toml",
31
+ default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
32
+ )
33
+ parser.add_argument(
34
+ "-m",
35
+ "--model",
36
+ help="F5-TTS | E2-TTS",
37
+ )
38
+ parser.add_argument(
39
+ "-p",
40
+ "--ckpt_file",
41
+ help="The Checkpoint .pt",
42
+ )
43
+ parser.add_argument(
44
+ "-v",
45
+ "--vocab_file",
46
+ help="The vocab .txt",
47
+ )
48
+ parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
+ parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
+ parser.add_argument(
51
+ "-t",
52
+ "--gen_text",
53
+ type=str,
54
+ help="Text to generate.",
55
+ )
56
+ parser.add_argument(
57
+ "-f",
58
+ "--gen_file",
59
+ type=str,
60
+ help="File with text to generate. Ignores --text",
61
+ )
62
+ parser.add_argument(
63
+ "-o",
64
+ "--output_dir",
65
+ type=str,
66
+ help="Path to output folder..",
67
+ )
68
+ parser.add_argument(
69
+ "--remove_silence",
70
+ help="Remove silence.",
71
+ )
72
+ parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
73
+ parser.add_argument(
74
+ "--load_vocoder_from_local",
75
+ action="store_true",
76
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
77
+ )
78
+ parser.add_argument(
79
+ "--speed",
80
+ type=float,
81
+ default=1.0,
82
+ help="Adjust the speed of the audio generation (default: 1.0)",
83
+ )
84
+ args = parser.parse_args()
85
+
86
+ config = tomli.load(open(args.config, "rb"))
87
+
88
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
89
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
90
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
91
+ gen_file = args.gen_file if args.gen_file else config["gen_file"]
92
+
93
+ # patches for pip pkg user
94
+ if "infer/examples/" in ref_audio:
95
+ ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
96
+ if "infer/examples/" in gen_file:
97
+ gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
98
+ if "voices" in config:
99
+ for voice in config["voices"]:
100
+ voice_ref_audio = config["voices"][voice]["ref_audio"]
101
+ if "infer/examples/" in voice_ref_audio:
102
+ config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
103
+
104
+ if gen_file:
105
+ gen_text = codecs.open(gen_file, "r", "utf-8").read()
106
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
107
+ model = args.model if args.model else config["model"]
108
+ ckpt_file = args.ckpt_file if args.ckpt_file else ""
109
+ vocab_file = args.vocab_file if args.vocab_file else ""
110
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
111
+ speed = args.speed
112
+ wave_path = Path(output_dir) / "infer_cli_out.wav"
113
+ # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
+ if args.vocoder_name == "vocos":
115
+ vocoder_local_path = "../checkpoints/vocos-mel-24khz"
116
+ elif args.vocoder_name == "bigvgan":
117
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
118
+ mel_spec_type = args.vocoder_name
119
+
120
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
121
+
122
+
123
+ # load models
124
+ if model == "F5-TTS":
125
+ model_cls = DiT
126
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
127
+ if ckpt_file == "":
128
+ if args.vocoder_name == "vocos":
129
+ repo_name = "F5-TTS"
130
+ exp_name = "F5TTS_Base"
131
+ ckpt_step = 1200000
132
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
133
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
134
+ elif args.vocoder_name == "bigvgan":
135
+ repo_name = "F5-TTS"
136
+ exp_name = "F5TTS_Base_bigvgan"
137
+ ckpt_step = 1250000
138
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
139
+
140
+ elif model == "E2-TTS":
141
+ model_cls = UNetT
142
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
143
+ if ckpt_file == "":
144
+ repo_name = "E2-TTS"
145
+ exp_name = "E2TTS_Base"
146
+ ckpt_step = 1200000
147
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
148
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
149
+ elif args.vocoder_name == "bigvgan": # TODO: need to test
150
+ repo_name = "F5-TTS"
151
+ exp_name = "F5TTS_Base_bigvgan"
152
+ ckpt_step = 1250000
153
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
154
+
155
+
156
+ print(f"Using {model}...")
157
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
158
+
159
+
160
+ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
161
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
162
+ if "voices" not in config:
163
+ voices = {"main": main_voice}
164
+ else:
165
+ voices = config["voices"]
166
+ voices["main"] = main_voice
167
+ for voice in voices:
168
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
169
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
170
+ )
171
+ print("Voice:", voice)
172
+ print("Ref_audio:", voices[voice]["ref_audio"])
173
+ print("Ref_text:", voices[voice]["ref_text"])
174
+
175
+ generated_audio_segments = []
176
+ reg1 = r"(?=\[\w+\])"
177
+ chunks = re.split(reg1, text_gen)
178
+ reg2 = r"\[(\w+)\]"
179
+ for text in chunks:
180
+ if not text.strip():
181
+ continue
182
+ match = re.match(reg2, text)
183
+ if match:
184
+ voice = match[1]
185
+ else:
186
+ print("No voice tag found, using main.")
187
+ voice = "main"
188
+ if voice not in voices:
189
+ print(f"Voice {voice} not found, using main.")
190
+ voice = "main"
191
+ text = re.sub(reg2, "", text)
192
+ gen_text = text.strip()
193
+ ref_audio = voices[voice]["ref_audio"]
194
+ ref_text = voices[voice]["ref_text"]
195
+ print(f"Voice: {voice}")
196
+ audio, final_sample_rate, spectragram = infer_process(
197
+ ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
198
+ )
199
+ generated_audio_segments.append(audio)
200
+
201
+ if generated_audio_segments:
202
+ final_wave = np.concatenate(generated_audio_segments)
203
+
204
+ if not os.path.exists(output_dir):
205
+ os.makedirs(output_dir)
206
+
207
+ with open(wave_path, "wb") as f:
208
+ sf.write(f.name, final_wave, final_sample_rate)
209
+ # Remove silence
210
+ if remove_silence:
211
+ remove_silence_for_generated_wav(f.name)
212
+ print(f.name)
213
+
214
+
215
+ def main():
216
+ main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
217
+
218
+
219
+ if __name__ == "__main__":
220
+ main()
src/f5_tts/infer/src_f5_tts_infer_speech_edit.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+
7
+ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
8
+ from f5_tts.model import CFM, DiT, UNetT
9
+ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12
+
13
+
14
+ # --------------------- Dataset Settings -------------------- #
15
+
16
+ target_sample_rate = 24000
17
+ n_mel_channels = 100
18
+ hop_length = 256
19
+ win_length = 1024
20
+ n_fft = 1024
21
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
22
+ target_rms = 0.1
23
+
24
+ tokenizer = "pinyin"
25
+ dataset_name = "Emilia_ZH_EN"
26
+
27
+
28
+ # ---------------------- infer setting ---------------------- #
29
+
30
+ seed = None # int | None
31
+
32
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
33
+ ckpt_step = 1200000
34
+
35
+ nfe_step = 32 # 16, 32
36
+ cfg_strength = 2.0
37
+ ode_method = "euler" # euler | midpoint
38
+ sway_sampling_coef = -1.0
39
+ speed = 1.0
40
+
41
+ if exp_name == "F5TTS_Base":
42
+ model_cls = DiT
43
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
44
+
45
+ elif exp_name == "E2TTS_Base":
46
+ model_cls = UNetT
47
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
48
+
49
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
50
+ output_dir = "tests"
51
+
52
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
53
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
54
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
55
+ # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
56
+ # [result will be saved at same path of audio file]
57
+ # [--language "zho" for Chinese, "eng" for English]
58
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
59
+
60
+ audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
61
+ origin_text = "Some call me nature, others call me mother nature."
62
+ target_text = "Some call me optimist, others call me realist."
63
+ parts_to_edit = [
64
+ [1.42, 2.44],
65
+ [4.04, 4.9],
66
+ ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [
68
+ 1.2,
69
+ 1,
70
+ ] # fix duration for "optimist" & "realist", in seconds
71
+
72
+ # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
73
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
74
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
75
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
76
+ # fix_duration = None # use origin text duration
77
+
78
+
79
+ # -------------------------------------------------#
80
+
81
+ use_ema = True
82
+
83
+ if not os.path.exists(output_dir):
84
+ os.makedirs(output_dir)
85
+
86
+ # Vocoder model
87
+ local = False
88
+ if mel_spec_type == "vocos":
89
+ vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
90
+ elif mel_spec_type == "bigvgan":
91
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
92
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
93
+
94
+ # Tokenizer
95
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
+
97
+ # Model
98
+ model = CFM(
99
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
100
+ mel_spec_kwargs=dict(
101
+ n_fft=n_fft,
102
+ hop_length=hop_length,
103
+ win_length=win_length,
104
+ n_mel_channels=n_mel_channels,
105
+ target_sample_rate=target_sample_rate,
106
+ mel_spec_type=mel_spec_type,
107
+ ),
108
+ odeint_kwargs=dict(
109
+ method=ode_method,
110
+ ),
111
+ vocab_char_map=vocab_char_map,
112
+ ).to(device)
113
+
114
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
115
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
116
+
117
+ # Audio
118
+ audio, sr = torchaudio.load(audio_to_edit)
119
+ if audio.shape[0] > 1:
120
+ audio = torch.mean(audio, dim=0, keepdim=True)
121
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
122
+ if rms < target_rms:
123
+ audio = audio * target_rms / rms
124
+ if sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
+ audio = resampler(audio)
127
+ offset = 0
128
+ audio_ = torch.zeros(1, 0)
129
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130
+ for part in parts_to_edit:
131
+ start, end = part
132
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
+ part_dur = part_dur * target_sample_rate
134
+ start = start * target_sample_rate
135
+ audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
136
+ edit_mask = torch.cat(
137
+ (
138
+ edit_mask,
139
+ torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
140
+ torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
141
+ ),
142
+ dim=-1,
143
+ )
144
+ offset = end * target_sample_rate
145
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
146
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
147
+ audio = audio.to(device)
148
+ edit_mask = edit_mask.to(device)
149
+
150
+ # Text
151
+ text_list = [target_text]
152
+ if tokenizer == "pinyin":
153
+ final_text_list = convert_char_to_pinyin(text_list)
154
+ else:
155
+ final_text_list = [text_list]
156
+ print(f"text : {text_list}")
157
+ print(f"pinyin: {final_text_list}")
158
+
159
+ # Duration
160
+ ref_audio_len = 0
161
+ duration = audio.shape[-1] // hop_length
162
+
163
+ # Inference
164
+ with torch.inference_mode():
165
+ generated, trajectory = model.sample(
166
+ cond=audio,
167
+ text=final_text_list,
168
+ duration=duration,
169
+ steps=nfe_step,
170
+ cfg_strength=cfg_strength,
171
+ sway_sampling_coef=sway_sampling_coef,
172
+ seed=seed,
173
+ edit_mask=edit_mask,
174
+ )
175
+ print(f"Generated mel: {generated.shape}")
176
+
177
+ # Final result
178
+ generated = generated.to(torch.float32)
179
+ generated = generated[:, ref_audio_len:, :]
180
+ gen_mel_spec = generated.permute(0, 2, 1)
181
+ if mel_spec_type == "vocos":
182
+ generated_wave = vocoder.decode(gen_mel_spec)
183
+ elif mel_spec_type == "bigvgan":
184
+ generated_wave = vocoder(gen_mel_spec)
185
+
186
+ if rms < target_rms:
187
+ generated_wave = generated_wave * rms / target_rms
188
+
189
+ save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
190
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
191
+ print(f"Generated wav: {generated_wave.shape}")
src/f5_tts/infer/src_f5_tts_infer_utils_infer.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A unified script for inference process
2
+ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
+ import os
4
+ import sys
5
+
6
+ sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
7
+
8
+ import hashlib
9
+ import re
10
+ import tempfile
11
+ from importlib.resources import files
12
+
13
+ import matplotlib
14
+
15
+ matplotlib.use("Agg")
16
+
17
+ import matplotlib.pylab as plt
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ import tqdm
22
+ from pydub import AudioSegment, silence
23
+ from transformers import pipeline
24
+ from vocos import Vocos
25
+
26
+ from f5_tts.model import CFM
27
+ from f5_tts.model.utils import (
28
+ get_tokenizer,
29
+ convert_char_to_pinyin,
30
+ )
31
+
32
+ _ref_audio_cache = {}
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
35
+
36
+ # -----------------------------------------
37
+
38
+ target_sample_rate = 24000
39
+ n_mel_channels = 100
40
+ hop_length = 256
41
+ win_length = 1024
42
+ n_fft = 1024
43
+ mel_spec_type = "vocos"
44
+ target_rms = 0.1
45
+ cross_fade_duration = 0.15
46
+ ode_method = "euler"
47
+ nfe_step = 32 # 16, 32
48
+ cfg_strength = 2.0
49
+ sway_sampling_coef = -1.0
50
+ speed = 1.0
51
+ fix_duration = None
52
+
53
+ # -----------------------------------------
54
+
55
+
56
+ # chunk text into smaller pieces
57
+
58
+
59
+ def chunk_text(text, max_chars=135):
60
+ """
61
+ Splits the input text into chunks, each with a maximum number of characters.
62
+
63
+ Args:
64
+ text (str): The text to be split.
65
+ max_chars (int): The maximum number of characters per chunk.
66
+
67
+ Returns:
68
+ List[str]: A list of text chunks.
69
+ """
70
+ chunks = []
71
+ current_chunk = ""
72
+ # Split the text into sentences based on punctuation followed by whitespace
73
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
74
+
75
+ for sentence in sentences:
76
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
77
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
78
+ else:
79
+ if current_chunk:
80
+ chunks.append(current_chunk.strip())
81
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
82
+
83
+ if current_chunk:
84
+ chunks.append(current_chunk.strip())
85
+
86
+ return chunks
87
+
88
+
89
+ # load vocoder
90
+ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device):
91
+ if vocoder_name == "vocos":
92
+ if is_local:
93
+ print(f"Load vocos from local path {local_path}")
94
+ vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
95
+ state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
96
+ vocoder.load_state_dict(state_dict)
97
+ vocoder = vocoder.eval().to(device)
98
+ else:
99
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
100
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
101
+ elif vocoder_name == "bigvgan":
102
+ try:
103
+ from third_party.BigVGAN import bigvgan
104
+ except ImportError:
105
+ print("You need to follow the README to init submodule and change the BigVGAN source code.")
106
+ if is_local:
107
+ """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
108
+ vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
109
+ else:
110
+ vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
111
+
112
+ vocoder.remove_weight_norm()
113
+ vocoder = vocoder.eval().to(device)
114
+ return vocoder
115
+
116
+
117
+ # load asr pipeline
118
+
119
+ asr_pipe = None
120
+
121
+
122
+ def initialize_asr_pipeline(device=device, dtype=None):
123
+ if dtype is None:
124
+ dtype = (
125
+ torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
126
+ )
127
+ global asr_pipe
128
+ asr_pipe = pipeline(
129
+ "automatic-speech-recognition",
130
+ model="openai/whisper-large-v3-turbo",
131
+ torch_dtype=dtype,
132
+ device=device,
133
+ )
134
+
135
+
136
+ # load model checkpoint for inference
137
+
138
+
139
+ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
140
+ if dtype is None:
141
+ dtype = (
142
+ torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
143
+ )
144
+ model = model.to(dtype)
145
+
146
+ ckpt_type = ckpt_path.split(".")[-1]
147
+ if ckpt_type == "safetensors":
148
+ from safetensors.torch import load_file
149
+
150
+ checkpoint = load_file(ckpt_path)
151
+ else:
152
+ checkpoint = torch.load(ckpt_path, weights_only=True)
153
+
154
+ if use_ema:
155
+ if ckpt_type == "safetensors":
156
+ checkpoint = {"ema_model_state_dict": checkpoint}
157
+ checkpoint["model_state_dict"] = {
158
+ k.replace("ema_model.", ""): v
159
+ for k, v in checkpoint["ema_model_state_dict"].items()
160
+ if k not in ["initted", "step"]
161
+ }
162
+
163
+ # patch for backward compatibility, 305e3ea
164
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
165
+ if key in checkpoint["model_state_dict"]:
166
+ del checkpoint["model_state_dict"][key]
167
+
168
+ model.load_state_dict(checkpoint["model_state_dict"])
169
+ else:
170
+ if ckpt_type == "safetensors":
171
+ checkpoint = {"model_state_dict": checkpoint}
172
+ model.load_state_dict(checkpoint["model_state_dict"])
173
+
174
+ return model.to(device)
175
+
176
+
177
+ # load model for inference
178
+
179
+
180
+ def load_model(
181
+ model_cls,
182
+ model_cfg,
183
+ ckpt_path,
184
+ mel_spec_type=mel_spec_type,
185
+ vocab_file="",
186
+ ode_method=ode_method,
187
+ use_ema=True,
188
+ device=device,
189
+ ):
190
+ if vocab_file == "":
191
+ vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
192
+ tokenizer = "custom"
193
+
194
+ print("\nvocab : ", vocab_file)
195
+ print("tokenizer : ", tokenizer)
196
+ print("model : ", ckpt_path, "\n")
197
+
198
+ vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
199
+ model = CFM(
200
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
201
+ mel_spec_kwargs=dict(
202
+ n_fft=n_fft,
203
+ hop_length=hop_length,
204
+ win_length=win_length,
205
+ n_mel_channels=n_mel_channels,
206
+ target_sample_rate=target_sample_rate,
207
+ mel_spec_type=mel_spec_type,
208
+ ),
209
+ odeint_kwargs=dict(
210
+ method=ode_method,
211
+ ),
212
+ vocab_char_map=vocab_char_map,
213
+ ).to(device)
214
+
215
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
216
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
217
+
218
+ return model
219
+
220
+
221
+ # preprocess reference audio and text
222
+
223
+
224
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
225
+ show_info("Converting audio...")
226
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
227
+ aseg = AudioSegment.from_file(ref_audio_orig)
228
+
229
+ if clip_short:
230
+ # 1. try to find long silence for clipping
231
+ non_silent_segs = silence.split_on_silence(
232
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
233
+ )
234
+ non_silent_wave = AudioSegment.silent(duration=0)
235
+ for non_silent_seg in non_silent_segs:
236
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
237
+ show_info("Audio is over 15s, clipping short. (1)")
238
+ break
239
+ non_silent_wave += non_silent_seg
240
+
241
+ # 2. try to find short silence for clipping if 1. failed
242
+ if len(non_silent_wave) > 15000:
243
+ non_silent_segs = silence.split_on_silence(
244
+ aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000
245
+ )
246
+ non_silent_wave = AudioSegment.silent(duration=0)
247
+ for non_silent_seg in non_silent_segs:
248
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
249
+ show_info("Audio is over 15s, clipping short. (2)")
250
+ break
251
+ non_silent_wave += non_silent_seg
252
+
253
+ aseg = non_silent_wave
254
+
255
+ # 3. if no proper silence found for clipping
256
+ if len(aseg) > 15000:
257
+ aseg = aseg[:15000]
258
+ show_info("Audio is over 15s, clipping short. (3)")
259
+
260
+ aseg.export(f.name, format="wav")
261
+ ref_audio = f.name
262
+
263
+ # Compute a hash of the reference audio file
264
+ with open(ref_audio, "rb") as audio_file:
265
+ audio_data = audio_file.read()
266
+ audio_hash = hashlib.md5(audio_data).hexdigest()
267
+
268
+ global _ref_audio_cache
269
+ if audio_hash in _ref_audio_cache:
270
+ # Use cached reference text
271
+ show_info("Using cached reference text...")
272
+ ref_text = _ref_audio_cache[audio_hash]
273
+ else:
274
+ if not ref_text.strip():
275
+ global asr_pipe
276
+ if asr_pipe is None:
277
+ initialize_asr_pipeline(device=device)
278
+ show_info("No reference text provided, transcribing reference audio...")
279
+ ref_text = asr_pipe(
280
+ ref_audio,
281
+ chunk_length_s=30,
282
+ batch_size=128,
283
+ generate_kwargs={"task": "transcribe"},
284
+ return_timestamps=False,
285
+ )["text"].strip()
286
+ show_info("Finished transcription")
287
+ else:
288
+ show_info("Using custom reference text...")
289
+ # Cache the transcribed text
290
+ _ref_audio_cache[audio_hash] = ref_text
291
+
292
+ # Ensure ref_text ends with a proper sentence-ending punctuation
293
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
294
+ if ref_text.endswith("."):
295
+ ref_text += " "
296
+ else:
297
+ ref_text += ". "
298
+
299
+ return ref_audio, ref_text
300
+
301
+
302
+ # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
303
+
304
+
305
+ def infer_process(
306
+ ref_audio,
307
+ ref_text,
308
+ gen_text,
309
+ model_obj,
310
+ vocoder,
311
+ mel_spec_type=mel_spec_type,
312
+ show_info=print,
313
+ progress=tqdm,
314
+ target_rms=target_rms,
315
+ cross_fade_duration=cross_fade_duration,
316
+ nfe_step=nfe_step,
317
+ cfg_strength=cfg_strength,
318
+ sway_sampling_coef=sway_sampling_coef,
319
+ speed=speed,
320
+ fix_duration=fix_duration,
321
+ device=device,
322
+ ):
323
+ # Split the input text into batches
324
+ audio, sr = torchaudio.load(ref_audio)
325
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
326
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
327
+ for i, gen_text in enumerate(gen_text_batches):
328
+ print(f"gen_text {i}", gen_text)
329
+
330
+ show_info(f"Generating audio in {len(gen_text_batches)} batches...")
331
+ return infer_batch_process(
332
+ (audio, sr),
333
+ ref_text,
334
+ gen_text_batches,
335
+ model_obj,
336
+ vocoder,
337
+ mel_spec_type=mel_spec_type,
338
+ progress=progress,
339
+ target_rms=target_rms,
340
+ cross_fade_duration=cross_fade_duration,
341
+ nfe_step=nfe_step,
342
+ cfg_strength=cfg_strength,
343
+ sway_sampling_coef=sway_sampling_coef,
344
+ speed=speed,
345
+ fix_duration=fix_duration,
346
+ device=device,
347
+ )
348
+
349
+
350
+ # infer batches
351
+
352
+
353
+ def infer_batch_process(
354
+ ref_audio,
355
+ ref_text,
356
+ gen_text_batches,
357
+ model_obj,
358
+ vocoder,
359
+ mel_spec_type="vocos",
360
+ progress=tqdm,
361
+ target_rms=0.1,
362
+ cross_fade_duration=0.15,
363
+ nfe_step=32,
364
+ cfg_strength=2.0,
365
+ sway_sampling_coef=-1,
366
+ speed=1,
367
+ fix_duration=None,
368
+ device=None,
369
+ ):
370
+ audio, sr = ref_audio
371
+ if audio.shape[0] > 1:
372
+ audio = torch.mean(audio, dim=0, keepdim=True)
373
+
374
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
375
+ if rms < target_rms:
376
+ audio = audio * target_rms / rms
377
+ if sr != target_sample_rate:
378
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
379
+ audio = resampler(audio)
380
+ audio = audio.to(device)
381
+
382
+ generated_waves = []
383
+ spectrograms = []
384
+
385
+ if len(ref_text[-1].encode("utf-8")) == 1:
386
+ ref_text = ref_text + " "
387
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
388
+ # Prepare the text
389
+ text_list = [ref_text + gen_text]
390
+ final_text_list = convert_char_to_pinyin(text_list)
391
+
392
+ ref_audio_len = audio.shape[-1] // hop_length
393
+ if fix_duration is not None:
394
+ duration = int(fix_duration * target_sample_rate / hop_length)
395
+ else:
396
+ # Calculate duration
397
+ ref_text_len = len(ref_text.encode("utf-8"))
398
+ gen_text_len = len(gen_text.encode("utf-8"))
399
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
400
+
401
+ # inference
402
+ with torch.inference_mode():
403
+ generated, _ = model_obj.sample(
404
+ cond=audio,
405
+ text=final_text_list,
406
+ duration=duration,
407
+ steps=nfe_step,
408
+ cfg_strength=cfg_strength,
409
+ sway_sampling_coef=sway_sampling_coef,
410
+ )
411
+
412
+ generated = generated.to(torch.float32)
413
+ generated = generated[:, ref_audio_len:, :]
414
+ generated_mel_spec = generated.permute(0, 2, 1)
415
+ if mel_spec_type == "vocos":
416
+ generated_wave = vocoder.decode(generated_mel_spec)
417
+ elif mel_spec_type == "bigvgan":
418
+ generated_wave = vocoder(generated_mel_spec)
419
+ if rms < target_rms:
420
+ generated_wave = generated_wave * rms / target_rms
421
+
422
+ # wav -> numpy
423
+ generated_wave = generated_wave.squeeze().cpu().numpy()
424
+
425
+ generated_waves.append(generated_wave)
426
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
427
+
428
+ # Combine all generated waves with cross-fading
429
+ if cross_fade_duration <= 0:
430
+ # Simply concatenate
431
+ final_wave = np.concatenate(generated_waves)
432
+ else:
433
+ final_wave = generated_waves[0]
434
+ for i in range(1, len(generated_waves)):
435
+ prev_wave = final_wave
436
+ next_wave = generated_waves[i]
437
+
438
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
439
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
440
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
441
+
442
+ if cross_fade_samples <= 0:
443
+ # No overlap possible, concatenate
444
+ final_wave = np.concatenate([prev_wave, next_wave])
445
+ continue
446
+
447
+ # Overlapping parts
448
+ prev_overlap = prev_wave[-cross_fade_samples:]
449
+ next_overlap = next_wave[:cross_fade_samples]
450
+
451
+ # Fade out and fade in
452
+ fade_out = np.linspace(1, 0, cross_fade_samples)
453
+ fade_in = np.linspace(0, 1, cross_fade_samples)
454
+
455
+ # Cross-faded overlap
456
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
457
+
458
+ # Combine
459
+ new_wave = np.concatenate(
460
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
461
+ )
462
+
463
+ final_wave = new_wave
464
+
465
+ # Create a combined spectrogram
466
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
467
+
468
+ return final_wave, target_sample_rate, combined_spectrogram
469
+
470
+
471
+ # remove silence from generated wav
472
+
473
+
474
+ def remove_silence_for_generated_wav(filename):
475
+ aseg = AudioSegment.from_file(filename)
476
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
477
+ non_silent_wave = AudioSegment.silent(duration=0)
478
+ for non_silent_seg in non_silent_segs:
479
+ non_silent_wave += non_silent_seg
480
+ aseg = non_silent_wave
481
+ aseg.export(filename, format="wav")
482
+
483
+
484
+ # save spectrogram
485
+
486
+
487
+ def save_spectrogram(spectrogram, path):
488
+ plt.figure(figsize=(12, 4))
489
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
490
+ plt.colorbar()
491
+ plt.savefig(path)
492
+ plt.close()
src/f5_tts/src_f5_tts_api.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from importlib.resources import files
4
+
5
+ import soundfile as sf
6
+ import torch
7
+ import tqdm
8
+ from cached_path import cached_path
9
+
10
+ from f5_tts.infer.utils_infer import (
11
+ hop_length,
12
+ infer_process,
13
+ load_model,
14
+ load_vocoder,
15
+ preprocess_ref_audio_text,
16
+ remove_silence_for_generated_wav,
17
+ save_spectrogram,
18
+ target_sample_rate,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+ from f5_tts.model.utils import seed_everything
22
+
23
+
24
+ class F5TTS:
25
+ def __init__(
26
+ self,
27
+ model_type="F5-TTS",
28
+ ckpt_file="",
29
+ vocab_file="",
30
+ ode_method="euler",
31
+ use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
+ device=None,
35
+ ):
36
+ # Initialize parameters
37
+ self.final_wave = None
38
+ self.target_sample_rate = target_sample_rate
39
+ self.hop_length = hop_length
40
+ self.seed = -1
41
+ self.mel_spec_type = vocoder_name
42
+
43
+ # Set device
44
+ self.device = device or (
45
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
46
+ )
47
+
48
+ # Load models
49
+ self.load_vocoder_model(vocoder_name, local_path)
50
+ self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
+
52
+ def load_vocoder_model(self, vocoder_name, local_path):
53
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
+
55
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
56
+ if model_type == "F5-TTS":
57
+ if not ckpt_file:
58
+ if mel_spec_type == "vocos":
59
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
60
+ elif mel_spec_type == "bigvgan":
61
+ ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
62
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
+ model_cls = DiT
64
+ elif model_type == "E2-TTS":
65
+ if not ckpt_file:
66
+ ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
+ model_cls = UNetT
69
+ else:
70
+ raise ValueError(f"Unknown model type: {model_type}")
71
+
72
+ self.ema_model = load_model(
73
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
74
+ )
75
+
76
+ def export_wav(self, wav, file_wave, remove_silence=False):
77
+ sf.write(file_wave, wav, self.target_sample_rate)
78
+
79
+ if remove_silence:
80
+ remove_silence_for_generated_wav(file_wave)
81
+
82
+ def export_spectrogram(self, spect, file_spect):
83
+ save_spectrogram(spect, file_spect)
84
+
85
+ def infer(
86
+ self,
87
+ ref_file,
88
+ ref_text,
89
+ gen_text,
90
+ show_info=print,
91
+ progress=tqdm,
92
+ target_rms=0.1,
93
+ cross_fade_duration=0.15,
94
+ sway_sampling_coef=-1,
95
+ cfg_strength=2,
96
+ nfe_step=32,
97
+ speed=1.0,
98
+ fix_duration=None,
99
+ remove_silence=False,
100
+ file_wave=None,
101
+ file_spect=None,
102
+ seed=-1,
103
+ ):
104
+ if seed == -1:
105
+ seed = random.randint(0, sys.maxsize)
106
+ seed_everything(seed)
107
+ self.seed = seed
108
+
109
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
110
+
111
+ wav, sr, spect = infer_process(
112
+ ref_file,
113
+ ref_text,
114
+ gen_text,
115
+ self.ema_model,
116
+ self.vocoder,
117
+ self.mel_spec_type,
118
+ show_info=show_info,
119
+ progress=progress,
120
+ target_rms=target_rms,
121
+ cross_fade_duration=cross_fade_duration,
122
+ nfe_step=nfe_step,
123
+ cfg_strength=cfg_strength,
124
+ sway_sampling_coef=sway_sampling_coef,
125
+ speed=speed,
126
+ fix_duration=fix_duration,
127
+ device=self.device,
128
+ )
129
+
130
+ if file_wave is not None:
131
+ self.export_wav(wav, file_wave, remove_silence)
132
+
133
+ if file_spect is not None:
134
+ self.export_spectrogram(spect, file_spect)
135
+
136
+ return wav, sr, spect
137
+
138
+
139
+ if __name__ == "__main__":
140
+ f5tts = F5TTS()
141
+
142
+ wav, sr, spect = f5tts.infer(
143
+ ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
144
+ ref_text="some call me nature, others call me mother nature.",
145
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
146
+ file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
147
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
148
+ seed=-1, # random seed = -1
149
+ )
150
+
151
+ print("seed :", f5tts.seed)