Spaces:
Sleeping
Sleeping
import codecs | |
import os | |
import re | |
from datetime import datetime | |
from importlib.resources import files | |
from pathlib import Path | |
import numpy as np | |
import soundfile as sf | |
import tomli | |
from cached_path import cached_path | |
from hydra.utils import get_class | |
from omegaconf import OmegaConf | |
from f5_tts.infer.utils_infer import ( | |
mel_spec_type, | |
target_rms, | |
cross_fade_duration, | |
nfe_step, | |
cfg_strength, | |
sway_sampling_coef, | |
speed, | |
fix_duration, | |
device, | |
infer_process, | |
load_model, | |
load_vocoder, | |
preprocess_ref_audio_text, | |
remove_silence_for_generated_wav, | |
) | |
# ββ USER CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
config_path = "infer/examples/basic/basic.toml" | |
model = "F5TTS_v1_Base" | |
model_cfg_path = None # e.g. "path/to/your/model.yaml", or leave None to use default from config | |
ckpt_file = "" # leave blank to pull from HF cache | |
vocab_file = "" # leave blank to use default | |
ref_audio = "data/15sec.wav" | |
ref_text = ( | |
"Fuck your phone. Stop texting all the time. " | |
"Look up from your phone and breathe. Release yourself." | |
) | |
gen_text = ( | |
"I am not feeling it. This is it. There is no reconceptualizing." | |
) | |
gen_file = "" # if set, will override gen_text by loading from this file | |
output_dir = "tests" | |
output_file = f"infer_cli_{datetime.now():%Y%m%d_%H%M%S}.wav" | |
save_chunk = False | |
remove_silence = False | |
load_vocoder_from_local = False | |
vocoder_name = None # "vocos" or "bigvgan" or None to use default from config | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# load config | |
config = tomli.load(open(config_path, "rb")) | |
# resolve parameters (fall back to config defaults where applicable) | |
model_cfg_path = model_cfg_path or config.get("model_cfg", None) | |
ckpt_file = ckpt_file or config.get("ckpt_file", "") | |
vocab_file = vocab_file or config.get("vocab_file", "") | |
gen_file = gen_file or config.get("gen_file", "") | |
save_chunk = save_chunk or config.get("save_chunk", False) | |
remove_silence = remove_silence or config.get("remove_silence", False) | |
load_vocoder_from_local = load_vocoder_from_local or config.get("load_vocoder_from_local", False) | |
vocoder_name = vocoder_name or config.get("vocoder_name", mel_spec_type) | |
target_rms = config.get("target_rms", target_rms) | |
cross_fade_duration = config.get("cross_fade_duration", cross_fade_duration) | |
nfe_step = config.get("nfe_step", nfe_step) | |
cfg_strength = config.get("cfg_strength", cfg_strength) | |
sway_sampling_coef = config.get("sway_sampling_coef", sway_sampling_coef) | |
speed = config.get("speed", speed) | |
fix_duration = config.get("fix_duration", fix_duration) | |
device = config.get("device", device) | |
# if user pointed at example paths inside the package, fix them | |
if "infer/examples/" in ref_audio: | |
ref_audio = str(files("f5_tts").joinpath(ref_audio)) | |
if gen_file and "infer/examples/" in gen_file: | |
gen_file = str(files("f5_tts").joinpath(gen_file)) | |
if "voices" in config: | |
for v in config["voices"].values(): | |
if "infer/examples/" in v.get("ref_audio", ""): | |
v["ref_audio"] = str(files("f5_tts").joinpath(v["ref_audio"])) | |
# if using a gen_file, load its text | |
if gen_file: | |
gen_text = codecs.open(gen_file, "r", "utf-8").read() | |
# prepare output paths | |
wave_path = Path(output_dir) / output_file | |
if save_chunk: | |
chunk_dir = Path(output_dir) / f"{wave_path.stem}_chunks" | |
chunk_dir.mkdir(parents=True, exist_ok=True) | |
# load vocoder | |
if vocoder_name == "vocos": | |
vocoder_local_path = "../checkpoints/vocos-mel-24khz" | |
elif vocoder_name == "bigvgan": | |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" | |
else: | |
vocoder_local_path = None | |
vocoder = load_vocoder( | |
vocoder_name=vocoder_name, | |
is_local=load_vocoder_from_local, | |
local_path=vocoder_local_path, | |
device=device, | |
) | |
# load TTS model | |
model_cfg = OmegaConf.load( | |
model_cfg_path | |
or str(files("f5_tts").joinpath(f"configs/{model}.yaml")) | |
) | |
ModelClass = get_class(f"f5_tts.model.{model_cfg.model.backbone}") | |
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type | |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" | |
if model == "F5TTS_Base": | |
if vocoder_name == "vocos": | |
ckpt_step = 1200000 | |
else: | |
model = "F5TTS_Base_bigvgan" | |
ckpt_type = "pt" | |
elif model == "E2TTS_Base": | |
repo_name, ckpt_step = "E2-TTS", 1200000 | |
if not ckpt_file: | |
ckpt_file = str( | |
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}") | |
) | |
print(f"Loading model {model} checkpointβ¦") | |
ema_model = load_model( | |
ModelClass, | |
model_cfg.model.arch, | |
ckpt_file, | |
mel_spec_type=vocoder_name, | |
vocab_file=vocab_file, | |
device=device, | |
) | |
def generate_tts(input_text, output_dir="tests", output_file=None, ref_audio=ref_audio, ref_text=None): | |
""" | |
Generate text-to-speech audio from input text. | |
Args: | |
input_text (str): Text to convert to speech | |
output_dir (str): Directory to save the output file (default: "tests") | |
output_file (str): Output filename (default: auto-generated based on timestamp) | |
ref_audio (str): Reference audio file (default: "15sec.wav") | |
ref_text (str): Reference text (default: predefined text) | |
Returns: | |
str: Path to the generated audio file | |
""" | |
if ref_text is None: | |
ref_text = ( | |
"Fuck your phone. Stop texting all the time. " | |
"Look up from your phone and breathe. Release yourself." | |
) | |
gen_text = input_text | |
if output_file is None: | |
output_file = f"infer_cli_{datetime.now():%Y%m%d_%H%M%S}.wav" | |
# assemble voices dict | |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} | |
voices = {"main": main_voice} | |
if "voices" in config: | |
voices.update(config["voices"]) | |
voices["main"] = main_voice | |
# preprocess all references | |
for name, v in voices.items(): | |
v["ref_audio"], v["ref_text"] = preprocess_ref_audio_text( | |
v["ref_audio"], v["ref_text"] | |
) | |
# break text into perβvoice chunks | |
reg1 = r"(?=\[\w+\])" | |
reg2 = r"\[(\w+)\]" | |
chunks = re.split(reg1, gen_text) | |
segments = [] | |
for chunk in chunks: | |
txt = chunk.strip() | |
if not txt: | |
continue | |
m = re.match(reg2, txt) | |
if m: | |
voice = m.group(1) | |
txt = re.sub(reg2, "", txt).strip() | |
else: | |
voice = "main" | |
if voice not in voices: | |
print(f"Unknown voice '{voice}', using main.") | |
voice = "main" | |
seg, sr, _ = infer_process( | |
voices[voice]["ref_audio"], | |
voices[voice]["ref_text"], | |
txt, | |
ema_model, | |
vocoder, | |
mel_spec_type=vocoder_name, | |
target_rms=target_rms, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
cfg_strength=cfg_strength, | |
sway_sampling_coef=sway_sampling_coef, | |
speed=speed, | |
fix_duration=fix_duration, | |
device=device, | |
) | |
segments.append(seg) | |
if save_chunk: | |
name = txt[:200].replace(" ", "_") | |
sf.write(str(chunk_dir / f"{len(segments)-1}_{name}.wav"), seg, sr) | |
# concatenate and write | |
final = np.concatenate(segments) if segments else np.array([], dtype=np.float32) | |
os.makedirs(output_dir, exist_ok=True) | |
wave_path = Path(output_dir) / output_file | |
sf.write(str(wave_path), final, sr) | |
if remove_silence: | |
remove_silence_for_generated_wav(str(wave_path)) | |
print(f"Written output to {wave_path}") | |
return str(wave_path) | |
if __name__ == "__main__": | |
test_text = "This is a test of the TTS system." | |
generated_file = generate_tts(test_text) | |
print(f"Generated file: {generated_file}") | |