dangtr0408 commited on
Commit
5763d4c
·
1 Parent(s): 9165825

Optimize gradio

Browse files
Files changed (1) hide show
  1. app.py +23 -31
app.py CHANGED
@@ -18,20 +18,17 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
  config_path = os.path.join(repo_dir, "Models", "config.yml")
19
  models_path = os.path.join(repo_dir, "Models", "model.pth")
20
  model = StyleTTS2(config_path, models_path).eval().to(device)
21
- session_uploaded_files = {}
22
  voice_path = os.path.join(repo_dir, "reference_audio")
23
  eg_voices = [os.path.join(voice_path,"vn_1.wav"), os.path.join(voice_path,"vn_2.wav")]
24
  eg_texts = [
 
25
  "[id_1] Với [en-us]{StyleTTS2-lite} bạn có thể sử dụng [en-us]{language tag} để mô hình chắc chắn đọc bằng tiếng Anh, [id_2]cũng như sử dụng [en-us]{speaker tag} để chuyển đổi nhanh giữa các giọng đọc.",
26
- "[id_1]Chỉ với khoảng 90 triệu tham số, [id_2][en-us]{StyleTTS2-lite} có thể dễ dàng tạo giọng nói với tốc độ cao.",
27
  ]
28
 
29
 
30
  # Core inference function
31
- def main(text_prompt, denoise, avg_style, stabilize):
32
  try:
33
- global session_uploaded_files
34
- reference_paths = [file for file in session_uploaded_files.values()]
35
  speakers = {}
36
  for i, path in enumerate(reference_paths, 1):
37
  speaker_id = f"id_{i}"
@@ -42,11 +39,10 @@ def main(text_prompt, denoise, avg_style, stabilize):
42
  }
43
 
44
  with torch.no_grad():
45
- r = model.generate(text_prompt, speakers, avg_style, stabilize, denoise, 20, "[id_1]")
46
-
47
- r = r / np.abs(r).max()
48
  sf.write("output.wav", r, samplerate=24000)
49
-
50
  return "output.wav", "Audio generated successfully!"
51
 
52
  except Exception as e:
@@ -55,29 +51,24 @@ def main(text_prompt, denoise, avg_style, stabilize):
55
 
56
  def on_file_upload(file_list):
57
  if not file_list:
58
- return "No file uploaded yet."
59
 
60
- global session_uploaded_files
61
  for file_path in file_list:
62
  file_name = os.path.basename(file_path)
63
- session_uploaded_files[file_name] = file_path #update and remove duplicate
64
 
65
  uploaded_infos = []
66
- uploaded_file_names = list(session_uploaded_files.keys())
67
- for i in range(len(session_uploaded_files)):
68
- uploaded_infos.append(f"[id_{i}]: {uploaded_file_names[i]}")
69
 
70
  summary = "\n".join(uploaded_infos)
71
- return f"Current reference audios:\n{summary}"
72
 
73
- def gen_example(text_prompt):
74
- on_file_upload(eg_voices)
75
- output, status = main(text_prompt, 0.6, True, True)
76
- #Reset
77
- on_file_upload(None)
78
- global session_uploaded_files
79
- session_uploaded_files = {}
80
- return output, status
81
 
82
 
83
  # Gradio UI
@@ -112,14 +103,15 @@ with gr.Blocks() as demo:
112
  status = gr.Textbox(label="Status", interactive=False, lines=3)
113
 
114
  reference_audios.change(
115
- on_file_upload,
116
- inputs=[reference_audios],
117
- outputs=[status]
118
  )
119
 
120
  gen_button.click(
121
  fn=main,
122
  inputs=[
 
123
  text_prompt,
124
  denoise,
125
  avg_style,
@@ -129,12 +121,12 @@ with gr.Blocks() as demo:
129
  )
130
 
131
  gr.Examples(
132
- examples=[eg_texts[0], eg_texts[1]],
133
- inputs=[text_prompt],
134
- outputs=[synthesized_audio, status],
135
  fn=gen_example,
136
  cache_examples=False,
137
- label="Creation Examples",
138
  run_on_click=True
139
  )
140
 
 
18
  config_path = os.path.join(repo_dir, "Models", "config.yml")
19
  models_path = os.path.join(repo_dir, "Models", "model.pth")
20
  model = StyleTTS2(config_path, models_path).eval().to(device)
 
21
  voice_path = os.path.join(repo_dir, "reference_audio")
22
  eg_voices = [os.path.join(voice_path,"vn_1.wav"), os.path.join(voice_path,"vn_2.wav")]
23
  eg_texts = [
24
+ "Chỉ với khoảng 90 triệu tham số, [en-us]{StyleTTS2-lite} có thể dễ dàng tạo giọng nói với tốc độ cao.",
25
  "[id_1] Với [en-us]{StyleTTS2-lite} bạn có thể sử dụng [en-us]{language tag} để mô hình chắc chắn đọc bằng tiếng Anh, [id_2]cũng như sử dụng [en-us]{speaker tag} để chuyển đổi nhanh giữa các giọng đọc.",
 
26
  ]
27
 
28
 
29
  # Core inference function
30
+ def main(reference_paths, text_prompt, denoise, avg_style, stabilize):
31
  try:
 
 
32
  speakers = {}
33
  for i, path in enumerate(reference_paths, 1):
34
  speaker_id = f"id_{i}"
 
39
  }
40
 
41
  with torch.no_grad():
42
+ r = model.generate(text_prompt, speakers, avg_style, stabilize, denoise, 18, "[id_1]") #Should seperate style computation process to style caching.
43
+ r = r / np.abs(r).max()
44
+
45
  sf.write("output.wav", r, samplerate=24000)
 
46
  return "output.wav", "Audio generated successfully!"
47
 
48
  except Exception as e:
 
51
 
52
  def on_file_upload(file_list):
53
  if not file_list:
54
+ return None, "No file uploaded yet."
55
 
56
+ unique_files = {}
57
  for file_path in file_list:
58
  file_name = os.path.basename(file_path)
59
+ unique_files[file_name] = file_path #update and remove duplicate
60
 
61
  uploaded_infos = []
62
+ uploaded_file_names = list(unique_files.keys())
63
+ for i in range(len(uploaded_file_names)):
64
+ uploaded_infos.append(f"[id_{i+1}]: {uploaded_file_names[i]}")
65
 
66
  summary = "\n".join(uploaded_infos)
67
+ return list(unique_files.values()), f"Current reference audios:\n{summary}"
68
 
69
+ def gen_example(reference_paths, text_prompt):
70
+ output, status = main(reference_paths, text_prompt, 0.6, True, True)
71
+ return output, eg_voices, status
 
 
 
 
 
72
 
73
 
74
  # Gradio UI
 
103
  status = gr.Textbox(label="Status", interactive=False, lines=3)
104
 
105
  reference_audios.change(
106
+ on_file_upload,
107
+ inputs=[reference_audios],
108
+ outputs=[reference_audios, status]
109
  )
110
 
111
  gen_button.click(
112
  fn=main,
113
  inputs=[
114
+ reference_audios,
115
  text_prompt,
116
  denoise,
117
  avg_style,
 
121
  )
122
 
123
  gr.Examples(
124
+ examples=[[eg_voices, eg_texts[0]], [eg_voices, eg_texts[1]]],
125
+ inputs=[reference_audios, text_prompt],
126
+ outputs=[synthesized_audio, reference_audios, status],
127
  fn=gen_example,
128
  cache_examples=False,
129
+ label="Examples",
130
  run_on_click=True
131
  )
132