2catycm commited on
Commit
78e4509
·
1 Parent(s): f7825d3

feat: updates

Browse files
Files changed (7) hide show
  1. .gitignore +175 -0
  2. README.md +68 -6
  3. app.py +266 -0
  4. experiments/gmm_dataset.py +190 -0
  5. experiments/gmm_fitting.py +157 -0
  6. experiments/test.py +69 -0
  7. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.npz
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # UV
99
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ #uv.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117
+ .pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
170
+
171
+ # Ruff stuff:
172
+ .ruff_cache/
173
+
174
+ # PyPI configuration file
175
+ .pypirc
README.md CHANGED
@@ -1,14 +1,76 @@
1
  ---
2
- title: VisualizationForGeneralizedGaussianMixture
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 1.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Interactive visualization of Generalized Gaussian Mixture
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Generalized Gaussian Mixture Visualization
3
+ emoji: 🔄
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: streamlit
7
+ sdk_version: 1.32.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: 'Interactive visualization of Generalized Gaussian Mixture Distribution.'
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # 广义高斯混合分布可视化
17
+
18
+ ## 可视化思路
19
+
20
+ 1. 页面布局:
21
+ ```plaintext
22
+ +-----------------+----------------------+
23
+ | 参数侧边栏 | 主显示区域 |
24
+ | - 形状参数p | +--------+--------+ |
25
+ | - 分量数K | | | | |
26
+ | - 分量参数 | | 3D | 等高线 | |
27
+ | | | Surface | Plot | |
28
+ +-----------------+ | | | |
29
+ +--------+--------+ |
30
+ | 参数说明 |
31
+ +----------------+ |
32
+ ```
33
+
34
+ 2. 图表配置:
35
+ - 左图:3D曲面图 (Surface Plot)
36
+ - X轴:第一维坐标
37
+ - Y轴:第二维坐标
38
+ - Z轴:概率密度值
39
+ - 使用viridis配色方案
40
+
41
+ - 右图:等高线图 (Contour Plot)
42
+ - X轴:第一维坐标
43
+ - Y轴:第二维坐标
44
+ - 颜色:概率密度值
45
+ - 标记分量中心点
46
+
47
+ 3. Plotly配置要点:
48
+ ```python
49
+ # 子图布局
50
+ specs=[[{'type': 'surface'}, {'type': 'contour'}]]
51
+
52
+ # 坐标轴配置
53
+ scene=dict( # 3D图的坐标轴
54
+ xaxis_title='X',
55
+ yaxis_title='Y',
56
+ zaxis_title='Density'
57
+ )
58
+ xaxis=dict(title='X'), # 2D图X轴
59
+ yaxis=dict(title='Y') # 2D图Y轴
60
+ ```
61
+
62
+ ## 数据处理流程
63
+
64
+ 1. 参数处理
65
+ - 基本参数:p(形状), K(分量数)
66
+ - 每个分量:中心点、尺度、权重
67
+ - 参数改变时实时更新
68
+
69
+ 2. 数据生成
70
+ - 使用meshgrid生成网格点
71
+ - 计算每个点的概率密度
72
+ - 重塑数据以适配plotly格式
73
+
74
+ 3. 交互更新
75
+ - 参数变化触发重新计算
76
+ - 动态更新图表和说明
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from experiments.gmm_dataset import GeneralizedGaussianMixture
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ from typing import List, Tuple
8
+
9
+ def init_session_state():
10
+ """初始化session state"""
11
+ if 'prev_K' not in st.session_state:
12
+ st.session_state.prev_K = 3
13
+ if 'p' not in st.session_state:
14
+ st.session_state.p = 2.0
15
+ if 'centers' not in st.session_state:
16
+ st.session_state.centers = np.array([[-2, -2], [0, 0], [2, 2]], dtype=np.float64)
17
+ if 'scales' not in st.session_state:
18
+ st.session_state.scales = np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]], dtype=np.float64)
19
+ if 'weights' not in st.session_state:
20
+ st.session_state.weights = np.ones(3, dtype=np.float64) / 3
21
+ if 'sample_points' not in st.session_state:
22
+ st.session_state.sample_points = None
23
+
24
+ def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
25
+ """创建默认参数"""
26
+ # 在[-3, 3]范围内均匀生成K个中心点
27
+ x = np.linspace(-3, 3, K)
28
+ y = np.linspace(-3, 3, K)
29
+ centers = np.column_stack((x, y))
30
+
31
+ # 默认尺度和权重
32
+ scales = np.ones((K, 2), dtype=np.float64) * 3
33
+ weights = np.random.random(size=K).astype(np.float64)
34
+ weights /= weights.sum() # 归一化权重
35
+ return centers, scales, weights
36
+
37
+ def generate_latex_formula(p: float, K: int, centers: np.ndarray,
38
+ scales: np.ndarray, weights: np.ndarray) -> str:
39
+ """生成LaTeX公式"""
40
+ formula = r"P(x) = \sum_{k=1}^{" + str(K) + r"} \pi_k P_{\theta_k}(x) \\"
41
+ formula += r"P_{\theta_k}(x) = \eta_k \exp(-s_k d_k(x)) = \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-\frac{|x-c_k|^p}{\alpha_k^p})= \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-|\frac{x-c_k}{\alpha_k}|^p) \\"
42
+ formula += r"\text{where: }"
43
+
44
+ for k in range(K):
45
+ c = centers[k]
46
+ s = scales[k]
47
+ w = weights[k]
48
+ component = f"P_{k+1}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
49
+ formula += component
50
+ formula += f"\\pi_{k+1} = {w:.2f} \\\\"
51
+
52
+ return formula
53
+
54
+ st.set_page_config(page_title="GMM Distribution Visualization", layout="wide")
55
+ st.title("广义高斯混合分布可视化")
56
+
57
+ # 初始化session state
58
+ init_session_state()
59
+
60
+ # 侧边栏参数设置
61
+ with st.sidebar:
62
+ st.header("分布参数")
63
+
64
+ # 分布基本参数
65
+ st.session_state.p = st.slider("形状参数 (p)", 0.1, 5.0, st.session_state.p, 0.1,
66
+ help="p=1: 拉普拉斯分布, p=2: 高斯分布, p→∞: 均匀分布")
67
+ K = st.slider("分量数 (K)", 1, 5, st.session_state.prev_K)
68
+
69
+ # 如果K发生变化,重新初始化参数
70
+ if K != st.session_state.prev_K:
71
+ centers, scales, weights = create_default_parameters(K)
72
+ st.session_state.centers = centers
73
+ st.session_state.scales = scales
74
+ st.session_state.weights = weights
75
+ st.session_state.prev_K = K
76
+
77
+ # 高级参数设置
78
+ st.subheader("高级设置")
79
+ show_advanced = st.checkbox("显示分量参数", value=False)
80
+
81
+ if show_advanced:
82
+ # 为每个分量设置参数
83
+ centers_list: List[List[float]] = []
84
+ scales_list: List[List[float]] = []
85
+ weights_list: List[float] = []
86
+
87
+ for k in range(K):
88
+ st.write(f"分量 {k+1}")
89
+ col1, col2 = st.columns(2)
90
+ with col1:
91
+ cx = st.number_input(f"中心X_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][0]), 0.1)
92
+ cy = st.number_input(f"中心Y_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][1]), 0.1)
93
+ with col2:
94
+ sx = st.number_input(f"尺度X_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][0]), 0.1)
95
+ sy = st.number_input(f"尺度Y_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][1]), 0.1)
96
+ w = st.slider(f"权重_{k+1}", 0.0, 1.0, float(st.session_state.weights[k]), 0.1)
97
+
98
+ centers_list.append([cx, cy])
99
+ scales_list.append([sx, sy])
100
+ weights_list.append(w)
101
+
102
+ centers = np.array(centers_list, dtype=np.float64)
103
+ scales = np.array(scales_list, dtype=np.float64)
104
+ weights = np.array(weights_list, dtype=np.float64)
105
+ weights = weights / weights.sum()
106
+
107
+ st.session_state.centers = centers
108
+ st.session_state.scales = scales
109
+ st.session_state.weights = weights
110
+ else:
111
+ centers = st.session_state.centers
112
+ scales = st.session_state.scales
113
+ weights = st.session_state.weights
114
+
115
+ # 采样设置
116
+ st.subheader("采样设置")
117
+ n_samples = st.slider("采样点数", 5, 20, 10)
118
+ if st.button("重新采样"):
119
+ # 生成随机样本
120
+ samples = []
121
+ for _ in range(n_samples):
122
+ # 选择分量
123
+ k = np.random.choice(K, p=weights)
124
+ # 从选定的分量生成样本
125
+ sample = np.random.normal(centers[k], scales[k], size=2)
126
+ samples.append(sample)
127
+ st.session_state.sample_points = np.array(samples)
128
+
129
+ # 创建GMM数据集
130
+ dataset = GeneralizedGaussianMixture(
131
+ D=2,
132
+ K=K,
133
+ p=st.session_state.p,
134
+ centers=centers[:K],
135
+ scales=scales[:K],
136
+ weights=weights[:K]
137
+ )
138
+
139
+ # 生成网格数据
140
+ x = np.linspace(-5, 5, 100)
141
+ y = np.linspace(-5, 5, 100)
142
+ X, Y = np.meshgrid(x, y)
143
+ xy = np.column_stack((X.ravel(), Y.ravel()))
144
+
145
+ # 计算概率密度
146
+ Z = dataset.pdf(xy).reshape(X.shape)
147
+
148
+ # 创建2D和3D可视化
149
+ fig = make_subplots(
150
+ rows=1, cols=2,
151
+ specs=[[{'type': 'surface'}, {'type': 'contour'}]],
152
+ subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
153
+ )
154
+
155
+ # 3D Surface
156
+ surface = go.Surface(
157
+ x=X, y=Y, z=Z,
158
+ colorscale='viridis',
159
+ showscale=True,
160
+ colorbar=dict(x=0.45)
161
+ )
162
+ fig.add_trace(surface, row=1, col=1)
163
+
164
+ # Contour Plot with component centers
165
+ contour = go.Contour(
166
+ x=x, y=y, z=Z,
167
+ colorscale='viridis',
168
+ showscale=True,
169
+ colorbar=dict(x=1.0),
170
+ contours=dict(
171
+ showlabels=True,
172
+ labelfont=dict(size=12)
173
+ )
174
+ )
175
+ fig.add_trace(contour, row=1, col=2)
176
+
177
+ # 添加分量中心点
178
+ fig.add_trace(
179
+ go.Scatter(
180
+ x=centers[:K, 0], y=centers[:K, 1],
181
+ mode='markers+text',
182
+ marker=dict(size=10, color='red'),
183
+ text=[f'C{i+1}' for i in range(K)],
184
+ textposition="top center",
185
+ name='分量中心'
186
+ ),
187
+ row=1, col=2
188
+ )
189
+
190
+ # 添加采样点(如果有)
191
+ if st.session_state.sample_points is not None:
192
+ samples = st.session_state.sample_points
193
+ # 计算每个样本点的概率密度
194
+ probs = dataset.pdf(samples)
195
+ # 计算每个样本点属于每个分量的后验概率
196
+ posteriors = []
197
+ for sample in samples:
198
+ component_probs = [
199
+ weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
200
+ for k in range(K)
201
+ ]
202
+ total = sum(component_probs)
203
+ posteriors.append([p/total for p in component_probs])
204
+
205
+ # 添加样本点到图表
206
+ fig.add_trace(
207
+ go.Scatter(
208
+ x=samples[:, 0], y=samples[:, 1],
209
+ mode='markers+text',
210
+ marker=dict(
211
+ size=8,
212
+ color='yellow',
213
+ line=dict(color='black', width=1)
214
+ ),
215
+ text=[f'S{i+1}' for i in range(len(samples))],
216
+ textposition="bottom center",
217
+ name='采样点'
218
+ ),
219
+ row=1, col=2
220
+ )
221
+
222
+ # 显示样本点的概率信息
223
+ st.subheader("采样点信息")
224
+ for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
225
+ st.write(f"样本点 S{i+1} ({sample[0]:.2f}, {sample[1]:.2f}):")
226
+ st.write(f"- 概率密度: {prob:.4f}")
227
+ st.write("- 后验概率:")
228
+ for k in range(K):
229
+ st.write(f" - 分量 {k+1}: {post[k]:.4f}")
230
+ st.write("---")
231
+
232
+ # 更新布局
233
+ fig.update_layout(
234
+ title='广义高斯混合分布',
235
+ showlegend=True,
236
+ width=1200,
237
+ height=600,
238
+ scene=dict(
239
+ xaxis_title='X',
240
+ yaxis_title='Y',
241
+ zaxis_title='密度'
242
+ )
243
+ )
244
+
245
+ # 更新2D图的坐标轴
246
+ fig.update_xaxes(title_text='X', row=1, col=2)
247
+ fig.update_yaxes(title_text='Y', row=1, col=2)
248
+
249
+ # 显示图形
250
+ st.plotly_chart(fig, use_container_width=True)
251
+
252
+ # 添加参数说明
253
+ with st.expander("分布参数说明"):
254
+ st.markdown("""
255
+ - **形状参数 (p)**:控制分布的形状
256
+ - p = 1: 拉普拉斯分布
257
+ - p = 2: 高斯分布
258
+ - p → ∞: 均匀分布
259
+ - **分量参数**:每个分量由以下参数确定
260
+ - 中心 (μ): 峰值位置,通过X和Y坐标确定
261
+ - 尺度 (α): 分布的展宽程度,X和Y方向可不同
262
+ - 权重 (π): 混合系数,所有分量权重和为1
263
+ """)
264
+
265
+ # 显示当前参数的数学公式
266
+ st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))
experiments/gmm_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from scipy.special import gamma
4
+ from typing import Optional, Tuple, Dict, List, Union
5
+ import torch
6
+ import os
7
+
8
+ class GeneralizedGaussianMixture:
9
+ r"""广义高斯混合分布数据集生成器
10
+ P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
11
+ """
12
+
13
+ def __init__(self,
14
+ D: int = 2, # 维度
15
+ K: int = 3, # 聚类数量
16
+ p: float = 2.0, # 幂次,p=2为标准高斯分布
17
+ centers: Optional[np.ndarray] = None, # 聚类中心
18
+ scales: Optional[np.ndarray] = None, # 尺度参数
19
+ weights: Optional[np.ndarray] = None, # 混合权重
20
+ seed: int = 42): # 随机种子
21
+ """初始化GMM数据集生成器
22
+ Args:
23
+ D: 数据维度
24
+ K: 聚类数量
25
+ p: 幂次参数,控制分布的形状
26
+ centers: 聚类中心,形状为(K, D)
27
+ scales: 尺度参数,形状为(K, D)
28
+ weights: 混合权重,形状为(K,)
29
+ seed: 随机种子
30
+ """
31
+ self.D = D
32
+ self.K = K
33
+ self.p = p
34
+ self.seed = seed
35
+ np.random.seed(seed)
36
+
37
+ # 初始化分布参数
38
+ if centers is None:
39
+ self.centers = np.random.randn(K, D) * 2
40
+ else:
41
+ self.centers = centers
42
+
43
+ if scales is None:
44
+ self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
45
+ else:
46
+ self.scales = scales
47
+
48
+ if weights is None:
49
+ self.weights = np.random.dirichlet(np.ones(K))
50
+ else:
51
+ self.weights = weights / weights.sum() # 确保权重和为1
52
+
53
+ def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
54
+ """计算第k个分量的概率密度
55
+ Args:
56
+ x: 输入数据点,形状为(N, D)
57
+ k: 分量索引
58
+ Returns:
59
+ 概率密度值,形状为(N,)
60
+ """
61
+ # 计算归一化常数
62
+ norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
63
+
64
+ # 计算|x_i - c_k|^p / α_k^p
65
+ z = np.abs(x - self.centers[k]) / self.scales[k]
66
+ exp_term = np.exp(-np.sum(z**self.p, axis=1))
67
+
68
+ return np.prod(norm_const) * exp_term
69
+
70
+ def pdf(self, x: np.ndarray) -> np.ndarray:
71
+ """计算混合分布的概率密度
72
+ Args:
73
+ x: 输入数据点,形状为(N, D)
74
+ Returns:
75
+ 概率密度值,形状为(N,)
76
+ """
77
+ density = np.zeros(len(x))
78
+ for k in range(self.K):
79
+ density += self.weights[k] * self.component_pdf(x, k)
80
+ return density
81
+
82
+ def generate_component_samples(self, n: int, k: int) -> np.ndarray:
83
+ """从第k个分量生成样本
84
+ Args:
85
+ n: 样本数量
86
+ k: 分量索引
87
+ Returns:
88
+ 样本点,形状为(n, D)
89
+ """
90
+ # 使用幂指数分布的反变换采样
91
+ u = np.random.uniform(-1, 1, size=(n, self.D))
92
+ r = np.abs(u) ** (1/self.p)
93
+ samples = self.centers[k] + self.scales[k] * np.sign(u) * r
94
+ return samples
95
+
96
+ def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
97
+ """生成混合分布的样本
98
+ Args:
99
+ N: 总样本数量
100
+ Returns:
101
+ X: 生成的数据点,形状为(N, D)
102
+ y: 对应的概率密度值,形状为(N,)
103
+ """
104
+ # 根据混合权重确定每个分量的样本数量
105
+ n_samples = np.random.multinomial(N, self.weights)
106
+
107
+ # 从每个分量生成样本
108
+ samples = []
109
+ for k in range(self.K):
110
+ x = self.generate_component_samples(n_samples[k], k)
111
+ samples.append(x)
112
+
113
+ # 合并并打乱样本
114
+ X = np.vstack(samples)
115
+ idx = np.random.permutation(N)
116
+ X = X[idx]
117
+
118
+ # 计算概率密度
119
+ y = self.pdf(X)
120
+
121
+ return X, y
122
+
123
+ def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
124
+ """保存数据集到文件
125
+ Args:
126
+ save_dir: 保存目录
127
+ name: 数据集名称
128
+ """
129
+ save_path = Path(save_dir)
130
+ save_path.mkdir(parents=True, exist_ok=True)
131
+
132
+ # 生成并保存数据
133
+ X, y = self.generate_samples(N=1000)
134
+ np.savez(str(save_path / f'{name}.npz'),
135
+ X=X, y=y,
136
+ centers=self.centers,
137
+ scales=self.scales,
138
+ weights=self.weights,
139
+ D=self.D,
140
+ K=self.K,
141
+ p=self.p)
142
+
143
+ @classmethod
144
+ def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
145
+ """从文件加载数据集
146
+ Args:
147
+ file_path: 数据文件路径
148
+ Returns:
149
+ 加载的GMM对象
150
+ """
151
+ data = np.load(str(file_path))
152
+ return cls(
153
+ D=int(data['D']),
154
+ K=int(data['K']),
155
+ p=float(data['p']),
156
+ centers=data['centers'],
157
+ scales=data['scales'],
158
+ weights=data['weights']
159
+ )
160
+
161
+ def test_gmm_dataset():
162
+ """测试GMM数据集生成器"""
163
+ # 创建2D的GMM数据集
164
+ gmm = GeneralizedGaussianMixture(
165
+ D=2,
166
+ K=3,
167
+ p=2.0,
168
+ centers=np.array([[-2, -2], [0, 0], [2, 2]]),
169
+ scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
170
+ weights=np.array([0.3, 0.4, 0.3])
171
+ )
172
+
173
+ # 生成样本
174
+ X, y = gmm.generate_samples(1000)
175
+
176
+ # 保存数据集
177
+ gmm.save_dataset('test_data')
178
+
179
+ # 加载数据集
180
+ loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
181
+
182
+ # 验证保存和加载的参数是否一致
183
+ assert np.allclose(gmm.centers, loaded_gmm.centers)
184
+ assert np.allclose(gmm.scales, loaded_gmm.scales)
185
+ assert np.allclose(gmm.weights, loaded_gmm.weights)
186
+
187
+ print("GMM数据集测试通过!")
188
+
189
+ if __name__ == '__main__':
190
+ test_gmm_dataset()
experiments/gmm_fitting.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.neural_network import MLPRegressor
4
+ from pathlib import Path
5
+ import sys
6
+ import json
7
+ import os
8
+ import shutil
9
+ from typing import Any, Optional
10
+
11
+ # 添加pykan到Python路径
12
+ repo_root = Path(__file__).parent.parent.parent
13
+ sys.path.append(str(repo_root / 'pykan'))
14
+
15
+ from kan import *
16
+ # 针对gmm_dataset的导入,尝试不同的导入路径
17
+ try:
18
+ from .gmm_dataset import GeneralizedGaussianMixture
19
+ except ImportError:
20
+ from gmm_dataset import GeneralizedGaussianMixture
21
+
22
+ def train_and_evaluate(dataset: GeneralizedGaussianMixture,
23
+ save_dir: Path,
24
+ kan_config: Optional[dict[str, Any]] = None,
25
+ random_state: int = 42) -> dict[str, Any]:
26
+ """训练和评估不同模型"""
27
+ save_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ # 生成训练和测试数据
30
+ X_train, y_train = dataset.generate_samples(N=1000)
31
+ X_test, y_test = dataset.generate_samples(N=200)
32
+
33
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ torch.set_default_dtype(torch.float64) # 设置为双精度
35
+
36
+ # 转换数据为PyTorch格式
37
+ train_data = {
38
+ 'train_input': torch.FloatTensor(X_train).to(device),
39
+ 'train_label': torch.FloatTensor(y_train).reshape(-1, 1).to(device),
40
+ 'test_input': torch.FloatTensor(X_test).to(device),
41
+ 'test_label': torch.FloatTensor(y_test).reshape(-1, 1).to(device)
42
+ }
43
+
44
+ # 保存训练数据
45
+ np.savez(save_dir / f'data_{random_state}.npz',
46
+ X_train=X_train, y_train=y_train,
47
+ X_test=X_test, y_test=y_test)
48
+
49
+ # 训练KAN
50
+ if kan_config is None:
51
+ kan_config = {
52
+ 'width': [dataset.D, 5, 1],
53
+ 'grid': 5,
54
+ 'k': 3
55
+ }
56
+
57
+ # 确保device参数是字符串
58
+ kan_model = KAN(**kan_config, seed=random_state, device=str(device))
59
+ kan_model = kan_model.to(device) # 确保模型在正确的设备上
60
+ results = kan_model.fit(train_data, opt="LBFGS", steps=50, lamb=0.001)
61
+
62
+ # 训练MLP
63
+ mlp = MLPRegressor(
64
+ hidden_layer_sizes=(10, 5),
65
+ max_iter=1000,
66
+ random_state=random_state
67
+ )
68
+ mlp.fit(X_train, y_train)
69
+
70
+ # 计算和保存预测结果
71
+ grid_x = np.linspace(X_train.min(), X_train.max(), 100)
72
+ grid_y = np.linspace(X_train.min(), X_train.max(), 100)
73
+ XX, YY = np.meshgrid(grid_x, grid_y)
74
+ grid_points = np.column_stack((XX.ravel(), YY.ravel()))
75
+
76
+ with torch.no_grad():
77
+ kan_pred = kan_model(torch.FloatTensor(grid_points).to(device)).cpu().numpy()
78
+ mlp_pred = mlp.predict(grid_points)
79
+ true_density = dataset.pdf(grid_points)
80
+
81
+ # 计算测试集RMSE
82
+ kan_test_rmse = np.sqrt(np.mean((kan_model(train_data['test_input']).cpu().numpy() - y_test.reshape(-1, 1))**2))
83
+ mlp_test_rmse = np.sqrt(np.mean((mlp.predict(X_test).reshape(-1, 1) - y_test.reshape(-1, 1))**2))
84
+
85
+ evaluation = {
86
+ 'random_state': random_state,
87
+ 'kan_test_rmse': float(kan_test_rmse),
88
+ 'mlp_test_rmse': float(mlp_test_rmse),
89
+ 'training_history': results
90
+ }
91
+
92
+ # 保存预测结果
93
+ np.savez(save_dir / f'predictions_{random_state}.npz',
94
+ grid_points=grid_points,
95
+ kan_pred=kan_pred,
96
+ mlp_pred=mlp_pred,
97
+ true_density=true_density)
98
+
99
+ # 保存评估结果
100
+ with open(save_dir / f'evaluation_{random_state}.json', 'w') as f:
101
+ json.dump(evaluation, f)
102
+
103
+ return evaluation
104
+
105
+ def run_experiments(save_dir: Path, n_experiments: int = 5) -> dict[str, float]:
106
+ """进行多次随机实验"""
107
+ save_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ all_results = []
110
+ base_seed = 42
111
+
112
+ for i in range(n_experiments):
113
+ print(f"Running experiment {i+1}/{n_experiments}")
114
+ random_state = base_seed + i
115
+
116
+ # 创建数据集
117
+ dataset = GeneralizedGaussianMixture(
118
+ D=2,
119
+ K=3,
120
+ p=2.0,
121
+ centers=np.array([[-2, -2], [0, 0], [2, 2]]),
122
+ scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
123
+ weights=np.array([0.3, 0.4, 0.3]),
124
+ seed=random_state
125
+ )
126
+
127
+ # 训练和评估
128
+ result = train_and_evaluate(dataset, save_dir / str(random_state), random_state=random_state)
129
+ all_results.append(result)
130
+
131
+ # 保存所有结果
132
+ with open(save_dir / 'all_results.json', 'w') as f:
133
+ json.dump(all_results, f)
134
+
135
+ # 计算统计量
136
+ kan_rmses = [r['kan_test_rmse'] for r in all_results]
137
+ mlp_rmses = [r['mlp_test_rmse'] for r in all_results]
138
+
139
+ statistics = {
140
+ 'kan_mean_rmse': float(np.mean(kan_rmses)),
141
+ 'kan_std_rmse': float(np.std(kan_rmses)),
142
+ 'mlp_mean_rmse': float(np.mean(mlp_rmses)),
143
+ 'mlp_std_rmse': float(np.std(mlp_rmses)),
144
+ }
145
+
146
+ with open(save_dir / 'statistics.json', 'w') as f:
147
+ json.dump(statistics, f)
148
+
149
+ return statistics
150
+
151
+ if __name__ == '__main__':
152
+ # 使用相对路径,保存在experiments/results目录下
153
+ results_dir = Path(__file__).parent / 'results'
154
+ stats = run_experiments(results_dir)
155
+ print("\nExperiment Statistics:")
156
+ print(f"KAN Test RMSE: {stats['kan_mean_rmse']:.4f} ± {stats['kan_std_rmse']:.4f}")
157
+ print(f"MLP Test RMSE: {stats['mlp_mean_rmse']:.4f} ± {stats['mlp_std_rmse']:.4f}")
experiments/test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+
3
+ # 示例数据(请替换为实际数据)
4
+ methods = ['RSRM', 'PSRN', 'NGGP', 'PySR', 'BMS', 'uDSR', 'AIF',
5
+ 'DGSR', 'E2E', 'SymINDy', 'PhySO', 'TPSR', 'SPL',
6
+ 'DEAP', 'SINDy', 'NSRS', 'gplearn', 'SNIP', 'KAN', 'EQL']
7
+ recovery_rates = [85, 78, 92, 88, 76, 83, 95, 81, 89, 77, 84, 86, 80,
8
+ 79, 82, 87, 75, 88, 90, 84] # 恢复率百分比
9
+ errors = [3, 4, 2, 3, 5, 2, 1, 3, 2, 4, 3, 2, 3, 4, 2, 3, 5, 2, 3, 2] # 误差范围
10
+
11
+ # 创建图形对象
12
+ fig = go.Figure()
13
+
14
+ # 添加带误差线的数据点
15
+ fig.add_trace(go.Scatter(
16
+ x=recovery_rates,
17
+ y=methods,
18
+ mode='markers',
19
+ error_x=dict(
20
+ type='data',
21
+ array=errors,
22
+ visible=True,
23
+ color='#FF5733',
24
+ thickness=2,
25
+ width=10
26
+ ),
27
+ marker=dict(
28
+ size=12,
29
+ color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
30
+ '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
31
+ '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
32
+ '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5'],
33
+ opacity=0.8
34
+ )
35
+ ))
36
+
37
+ # 设置布局
38
+ fig.update_layout(
39
+ title='不同方法的恢复率比较',
40
+ xaxis=dict(
41
+ title='恢复率 (%)',
42
+ range=[0, 100],
43
+ dtick=20,
44
+ title_standoff=25
45
+ ),
46
+ yaxis=dict(
47
+ title='Methods',
48
+ title_font=dict(size=14),
49
+ tickfont=dict(size=12),
50
+ autorange="reversed" # 使第一个方法显示在最上方
51
+ ),
52
+ hovermode='closest',
53
+ width=1000,
54
+ height=600,
55
+ showlegend=False
56
+ )
57
+
58
+ # 添加注释(可选)
59
+ fig.add_annotation(
60
+ x=0,
61
+ y=0.95,
62
+ xref='paper',
63
+ yref='paper',
64
+ text='知乎 @x66ccff',
65
+ showarrow=False,
66
+ font=dict(size=10)
67
+ )
68
+
69
+ fig.show()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.32.0
2
+ numpy>=1.21.0
3
+ pandas>=1.3.0
4
+ plotly>=5.18.0
5
+ scipy>=1.7.0
6
+ torch>=1.9.0
7
+ scikit-learn>=1.0.0
8
+ ipython>=8.0.0
9
+ ipywidgets>=7.0.0
10
+ nbformat>=5.0.0
11
+ sympy>=1.8