Spaces:
Running
Running

Refactor model structure: update import paths from 'app.modelz' to 'app.models' across multiple files for consistency, remove obsolete 'modelz' directory, and adjust Dockerfile and migration script to reflect these changes, enhancing clarity and organization in the codebase.
c27f115
"""Advanced voice enhancement and consistency system for CSM-1B.""" | |
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
import soundfile as sf | |
from typing import Dict, List, Optional, Tuple | |
import logging | |
from dataclasses import dataclass | |
from scipy import signal | |
# Setup logging | |
logger = logging.getLogger(__name__) | |
# Define persistent paths | |
VOICE_REFERENCES_DIR = "/app/voice_references" | |
VOICE_PROFILES_DIR = "/app/voice_profiles" | |
# Ensure directories exist | |
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
class VoiceProfile: | |
"""Detailed voice profile with acoustic characteristics.""" | |
name: str | |
speaker_id: int | |
# Acoustic parameters | |
pitch_range: Tuple[float, float] # Min/max pitch in Hz | |
intensity_range: Tuple[float, float] # Min/max intensity (volume) | |
spectral_tilt: float # Brightness vs. darkness | |
prosody_pattern: str # Pattern of intonation and rhythm | |
speech_rate: float # Relative speech rate (1.0 = normal) | |
formant_shift: float # Formant frequency shift (1.0 = no shift) | |
# Reference audio | |
reference_segments: List[torch.Tensor] | |
# Normalization parameters | |
target_rms: float = 0.2 | |
target_peak: float = 0.95 | |
def get_enhancement_params(self) -> Dict: | |
"""Get parameters for enhancing generated audio.""" | |
return { | |
"target_rms": self.target_rms, | |
"target_peak": self.target_peak, | |
"pitch_range": self.pitch_range, | |
"formant_shift": self.formant_shift, | |
"speech_rate": self.speech_rate, | |
"spectral_tilt": self.spectral_tilt | |
} | |
# Voice profiles with carefully tuned parameters | |
VOICE_PROFILES = { | |
"alloy": VoiceProfile( | |
name="alloy", | |
speaker_id=0, | |
pitch_range=(85, 180), # Hz - balanced range | |
intensity_range=(0.15, 0.3), # moderate intensity | |
spectral_tilt=0.0, # neutral tilt | |
prosody_pattern="balanced", | |
speech_rate=1.0, # normal rate | |
formant_shift=1.0, # no shift | |
reference_segments=[], | |
target_rms=0.2, | |
target_peak=0.95 | |
), | |
"echo": VoiceProfile( | |
name="echo", | |
speaker_id=1, | |
pitch_range=(75, 165), # Hz - lower, resonant | |
intensity_range=(0.2, 0.35), # slightly stronger | |
spectral_tilt=-0.2, # more low frequencies | |
prosody_pattern="deliberate", | |
speech_rate=0.95, # slightly slower | |
formant_shift=0.95, # slightly lower formants | |
reference_segments=[], | |
target_rms=0.22, # slightly louder | |
target_peak=0.95 | |
), | |
"fable": VoiceProfile( | |
name="fable", | |
speaker_id=2, | |
pitch_range=(120, 250), # Hz - higher range | |
intensity_range=(0.15, 0.28), # moderate intensity | |
spectral_tilt=0.2, # more high frequencies | |
prosody_pattern="animated", | |
speech_rate=1.05, # slightly faster | |
formant_shift=1.05, # slightly higher formants | |
reference_segments=[], | |
target_rms=0.19, | |
target_peak=0.95 | |
), | |
"onyx": VoiceProfile( | |
name="onyx", | |
speaker_id=3, | |
pitch_range=(65, 150), # Hz - deeper range | |
intensity_range=(0.18, 0.32), # moderate-strong | |
spectral_tilt=-0.3, # more low frequencies | |
prosody_pattern="authoritative", | |
speech_rate=0.93, # slightly slower | |
formant_shift=0.9, # lower formants | |
reference_segments=[], | |
target_rms=0.23, # stronger | |
target_peak=0.95 | |
), | |
"nova": VoiceProfile( | |
name="nova", | |
speaker_id=4, | |
pitch_range=(90, 200), # Hz - warm midrange | |
intensity_range=(0.15, 0.27), # moderate | |
spectral_tilt=-0.1, # slightly warm | |
prosody_pattern="flowing", | |
speech_rate=1.0, # normal rate | |
formant_shift=1.0, # no shift | |
reference_segments=[], | |
target_rms=0.2, | |
target_peak=0.95 | |
), | |
"shimmer": VoiceProfile( | |
name="shimmer", | |
speaker_id=5, | |
pitch_range=(140, 280), # Hz - brighter, higher | |
intensity_range=(0.15, 0.25), # moderate-light | |
spectral_tilt=0.3, # more high frequencies | |
prosody_pattern="light", | |
speech_rate=1.07, # slightly faster | |
formant_shift=1.1, # higher formants | |
reference_segments=[], | |
target_rms=0.18, # slightly softer | |
target_peak=0.95 | |
) | |
} | |
# Voice-specific prompt templates - crafted to establish voice identity clearly | |
VOICE_PROMPTS = { | |
"alloy": [ | |
"Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.", | |
"This is Alloy speaking. My voice is designed to be clear and conversational.", | |
"Alloy here - I have a neutral, friendly voice with balanced tone qualities." | |
], | |
"echo": [ | |
"Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.", | |
"This is Echo speaking. My voice has a rich, resonant quality with depth.", | |
"Echo here - My voice is characterized by its warm, resonant tones." | |
], | |
"fable": [ | |
"Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.", | |
"This is Fable speaking. My voice is characterized by its clear, bright quality.", | |
"Fable here - My voice is light, articulate, and slightly higher-pitched." | |
], | |
"onyx": [ | |
"Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.", | |
"This is Onyx speaking. My voice has a powerful, deep quality with gravitas.", | |
"Onyx here - My voice is characterized by its depth and commanding presence." | |
], | |
"nova": [ | |
"Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.", | |
"This is Nova speaking. My voice has a smooth, harmonious quality.", | |
"Nova here - My voice is characterized by its warm, friendly mid-tones." | |
], | |
"shimmer": [ | |
"Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.", | |
"This is Shimmer speaking. My voice has an airy, higher-pitched quality.", | |
"Shimmer here - My voice is characterized by its bright, crystalline tones." | |
] | |
} | |
def initialize_voice_profiles(): | |
"""Initialize voice profiles with default settings. | |
This function loads existing voice profiles from disk if available, | |
or initializes them with default settings. | |
""" | |
global VOICE_PROFILES | |
# Try to load existing profiles from persistent storage | |
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
if os.path.exists(profile_path): | |
try: | |
logger.info(f"Loading voice profiles from {profile_path}") | |
saved_profiles = torch.load(profile_path) | |
# Update existing profiles with saved data | |
for name, data in saved_profiles.items(): | |
if name in VOICE_PROFILES: | |
VOICE_PROFILES[name].reference_segments = [ | |
seg.to(torch.device("cpu")) for seg in data.get('reference_segments', []) | |
] | |
logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices") | |
except Exception as e: | |
logger.error(f"Error loading voice profiles: {e}") | |
logger.info("Using default voice profiles") | |
else: | |
logger.info("No saved voice profiles found, using defaults") | |
# Ensure all voices have at least empty reference segments | |
for name, profile in VOICE_PROFILES.items(): | |
if not hasattr(profile, 'reference_segments'): | |
profile.reference_segments = [] | |
logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices") | |
return VOICE_PROFILES | |
def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor: | |
"""Apply professional-grade normalization to audio. | |
Args: | |
audio: Audio tensor | |
target_rms: Target RMS level for normalization | |
target_peak: Target peak level for limiting | |
Returns: | |
Normalized audio tensor | |
""" | |
# Ensure audio is on CPU for processing | |
audio_cpu = audio.detach().cpu() | |
# Handle silent audio | |
if audio_cpu.abs().max() < 1e-6: | |
logger.warning("Audio is nearly silent, returning original") | |
return audio | |
# Calculate current RMS | |
current_rms = torch.sqrt(torch.mean(audio_cpu ** 2)) | |
# Apply RMS normalization | |
if current_rms > 0: | |
gain = target_rms / current_rms | |
normalized = audio_cpu * gain | |
else: | |
normalized = audio_cpu | |
# Apply peak limiting | |
current_peak = normalized.abs().max() | |
if current_peak > target_peak: | |
normalized = normalized * (target_peak / current_peak) | |
# Return to original device | |
return normalized.to(audio.device) | |
def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor: | |
"""Apply anti-muffling to improve clarity. | |
Args: | |
audio: Audio tensor | |
sample_rate: Audio sample rate | |
clarity_boost: Amount of high frequency boost (1.0 = no boost) | |
Returns: | |
Processed audio tensor | |
""" | |
# Convert to numpy for filtering | |
audio_np = audio.detach().cpu().numpy() | |
try: | |
# Design a high shelf filter to boost high frequencies | |
# Use a standard high-shelf filter that's supported by scipy.signal | |
# We'll use a second-order Butterworth high-pass filter as an alternative | |
cutoff = 2000 # Hz | |
b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False) | |
# Apply the filter with the clarity boost gain | |
boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost | |
# Mix with original to maintain some warmth | |
mix_ratio = 0.7 # 70% processed, 30% original | |
processed = mix_ratio * boosted + (1-mix_ratio) * audio_np | |
except Exception as e: | |
logger.warning(f"Audio enhancement failed, using original: {e}") | |
# Return original audio if enhancement fails | |
return audio | |
# Convert back to tensor on original device | |
return torch.tensor(processed, dtype=audio.dtype, device=audio.device) | |
def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor: | |
"""Apply comprehensive audio enhancement based on voice profile. | |
Args: | |
audio: Audio tensor | |
sample_rate: Audio sample rate | |
voice_profile: Voice profile containing enhancement parameters | |
Returns: | |
Enhanced audio tensor | |
""" | |
if audio is None or audio.numel() == 0: | |
logger.error("Cannot enhance empty audio") | |
return audio | |
try: | |
# Step 1: Normalize audio levels | |
params = voice_profile.get_enhancement_params() | |
normalized = normalize_audio( | |
audio, | |
target_rms=params["target_rms"], | |
target_peak=params["target_peak"] | |
) | |
# Step 2: Apply anti-muffling based on spectral tilt | |
# Positive tilt means brighter voice so less clarity boost needed | |
clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5 | |
clarified = apply_anti_muffling( | |
normalized, | |
sample_rate, | |
clarity_boost=clarity_boost | |
) | |
# Log the enhancement | |
logger.debug( | |
f"Enhanced audio for {voice_profile.name}: " | |
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, " | |
f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}" | |
) | |
return clarified | |
except Exception as e: | |
logger.error(f"Error in audio enhancement: {e}") | |
return audio # Return original audio if enhancement fails | |
def validate_generated_audio( | |
audio: torch.Tensor, | |
voice_name: str, | |
sample_rate: int, | |
min_expected_duration: float = 0.5 | |
) -> Tuple[bool, torch.Tensor, str]: | |
"""Validate and fix generated audio. | |
Args: | |
audio: Audio tensor to validate | |
voice_name: Name of the voice used | |
sample_rate: Audio sample rate | |
min_expected_duration: Minimum expected duration in seconds | |
Returns: | |
Tuple of (is_valid, fixed_audio, message) | |
""" | |
if audio is None: | |
return False, torch.zeros(1), "Audio is None" | |
# Check for NaN values | |
if torch.isnan(audio).any(): | |
logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros") | |
audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio) | |
# Check audio duration | |
duration = audio.shape[0] / sample_rate | |
if duration < min_expected_duration: | |
logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)") | |
return False, audio, f"Audio too short: {duration:.2f}s" | |
# Check for silent sections - this can indicate generation problems | |
rms = torch.sqrt(torch.mean(audio ** 2)) | |
if rms < 0.01: # Very low RMS indicates near silence | |
logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})") | |
return False, audio, f"Audio nearly silent: RMS = {rms:.6f}" | |
# Check if audio suddenly cuts off - this detects premature stopping | |
# Calculate RMS in the last 100ms | |
last_samples = int(0.1 * sample_rate) | |
if audio.shape[0] > last_samples: | |
end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2)) | |
if end_rms > 0.1: # High RMS at the end suggests an abrupt cutoff | |
logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})") | |
return True, audio, "Audio may have cut off prematurely" | |
return True, audio, "Audio validation passed" | |
def create_voice_segments(app_state, regenerate: bool = False): | |
"""Create high-quality voice reference segments. | |
Args: | |
app_state: Application state containing generator | |
regenerate: Whether to regenerate existing references | |
""" | |
generator = app_state.generator | |
if not generator: | |
logger.error("Cannot create voice segments: generator not available") | |
return | |
# Use persistent directory for voice reference segments | |
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
for voice_name, profile in VOICE_PROFILES.items(): | |
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
os.makedirs(voice_dir, exist_ok=True) | |
# Check if we already have references | |
if not regenerate and profile.reference_segments: | |
logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments") | |
continue | |
# Get prompts for this voice | |
prompts = VOICE_PROMPTS[voice_name] | |
# Generate reference segments | |
logger.info(f"Generating reference segments for voice: {voice_name}") | |
reference_segments = [] | |
for i, prompt in enumerate(prompts): | |
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
# Skip if file exists and we're not regenerating | |
if not regenerate and os.path.exists(ref_path): | |
try: | |
# Load existing reference | |
audio_tensor, sr = torchaudio.load(ref_path) | |
if sr != generator.sample_rate: | |
audio_tensor = torchaudio.functional.resample( | |
audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate | |
) | |
else: | |
audio_tensor = audio_tensor.squeeze(0) | |
reference_segments.append(audio_tensor.to(generator.device)) | |
logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}") | |
continue | |
except Exception as e: | |
logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}") | |
try: | |
# Use a lower temperature for more stability in reference samples | |
logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'") | |
# We want references to be as clean as possible | |
audio = generator.generate( | |
text=prompt, | |
speaker=profile.speaker_id, | |
context=[], # No context for initial samples to prevent voice bleed | |
max_audio_length_ms=6000, # Shorter for more control | |
temperature=0.7, # Lower temperature for more stability | |
topk=30, # More focused sampling | |
) | |
# Validate and enhance the audio | |
is_valid, audio, message = validate_generated_audio( | |
audio, voice_name, generator.sample_rate | |
) | |
if is_valid: | |
# Enhance the audio | |
audio = enhance_audio(audio, generator.sample_rate, profile) | |
# Save the reference to persistent storage | |
torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
reference_segments.append(audio) | |
logger.info(f"Generated reference {i+1} for {voice_name}: {message}") | |
else: | |
logger.warning(f"Invalid reference for {voice_name}: {message}") | |
# Try again with different settings if invalid | |
if i < len(prompts) - 1: | |
logger.info(f"Trying again with next prompt") | |
continue | |
except Exception as e: | |
logger.error(f"Error generating reference for {voice_name}: {e}") | |
# Update the voice profile with references | |
if reference_segments: | |
VOICE_PROFILES[voice_name].reference_segments = reference_segments | |
logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments") | |
# Save the updated profiles to persistent storage | |
save_voice_profiles() | |
def get_voice_segments(voice_name: str, device: torch.device) -> List: | |
"""Get context segments for a given voice. | |
Args: | |
voice_name: Name of the voice to use | |
device: Device to place tensors on | |
Returns: | |
List of context segments | |
""" | |
from app.model import Segment | |
if voice_name not in VOICE_PROFILES: | |
logger.warning(f"Voice {voice_name} not found, defaulting to alloy") | |
voice_name = "alloy" | |
profile = VOICE_PROFILES[voice_name] | |
# If we don't have reference segments yet, create them | |
if not profile.reference_segments: | |
try: | |
# Try to load from disk - use persistent storage | |
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
if os.path.exists(voice_dir): | |
reference_segments = [] | |
prompts = VOICE_PROMPTS[voice_name] | |
for i, prompt in enumerate(prompts): | |
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
if os.path.exists(ref_path): | |
audio_tensor, sr = torchaudio.load(ref_path) | |
audio_tensor = audio_tensor.squeeze(0) | |
reference_segments.append(audio_tensor) | |
if reference_segments: | |
profile.reference_segments = reference_segments | |
logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}") | |
except Exception as e: | |
logger.error(f"Error loading reference segments for {voice_name}: {e}") | |
# Create context segments from references | |
context = [] | |
if profile.reference_segments: | |
for i, ref_audio in enumerate(profile.reference_segments): | |
# Use corresponding prompt if available, otherwise use a generic one | |
text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}" | |
context.append( | |
Segment( | |
speaker=profile.speaker_id, | |
text=text, | |
audio=ref_audio.to(device) | |
) | |
) | |
logger.info(f"Returning {len(context)} context segments for {voice_name}") | |
return context | |
def save_voice_profiles(): | |
"""Save voice profiles to persistent storage.""" | |
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
# Create a serializable version of the profiles | |
serializable_profiles = {} | |
for name, profile in VOICE_PROFILES.items(): | |
serializable_profiles[name] = { | |
'reference_segments': [seg.cpu() for seg in profile.reference_segments] | |
} | |
# Save to persistent storage | |
torch.save(serializable_profiles, profile_path) | |
logger.info(f"Saved voice profiles to {profile_path}") | |
def process_generated_audio( | |
audio: torch.Tensor, | |
voice_name: str, | |
sample_rate: int, | |
text: str | |
) -> torch.Tensor: | |
"""Process generated audio for consistency and quality. | |
Args: | |
audio: Audio tensor | |
voice_name: Name of voice used | |
sample_rate: Audio sample rate | |
text: Text that was spoken | |
Returns: | |
Processed audio tensor | |
""" | |
# Validate the audio | |
is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate) | |
if not is_valid: | |
logger.warning(f"Generated audio validation issue: {message}") | |
# Get voice profile for enhancement | |
profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"]) | |
# Enhance the audio based on voice profile | |
enhanced = enhance_audio(audio, sample_rate, profile) | |
# Log the enhancement | |
original_duration = audio.shape[0] / sample_rate | |
enhanced_duration = enhanced.shape[0] / sample_rate | |
logger.info( | |
f"Processed audio for '{voice_name}': " | |
f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, " | |
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}" | |
) | |
return enhanced |