lalalic commited on
Commit
84202be
·
verified ·
1 Parent(s): b283951

Update xtts.py

Browse files
Files changed (1) hide show
  1. xtts.py +47 -23
xtts.py CHANGED
@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
14
  # def upload_bytes(bytes, ext=".wav"):
15
  # return bytes
16
 
17
- from qili import upload_bytes
18
  # if __name__ == "__main__":
19
  # app = Flask(__name__)
20
  # else:
@@ -32,29 +32,35 @@ if not os.path.exists(sample_root):
32
  default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
33
  ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def predict(text, sample=None, language="zh"):
36
  global tts
37
  global model
38
  try:
39
- if tts is None:
40
- model_dir=os.environ.get("MODEL_DIR")
41
- model_path=model_dir
42
- config_path=f'{model_dir}/config.json'
43
- vocoder_config_path=f'{model_dir}/vocab.json'
44
- model_name="tts_models/multilingual/multi-dataset/xtts_v2"
45
- logging.info(f"loading model {model_name} ...")
46
- tts = TTS(
47
- # model_name,
48
- model_path=model_path,
49
- config_path=config_path,
50
- vocoder_config_path=vocoder_config_path,
51
- progress_bar=True
52
- )
53
- model=tts.synthesizer.tts_model
54
- #hack to use cache
55
- model.__get_conditioning_latents=model.get_conditioning_latents
56
- model.get_conditioning_latents=get_conditioning_latents
57
- logging.info("model is ready")
58
  text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
59
  wav = tts.tts(
60
  text,
@@ -143,7 +149,7 @@ def trim_sample_audio(speaker_wav):
143
  from flask import Flask, request
144
  app = Flask(__name__)
145
  @app.route("/tts")
146
- def convert():
147
  text = request.args.get('text')
148
  sample = request.args.get('sample')
149
  language = request.args.get('language')
@@ -158,11 +164,29 @@ def convert():
158
 
159
  # @app.get("/play")
160
  # def play(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
161
- @app.route("/play")
162
- def play():
163
  url=convert()
164
  return playInHTML(url)
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # import gradio as gr
167
  # demo=gr.Interface(predict, inputs=["text", "text"], outputs=gr.Audio())
168
  # app = gr.mount_gradio_app(app, demo, path="/")
 
14
  # def upload_bytes(bytes, ext=".wav"):
15
  # return bytes
16
 
17
+ from qili import upload, upload_bytes
18
  # if __name__ == "__main__":
19
  # app = Flask(__name__)
20
  # else:
 
32
  default_sample=f'{os.path.dirname(os.path.abspath(__file__))}/sample.wav', f'{sample_root}/sample.pt'
33
  ffmpeg=f'{os.path.dirname(os.path.abspath(__file__))}/ffmpeg'
34
 
35
+ def get_tts():
36
+ global tts
37
+ global model
38
+ if tts is None:
39
+ model_dir=os.environ.get("MODEL_DIR")
40
+ model_path=model_dir
41
+ config_path=f'{model_dir}/config.json'
42
+ vocoder_config_path=f'{model_dir}/vocab.json'
43
+ model_name="tts_models/multilingual/multi-dataset/xtts_v2"
44
+ logging.info(f"loading model {model_name} ...")
45
+ tts = TTS(
46
+ # model_name,
47
+ model_path=model_path,
48
+ config_path=config_path,
49
+ vocoder_config_path=vocoder_config_path,
50
+ progress_bar=True
51
+ )
52
+ model=tts.synthesizer.tts_model
53
+ #hack to use cache
54
+ model.__get_conditioning_latents=model.get_conditioning_latents
55
+ model.get_conditioning_latents=get_conditioning_latents
56
+ logging.info("model is ready")
57
+
58
+
59
  def predict(text, sample=None, language="zh"):
60
  global tts
61
  global model
62
  try:
63
+ get_tts()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
65
  wav = tts.tts(
66
  text,
 
149
  from flask import Flask, request
150
  app = Flask(__name__)
151
  @app.route("/tts")
152
+ def tts():
153
  text = request.args.get('text')
154
  sample = request.args.get('sample')
155
  language = request.args.get('language')
 
164
 
165
  # @app.get("/play")
166
  # def play(text: str=Query(None), sample: str=Query(None), language: str=Query('zh')):
167
+ @app.route("/tts/play")
168
+ def tts_play():
169
  url=convert()
170
  return playInHTML(url)
171
 
172
+ @app.route("/clone")
173
+ def clone():
174
+ source=request.args.get('source')
175
+ sample=request.args.get('sample')
176
+ get_tts()
177
+ output=tempfile.mktemp(suffix=".wav", delete=False)[1]
178
+ tts.voice_conversion_to_file(
179
+ source_wav=source,
180
+ target_wav=sample,
181
+ file_path=output
182
+ )
183
+ return upload(output)
184
+
185
+ @app.route("/clone/play")
186
+ def clone_play():
187
+ url=clone()
188
+ return playInHTML(url)
189
+
190
  # import gradio as gr
191
  # demo=gr.Interface(predict, inputs=["text", "text"], outputs=gr.Audio())
192
  # app = gr.mount_gradio_app(app, demo, path="/")