File size: 8,738 Bytes
030c851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91223c9
 
 
 
 
 
 
 
 
 
 
 
030c851
91223c9
 
 
 
2d176f4
 
 
 
 
91223c9
 
 
 
 
 
 
 
030c851
 
91223c9
 
030c851
 
 
c549dab
030c851
 
 
 
 
 
 
 
 
 
 
91223c9
030c851
2d176f4
 
 
030c851
2d176f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a9bb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030c851
4a9bb1a
 
 
 
030c851
4a9bb1a
 
 
 
 
 
 
 
 
 
 
 
 
91223c9
4a9bb1a
 
 
 
 
2d176f4
4a9bb1a
2d176f4
 
030c851
4a9bb1a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import time
import logging
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional

from dia.model import Dia

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"

# Global model instance (lazy loaded)
_model = None


def _get_model() -> Dia:
    """Lazy-load the Dia model to avoid loading it until needed"""
    global _model
    if _model is None:
        logger.info("Loading Dia model...")
        try:
            # Check if torch is available with correct version
            logger.info(f"PyTorch version: {torch.__version__}")
            logger.info(f"CUDA available: {torch.cuda.is_available()}")
            if torch.cuda.is_available():
                logger.info(f"CUDA version: {torch.version.cuda}")
                logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
            
            # Check if model path exists
            logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
            
            # Load the model with detailed logging
            logger.info("Initializing Dia model...")
            _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
            
            # Log model details
            logger.info(f"Dia model loaded successfully")
            logger.info(f"Model type: {type(_model).__name__}")
            # Check if model has parameters method (PyTorch models do, but Dia might not)
            if hasattr(_model, 'parameters'):
                logger.info(f"Model device: {next(_model.parameters()).device}")
            else:
                logger.info("Model device: Device information not available for Dia model")
        except ImportError as import_err:
            logger.error(f"Import error loading Dia model: {import_err}")
            logger.error(f"This may indicate missing dependencies")
            raise
        except FileNotFoundError as file_err:
            logger.error(f"File not found error loading Dia model: {file_err}")
            logger.error(f"Model path may be incorrect or inaccessible")
            raise
        except Exception as e:
            logger.error(f"Error loading Dia model: {e}", exc_info=True)
            logger.error(f"Error type: {type(e).__name__}")
            logger.error(f"This may indicate incompatible versions or missing CUDA support")
            raise
    return _model


def generate_speech(text: str, language: str = "zh") -> str:
    """Public interface for TTS generation using Dia model
    
    Args:
        text (str): Input text to synthesize
        language (str): Language code (not used in Dia model, kept for API compatibility)
        
    Returns:
        str: Path to the generated audio file
    """
    logger.info(f"Generating speech for text length: {len(text)}")
    logger.info(f"Text content (first 50 chars): {text[:50]}...")
    
    # Create output directory if it doesn't exist
    output_dir = "temp/outputs"
    logger.info(f"Ensuring output directory exists: {output_dir}")
    try:
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Output directory ready: {output_dir}")
    except PermissionError as perm_err:
        logger.error(f"Permission error creating output directory: {perm_err}")
        # Fall back to dummy TTS
        logger.info("Falling back to dummy TTS due to directory creation error")
        from utils.tts_dummy import generate_speech as dummy_generate_speech
        return dummy_generate_speech(text, language)
    except Exception as dir_err:
        logger.error(f"Error creating output directory: {dir_err}")
        # Fall back to dummy TTS
        logger.info("Falling back to dummy TTS due to directory creation error")
        from utils.tts_dummy import generate_speech as dummy_generate_speech
        return dummy_generate_speech(text, language)
    
    # Generate unique output path
    timestamp = int(time.time())
    output_path = f"{output_dir}/output_{timestamp}.wav"
    logger.info(f"Output will be saved to: {output_path}")
    
    # Get the model
    logger.info("Retrieving Dia model instance")
    try:
        model = _get_model()
        logger.info("Successfully retrieved Dia model instance")
    except Exception as model_err:
        logger.error(f"Failed to get Dia model: {model_err}")
        logger.error(f"Error type: {type(model_err).__name__}")
        # Fall back to dummy TTS
        logger.info("Falling back to dummy TTS due to model loading error")
        from utils.tts_dummy import generate_speech as dummy_generate_speech
        return dummy_generate_speech(text, language)
    
    # Generate audio
    logger.info("Starting audio generation with Dia model")
    start_time = time.time()
    
    try:
        with torch.inference_mode():
            logger.info("Calling model.generate() with inference_mode")
            output_audio_np = model.generate(
                text,
                max_tokens=None,  # Use default from model config
                cfg_scale=3.0,
                temperature=1.3,
                top_p=0.95,
                cfg_filter_top_k=35,
                use_torch_compile=False,  # Keep False for stability
                verbose=False
            )
            logger.info("Model.generate() completed")
    except RuntimeError as rt_err:
        logger.error(f"Runtime error during generation: {rt_err}")
        if "CUDA out of memory" in str(rt_err):
            logger.error("CUDA out of memory error - consider reducing batch size or model size")
        # Fall back to dummy TTS
        logger.info("Falling back to dummy TTS due to runtime error during generation")
        from utils.tts_dummy import generate_speech as dummy_generate_speech
        return dummy_generate_speech(text, language)
    except Exception as gen_err:
        logger.error(f"Error during audio generation: {gen_err}")
        logger.error(f"Error type: {type(gen_err).__name__}")
        # Fall back to dummy TTS
        logger.info("Falling back to dummy TTS due to error during generation")
        from utils.tts_dummy import generate_speech as dummy_generate_speech
        return dummy_generate_speech(text, language)
    
    end_time = time.time()
    generation_time = end_time - start_time
    logger.info(f"Generation finished in {generation_time:.2f} seconds")
    
    # Process the output
    if output_audio_np is not None:
        logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
        logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
        
        # Apply a slight slowdown for better quality (0.94x speed)
        speed_factor = 0.94
        original_len = len(output_audio_np)
        target_len = int(original_len / speed_factor)
        
        logger.info(f"Applying speed adjustment factor: {speed_factor}")
        if target_len != original_len and target_len > 0:
            try:
                x_original = np.arange(original_len)
                x_resampled = np.linspace(0, original_len - 1, target_len)
                output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
                logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
            except Exception as resample_err:
                logger.error(f"Error during audio resampling: {resample_err}")
                logger.warning("Using original audio without resampling")
        
        # Save the audio file
        logger.info(f"Saving audio to file: {output_path}")
        try:
            sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
            logger.info(f"Audio successfully saved to {output_path}")
        except Exception as save_err:
            logger.error(f"Error saving audio file: {save_err}")
            logger.error(f"Error type: {type(save_err).__name__}")
            # Fall back to dummy TTS
            logger.info("Falling back to dummy TTS due to error saving audio file")
            from utils.tts_dummy import generate_speech as dummy_generate_speech
            return dummy_generate_speech(text, language)
        
        return output_path
    else:
        logger.warning("Generation produced no output (None returned from model)")
        logger.warning("This may indicate a model configuration issue or empty input text")
        dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
        logger.warning(f"Returning dummy audio path: {dummy_path}")
        return dummy_path