import os import time import logging import torch import numpy as np import soundfile as sf from pathlib import Path from typing import Optional from dia.model import Dia # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants DEFAULT_SAMPLE_RATE = 44100 DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B" # Global model instance (lazy loaded) _model = None def _get_model() -> Dia: """Lazy-load the Dia model to avoid loading it until needed""" global _model if _model is None: logger.info("Loading Dia model...") try: # Check if torch is available with correct version logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"GPU device: {torch.cuda.get_device_name(0)}") # Check if model path exists logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}") # Load the model with detailed logging logger.info("Initializing Dia model...") _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16") # Log model details logger.info(f"Dia model loaded successfully") logger.info(f"Model type: {type(_model).__name__}") # Check if model has parameters method (PyTorch models do, but Dia might not) if hasattr(_model, 'parameters'): logger.info(f"Model device: {next(_model.parameters()).device}") else: logger.info("Model device: Device information not available for Dia model") except ImportError as import_err: logger.error(f"Import error loading Dia model: {import_err}") logger.error(f"This may indicate missing dependencies") raise except FileNotFoundError as file_err: logger.error(f"File not found error loading Dia model: {file_err}") logger.error(f"Model path may be incorrect or inaccessible") raise except Exception as e: logger.error(f"Error loading Dia model: {e}", exc_info=True) logger.error(f"Error type: {type(e).__name__}") logger.error(f"This may indicate incompatible versions or missing CUDA support") raise return _model def generate_speech(text: str, language: str = "zh") -> str: """Public interface for TTS generation using Dia model Args: text (str): Input text to synthesize language (str): Language code (not used in Dia model, kept for API compatibility) Returns: str: Path to the generated audio file """ logger.info(f"Generating speech for text length: {len(text)}") logger.info(f"Text content (first 50 chars): {text[:50]}...") # Create output directory if it doesn't exist output_dir = "temp/outputs" logger.info(f"Ensuring output directory exists: {output_dir}") try: os.makedirs(output_dir, exist_ok=True) logger.info(f"Output directory ready: {output_dir}") except PermissionError as perm_err: logger.error(f"Permission error creating output directory: {perm_err}") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to directory creation error") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) except Exception as dir_err: logger.error(f"Error creating output directory: {dir_err}") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to directory creation error") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) # Generate unique output path timestamp = int(time.time()) output_path = f"{output_dir}/output_{timestamp}.wav" logger.info(f"Output will be saved to: {output_path}") # Get the model logger.info("Retrieving Dia model instance") try: model = _get_model() logger.info("Successfully retrieved Dia model instance") except Exception as model_err: logger.error(f"Failed to get Dia model: {model_err}") logger.error(f"Error type: {type(model_err).__name__}") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to model loading error") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) # Generate audio logger.info("Starting audio generation with Dia model") start_time = time.time() try: with torch.inference_mode(): logger.info("Calling model.generate() with inference_mode") output_audio_np = model.generate( text, max_tokens=None, # Use default from model config cfg_scale=3.0, temperature=1.3, top_p=0.95, cfg_filter_top_k=35, use_torch_compile=False, # Keep False for stability verbose=False ) logger.info("Model.generate() completed") except RuntimeError as rt_err: logger.error(f"Runtime error during generation: {rt_err}") if "CUDA out of memory" in str(rt_err): logger.error("CUDA out of memory error - consider reducing batch size or model size") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to runtime error during generation") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) except Exception as gen_err: logger.error(f"Error during audio generation: {gen_err}") logger.error(f"Error type: {type(gen_err).__name__}") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to error during generation") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) end_time = time.time() generation_time = end_time - start_time logger.info(f"Generation finished in {generation_time:.2f} seconds") # Process the output if output_audio_np is not None: logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}") logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}") # Apply a slight slowdown for better quality (0.94x speed) speed_factor = 0.94 original_len = len(output_audio_np) target_len = int(original_len / speed_factor) logger.info(f"Applying speed adjustment factor: {speed_factor}") if target_len != original_len and target_len > 0: try: x_original = np.arange(original_len) x_resampled = np.linspace(0, original_len - 1, target_len) output_audio_np = np.interp(x_resampled, x_original, output_audio_np) logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed") except Exception as resample_err: logger.error(f"Error during audio resampling: {resample_err}") logger.warning("Using original audio without resampling") # Save the audio file logger.info(f"Saving audio to file: {output_path}") try: sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE) logger.info(f"Audio successfully saved to {output_path}") except Exception as save_err: logger.error(f"Error saving audio file: {save_err}") logger.error(f"Error type: {type(save_err).__name__}") # Fall back to dummy TTS logger.info("Falling back to dummy TTS due to error saving audio file") from utils.tts_dummy import generate_speech as dummy_generate_speech return dummy_generate_speech(text, language) return output_path else: logger.warning("Generation produced no output (None returned from model)") logger.warning("This may indicate a model configuration issue or empty input text") dummy_path = f"{output_dir}/dummy_{timestamp}.wav" logger.warning(f"Returning dummy audio path: {dummy_path}") return dummy_path