File size: 4,083 Bytes
2c564da
60b8847
 
 
2210c38
2c564da
 
 
60b8847
 
 
 
 
 
 
 
 
2c564da
 
 
 
 
 
 
 
 
 
60b8847
2c564da
 
 
 
 
 
 
 
 
 
 
 
 
60b8847
2c564da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b8847
2c564da
 
 
 
 
 
 
 
 
60b8847
2c564da
 
 
 
 
 
 
 
2210c38
2c564da
2210c38
2c564da
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# tts.py
import os
import torch
import torchaudio
import spaces
import numpy as np
from typing import AsyncGenerator, Generator, Optional, Protocol, Tuple, Union
from numpy.typing import NDArray
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio

# Create cache/output directory
os.makedirs("outputs", exist_ok=True)

# Create a global TTS model instance
tts_model = None

# Define TTSOptions for compatibility with FastRTC
class TortoiseOptions:
    def __init__(self, voice_preset="random", voice_file_path=None):
        self.voice_preset = voice_preset
        self.voice_file_path = voice_file_path

# The main Tortoise TTS wrapper class implementing FastRTC's TTSModel protocol
class TortoiseTTSModel:
    def __init__(self):
        global tts_model
        if tts_model is None:
            self._initialize_model()
        self.tts_model = tts_model
    
    @spaces.GPU
    def _initialize_model(self):
        global tts_model
        print("Initializing Tortoise-TTS model...")
        tts_model = TextToSpeech(use_deepspeed=torch.cuda.is_available())
        print(f"Model initialized. Using device: {next(tts_model.autoregressive.parameters()).device}")
    
    @spaces.GPU
    def _generate_speech(self, text, options=None):
        options = options or TortoiseOptions()
        
        try:
            # Process voice sample if provided
            voice_samples = None
            if options.voice_file_path and os.path.exists(options.voice_file_path):
                print(f"Loading voice from {options.voice_file_path}")
                voice_samples, _ = load_audio(options.voice_file_path, 22050)
                voice_samples = [voice_samples]
                voice_preset = None
            else:
                voice_preset = options.voice_preset
            
            # Generate speech
            print(f"Generating speech for text: {text[:50]}...")
            
            gen = self.tts_model.tts_with_preset(
                text,
                voice_samples=voice_samples,
                preset=voice_preset
            )
            
            # Return the audio data with sample rate
            return 24000, gen.squeeze(0).cpu().numpy().astype(np.float32)
            
        except Exception as e:
            print(f"Error generating speech: {str(e)}")
            raise
    
    def tts(self, text: str, options: Optional[TortoiseOptions] = None) -> Tuple[int, NDArray[np.float32]]:
        """Generate speech audio from text in a single call"""
        return self._generate_speech(text, options)
    
    async def stream_tts(self, text: str, options: Optional[TortoiseOptions] = None) -> AsyncGenerator[Tuple[int, NDArray[np.float32]], None]:
        """Stream speech audio asynchronously in chunks"""
        sample_rate, audio_array = self._generate_speech(text, options)
        
        # Split audio into chunks for streaming
        chunk_size = 4000  # Adjust chunk size as needed
        for i in range(0, len(audio_array), chunk_size):
            chunk = audio_array[i:i+chunk_size]
            yield sample_rate, chunk
    
    def stream_tts_sync(self, text: str, options: Optional[TortoiseOptions] = None) -> Generator[Tuple[int, NDArray[np.float32]], None, None]:
        """Stream speech audio synchronously in chunks"""
        sample_rate, audio_array = self._generate_speech(text, options)
        
        # Split audio into chunks for streaming
        chunk_size = 4000  # Adjust chunk size as needed
        for i in range(0, len(audio_array), chunk_size):
            chunk = audio_array[i:i+chunk_size]
            yield sample_rate, chunk

# Create a singleton instance for easy import
tortoise_tts = TortoiseTTSModel()

# Legacy function for backward compatibility
async def generate_speech(text, voice_preset="random", voice_file_path=None):
    options = TortoiseOptions(voice_preset, voice_file_path)
    sample_rate, audio_array = tortoise_tts.tts(text, options)
    return f"outputs/tts_output_{hash(text) % 10000}.wav", (sample_rate, torch.from_numpy(audio_array))