File size: 6,233 Bytes
78e4509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import numpy as np
from pathlib import Path
from scipy.special import gamma
from typing import Optional, Tuple, Dict, List, Union
import torch
import os

class GeneralizedGaussianMixture:
    r"""广义高斯混合分布数据集生成器
    P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
    """
    
    def __init__(self, 
                 D: int = 2,           # 维度
                 K: int = 3,           # 聚类数量
                 p: float = 2.0,       # 幂次,p=2为标准高斯分布
                 centers: Optional[np.ndarray] = None,  # 聚类中心
                 scales: Optional[np.ndarray] = None,   # 尺度参数
                 weights: Optional[np.ndarray] = None,  # 混合权重
                 seed: int = 42):      # 随机种子
        """初始化GMM数据集生成器
        Args:
            D: 数据维度
            K: 聚类数量
            p: 幂次参数,控制分布的形状
            centers: 聚类中心,形状为(K, D)
            scales: 尺度参数,形状为(K, D)
            weights: 混合权重,形状为(K,)
            seed: 随机种子
        """
        self.D = D
        self.K = K
        self.p = p
        self.seed = seed
        np.random.seed(seed)
        
        # 初始化分布参数
        if centers is None:
            self.centers = np.random.randn(K, D) * 2
        else:
            self.centers = centers
            
        if scales is None:
            self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
        else:
            self.scales = scales
            
        if weights is None:
            self.weights = np.random.dirichlet(np.ones(K))
        else:
            self.weights = weights / weights.sum()  # 确保权重和为1
        
    def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
        """计算第k个分量的概率密度
        Args:
            x: 输入数据点,形状为(N, D)
            k: 分量索引
        Returns:
            概率密度值,形状为(N,)
        """
        # 计算归一化常数
        norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
        
        # 计算|x_i - c_k|^p / α_k^p
        z = np.abs(x - self.centers[k]) / self.scales[k]
        exp_term = np.exp(-np.sum(z**self.p, axis=1))
        
        return np.prod(norm_const) * exp_term
    
    def pdf(self, x: np.ndarray) -> np.ndarray:
        """计算混合分布的概率密度
        Args:
            x: 输入数据点,形状为(N, D)
        Returns:
            概率密度值,形状为(N,)
        """
        density = np.zeros(len(x))
        for k in range(self.K):
            density += self.weights[k] * self.component_pdf(x, k)
        return density
    
    def generate_component_samples(self, n: int, k: int) -> np.ndarray:
        """从第k个分量生成样本
        Args:
            n: 样本数量
            k: 分量索引
        Returns:
            样本点,形状为(n, D)
        """
        # 使用幂指数分布的反变换采样
        u = np.random.uniform(-1, 1, size=(n, self.D))
        r = np.abs(u) ** (1/self.p)
        samples = self.centers[k] + self.scales[k] * np.sign(u) * r
        return samples
    
    def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
        """生成混合分布的样本
        Args:
            N: 总样本数量
        Returns:
            X: 生成的数据点,形状为(N, D)
            y: 对应的概率密度值,形状为(N,)
        """
        # 根据混合权重确定每个分量的样本数量
        n_samples = np.random.multinomial(N, self.weights)
        
        # 从每个分量生成样本
        samples = []
        for k in range(self.K):
            x = self.generate_component_samples(n_samples[k], k)
            samples.append(x)
        
        # 合并并打乱样本
        X = np.vstack(samples)
        idx = np.random.permutation(N)
        X = X[idx]
        
        # 计算概率密度
        y = self.pdf(X)
        
        return X, y
    
    def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
        """保存数据集到文件
        Args:
            save_dir: 保存目录
            name: 数据集名称
        """
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # 生成并保存数据
        X, y = self.generate_samples(N=1000)
        np.savez(str(save_path / f'{name}.npz'),
                 X=X, y=y,
                 centers=self.centers,
                 scales=self.scales,
                 weights=self.weights,
                 D=self.D,
                 K=self.K,
                 p=self.p)
    
    @classmethod
    def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
        """从文件加载数据集
        Args:
            file_path: 数据文件路径
        Returns:
            加载的GMM对象
        """
        data = np.load(str(file_path))
        return cls(
            D=int(data['D']),
            K=int(data['K']),
            p=float(data['p']),
            centers=data['centers'],
            scales=data['scales'],
            weights=data['weights']
        )

def test_gmm_dataset():
    """测试GMM数据集生成器"""
    # 创建2D的GMM数据集
    gmm = GeneralizedGaussianMixture(
        D=2, 
        K=3, 
        p=2.0,
        centers=np.array([[-2, -2], [0, 0], [2, 2]]),
        scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
        weights=np.array([0.3, 0.4, 0.3])
    )
    
    # 生成样本
    X, y = gmm.generate_samples(1000)
    
    # 保存数据集
    gmm.save_dataset('test_data')
    
    # 加载数据集
    loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
    
    # 验证保存和加载的参数是否一致
    assert np.allclose(gmm.centers, loaded_gmm.centers)
    assert np.allclose(gmm.scales, loaded_gmm.scales)
    assert np.allclose(gmm.weights, loaded_gmm.weights)
    
    print("GMM数据集测试通过!")

if __name__ == '__main__':
    test_gmm_dataset()