LLM_demo / tts_api.py
chenjianfei
Add application file
49e5e54
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model
import numpy as np
import ChatTTS
import re
DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B"
DEMO_EXAMPLES = [
["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"],
["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"],
["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"],
["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."]
]
class TTSapi:
def __init__(self,
model_name=DEFAULT_TTS_MODEL_NAME,
codec_model_name="HKUST-Audio/xcodec2",
device=torch.device("cuda:0")):
self.reload(model_name, codec_model_name, device)
def reload(self,
model_name=DEFAULT_TTS_MODEL_NAME,
codec_model_name="HKUST-Audio/xcodec2",
device=torch.device("cuda:0")):
if 'llasa' in model_name.lower():
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval().to(device)
self.codec_model = XCodec2Model.from_pretrained(codec_model_name)
self.codec_model.eval().to(device)
self.device = device
self.codec_model_name = codec_model_name
self.sr = 16000
elif 'chattts' in model_name.lower():
self.model = ChatTTS.Chat()
self.model.load(compile=False) # Set to True for better performance but would l significantly reduce the loading speed
self.sr = 24000
self.punctuation = r'[,,.。??!!~~;;]'
else:
raise ValueError(f'不支持的TTS模型:{model_name}')
self.model_name = model_name
def ids_to_speech_tokens(self, speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
def extract_speech_ids(self, speech_tokens_str):
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'):
#TTS start!
with torch.no_grad():
if 'chattts' in self.model_name.lower():
# rand_spk = chat.sample_random_speaker()
# print(rand_spk) # save it for later timbre recovery
# params_infer_code = ChatTTS.Chat.InferCodeParams(
# spk_emb = rand_spk, # add sampled speaker
# temperature = .3, # using custom temperature
# top_P = 0.7, # top P decode
# top_K = 20, # top K decode
# )
break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2)
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt=f'[oral_2][laugh_0][break_{break_num}]',
)
wavs = self.model.infer([input_text],
params_refine_text=params_refine_text,
)
gen_wav_save = wavs[0]
sf.write(save_path, gen_wav_save, 24000)
else:
if speech_prompt:
# only 16khz speech support!
prompt_wav, sr = sf.read(speech_prompt) # you can find wav in Files
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)
# Encode the prompt wav
vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav)
print("Prompt Vq Code Shape:", vq_code_prompt.shape )
vq_code_prompt = vq_code_prompt[0,0,:]
# Convert int 12345 to token <|s_12345|>
speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
else:
speech_ids_prefix = ''
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
# Tokenize the text ( and the speech prefix)
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
]
input_ids = self.tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids = input_ids.to(self.device)
speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
# Generate the speech autoregressively
outputs = self.model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id= speech_end_id ,
do_sample=True,
top_p=1, # Adjusts the diversity of generated content
temperature=1, # Controls randomness in output
)
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = self.extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = self.codec_model.decode_code(speech_tokens)
# if only need the generated part
if speech_prompt:
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
gen_wav_save = gen_wav[0, 0, :].cpu().numpy()
sf.write(save_path, gen_wav_save, 16000)
# gen_wav_save = np.clip(gen_wav_save, -1, 1)
# gen_wav_save = (gen_wav_save * 32767).astype(np.int16)
return gen_wav_save
if __name__ == '__main__':
# Llasa-8B shows better text understanding ability.
# input_text = " He shouted, 'Everyone, please gather 'round! Here's the plan: 1) Set-up at 9:15 a.m.; 2) Lunch at 12:00 p.m. (please RSVP!); 3) Playing — e.g., games, music, etc. — from 1:15 to 4:45; and 4) Clean-up at 5 p.m.'"
# prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
# input_text = prompt_text + '嘻嘻,臭宝儿你真可爱,我好喜欢你呀。'
# save_root = 'wavs/generated/'
# save_path = save_root + 'test.wav'
# speech_ref = 'wavs/ref/太乙真人.wav'
# # speech_ref = None
# # 帘外雨潺潺,春意阑珊。罗衾不耐五更寒。梦里不知身是客,一晌贪欢。独自莫凭栏,无限江山。别时容易见时难。流水落花春去也,天上人间。
# llasa_tts = TTSapi()
# gen = llasa_tts.forward(input_text, speech_prompt=speech_ref, save_path=save_path)
# print(gen.shape)
import gradio as gr
synthesiser = TTSapi()
TTS_LOADED = True
def predict(config):
global TTS_LOADED, synthesiser
print(f"待合成文本:{config['msg']}")
print(f"选中TTS模型:{config['tts_model']}")
print(f"参考音频路径:{config['ref_audio']}")
print(f"参考音频文本:{config['ref_audio_transcribe']}")
text = config['msg']
try:
if len(text) == 0:
audio_output = np.array([0], dtype=np.int16)
print("输入为空,无法合成语音")
else:
if not TTS_LOADED:
print('TTS模型首次加载...')
gr.Info("初次加载TTS模型,请稍候..", duration=63)
synthesiser = TTSapi(model_name=config['tts_model'])#, device="cuda:2")
TTS_LOADED = True
print('加载完毕...')
# 检查当前模型是否是所选
if config['tts_model'] != synthesiser.model_name:
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载')
synthesiser.reload(model_name=config['tts_model'])
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀
if config['ref_audio']:
prompt_text = config['ref_audio_transcribe']
if prompt_text is None:
# prompt_text = ...
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型
text = prompt_text + text
audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio'])
except Exception as e:
print('!!!!!!!!')
print(e)
print('!!!!!!!!')
return (synthesiser.sr if synthesiser else 16000, audio_output)
with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo:
gr.Markdown("""
# Personalized TTS Demo
## 使用步骤
* 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例
* 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可
""")
with gr.Row():
with gr.Column():
# TTS模型选择
tts_model = gr.Dropdown(
label="选择TTS模型",
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"],
value=DEFAULT_TTS_MODEL_NAME,
interactive=True,
visible=False # 给产品演示,暂时不展示模型选择
)
# 参考音频上传
ref_audio = gr.Audio(
label="上传参考音频",
type="filepath",
interactive=True
)
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True)
# 创建示例数据
examples = gr.Examples(
examples=DEMO_EXAMPLES,
inputs=[ref_audio, ref_audio_transcribe],
fn=predict
)
with gr.Column():
audio_player = gr.Audio(
label="听听我声音~",
type="numpy",
interactive=False
)
msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本")
submit_btn = gr.Button("合成语音", variant="primary")
current_config = gr.State({
"msg": None,
"tts_model": DEFAULT_TTS_MODEL_NAME,
"ref_audio": None,
"ref_audio_transcribe": None
})
gr.on(triggers=[msg.change, tts_model.change, ref_audio.change,
ref_audio_transcribe.change],
fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio,
"ref_audio_transcribe": ref_text},
inputs=[msg, tts_model, ref_audio, ref_audio_transcribe],
outputs=current_config
)
submit_btn.click(
predict,
[current_config],
[audio_player],
queue=False
)
demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True)