0xrushi
rest
ef09716
raw
history blame contribute delete
8.29 kB
import codecs
import os
import re
from datetime import datetime
from importlib.resources import files
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
device,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
# ── USER CONFIG ────────────────────────────────────────────────────────────────
config_path = "infer/examples/basic/basic.toml"
model = "F5TTS_v1_Base"
model_cfg_path = None # e.g. "path/to/your/model.yaml", or leave None to use default from config
ckpt_file = "" # leave blank to pull from HF cache
vocab_file = "" # leave blank to use default
ref_audio = "data/15sec.wav"
ref_text = (
"Fuck your phone. Stop texting all the time. "
"Look up from your phone and breathe. Release yourself."
)
gen_text = (
"I am not feeling it. This is it. There is no reconceptualizing."
)
gen_file = "" # if set, will override gen_text by loading from this file
output_dir = "tests"
output_file = f"infer_cli_{datetime.now():%Y%m%d_%H%M%S}.wav"
save_chunk = False
remove_silence = False
load_vocoder_from_local = False
vocoder_name = None # "vocos" or "bigvgan" or None to use default from config
# ────────────────────────────────────────────────────────────────────────────────
# load config
config = tomli.load(open(config_path, "rb"))
# resolve parameters (fall back to config defaults where applicable)
model_cfg_path = model_cfg_path or config.get("model_cfg", None)
ckpt_file = ckpt_file or config.get("ckpt_file", "")
vocab_file = vocab_file or config.get("vocab_file", "")
gen_file = gen_file or config.get("gen_file", "")
save_chunk = save_chunk or config.get("save_chunk", False)
remove_silence = remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = load_vocoder_from_local or config.get("load_vocoder_from_local", False)
vocoder_name = vocoder_name or config.get("vocoder_name", mel_spec_type)
target_rms = config.get("target_rms", target_rms)
cross_fade_duration = config.get("cross_fade_duration", cross_fade_duration)
nfe_step = config.get("nfe_step", nfe_step)
cfg_strength = config.get("cfg_strength", cfg_strength)
sway_sampling_coef = config.get("sway_sampling_coef", sway_sampling_coef)
speed = config.get("speed", speed)
fix_duration = config.get("fix_duration", fix_duration)
device = config.get("device", device)
# if user pointed at example paths inside the package, fix them
if "infer/examples/" in ref_audio:
ref_audio = str(files("f5_tts").joinpath(ref_audio))
if gen_file and "infer/examples/" in gen_file:
gen_file = str(files("f5_tts").joinpath(gen_file))
if "voices" in config:
for v in config["voices"].values():
if "infer/examples/" in v.get("ref_audio", ""):
v["ref_audio"] = str(files("f5_tts").joinpath(v["ref_audio"]))
# if using a gen_file, load its text
if gen_file:
gen_text = codecs.open(gen_file, "r", "utf-8").read()
# prepare output paths
wave_path = Path(output_dir) / output_file
if save_chunk:
chunk_dir = Path(output_dir) / f"{wave_path.stem}_chunks"
chunk_dir.mkdir(parents=True, exist_ok=True)
# load vocoder
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
else:
vocoder_local_path = None
vocoder = load_vocoder(
vocoder_name=vocoder_name,
is_local=load_vocoder_from_local,
local_path=vocoder_local_path,
device=device,
)
# load TTS model
model_cfg = OmegaConf.load(
model_cfg_path
or str(files("f5_tts").joinpath(f"configs/{model}.yaml"))
)
ModelClass = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
else:
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name, ckpt_step = "E2-TTS", 1200000
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")
)
print(f"Loading model {model} checkpoint…")
ema_model = load_model(
ModelClass,
model_cfg.model.arch,
ckpt_file,
mel_spec_type=vocoder_name,
vocab_file=vocab_file,
device=device,
)
def generate_tts(input_text, output_dir="tests", output_file=None, ref_audio=ref_audio, ref_text=None):
"""
Generate text-to-speech audio from input text.
Args:
input_text (str): Text to convert to speech
output_dir (str): Directory to save the output file (default: "tests")
output_file (str): Output filename (default: auto-generated based on timestamp)
ref_audio (str): Reference audio file (default: "15sec.wav")
ref_text (str): Reference text (default: predefined text)
Returns:
str: Path to the generated audio file
"""
if ref_text is None:
ref_text = (
"Fuck your phone. Stop texting all the time. "
"Look up from your phone and breathe. Release yourself."
)
gen_text = input_text
if output_file is None:
output_file = f"infer_cli_{datetime.now():%Y%m%d_%H%M%S}.wav"
# assemble voices dict
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
voices = {"main": main_voice}
if "voices" in config:
voices.update(config["voices"])
voices["main"] = main_voice
# preprocess all references
for name, v in voices.items():
v["ref_audio"], v["ref_text"] = preprocess_ref_audio_text(
v["ref_audio"], v["ref_text"]
)
# break text into per‑voice chunks
reg1 = r"(?=\[\w+\])"
reg2 = r"\[(\w+)\]"
chunks = re.split(reg1, gen_text)
segments = []
for chunk in chunks:
txt = chunk.strip()
if not txt:
continue
m = re.match(reg2, txt)
if m:
voice = m.group(1)
txt = re.sub(reg2, "", txt).strip()
else:
voice = "main"
if voice not in voices:
print(f"Unknown voice '{voice}', using main.")
voice = "main"
seg, sr, _ = infer_process(
voices[voice]["ref_audio"],
voices[voice]["ref_text"],
txt,
ema_model,
vocoder,
mel_spec_type=vocoder_name,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
segments.append(seg)
if save_chunk:
name = txt[:200].replace(" ", "_")
sf.write(str(chunk_dir / f"{len(segments)-1}_{name}.wav"), seg, sr)
# concatenate and write
final = np.concatenate(segments) if segments else np.array([], dtype=np.float32)
os.makedirs(output_dir, exist_ok=True)
wave_path = Path(output_dir) / output_file
sf.write(str(wave_path), final, sr)
if remove_silence:
remove_silence_for_generated_wav(str(wave_path))
print(f"Written output to {wave_path}")
return str(wave_path)
if __name__ == "__main__":
test_text = "This is a test of the TTS system."
generated_file = generate_tts(test_text)
print(f"Generated file: {generated_file}")