Command_RTC / tts.py
RSHVR's picture
Update tts.py
2c564da verified
# tts.py
import os
import torch
import torchaudio
import spaces
import numpy as np
from typing import AsyncGenerator, Generator, Optional, Protocol, Tuple, Union
from numpy.typing import NDArray
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio
# Create cache/output directory
os.makedirs("outputs", exist_ok=True)
# Create a global TTS model instance
tts_model = None
# Define TTSOptions for compatibility with FastRTC
class TortoiseOptions:
def __init__(self, voice_preset="random", voice_file_path=None):
self.voice_preset = voice_preset
self.voice_file_path = voice_file_path
# The main Tortoise TTS wrapper class implementing FastRTC's TTSModel protocol
class TortoiseTTSModel:
def __init__(self):
global tts_model
if tts_model is None:
self._initialize_model()
self.tts_model = tts_model
@spaces.GPU
def _initialize_model(self):
global tts_model
print("Initializing Tortoise-TTS model...")
tts_model = TextToSpeech(use_deepspeed=torch.cuda.is_available())
print(f"Model initialized. Using device: {next(tts_model.autoregressive.parameters()).device}")
@spaces.GPU
def _generate_speech(self, text, options=None):
options = options or TortoiseOptions()
try:
# Process voice sample if provided
voice_samples = None
if options.voice_file_path and os.path.exists(options.voice_file_path):
print(f"Loading voice from {options.voice_file_path}")
voice_samples, _ = load_audio(options.voice_file_path, 22050)
voice_samples = [voice_samples]
voice_preset = None
else:
voice_preset = options.voice_preset
# Generate speech
print(f"Generating speech for text: {text[:50]}...")
gen = self.tts_model.tts_with_preset(
text,
voice_samples=voice_samples,
preset=voice_preset
)
# Return the audio data with sample rate
return 24000, gen.squeeze(0).cpu().numpy().astype(np.float32)
except Exception as e:
print(f"Error generating speech: {str(e)}")
raise
def tts(self, text: str, options: Optional[TortoiseOptions] = None) -> Tuple[int, NDArray[np.float32]]:
"""Generate speech audio from text in a single call"""
return self._generate_speech(text, options)
async def stream_tts(self, text: str, options: Optional[TortoiseOptions] = None) -> AsyncGenerator[Tuple[int, NDArray[np.float32]], None]:
"""Stream speech audio asynchronously in chunks"""
sample_rate, audio_array = self._generate_speech(text, options)
# Split audio into chunks for streaming
chunk_size = 4000 # Adjust chunk size as needed
for i in range(0, len(audio_array), chunk_size):
chunk = audio_array[i:i+chunk_size]
yield sample_rate, chunk
def stream_tts_sync(self, text: str, options: Optional[TortoiseOptions] = None) -> Generator[Tuple[int, NDArray[np.float32]], None, None]:
"""Stream speech audio synchronously in chunks"""
sample_rate, audio_array = self._generate_speech(text, options)
# Split audio into chunks for streaming
chunk_size = 4000 # Adjust chunk size as needed
for i in range(0, len(audio_array), chunk_size):
chunk = audio_array[i:i+chunk_size]
yield sample_rate, chunk
# Create a singleton instance for easy import
tortoise_tts = TortoiseTTSModel()
# Legacy function for backward compatibility
async def generate_speech(text, voice_preset="random", voice_file_path=None):
options = TortoiseOptions(voice_preset, voice_file_path)
sample_rate, audio_array = tortoise_tts.tts(text, options)
return f"outputs/tts_output_{hash(text) % 10000}.wav", (sample_rate, torch.from_numpy(audio_array))