File size: 22,809 Bytes
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
23beeea
 
 
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27f115
23beeea
383520d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
"""Advanced voice enhancement and consistency system for CSM-1B."""
import os
import torch
import torchaudio
import numpy as np
import soundfile as sf
from typing import Dict, List, Optional, Tuple
import logging
from dataclasses import dataclass
from scipy import signal

# Setup logging
logger = logging.getLogger(__name__)

# Define persistent paths
VOICE_REFERENCES_DIR = "/app/voice_references"
VOICE_PROFILES_DIR = "/app/voice_profiles"

# Ensure directories exist
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)

@dataclass
class VoiceProfile:
    """Detailed voice profile with acoustic characteristics."""
    name: str
    speaker_id: int
    # Acoustic parameters
    pitch_range: Tuple[float, float]  # Min/max pitch in Hz
    intensity_range: Tuple[float, float]  # Min/max intensity (volume)
    spectral_tilt: float  # Brightness vs. darkness
    prosody_pattern: str  # Pattern of intonation and rhythm
    speech_rate: float  # Relative speech rate (1.0 = normal)
    formant_shift: float  # Formant frequency shift (1.0 = no shift)
    # Reference audio
    reference_segments: List[torch.Tensor]
    # Normalization parameters
    target_rms: float = 0.2
    target_peak: float = 0.95
    
    def get_enhancement_params(self) -> Dict:
        """Get parameters for enhancing generated audio."""
        return {
            "target_rms": self.target_rms,
            "target_peak": self.target_peak,
            "pitch_range": self.pitch_range,
            "formant_shift": self.formant_shift,
            "speech_rate": self.speech_rate,
            "spectral_tilt": self.spectral_tilt
        }

# Voice profiles with carefully tuned parameters
VOICE_PROFILES = {
    "alloy": VoiceProfile(
        name="alloy",
        speaker_id=0,
        pitch_range=(85, 180),  # Hz - balanced range
        intensity_range=(0.15, 0.3),  # moderate intensity
        spectral_tilt=0.0,  # neutral tilt
        prosody_pattern="balanced",
        speech_rate=1.0,  # normal rate
        formant_shift=1.0,  # no shift
        reference_segments=[],
        target_rms=0.2,
        target_peak=0.95
    ),
    "echo": VoiceProfile(
        name="echo",
        speaker_id=1,
        pitch_range=(75, 165),  # Hz - lower, resonant 
        intensity_range=(0.2, 0.35),  # slightly stronger
        spectral_tilt=-0.2,  # more low frequencies
        prosody_pattern="deliberate",
        speech_rate=0.95,  # slightly slower
        formant_shift=0.95,  # slightly lower formants
        reference_segments=[],
        target_rms=0.22,  # slightly louder
        target_peak=0.95
    ),
    "fable": VoiceProfile(
        name="fable",
        speaker_id=2,
        pitch_range=(120, 250),  # Hz - higher range
        intensity_range=(0.15, 0.28),  # moderate intensity
        spectral_tilt=0.2,  # more high frequencies
        prosody_pattern="animated",
        speech_rate=1.05,  # slightly faster
        formant_shift=1.05,  # slightly higher formants
        reference_segments=[],
        target_rms=0.19,
        target_peak=0.95
    ),
    "onyx": VoiceProfile(
        name="onyx",
        speaker_id=3,
        pitch_range=(65, 150),  # Hz - deeper range
        intensity_range=(0.18, 0.32),  # moderate-strong
        spectral_tilt=-0.3,  # more low frequencies
        prosody_pattern="authoritative",
        speech_rate=0.93,  # slightly slower
        formant_shift=0.9,  # lower formants
        reference_segments=[],
        target_rms=0.23,  # stronger
        target_peak=0.95
    ),
    "nova": VoiceProfile(
        name="nova",
        speaker_id=4,
        pitch_range=(90, 200),  # Hz - warm midrange
        intensity_range=(0.15, 0.27),  # moderate
        spectral_tilt=-0.1,  # slightly warm
        prosody_pattern="flowing",
        speech_rate=1.0,  # normal rate
        formant_shift=1.0,  # no shift
        reference_segments=[],
        target_rms=0.2,
        target_peak=0.95
    ),
    "shimmer": VoiceProfile(
        name="shimmer",
        speaker_id=5,
        pitch_range=(140, 280),  # Hz - brighter, higher
        intensity_range=(0.15, 0.25),  # moderate-light
        spectral_tilt=0.3,  # more high frequencies
        prosody_pattern="light",
        speech_rate=1.07,  # slightly faster
        formant_shift=1.1,  # higher formants
        reference_segments=[],
        target_rms=0.18,  # slightly softer
        target_peak=0.95
    )
}

