dangtr0408 commited on
Commit
9165825
·
1 Parent(s): d4ece3a

Add examples, better UI

Browse files
Files changed (1) hide show
  1. app.py +108 -55
app.py CHANGED
@@ -5,84 +5,137 @@ import sys
5
  import soundfile as sf
6
  import numpy as np
7
  import torch
 
8
 
9
  repo_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite-vi"
10
  repo_dir = "StyleTTS2-lite-vi"
11
-
12
  if not os.path.exists(repo_dir):
13
  subprocess.run(["git", "clone", repo_url, repo_dir])
14
-
15
- # Clone repo and load model
16
  sys.path.append(os.path.abspath(repo_dir))
17
  from inference import StyleTTS2
18
 
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
-
21
  config_path = os.path.join(repo_dir, "Models", "config.yml")
22
  models_path = os.path.join(repo_dir, "Models", "model.pth")
23
  model = StyleTTS2(config_path, models_path).eval().to(device)
 
 
 
 
 
 
 
 
24
 
25
  # Core inference function
26
- def process_inputs(text_prompt, reference_audio_paths,
27
- n_merge, denoise, avg_style,stabilize):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- speakers = {}
30
- for i, path in enumerate(reference_audio_paths, 1):
31
- speaker_id = f"id_{i}"
32
- speakers[speaker_id] = {
33
- "path": path,
34
- "lang": "vi",
35
- "speed": 1.0
36
- }
37
-
38
- with torch.no_grad():
39
- r = model.generate(text_prompt, speakers, avg_style, stabilize, denoise, n_merge, "[id_1]")
40
-
41
- r = r / np.abs(r).max()
42
- sf.write("output.wav", r, samplerate=24000)
43
- return "output.wav"
44
-
45
- custom_css = """
46
- #custom-box {
47
- min-height: 300px !important;
48
- display: flex;
49
- align-items: center;
50
- }
51
- #custom-box textarea {
52
- min-height: 250px !important;
53
- height: 100% !important;
54
- }
55
- """
56
 
57
  # Gradio UI
58
- with gr.Blocks(css=custom_css) as demo:
59
- gr.Markdown("## StyleTTS2-lite-vi Demo")
60
- gr.Markdown("Upload a reference audio and input your text to synthesize speech with style control.")
61
-
62
- with gr.Row():
63
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="Enter your text here...", elem_id="custom-box")
64
- reference_audios = gr.File(label="Reference Audios", file_types=[".wav", ".mp3", ".flac"], file_count="multiple", elem_id="custom-box")
65
- # Parameters
66
- with gr.Accordion("Advanced Settings", open=False):
67
- avg_style = gr.Checkbox(label="Use Average Styles", value=True)
68
- stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
69
- denoise = gr.Slider(0.0, 1.0, value=0.6, label="Denoise Strength")
70
- n_merge = gr.Slider(10, 30, value=16, label="Min Words to Merge")
71
-
72
- submit_button = gr.Button("Synthesize")
73
- synthesized_audio = gr.Audio(label="Synthesized Audio", type="filepath")
74
-
75
- submit_button.click(
76
- fn=process_inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  inputs=[
78
  text_prompt,
79
- reference_audios,
80
- n_merge,
81
  denoise,
82
  avg_style,
83
  stabilize
84
  ],
85
- outputs=synthesized_audio
 
 
 
 
 
 
 
 
 
 
86
  )
87
 
88
  demo.launch()
 
5
  import soundfile as sf
6
  import numpy as np
7
  import torch
8
+ import traceback
9
 
10
  repo_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite-vi"
11
  repo_dir = "StyleTTS2-lite-vi"
 
12
  if not os.path.exists(repo_dir):
13
  subprocess.run(["git", "clone", repo_url, repo_dir])
 
 
14
  sys.path.append(os.path.abspath(repo_dir))
15
  from inference import StyleTTS2
16
 
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
18
  config_path = os.path.join(repo_dir, "Models", "config.yml")
19
  models_path = os.path.join(repo_dir, "Models", "model.pth")
20
  model = StyleTTS2(config_path, models_path).eval().to(device)
21
+ session_uploaded_files = {}
22
+ voice_path = os.path.join(repo_dir, "reference_audio")
23
+ eg_voices = [os.path.join(voice_path,"vn_1.wav"), os.path.join(voice_path,"vn_2.wav")]
24
+ eg_texts = [
25
+ "[id_1] Với [en-us]{StyleTTS2-lite} bạn có thể sử dụng [en-us]{language tag} để mô hình chắc chắn đọc bằng tiếng Anh, [id_2]cũng như sử dụng [en-us]{speaker tag} để chuyển đổi nhanh giữa các giọng đọc.",
26
+ "[id_1]Chỉ với khoảng 90 triệu tham số, [id_2][en-us]{StyleTTS2-lite} có thể dễ dàng tạo giọng nói với tốc độ cao.",
27
+ ]
28
+
29
 
