File size: 6,233 Bytes
78e4509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import numpy as np
from pathlib import Path
from scipy.special import gamma
from typing import Optional, Tuple, Dict, List, Union
import torch
import os
class GeneralizedGaussianMixture:
r"""广义高斯混合分布数据集生成器
P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
"""
def __init__(self,
D: int = 2, # 维度
K: int = 3, # 聚类数量
p: float = 2.0, # 幂次,p=2为标准高斯分布
centers: Optional[np.ndarray] = None, # 聚类中心
scales: Optional[np.ndarray] = None, # 尺度参数
weights: Optional[np.ndarray] = None, # 混合权重
seed: int = 42): # 随机种子
"""初始化GMM数据集生成器
Args:
D: 数据维度
K: 聚类数量
p: 幂次参数,控制分布的形状
centers: 聚类中心,形状为(K, D)
scales: 尺度参数,形状为(K, D)
weights: 混合权重,形状为(K,)
seed: 随机种子
"""
self.D = D
self.K = K
self.p = p
self.seed = seed
np.random.seed(seed)
# 初始化分布参数
if centers is None:
self.centers = np.random.randn(K, D) * 2
else:
self.centers = centers
if scales is None:
self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
else:
self.scales = scales
if weights is None:
self.weights = np.random.dirichlet(np.ones(K))
else:
self.weights = weights / weights.sum() # 确保权重和为1
def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
"""计算第k个分量的概率密度
Args:
x: 输入数据点,形状为(N, D)
k: 分量索引
Returns:
概率密度值,形状为(N,)
"""
# 计算归一化常数
norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
# 计算|x_i - c_k|^p / α_k^p
z = np.abs(x - self.centers[k]) / self.scales[k]
exp_term = np.exp(-np.sum(z**self.p, axis=1))
return np.prod(norm_const) * exp_term
def pdf(self, x: np.ndarray) -> np.ndarray:
"""计算混合分布的概率密度
Args:
x: 输入数据点,形状为(N, D)
Returns:
概率密度值,形状为(N,)
"""
density = np.zeros(len(x))
for k in range(self.K):
density += self.weights[k] * self.component_pdf(x, k)
return density
def generate_component_samples(self, n: int, k: int) -> np.ndarray:
"""从第k个分量生成样本
Args:
n: 样本数量
k: 分量索引
Returns:
样本点,形状为(n, D)
"""
# 使用幂指数分布的反变换采样
u = np.random.uniform(-1, 1, size=(n, self.D))
r = np.abs(u) ** (1/self.p)
samples = self.centers[k] + self.scales[k] * np.sign(u) * r
return samples
def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
"""生成混合分布的样本
Args:
N: 总样本数量
Returns:
X: 生成的数据点,形状为(N, D)
y: 对应的概率密度值,形状为(N,)
"""
# 根据混合权重确定每个分量的样本数量
n_samples = np.random.multinomial(N, self.weights)
# 从每个分量生成样本
samples = []
for k in range(self.K):
x = self.generate_component_samples(n_samples[k], k)
samples.append(x)
# 合并并打乱样本
X = np.vstack(samples)
idx = np.random.permutation(N)
X = X[idx]
# 计算概率密度
y = self.pdf(X)
return X, y
def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
"""保存数据集到文件
Args:
save_dir: 保存目录
name: 数据集名称
"""
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
# 生成并保存数据
X, y = self.generate_samples(N=1000)
np.savez(str(save_path / f'{name}.npz'),
X=X, y=y,
centers=self.centers,
scales=self.scales,
weights=self.weights,
D=self.D,
K=self.K,
p=self.p)
@classmethod
def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
"""从文件加载数据集
Args:
file_path: 数据文件路径
Returns:
加载的GMM对象
"""
data = np.load(str(file_path))
return cls(
D=int(data['D']),
K=int(data['K']),
p=float(data['p']),
centers=data['centers'],
scales=data['scales'],
weights=data['weights']
)
def test_gmm_dataset():
"""测试GMM数据集生成器"""
# 创建2D的GMM数据集
gmm = GeneralizedGaussianMixture(
D=2,
K=3,
p=2.0,
centers=np.array([[-2, -2], [0, 0], [2, 2]]),
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
weights=np.array([0.3, 0.4, 0.3])
)
# 生成样本
X, y = gmm.generate_samples(1000)
# 保存数据集
gmm.save_dataset('test_data')
# 加载数据集
loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
# 验证保存和加载的参数是否一致
assert np.allclose(gmm.centers, loaded_gmm.centers)
assert np.allclose(gmm.scales, loaded_gmm.scales)
assert np.allclose(gmm.weights, loaded_gmm.weights)
print("GMM数据集测试通过!")
if __name__ == '__main__':
test_gmm_dataset() |