# Voice-specific prompt templates - crafted to establish voice identity clearly
VOICE_PROMPTS = {
    "alloy": [
        "Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.",
        "This is Alloy speaking. My voice is designed to be clear and conversational.",
        "Alloy here - I have a neutral, friendly voice with balanced tone qualities."
    ],
    "echo": [
        "Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.",
        "This is Echo speaking. My voice has a rich, resonant quality with depth.",
        "Echo here - My voice is characterized by its warm, resonant tones."
    ],
    "fable": [
        "Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.",
        "This is Fable speaking. My voice is characterized by its clear, bright quality.",
        "Fable here - My voice is light, articulate, and slightly higher-pitched."
    ],
    "onyx": [
        "Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.",
        "This is Onyx speaking. My voice has a powerful, deep quality with gravitas.",
        "Onyx here - My voice is characterized by its depth and commanding presence."
    ],
    "nova": [
        "Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.",
        "This is Nova speaking. My voice has a smooth, harmonious quality.",
        "Nova here - My voice is characterized by its warm, friendly mid-tones."
    ],
    "shimmer": [
        "Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.",
        "This is Shimmer speaking. My voice has an airy, higher-pitched quality.",
        "Shimmer here - My voice is characterized by its bright, crystalline tones."
    ]
}

def initialize_voice_profiles():
    """Initialize voice profiles with default settings.
    
    This function loads existing voice profiles from disk if available,
    or initializes them with default settings.
    """
    global VOICE_PROFILES
    
    # Try to load existing profiles from persistent storage
    profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
    
    if os.path.exists(profile_path):
        try:
            logger.info(f"Loading voice profiles from {profile_path}")
            saved_profiles = torch.load(profile_path)
            
            # Update existing profiles with saved data
            for name, data in saved_profiles.items():
                if name in VOICE_PROFILES:
                    VOICE_PROFILES[name].reference_segments = [
                        seg.to(torch.device("cpu")) for seg in data.get('reference_segments', [])
                    ]
            
            logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices")
        except Exception as e:
            logger.error(f"Error loading voice profiles: {e}")
            logger.info("Using default voice profiles")
    else:
        logger.info("No saved voice profiles found, using defaults")
    
    # Ensure all voices have at least empty reference segments
    for name, profile in VOICE_PROFILES.items():
        if not hasattr(profile, 'reference_segments'):
            profile.reference_segments = []
    
    logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices")
    return VOICE_PROFILES

def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor:
    """Apply professional-grade normalization to audio.
    
    Args:
        audio: Audio tensor
        target_rms: Target RMS level for normalization
        target_peak: Target peak level for limiting
        
    Returns:
        Normalized audio tensor
    """
    # Ensure audio is on CPU for processing
    audio_cpu = audio.detach().cpu()
    
    # Handle silent audio
    if audio_cpu.abs().max() < 1e-6:
        logger.warning("Audio is nearly silent, returning original")
        return audio
    
    # Calculate current RMS
    current_rms = torch.sqrt(torch.mean(audio_cpu ** 2))
    
    # Apply RMS normalization
    if current_rms > 0:
        gain = target_rms / current_rms
        normalized = audio_cpu * gain
    else:
        normalized = audio_cpu
        
    # Apply peak limiting
    current_peak = normalized.abs().max()
    if current_peak > target_peak:
        normalized = normalized * (target_peak / current_peak)
    
    # Return to original device
    return normalized.to(audio.device)

def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor:
    """Apply anti-muffling to improve clarity.
    
    Args:
        audio: Audio tensor
        sample_rate: Audio sample rate
        clarity_boost: Amount of high frequency boost (1.0 = no boost)
        
    Returns:
        Processed audio tensor
    """
    # Convert to numpy for filtering
    audio_np = audio.detach().cpu().numpy()
    
    try:
        # Design a high shelf filter to boost high frequencies
        # Use a standard high-shelf filter that's supported by scipy.signal
        # We'll use a second-order Butterworth high-pass filter as an alternative
        cutoff = 2000  # Hz
        b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False)
        
        # Apply the filter with the clarity boost gain
        boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost
        
        # Mix with original to maintain some warmth
        mix_ratio = 0.7  # 70% processed, 30% original
        processed = mix_ratio * boosted + (1-mix_ratio) * audio_np
        
    except Exception as e:
        logger.warning(f"Audio enhancement failed, using original: {e}")
        # Return original audio if enhancement fails
        return audio
    
    # Convert back to tensor on original device
    return torch.tensor(processed, dtype=audio.dtype, device=audio.device)

