aai / tabs /tts /events.py
barreloflube's picture
Refactor code to update UI buttons in audio_tab()
fb7b96a
import re
import os
import gc
import tempfile
from uuid import uuid4
import spaces
import gradio as gr
import torchaudio
import numpy as np
from df.enhance import enhance, load_audio, save_audio
from config import Config
from .load_models import *
from .modules.CosyVoice.cosyvoice.utils.file_utils import load_wav
# Helper functions
def create_temp_file():
return tempfile.NamedTemporaryFile(delete=False)
def assign_language_tags(text):
return text
# # Process the text
# # based on the language assign <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
# # at the start of the text for that language
# # e.g. input: 你好 Hello こんにちは 你好 안녕하세요
# # output: <|zh|>你好<|en|>Hello<|jp|>こんにちは<|yue|>你好<|ko|>안녕하세요
# # Define language patterns
# patterns = {
# 'zh': r'[\u4e00-\u9fff]+', # Chinese characters
# 'en': r'[a-zA-Z]+', # English letters
# 'jp': r'[\u3040-\u30ff\u31f0-\u31ff]+', # Japanese characters
# 'ko': r'[\uac00-\ud7a3]+', # Korean characters
# }
# # Find all matches
# matches = []
# for lang, pattern in patterns.items():
# for match in re.finditer(pattern, text):
# matches.append((match.start(), match.end(), lang, match.group()))
# # Sort matches by start position
# matches.sort(key=lambda x: x[0])
# # Build the result string
# result = []
# last_end = 0
# zh_count = 0
# for start, end, lang, content in matches:
# if start > last_end:
# result.append(text[last_end:start])
# if lang == 'zh':
# zh_count += 1
# if zh_count > 1:
# lang = 'yue'
# result.append(f'<|{lang}|>{content}')
# last_end = end
# if last_end < len(text):
# result.append(text[last_end:])
# return ''.join(result)
def update_mode(mode, sft_speaker, speaker_audio, voice_instructions):
if mode == 'SFT':
return (
gr.update( # sft_speaker
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'VC':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=True,
),
)
elif mode == 'VC-CrossLingual':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'Instruct':
return (
gr.update( # sft_speaker,
visible=True,
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=True,
),
)
else:
raise gr.Error('Invalid mode')
@spaces.GPU(duration=10)
def clear_audio(audio: np.ndarray):
# Save the audio file
audio_file = create_temp_file()
np.save(audio_file.name, audio)
# Load the audio file
audio, _ = load_audio(audio_file.name, sr=df_state.sr())
enhanced = enhance(df_model, df_state, audio)
# Save the enhanced audio file
save_audio(audio_file.name, enhanced, df_state.sr())
return gr.update( # speaker_audio, output_audio
value=audio_file.name,
)
@spaces.GPU(duration=20)
def gen_audio(text, mode, sft_speaker = None, speaker_audio = None, voice_instructions = None):
if mode == any(['VC', 'VC-CrossLingual']):
# Save the speaker audio file
speaker_audio_file = create_temp_file()
np.save(speaker_audio_file.name, speaker_audio)
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
else:
speaker_audio_file = None
prompt_speech_16k = None
# Assign language tags
text = assign_language_tags(text)
# Generate the audio
out_file = create_temp_file()
if mode == 'SFT':
if not sft_speaker:
raise gr.Error('Please select a speaker')
for i, j in enumerate(cosyvoice_sft.inference_sft(
tts_text=text,
spk_id=sft_speaker,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cosyvoice.inference_zero_shot(
tts_text=text,
prompt_text=voice_instructions,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC-CrossLingual':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cosyvoice.inference_cross_lingual(
tts_text=text,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'Instruct':
if not voice_instructions:
raise gr.Error('Please enter voice instructions')
for i, j in enumerate(cosyvoice_instruct.inference_instruct(
tts_text=text,
spk_id=sft_speaker,
instruct_text=voice_instructions,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
return gr.update( # output_audio
value=out_file.name,
)