Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torchaudio | |
import torchaudio.transforms as T | |
import matplotlib.pyplot as plt | |
import os | |
from typing import List, Tuple | |
from config import LOGS_DIR | |
##Some utils: | |
def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]: | |
""" | |
Load multiple audio files and ensure they have the same length. | |
Args: | |
file_paths: List of paths to audio files | |
Returns: | |
List of tuples containing audio data and sample rate | |
""" | |
audio_data = [] | |
for path in file_paths: | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(path) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
audio_data.append((waveform.squeeze(), sample_rate)) | |
# Verify all audio files have the same length and sample rate | |
lengths = [len(audio) for audio, _ in audio_data] | |
sample_rates = [sr for _, sr in audio_data] | |
if len(set(lengths)) > 1: | |
raise ValueError(f"Audio files have different lengths: {lengths}") | |
if len(set(sample_rates)) > 1: | |
raise ValueError(f"Audio files have different sample rates: {sample_rates}") | |
return audio_data | |
def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]: | |
""" | |
Normalize the volume of each audio file to have the same energy level. | |
Args: | |
audio_data: List of tuples containing audio data and sample rate | |
Returns: | |
List of tuples containing normalized audio data and sample rate | |
""" | |
normalized_data = [] | |
# Calculate RMS (Root Mean Square) for each audio | |
rms_values = [] | |
for audio, sr in audio_data: | |
# Calculate energy (squared amplitude) | |
energy = torch.mean(audio ** 2) | |
# Calculate RMS (square root of mean energy) | |
rms = torch.sqrt(energy) | |
rms_values.append(rms) | |
# Find the target RMS (we'll use the median to avoid outliers) | |
target_rms = torch.median(torch.tensor(rms_values)) | |
# Normalize each audio to the target RMS | |
for (audio, sr), rms in zip(audio_data, rms_values): | |
if rms > 0: # Avoid division by zero | |
# Calculate scaling factor | |
scaling_factor = target_rms / rms | |
# Apply scaling | |
normalized_audio = audio * scaling_factor | |
else: | |
normalized_audio = audio | |
normalized_data.append((normalized_audio, sr)) | |
return normalized_data | |
def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None: | |
""" | |
Plot a comparison of energy metrics before and after normalization. | |
Args: | |
original_metrics: List of dictionaries containing metrics for original audio | |
normalized_metrics: List of dictionaries containing metrics for normalized audio | |
file_names: List of audio file names | |
output_path: Path to save the plot | |
""" | |
fig, axs = plt.subplots(2, 2, figsize=(14, 10)) | |
# Extract metrics | |
orig_rms = [m['rms'] for m in original_metrics] | |
norm_rms = [m['rms'] for m in normalized_metrics] | |
orig_peak = [m['peak'] for m in original_metrics] | |
norm_peak = [m['peak'] for m in normalized_metrics] | |
orig_dr = [m['dynamic_range_db'] for m in original_metrics] | |
norm_dr = [m['dynamic_range_db'] for m in normalized_metrics] | |
orig_cf = [m['crest_factor'] for m in original_metrics] | |
norm_cf = [m['crest_factor'] for m in normalized_metrics] | |
# Prepare x-axis | |
x = np.arange(len(file_names)) | |
width = 0.35 | |
# Plot RMS (volume) | |
axs[0, 0].bar(x - width/2, orig_rms, width, label='Original') | |
axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized') | |
axs[0, 0].set_title('RMS Energy (Volume)') | |
axs[0, 0].set_xticks(x) | |
axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right') | |
axs[0, 0].set_ylabel('RMS Value') | |
axs[0, 0].legend() | |
# Plot Peak Amplitude | |
axs[0, 1].bar(x - width/2, orig_peak, width, label='Original') | |
axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized') | |
axs[0, 1].set_title('Peak Amplitude') | |
axs[0, 1].set_xticks(x) | |
axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right') | |
axs[0, 1].set_ylabel('Peak Value') | |
axs[0, 1].legend() | |
# Plot Dynamic Range | |
axs[1, 0].bar(x - width/2, orig_dr, width, label='Original') | |
axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized') | |
axs[1, 0].set_title('Dynamic Range (dB)') | |
axs[1, 0].set_xticks(x) | |
axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right') | |
axs[1, 0].set_ylabel('dB') | |
axs[1, 0].legend() | |
# Plot Crest Factor | |
axs[1, 1].bar(x - width/2, orig_cf, width, label='Original') | |
axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized') | |
axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)') | |
axs[1, 1].set_xticks(x) | |
axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right') | |
axs[1, 1].set_ylabel('Ratio') | |
axs[1, 1].legend() | |
plt.tight_layout() | |
# Create directory if it doesn't exist | |
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
# Save the plot | |
plt.savefig(output_path) | |
plt.close() | |
def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]: | |
""" | |
Calculate various audio metrics for each audio file. | |
Args: | |
audio_data: List of tuples containing audio data and sample rate | |
Returns: | |
List of dictionaries containing metrics | |
""" | |
metrics = [] | |
for audio, sr in audio_data: | |
# Calculate RMS (Root Mean Square) | |
energy = torch.mean(audio ** 2) | |
rms = torch.sqrt(energy) | |
# Calculate peak amplitude | |
peak = torch.max(torch.abs(audio)) | |
# Calculate dynamic range | |
if torch.min(torch.abs(audio[audio != 0])) > 0: | |
min_non_zero = torch.min(torch.abs(audio[audio != 0])) | |
dynamic_range_db = 20 * torch.log10(peak / min_non_zero) | |
else: | |
dynamic_range_db = torch.tensor(float('inf')) | |
# Calculate crest factor (peak to RMS ratio) | |
crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf')) | |
metrics.append({ | |
'rms': rms.item(), | |
'peak': peak.item(), | |
'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'), | |
'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf') | |
}) | |
return metrics | |
def create_weighted_composite( | |
audio_data: List[Tuple[torch.Tensor, int]], | |
weights: List[float] | |
) -> torch.Tensor: | |
""" | |
Create a weighted composite of multiple audio files. | |
Args: | |
audio_data: List of tuples containing audio data and sample rate | |
weights: List of weights for each audio file | |
Returns: | |
Weighted composite audio data | |
""" | |
if len(audio_data) != len(weights): | |
raise ValueError("Number of audio files and weights must match") | |
# Normalize weights to sum to 1 | |
weights = torch.tensor(weights) / sum(weights) | |
# Initialize composite audio with zeros | |
composite = torch.zeros_like(audio_data[0][0]) | |
# Add weighted audio data | |
for (audio, _), weight in zip(audio_data, weights): | |
composite += audio * weight | |
# Normalize to prevent clipping | |
max_val = torch.max(torch.abs(composite)) | |
if max_val > 1.0: | |
composite = composite / max_val | |
return composite | |
def create_melspectrograms( | |
audio_data: List[Tuple[torch.Tensor, int]], | |
composite: torch.Tensor, | |
sr: int | |
) -> List[torch.Tensor]: | |
""" | |
Create melspectrograms for individual audio files and the composite. | |
Args: | |
audio_data: List of tuples containing audio data and sample rate | |
composite: Composite audio data | |
sr: Sample rate | |
Returns: | |
List of melspectrogram data | |
""" | |
specs = [] | |
# Create mel spectrogram transform | |
mel_transform = T.MelSpectrogram( | |
sample_rate=sr, | |
n_fft=2048, | |
win_length=2048, | |
hop_length=512, | |
n_mels=128, | |
f_max=8000 | |
) | |
# Generate spectrograms for individual audio files | |
for audio, _ in audio_data: | |
melspec = mel_transform(audio) | |
specs.append(melspec) | |
# Generate spectrogram for composite audio | |
composite_melspec = mel_transform(composite) | |
specs.append(composite_melspec) | |
return specs | |
def plot_melspectrograms( | |
specs: List[torch.Tensor], | |
sr: int, | |
file_names: List[str], | |
weights: List[float], | |
output_path: str = "melspectrograms.png" | |
) -> None: | |
""" | |
Plot melspectrograms for individual audio files and the composite. | |
Args: | |
specs: List of melspectrogram data | |
sr: Sample rate | |
file_names: List of audio file names | |
weights: List of weights for each audio file | |
output_path: Path to save the plot | |
""" | |
fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs))) | |
# Create labels for the plots | |
labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)] | |
labels.append("Composite.wav") | |
# Convert to dB scale (similar to librosa's power_to_db) | |
def power_to_db(spec): | |
return 10 * torch.log10(spec + 1e-10) | |
# Plot each melspectrogram | |
for i, (spec, label) in enumerate(zip(specs, labels)): | |
spec_db = power_to_db(spec).numpy().squeeze() | |
# For single subplot case | |
if len(specs) == 1: | |
ax = axs | |
else: | |
ax = axs[i] | |
img = ax.imshow( | |
spec_db, | |
aspect='auto', | |
origin='lower', | |
interpolation='none', | |
extent=[0, spec_db.shape[1], 0, sr/2] | |
) | |
ax.set_title(label) | |
ax.set_ylabel('Frequency (Hz)') | |
ax.set_xlabel('Time Frames') | |
# No colorbar as requested | |
plt.tight_layout() | |
# Create directory if it doesn't exist | |
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
# Save the plot | |
plt.savefig(output_path,dpi=300) | |
plt.close() | |
def compose_audio( | |
file_paths: List[str], | |
weights: List[float], | |
output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"), | |
output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"), | |
energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png") | |
) -> None: | |
""" | |
Main function to process audio files and create visualizations. | |
Args: | |
file_paths: List of paths to audio files (supports 4 audio files) | |
weights: List of weights for each audio file | |
output_audio_path: Path to save the composite audio | |
output_plot_path: Path to save the melspectrogram plot | |
energy_plot_path: Path to save the energy comparison plot | |
""" | |
# Load audio files | |
audio_data = load_audio_files(file_paths) | |
# # Calculate metrics for original audio | |
print("Calculating metrics for original audio...") | |
original_metrics = calculate_audio_metrics(audio_data) | |
# Normalize audio volumes to have same energy level | |
print("Normalizing audio volumes...") | |
normalized_audio_data = normalize_audio_volumes(audio_data) | |
# Calculate metrics for normalized audio | |
print("Calculating metrics for normalized audio...") | |
normalized_metrics = calculate_audio_metrics(normalized_audio_data) | |
# Print energy comparison | |
print("\nAudio Energy Comparison (RMS values):") | |
print("-" * 50) | |
print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}") | |
print("-" * 50) | |
for i, path in enumerate(file_paths): | |
file_name = path.split("/")[-1] | |
orig_rms = original_metrics[i]['rms'] | |
norm_rms = normalized_metrics[i]['rms'] | |
scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf') | |
print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}") | |
# Create energy comparison plot | |
print("\nCreating energy comparison plot...") | |
file_names = [path.split("/")[-1] for path in file_paths] | |
plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path) | |
# Get sample rate (all files have the same sample rate) | |
sr = normalized_audio_data[0][1] | |
# Create weighted composite | |
print("\nCreating weighted composite...") | |
composite = create_weighted_composite(normalized_audio_data, weights) | |
# Create directory if it doesn't exist | |
os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True) | |
# Save composite audio | |
print("Saving composite audio...") | |
torchaudio.save(output_audio_path, composite.unsqueeze(0), sr) | |
# Create melspectrograms for normalized audio (not original) | |
print("Creating melspectrograms for normalized audio...") | |
specs = create_melspectrograms(normalized_audio_data, composite, sr) | |
# Get file names without path | |
labeled_file_names = [path.split("/")[-1] for path in file_paths] | |
# Plot melspectrograms | |
print("Plotting melspectrograms...") | |
plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path) | |
print(f"\nComposite audio saved to {output_audio_path}") | |
print(f"Melspectrograms saved to {output_plot_path}") | |
print(f"Energy comparison saved to {energy_plot_path}") | |
print(f"Composite audio saved to {output_audio_path}") | |
print(f"Melspectrograms saved to {output_plot_path}") | |
# if __name__ == "__main__": | |
# import argparse | |
# parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms") | |
# parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files") | |
# parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file") | |
# parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio") | |
# parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot") | |
# args = parser.parse_args() | |
# os.makedirs("./logs", exist_ok=True) | |
# main(args.files, args.weights, args.output_audio, args.output_plot) | |
# Example usage: | |
# python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1 |