Spaces:
Running
Running
Upload 12 files
Browse files- src/f5_tts/eval/src_f5_tts_eval_README.md +49 -0
- src/f5_tts/eval/src_f5_tts_eval_ecapa_tdnn.py +330 -0
- src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.py +207 -0
- src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.sh +13 -0
- src/f5_tts/eval/src_f5_tts_eval_eval_librispeech_test_clean.py +73 -0
- src/f5_tts/eval/src_f5_tts_eval_eval_seedtts_testset.py +75 -0
- src/f5_tts/eval/src_f5_tts_eval_utils_eval.py +405 -0
- src/f5_tts/infer/src_f5_tts_infer_README.md +116 -0
- src/f5_tts/infer/src_f5_tts_infer_infer_cli.py +220 -0
- src/f5_tts/infer/src_f5_tts_infer_speech_edit.py +191 -0
- src/f5_tts/infer/src_f5_tts_infer_utils_infer.py +492 -0
- src/f5_tts/src_f5_tts_api.py +151 -0
src/f5_tts/eval/src_f5_tts_eval_README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Evaluation
|
3 |
+
|
4 |
+
Install packages for evaluation:
|
5 |
+
|
6 |
+
```bash
|
7 |
+
pip install -e .[eval]
|
8 |
+
```
|
9 |
+
|
10 |
+
## Generating Samples for Evaluation
|
11 |
+
|
12 |
+
### Prepare Test Datasets
|
13 |
+
|
14 |
+
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
15 |
+
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
16 |
+
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
17 |
+
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
18 |
+
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
19 |
+
|
20 |
+
### Batch Inference for Test Set
|
21 |
+
|
22 |
+
To run batch inference for evaluations, execute the following commands:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
# batch inference for evaluations
|
26 |
+
accelerate config # if not set before
|
27 |
+
bash src/f5_tts/eval/eval_infer_batch.sh
|
28 |
+
```
|
29 |
+
|
30 |
+
## Objective Evaluation on Generated Results
|
31 |
+
|
32 |
+
### Download Evaluation Model Checkpoints
|
33 |
+
|
34 |
+
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
|
35 |
+
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
36 |
+
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
37 |
+
|
38 |
+
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
39 |
+
|
40 |
+
### Objective Evaluation
|
41 |
+
|
42 |
+
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
43 |
+
```bash
|
44 |
+
# Evaluation for Seed-TTS test set
|
45 |
+
python src/f5_tts/eval/eval_seedtts_testset.py
|
46 |
+
|
47 |
+
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
48 |
+
python src/f5_tts/eval/eval_librispeech_test_clean.py
|
49 |
+
```
|
src/f5_tts/eval/src_f5_tts_eval_ecapa_tdnn.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# just for speaker similarity evaluation, third-party code
|
2 |
+
|
3 |
+
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
|
4 |
+
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
""" Res2Conv1d + BatchNorm1d + ReLU
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
class Res2Conv1dReluBn(nn.Module):
|
17 |
+
"""
|
18 |
+
in_channels == out_channels == channels
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
22 |
+
super().__init__()
|
23 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
24 |
+
self.scale = scale
|
25 |
+
self.width = channels // scale
|
26 |
+
self.nums = scale if scale == 1 else scale - 1
|
27 |
+
|
28 |
+
self.convs = []
|
29 |
+
self.bns = []
|
30 |
+
for i in range(self.nums):
|
31 |
+
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
32 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
33 |
+
self.convs = nn.ModuleList(self.convs)
|
34 |
+
self.bns = nn.ModuleList(self.bns)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
out = []
|
38 |
+
spx = torch.split(x, self.width, 1)
|
39 |
+
for i in range(self.nums):
|
40 |
+
if i == 0:
|
41 |
+
sp = spx[i]
|
42 |
+
else:
|
43 |
+
sp = sp + spx[i]
|
44 |
+
# Order: conv -> relu -> bn
|
45 |
+
sp = self.convs[i](sp)
|
46 |
+
sp = self.bns[i](F.relu(sp))
|
47 |
+
out.append(sp)
|
48 |
+
if self.scale != 1:
|
49 |
+
out.append(spx[self.nums])
|
50 |
+
out = torch.cat(out, dim=1)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
""" Conv1d + BatchNorm1d + ReLU
|
56 |
+
"""
|
57 |
+
|
58 |
+
|
59 |
+
class Conv1dReluBn(nn.Module):
|
60 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
61 |
+
super().__init__()
|
62 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
63 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
return self.bn(F.relu(self.conv(x)))
|
67 |
+
|
68 |
+
|
69 |
+
""" The SE connection of 1D case.
|
70 |
+
"""
|
71 |
+
|
72 |
+
|
73 |
+
class SE_Connect(nn.Module):
|
74 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
75 |
+
super().__init__()
|
76 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
77 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
out = x.mean(dim=2)
|
81 |
+
out = F.relu(self.linear1(out))
|
82 |
+
out = torch.sigmoid(self.linear2(out))
|
83 |
+
out = x * out.unsqueeze(2)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
89 |
+
"""
|
90 |
+
|
91 |
+
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
92 |
+
# return nn.Sequential(
|
93 |
+
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
94 |
+
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
95 |
+
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
96 |
+
# SE_Connect(channels)
|
97 |
+
# )
|
98 |
+
|
99 |
+
|
100 |
+
class SE_Res2Block(nn.Module):
|
101 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
102 |
+
super().__init__()
|
103 |
+
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
104 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
|
105 |
+
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
106 |
+
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
107 |
+
|
108 |
+
self.shortcut = None
|
109 |
+
if in_channels != out_channels:
|
110 |
+
self.shortcut = nn.Conv1d(
|
111 |
+
in_channels=in_channels,
|
112 |
+
out_channels=out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
residual = x
|
118 |
+
if self.shortcut:
|
119 |
+
residual = self.shortcut(x)
|
120 |
+
|
121 |
+
x = self.Conv1dReluBn1(x)
|
122 |
+
x = self.Res2Conv1dReluBn(x)
|
123 |
+
x = self.Conv1dReluBn2(x)
|
124 |
+
x = self.SE_Connect(x)
|
125 |
+
|
126 |
+
return x + residual
|
127 |
+
|
128 |
+
|
129 |
+
""" Attentive weighted mean and standard deviation pooling.
|
130 |
+
"""
|
131 |
+
|
132 |
+
|
133 |
+
class AttentiveStatsPool(nn.Module):
|
134 |
+
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
135 |
+
super().__init__()
|
136 |
+
self.global_context_att = global_context_att
|
137 |
+
|
138 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
139 |
+
if global_context_att:
|
140 |
+
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
|
141 |
+
else:
|
142 |
+
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
|
143 |
+
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
if self.global_context_att:
|
147 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
148 |
+
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
149 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
150 |
+
else:
|
151 |
+
x_in = x
|
152 |
+
|
153 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
154 |
+
alpha = torch.tanh(self.linear1(x_in))
|
155 |
+
# alpha = F.relu(self.linear1(x_in))
|
156 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
157 |
+
mean = torch.sum(alpha * x, dim=2)
|
158 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
159 |
+
std = torch.sqrt(residuals.clamp(min=1e-9))
|
160 |
+
return torch.cat([mean, std], dim=1)
|
161 |
+
|
162 |
+
|
163 |
+
class ECAPA_TDNN(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
feat_dim=80,
|
167 |
+
channels=512,
|
168 |
+
emb_dim=192,
|
169 |
+
global_context_att=False,
|
170 |
+
feat_type="wavlm_large",
|
171 |
+
sr=16000,
|
172 |
+
feature_selection="hidden_states",
|
173 |
+
update_extract=False,
|
174 |
+
config_path=None,
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.feat_type = feat_type
|
179 |
+
self.feature_selection = feature_selection
|
180 |
+
self.update_extract = update_extract
|
181 |
+
self.sr = sr
|
182 |
+
|
183 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
184 |
+
try:
|
185 |
+
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
186 |
+
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
|
187 |
+
except: # noqa: E722
|
188 |
+
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
189 |
+
|
190 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
191 |
+
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
192 |
+
):
|
193 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
194 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
195 |
+
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
196 |
+
):
|
197 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
198 |
+
|
199 |
+
self.feat_num = self.get_feat_num()
|
200 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
201 |
+
|
202 |
+
if feat_type != "fbank" and feat_type != "mfcc":
|
203 |
+
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
|
204 |
+
for name, param in self.feature_extract.named_parameters():
|
205 |
+
for freeze_val in freeze_list:
|
206 |
+
if freeze_val in name:
|
207 |
+
param.requires_grad = False
|
208 |
+
break
|
209 |
+
|
210 |
+
if not self.update_extract:
|
211 |
+
for param in self.feature_extract.parameters():
|
212 |
+
param.requires_grad = False
|
213 |
+
|
214 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
215 |
+
# self.channels = [channels] * 4 + [channels * 3]
|
216 |
+
self.channels = [channels] * 4 + [1536]
|
217 |
+
|
218 |
+
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
219 |
+
self.layer2 = SE_Res2Block(
|
220 |
+
self.channels[0],
|
221 |
+
self.channels[1],
|
222 |
+
kernel_size=3,
|
223 |
+
stride=1,
|
224 |
+
padding=2,
|
225 |
+
dilation=2,
|
226 |
+
scale=8,
|
227 |
+
se_bottleneck_dim=128,
|
228 |
+
)
|
229 |
+
self.layer3 = SE_Res2Block(
|
230 |
+
self.channels[1],
|
231 |
+
self.channels[2],
|
232 |
+
kernel_size=3,
|
233 |
+
stride=1,
|
234 |
+
padding=3,
|
235 |
+
dilation=3,
|
236 |
+
scale=8,
|
237 |
+
se_bottleneck_dim=128,
|
238 |
+
)
|
239 |
+
self.layer4 = SE_Res2Block(
|
240 |
+
self.channels[2],
|
241 |
+
self.channels[3],
|
242 |
+
kernel_size=3,
|
243 |
+
stride=1,
|
244 |
+
padding=4,
|
245 |
+
dilation=4,
|
246 |
+
scale=8,
|
247 |
+
se_bottleneck_dim=128,
|
248 |
+
)
|
249 |
+
|
250 |
+
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
251 |
+
cat_channels = channels * 3
|
252 |
+
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
253 |
+
self.pooling = AttentiveStatsPool(
|
254 |
+
self.channels[-1], attention_channels=128, global_context_att=global_context_att
|
255 |
+
)
|
256 |
+
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
257 |
+
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
258 |
+
|
259 |
+
def get_feat_num(self):
|
260 |
+
self.feature_extract.eval()
|
261 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
262 |
+
with torch.no_grad():
|
263 |
+
features = self.feature_extract(wav)
|
264 |
+
select_feature = features[self.feature_selection]
|
265 |
+
if isinstance(select_feature, (list, tuple)):
|
266 |
+
return len(select_feature)
|
267 |
+
else:
|
268 |
+
return 1
|
269 |
+
|
270 |
+
def get_feat(self, x):
|
271 |
+
if self.update_extract:
|
272 |
+
x = self.feature_extract([sample for sample in x])
|
273 |
+
else:
|
274 |
+
with torch.no_grad():
|
275 |
+
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
276 |
+
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
277 |
+
else:
|
278 |
+
x = self.feature_extract([sample for sample in x])
|
279 |
+
|
280 |
+
if self.feat_type == "fbank":
|
281 |
+
x = x.log()
|
282 |
+
|
283 |
+
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
284 |
+
x = x[self.feature_selection]
|
285 |
+
if isinstance(x, (list, tuple)):
|
286 |
+
x = torch.stack(x, dim=0)
|
287 |
+
else:
|
288 |
+
x = x.unsqueeze(0)
|
289 |
+
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
290 |
+
x = (norm_weights * x).sum(dim=0)
|
291 |
+
x = torch.transpose(x, 1, 2) + 1e-6
|
292 |
+
|
293 |
+
x = self.instance_norm(x)
|
294 |
+
return x
|
295 |
+
|
296 |
+
def forward(self, x):
|
297 |
+
x = self.get_feat(x)
|
298 |
+
|
299 |
+
out1 = self.layer1(x)
|
300 |
+
out2 = self.layer2(out1)
|
301 |
+
out3 = self.layer3(out2)
|
302 |
+
out4 = self.layer4(out3)
|
303 |
+
|
304 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
305 |
+
out = F.relu(self.conv(out))
|
306 |
+
out = self.bn(self.pooling(out))
|
307 |
+
out = self.linear(out)
|
308 |
+
|
309 |
+
return out
|
310 |
+
|
311 |
+
|
312 |
+
def ECAPA_TDNN_SMALL(
|
313 |
+
feat_dim,
|
314 |
+
emb_dim=256,
|
315 |
+
feat_type="wavlm_large",
|
316 |
+
sr=16000,
|
317 |
+
feature_selection="hidden_states",
|
318 |
+
update_extract=False,
|
319 |
+
config_path=None,
|
320 |
+
):
|
321 |
+
return ECAPA_TDNN(
|
322 |
+
feat_dim=feat_dim,
|
323 |
+
channels=512,
|
324 |
+
emb_dim=emb_dim,
|
325 |
+
feat_type=feat_type,
|
326 |
+
sr=sr,
|
327 |
+
feature_selection=feature_selection,
|
328 |
+
update_extract=update_extract,
|
329 |
+
config_path=config_path,
|
330 |
+
)
|
src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import time
|
8 |
+
from importlib.resources import files
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from f5_tts.eval.utils_eval import (
|
16 |
+
get_inference_prompt,
|
17 |
+
get_librispeech_test_clean_metainfo,
|
18 |
+
get_seedtts_testset_metainfo,
|
19 |
+
)
|
20 |
+
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
21 |
+
from f5_tts.model import CFM, DiT, UNetT
|
22 |
+
from f5_tts.model.utils import get_tokenizer
|
23 |
+
|
24 |
+
accelerator = Accelerator()
|
25 |
+
device = f"cuda:{accelerator.process_index}"
|
26 |
+
|
27 |
+
|
28 |
+
# --------------------- Dataset Settings -------------------- #
|
29 |
+
|
30 |
+
target_sample_rate = 24000
|
31 |
+
n_mel_channels = 100
|
32 |
+
hop_length = 256
|
33 |
+
win_length = 1024
|
34 |
+
n_fft = 1024
|
35 |
+
target_rms = 0.1
|
36 |
+
|
37 |
+
|
38 |
+
tokenizer = "pinyin"
|
39 |
+
rel_path = str(files("f5_tts").joinpath("../../"))
|
40 |
+
|
41 |
+
|
42 |
+
def main():
|
43 |
+
# ---------------------- infer setting ---------------------- #
|
44 |
+
|
45 |
+
parser = argparse.ArgumentParser(description="batch inference")
|
46 |
+
|
47 |
+
parser.add_argument("-s", "--seed", default=None, type=int)
|
48 |
+
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
49 |
+
parser.add_argument("-n", "--expname", required=True)
|
50 |
+
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
51 |
+
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
52 |
+
|
53 |
+
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
54 |
+
parser.add_argument("-o", "--odemethod", default="euler")
|
55 |
+
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
56 |
+
|
57 |
+
parser.add_argument("-t", "--testset", required=True)
|
58 |
+
|
59 |
+
args = parser.parse_args()
|
60 |
+
|
61 |
+
seed = args.seed
|
62 |
+
dataset_name = args.dataset
|
63 |
+
exp_name = args.expname
|
64 |
+
ckpt_step = args.ckptstep
|
65 |
+
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
66 |
+
mel_spec_type = args.mel_spec_type
|
67 |
+
|
68 |
+
nfe_step = args.nfestep
|
69 |
+
ode_method = args.odemethod
|
70 |
+
sway_sampling_coef = args.swaysampling
|
71 |
+
|
72 |
+
testset = args.testset
|
73 |
+
|
74 |
+
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
75 |
+
cfg_strength = 2.0
|
76 |
+
speed = 1.0
|
77 |
+
use_truth_duration = False
|
78 |
+
no_ref_audio = False
|
79 |
+
|
80 |
+
if exp_name == "F5TTS_Base":
|
81 |
+
model_cls = DiT
|
82 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
83 |
+
|
84 |
+
elif exp_name == "E2TTS_Base":
|
85 |
+
model_cls = UNetT
|
86 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
87 |
+
|
88 |
+
if testset == "ls_pc_test_clean":
|
89 |
+
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
90 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
91 |
+
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
92 |
+
|
93 |
+
elif testset == "seedtts_test_zh":
|
94 |
+
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
|
95 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
96 |
+
|
97 |
+
elif testset == "seedtts_test_en":
|
98 |
+
metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
|
99 |
+
metainfo = get_seedtts_testset_metainfo(metalst)
|
100 |
+
|
101 |
+
# path to save genereted wavs
|
102 |
+
output_dir = (
|
103 |
+
f"{rel_path}/"
|
104 |
+
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
105 |
+
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
106 |
+
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
107 |
+
f"_cfg{cfg_strength}_speed{speed}"
|
108 |
+
f"{'_gt-dur' if use_truth_duration else ''}"
|
109 |
+
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
110 |
+
)
|
111 |
+
|
112 |
+
# -------------------------------------------------#
|
113 |
+
|
114 |
+
use_ema = True
|
115 |
+
|
116 |
+
prompts_all = get_inference_prompt(
|
117 |
+
metainfo,
|
118 |
+
speed=speed,
|
119 |
+
tokenizer=tokenizer,
|
120 |
+
target_sample_rate=target_sample_rate,
|
121 |
+
n_mel_channels=n_mel_channels,
|
122 |
+
hop_length=hop_length,
|
123 |
+
mel_spec_type=mel_spec_type,
|
124 |
+
target_rms=target_rms,
|
125 |
+
use_truth_duration=use_truth_duration,
|
126 |
+
infer_batch_size=infer_batch_size,
|
127 |
+
)
|
128 |
+
|
129 |
+
# Vocoder model
|
130 |
+
local = False
|
131 |
+
if mel_spec_type == "vocos":
|
132 |
+
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
133 |
+
elif mel_spec_type == "bigvgan":
|
134 |
+
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
135 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
136 |
+
|
137 |
+
# Tokenizer
|
138 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
139 |
+
|
140 |
+
# Model
|
141 |
+
model = CFM(
|
142 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
143 |
+
mel_spec_kwargs=dict(
|
144 |
+
n_fft=n_fft,
|
145 |
+
hop_length=hop_length,
|
146 |
+
win_length=win_length,
|
147 |
+
n_mel_channels=n_mel_channels,
|
148 |
+
target_sample_rate=target_sample_rate,
|
149 |
+
mel_spec_type=mel_spec_type,
|
150 |
+
),
|
151 |
+
odeint_kwargs=dict(
|
152 |
+
method=ode_method,
|
153 |
+
),
|
154 |
+
vocab_char_map=vocab_char_map,
|
155 |
+
).to(device)
|
156 |
+
|
157 |
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
158 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
159 |
+
|
160 |
+
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
161 |
+
os.makedirs(output_dir)
|
162 |
+
|
163 |
+
# start batch inference
|
164 |
+
accelerator.wait_for_everyone()
|
165 |
+
start = time.time()
|
166 |
+
|
167 |
+
with accelerator.split_between_processes(prompts_all) as prompts:
|
168 |
+
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
169 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
170 |
+
ref_mels = ref_mels.to(device)
|
171 |
+
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
172 |
+
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
173 |
+
|
174 |
+
# Inference
|
175 |
+
with torch.inference_mode():
|
176 |
+
generated, _ = model.sample(
|
177 |
+
cond=ref_mels,
|
178 |
+
text=final_text_list,
|
179 |
+
duration=total_mel_lens,
|
180 |
+
lens=ref_mel_lens,
|
181 |
+
steps=nfe_step,
|
182 |
+
cfg_strength=cfg_strength,
|
183 |
+
sway_sampling_coef=sway_sampling_coef,
|
184 |
+
no_ref_audio=no_ref_audio,
|
185 |
+
seed=seed,
|
186 |
+
)
|
187 |
+
# Final result
|
188 |
+
for i, gen in enumerate(generated):
|
189 |
+
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
190 |
+
gen_mel_spec = gen.permute(0, 2, 1)
|
191 |
+
if mel_spec_type == "vocos":
|
192 |
+
generated_wave = vocoder.decode(gen_mel_spec)
|
193 |
+
elif mel_spec_type == "bigvgan":
|
194 |
+
generated_wave = vocoder(gen_mel_spec)
|
195 |
+
|
196 |
+
if ref_rms_list[i] < target_rms:
|
197 |
+
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
198 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
|
199 |
+
|
200 |
+
accelerator.wait_for_everyone()
|
201 |
+
if accelerator.is_main_process:
|
202 |
+
timediff = time.time() - start
|
203 |
+
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
main()
|
src/f5_tts/eval/src_f5_tts_eval_eval_infer_batch.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
+
|
8 |
+
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
+
|
13 |
+
# etc.
|
src/f5_tts/eval/src_f5_tts_eval_eval_librispeech_test_clean.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
|
8 |
+
import multiprocessing as mp
|
9 |
+
from importlib.resources import files
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from f5_tts.eval.utils_eval import (
|
14 |
+
get_librispeech_test,
|
15 |
+
run_asr_wer,
|
16 |
+
run_sim,
|
17 |
+
)
|
18 |
+
|
19 |
+
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
+
|
21 |
+
|
22 |
+
eval_task = "wer" # sim | wer
|
23 |
+
lang = "en"
|
24 |
+
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
25 |
+
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
26 |
+
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
27 |
+
|
28 |
+
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
29 |
+
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
30 |
+
|
31 |
+
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
32 |
+
## leading to a low similarity for the ground truth in some cases.
|
33 |
+
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
34 |
+
|
35 |
+
local = False
|
36 |
+
if local: # use local custom checkpoint dir
|
37 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
38 |
+
else:
|
39 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
40 |
+
|
41 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
42 |
+
|
43 |
+
|
44 |
+
# --------------------------- WER ---------------------------
|
45 |
+
|
46 |
+
if eval_task == "wer":
|
47 |
+
wers = []
|
48 |
+
|
49 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
50 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
51 |
+
results = pool.map(run_asr_wer, args)
|
52 |
+
for wers_ in results:
|
53 |
+
wers.extend(wers_)
|
54 |
+
|
55 |
+
wer = round(np.mean(wers) * 100, 3)
|
56 |
+
print(f"\nTotal {len(wers)} samples")
|
57 |
+
print(f"WER : {wer}%")
|
58 |
+
|
59 |
+
|
60 |
+
# --------------------------- SIM ---------------------------
|
61 |
+
|
62 |
+
if eval_task == "sim":
|
63 |
+
sim_list = []
|
64 |
+
|
65 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
66 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
67 |
+
results = pool.map(run_sim, args)
|
68 |
+
for sim_ in results:
|
69 |
+
sim_list.extend(sim_)
|
70 |
+
|
71 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
72 |
+
print(f"\nTotal {len(sim_list)} samples")
|
73 |
+
print(f"SIM : {sim}")
|
src/f5_tts/eval/src_f5_tts_eval_eval_seedtts_testset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate with Seed-TTS testset
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
|
8 |
+
import multiprocessing as mp
|
9 |
+
from importlib.resources import files
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from f5_tts.eval.utils_eval import (
|
14 |
+
get_seed_tts_test,
|
15 |
+
run_asr_wer,
|
16 |
+
run_sim,
|
17 |
+
)
|
18 |
+
|
19 |
+
rel_path = str(files("f5_tts").joinpath("../../"))
|
20 |
+
|
21 |
+
|
22 |
+
eval_task = "wer" # sim | wer
|
23 |
+
lang = "zh" # zh | en
|
24 |
+
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
25 |
+
# gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
26 |
+
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
27 |
+
|
28 |
+
|
29 |
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
30 |
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
31 |
+
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
32 |
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
33 |
+
|
34 |
+
local = False
|
35 |
+
if local: # use local custom checkpoint dir
|
36 |
+
if lang == "zh":
|
37 |
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
38 |
+
elif lang == "en":
|
39 |
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
40 |
+
else:
|
41 |
+
asr_ckpt_dir = "" # auto download to cache dir
|
42 |
+
|
43 |
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
44 |
+
|
45 |
+
|
46 |
+
# --------------------------- WER ---------------------------
|
47 |
+
|
48 |
+
if eval_task == "wer":
|
49 |
+
wers = []
|
50 |
+
|
51 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
52 |
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
53 |
+
results = pool.map(run_asr_wer, args)
|
54 |
+
for wers_ in results:
|
55 |
+
wers.extend(wers_)
|
56 |
+
|
57 |
+
wer = round(np.mean(wers) * 100, 3)
|
58 |
+
print(f"\nTotal {len(wers)} samples")
|
59 |
+
print(f"WER : {wer}%")
|
60 |
+
|
61 |
+
|
62 |
+
# --------------------------- SIM ---------------------------
|
63 |
+
|
64 |
+
if eval_task == "sim":
|
65 |
+
sim_list = []
|
66 |
+
|
67 |
+
with mp.Pool(processes=len(gpus)) as pool:
|
68 |
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
69 |
+
results = pool.map(run_sim, args)
|
70 |
+
for sim_ in results:
|
71 |
+
sim_list.extend(sim_)
|
72 |
+
|
73 |
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
74 |
+
print(f"\nTotal {len(sim_list)} samples")
|
75 |
+
print(f"SIM : {sim}")
|
src/f5_tts/eval/src_f5_tts_eval_utils_eval.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
12 |
+
from f5_tts.model.modules import MelSpec
|
13 |
+
from f5_tts.model.utils import convert_char_to_pinyin
|
14 |
+
|
15 |
+
|
16 |
+
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
17 |
+
def get_seedtts_testset_metainfo(metalst):
|
18 |
+
f = open(metalst)
|
19 |
+
lines = f.readlines()
|
20 |
+
f.close()
|
21 |
+
metainfo = []
|
22 |
+
for line in lines:
|
23 |
+
if len(line.strip().split("|")) == 5:
|
24 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
25 |
+
elif len(line.strip().split("|")) == 4:
|
26 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
27 |
+
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
28 |
+
if not os.path.isabs(prompt_wav):
|
29 |
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
30 |
+
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
31 |
+
return metainfo
|
32 |
+
|
33 |
+
|
34 |
+
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
35 |
+
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
36 |
+
f = open(metalst)
|
37 |
+
lines = f.readlines()
|
38 |
+
f.close()
|
39 |
+
metainfo = []
|
40 |
+
for line in lines:
|
41 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
42 |
+
|
43 |
+
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
44 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
45 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
46 |
+
|
47 |
+
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
48 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
49 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
50 |
+
|
51 |
+
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
52 |
+
|
53 |
+
return metainfo
|
54 |
+
|
55 |
+
|
56 |
+
# padded to max length mel batch
|
57 |
+
def padded_mel_batch(ref_mels):
|
58 |
+
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
59 |
+
padded_ref_mels = []
|
60 |
+
for mel in ref_mels:
|
61 |
+
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
62 |
+
padded_ref_mels.append(padded_ref_mel)
|
63 |
+
padded_ref_mels = torch.stack(padded_ref_mels)
|
64 |
+
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
65 |
+
return padded_ref_mels
|
66 |
+
|
67 |
+
|
68 |
+
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
69 |
+
|
70 |
+
|
71 |
+
def get_inference_prompt(
|
72 |
+
metainfo,
|
73 |
+
speed=1.0,
|
74 |
+
tokenizer="pinyin",
|
75 |
+
polyphone=True,
|
76 |
+
target_sample_rate=24000,
|
77 |
+
n_fft=1024,
|
78 |
+
win_length=1024,
|
79 |
+
n_mel_channels=100,
|
80 |
+
hop_length=256,
|
81 |
+
mel_spec_type="vocos",
|
82 |
+
target_rms=0.1,
|
83 |
+
use_truth_duration=False,
|
84 |
+
infer_batch_size=1,
|
85 |
+
num_buckets=200,
|
86 |
+
min_secs=3,
|
87 |
+
max_secs=40,
|
88 |
+
):
|
89 |
+
prompts_all = []
|
90 |
+
|
91 |
+
min_tokens = min_secs * target_sample_rate // hop_length
|
92 |
+
max_tokens = max_secs * target_sample_rate // hop_length
|
93 |
+
|
94 |
+
batch_accum = [0] * num_buckets
|
95 |
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
96 |
+
[[] for _ in range(num_buckets)] for _ in range(6)
|
97 |
+
)
|
98 |
+
|
99 |
+
mel_spectrogram = MelSpec(
|
100 |
+
n_fft=n_fft,
|
101 |
+
hop_length=hop_length,
|
102 |
+
win_length=win_length,
|
103 |
+
n_mel_channels=n_mel_channels,
|
104 |
+
target_sample_rate=target_sample_rate,
|
105 |
+
mel_spec_type=mel_spec_type,
|
106 |
+
)
|
107 |
+
|
108 |
+
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
109 |
+
# Audio
|
110 |
+
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
111 |
+
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
112 |
+
if ref_rms < target_rms:
|
113 |
+
ref_audio = ref_audio * target_rms / ref_rms
|
114 |
+
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
115 |
+
if ref_sr != target_sample_rate:
|
116 |
+
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
117 |
+
ref_audio = resampler(ref_audio)
|
118 |
+
|
119 |
+
# Text
|
120 |
+
if len(prompt_text[-1].encode("utf-8")) == 1:
|
121 |
+
prompt_text = prompt_text + " "
|
122 |
+
text = [prompt_text + gt_text]
|
123 |
+
if tokenizer == "pinyin":
|
124 |
+
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
125 |
+
else:
|
126 |
+
text_list = text
|
127 |
+
|
128 |
+
# Duration, mel frame length
|
129 |
+
ref_mel_len = ref_audio.shape[-1] // hop_length
|
130 |
+
if use_truth_duration:
|
131 |
+
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
132 |
+
if gt_sr != target_sample_rate:
|
133 |
+
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
134 |
+
gt_audio = resampler(gt_audio)
|
135 |
+
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
136 |
+
|
137 |
+
# # test vocoder resynthesis
|
138 |
+
# ref_audio = gt_audio
|
139 |
+
else:
|
140 |
+
ref_text_len = len(prompt_text.encode("utf-8"))
|
141 |
+
gen_text_len = len(gt_text.encode("utf-8"))
|
142 |
+
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
143 |
+
|
144 |
+
# to mel spectrogram
|
145 |
+
ref_mel = mel_spectrogram(ref_audio)
|
146 |
+
ref_mel = ref_mel.squeeze(0)
|
147 |
+
|
148 |
+
# deal with batch
|
149 |
+
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
150 |
+
assert (
|
151 |
+
min_tokens <= total_mel_len <= max_tokens
|
152 |
+
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
153 |
+
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
154 |
+
|
155 |
+
utts[bucket_i].append(utt)
|
156 |
+
ref_rms_list[bucket_i].append(ref_rms)
|
157 |
+
ref_mels[bucket_i].append(ref_mel)
|
158 |
+
ref_mel_lens[bucket_i].append(ref_mel_len)
|
159 |
+
total_mel_lens[bucket_i].append(total_mel_len)
|
160 |
+
final_text_list[bucket_i].extend(text_list)
|
161 |
+
|
162 |
+
batch_accum[bucket_i] += total_mel_len
|
163 |
+
|
164 |
+
if batch_accum[bucket_i] >= infer_batch_size:
|
165 |
+
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
166 |
+
prompts_all.append(
|
167 |
+
(
|
168 |
+
utts[bucket_i],
|
169 |
+
ref_rms_list[bucket_i],
|
170 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
171 |
+
ref_mel_lens[bucket_i],
|
172 |
+
total_mel_lens[bucket_i],
|
173 |
+
final_text_list[bucket_i],
|
174 |
+
)
|
175 |
+
)
|
176 |
+
batch_accum[bucket_i] = 0
|
177 |
+
(
|
178 |
+
utts[bucket_i],
|
179 |
+
ref_rms_list[bucket_i],
|
180 |
+
ref_mels[bucket_i],
|
181 |
+
ref_mel_lens[bucket_i],
|
182 |
+
total_mel_lens[bucket_i],
|
183 |
+
final_text_list[bucket_i],
|
184 |
+
) = [], [], [], [], [], []
|
185 |
+
|
186 |
+
# add residual
|
187 |
+
for bucket_i, bucket_frames in enumerate(batch_accum):
|
188 |
+
if bucket_frames > 0:
|
189 |
+
prompts_all.append(
|
190 |
+
(
|
191 |
+
utts[bucket_i],
|
192 |
+
ref_rms_list[bucket_i],
|
193 |
+
padded_mel_batch(ref_mels[bucket_i]),
|
194 |
+
ref_mel_lens[bucket_i],
|
195 |
+
total_mel_lens[bucket_i],
|
196 |
+
final_text_list[bucket_i],
|
197 |
+
)
|
198 |
+
)
|
199 |
+
# not only leave easy work for last workers
|
200 |
+
random.seed(666)
|
201 |
+
random.shuffle(prompts_all)
|
202 |
+
|
203 |
+
return prompts_all
|
204 |
+
|
205 |
+
|
206 |
+
# get wav_res_ref_text of seed-tts test metalst
|
207 |
+
# https://github.com/BytedanceSpeech/seed-tts-eval
|
208 |
+
|
209 |
+
|
210 |
+
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
211 |
+
f = open(metalst)
|
212 |
+
lines = f.readlines()
|
213 |
+
f.close()
|
214 |
+
|
215 |
+
test_set_ = []
|
216 |
+
for line in tqdm(lines):
|
217 |
+
if len(line.strip().split("|")) == 5:
|
218 |
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
219 |
+
elif len(line.strip().split("|")) == 4:
|
220 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
221 |
+
|
222 |
+
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
|
223 |
+
continue
|
224 |
+
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
|
225 |
+
if not os.path.isabs(prompt_wav):
|
226 |
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
227 |
+
|
228 |
+
test_set_.append((gen_wav, prompt_wav, gt_text))
|
229 |
+
|
230 |
+
num_jobs = len(gpus)
|
231 |
+
if num_jobs == 1:
|
232 |
+
return [(gpus[0], test_set_)]
|
233 |
+
|
234 |
+
wav_per_job = len(test_set_) // num_jobs + 1
|
235 |
+
test_set = []
|
236 |
+
for i in range(num_jobs):
|
237 |
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
238 |
+
|
239 |
+
return test_set
|
240 |
+
|
241 |
+
|
242 |
+
# get librispeech test-clean cross sentence test
|
243 |
+
|
244 |
+
|
245 |
+
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
|
246 |
+
f = open(metalst)
|
247 |
+
lines = f.readlines()
|
248 |
+
f.close()
|
249 |
+
|
250 |
+
test_set_ = []
|
251 |
+
for line in tqdm(lines):
|
252 |
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
253 |
+
|
254 |
+
if eval_ground_truth:
|
255 |
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
256 |
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
257 |
+
else:
|
258 |
+
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
259 |
+
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
260 |
+
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
261 |
+
|
262 |
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
263 |
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
264 |
+
|
265 |
+
test_set_.append((gen_wav, ref_wav, gen_txt))
|
266 |
+
|
267 |
+
num_jobs = len(gpus)
|
268 |
+
if num_jobs == 1:
|
269 |
+
return [(gpus[0], test_set_)]
|
270 |
+
|
271 |
+
wav_per_job = len(test_set_) // num_jobs + 1
|
272 |
+
test_set = []
|
273 |
+
for i in range(num_jobs):
|
274 |
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
275 |
+
|
276 |
+
return test_set
|
277 |
+
|
278 |
+
|
279 |
+
# load asr model
|
280 |
+
|
281 |
+
|
282 |
+
def load_asr_model(lang, ckpt_dir=""):
|
283 |
+
if lang == "zh":
|
284 |
+
from funasr import AutoModel
|
285 |
+
|
286 |
+
model = AutoModel(
|
287 |
+
model=os.path.join(ckpt_dir, "paraformer-zh"),
|
288 |
+
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
289 |
+
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
290 |
+
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
291 |
+
disable_update=True,
|
292 |
+
) # following seed-tts setting
|
293 |
+
elif lang == "en":
|
294 |
+
from faster_whisper import WhisperModel
|
295 |
+
|
296 |
+
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
297 |
+
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
298 |
+
return model
|
299 |
+
|
300 |
+
|
301 |
+
# WER Evaluation, the way Seed-TTS does
|
302 |
+
|
303 |
+
|
304 |
+
def run_asr_wer(args):
|
305 |
+
rank, lang, test_set, ckpt_dir = args
|
306 |
+
|
307 |
+
if lang == "zh":
|
308 |
+
import zhconv
|
309 |
+
|
310 |
+
torch.cuda.set_device(rank)
|
311 |
+
elif lang == "en":
|
312 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
313 |
+
else:
|
314 |
+
raise NotImplementedError(
|
315 |
+
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
|
316 |
+
)
|
317 |
+
|
318 |
+
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
|
319 |
+
|
320 |
+
from zhon.hanzi import punctuation
|
321 |
+
|
322 |
+
punctuation_all = punctuation + string.punctuation
|
323 |
+
wers = []
|
324 |
+
|
325 |
+
from jiwer import compute_measures
|
326 |
+
|
327 |
+
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
328 |
+
if lang == "zh":
|
329 |
+
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
330 |
+
hypo = res[0]["text"]
|
331 |
+
hypo = zhconv.convert(hypo, "zh-cn")
|
332 |
+
elif lang == "en":
|
333 |
+
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
334 |
+
hypo = ""
|
335 |
+
for segment in segments:
|
336 |
+
hypo = hypo + " " + segment.text
|
337 |
+
|
338 |
+
# raw_truth = truth
|
339 |
+
# raw_hypo = hypo
|
340 |
+
|
341 |
+
for x in punctuation_all:
|
342 |
+
truth = truth.replace(x, "")
|
343 |
+
hypo = hypo.replace(x, "")
|
344 |
+
|
345 |
+
truth = truth.replace(" ", " ")
|
346 |
+
hypo = hypo.replace(" ", " ")
|
347 |
+
|
348 |
+
if lang == "zh":
|
349 |
+
truth = " ".join([x for x in truth])
|
350 |
+
hypo = " ".join([x for x in hypo])
|
351 |
+
elif lang == "en":
|
352 |
+
truth = truth.lower()
|
353 |
+
hypo = hypo.lower()
|
354 |
+
|
355 |
+
measures = compute_measures(truth, hypo)
|
356 |
+
wer = measures["wer"]
|
357 |
+
|
358 |
+
# ref_list = truth.split(" ")
|
359 |
+
# subs = measures["substitutions"] / len(ref_list)
|
360 |
+
# dele = measures["deletions"] / len(ref_list)
|
361 |
+
# inse = measures["insertions"] / len(ref_list)
|
362 |
+
|
363 |
+
wers.append(wer)
|
364 |
+
|
365 |
+
return wers
|
366 |
+
|
367 |
+
|
368 |
+
# SIM Evaluation
|
369 |
+
|
370 |
+
|
371 |
+
def run_sim(args):
|
372 |
+
rank, test_set, ckpt_dir = args
|
373 |
+
device = f"cuda:{rank}"
|
374 |
+
|
375 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
376 |
+
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
377 |
+
model.load_state_dict(state_dict["model"], strict=False)
|
378 |
+
|
379 |
+
use_gpu = True if torch.cuda.is_available() else False
|
380 |
+
if use_gpu:
|
381 |
+
model = model.cuda(device)
|
382 |
+
model.eval()
|
383 |
+
|
384 |
+
sim_list = []
|
385 |
+
for wav1, wav2, truth in tqdm(test_set):
|
386 |
+
wav1, sr1 = torchaudio.load(wav1)
|
387 |
+
wav2, sr2 = torchaudio.load(wav2)
|
388 |
+
|
389 |
+
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
390 |
+
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
391 |
+
wav1 = resample1(wav1)
|
392 |
+
wav2 = resample2(wav2)
|
393 |
+
|
394 |
+
if use_gpu:
|
395 |
+
wav1 = wav1.cuda(device)
|
396 |
+
wav2 = wav2.cuda(device)
|
397 |
+
with torch.no_grad():
|
398 |
+
emb1 = model(wav1)
|
399 |
+
emb2 = model(wav2)
|
400 |
+
|
401 |
+
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
402 |
+
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
403 |
+
sim_list.append(sim)
|
404 |
+
|
405 |
+
return sim_list
|
src/f5_tts/infer/src_f5_tts_infer_README.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inference
|
2 |
+
|
3 |
+
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
|
4 |
+
|
5 |
+
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
|
6 |
+
|
7 |
+
To avoid possible inference failures, make sure you have seen through the following instructions.
|
8 |
+
|
9 |
+
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
10 |
+
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
11 |
+
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
12 |
+
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
13 |
+
|
14 |
+
|
15 |
+
## Gradio App
|
16 |
+
|
17 |
+
Currently supported features:
|
18 |
+
|
19 |
+
- Basic TTS with Chunk Inference
|
20 |
+
- Multi-Style / Multi-Speaker Generation
|
21 |
+
- Voice Chat powered by Qwen2.5-3B-Instruct
|
22 |
+
|
23 |
+
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
24 |
+
|
25 |
+
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
|
26 |
+
|
27 |
+
Could also be used as a component for larger application.
|
28 |
+
```python
|
29 |
+
import gradio as gr
|
30 |
+
from f5_tts.infer.infer_gradio import app
|
31 |
+
|
32 |
+
with gr.Blocks() as main_app:
|
33 |
+
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
|
34 |
+
|
35 |
+
# ... other Gradio components
|
36 |
+
|
37 |
+
app.render()
|
38 |
+
|
39 |
+
main_app.launch()
|
40 |
+
```
|
41 |
+
|
42 |
+
|
43 |
+
## CLI Inference
|
44 |
+
|
45 |
+
The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
|
46 |
+
|
47 |
+
The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
|
48 |
+
|
49 |
+
For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
|
50 |
+
|
51 |
+
Basically you can inference with flags:
|
52 |
+
```bash
|
53 |
+
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
54 |
+
f5-tts_infer-cli \
|
55 |
+
--model "F5-TTS" \
|
56 |
+
--ref_audio "ref_audio.wav" \
|
57 |
+
--ref_text "The content, subtitle or transcription of reference audio." \
|
58 |
+
--gen_text "Some text you want TTS model generate for you."
|
59 |
+
|
60 |
+
# Choose Vocoder
|
61 |
+
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
62 |
+
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
63 |
+
```
|
64 |
+
|
65 |
+
And a `.toml` file would help with more flexible usage.
|
66 |
+
|
67 |
+
```bash
|
68 |
+
f5-tts_infer-cli -c custom.toml
|
69 |
+
```
|
70 |
+
|
71 |
+
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
72 |
+
|
73 |
+
```toml
|
74 |
+
# F5-TTS | E2-TTS
|
75 |
+
model = "F5-TTS"
|
76 |
+
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
77 |
+
# If an empty "", transcribes the reference audio automatically.
|
78 |
+
ref_text = "Some call me nature, others call me mother nature."
|
79 |
+
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
80 |
+
# File with text to generate. Ignores the text above.
|
81 |
+
gen_file = ""
|
82 |
+
remove_silence = false
|
83 |
+
output_dir = "tests"
|
84 |
+
```
|
85 |
+
|
86 |
+
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
87 |
+
|
88 |
+
```toml
|
89 |
+
# F5-TTS | E2-TTS
|
90 |
+
model = "F5-TTS"
|
91 |
+
ref_audio = "infer/examples/multi/main.flac"
|
92 |
+
# If an empty "", transcribes the reference audio automatically.
|
93 |
+
ref_text = ""
|
94 |
+
gen_text = ""
|
95 |
+
# File with text to generate. Ignores the text above.
|
96 |
+
gen_file = "infer/examples/multi/story.txt"
|
97 |
+
remove_silence = true
|
98 |
+
output_dir = "tests"
|
99 |
+
|
100 |
+
[voices.town]
|
101 |
+
ref_audio = "infer/examples/multi/town.flac"
|
102 |
+
ref_text = ""
|
103 |
+
|
104 |
+
[voices.country]
|
105 |
+
ref_audio = "infer/examples/multi/country.flac"
|
106 |
+
ref_text = ""
|
107 |
+
```
|
108 |
+
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
109 |
+
|
110 |
+
## Speech Editing
|
111 |
+
|
112 |
+
To test speech editing capabilities, use the following command:
|
113 |
+
|
114 |
+
```bash
|
115 |
+
python src/f5_tts/infer/speech_edit.py
|
116 |
+
```
|
src/f5_tts/infer/src_f5_tts_infer_infer_cli.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import codecs
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from importlib.resources import files
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import tomli
|
11 |
+
from cached_path import cached_path
|
12 |
+
|
13 |
+
from f5_tts.infer.utils_infer import (
|
14 |
+
infer_process,
|
15 |
+
load_model,
|
16 |
+
load_vocoder,
|
17 |
+
preprocess_ref_audio_text,
|
18 |
+
remove_silence_for_generated_wav,
|
19 |
+
)
|
20 |
+
from f5_tts.model import DiT, UNetT
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser(
|
23 |
+
prog="python3 infer-cli.py",
|
24 |
+
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
25 |
+
epilog="Specify options above to override one or more settings from config.",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"-c",
|
29 |
+
"--config",
|
30 |
+
help="Configuration file. Default=infer/examples/basic/basic.toml",
|
31 |
+
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"-m",
|
35 |
+
"--model",
|
36 |
+
help="F5-TTS | E2-TTS",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"-p",
|
40 |
+
"--ckpt_file",
|
41 |
+
help="The Checkpoint .pt",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"-v",
|
45 |
+
"--vocab_file",
|
46 |
+
help="The vocab .txt",
|
47 |
+
)
|
48 |
+
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
|
49 |
+
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
|
50 |
+
parser.add_argument(
|
51 |
+
"-t",
|
52 |
+
"--gen_text",
|
53 |
+
type=str,
|
54 |
+
help="Text to generate.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"-f",
|
58 |
+
"--gen_file",
|
59 |
+
type=str,
|
60 |
+
help="File with text to generate. Ignores --text",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"-o",
|
64 |
+
"--output_dir",
|
65 |
+
type=str,
|
66 |
+
help="Path to output folder..",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--remove_silence",
|
70 |
+
help="Remove silence.",
|
71 |
+
)
|
72 |
+
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
|
73 |
+
parser.add_argument(
|
74 |
+
"--load_vocoder_from_local",
|
75 |
+
action="store_true",
|
76 |
+
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--speed",
|
80 |
+
type=float,
|
81 |
+
default=1.0,
|
82 |
+
help="Adjust the speed of the audio generation (default: 1.0)",
|
83 |
+
)
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
config = tomli.load(open(args.config, "rb"))
|
87 |
+
|
88 |
+
ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
|
89 |
+
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
|
90 |
+
gen_text = args.gen_text if args.gen_text else config["gen_text"]
|
91 |
+
gen_file = args.gen_file if args.gen_file else config["gen_file"]
|
92 |
+
|
93 |
+
# patches for pip pkg user
|
94 |
+
if "infer/examples/" in ref_audio:
|
95 |
+
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
|
96 |
+
if "infer/examples/" in gen_file:
|
97 |
+
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
|
98 |
+
if "voices" in config:
|
99 |
+
for voice in config["voices"]:
|
100 |
+
voice_ref_audio = config["voices"][voice]["ref_audio"]
|
101 |
+
if "infer/examples/" in voice_ref_audio:
|
102 |
+
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
103 |
+
|
104 |
+
if gen_file:
|
105 |
+
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
106 |
+
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
107 |
+
model = args.model if args.model else config["model"]
|
108 |
+
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
109 |
+
vocab_file = args.vocab_file if args.vocab_file else ""
|
110 |
+
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
111 |
+
speed = args.speed
|
112 |
+
wave_path = Path(output_dir) / "infer_cli_out.wav"
|
113 |
+
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
114 |
+
if args.vocoder_name == "vocos":
|
115 |
+
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
116 |
+
elif args.vocoder_name == "bigvgan":
|
117 |
+
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
118 |
+
mel_spec_type = args.vocoder_name
|
119 |
+
|
120 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
121 |
+
|
122 |
+
|
123 |
+
# load models
|
124 |
+
if model == "F5-TTS":
|
125 |
+
model_cls = DiT
|
126 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
127 |
+
if ckpt_file == "":
|
128 |
+
if args.vocoder_name == "vocos":
|
129 |
+
repo_name = "F5-TTS"
|
130 |
+
exp_name = "F5TTS_Base"
|
131 |
+
ckpt_step = 1200000
|
132 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
133 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
134 |
+
elif args.vocoder_name == "bigvgan":
|
135 |
+
repo_name = "F5-TTS"
|
136 |
+
exp_name = "F5TTS_Base_bigvgan"
|
137 |
+
ckpt_step = 1250000
|
138 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
139 |
+
|
140 |
+
elif model == "E2-TTS":
|
141 |
+
model_cls = UNetT
|
142 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
143 |
+
if ckpt_file == "":
|
144 |
+
repo_name = "E2-TTS"
|
145 |
+
exp_name = "E2TTS_Base"
|
146 |
+
ckpt_step = 1200000
|
147 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
148 |
+
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
149 |
+
elif args.vocoder_name == "bigvgan": # TODO: need to test
|
150 |
+
repo_name = "F5-TTS"
|
151 |
+
exp_name = "F5TTS_Base_bigvgan"
|
152 |
+
ckpt_step = 1250000
|
153 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
154 |
+
|
155 |
+
|
156 |
+
print(f"Using {model}...")
|
157 |
+
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
|
158 |
+
|
159 |
+
|
160 |
+
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
161 |
+
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
162 |
+
if "voices" not in config:
|
163 |
+
voices = {"main": main_voice}
|
164 |
+
else:
|
165 |
+
voices = config["voices"]
|
166 |
+
voices["main"] = main_voice
|
167 |
+
for voice in voices:
|
168 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
169 |
+
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
170 |
+
)
|
171 |
+
print("Voice:", voice)
|
172 |
+
print("Ref_audio:", voices[voice]["ref_audio"])
|
173 |
+
print("Ref_text:", voices[voice]["ref_text"])
|
174 |
+
|
175 |
+
generated_audio_segments = []
|
176 |
+
reg1 = r"(?=\[\w+\])"
|
177 |
+
chunks = re.split(reg1, text_gen)
|
178 |
+
reg2 = r"\[(\w+)\]"
|
179 |
+
for text in chunks:
|
180 |
+
if not text.strip():
|
181 |
+
continue
|
182 |
+
match = re.match(reg2, text)
|
183 |
+
if match:
|
184 |
+
voice = match[1]
|
185 |
+
else:
|
186 |
+
print("No voice tag found, using main.")
|
187 |
+
voice = "main"
|
188 |
+
if voice not in voices:
|
189 |
+
print(f"Voice {voice} not found, using main.")
|
190 |
+
voice = "main"
|
191 |
+
text = re.sub(reg2, "", text)
|
192 |
+
gen_text = text.strip()
|
193 |
+
ref_audio = voices[voice]["ref_audio"]
|
194 |
+
ref_text = voices[voice]["ref_text"]
|
195 |
+
print(f"Voice: {voice}")
|
196 |
+
audio, final_sample_rate, spectragram = infer_process(
|
197 |
+
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
|
198 |
+
)
|
199 |
+
generated_audio_segments.append(audio)
|
200 |
+
|
201 |
+
if generated_audio_segments:
|
202 |
+
final_wave = np.concatenate(generated_audio_segments)
|
203 |
+
|
204 |
+
if not os.path.exists(output_dir):
|
205 |
+
os.makedirs(output_dir)
|
206 |
+
|
207 |
+
with open(wave_path, "wb") as f:
|
208 |
+
sf.write(f.name, final_wave, final_sample_rate)
|
209 |
+
# Remove silence
|
210 |
+
if remove_silence:
|
211 |
+
remove_silence_for_generated_wav(f.name)
|
212 |
+
print(f.name)
|
213 |
+
|
214 |
+
|
215 |
+
def main():
|
216 |
+
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
main()
|
src/f5_tts/infer/src_f5_tts_infer_speech_edit.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
8 |
+
from f5_tts.model import CFM, DiT, UNetT
|
9 |
+
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
10 |
+
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
12 |
+
|
13 |
+
|
14 |
+
# --------------------- Dataset Settings -------------------- #
|
15 |
+
|
16 |
+
target_sample_rate = 24000
|
17 |
+
n_mel_channels = 100
|
18 |
+
hop_length = 256
|
19 |
+
win_length = 1024
|
20 |
+
n_fft = 1024
|
21 |
+
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
22 |
+
target_rms = 0.1
|
23 |
+
|
24 |
+
tokenizer = "pinyin"
|
25 |
+
dataset_name = "Emilia_ZH_EN"
|
26 |
+
|
27 |
+
|
28 |
+
# ---------------------- infer setting ---------------------- #
|
29 |
+
|
30 |
+
seed = None # int | None
|
31 |
+
|
32 |
+
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
33 |
+
ckpt_step = 1200000
|
34 |
+
|
35 |
+
nfe_step = 32 # 16, 32
|
36 |
+
cfg_strength = 2.0
|
37 |
+
ode_method = "euler" # euler | midpoint
|
38 |
+
sway_sampling_coef = -1.0
|
39 |
+
speed = 1.0
|
40 |
+
|
41 |
+
if exp_name == "F5TTS_Base":
|
42 |
+
model_cls = DiT
|
43 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
44 |
+
|
45 |
+
elif exp_name == "E2TTS_Base":
|
46 |
+
model_cls = UNetT
|
47 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
48 |
+
|
49 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
50 |
+
output_dir = "tests"
|
51 |
+
|
52 |
+
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
53 |
+
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
54 |
+
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
55 |
+
# ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
|
56 |
+
# [result will be saved at same path of audio file]
|
57 |
+
# [--language "zho" for Chinese, "eng" for English]
|
58 |
+
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
59 |
+
|
60 |
+
audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
|
61 |
+
origin_text = "Some call me nature, others call me mother nature."
|
62 |
+
target_text = "Some call me optimist, others call me realist."
|
63 |
+
parts_to_edit = [
|
64 |
+
[1.42, 2.44],
|
65 |
+
[4.04, 4.9],
|
66 |
+
] # stard_ends of "nature" & "mother nature", in seconds
|
67 |
+
fix_duration = [
|
68 |
+
1.2,
|
69 |
+
1,
|
70 |
+
] # fix duration for "optimist" & "realist", in seconds
|
71 |
+
|
72 |
+
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
|
73 |
+
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
74 |
+
# target_text = "对,那就是你,万人敬仰的太白金星。"
|
75 |
+
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
|
76 |
+
# fix_duration = None # use origin text duration
|
77 |
+
|
78 |
+
|
79 |
+
# -------------------------------------------------#
|
80 |
+
|
81 |
+
use_ema = True
|
82 |
+
|
83 |
+
if not os.path.exists(output_dir):
|
84 |
+
os.makedirs(output_dir)
|
85 |
+
|
86 |
+
# Vocoder model
|
87 |
+
local = False
|
88 |
+
if mel_spec_type == "vocos":
|
89 |
+
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
90 |
+
elif mel_spec_type == "bigvgan":
|
91 |
+
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
92 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
93 |
+
|
94 |
+
# Tokenizer
|
95 |
+
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
96 |
+
|
97 |
+
# Model
|
98 |
+
model = CFM(
|
99 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
100 |
+
mel_spec_kwargs=dict(
|
101 |
+
n_fft=n_fft,
|
102 |
+
hop_length=hop_length,
|
103 |
+
win_length=win_length,
|
104 |
+
n_mel_channels=n_mel_channels,
|
105 |
+
target_sample_rate=target_sample_rate,
|
106 |
+
mel_spec_type=mel_spec_type,
|
107 |
+
),
|
108 |
+
odeint_kwargs=dict(
|
109 |
+
method=ode_method,
|
110 |
+
),
|
111 |
+
vocab_char_map=vocab_char_map,
|
112 |
+
).to(device)
|
113 |
+
|
114 |
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
115 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
116 |
+
|
117 |
+
# Audio
|
118 |
+
audio, sr = torchaudio.load(audio_to_edit)
|
119 |
+
if audio.shape[0] > 1:
|
120 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
121 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
122 |
+
if rms < target_rms:
|
123 |
+
audio = audio * target_rms / rms
|
124 |
+
if sr != target_sample_rate:
|
125 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
126 |
+
audio = resampler(audio)
|
127 |
+
offset = 0
|
128 |
+
audio_ = torch.zeros(1, 0)
|
129 |
+
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
|
130 |
+
for part in parts_to_edit:
|
131 |
+
start, end = part
|
132 |
+
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
133 |
+
part_dur = part_dur * target_sample_rate
|
134 |
+
start = start * target_sample_rate
|
135 |
+
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
136 |
+
edit_mask = torch.cat(
|
137 |
+
(
|
138 |
+
edit_mask,
|
139 |
+
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
140 |
+
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
141 |
+
),
|
142 |
+
dim=-1,
|
143 |
+
)
|
144 |
+
offset = end * target_sample_rate
|
145 |
+
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
146 |
+
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
147 |
+
audio = audio.to(device)
|
148 |
+
edit_mask = edit_mask.to(device)
|
149 |
+
|
150 |
+
# Text
|
151 |
+
text_list = [target_text]
|
152 |
+
if tokenizer == "pinyin":
|
153 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
154 |
+
else:
|
155 |
+
final_text_list = [text_list]
|
156 |
+
print(f"text : {text_list}")
|
157 |
+
print(f"pinyin: {final_text_list}")
|
158 |
+
|
159 |
+
# Duration
|
160 |
+
ref_audio_len = 0
|
161 |
+
duration = audio.shape[-1] // hop_length
|
162 |
+
|
163 |
+
# Inference
|
164 |
+
with torch.inference_mode():
|
165 |
+
generated, trajectory = model.sample(
|
166 |
+
cond=audio,
|
167 |
+
text=final_text_list,
|
168 |
+
duration=duration,
|
169 |
+
steps=nfe_step,
|
170 |
+
cfg_strength=cfg_strength,
|
171 |
+
sway_sampling_coef=sway_sampling_coef,
|
172 |
+
seed=seed,
|
173 |
+
edit_mask=edit_mask,
|
174 |
+
)
|
175 |
+
print(f"Generated mel: {generated.shape}")
|
176 |
+
|
177 |
+
# Final result
|
178 |
+
generated = generated.to(torch.float32)
|
179 |
+
generated = generated[:, ref_audio_len:, :]
|
180 |
+
gen_mel_spec = generated.permute(0, 2, 1)
|
181 |
+
if mel_spec_type == "vocos":
|
182 |
+
generated_wave = vocoder.decode(gen_mel_spec)
|
183 |
+
elif mel_spec_type == "bigvgan":
|
184 |
+
generated_wave = vocoder(gen_mel_spec)
|
185 |
+
|
186 |
+
if rms < target_rms:
|
187 |
+
generated_wave = generated_wave * rms / target_rms
|
188 |
+
|
189 |
+
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
190 |
+
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
|
191 |
+
print(f"Generated wav: {generated_wave.shape}")
|
src/f5_tts/infer/src_f5_tts_infer_utils_infer.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A unified script for inference process
|
2 |
+
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
|
7 |
+
|
8 |
+
import hashlib
|
9 |
+
import re
|
10 |
+
import tempfile
|
11 |
+
from importlib.resources import files
|
12 |
+
|
13 |
+
import matplotlib
|
14 |
+
|
15 |
+
matplotlib.use("Agg")
|
16 |
+
|
17 |
+
import matplotlib.pylab as plt
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
import tqdm
|
22 |
+
from pydub import AudioSegment, silence
|
23 |
+
from transformers import pipeline
|
24 |
+
from vocos import Vocos
|
25 |
+
|
26 |
+
from f5_tts.model import CFM
|
27 |
+
from f5_tts.model.utils import (
|
28 |
+
get_tokenizer,
|
29 |
+
convert_char_to_pinyin,
|
30 |
+
)
|
31 |
+
|
32 |
+
_ref_audio_cache = {}
|
33 |
+
|
34 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
35 |
+
|
36 |
+
# -----------------------------------------
|
37 |
+
|
38 |
+
target_sample_rate = 24000
|
39 |
+
n_mel_channels = 100
|
40 |
+
hop_length = 256
|
41 |
+
win_length = 1024
|
42 |
+
n_fft = 1024
|
43 |
+
mel_spec_type = "vocos"
|
44 |
+
target_rms = 0.1
|
45 |
+
cross_fade_duration = 0.15
|
46 |
+
ode_method = "euler"
|
47 |
+
nfe_step = 32 # 16, 32
|
48 |
+
cfg_strength = 2.0
|
49 |
+
sway_sampling_coef = -1.0
|
50 |
+
speed = 1.0
|
51 |
+
fix_duration = None
|
52 |
+
|
53 |
+
# -----------------------------------------
|
54 |
+
|
55 |
+
|
56 |
+
# chunk text into smaller pieces
|
57 |
+
|
58 |
+
|
59 |
+
def chunk_text(text, max_chars=135):
|
60 |
+
"""
|
61 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text (str): The text to be split.
|
65 |
+
max_chars (int): The maximum number of characters per chunk.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
List[str]: A list of text chunks.
|
69 |
+
"""
|
70 |
+
chunks = []
|
71 |
+
current_chunk = ""
|
72 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
73 |
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
74 |
+
|
75 |
+
for sentence in sentences:
|
76 |
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
77 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
78 |
+
else:
|
79 |
+
if current_chunk:
|
80 |
+
chunks.append(current_chunk.strip())
|
81 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
82 |
+
|
83 |
+
if current_chunk:
|
84 |
+
chunks.append(current_chunk.strip())
|
85 |
+
|
86 |
+
return chunks
|
87 |
+
|
88 |
+
|
89 |
+
# load vocoder
|
90 |
+
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device):
|
91 |
+
if vocoder_name == "vocos":
|
92 |
+
if is_local:
|
93 |
+
print(f"Load vocos from local path {local_path}")
|
94 |
+
vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
|
95 |
+
state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
|
96 |
+
vocoder.load_state_dict(state_dict)
|
97 |
+
vocoder = vocoder.eval().to(device)
|
98 |
+
else:
|
99 |
+
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
100 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
101 |
+
elif vocoder_name == "bigvgan":
|
102 |
+
try:
|
103 |
+
from third_party.BigVGAN import bigvgan
|
104 |
+
except ImportError:
|
105 |
+
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
106 |
+
if is_local:
|
107 |
+
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
108 |
+
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
109 |
+
else:
|
110 |
+
vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
|
111 |
+
|
112 |
+
vocoder.remove_weight_norm()
|
113 |
+
vocoder = vocoder.eval().to(device)
|
114 |
+
return vocoder
|
115 |
+
|
116 |
+
|
117 |
+
# load asr pipeline
|
118 |
+
|
119 |
+
asr_pipe = None
|
120 |
+
|
121 |
+
|
122 |
+
def initialize_asr_pipeline(device=device, dtype=None):
|
123 |
+
if dtype is None:
|
124 |
+
dtype = (
|
125 |
+
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
126 |
+
)
|
127 |
+
global asr_pipe
|
128 |
+
asr_pipe = pipeline(
|
129 |
+
"automatic-speech-recognition",
|
130 |
+
model="openai/whisper-large-v3-turbo",
|
131 |
+
torch_dtype=dtype,
|
132 |
+
device=device,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# load model checkpoint for inference
|
137 |
+
|
138 |
+
|
139 |
+
def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
|
140 |
+
if dtype is None:
|
141 |
+
dtype = (
|
142 |
+
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
143 |
+
)
|
144 |
+
model = model.to(dtype)
|
145 |
+
|
146 |
+
ckpt_type = ckpt_path.split(".")[-1]
|
147 |
+
if ckpt_type == "safetensors":
|
148 |
+
from safetensors.torch import load_file
|
149 |
+
|
150 |
+
checkpoint = load_file(ckpt_path)
|
151 |
+
else:
|
152 |
+
checkpoint = torch.load(ckpt_path, weights_only=True)
|
153 |
+
|
154 |
+
if use_ema:
|
155 |
+
if ckpt_type == "safetensors":
|
156 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
157 |
+
checkpoint["model_state_dict"] = {
|
158 |
+
k.replace("ema_model.", ""): v
|
159 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
160 |
+
if k not in ["initted", "step"]
|
161 |
+
}
|
162 |
+
|
163 |
+
# patch for backward compatibility, 305e3ea
|
164 |
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
165 |
+
if key in checkpoint["model_state_dict"]:
|
166 |
+
del checkpoint["model_state_dict"][key]
|
167 |
+
|
168 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
169 |
+
else:
|
170 |
+
if ckpt_type == "safetensors":
|
171 |
+
checkpoint = {"model_state_dict": checkpoint}
|
172 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
173 |
+
|
174 |
+
return model.to(device)
|
175 |
+
|
176 |
+
|
177 |
+
# load model for inference
|
178 |
+
|
179 |
+
|
180 |
+
def load_model(
|
181 |
+
model_cls,
|
182 |
+
model_cfg,
|
183 |
+
ckpt_path,
|
184 |
+
mel_spec_type=mel_spec_type,
|
185 |
+
vocab_file="",
|
186 |
+
ode_method=ode_method,
|
187 |
+
use_ema=True,
|
188 |
+
device=device,
|
189 |
+
):
|
190 |
+
if vocab_file == "":
|
191 |
+
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
192 |
+
tokenizer = "custom"
|
193 |
+
|
194 |
+
print("\nvocab : ", vocab_file)
|
195 |
+
print("tokenizer : ", tokenizer)
|
196 |
+
print("model : ", ckpt_path, "\n")
|
197 |
+
|
198 |
+
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
199 |
+
model = CFM(
|
200 |
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
201 |
+
mel_spec_kwargs=dict(
|
202 |
+
n_fft=n_fft,
|
203 |
+
hop_length=hop_length,
|
204 |
+
win_length=win_length,
|
205 |
+
n_mel_channels=n_mel_channels,
|
206 |
+
target_sample_rate=target_sample_rate,
|
207 |
+
mel_spec_type=mel_spec_type,
|
208 |
+
),
|
209 |
+
odeint_kwargs=dict(
|
210 |
+
method=ode_method,
|
211 |
+
),
|
212 |
+
vocab_char_map=vocab_char_map,
|
213 |
+
).to(device)
|
214 |
+
|
215 |
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
216 |
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
217 |
+
|
218 |
+
return model
|
219 |
+
|
220 |
+
|
221 |
+
# preprocess reference audio and text
|
222 |
+
|
223 |
+
|
224 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
225 |
+
show_info("Converting audio...")
|
226 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
227 |
+
aseg = AudioSegment.from_file(ref_audio_orig)
|
228 |
+
|
229 |
+
if clip_short:
|
230 |
+
# 1. try to find long silence for clipping
|
231 |
+
non_silent_segs = silence.split_on_silence(
|
232 |
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
|
233 |
+
)
|
234 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
235 |
+
for non_silent_seg in non_silent_segs:
|
236 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
237 |
+
show_info("Audio is over 15s, clipping short. (1)")
|
238 |
+
break
|
239 |
+
non_silent_wave += non_silent_seg
|
240 |
+
|
241 |
+
# 2. try to find short silence for clipping if 1. failed
|
242 |
+
if len(non_silent_wave) > 15000:
|
243 |
+
non_silent_segs = silence.split_on_silence(
|
244 |
+
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000
|
245 |
+
)
|
246 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
247 |
+
for non_silent_seg in non_silent_segs:
|
248 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
249 |
+
show_info("Audio is over 15s, clipping short. (2)")
|
250 |
+
break
|
251 |
+
non_silent_wave += non_silent_seg
|
252 |
+
|
253 |
+
aseg = non_silent_wave
|
254 |
+
|
255 |
+
# 3. if no proper silence found for clipping
|
256 |
+
if len(aseg) > 15000:
|
257 |
+
aseg = aseg[:15000]
|
258 |
+
show_info("Audio is over 15s, clipping short. (3)")
|
259 |
+
|
260 |
+
aseg.export(f.name, format="wav")
|
261 |
+
ref_audio = f.name
|
262 |
+
|
263 |
+
# Compute a hash of the reference audio file
|
264 |
+
with open(ref_audio, "rb") as audio_file:
|
265 |
+
audio_data = audio_file.read()
|
266 |
+
audio_hash = hashlib.md5(audio_data).hexdigest()
|
267 |
+
|
268 |
+
global _ref_audio_cache
|
269 |
+
if audio_hash in _ref_audio_cache:
|
270 |
+
# Use cached reference text
|
271 |
+
show_info("Using cached reference text...")
|
272 |
+
ref_text = _ref_audio_cache[audio_hash]
|
273 |
+
else:
|
274 |
+
if not ref_text.strip():
|
275 |
+
global asr_pipe
|
276 |
+
if asr_pipe is None:
|
277 |
+
initialize_asr_pipeline(device=device)
|
278 |
+
show_info("No reference text provided, transcribing reference audio...")
|
279 |
+
ref_text = asr_pipe(
|
280 |
+
ref_audio,
|
281 |
+
chunk_length_s=30,
|
282 |
+
batch_size=128,
|
283 |
+
generate_kwargs={"task": "transcribe"},
|
284 |
+
return_timestamps=False,
|
285 |
+
)["text"].strip()
|
286 |
+
show_info("Finished transcription")
|
287 |
+
else:
|
288 |
+
show_info("Using custom reference text...")
|
289 |
+
# Cache the transcribed text
|
290 |
+
_ref_audio_cache[audio_hash] = ref_text
|
291 |
+
|
292 |
+
# Ensure ref_text ends with a proper sentence-ending punctuation
|
293 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
294 |
+
if ref_text.endswith("."):
|
295 |
+
ref_text += " "
|
296 |
+
else:
|
297 |
+
ref_text += ". "
|
298 |
+
|
299 |
+
return ref_audio, ref_text
|
300 |
+
|
301 |
+
|
302 |
+
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
303 |
+
|
304 |
+
|
305 |
+
def infer_process(
|
306 |
+
ref_audio,
|
307 |
+
ref_text,
|
308 |
+
gen_text,
|
309 |
+
model_obj,
|
310 |
+
vocoder,
|
311 |
+
mel_spec_type=mel_spec_type,
|
312 |
+
show_info=print,
|
313 |
+
progress=tqdm,
|
314 |
+
target_rms=target_rms,
|
315 |
+
cross_fade_duration=cross_fade_duration,
|
316 |
+
nfe_step=nfe_step,
|
317 |
+
cfg_strength=cfg_strength,
|
318 |
+
sway_sampling_coef=sway_sampling_coef,
|
319 |
+
speed=speed,
|
320 |
+
fix_duration=fix_duration,
|
321 |
+
device=device,
|
322 |
+
):
|
323 |
+
# Split the input text into batches
|
324 |
+
audio, sr = torchaudio.load(ref_audio)
|
325 |
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
326 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
327 |
+
for i, gen_text in enumerate(gen_text_batches):
|
328 |
+
print(f"gen_text {i}", gen_text)
|
329 |
+
|
330 |
+
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
331 |
+
return infer_batch_process(
|
332 |
+
(audio, sr),
|
333 |
+
ref_text,
|
334 |
+
gen_text_batches,
|
335 |
+
model_obj,
|
336 |
+
vocoder,
|
337 |
+
mel_spec_type=mel_spec_type,
|
338 |
+
progress=progress,
|
339 |
+
target_rms=target_rms,
|
340 |
+
cross_fade_duration=cross_fade_duration,
|
341 |
+
nfe_step=nfe_step,
|
342 |
+
cfg_strength=cfg_strength,
|
343 |
+
sway_sampling_coef=sway_sampling_coef,
|
344 |
+
speed=speed,
|
345 |
+
fix_duration=fix_duration,
|
346 |
+
device=device,
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
# infer batches
|
351 |
+
|
352 |
+
|
353 |
+
def infer_batch_process(
|
354 |
+
ref_audio,
|
355 |
+
ref_text,
|
356 |
+
gen_text_batches,
|
357 |
+
model_obj,
|
358 |
+
vocoder,
|
359 |
+
mel_spec_type="vocos",
|
360 |
+
progress=tqdm,
|
361 |
+
target_rms=0.1,
|
362 |
+
cross_fade_duration=0.15,
|
363 |
+
nfe_step=32,
|
364 |
+
cfg_strength=2.0,
|
365 |
+
sway_sampling_coef=-1,
|
366 |
+
speed=1,
|
367 |
+
fix_duration=None,
|
368 |
+
device=None,
|
369 |
+
):
|
370 |
+
audio, sr = ref_audio
|
371 |
+
if audio.shape[0] > 1:
|
372 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
373 |
+
|
374 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
375 |
+
if rms < target_rms:
|
376 |
+
audio = audio * target_rms / rms
|
377 |
+
if sr != target_sample_rate:
|
378 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
379 |
+
audio = resampler(audio)
|
380 |
+
audio = audio.to(device)
|
381 |
+
|
382 |
+
generated_waves = []
|
383 |
+
spectrograms = []
|
384 |
+
|
385 |
+
if len(ref_text[-1].encode("utf-8")) == 1:
|
386 |
+
ref_text = ref_text + " "
|
387 |
+
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
388 |
+
# Prepare the text
|
389 |
+
text_list = [ref_text + gen_text]
|
390 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
391 |
+
|
392 |
+
ref_audio_len = audio.shape[-1] // hop_length
|
393 |
+
if fix_duration is not None:
|
394 |
+
duration = int(fix_duration * target_sample_rate / hop_length)
|
395 |
+
else:
|
396 |
+
# Calculate duration
|
397 |
+
ref_text_len = len(ref_text.encode("utf-8"))
|
398 |
+
gen_text_len = len(gen_text.encode("utf-8"))
|
399 |
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
400 |
+
|
401 |
+
# inference
|
402 |
+
with torch.inference_mode():
|
403 |
+
generated, _ = model_obj.sample(
|
404 |
+
cond=audio,
|
405 |
+
text=final_text_list,
|
406 |
+
duration=duration,
|
407 |
+
steps=nfe_step,
|
408 |
+
cfg_strength=cfg_strength,
|
409 |
+
sway_sampling_coef=sway_sampling_coef,
|
410 |
+
)
|
411 |
+
|
412 |
+
generated = generated.to(torch.float32)
|
413 |
+
generated = generated[:, ref_audio_len:, :]
|
414 |
+
generated_mel_spec = generated.permute(0, 2, 1)
|
415 |
+
if mel_spec_type == "vocos":
|
416 |
+
generated_wave = vocoder.decode(generated_mel_spec)
|
417 |
+
elif mel_spec_type == "bigvgan":
|
418 |
+
generated_wave = vocoder(generated_mel_spec)
|
419 |
+
if rms < target_rms:
|
420 |
+
generated_wave = generated_wave * rms / target_rms
|
421 |
+
|
422 |
+
# wav -> numpy
|
423 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
424 |
+
|
425 |
+
generated_waves.append(generated_wave)
|
426 |
+
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
427 |
+
|
428 |
+
# Combine all generated waves with cross-fading
|
429 |
+
if cross_fade_duration <= 0:
|
430 |
+
# Simply concatenate
|
431 |
+
final_wave = np.concatenate(generated_waves)
|
432 |
+
else:
|
433 |
+
final_wave = generated_waves[0]
|
434 |
+
for i in range(1, len(generated_waves)):
|
435 |
+
prev_wave = final_wave
|
436 |
+
next_wave = generated_waves[i]
|
437 |
+
|
438 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
439 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
440 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
441 |
+
|
442 |
+
if cross_fade_samples <= 0:
|
443 |
+
# No overlap possible, concatenate
|
444 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
445 |
+
continue
|
446 |
+
|
447 |
+
# Overlapping parts
|
448 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
449 |
+
next_overlap = next_wave[:cross_fade_samples]
|
450 |
+
|
451 |
+
# Fade out and fade in
|
452 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
453 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
454 |
+
|
455 |
+
# Cross-faded overlap
|
456 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
457 |
+
|
458 |
+
# Combine
|
459 |
+
new_wave = np.concatenate(
|
460 |
+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
461 |
+
)
|
462 |
+
|
463 |
+
final_wave = new_wave
|
464 |
+
|
465 |
+
# Create a combined spectrogram
|
466 |
+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
467 |
+
|
468 |
+
return final_wave, target_sample_rate, combined_spectrogram
|
469 |
+
|
470 |
+
|
471 |
+
# remove silence from generated wav
|
472 |
+
|
473 |
+
|
474 |
+
def remove_silence_for_generated_wav(filename):
|
475 |
+
aseg = AudioSegment.from_file(filename)
|
476 |
+
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
477 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
478 |
+
for non_silent_seg in non_silent_segs:
|
479 |
+
non_silent_wave += non_silent_seg
|
480 |
+
aseg = non_silent_wave
|
481 |
+
aseg.export(filename, format="wav")
|
482 |
+
|
483 |
+
|
484 |
+
# save spectrogram
|
485 |
+
|
486 |
+
|
487 |
+
def save_spectrogram(spectrogram, path):
|
488 |
+
plt.figure(figsize=(12, 4))
|
489 |
+
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
490 |
+
plt.colorbar()
|
491 |
+
plt.savefig(path)
|
492 |
+
plt.close()
|
src/f5_tts/src_f5_tts_api.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import sys
|
3 |
+
from importlib.resources import files
|
4 |
+
|
5 |
+
import soundfile as sf
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
from cached_path import cached_path
|
9 |
+
|
10 |
+
from f5_tts.infer.utils_infer import (
|
11 |
+
hop_length,
|
12 |
+
infer_process,
|
13 |
+
load_model,
|
14 |
+
load_vocoder,
|
15 |
+
preprocess_ref_audio_text,
|
16 |
+
remove_silence_for_generated_wav,
|
17 |
+
save_spectrogram,
|
18 |
+
target_sample_rate,
|
19 |
+
)
|
20 |
+
from f5_tts.model import DiT, UNetT
|
21 |
+
from f5_tts.model.utils import seed_everything
|
22 |
+
|
23 |
+
|
24 |
+
class F5TTS:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model_type="F5-TTS",
|
28 |
+
ckpt_file="",
|
29 |
+
vocab_file="",
|
30 |
+
ode_method="euler",
|
31 |
+
use_ema=True,
|
32 |
+
vocoder_name="vocos",
|
33 |
+
local_path=None,
|
34 |
+
device=None,
|
35 |
+
):
|
36 |
+
# Initialize parameters
|
37 |
+
self.final_wave = None
|
38 |
+
self.target_sample_rate = target_sample_rate
|
39 |
+
self.hop_length = hop_length
|
40 |
+
self.seed = -1
|
41 |
+
self.mel_spec_type = vocoder_name
|
42 |
+
|
43 |
+
# Set device
|
44 |
+
self.device = device or (
|
45 |
+
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
46 |
+
)
|
47 |
+
|
48 |
+
# Load models
|
49 |
+
self.load_vocoder_model(vocoder_name, local_path)
|
50 |
+
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
|
51 |
+
|
52 |
+
def load_vocoder_model(self, vocoder_name, local_path):
|
53 |
+
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
54 |
+
|
55 |
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
|
56 |
+
if model_type == "F5-TTS":
|
57 |
+
if not ckpt_file:
|
58 |
+
if mel_spec_type == "vocos":
|
59 |
+
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
60 |
+
elif mel_spec_type == "bigvgan":
|
61 |
+
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
|
62 |
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
63 |
+
model_cls = DiT
|
64 |
+
elif model_type == "E2-TTS":
|
65 |
+
if not ckpt_file:
|
66 |
+
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
67 |
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
68 |
+
model_cls = UNetT
|
69 |
+
else:
|
70 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
71 |
+
|
72 |
+
self.ema_model = load_model(
|
73 |
+
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
74 |
+
)
|
75 |
+
|
76 |
+
def export_wav(self, wav, file_wave, remove_silence=False):
|
77 |
+
sf.write(file_wave, wav, self.target_sample_rate)
|
78 |
+
|
79 |
+
if remove_silence:
|
80 |
+
remove_silence_for_generated_wav(file_wave)
|
81 |
+
|
82 |
+
def export_spectrogram(self, spect, file_spect):
|
83 |
+
save_spectrogram(spect, file_spect)
|
84 |
+
|
85 |
+
def infer(
|
86 |
+
self,
|
87 |
+
ref_file,
|
88 |
+
ref_text,
|
89 |
+
gen_text,
|
90 |
+
show_info=print,
|
91 |
+
progress=tqdm,
|
92 |
+
target_rms=0.1,
|
93 |
+
cross_fade_duration=0.15,
|
94 |
+
sway_sampling_coef=-1,
|
95 |
+
cfg_strength=2,
|
96 |
+
nfe_step=32,
|
97 |
+
speed=1.0,
|
98 |
+
fix_duration=None,
|
99 |
+
remove_silence=False,
|
100 |
+
file_wave=None,
|
101 |
+
file_spect=None,
|
102 |
+
seed=-1,
|
103 |
+
):
|
104 |
+
if seed == -1:
|
105 |
+
seed = random.randint(0, sys.maxsize)
|
106 |
+
seed_everything(seed)
|
107 |
+
self.seed = seed
|
108 |
+
|
109 |
+
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
110 |
+
|
111 |
+
wav, sr, spect = infer_process(
|
112 |
+
ref_file,
|
113 |
+
ref_text,
|
114 |
+
gen_text,
|
115 |
+
self.ema_model,
|
116 |
+
self.vocoder,
|
117 |
+
self.mel_spec_type,
|
118 |
+
show_info=show_info,
|
119 |
+
progress=progress,
|
120 |
+
target_rms=target_rms,
|
121 |
+
cross_fade_duration=cross_fade_duration,
|
122 |
+
nfe_step=nfe_step,
|
123 |
+
cfg_strength=cfg_strength,
|
124 |
+
sway_sampling_coef=sway_sampling_coef,
|
125 |
+
speed=speed,
|
126 |
+
fix_duration=fix_duration,
|
127 |
+
device=self.device,
|
128 |
+
)
|
129 |
+
|
130 |
+
if file_wave is not None:
|
131 |
+
self.export_wav(wav, file_wave, remove_silence)
|
132 |
+
|
133 |
+
if file_spect is not None:
|
134 |
+
self.export_spectrogram(spect, file_spect)
|
135 |
+
|
136 |
+
return wav, sr, spect
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
f5tts = F5TTS()
|
141 |
+
|
142 |
+
wav, sr, spect = f5tts.infer(
|
143 |
+
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
144 |
+
ref_text="some call me nature, others call me mother nature.",
|
145 |
+
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
146 |
+
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
147 |
+
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
148 |
+
seed=-1, # random seed = -1
|
149 |
+
)
|
150 |
+
|
151 |
+
print("seed :", f5tts.seed)
|