Michael Hu
enhance logging
91223c9
raw
history blame
17.3 kB
import os
import logging
import time
import soundfile as sf
logger = logging.getLogger(__name__)
# Flag to track TTS engine availability
KOKORO_AVAILABLE = False
DIA_AVAILABLE = False
# Try to import Kokoro first
try:
from kokoro import KPipeline
KOKORO_AVAILABLE = True
logger.info("Kokoro TTS engine is available")
except AttributeError as e:
# Specifically catch the EspeakWrapper.set_data_path error
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue")
else:
# Re-raise if it's a different error
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
raise
except ImportError:
logger.warning("Kokoro TTS engine is not available")
# Try to import Dia as fallback
if not KOKORO_AVAILABLE:
try:
logger.info("Attempting to import Dia TTS engine as fallback")
try:
# Check if required dependencies for Dia are available
import torch
logger.info("PyTorch is available for Dia TTS")
except ImportError as torch_err:
logger.error(f"PyTorch dependency for Dia TTS is missing: {str(torch_err)}")
raise ImportError(f"PyTorch dependency required for Dia TTS: {str(torch_err)}") from torch_err
# Try to import the Dia module
try:
from utils.tts_dia import _get_model as get_dia_model
logger.info("Successfully imported Dia TTS module")
# Verify the model can be accessed
logger.info("Verifying Dia model can be accessed")
model_info = get_dia_model.__module__
logger.info(f"Dia model module: {model_info}")
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available as fallback")
except ImportError as module_err:
logger.error(f"Failed to import Dia TTS module: {str(module_err)}")
logger.error(f"Module path: {module_err.__traceback__.tb_frame.f_globals.get('__file__', 'unknown')}")
raise
except AttributeError as attr_err:
logger.error(f"Dia TTS module attribute error: {str(attr_err)}")
logger.error(f"This may indicate the module exists but has incorrect structure")
raise
except ImportError as e:
logger.error(f"Dia TTS engine is not available due to import error: {str(e)}")
logger.error(f"Import path attempted: {e.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
logger.warning("Will use dummy TTS implementation as fallback")
except Exception as e:
logger.error(f"Unexpected error initializing Dia TTS: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
logger.error("Will use dummy TTS implementation as fallback")
class TTSEngine:
def __init__(self, lang_code='z'):
"""Initialize TTS Engine with Kokoro or Dia as fallback
Args:
lang_code (str): Language code ('a' for US English, 'b' for British English,
'j' for Japanese, 'z' for Mandarin Chinese)
Note: lang_code is only used for Kokoro, not for Dia
"""
logger.info("Initializing TTS Engine")
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
self.engine_type = None
if KOKORO_AVAILABLE:
logger.info(f"Using Kokoro as primary TTS engine with language code: {lang_code}")
try:
self.pipeline = KPipeline(lang_code=lang_code)
self.engine_type = "kokoro"
logger.info("TTS engine successfully initialized with Kokoro")
except Exception as kokoro_err:
logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
logger.error(f"Error type: {type(kokoro_err).__name__}")
logger.info("Will try to fall back to Dia TTS engine")
# Fall through to try Dia
# Try Dia if Kokoro is not available or failed to initialize
if self.engine_type is None and DIA_AVAILABLE:
logger.info("Using Dia as fallback TTS engine")
# For Dia, we don't need to initialize anything here
# The model will be lazy-loaded when needed
self.pipeline = None
self.engine_type = "dia"
logger.info("TTS engine initialized with Dia (lazy loading)")
# Use dummy if no TTS engines are available
if self.engine_type is None:
logger.warning("Using dummy TTS implementation as no TTS engines are available")
logger.warning("Check logs above for specific errors that prevented Kokoro or Dia initialization")
self.pipeline = None
self.engine_type = "dummy"
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Generate speech from text using available TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
Note: voice parameter is only used for Kokoro, not for Dia
speed (float): Speech speed multiplier (0.5 to 2.0)
Note: speed parameter is only used for Kokoro, not for Dia
Returns:
str: Path to the generated audio file
"""
logger.info(f"Generating speech for text length: {len(text)}")
try:
# Create output directory if it doesn't exist
os.makedirs("temp/outputs", exist_ok=True)
# Generate unique output path
output_path = f"temp/outputs/output_{int(time.time())}.wav"
# Use the appropriate TTS engine based on availability
if self.engine_type == "kokoro":
# Use Kokoro for TTS generation
generator = self.pipeline(text, voice=voice, speed=speed)
for _, _, audio in generator:
logger.info(f"Saving Kokoro audio to {output_path}")
sf.write(output_path, audio, 24000)
break
elif self.engine_type == "dia":
# Use Dia for TTS generation
try:
logger.info("Attempting to use Dia TTS for speech generation")
# Import here to avoid circular imports
try:
logger.info("Importing Dia speech generation module")
from utils.tts_dia import generate_speech as dia_generate_speech
logger.info("Successfully imported Dia speech generation function")
except ImportError as import_err:
logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
raise
# Call Dia's generate_speech function
logger.info("Calling Dia's generate_speech function")
output_path = dia_generate_speech(text)
logger.info(f"Generated audio with Dia: {output_path}")
except ImportError as import_err:
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
logger.error("Falling back to dummy audio generation")
return self._generate_dummy_audio(output_path)
except Exception as dia_error:
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
logger.error(f"Error type: {type(dia_error).__name__}")
logger.error("Falling back to dummy audio generation")
# Fall back to dummy audio if Dia fails
return self._generate_dummy_audio(output_path)
else:
# Generate dummy audio as fallback
return self._generate_dummy_audio(output_path)
logger.info(f"Audio generation complete: {output_path}")
return output_path
except Exception as e:
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
raise
def _generate_dummy_audio(self, output_path):
"""Generate a dummy audio file with a simple sine wave
Args:
output_path (str): Path to save the dummy audio file
Returns:
str: Path to the generated dummy audio file
"""
import numpy as np
sample_rate = 24000
duration = 3.0 # seconds
t = np.linspace(0, duration, int(sample_rate * duration), False)
tone = np.sin(2 * np.pi * 440 * t) * 0.3
logger.info(f"Saving dummy audio to {output_path}")
sf.write(output_path, tone, sample_rate)
logger.info(f"Dummy audio generation complete: {output_path}")
return output_path
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
"""Generate speech from text and yield each segment
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
speed (float): Speech speed multiplier (0.5 to 2.0)
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
try:
# Use the appropriate TTS engine based on availability
if self.engine_type == "kokoro":
# Use Kokoro for streaming TTS
generator = self.pipeline(text, voice=voice, speed=speed)
for _, _, audio in generator:
yield 24000, audio
elif self.engine_type == "dia":
# Dia doesn't support streaming natively, so we generate the full audio
# and then yield it as a single chunk
try:
logger.info("Attempting to use Dia TTS for speech streaming")
# Import here to avoid circular imports
try:
logger.info("Importing required modules for Dia streaming")
import torch
logger.info("PyTorch successfully imported for Dia streaming")
try:
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
logger.info("Successfully imported Dia model and sample rate")
except ImportError as import_err:
logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
raise
except ImportError as torch_err:
logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
raise
# Get the Dia model
logger.info("Getting Dia model instance")
try:
model = _get_model()
logger.info("Successfully obtained Dia model instance")
except Exception as model_err:
logger.error(f"Failed to get Dia model instance: {str(model_err)}")
logger.error(f"Error type: {type(model_err).__name__}")
raise
# Generate audio
logger.info("Generating audio with Dia model")
with torch.inference_mode():
output_audio_np = model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
yield DEFAULT_SAMPLE_RATE, output_audio_np
else:
logger.warning("Dia model returned None for audio output")
logger.warning("Falling back to dummy audio stream")
# Fall back to dummy audio if Dia fails
yield from self._generate_dummy_audio_stream()
except ImportError as import_err:
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
logger.error("Falling back to dummy audio stream")
# Fall back to dummy audio if Dia fails
yield from self._generate_dummy_audio_stream()
except Exception as dia_error:
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
logger.error(f"Error type: {type(dia_error).__name__}")
logger.error("Falling back to dummy audio stream")
# Fall back to dummy audio if Dia fails
yield from self._generate_dummy_audio_stream()
else:
# Generate dummy audio chunks as fallback
yield from self._generate_dummy_audio_stream()
except Exception as e:
logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
raise
def _generate_dummy_audio_stream(self):
"""Generate dummy audio chunks with simple sine waves
Yields:
tuple: (sample_rate, audio_data) pairs for each dummy segment
"""
import numpy as np
sample_rate = 24000
duration = 1.0 # seconds per chunk
# Create 3 chunks of dummy audio
for i in range(3):
t = np.linspace(0, duration, int(sample_rate * duration), False)
freq = 440 + (i * 220) # Different frequency for each chunk
tone = np.sin(2 * np.pi * freq * t) * 0.3
yield sample_rate, tone
# Initialize TTS engine with cache decorator if using Streamlit
def get_tts_engine(lang_code='a'):
"""Get or create TTS engine instance
Args:
lang_code (str): Language code for the pipeline
Returns:
TTSEngine: Initialized TTS engine instance
"""
logger.info(f"Requesting TTS engine with language code: {lang_code}")
try:
import streamlit as st
logger.info("Streamlit detected, using cached TTS engine")
@st.cache_resource
def _get_engine():
logger.info("Creating cached TTS engine instance")
engine = TTSEngine(lang_code)
logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
return engine
engine = _get_engine()
logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
return engine
except ImportError:
logger.info("Streamlit not available, creating direct TTS engine instance")
engine = TTSEngine(lang_code)
logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
return engine
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Public interface for TTS generation
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Returns:
str: Path to generated audio file
"""
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
try:
# Get the TTS engine
logger.info("Getting TTS engine instance")
engine = get_tts_engine()
logger.info(f"Using TTS engine type: {engine.engine_type}")
# Generate speech
logger.info("Calling engine.generate_speech")
output_path = engine.generate_speech(text, voice, speed)
logger.info(f"Speech generation complete, output path: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
if hasattr(e, '__traceback__'):
tb = e.__traceback__
while tb.tb_next:
tb = tb.tb_next
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
raise