def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor:
    """Apply comprehensive audio enhancement based on voice profile.
    
    Args:
        audio: Audio tensor
        sample_rate: Audio sample rate
        voice_profile: Voice profile containing enhancement parameters
        
    Returns:
        Enhanced audio tensor
    """
    if audio is None or audio.numel() == 0:
        logger.error("Cannot enhance empty audio")
        return audio
        
    try:
        # Step 1: Normalize audio levels
        params = voice_profile.get_enhancement_params()
        normalized = normalize_audio(
            audio, 
            target_rms=params["target_rms"],
            target_peak=params["target_peak"]
        )
        
        # Step 2: Apply anti-muffling based on spectral tilt
        # Positive tilt means brighter voice so less clarity boost needed
        clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5
        clarified = apply_anti_muffling(
            normalized,
            sample_rate,
            clarity_boost=clarity_boost
        )
        
        # Log the enhancement
        logger.debug(
            f"Enhanced audio for {voice_profile.name}: "
            f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, "
            f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}"
        )
        
        return clarified
        
    except Exception as e:
        logger.error(f"Error in audio enhancement: {e}")
        return audio  # Return original audio if enhancement fails

def validate_generated_audio(
    audio: torch.Tensor, 
    voice_name: str, 
    sample_rate: int,
    min_expected_duration: float = 0.5
) -> Tuple[bool, torch.Tensor, str]:
    """Validate and fix generated audio.
    
    Args:
        audio: Audio tensor to validate
        voice_name: Name of the voice used
        sample_rate: Audio sample rate
        min_expected_duration: Minimum expected duration in seconds
        
    Returns:
        Tuple of (is_valid, fixed_audio, message)
    """
    if audio is None:
        return False, torch.zeros(1), "Audio is None"
    
    # Check for NaN values
    if torch.isnan(audio).any():
        logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros")
        audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio)
    
    # Check audio duration
    duration = audio.shape[0] / sample_rate
    if duration < min_expected_duration:
        logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)")
        return False, audio, f"Audio too short: {duration:.2f}s"
    
    # Check for silent sections - this can indicate generation problems
    rms = torch.sqrt(torch.mean(audio ** 2))
    if rms < 0.01:  # Very low RMS indicates near silence
        logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})")
        return False, audio, f"Audio nearly silent: RMS = {rms:.6f}"
    
    # Check if audio suddenly cuts off - this detects premature stopping
    # Calculate RMS in the last 100ms
    last_samples = int(0.1 * sample_rate)
    if audio.shape[0] > last_samples:
        end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2))
        if end_rms > 0.1:  # High RMS at the end suggests an abrupt cutoff
            logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})")
            return True, audio, "Audio may have cut off prematurely"
    
    return True, audio, "Audio validation passed"

def create_voice_segments(app_state, regenerate: bool = False):
    """Create high-quality voice reference segments.
    
    Args:
        app_state: Application state containing generator
        regenerate: Whether to regenerate existing references
    """
    generator = app_state.generator
    if not generator:
        logger.error("Cannot create voice segments: generator not available")
        return
    
    # Use persistent directory for voice reference segments
    os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
    
    for voice_name, profile in VOICE_PROFILES.items():
        voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
        os.makedirs(voice_dir, exist_ok=True)
        
        # Check if we already have references
        if not regenerate and profile.reference_segments:
            logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments")
            continue
        
        # Get prompts for this voice
        prompts = VOICE_PROMPTS[voice_name]
        
        # Generate reference segments
        logger.info(f"Generating reference segments for voice: {voice_name}")
        reference_segments = []
        
        for i, prompt in enumerate(prompts):
            ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
            
            # Skip if file exists and we're not regenerating
            if not regenerate and os.path.exists(ref_path):
                try:
                    # Load existing reference
                    audio_tensor, sr = torchaudio.load(ref_path)
                    if sr != generator.sample_rate:
                        audio_tensor = torchaudio.functional.resample(
                            audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
                        )
                    else:
                        audio_tensor = audio_tensor.squeeze(0)
                    reference_segments.append(audio_tensor.to(generator.device))
                    logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}")
                    continue
                except Exception as e:
                    logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}")
            
            try:
                # Use a lower temperature for more stability in reference samples
                logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'")
                
                # We want references to be as clean as possible
                audio = generator.generate(
                    text=prompt,
                    speaker=profile.speaker_id,
                    context=[],  # No context for initial samples to prevent voice bleed
                    max_audio_length_ms=6000,  # Shorter for more control
                    temperature=0.7,  # Lower temperature for more stability
                    topk=30,  # More focused sampling
                )
                
                # Validate and enhance the audio
                is_valid, audio, message = validate_generated_audio(
                    audio, voice_name, generator.sample_rate
                )
                
                if is_valid:
                    # Enhance the audio
                    audio = enhance_audio(audio, generator.sample_rate, profile)
                    
                    # Save the reference to persistent storage
                    torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
                    reference_segments.append(audio)
                    logger.info(f"Generated reference {i+1} for {voice_name}: {message}")
                else:
                    logger.warning(f"Invalid reference for {voice_name}: {message}")
                    # Try again with different settings if invalid
                    if i < len(prompts) - 1:
                        logger.info(f"Trying again with next prompt")
                        continue
            
            except Exception as e:
                logger.error(f"Error generating reference for {voice_name}: {e}")
        
        # Update the voice profile with references
        if reference_segments:
            VOICE_PROFILES[voice_name].reference_segments = reference_segments
            logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments")
    
    # Save the updated profiles to persistent storage
    save_voice_profiles()

