|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import soundfile as sf |
|
from xcodec2.modeling_xcodec2 import XCodec2Model |
|
import numpy as np |
|
import ChatTTS |
|
import re |
|
DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B" |
|
DEMO_EXAMPLES = [ |
|
["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"], |
|
["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"], |
|
["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"], |
|
["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."] |
|
] |
|
class TTSapi: |
|
def __init__(self, |
|
model_name=DEFAULT_TTS_MODEL_NAME, |
|
codec_model_name="HKUST-Audio/xcodec2", |
|
device=torch.device("cuda:0")): |
|
|
|
self.reload(model_name, codec_model_name, device) |
|
|
|
def reload(self, |
|
model_name=DEFAULT_TTS_MODEL_NAME, |
|
codec_model_name="HKUST-Audio/xcodec2", |
|
device=torch.device("cuda:0")): |
|
if 'llasa' in model_name.lower(): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_name) |
|
self.model.eval().to(device) |
|
|
|
self.codec_model = XCodec2Model.from_pretrained(codec_model_name) |
|
self.codec_model.eval().to(device) |
|
self.device = device |
|
self.codec_model_name = codec_model_name |
|
self.sr = 16000 |
|
elif 'chattts' in model_name.lower(): |
|
self.model = ChatTTS.Chat() |
|
self.model.load(compile=False) |
|
self.sr = 24000 |
|
self.punctuation = r'[,,.。??!!~~;;]' |
|
else: |
|
raise ValueError(f'不支持的TTS模型:{model_name}') |
|
|
|
self.model_name = model_name |
|
|
|
def ids_to_speech_tokens(self, speech_ids): |
|
speech_tokens_str = [] |
|
for speech_id in speech_ids: |
|
speech_tokens_str.append(f"<|s_{speech_id}|>") |
|
return speech_tokens_str |
|
|
|
def extract_speech_ids(self, speech_tokens_str): |
|
speech_ids = [] |
|
for token_str in speech_tokens_str: |
|
if token_str.startswith('<|s_') and token_str.endswith('|>'): |
|
num_str = token_str[4:-2] |
|
|
|
num = int(num_str) |
|
speech_ids.append(num) |
|
else: |
|
print(f"Unexpected token: {token_str}") |
|
return speech_ids |
|
|
|
|
|
def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'): |
|
|
|
with torch.no_grad(): |
|
if 'chattts' in self.model_name.lower(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2) |
|
params_refine_text = ChatTTS.Chat.RefineTextParams( |
|
prompt=f'[oral_2][laugh_0][break_{break_num}]', |
|
) |
|
wavs = self.model.infer([input_text], |
|
params_refine_text=params_refine_text, |
|
) |
|
gen_wav_save = wavs[0] |
|
sf.write(save_path, gen_wav_save, 24000) |
|
|
|
else: |
|
if speech_prompt: |
|
|
|
prompt_wav, sr = sf.read(speech_prompt) |
|
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0) |
|
|
|
|
|
vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav) |
|
print("Prompt Vq Code Shape:", vq_code_prompt.shape ) |
|
|
|
vq_code_prompt = vq_code_prompt[0,0,:] |
|
|
|
speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) |
|
else: |
|
speech_ids_prefix = '' |
|
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" |
|
|
|
|
|
chat = [ |
|
{"role": "user", "content": "Convert the text to speech:" + formatted_text}, |
|
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} |
|
] |
|
|
|
input_ids = self.tokenizer.apply_chat_template( |
|
chat, |
|
tokenize=True, |
|
return_tensors='pt', |
|
continue_final_message=True |
|
) |
|
input_ids = input_ids.to(self.device) |
|
speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') |
|
|
|
|
|
outputs = self.model.generate( |
|
input_ids, |
|
max_length=2048, |
|
eos_token_id= speech_end_id , |
|
do_sample=True, |
|
top_p=1, |
|
temperature=1, |
|
) |
|
|
|
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1] |
|
|
|
speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
speech_tokens = self.extract_speech_ids(speech_tokens) |
|
|
|
speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
gen_wav = self.codec_model.decode_code(speech_tokens) |
|
|
|
|
|
if speech_prompt: |
|
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:] |
|
|
|
gen_wav_save = gen_wav[0, 0, :].cpu().numpy() |
|
sf.write(save_path, gen_wav_save, 16000) |
|
|
|
|
|
return gen_wav_save |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
synthesiser = TTSapi() |
|
TTS_LOADED = True |
|
def predict(config): |
|
global TTS_LOADED, synthesiser |
|
print(f"待合成文本:{config['msg']}") |
|
print(f"选中TTS模型:{config['tts_model']}") |
|
print(f"参考音频路径:{config['ref_audio']}") |
|
print(f"参考音频文本:{config['ref_audio_transcribe']}") |
|
text = config['msg'] |
|
try: |
|
if len(text) == 0: |
|
audio_output = np.array([0], dtype=np.int16) |
|
print("输入为空,无法合成语音") |
|
else: |
|
if not TTS_LOADED: |
|
print('TTS模型首次加载...') |
|
gr.Info("初次加载TTS模型,请稍候..", duration=63) |
|
synthesiser = TTSapi(model_name=config['tts_model']) |
|
TTS_LOADED = True |
|
print('加载完毕...') |
|
|
|
if config['tts_model'] != synthesiser.model_name: |
|
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') |
|
synthesiser.reload(model_name=config['tts_model']) |
|
|
|
|
|
if config['ref_audio']: |
|
prompt_text = config['ref_audio_transcribe'] |
|
if prompt_text is None: |
|
|
|
raise NotImplementedError('暂时必须提供文本') |
|
text = prompt_text + text |
|
|
|
audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio']) |
|
|
|
except Exception as e: |
|
print('!!!!!!!!') |
|
print(e) |
|
print('!!!!!!!!') |
|
|
|
return (synthesiser.sr if synthesiser else 16000, audio_output) |
|
|
|
with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: |
|
gr.Markdown(""" |
|
# Personalized TTS Demo |
|
## 使用步骤 |
|
* 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例 |
|
* 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可 |
|
|
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
tts_model = gr.Dropdown( |
|
label="选择TTS模型", |
|
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], |
|
value=DEFAULT_TTS_MODEL_NAME, |
|
interactive=True, |
|
visible=False |
|
) |
|
|
|
|
|
ref_audio = gr.Audio( |
|
label="上传参考音频", |
|
type="filepath", |
|
interactive=True |
|
) |
|
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) |
|
|
|
examples = gr.Examples( |
|
examples=DEMO_EXAMPLES, |
|
inputs=[ref_audio, ref_audio_transcribe], |
|
fn=predict |
|
) |
|
|
|
with gr.Column(): |
|
audio_player = gr.Audio( |
|
label="听听我声音~", |
|
type="numpy", |
|
interactive=False |
|
) |
|
msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本") |
|
submit_btn = gr.Button("合成语音", variant="primary") |
|
|
|
current_config = gr.State({ |
|
"msg": None, |
|
"tts_model": DEFAULT_TTS_MODEL_NAME, |
|
"ref_audio": None, |
|
"ref_audio_transcribe": None |
|
}) |
|
gr.on(triggers=[msg.change, tts_model.change, ref_audio.change, |
|
ref_audio_transcribe.change], |
|
fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio, |
|
"ref_audio_transcribe": ref_text}, |
|
inputs=[msg, tts_model, ref_audio, ref_audio_transcribe], |
|
outputs=current_config |
|
) |
|
submit_btn.click( |
|
predict, |
|
[current_config], |
|
[audio_player], |
|
queue=False |
|
) |
|
demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True) |
|
|