JinhuaL1ANG's picture
v1
9a6dac6
try:
import gudhi
except ImportError as e:
import six
error = e.__class__(
"You are likely missing your GUDHI installation, "
"you should visit http://gudhi.gforge.inria.fr/python/latest/installation.html "
"for further instructions.\nIf you use conda, you can use\nconda install -c conda-forge gudhi"
)
six.raise_from(error, e)
import numpy as np
from scipy.spatial.distance import cdist # , pdist, squareform
import matplotlib.pyplot as plt
def relative(I_1, alpha_max, i_max=100):
"""
For a collection of intervals I_1 this functions computes
RLT by formulas (2) and (3). This function will be typically called
on the output of the gudhi persistence_intervals_in_dimension function.
Args:
I_1: list of intervals e.g. [[0, 1], [0, 2], [0, np.inf]].
alpha_max: float, the maximal persistence value
i_max: int, upper bound on the value of beta_1 to compute.
Returns
An array of size (i_max, ) containing desired RLT.
"""
persistence_intervals = []
# If for some interval we have that it persisted up to np.inf
# we replace this point with alpha_max.
for interval in I_1:
if not np.isinf(interval[1]):
persistence_intervals.append(list(interval))
elif np.isinf(interval[1]):
persistence_intervals.append([interval[0], alpha_max])
# If there are no intervals in H1 then we always observed 0 holes.
if len(persistence_intervals) == 0:
rlt = np.zeros(i_max)
rlt[0] = 1.0
return rlt
persistence_intervals_ext = persistence_intervals + [[0, alpha_max]]
persistence_intervals_ext = np.array(persistence_intervals_ext)
persistence_intervals = np.array(persistence_intervals)
# Change in the value of beta_1 may happen only at the boundary points
# of the intervals
switch_points = np.sort(np.unique(persistence_intervals_ext.flatten()))
rlt = np.zeros(i_max)
for i in range(switch_points.shape[0] - 1):
midpoint = (switch_points[i] + switch_points[i + 1]) / 2
s = 0
for interval in persistence_intervals:
# Count how many intervals contain midpoint
if midpoint >= interval[0] and midpoint < interval[1]:
s = s + 1
if s < i_max:
rlt[s] += switch_points[i + 1] - switch_points[i]
return rlt / alpha_max
def lmrk_table(W, L):
"""
Helper function to construct an input for the gudhi.WitnessComplex
function.
Args:
W: 2d array of size w x d, containing witnesses
L: 2d array of size l x d containing landmarks
Returns
Return a 3d array D of size w x l x 2 and the maximal distance
between W and L.
D satisfies the property that D[i, :, :] is [idx_i, dists_i],
where dists_i are the sorted distances from the i-th witness to each
point in L and idx_i are the indices of the corresponding points
in L, e.g.,
D[i, :, :] = [[0, 0.1], [1, 0.2], [3, 0.3], [2, 0.4]]
"""
a = cdist(W, L)
max_val = np.max(a)
idx = np.argsort(a)
b = a[np.arange(np.shape(a)[0])[:, np.newaxis], idx]
return np.dstack([idx, b]), max_val
def random_landmarks(X, L_0=32):
"""
Randomly sample L_0 points from X.
"""
sz = X.shape[0]
idx = np.random.choice(sz, L_0)
L = X[idx]
return L
def witness(X, gamma=1.0 / 128, L_0=64):
"""
This function computes the persistence intervals for the dataset
X using the witness complex.
Args:
X: 2d array representing the dataset.
gamma: parameter determining the maximal persistence value.
L_0: int, number of landmarks to use.
Returns
A list of persistence intervals and the maximal persistence value.
"""
L = random_landmarks(X, L_0)
W = X
lmrk_tab, max_dist = lmrk_table(W, L)
wc = gudhi.WitnessComplex(lmrk_tab)
alpha_max = max_dist * gamma
st = wc.create_simplex_tree(max_alpha_square=alpha_max, limit_dimension=2)
# this seems to modify the st object
st.persistence(homology_coeff_field=2)
diag = st.persistence_intervals_in_dimension(1)
return diag, alpha_max
def fancy_plot(y, color="C0", label="", alpha=0.3):
"""
A function for a nice visualization of MRLT.
"""
n = y.shape[0]
x = np.arange(n)
xleft = x - 0.5
xright = x + 0.5
X = np.array([xleft, xright]).T.flatten()
Xn = np.zeros(X.shape[0] + 2)
Xn[1:-1] = X
Xn[0] = -0.5
Xn[-1] = n - 0.5
Y = np.array([y, y]).T.flatten()
Yn = np.zeros(Y.shape[0] + 2)
Yn[1:-1] = Y
plt.bar(x, y, width=1, alpha=alpha, color=color, edgecolor=color)
plt.plot(Xn, Yn, c=color, label=label, lw=3)