Spaces:
Running
on
Zero
Running
on
Zero
import librosa | |
import torch | |
import json | |
import random | |
import math | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
from torch import lerp | |
from torch.nn import ReflectionPad1d | |
import torch.nn.functional as F | |
def load_json(fname): | |
with open(fname, "r") as f: | |
data = json.load(f) | |
return data | |
# def plot_spectrogram(fbank, filename=None, title=None, ylabel="freq_bin", ax=None): | |
# r""" | |
# Params: `fbank`: (`n_mel_bins`, `n_frames`) | |
# """ | |
# if fbank.ndim > 2: | |
# fbank = fbank.detach().cpu().squeeze() | |
# else: | |
# fbank = fbank.detach().cpu() | |
# if ax is None: | |
# _, ax = plt.subplots(1, 1) | |
# if title is not None: | |
# ax.set_title(title) | |
# ax.set_ylabel(ylabel) | |
# ax.imshow(fbank, origin="lower", aspect="auto", interpolation="nearest") | |
# if filename is not None: | |
# ax.figure.savefig(filename) | |
# return ax | |
def plot_spectrogram(fbank, filename=None, title=None, ylabel=None, auto_amp=False, figsize=(16, 9)): | |
r""" | |
Params: `fbank`: (`n_mel_bins`, `n_frames`) | |
""" | |
if fbank.ndim > 2: | |
fbank = fbank.detach().cpu().squeeze() | |
else: | |
fbank = fbank.detach().cpu() | |
fig, ax = plt.subplots(1, 1, figsize=figsize) | |
fbank = fbank.numpy() | |
if auto_amp: | |
img=librosa.display.specshow(fbank, ax=ax) | |
else: | |
img=librosa.display.specshow(fbank, ax=ax, vmin=-10, vmax=0) # x_axis='time', y_axis='mel', | |
if title is not None: | |
ax.set_title(title) | |
if ylabel is not None: | |
ax.set_ylabel(ylabel) | |
# fig.colorbar(img, ax=ax, format="%+2.f dB") | |
plt.tight_layout() | |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # Adjust subplots to fill the figure | |
if filename is not None: | |
ax.figure.savefig(filename) | |
return ax | |
def get_current_time(out_format="%Y-%m-%d %H:%M:%S"): | |
current_time = datetime.now() | |
formatted_time = current_time.strftime(out_format) | |
return formatted_time | |
def get_box_boundry(mask: torch.Tensor): | |
r"""Get the box boundy of masked region.""" | |
ws, hs = torch.nonzero(mask, as_tuple=True) | |
w_l, w_r = torch.min(ws), torch.max(ws) | |
h_b, h_t = torch.min(hs), torch.max(hs) | |
return (w_l, w_r), (h_b, h_t) | |
def get_neibor_with_mask(matrix, mask, reverse=False): | |
assert matrix.shape == mask.shape | |
# Pad the unmasked region using reflection if applicable | |
if reverse: | |
mask = ~mask.bool() | |
(w_l, w_r), (h_b, h_t) = get_box_boundry(mask) | |
mask_w_cntr = (w_r + w_l) // 2 | |
pad_l_fn = ReflectionPad1d((0, mask_w_cntr - w_l)) | |
matrix_l_cur = pad_l_fn(matrix[: w_l + 1, :].permute(1, 0)).permute(1, 0) | |
pad_r_fn = ReflectionPad1d((w_r - mask_w_cntr - 1)) | |
import ipdb | |
ipdb.set_trace() | |
matrix_r_cur = pad_r_fn(matrix[w_r + 1 :, :].permute(1, 0)).permute(1, 0) | |
# import ipdb; ipdb.set_trace() | |
matrix_cur = torch.cat([matrix_l_cur, matrix_r_cur], dim=0) | |
return matrix[mask] + matrix_cur[~mask] | |
# def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
# ''' | |
# Spherical linear interpolation | |
# Args: | |
# t (float/np.ndarray): Float value between 0.0 and 1.0 | |
# v0 (np.ndarray): Starting vector | |
# v1 (np.ndarray): Final vector | |
# DOT_THRESHOLD (float): Threshold for considering the two vectors as | |
# colineal. Not recommended to alter this. | |
# Returns: | |
# v2 (np.ndarray): Interpolation vector between v0 and v1 | |
# ''' | |
# is_tensor = False | |
# if not isinstance(v0,np.ndarray): | |
# is_tensor = True | |
# device = v0.device | |
# v0 = v0.detach().cpu().numpy() | |
# if not isinstance(v1,np.ndarray): | |
# is_tensor = True | |
# device = v1.device # overwrite if v0 is also Tensor | |
# v1 = v1.detach().cpu().numpy() | |
# # Copy the vectors to reuse them later | |
# v0_copy = np.copy(v0) | |
# v1_copy = np.copy(v1) | |
# # Normalize the vectors to get the directions and angles | |
# v0 = v0 / np.linalg.norm(v0) | |
# v1 = v1 / np.linalg.norm(v1) | |
# # Dot product with the normalized vectors (can't use np.dot in W) | |
# dot = np.sum(v0 * v1) | |
# # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp | |
# if np.abs(dot) > DOT_THRESHOLD: | |
# return lerp(t, v0_copy, v1_copy) | |
# # Calculate initial angle between v0 and v1 | |
# theta_0 = np.arccos(dot) | |
# sin_theta_0 = np.sin(theta_0) | |
# # Angle at timestep t | |
# theta_t = theta_0 * t | |
# sin_theta_t = np.sin(theta_t) | |
# s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
# s1 = sin_theta_t / sin_theta_0 | |
# v2 = s0*v0_copy + s1*v1_copy | |
if is_tensor: | |
res = torch.from_numpy(v2).to(device) | |
else: | |
res = v2 | |
return res | |
def normalize_along_channel(in_feat, eps=1e-10): | |
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) | |
return in_feat / (norm_factor + eps) | |
# def extract_and_fill(spectrum, a, b, sr, hop_length): | |
# """ | |
# Extract a 1-second segment from (a, b) and fill the rest of the segment using repeat or reflection. | |
# Parameters: | |
# spectrum (Tensor): The input spectrum tensor. | |
# a (float): The start time of the region with energy. | |
# b (float): The end time of the region with energy. | |
# sr (int): The sample rate of the spectrum. | |
# hop_length (int) | |
# Returns: | |
# Tensor: The processed spectrum tensor. | |
# """ | |
# n_frames = spectrum.size(1) | |
# n_frames_per_sec = sr // hop_length | |
# mask = (spectrum!=0).float() | |
# # Convert time to samples | |
# a_frame = math.floor(a * sr / hop_length) | |
# b_frame = math.ceil(b * sr / hop_length) | |
# assert a_frame < n_frames and b_frame < n_frames | |
# duration = b_frame - a_frame | |
# # If the energy region is shorter than 1 second, adjust | |
# extract_duration = duration // 2 if duration <= n_frames_per_sec else n_frames_per_sec | |
# padding = duration - extract_duration | |
# start_frame = random.randint(a_frame, b_frame-extract_duration) | |
# segment = spectrum[:, start_frame:start_frame+extract_duration, :] | |
# segment = segment.repeat(1, n_frames//extract_duration+1, 1)[:, :n_frames, :] | |
# segment *= mask | |
# return segment | |
def extract_and_fill(spec, stt_frame, end_frame, tgt_length): | |
""" | |
Extract a region with <= `tgt_length` from (`stt_frame`, `end_frame`) and fill the rest of the spec by repeating the extracted region. | |
Param: | |
spec: Tensor: input spectrogram, shape = (C,T,F). | |
a: float: The start time of the region with energy. | |
b: float: The end time of the region with energy. | |
Returns: | |
Tensor: the processed spectrum tensor. | |
""" | |
assert (spec.ndim == 3 or spec.ndim == 4), "Format the input `spec` with the shape = (C, T, F) or (B,C,T,F)." | |
total_length = spec.size(-2) | |
assert stt_frame < total_length and end_frame < total_length | |
duration = end_frame - stt_frame | |
mask = (spec != 0).float() | |
# If the energy region is shorter than 1 second, adjust | |
extract_duration = duration // 2 if duration <= tgt_length else tgt_length | |
start_frame = random.randint(stt_frame, end_frame - extract_duration) | |
if spec.ndim == 3: | |
segment = spec[:, start_frame : start_frame + extract_duration, :] | |
segment = segment.repeat(1, total_length // extract_duration + 1, 1)[ | |
:, :total_length, : | |
] | |
else: | |
segment = spec[:, :, start_frame : start_frame + extract_duration, :] | |
segment = segment.repeat(1, 1, total_length // extract_duration + 1, 1)[ | |
:, :, :total_length, : | |
] | |
segment *= mask | |
return segment | |
def fill_with_neighbor(spec, stt_frame, end_frame, neighbor_length): | |
""" | |
Fill a region from (`stt_frame`, `end_frame`) with neighbor of `neighbor_length` | |
Param: | |
spec: Tensor: input spectrogram, shape = (C,T,F). | |
stt_frame: int: The start frame of the region with energy. | |
end_frame: int: The end frame of the region with energy. | |
neighbor_length: int: selected length of neighbor | |
Returns: | |
Tensor: the processed spectrum tensor. | |
""" | |
assert spec.ndim == 3, "Format the input `spec` with the shape = (C, T, F)." | |
total_length = spec.size(1) | |
assert stt_frame < total_length and end_frame < total_length | |
duration = end_frame - stt_frame | |
mask = torch.zeros_like(spec) | |
mask[:, stt_frame : end_frame + 1, :] = 1 | |
left_duration = min(math.ceil(neighbor_length / 2), stt_frame) | |
right_duration = min(neighbor_length - left_duration, total_length - end_frame - 1) | |
if left_duration + right_duration < 1: | |
print("Warning: cannot find effect positive part!") | |
return torch.randn_like(segment) | |
left_segment = spec[:, stt_frame - left_duration : stt_frame, :] | |
right_segment = spec[:, end_frame + 1 : end_frame + right_duration + 1, :] | |
segment = torch.cat([left_segment, right_segment], dim=1) | |
segment = segment.repeat( | |
1, total_length // (left_duration + right_duration) + 1, 1 | |
)[:, :total_length, :] | |
segment = segment * mask + spec * (1 - mask) | |
return segment | |
# def slerp(t, A, B, eps=1e-8): | |
# """ | |
# Spherical Linear Interpolation (SLERP) between points A and B on a sphere. | |
# """ | |
# A = A / (torch.norm(A, p=2) + eps) | |
# B = B / (torch.norm(B, p=2) + eps) | |
# dot_product = torch.sum(A * B) | |
# dot_product = torch.clamp(dot_product, -1.0, 1.0) | |
# theta = torch.acos(dot_product) | |
# if torch.abs(theta) < 1e-10: | |
# return (1 - t) * A + t * B | |
# sin_theta = torch.sin(theta) | |
# A_factor = torch.sin((1 - t) * theta) / sin_theta | |
# B_factor = torch.sin(t * theta) / sin_theta | |
# return A_factor * A + B_factor * B | |
def lerp(t, v0, v1): | |
""" | |
Linear interpolation in PyTorch. | |
Args: | |
t (float/torch.Tensor): Float value between 0.0 and 1.0 | |
v0 (torch.Tensor): Starting vector | |
v1 (torch.Tensor): Final vector | |
Returns: | |
v2 (torch.Tensor): Interpolation vector between v0 and v1 | |
""" | |
return (1 - t) * v0 + t * v1 | |
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
""" | |
Spherical linear interpolation in PyTorch. | |
Args: | |
t (float/torch.Tensor): Float value between 0.0 and 1.0 | |
v0 (torch.Tensor): Starting vector | |
v1 (torch.Tensor): Final vector | |
DOT_THRESHOLD (float): Threshold for considering the two vectors as collinear. Not recommended to alter this. | |
Returns: | |
v2 (torch.Tensor): Interpolation vector between v0 and v1 | |
""" | |
device = v0.device | |
# Normalize the vectors to get the directions and angles | |
v0_norm = v0 / torch.norm(v0) | |
v1_norm = v1 / torch.norm(v1) | |
# Dot product with the normalized vectors | |
dot = torch.sum(v0_norm * v1_norm) | |
# If absolute value of dot product is almost 1, vectors are ~collinear, so use lerp | |
if torch.abs(dot) > DOT_THRESHOLD: | |
return lerp(t, v0, v1) | |
# Calculate initial angle between v0 and v1 | |
theta_0 = torch.acos(dot) | |
sin_theta_0 = torch.sin(theta_0) | |
# Angle at timestep t | |
theta_t = theta_0 * t | |
sin_theta_t = torch.sin(theta_t) | |
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
v2 = s0 * v0 + s1 * v1 | |
return v2 | |
def geodesic_distance(X, Y): | |
""" | |
Compute the geodesic distance between two points X and Y on a sphere. | |
""" | |
dot_product = torch.sum(X * Y) | |
dot_product = torch.clamp(dot_product, -1.0, 1.0) | |
return torch.acos(dot_product) | |
def optimize_neighborhood_points( | |
A, | |
B, | |
M, | |
t, | |
learning_rate=1e-4, | |
iterations=100, | |
enable_penalty=False, | |
enable_tangent_proj=True, | |
): | |
""" | |
Optimize the neighborhood points A_e and B_e to minimize the distance between | |
the SLERP interpolation and the given interpolation point M. | |
""" | |
# Initialize perturbations | |
epsilon_A = torch.zeros_like(A, requires_grad=True) | |
epsilon_B = torch.zeros_like(B, requires_grad=True) | |
optimizer = torch.optim.SGD([epsilon_A, epsilon_B], lr=learning_rate) # Adam | |
for i in range(iterations): | |
optimizer.zero_grad() | |
# Compute current neighborhood points | |
A_e = A + epsilon_A | |
B_e = B + epsilon_B | |
# Compute the SLERP interpolation | |
P = slerp(t, A_e, B_e) | |
# Compute the distance | |
dist = geodesic_distance(M, P) | |
if enable_penalty: | |
orthogonality_penalty = torch.sum(A_e * B_e) ** 2 | |
dist += orthogonality_penalty | |
# Backpropagation | |
dist.backward() | |
if enable_tangent_proj: | |
with torch.no_grad(): | |
epsilon_A.grad = project_onto_tangent_space(epsilon_A.grad, A_e) | |
epsilon_B.grad = project_onto_tangent_space(epsilon_B.grad, B_e) | |
# Clip gradients to prevent large updates | |
torch.nn.utils.clip_grad_norm_([epsilon_A, epsilon_B], max_norm=1.0) | |
# Check gradients for NaNs | |
if torch.isnan(epsilon_A.grad).any() or torch.isnan(epsilon_B.grad).any(): | |
print(f"NaN encountered in gradients at iteration {i}") | |
break | |
# Update perturbations | |
optimizer.step() | |
return A + epsilon_A.detach(), B + epsilon_B.detach() | |
# def optimize_neighborhood_points(A, B, M, t, learning_rate=1e-4, iterations=100, enable_penalty=False, eps=1e-8): | |
# """ | |
# Optimize the neighborhood points A_e and B_e to minimize the distance between | |
# the SLERP interpolation and the given interpolation point M. | |
# """ | |
# # Initialize perturbations | |
# epsilon_A = torch.zeros_like(A, requires_grad=True) | |
# epsilon_B = torch.zeros_like(B, requires_grad=True) | |
# optimizer = torch.optim.SGD([epsilon_A, epsilon_B], lr=learning_rate) | |
# for _ in range(iterations): | |
# optimizer.zero_grad() | |
# # Compute current neighborhood points | |
# A_e = A + epsilon_A | |
# B_e = B + epsilon_B | |
# # # Normalize to ensure they are on the unit sphere | |
# # A_e = A_e / (torch.norm(A_e, p=2) + eps) | |
# # B_e = B_e / (torch.norm(B_e, p=2) + eps) | |
# # Compute the SLERP interpolation | |
# P = slerp(t, A_e, B_e) | |
# # Compute the distance | |
# dist = geodesic_distance(M, P) | |
# if enable_penalty: | |
# orthogonality_penalty = torch.sum(A_e * B_e) ** 2 | |
# dist += orthogonality_penalty | |
# # Backpropagation | |
# dist.backward() | |
# # Update perturbations | |
# optimizer.step() | |
# return A + epsilon_A.detach(), B + epsilon_B.detach() | |
# def optimize_neighborhood_points(A, B, M, t, learning_rate=2e-5, iterations=100, enable_penalty=False): | |
# """ | |
# [Deprecated] this method tends to NaN | |
# Optimize the neighborhood points A_e and B_e to minimize the distance between | |
# the SLERP interpolation and the given interpolation point M. | |
# """ | |
# # Initialize perturbations | |
# A_e = A.clone().detach().requires_grad_(True) | |
# B_e = B.clone().detach().requires_grad_(True) | |
# optimizer = torch.optim.SGD([A_e, B_e], lr=learning_rate) | |
# for _ in range(iterations): | |
# optimizer.zero_grad() | |
# # Compute the SLERP interpolation | |
# P = slerp(t, A_e, B_e) | |
# # Compute the distance | |
# dist = geodesic_distance(M, P) | |
# if enable_penalty: | |
# orthogonality_penalty = torch.sum(A_e * B_e) ** 2 | |
# dist += orthogonality_penalty | |
# # Backpropagation | |
# dist.backward() | |
# with torch.no_grad(): | |
# A_e.grad = project_onto_tangent_space(A_e.grad, A_e) | |
# B_e.grad = project_onto_tangent_space(B_e.grad, B_e) | |
# # Update perturbations | |
# optimizer.step() | |
# return A_e.detach().requires_grad_(False), B_e.detach().requires_grad_(False) | |
def project_onto_tangent_space(g, h, eps=1e-8): | |
""" | |
Projects vector g onto the tangent space of vector h. | |
Args: | |
g (torch.Tensor): The vector to be projected. | |
h (torch.Tensor): The vector whose tangent space g is projected onto. | |
Returns: | |
torch.Tensor: The projection of g onto the tangent space of h. | |
""" | |
g = torch.tensor(g) | |
h = torch.tensor(h) | |
# Compute the dot product g . h | |
dot_product = torch.sum(g * h) | |
# Compute the squared norm of h, h . h | |
h_norm_squared = torch.sum(h * h) + eps | |
# Calculate the projection scalar | |
proj_scalar = dot_product / h_norm_squared | |
# Compute the component of g in the direction of h | |
g_para = proj_scalar * h | |
# Compute the projection of g onto the tangent space of h | |
g_ortho = g - g_para | |
return g_ortho | |
def label2caption(label, background_sound=None, template="{} can be heard"): | |
r"""This is a helper function converting list of labels to captions.""" | |
if background_sound is None: | |
return [template.format(", ".join(l)) for l in label] | |
if isinstance(background_sound, str): | |
background_sound = [[background_sound]] * len(label) | |
assert len(label) == len( | |
background_sound | |
), "the number of `background_sound` should match the number of `label`." | |
caption = [] | |
for l, bg in zip(label, background_sound): | |
cap = template.format(", ".join(l)) | |
cap += " with the background sounds of {}".format(", ".join(bg)) | |
caption.append(cap) | |
return caption | |
def load_json(fname): | |
with open(fname, "r") as f: | |
data = json.load(f) | |
return data | |
def identity_projection(g, *args, **kwargs): | |
return g | |
def convert_float_to_int(data): | |
data *= 32768 | |
data = np.nan_to_num(data, nan=0.0, posinf=32767, neginf=-32768) | |
data = np.clip(data, -32768, 32767) | |
return data | |
def get_edit_mask(mask, dx, dy, resize_scale_x, resize_scale_y): | |
_mask = ( | |
F.interpolate( | |
mask.unsqueeze(0).unsqueeze(0), | |
( | |
int(mask.shape[-2] * resize_scale_y), | |
int(mask.shape[-1] * resize_scale_x), | |
), | |
) | |
> 0.5 | |
) | |
_mask = torch.roll( | |
_mask, | |
(int(dy * resize_scale_y), int(dx * resize_scale_x)), | |
(-2, -1), | |
) | |
if resize_scale_x != 1 or resize_scale_y != 1: | |
mask_res = torch.zeros(1, 1, mask.shape[-2], mask.shape[-1]).to(mask.device) | |
pad_x = (mask_res.shape[-1] - _mask.shape[-1]) // 2 | |
pad_y = (mask_res.shape[-2] - _mask.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
mask_res[:,:,py_tmp:py_tmp+_mask.shape[-2],px_tmp:px_tmp+_mask.shape[-1]] = _mask[ | |
:,:,py_tar:py_tar+mask_res.shape[-2],px_tar:px_tar+mask_res.shape[-1]] | |
# # Binary mask | |
# mask_res = mask_res > 0.5 | |
# else: | |
# mask_res = _mask > 0.5 | |
else: | |
mask_res = _mask | |
return mask_res.squeeze() # (y,x) | |
if __name__ == "__main__": | |
# import torch | |
# spec = torch.rand(1024, 64) | |
# # import ipdb; ipdb.set_trace() | |
# plot_spectrogram(spec.permute(1,0),'test.png') | |
# m = torch.rand(4,4) | |
# mask = [[0,0,0,0],[0,1,1,0],[0,1,1,0],[0,0,0,0]] | |
# mask = torch.tensor(mask).bool() | |
# print(m) | |
# print(get_neibor_with_mask(m, mask)) | |
# audio = torch.zeros(1,1024,64) | |
# audio[:,250:750,:]=torch.rand(1,500,64) | |
# res=extract_and_fill(audio, a=5,b=7.5, sr=16000, hop_length=160) | |
# # import ipdb; ipdb.set_trace() | |
# Try SLERP | |
# A = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # Point on the unit sphere | |
# B = torch.tensor([0.0, 1.0, 0.0, 0.0], dtype=torch.float32) # Another point on the unit sphere | |
# t = 0.5 # Interpolation parameter (0 <= t <= 1) | |
# M = torch.tensor([0.7, 0.0, 1.2, 0.0], dtype=torch.float32) # slerp(t, A, B) # Given interpolation point | |
# A_e, B_e = optimize_neighborhood_points(A, B, M, t, enable_penalty=True) | |
# print("Optimized A_e:", A_e) | |
# print("Optimized B_e:", B_e) | |
# spec = torch.arange(36).view(6,6)[None,...] | |
# res = fill_with_neighbor(spec, 2, 4, 2) | |
# Example usage | |
g = torch.tensor([1.0, 2.0, 3.0]) | |
h = torch.tensor([4.0, 5.0, 6.0]) | |
g_ortho = project_onto_tangent_space(g, h) | |
print(g_ortho) | |
import ipdb | |
ipdb.set_trace() | |