30
  # Core inference function
31
+ def main(text_prompt, denoise, avg_style, stabilize):
32
+ try:
33
+ global session_uploaded_files
34
+ reference_paths = [file for file in session_uploaded_files.values()]
35
+ speakers = {}
36
+ for i, path in enumerate(reference_paths, 1):
37
+ speaker_id = f"id_{i}"
38
+ speakers[speaker_id] = {
39
+ "path": path,
40
+ "lang": "vi",
41
+ "speed": 1.0
42
+ }
43
+
44
+ with torch.no_grad():
45
+ r = model.generate(text_prompt, speakers, avg_style, stabilize, denoise, 20, "[id_1]")
46
+
47
+ r = r / np.abs(r).max()
48
+ sf.write("output.wav", r, samplerate=24000)
49
+
50
+ return "output.wav", "Audio generated successfully!"
51
+
52
+ except Exception as e:
53
+ error_message = traceback.format_exc()
54
+ return None, error_message
55
+
56
+ def on_file_upload(file_list):
57
+ if not file_list:
58
+ return "No file uploaded yet."
59
 
60
+ global session_uploaded_files
61
+ for file_path in file_list:
62
+ file_name = os.path.basename(file_path)
63
+ session_uploaded_files[file_name] = file_path #update and remove duplicate
64
+
65
+ uploaded_infos = []
66
+ uploaded_file_names = list(session_uploaded_files.keys())
67
+ for i in range(len(session_uploaded_files)):
68
+ uploaded_infos.append(f"[id_{i}]: {uploaded_file_names[i]}")
69
+
70
+ summary = "\n".join(uploaded_infos)
71
+ return f"Current reference audios:\n{summary}"
72
+
73
+ def gen_example(text_prompt):
74
+ on_file_upload(eg_voices)
75
+ output, status = main(text_prompt, 0.6, True, True)
76
+ #Reset
77
+ on_file_upload(None)
78
+ global session_uploaded_files
79
+ session_uploaded_files = {}
80
+ return output, status
81
+
 
 
 
 
 
82
 
83
  # Gradio UI
84
+ with gr.Blocks() as demo:
85
+ gr.HTML("<h1 style='text-align: center;'>StyleTTS2‑Lite Demo</h1>")
86
+ gr.Markdown(
87
+ "Download the local inference package from Hugging Face: "
88
+ "[StyleTTS2‑Lite (Vietnamese)]"
89
+ "(https://huggingface.co/dangtr0408/StyleTTS2-lite-vi/)."
90
+ )
91
+ gr.Markdown(
92
+ "Please specify a language tag in your inputs if the word is not Vietnamese, e.g., [en-us]{ } for English. For more information, see "
93
+ "[eSpeakNG docs]"
94
+ "(https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md)"
95
+ )
96
+
97
+ with gr.Row(equal_height=True):
98
+ with gr.Column(scale=1):
99
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="Enter your text here...", lines=4)
100
+ with gr.Column(scale=1):
101
+ avg_style = gr.Checkbox(label="Use Average Styles", value=True)
102
+ stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
103
+ denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.6, label="Denoise Strength")
104
+
105
+ with gr.Row(equal_height=True):
106
+ with gr.Column(scale=1):
107
+ reference_audios = gr.File(label="Reference Audios", file_types=[".wav", ".mp3"], file_count="multiple", height=150)
108
+ gen_button = gr.Button("Generate")
109
+ with gr.Column(scale=1):
110
+ synthesized_audio = gr.Audio(label="Generate Audio", type="filepath")
111
+
112
+ status = gr.Textbox(label="Status", interactive=False, lines=3)
113
+
114
+ reference_audios.change(
115
+ on_file_upload,
116
+ inputs=[reference_audios],
117
+ outputs=[status]
118
+ )
119
+
120
+ gen_button.click(
121
+ fn=main,
122
  inputs=[
123
  text_prompt,
 
 
124
  denoise,
125
  avg_style,
126
  stabilize
127
  ],
128
+ outputs=[synthesized_audio, status]
129
+ )
130
+
131
+ gr.Examples(
132
+ examples=[eg_texts[0], eg_texts[1]],
133
+ inputs=[text_prompt],
134
+ outputs=[synthesized_audio, status],
135
+ fn=gen_example,
136
+ cache_examples=False,
137
+ label="Creation Examples",
138
+ run_on_click=True
139
  )
140
 
141
  demo.launch()