from transformers import AutoTokenizer, AutoModelForCausalLM import torch import soundfile as sf from xcodec2.modeling_xcodec2 import XCodec2Model import numpy as np import ChatTTS import re DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B" DEMO_EXAMPLES = [ ["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"], ["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"], ["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"], ["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."] ] class TTSapi: def __init__(self, model_name=DEFAULT_TTS_MODEL_NAME, codec_model_name="HKUST-Audio/xcodec2", device=torch.device("cuda:0")): self.reload(model_name, codec_model_name, device) def reload(self, model_name=DEFAULT_TTS_MODEL_NAME, codec_model_name="HKUST-Audio/xcodec2", device=torch.device("cuda:0")): if 'llasa' in model_name.lower(): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) self.model.eval().to(device) self.codec_model = XCodec2Model.from_pretrained(codec_model_name) self.codec_model.eval().to(device) self.device = device self.codec_model_name = codec_model_name self.sr = 16000 elif 'chattts' in model_name.lower(): self.model = ChatTTS.Chat() self.model.load(compile=False) # Set to True for better performance but would l significantly reduce the loading speed self.sr = 24000 self.punctuation = r'[,,.。??!!~~;;]' else: raise ValueError(f'不支持的TTS模型:{model_name}') self.model_name = model_name def ids_to_speech_tokens(self, speech_ids): speech_tokens_str = [] for speech_id in speech_ids: speech_tokens_str.append(f"<|s_{speech_id}|>") return speech_tokens_str def extract_speech_ids(self, speech_tokens_str): speech_ids = [] for token_str in speech_tokens_str: if token_str.startswith('<|s_') and token_str.endswith('|>'): num_str = token_str[4:-2] num = int(num_str) speech_ids.append(num) else: print(f"Unexpected token: {token_str}") return speech_ids def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'): #TTS start! with torch.no_grad(): if 'chattts' in self.model_name.lower(): # rand_spk = chat.sample_random_speaker() # print(rand_spk) # save it for later timbre recovery # params_infer_code = ChatTTS.Chat.InferCodeParams( # spk_emb = rand_spk, # add sampled speaker # temperature = .3, # using custom temperature # top_P = 0.7, # top P decode # top_K = 20, # top K decode # ) break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2) params_refine_text = ChatTTS.Chat.RefineTextParams( prompt=f'[oral_2][laugh_0][break_{break_num}]', ) wavs = self.model.infer([input_text], params_refine_text=params_refine_text, ) gen_wav_save = wavs[0] sf.write(save_path, gen_wav_save, 24000) else: if speech_prompt: # only 16khz speech support! prompt_wav, sr = sf.read(speech_prompt) # you can find wav in Files prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0) # Encode the prompt wav vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav) print("Prompt Vq Code Shape:", vq_code_prompt.shape ) vq_code_prompt = vq_code_prompt[0,0,:] # Convert int 12345 to token <|s_12345|> speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) else: speech_ids_prefix = '' formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" # Tokenize the text ( and the speech prefix) chat = [ {"role": "user", "content": "Convert the text to speech:" + formatted_text}, {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} ] input_ids = self.tokenizer.apply_chat_template( chat, tokenize=True, return_tensors='pt', continue_final_message=True ) input_ids = input_ids.to(self.device) speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') # Generate the speech autoregressively outputs = self.model.generate( input_ids, max_length=2048, # We trained our model with a max length of 2048 eos_token_id= speech_end_id , do_sample=True, top_p=1, # Adjusts the diversity of generated content temperature=1, # Controls randomness in output ) # Extract the speech tokens generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1] speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # Convert token <|s_23456|> to int 23456 speech_tokens = self.extract_speech_ids(speech_tokens) speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0) # Decode the speech tokens to speech waveform gen_wav = self.codec_model.decode_code(speech_tokens) # if only need the generated part if speech_prompt: gen_wav = gen_wav[:,:,prompt_wav.shape[1]:] gen_wav_save = gen_wav[0, 0, :].cpu().numpy() sf.write(save_path, gen_wav_save, 16000) # gen_wav_save = np.clip(gen_wav_save, -1, 1) # gen_wav_save = (gen_wav_save * 32767).astype(np.int16) return gen_wav_save if __name__ == '__main__': # Llasa-8B shows better text understanding ability. # input_text = " He shouted, 'Everyone, please gather 'round! Here's the plan: 1) Set-up at 9:15 a.m.; 2) Lunch at 12:00 p.m. (please RSVP!); 3) Playing — e.g., games, music, etc. — from 1:15 to 4:45; and 4) Clean-up at 5 p.m.'" # prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # input_text = prompt_text + '嘻嘻,臭宝儿你真可爱,我好喜欢你呀。' # save_root = 'wavs/generated/' # save_path = save_root + 'test.wav' # speech_ref = 'wavs/ref/太乙真人.wav' # # speech_ref = None # # 帘外雨潺潺,春意阑珊。罗衾不耐五更寒。梦里不知身是客,一晌贪欢。独自莫凭栏,无限江山。别时容易见时难。流水落花春去也,天上人间。 # llasa_tts = TTSapi() # gen = llasa_tts.forward(input_text, speech_prompt=speech_ref, save_path=save_path) # print(gen.shape) import gradio as gr synthesiser = TTSapi() TTS_LOADED = True def predict(config): global TTS_LOADED, synthesiser print(f"待合成文本:{config['msg']}") print(f"选中TTS模型:{config['tts_model']}") print(f"参考音频路径:{config['ref_audio']}") print(f"参考音频文本:{config['ref_audio_transcribe']}") text = config['msg'] try: if len(text) == 0: audio_output = np.array([0], dtype=np.int16) print("输入为空,无法合成语音") else: if not TTS_LOADED: print('TTS模型首次加载...') gr.Info("初次加载TTS模型,请稍候..", duration=63) synthesiser = TTSapi(model_name=config['tts_model'])#, device="cuda:2") TTS_LOADED = True print('加载完毕...') # 检查当前模型是否是所选 if config['tts_model'] != synthesiser.model_name: print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') synthesiser.reload(model_name=config['tts_model']) # 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀 if config['ref_audio']: prompt_text = config['ref_audio_transcribe'] if prompt_text is None: # prompt_text = ... raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型 text = prompt_text + text audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio']) except Exception as e: print('!!!!!!!!') print(e) print('!!!!!!!!') return (synthesiser.sr if synthesiser else 16000, audio_output) with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: gr.Markdown(""" # Personalized TTS Demo ## 使用步骤 * 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例 * 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可 """) with gr.Row(): with gr.Column(): # TTS模型选择 tts_model = gr.Dropdown( label="选择TTS模型", choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], value=DEFAULT_TTS_MODEL_NAME, interactive=True, visible=False # 给产品演示,暂时不展示模型选择 ) # 参考音频上传 ref_audio = gr.Audio( label="上传参考音频", type="filepath", interactive=True ) ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) # 创建示例数据 examples = gr.Examples( examples=DEMO_EXAMPLES, inputs=[ref_audio, ref_audio_transcribe], fn=predict ) with gr.Column(): audio_player = gr.Audio( label="听听我声音~", type="numpy", interactive=False ) msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本") submit_btn = gr.Button("合成语音", variant="primary") current_config = gr.State({ "msg": None, "tts_model": DEFAULT_TTS_MODEL_NAME, "ref_audio": None, "ref_audio_transcribe": None }) gr.on(triggers=[msg.change, tts_model.change, ref_audio.change, ref_audio_transcribe.change], fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio, "ref_audio_transcribe": ref_text}, inputs=[msg, tts_model, ref_audio, ref_audio_transcribe], outputs=current_config ) submit_btn.click( predict, [current_config], [audio_player], queue=False ) demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True)