def get_voice_segments(voice_name: str, device: torch.device) -> List:
    """Get context segments for a given voice.
    
    Args:
        voice_name: Name of the voice to use
        device: Device to place tensors on
        
    Returns:
        List of context segments
    """
    from app.model import Segment
    
    if voice_name not in VOICE_PROFILES:
        logger.warning(f"Voice {voice_name} not found, defaulting to alloy")
        voice_name = "alloy"
    
    profile = VOICE_PROFILES[voice_name]
    
    # If we don't have reference segments yet, create them
    if not profile.reference_segments:
        try:
            # Try to load from disk - use persistent storage
            voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
            
            if os.path.exists(voice_dir):
                reference_segments = []
                prompts = VOICE_PROMPTS[voice_name]
                
                for i, prompt in enumerate(prompts):
                    ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
                    if os.path.exists(ref_path):
                        audio_tensor, sr = torchaudio.load(ref_path)
                        audio_tensor = audio_tensor.squeeze(0)
                        reference_segments.append(audio_tensor)
                
                if reference_segments:
                    profile.reference_segments = reference_segments
                    logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}")
        except Exception as e:
            logger.error(f"Error loading reference segments for {voice_name}: {e}")
    
    # Create context segments from references
    context = []
    if profile.reference_segments:
        for i, ref_audio in enumerate(profile.reference_segments):
            # Use corresponding prompt if available, otherwise use a generic one
            text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}"
            
            context.append(
                Segment(
                    speaker=profile.speaker_id,
                    text=text,
                    audio=ref_audio.to(device)
                )
            )
    
    logger.info(f"Returning {len(context)} context segments for {voice_name}")
    return context

def save_voice_profiles():
    """Save voice profiles to persistent storage."""
    os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
    profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
    
    # Create a serializable version of the profiles
    serializable_profiles = {}
    for name, profile in VOICE_PROFILES.items():
        serializable_profiles[name] = {
            'reference_segments': [seg.cpu() for seg in profile.reference_segments]
        }
    
    # Save to persistent storage
    torch.save(serializable_profiles, profile_path)
    logger.info(f"Saved voice profiles to {profile_path}")
    
def process_generated_audio(
    audio: torch.Tensor, 
    voice_name: str, 
    sample_rate: int,
    text: str
) -> torch.Tensor:
    """Process generated audio for consistency and quality.
    
    Args:
        audio: Audio tensor
        voice_name: Name of voice used
        sample_rate: Audio sample rate  
        text: Text that was spoken
        
    Returns:
        Processed audio tensor
    """
    # Validate the audio
    is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate)
    if not is_valid:
        logger.warning(f"Generated audio validation issue: {message}")
    
    # Get voice profile for enhancement
    profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"])
    
    # Enhance the audio based on voice profile
    enhanced = enhance_audio(audio, sample_rate, profile)
    
    # Log the enhancement
    original_duration = audio.shape[0] / sample_rate
    enhanced_duration = enhanced.shape[0] / sample_rate
    logger.info(
        f"Processed audio for '{voice_name}': "
        f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, "
        f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}"
    )
    
    return enhanced