Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9165825
1
Parent(s):
d4ece3a
Add examples, better UI
Browse files
app.py
CHANGED
@@ -5,84 +5,137 @@ import sys
|
|
5 |
import soundfile as sf
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
8 |
|
9 |
repo_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite-vi"
|
10 |
repo_dir = "StyleTTS2-lite-vi"
|
11 |
-
|
12 |
if not os.path.exists(repo_dir):
|
13 |
subprocess.run(["git", "clone", repo_url, repo_dir])
|
14 |
-
|
15 |
-
# Clone repo and load model
|
16 |
sys.path.append(os.path.abspath(repo_dir))
|
17 |
from inference import StyleTTS2
|
18 |
|
19 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
-
|
21 |
config_path = os.path.join(repo_dir, "Models", "config.yml")
|
22 |
models_path = os.path.join(repo_dir, "Models", "model.pth")
|
23 |
model = StyleTTS2(config_path, models_path).eval().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Core inference function
|
26 |
-
def
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
for
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
#custom-box textarea {
|
52 |
-
min-height: 250px !important;
|
53 |
-
height: 100% !important;
|
54 |
-
}
|
55 |
-
"""
|
56 |
|
57 |
# Gradio UI
|
58 |
-
with gr.Blocks(
|
59 |
-
gr.
|
60 |
-
gr.Markdown(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
inputs=[
|
78 |
text_prompt,
|
79 |
-
reference_audios,
|
80 |
-
n_merge,
|
81 |
denoise,
|
82 |
avg_style,
|
83 |
stabilize
|
84 |
],
|
85 |
-
outputs=synthesized_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
)
|
87 |
|
88 |
demo.launch()
|
|
|
5 |
import soundfile as sf
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
+
import traceback
|
9 |
|
10 |
repo_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite-vi"
|
11 |
repo_dir = "StyleTTS2-lite-vi"
|
|
|
12 |
if not os.path.exists(repo_dir):
|
13 |
subprocess.run(["git", "clone", repo_url, repo_dir])
|
|
|
|
|
14 |
sys.path.append(os.path.abspath(repo_dir))
|
15 |
from inference import StyleTTS2
|
16 |
|
17 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
18 |
config_path = os.path.join(repo_dir, "Models", "config.yml")
|
19 |
models_path = os.path.join(repo_dir, "Models", "model.pth")
|
20 |
model = StyleTTS2(config_path, models_path).eval().to(device)
|
21 |
+
session_uploaded_files = {}
|
22 |
+
voice_path = os.path.join(repo_dir, "reference_audio")
|
23 |
+
eg_voices = [os.path.join(voice_path,"vn_1.wav"), os.path.join(voice_path,"vn_2.wav")]
|
24 |
+
eg_texts = [
|
25 |
+
"[id_1] Với [en-us]{StyleTTS2-lite} bạn có thể sử dụng [en-us]{language tag} để mô hình chắc chắn đọc bằng tiếng Anh, [id_2]cũng như sử dụng [en-us]{speaker tag} để chuyển đổi nhanh giữa các giọng đọc.",
|
26 |
+
"[id_1]Chỉ với khoảng 90 triệu tham số, [id_2][en-us]{StyleTTS2-lite} có thể dễ dàng tạo giọng nói với tốc độ cao.",
|
27 |
+
]
|
28 |
+
|
29 |
|
30 |
# Core inference function
|
31 |
+
def main(text_prompt, denoise, avg_style, stabilize):
|
32 |
+
try:
|
33 |
+
global session_uploaded_files
|
34 |
+
reference_paths = [file for file in session_uploaded_files.values()]
|
35 |
+
speakers = {}
|
36 |
+
for i, path in enumerate(reference_paths, 1):
|
37 |
+
speaker_id = f"id_{i}"
|
38 |
+
speakers[speaker_id] = {
|
39 |
+
"path": path,
|
40 |
+
"lang": "vi",
|
41 |
+
"speed": 1.0
|
42 |
+
}
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
r = model.generate(text_prompt, speakers, avg_style, stabilize, denoise, 20, "[id_1]")
|
46 |
+
|
47 |
+
r = r / np.abs(r).max()
|
48 |
+
sf.write("output.wav", r, samplerate=24000)
|
49 |
+
|
50 |
+
return "output.wav", "Audio generated successfully!"
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
error_message = traceback.format_exc()
|
54 |
+
return None, error_message
|
55 |
+
|
56 |
+
def on_file_upload(file_list):
|
57 |
+
if not file_list:
|
58 |
+
return "No file uploaded yet."
|
59 |
|
60 |
+
global session_uploaded_files
|
61 |
+
for file_path in file_list:
|
62 |
+
file_name = os.path.basename(file_path)
|
63 |
+
session_uploaded_files[file_name] = file_path #update and remove duplicate
|
64 |
+
|
65 |
+
uploaded_infos = []
|
66 |
+
uploaded_file_names = list(session_uploaded_files.keys())
|
67 |
+
for i in range(len(session_uploaded_files)):
|
68 |
+
uploaded_infos.append(f"[id_{i}]: {uploaded_file_names[i]}")
|
69 |
+
|
70 |
+
summary = "\n".join(uploaded_infos)
|
71 |
+
return f"Current reference audios:\n{summary}"
|
72 |
+
|
73 |
+
def gen_example(text_prompt):
|
74 |
+
on_file_upload(eg_voices)
|
75 |
+
output, status = main(text_prompt, 0.6, True, True)
|
76 |
+
#Reset
|
77 |
+
on_file_upload(None)
|
78 |
+
global session_uploaded_files
|
79 |
+
session_uploaded_files = {}
|
80 |
+
return output, status
|
81 |
+
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
# Gradio UI
|
84 |
+
with gr.Blocks() as demo:
|
85 |
+
gr.HTML("<h1 style='text-align: center;'>StyleTTS2‑Lite Demo</h1>")
|
86 |
+
gr.Markdown(
|
87 |
+
"Download the local inference package from Hugging Face: "
|
88 |
+
"[StyleTTS2‑Lite (Vietnamese)]"
|
89 |
+
"(https://huggingface.co/dangtr0408/StyleTTS2-lite-vi/)."
|
90 |
+
)
|
91 |
+
gr.Markdown(
|
92 |
+
"Please specify a language tag in your inputs if the word is not Vietnamese, e.g., [en-us]{ } for English. For more information, see "
|
93 |
+
"[eSpeakNG docs]"
|
94 |
+
"(https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md)"
|
95 |
+
)
|
96 |
+
|
97 |
+
with gr.Row(equal_height=True):
|
98 |
+
with gr.Column(scale=1):
|
99 |
+
text_prompt = gr.Textbox(label="Text Prompt", placeholder="Enter your text here...", lines=4)
|
100 |
+
with gr.Column(scale=1):
|
101 |
+
avg_style = gr.Checkbox(label="Use Average Styles", value=True)
|
102 |
+
stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
|
103 |
+
denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.6, label="Denoise Strength")
|
104 |
+
|
105 |
+
with gr.Row(equal_height=True):
|
106 |
+
with gr.Column(scale=1):
|
107 |
+
reference_audios = gr.File(label="Reference Audios", file_types=[".wav", ".mp3"], file_count="multiple", height=150)
|
108 |
+
gen_button = gr.Button("Generate")
|
109 |
+
with gr.Column(scale=1):
|
110 |
+
synthesized_audio = gr.Audio(label="Generate Audio", type="filepath")
|
111 |
+
|
112 |
+
status = gr.Textbox(label="Status", interactive=False, lines=3)
|
113 |
+
|
114 |
+
reference_audios.change(
|
115 |
+
on_file_upload,
|
116 |
+
inputs=[reference_audios],
|
117 |
+
outputs=[status]
|
118 |
+
)
|
119 |
+
|
120 |
+
gen_button.click(
|
121 |
+
fn=main,
|
122 |
inputs=[
|
123 |
text_prompt,
|
|
|
|
|
124 |
denoise,
|
125 |
avg_style,
|
126 |
stabilize
|
127 |
],
|
128 |
+
outputs=[synthesized_audio, status]
|
129 |
+
)
|
130 |
+
|
131 |
+
gr.Examples(
|
132 |
+
examples=[eg_texts[0], eg_texts[1]],
|
133 |
+
inputs=[text_prompt],
|
134 |
+
outputs=[synthesized_audio, status],
|
135 |
+
fn=gen_example,
|
136 |
+
cache_examples=False,
|
137 |
+
label="Creation Examples",
|
138 |
+
run_on_click=True
|
139 |
)
|
140 |
|
141 |
demo.launch()
|