RSHVR commited on
Commit
2c564da
·
verified ·
1 Parent(s): 8d98b9d

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +82 -39
tts.py CHANGED
@@ -1,7 +1,11 @@
 
1
  import os
2
  import torch
3
  import torchaudio
4
  import spaces
 
 
 
5
  from tortoise.api import TextToSpeech
6
  from tortoise.utils.audio import load_audio
7
 
@@ -11,48 +15,87 @@ os.makedirs("outputs", exist_ok=True)
11
  # Create a global TTS model instance
12
  tts_model = None
13
 
14
- # Synchronous function with GPU decorator
15
- @spaces.GPU
16
- def _generate_speech_gpu(text, voice_preset="random", voice_file_path=None):
17
- global tts_model
18
-
19
- try:
20
- # Initialize the model if not already initialized
 
 
 
21
  if tts_model is None:
22
- print("Initializing Tortoise-TTS model...")
23
- tts_model = TextToSpeech(use_deepspeed=torch.cuda.is_available())
24
- print(f"Model initialized. Using device: {next(tts_model.autoregressive.parameters()).device}")
25
-
26
- # Process voice sample if provided
27
- voice_samples = None
28
- if voice_file_path and os.path.exists(voice_file_path):
29
- print(f"Loading voice from {voice_file_path}")
30
- voice_samples, _ = load_audio(voice_file_path, 22050)
31
- voice_samples = [voice_samples]
32
- voice_preset = None
33
-
34
- # Generate speech
35
- print(f"Generating speech for text: {text[:50]}...")
36
- output_filename = f"outputs/tts_output_{hash(text) % 10000}.wav"
37
-
38
- gen = tts_model.tts_with_preset(
39
- text,
40
- voice_samples=voice_samples,
41
- preset=voice_preset
42
- )
43
 
44
- # Save the generated audio
45
- torchaudio.save(output_filename, gen.squeeze(0).cpu(), 24000)
46
- print(f"Speech generated and saved to {output_filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Return the filename and audio data
49
- return output_filename, (24000, gen.squeeze(0).cpu())
 
 
 
 
 
 
 
50
 
51
- except Exception as e:
52
- print(f"Error generating speech: {str(e)}")
53
- raise
 
 
 
 
 
54
 
55
- # Async wrapper that calls the GPU function
56
  async def generate_speech(text, voice_preset="random", voice_file_path=None):
57
- # Call the GPU-decorated function
58
- return _generate_speech_gpu(text, voice_preset, voice_file_path)
 
 
1
+ # tts.py
2
  import os
3
  import torch
4
  import torchaudio
5
  import spaces
6
+ import numpy as np
7
+ from typing import AsyncGenerator, Generator, Optional, Protocol, Tuple, Union
8
+ from numpy.typing import NDArray
9
  from tortoise.api import TextToSpeech
10
  from tortoise.utils.audio import load_audio
11
 
 
15
  # Create a global TTS model instance
16
  tts_model = None
17
 
18
+ # Define TTSOptions for compatibility with FastRTC
19
+ class TortoiseOptions:
20
+ def __init__(self, voice_preset="random", voice_file_path=None):
21
+ self.voice_preset = voice_preset
22
+ self.voice_file_path = voice_file_path
23
+
24
+ # The main Tortoise TTS wrapper class implementing FastRTC's TTSModel protocol
25
+ class TortoiseTTSModel:
26
+ def __init__(self):
27
+ global tts_model
28
  if tts_model is None:
29
+ self._initialize_model()
30
+ self.tts_model = tts_model
31
+
32
+ @spaces.GPU
33
+ def _initialize_model(self):
34
+ global tts_model
35
+ print("Initializing Tortoise-TTS model...")
36
+ tts_model = TextToSpeech(use_deepspeed=torch.cuda.is_available())
37
+ print(f"Model initialized. Using device: {next(tts_model.autoregressive.parameters()).device}")
38
+
39
+ @spaces.GPU
40
+ def _generate_speech(self, text, options=None):
41
+ options = options or TortoiseOptions()
 
 
 
 
 
 
 
 
42
 
43
+ try:
44
+ # Process voice sample if provided
45
+ voice_samples = None
46
+ if options.voice_file_path and os.path.exists(options.voice_file_path):
47
+ print(f"Loading voice from {options.voice_file_path}")
48
+ voice_samples, _ = load_audio(options.voice_file_path, 22050)
49
+ voice_samples = [voice_samples]
50
+ voice_preset = None
51
+ else:
52
+ voice_preset = options.voice_preset
53
+
54
+ # Generate speech
55
+ print(f"Generating speech for text: {text[:50]}...")
56
+
57
+ gen = self.tts_model.tts_with_preset(
58
+ text,
59
+ voice_samples=voice_samples,
60
+ preset=voice_preset
61
+ )
62
+
63
+ # Return the audio data with sample rate
64
+ return 24000, gen.squeeze(0).cpu().numpy().astype(np.float32)
65
+
66
+ except Exception as e:
67
+ print(f"Error generating speech: {str(e)}")
68
+ raise
69
+
70
+ def tts(self, text: str, options: Optional[TortoiseOptions] = None) -> Tuple[int, NDArray[np.float32]]:
71
+ """Generate speech audio from text in a single call"""
72
+ return self._generate_speech(text, options)
73
+
74
+ async def stream_tts(self, text: str, options: Optional[TortoiseOptions] = None) -> AsyncGenerator[Tuple[int, NDArray[np.float32]], None]:
75
+ """Stream speech audio asynchronously in chunks"""
76
+ sample_rate, audio_array = self._generate_speech(text, options)
77
 
78
+ # Split audio into chunks for streaming
79
+ chunk_size = 4000 # Adjust chunk size as needed
80
+ for i in range(0, len(audio_array), chunk_size):
81
+ chunk = audio_array[i:i+chunk_size]
82
+ yield sample_rate, chunk
83
+
84
+ def stream_tts_sync(self, text: str, options: Optional[TortoiseOptions] = None) -> Generator[Tuple[int, NDArray[np.float32]], None, None]:
85
+ """Stream speech audio synchronously in chunks"""
86
+ sample_rate, audio_array = self._generate_speech(text, options)
87
 
88
+ # Split audio into chunks for streaming
89
+ chunk_size = 4000 # Adjust chunk size as needed
90
+ for i in range(0, len(audio_array), chunk_size):
91
+ chunk = audio_array[i:i+chunk_size]
92
+ yield sample_rate, chunk
93
+
94
+ # Create a singleton instance for easy import
95
+ tortoise_tts = TortoiseTTSModel()
96
 
97
+ # Legacy function for backward compatibility
98
  async def generate_speech(text, voice_preset="random", voice_file_path=None):
99
+ options = TortoiseOptions(voice_preset, voice_file_path)
100
+ sample_rate, audio_array = tortoise_tts.tts(text, options)
101
+ return f"outputs/tts_output_{hash(text) % 10000}.wav", (sample_rate, torch.from_numpy(audio_array))