import math from typing import Tuple import torch from torch import Tensor __all__ = [ "weighted_sum", "weighted_subtraction", "tensor_sum", "add_difference", "sum_twice", "triple_sum", "euclidean_add_difference", "multiply_difference", "top_k_tensor_sum", "similarity_add_difference", "distribution_crossover", "ties_add_difference", ] EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Basic Merge: alpha 0 returns Primary Model alpha 1 returns Secondary Model """ return (1 - alpha) * a + alpha * b def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ The inverse of a Weighted Sum Merge Returns Primary Model when alpha*beta = 0 High values of alpha*beta are likely to break the merged model """ # Adjust beta if both alpha and beta are 1.0 to avoid division by zero if alpha == 1.0 and beta == 1.0: beta -= EPSILON return (a - alpha * beta * b) / (1 - alpha * beta) def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Takes a slice of Secondary Model and pastes it into Primary Model Alpha sets the width of the slice Beta sets the start point of the slice ie Alpha = 0.5 Beta = 0.25 is (ABBA) Alpha = 0.25 Beta = 0 is (BAAA) """ if alpha + beta <= 1: tt = a.clone() talphas = int(a.shape[0] * beta) talphae = int(a.shape[0] * (alpha + beta)) tt[talphas:talphae] = b[talphas:talphae].clone() else: talphas = int(a.shape[0] * (alpha + beta - 1)) talphae = int(a.shape[0] * beta) tt = b.clone() tt[talphas:talphae] = a[talphas:talphae].clone() return tt def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Classic Add Difference Merge """ return a + alpha * (b - c) def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Stacked Basic Merge: Equivalent to Merging Primary and Secondary @ alpha Then merging the result with Tertiary @ beta """ return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Weights Secondary and Tertiary at alpha and beta respectively Fills in the rest with Primary Expect odd results if alpha + beta > 1 as Primary will be merged with a negative ratio """ return (1 - alpha - beta) * a + alpha * b + beta * c def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Subtract Primary and Secondary from Tertiary Compare the remainders via Euclidean distance Add to Tertiary Note: Slow """ a_diff = a.float() - c.float() b_diff = b.float() - c.float() a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff)) b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff)) distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2 distance = torch.sqrt(distance) sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float() distance = torch.copysign(distance, sum_diff) target_norm = torch.linalg.norm(sum_diff) return c + distance / torch.linalg.norm(distance) * target_norm def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Similar to Add Difference but with geometric mean instead of arithmatic mean """ diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha)) diff_b = torch.pow(torch.abs(b.float() - c), alpha) difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c) return c + difference.to(c.dtype) def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Redistributes the largest weights of Secondary Model into Primary Model """ a_flat = torch.flatten(a) a_dist = torch.msort(a_flat) b_indices = torch.argsort(torch.flatten(b), stable=True) redist_indices = torch.argsort(b_indices) start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a)) start_top_k = kth_abs_value(a_dist, start_i) end_top_k = kth_abs_value(a_dist, end_i) indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) if region_is_inverted: indices_mask = ~indices_mask indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) a_redist = torch.gather(a_dist, 0, redist_indices) a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist return a_redist.reshape_as(a) def kth_abs_value(a: Tensor, k: int) -> Tensor: if k <= 0: return torch.tensor(-1, device=a.device) else: return torch.kthvalue(torch.abs(a.float()), k)[0] def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: if width < 0: offset += width width = -width width = min(width, 1) if offset < 0: offset = 1 + offset - int(offset) offset = math.fmod(offset, 1.0) if width + offset <= 1: inverted = False start = offset * n end = (width + offset) * n else: inverted = True start = (width + offset - 1) * n end = offset * n return round(start), round(end), inverted def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar """ threshold = torch.maximum(torch.abs(a), torch.abs(b)) similarity = ((a * b / threshold**2) + 1) / 2 similarity = torch.nan_to_num(similarity * beta, nan=beta) ab_diff = a + alpha * (b - c) ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b return (1 - similarity) * ab_diff + similarity * ab_sum def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument """ From the creator: It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of Primary and Secondary when ordered with respect to Tertiary. Split the frequency domain using a linear function. Alpha is the split frequency and Beta is the inclination of the line. add everything under the line as the contribution of Primary and everything over the line as the contribution of Secondary """ if a.shape == (): return alpha * a + (1 - alpha) * b c_indices = torch.argsort(torch.flatten(c)) a_dist = torch.gather(torch.flatten(a), 0, c_indices) b_dist = torch.gather(torch.flatten(b), 0, c_indices) a_dft = torch.fft.rfft(a_dist.float()) b_dft = torch.fft.rfft(b_dist.float()) dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float() dft_filter /= torch.numel(a_dft) if beta > EPSILON: dft_filter = (dft_filter - alpha) / beta + 1 / 2 dft_filter = torch.clamp(dft_filter, 0.0, 1.0) else: dft_filter = (dft_filter >= alpha).float() x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft x_dist = torch.fft.irfft(x_dft, a_dist.shape[0]) x_values = torch.gather(x_dist, 0, torch.argsort(c_indices)) return x_values.reshape_as(a) def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument """ An implementation of arXiv:2306.01708 """ deltas = [] signs = [] for m in [a, b]: deltas.append(filter_top_k(m - c, beta)) signs.append(torch.sign(deltas[-1])) signs = torch.stack(signs, dim=0) final_sign = torch.sign(torch.sum(signs, dim=0)) delta_filters = (signs == final_sign).float() res = torch.zeros_like(c, device=c.device) for delta_filter, delta in zip(delta_filters, deltas): res += delta_filter * delta param_count = torch.sum(delta_filters, dim=0) return c + alpha * torch.nan_to_num(res / param_count) def filter_top_k(a: Tensor, k: float): k = max(int((1 - k) * torch.numel(a)), 1) k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) top_k_filter = (torch.abs(a) >= k_value).float() return a * top_k_filter