Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,758 Bytes
9a6dac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
try:
import gudhi
except ImportError as e:
import six
error = e.__class__(
"You are likely missing your GUDHI installation, "
"you should visit http://gudhi.gforge.inria.fr/python/latest/installation.html "
"for further instructions.\nIf you use conda, you can use\nconda install -c conda-forge gudhi"
)
six.raise_from(error, e)
import numpy as np
from scipy.spatial.distance import cdist # , pdist, squareform
import matplotlib.pyplot as plt
def relative(I_1, alpha_max, i_max=100):
"""
For a collection of intervals I_1 this functions computes
RLT by formulas (2) and (3). This function will be typically called
on the output of the gudhi persistence_intervals_in_dimension function.
Args:
I_1: list of intervals e.g. [[0, 1], [0, 2], [0, np.inf]].
alpha_max: float, the maximal persistence value
i_max: int, upper bound on the value of beta_1 to compute.
Returns
An array of size (i_max, ) containing desired RLT.
"""
persistence_intervals = []
# If for some interval we have that it persisted up to np.inf
# we replace this point with alpha_max.
for interval in I_1:
if not np.isinf(interval[1]):
persistence_intervals.append(list(interval))
elif np.isinf(interval[1]):
persistence_intervals.append([interval[0], alpha_max])
# If there are no intervals in H1 then we always observed 0 holes.
if len(persistence_intervals) == 0:
rlt = np.zeros(i_max)
rlt[0] = 1.0
return rlt
persistence_intervals_ext = persistence_intervals + [[0, alpha_max]]
persistence_intervals_ext = np.array(persistence_intervals_ext)
persistence_intervals = np.array(persistence_intervals)
# Change in the value of beta_1 may happen only at the boundary points
# of the intervals
switch_points = np.sort(np.unique(persistence_intervals_ext.flatten()))
rlt = np.zeros(i_max)
for i in range(switch_points.shape[0] - 1):
midpoint = (switch_points[i] + switch_points[i + 1]) / 2
s = 0
for interval in persistence_intervals:
# Count how many intervals contain midpoint
if midpoint >= interval[0] and midpoint < interval[1]:
s = s + 1
if s < i_max:
rlt[s] += switch_points[i + 1] - switch_points[i]
return rlt / alpha_max
def lmrk_table(W, L):
"""
Helper function to construct an input for the gudhi.WitnessComplex
function.
Args:
W: 2d array of size w x d, containing witnesses
L: 2d array of size l x d containing landmarks
Returns
Return a 3d array D of size w x l x 2 and the maximal distance
between W and L.
D satisfies the property that D[i, :, :] is [idx_i, dists_i],
where dists_i are the sorted distances from the i-th witness to each
point in L and idx_i are the indices of the corresponding points
in L, e.g.,
D[i, :, :] = [[0, 0.1], [1, 0.2], [3, 0.3], [2, 0.4]]
"""
a = cdist(W, L)
max_val = np.max(a)
idx = np.argsort(a)
b = a[np.arange(np.shape(a)[0])[:, np.newaxis], idx]
return np.dstack([idx, b]), max_val
def random_landmarks(X, L_0=32):
"""
Randomly sample L_0 points from X.
"""
sz = X.shape[0]
idx = np.random.choice(sz, L_0)
L = X[idx]
return L
def witness(X, gamma=1.0 / 128, L_0=64):
"""
This function computes the persistence intervals for the dataset
X using the witness complex.
Args:
X: 2d array representing the dataset.
gamma: parameter determining the maximal persistence value.
L_0: int, number of landmarks to use.
Returns
A list of persistence intervals and the maximal persistence value.
"""
L = random_landmarks(X, L_0)
W = X
lmrk_tab, max_dist = lmrk_table(W, L)
wc = gudhi.WitnessComplex(lmrk_tab)
alpha_max = max_dist * gamma
st = wc.create_simplex_tree(max_alpha_square=alpha_max, limit_dimension=2)
# this seems to modify the st object
st.persistence(homology_coeff_field=2)
diag = st.persistence_intervals_in_dimension(1)
return diag, alpha_max
def fancy_plot(y, color="C0", label="", alpha=0.3):
"""
A function for a nice visualization of MRLT.
"""
n = y.shape[0]
x = np.arange(n)
xleft = x - 0.5
xright = x + 0.5
X = np.array([xleft, xright]).T.flatten()
Xn = np.zeros(X.shape[0] + 2)
Xn[1:-1] = X
Xn[0] = -0.5
Xn[-1] = n - 0.5
Y = np.array([y, y]).T.flatten()
Yn = np.zeros(Y.shape[0] + 2)
Yn[1:-1] = Y
plt.bar(x, y, width=1, alpha=alpha, color=color, edgecolor=color)
plt.plot(Xn, Yn, c=color, label=label, lw=3)
|