|
import numpy as np |
|
from pathlib import Path |
|
from scipy.special import gamma |
|
from typing import Optional, Tuple, Dict, List, Union |
|
import torch |
|
import os |
|
|
|
class GeneralizedGaussianMixture: |
|
r"""广义高斯混合分布数据集生成器 |
|
P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p) |
|
""" |
|
|
|
def __init__(self, |
|
D: int = 2, |
|
K: int = 3, |
|
p: float = 2.0, |
|
centers: Optional[np.ndarray] = None, |
|
scales: Optional[np.ndarray] = None, |
|
weights: Optional[np.ndarray] = None, |
|
seed: int = 42): |
|
"""初始化GMM数据集生成器 |
|
Args: |
|
D: 数据维度 |
|
K: 聚类数量 |
|
p: 幂次参数,控制分布的形状 |
|
centers: 聚类中心,形状为(K, D) |
|
scales: 尺度参数,形状为(K, D) |
|
weights: 混合权重,形状为(K,) |
|
seed: 随机种子 |
|
""" |
|
self.D = D |
|
self.K = K |
|
self.p = p |
|
self.seed = seed |
|
np.random.seed(seed) |
|
|
|
|
|
if centers is None: |
|
self.centers = np.random.randn(K, D) * 2 |
|
else: |
|
self.centers = centers |
|
|
|
if scales is None: |
|
self.scales = np.random.uniform(0.1, 0.5, size=(K, D)) |
|
else: |
|
self.scales = scales |
|
|
|
if weights is None: |
|
self.weights = np.random.dirichlet(np.ones(K)) |
|
else: |
|
self.weights = weights / weights.sum() |
|
|
|
def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray: |
|
"""计算第k个分量的概率密度 |
|
Args: |
|
x: 输入数据点,形状为(N, D) |
|
k: 分量索引 |
|
Returns: |
|
概率密度值,形状为(N,) |
|
""" |
|
|
|
norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p)) |
|
|
|
|
|
z = np.abs(x - self.centers[k]) / self.scales[k] |
|
exp_term = np.exp(-np.sum(z**self.p, axis=1)) |
|
|
|
return np.prod(norm_const) * exp_term |
|
|
|
def pdf(self, x: np.ndarray) -> np.ndarray: |
|
"""计算混合分布的概率密度 |
|
Args: |
|
x: 输入数据点,形状为(N, D) |
|
Returns: |
|
概率密度值,形状为(N,) |
|
""" |
|
density = np.zeros(len(x)) |
|
for k in range(self.K): |
|
density += self.weights[k] * self.component_pdf(x, k) |
|
return density |
|
|
|
def generate_component_samples(self, n: int, k: int) -> np.ndarray: |
|
"""从第k个分量生成样本 |
|
Args: |
|
n: 样本数量 |
|
k: 分量索引 |
|
Returns: |
|
样本点,形状为(n, D) |
|
""" |
|
|
|
u = np.random.uniform(-1, 1, size=(n, self.D)) |
|
r = np.abs(u) ** (1/self.p) |
|
samples = self.centers[k] + self.scales[k] * np.sign(u) * r |
|
return samples |
|
|
|
def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]: |
|
"""生成混合分布的样本 |
|
Args: |
|
N: 总样本数量 |
|
Returns: |
|
X: 生成的数据点,形状为(N, D) |
|
y: 对应的概率密度值,形状为(N,) |
|
""" |
|
|
|
n_samples = np.random.multinomial(N, self.weights) |
|
|
|
|
|
samples = [] |
|
for k in range(self.K): |
|
x = self.generate_component_samples(n_samples[k], k) |
|
samples.append(x) |
|
|
|
|
|
X = np.vstack(samples) |
|
idx = np.random.permutation(N) |
|
X = X[idx] |
|
|
|
|
|
y = self.pdf(X) |
|
|
|
return X, y |
|
|
|
def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None: |
|
"""保存数据集到文件 |
|
Args: |
|
save_dir: 保存目录 |
|
name: 数据集名称 |
|
""" |
|
save_path = Path(save_dir) |
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
X, y = self.generate_samples(N=1000) |
|
np.savez(str(save_path / f'{name}.npz'), |
|
X=X, y=y, |
|
centers=self.centers, |
|
scales=self.scales, |
|
weights=self.weights, |
|
D=self.D, |
|
K=self.K, |
|
p=self.p) |
|
|
|
@classmethod |
|
def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture": |
|
"""从文件加载数据集 |
|
Args: |
|
file_path: 数据文件路径 |
|
Returns: |
|
加载的GMM对象 |
|
""" |
|
data = np.load(str(file_path)) |
|
return cls( |
|
D=int(data['D']), |
|
K=int(data['K']), |
|
p=float(data['p']), |
|
centers=data['centers'], |
|
scales=data['scales'], |
|
weights=data['weights'] |
|
) |
|
|
|
def test_gmm_dataset(): |
|
"""测试GMM数据集生成器""" |
|
|
|
gmm = GeneralizedGaussianMixture( |
|
D=2, |
|
K=3, |
|
p=2.0, |
|
centers=np.array([[-2, -2], [0, 0], [2, 2]]), |
|
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]), |
|
weights=np.array([0.3, 0.4, 0.3]) |
|
) |
|
|
|
|
|
X, y = gmm.generate_samples(1000) |
|
|
|
|
|
gmm.save_dataset('test_data') |
|
|
|
|
|
loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz') |
|
|
|
|
|
assert np.allclose(gmm.centers, loaded_gmm.centers) |
|
assert np.allclose(gmm.scales, loaded_gmm.scales) |
|
assert np.allclose(gmm.weights, loaded_gmm.weights) |
|
|
|
print("GMM数据集测试通过!") |
|
|
|
if __name__ == '__main__': |
|
test_gmm_dataset() |