Spaces:
Running
on
Zero
Running
on
Zero
try: | |
import gudhi | |
except ImportError as e: | |
import six | |
error = e.__class__( | |
"You are likely missing your GUDHI installation, " | |
"you should visit http://gudhi.gforge.inria.fr/python/latest/installation.html " | |
"for further instructions.\nIf you use conda, you can use\nconda install -c conda-forge gudhi" | |
) | |
six.raise_from(error, e) | |
import numpy as np | |
from scipy.spatial.distance import cdist # , pdist, squareform | |
import matplotlib.pyplot as plt | |
def relative(I_1, alpha_max, i_max=100): | |
""" | |
For a collection of intervals I_1 this functions computes | |
RLT by formulas (2) and (3). This function will be typically called | |
on the output of the gudhi persistence_intervals_in_dimension function. | |
Args: | |
I_1: list of intervals e.g. [[0, 1], [0, 2], [0, np.inf]]. | |
alpha_max: float, the maximal persistence value | |
i_max: int, upper bound on the value of beta_1 to compute. | |
Returns | |
An array of size (i_max, ) containing desired RLT. | |
""" | |
persistence_intervals = [] | |
# If for some interval we have that it persisted up to np.inf | |
# we replace this point with alpha_max. | |
for interval in I_1: | |
if not np.isinf(interval[1]): | |
persistence_intervals.append(list(interval)) | |
elif np.isinf(interval[1]): | |
persistence_intervals.append([interval[0], alpha_max]) | |
# If there are no intervals in H1 then we always observed 0 holes. | |
if len(persistence_intervals) == 0: | |
rlt = np.zeros(i_max) | |
rlt[0] = 1.0 | |
return rlt | |
persistence_intervals_ext = persistence_intervals + [[0, alpha_max]] | |
persistence_intervals_ext = np.array(persistence_intervals_ext) | |
persistence_intervals = np.array(persistence_intervals) | |
# Change in the value of beta_1 may happen only at the boundary points | |
# of the intervals | |
switch_points = np.sort(np.unique(persistence_intervals_ext.flatten())) | |
rlt = np.zeros(i_max) | |
for i in range(switch_points.shape[0] - 1): | |
midpoint = (switch_points[i] + switch_points[i + 1]) / 2 | |
s = 0 | |
for interval in persistence_intervals: | |
# Count how many intervals contain midpoint | |
if midpoint >= interval[0] and midpoint < interval[1]: | |
s = s + 1 | |
if s < i_max: | |
rlt[s] += switch_points[i + 1] - switch_points[i] | |
return rlt / alpha_max | |
def lmrk_table(W, L): | |
""" | |
Helper function to construct an input for the gudhi.WitnessComplex | |
function. | |
Args: | |
W: 2d array of size w x d, containing witnesses | |
L: 2d array of size l x d containing landmarks | |
Returns | |
Return a 3d array D of size w x l x 2 and the maximal distance | |
between W and L. | |
D satisfies the property that D[i, :, :] is [idx_i, dists_i], | |
where dists_i are the sorted distances from the i-th witness to each | |
point in L and idx_i are the indices of the corresponding points | |
in L, e.g., | |
D[i, :, :] = [[0, 0.1], [1, 0.2], [3, 0.3], [2, 0.4]] | |
""" | |
a = cdist(W, L) | |
max_val = np.max(a) | |
idx = np.argsort(a) | |
b = a[np.arange(np.shape(a)[0])[:, np.newaxis], idx] | |
return np.dstack([idx, b]), max_val | |
def random_landmarks(X, L_0=32): | |
""" | |
Randomly sample L_0 points from X. | |
""" | |
sz = X.shape[0] | |
idx = np.random.choice(sz, L_0) | |
L = X[idx] | |
return L | |
def witness(X, gamma=1.0 / 128, L_0=64): | |
""" | |
This function computes the persistence intervals for the dataset | |
X using the witness complex. | |
Args: | |
X: 2d array representing the dataset. | |
gamma: parameter determining the maximal persistence value. | |
L_0: int, number of landmarks to use. | |
Returns | |
A list of persistence intervals and the maximal persistence value. | |
""" | |
L = random_landmarks(X, L_0) | |
W = X | |
lmrk_tab, max_dist = lmrk_table(W, L) | |
wc = gudhi.WitnessComplex(lmrk_tab) | |
alpha_max = max_dist * gamma | |
st = wc.create_simplex_tree(max_alpha_square=alpha_max, limit_dimension=2) | |
# this seems to modify the st object | |
st.persistence(homology_coeff_field=2) | |
diag = st.persistence_intervals_in_dimension(1) | |
return diag, alpha_max | |
def fancy_plot(y, color="C0", label="", alpha=0.3): | |
""" | |
A function for a nice visualization of MRLT. | |
""" | |
n = y.shape[0] | |
x = np.arange(n) | |
xleft = x - 0.5 | |
xright = x + 0.5 | |
X = np.array([xleft, xright]).T.flatten() | |
Xn = np.zeros(X.shape[0] + 2) | |
Xn[1:-1] = X | |
Xn[0] = -0.5 | |
Xn[-1] = n - 0.5 | |
Y = np.array([y, y]).T.flatten() | |
Yn = np.zeros(Y.shape[0] + 2) | |
Yn[1:-1] = Y | |
plt.bar(x, y, width=1, alpha=alpha, color=color, edgecolor=color) | |
plt.plot(Xn, Yn, c=color, label=label, lw=3) | |