feat: updates
Browse files- .gitignore +175 -0
- README.md +68 -6
- app.py +266 -0
- experiments/gmm_dataset.py +190 -0
- experiments/gmm_fitting.py +157 -0
- experiments/test.py +69 -0
- requirements.txt +11 -0
.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.npz
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# UV
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
#uv.lock
|
103 |
+
|
104 |
+
# poetry
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
106 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
107 |
+
# commonly ignored for libraries.
|
108 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
109 |
+
#poetry.lock
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
#pdm.lock
|
114 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
115 |
+
# in version control.
|
116 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
117 |
+
.pdm.toml
|
118 |
+
.pdm-python
|
119 |
+
.pdm-build/
|
120 |
+
|
121 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
122 |
+
__pypackages__/
|
123 |
+
|
124 |
+
# Celery stuff
|
125 |
+
celerybeat-schedule
|
126 |
+
celerybeat.pid
|
127 |
+
|
128 |
+
# SageMath parsed files
|
129 |
+
*.sage.py
|
130 |
+
|
131 |
+
# Environments
|
132 |
+
.env
|
133 |
+
.venv
|
134 |
+
env/
|
135 |
+
venv/
|
136 |
+
ENV/
|
137 |
+
env.bak/
|
138 |
+
venv.bak/
|
139 |
+
|
140 |
+
# Spyder project settings
|
141 |
+
.spyderproject
|
142 |
+
.spyproject
|
143 |
+
|
144 |
+
# Rope project settings
|
145 |
+
.ropeproject
|
146 |
+
|
147 |
+
# mkdocs documentation
|
148 |
+
/site
|
149 |
+
|
150 |
+
# mypy
|
151 |
+
.mypy_cache/
|
152 |
+
.dmypy.json
|
153 |
+
dmypy.json
|
154 |
+
|
155 |
+
# Pyre type checker
|
156 |
+
.pyre/
|
157 |
+
|
158 |
+
# pytype static type analyzer
|
159 |
+
.pytype/
|
160 |
+
|
161 |
+
# Cython debug symbols
|
162 |
+
cython_debug/
|
163 |
+
|
164 |
+
# PyCharm
|
165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
169 |
+
#.idea/
|
170 |
+
|
171 |
+
# Ruff stuff:
|
172 |
+
.ruff_cache/
|
173 |
+
|
174 |
+
# PyPI configuration file
|
175 |
+
.pypirc
|
README.md
CHANGED
@@ -1,14 +1,76 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
-
short_description: Interactive visualization of Generalized Gaussian Mixture
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Generalized Gaussian Mixture Visualization
|
3 |
+
emoji: 🔄
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.32.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
short_description: 'Interactive visualization of Generalized Gaussian Mixture Distribution.'
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
# 广义高斯混合分布可视化
|
17 |
+
|
18 |
+
## 可视化思路
|
19 |
+
|
20 |
+
1. 页面布局:
|
21 |
+
```plaintext
|
22 |
+
+-----------------+----------------------+
|
23 |
+
| 参数侧边栏 | 主显示区域 |
|
24 |
+
| - 形状参数p | +--------+--------+ |
|
25 |
+
| - 分量数K | | | | |
|
26 |
+
| - 分量参数 | | 3D | 等高线 | |
|
27 |
+
| | | Surface | Plot | |
|
28 |
+
+-----------------+ | | | |
|
29 |
+
+--------+--------+ |
|
30 |
+
| 参数说明 |
|
31 |
+
+----------------+ |
|
32 |
+
```
|
33 |
+
|
34 |
+
2. 图表配置:
|
35 |
+
- 左图:3D曲面图 (Surface Plot)
|
36 |
+
- X轴:第一维坐标
|
37 |
+
- Y轴:第二维坐标
|
38 |
+
- Z轴:概率密度值
|
39 |
+
- 使用viridis配色方案
|
40 |
+
|
41 |
+
- 右图:等高线图 (Contour Plot)
|
42 |
+
- X轴:第一维坐标
|
43 |
+
- Y轴:第二维坐标
|
44 |
+
- 颜色:概率密度值
|
45 |
+
- 标记分量中心点
|
46 |
+
|
47 |
+
3. Plotly配置要点:
|
48 |
+
```python
|
49 |
+
# 子图布局
|
50 |
+
specs=[[{'type': 'surface'}, {'type': 'contour'}]]
|
51 |
+
|
52 |
+
# 坐标轴配置
|
53 |
+
scene=dict( # 3D图的坐标轴
|
54 |
+
xaxis_title='X',
|
55 |
+
yaxis_title='Y',
|
56 |
+
zaxis_title='Density'
|
57 |
+
)
|
58 |
+
xaxis=dict(title='X'), # 2D图X轴
|
59 |
+
yaxis=dict(title='Y') # 2D图Y轴
|
60 |
+
```
|
61 |
+
|
62 |
+
## 数据处理流程
|
63 |
+
|
64 |
+
1. 参数处理
|
65 |
+
- 基本参数:p(形状), K(分量数)
|
66 |
+
- 每个分量:中心点、尺度、权重
|
67 |
+
- 参数改变时实时更新
|
68 |
+
|
69 |
+
2. 数据生成
|
70 |
+
- 使用meshgrid生成网格点
|
71 |
+
- 计算每个点的概率密度
|
72 |
+
- 重塑数据以适配plotly格式
|
73 |
+
|
74 |
+
3. 交互更新
|
75 |
+
- 参数变化触发重新计算
|
76 |
+
- 动态更新图表和说明
|
app.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
from experiments.gmm_dataset import GeneralizedGaussianMixture
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
from plotly.subplots import make_subplots
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
def init_session_state():
|
10 |
+
"""初始化session state"""
|
11 |
+
if 'prev_K' not in st.session_state:
|
12 |
+
st.session_state.prev_K = 3
|
13 |
+
if 'p' not in st.session_state:
|
14 |
+
st.session_state.p = 2.0
|
15 |
+
if 'centers' not in st.session_state:
|
16 |
+
st.session_state.centers = np.array([[-2, -2], [0, 0], [2, 2]], dtype=np.float64)
|
17 |
+
if 'scales' not in st.session_state:
|
18 |
+
st.session_state.scales = np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]], dtype=np.float64)
|
19 |
+
if 'weights' not in st.session_state:
|
20 |
+
st.session_state.weights = np.ones(3, dtype=np.float64) / 3
|
21 |
+
if 'sample_points' not in st.session_state:
|
22 |
+
st.session_state.sample_points = None
|
23 |
+
|
24 |
+
def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
25 |
+
"""创建默认参数"""
|
26 |
+
# 在[-3, 3]范围内均匀生成K个中心点
|
27 |
+
x = np.linspace(-3, 3, K)
|
28 |
+
y = np.linspace(-3, 3, K)
|
29 |
+
centers = np.column_stack((x, y))
|
30 |
+
|
31 |
+
# 默认尺度和权重
|
32 |
+
scales = np.ones((K, 2), dtype=np.float64) * 3
|
33 |
+
weights = np.random.random(size=K).astype(np.float64)
|
34 |
+
weights /= weights.sum() # 归一化权重
|
35 |
+
return centers, scales, weights
|
36 |
+
|
37 |
+
def generate_latex_formula(p: float, K: int, centers: np.ndarray,
|
38 |
+
scales: np.ndarray, weights: np.ndarray) -> str:
|
39 |
+
"""生成LaTeX公式"""
|
40 |
+
formula = r"P(x) = \sum_{k=1}^{" + str(K) + r"} \pi_k P_{\theta_k}(x) \\"
|
41 |
+
formula += r"P_{\theta_k}(x) = \eta_k \exp(-s_k d_k(x)) = \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-\frac{|x-c_k|^p}{\alpha_k^p})= \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-|\frac{x-c_k}{\alpha_k}|^p) \\"
|
42 |
+
formula += r"\text{where: }"
|
43 |
+
|
44 |
+
for k in range(K):
|
45 |
+
c = centers[k]
|
46 |
+
s = scales[k]
|
47 |
+
w = weights[k]
|
48 |
+
component = f"P_{k+1}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
|
49 |
+
formula += component
|
50 |
+
formula += f"\\pi_{k+1} = {w:.2f} \\\\"
|
51 |
+
|
52 |
+
return formula
|
53 |
+
|
54 |
+
st.set_page_config(page_title="GMM Distribution Visualization", layout="wide")
|
55 |
+
st.title("广义高斯混合分布可视化")
|
56 |
+
|
57 |
+
# 初始化session state
|
58 |
+
init_session_state()
|
59 |
+
|
60 |
+
# 侧边栏参数设置
|
61 |
+
with st.sidebar:
|
62 |
+
st.header("分布参数")
|
63 |
+
|
64 |
+
# 分布基本参数
|
65 |
+
st.session_state.p = st.slider("形状参数 (p)", 0.1, 5.0, st.session_state.p, 0.1,
|
66 |
+
help="p=1: 拉普拉斯分布, p=2: 高斯分布, p→∞: 均匀分布")
|
67 |
+
K = st.slider("分量数 (K)", 1, 5, st.session_state.prev_K)
|
68 |
+
|
69 |
+
# 如果K发生变化,重新初始化参数
|
70 |
+
if K != st.session_state.prev_K:
|
71 |
+
centers, scales, weights = create_default_parameters(K)
|
72 |
+
st.session_state.centers = centers
|
73 |
+
st.session_state.scales = scales
|
74 |
+
st.session_state.weights = weights
|
75 |
+
st.session_state.prev_K = K
|
76 |
+
|
77 |
+
# 高级参数设置
|
78 |
+
st.subheader("高级设置")
|
79 |
+
show_advanced = st.checkbox("显示分量参数", value=False)
|
80 |
+
|
81 |
+
if show_advanced:
|
82 |
+
# 为每个分量设置参数
|
83 |
+
centers_list: List[List[float]] = []
|
84 |
+
scales_list: List[List[float]] = []
|
85 |
+
weights_list: List[float] = []
|
86 |
+
|
87 |
+
for k in range(K):
|
88 |
+
st.write(f"分量 {k+1}")
|
89 |
+
col1, col2 = st.columns(2)
|
90 |
+
with col1:
|
91 |
+
cx = st.number_input(f"中心X_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][0]), 0.1)
|
92 |
+
cy = st.number_input(f"中心Y_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][1]), 0.1)
|
93 |
+
with col2:
|
94 |
+
sx = st.number_input(f"尺度X_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][0]), 0.1)
|
95 |
+
sy = st.number_input(f"尺度Y_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][1]), 0.1)
|
96 |
+
w = st.slider(f"权重_{k+1}", 0.0, 1.0, float(st.session_state.weights[k]), 0.1)
|
97 |
+
|
98 |
+
centers_list.append([cx, cy])
|
99 |
+
scales_list.append([sx, sy])
|
100 |
+
weights_list.append(w)
|
101 |
+
|
102 |
+
centers = np.array(centers_list, dtype=np.float64)
|
103 |
+
scales = np.array(scales_list, dtype=np.float64)
|
104 |
+
weights = np.array(weights_list, dtype=np.float64)
|
105 |
+
weights = weights / weights.sum()
|
106 |
+
|
107 |
+
st.session_state.centers = centers
|
108 |
+
st.session_state.scales = scales
|
109 |
+
st.session_state.weights = weights
|
110 |
+
else:
|
111 |
+
centers = st.session_state.centers
|
112 |
+
scales = st.session_state.scales
|
113 |
+
weights = st.session_state.weights
|
114 |
+
|
115 |
+
# 采样设置
|
116 |
+
st.subheader("采样设置")
|
117 |
+
n_samples = st.slider("采样点数", 5, 20, 10)
|
118 |
+
if st.button("重新采样"):
|
119 |
+
# 生成随机样本
|
120 |
+
samples = []
|
121 |
+
for _ in range(n_samples):
|
122 |
+
# 选择分量
|
123 |
+
k = np.random.choice(K, p=weights)
|
124 |
+
# 从选定的分量生成样本
|
125 |
+
sample = np.random.normal(centers[k], scales[k], size=2)
|
126 |
+
samples.append(sample)
|
127 |
+
st.session_state.sample_points = np.array(samples)
|
128 |
+
|
129 |
+
# 创建GMM数据集
|
130 |
+
dataset = GeneralizedGaussianMixture(
|
131 |
+
D=2,
|
132 |
+
K=K,
|
133 |
+
p=st.session_state.p,
|
134 |
+
centers=centers[:K],
|
135 |
+
scales=scales[:K],
|
136 |
+
weights=weights[:K]
|
137 |
+
)
|
138 |
+
|
139 |
+
# 生成网格数据
|
140 |
+
x = np.linspace(-5, 5, 100)
|
141 |
+
y = np.linspace(-5, 5, 100)
|
142 |
+
X, Y = np.meshgrid(x, y)
|
143 |
+
xy = np.column_stack((X.ravel(), Y.ravel()))
|
144 |
+
|
145 |
+
# 计算概率密度
|
146 |
+
Z = dataset.pdf(xy).reshape(X.shape)
|
147 |
+
|
148 |
+
# 创建2D和3D可视化
|
149 |
+
fig = make_subplots(
|
150 |
+
rows=1, cols=2,
|
151 |
+
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
|
152 |
+
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
|
153 |
+
)
|
154 |
+
|
155 |
+
# 3D Surface
|
156 |
+
surface = go.Surface(
|
157 |
+
x=X, y=Y, z=Z,
|
158 |
+
colorscale='viridis',
|
159 |
+
showscale=True,
|
160 |
+
colorbar=dict(x=0.45)
|
161 |
+
)
|
162 |
+
fig.add_trace(surface, row=1, col=1)
|
163 |
+
|
164 |
+
# Contour Plot with component centers
|
165 |
+
contour = go.Contour(
|
166 |
+
x=x, y=y, z=Z,
|
167 |
+
colorscale='viridis',
|
168 |
+
showscale=True,
|
169 |
+
colorbar=dict(x=1.0),
|
170 |
+
contours=dict(
|
171 |
+
showlabels=True,
|
172 |
+
labelfont=dict(size=12)
|
173 |
+
)
|
174 |
+
)
|
175 |
+
fig.add_trace(contour, row=1, col=2)
|
176 |
+
|
177 |
+
# 添加分量中心点
|
178 |
+
fig.add_trace(
|
179 |
+
go.Scatter(
|
180 |
+
x=centers[:K, 0], y=centers[:K, 1],
|
181 |
+
mode='markers+text',
|
182 |
+
marker=dict(size=10, color='red'),
|
183 |
+
text=[f'C{i+1}' for i in range(K)],
|
184 |
+
textposition="top center",
|
185 |
+
name='分量中心'
|
186 |
+
),
|
187 |
+
row=1, col=2
|
188 |
+
)
|
189 |
+
|
190 |
+
# 添加采样点(如果有)
|
191 |
+
if st.session_state.sample_points is not None:
|
192 |
+
samples = st.session_state.sample_points
|
193 |
+
# 计算每个样本点的概率密度
|
194 |
+
probs = dataset.pdf(samples)
|
195 |
+
# 计算每个样本点属于每个分量的后验概率
|
196 |
+
posteriors = []
|
197 |
+
for sample in samples:
|
198 |
+
component_probs = [
|
199 |
+
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
|
200 |
+
for k in range(K)
|
201 |
+
]
|
202 |
+
total = sum(component_probs)
|
203 |
+
posteriors.append([p/total for p in component_probs])
|
204 |
+
|
205 |
+
# 添加样本点到图表
|
206 |
+
fig.add_trace(
|
207 |
+
go.Scatter(
|
208 |
+
x=samples[:, 0], y=samples[:, 1],
|
209 |
+
mode='markers+text',
|
210 |
+
marker=dict(
|
211 |
+
size=8,
|
212 |
+
color='yellow',
|
213 |
+
line=dict(color='black', width=1)
|
214 |
+
),
|
215 |
+
text=[f'S{i+1}' for i in range(len(samples))],
|
216 |
+
textposition="bottom center",
|
217 |
+
name='采样点'
|
218 |
+
),
|
219 |
+
row=1, col=2
|
220 |
+
)
|
221 |
+
|
222 |
+
# 显示样本点的概率信息
|
223 |
+
st.subheader("采样点信息")
|
224 |
+
for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
|
225 |
+
st.write(f"样本点 S{i+1} ({sample[0]:.2f}, {sample[1]:.2f}):")
|
226 |
+
st.write(f"- 概率密度: {prob:.4f}")
|
227 |
+
st.write("- 后验概率:")
|
228 |
+
for k in range(K):
|
229 |
+
st.write(f" - 分量 {k+1}: {post[k]:.4f}")
|
230 |
+
st.write("---")
|
231 |
+
|
232 |
+
# 更新布局
|
233 |
+
fig.update_layout(
|
234 |
+
title='广义高斯混合分布',
|
235 |
+
showlegend=True,
|
236 |
+
width=1200,
|
237 |
+
height=600,
|
238 |
+
scene=dict(
|
239 |
+
xaxis_title='X',
|
240 |
+
yaxis_title='Y',
|
241 |
+
zaxis_title='密度'
|
242 |
+
)
|
243 |
+
)
|
244 |
+
|
245 |
+
# 更新2D图的坐标轴
|
246 |
+
fig.update_xaxes(title_text='X', row=1, col=2)
|
247 |
+
fig.update_yaxes(title_text='Y', row=1, col=2)
|
248 |
+
|
249 |
+
# 显示图形
|
250 |
+
st.plotly_chart(fig, use_container_width=True)
|
251 |
+
|
252 |
+
# 添加参数说明
|
253 |
+
with st.expander("分布参数说明"):
|
254 |
+
st.markdown("""
|
255 |
+
- **形状参数 (p)**:控制分布的形状
|
256 |
+
- p = 1: 拉普拉斯分布
|
257 |
+
- p = 2: 高斯分布
|
258 |
+
- p → ∞: 均匀分布
|
259 |
+
- **分量参数**:每个分量由以下参数确定
|
260 |
+
- 中心 (μ): 峰值位置,通过X和Y坐标确定
|
261 |
+
- 尺度 (α): 分布的展宽程度,X和Y方向可不同
|
262 |
+
- 权重 (π): 混合系数,所有分量权重和为1
|
263 |
+
""")
|
264 |
+
|
265 |
+
# 显示当前参数的数学公式
|
266 |
+
st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))
|
experiments/gmm_dataset.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pathlib import Path
|
3 |
+
from scipy.special import gamma
|
4 |
+
from typing import Optional, Tuple, Dict, List, Union
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
|
8 |
+
class GeneralizedGaussianMixture:
|
9 |
+
r"""广义高斯混合分布数据集生成器
|
10 |
+
P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
D: int = 2, # 维度
|
15 |
+
K: int = 3, # 聚类数量
|
16 |
+
p: float = 2.0, # 幂次,p=2为标准高斯分布
|
17 |
+
centers: Optional[np.ndarray] = None, # 聚类中心
|
18 |
+
scales: Optional[np.ndarray] = None, # 尺度参数
|
19 |
+
weights: Optional[np.ndarray] = None, # 混合权重
|
20 |
+
seed: int = 42): # 随机种子
|
21 |
+
"""初始化GMM数据集生成器
|
22 |
+
Args:
|
23 |
+
D: 数据维度
|
24 |
+
K: 聚类数量
|
25 |
+
p: 幂次参数,控制分布的形状
|
26 |
+
centers: 聚类中心,形状为(K, D)
|
27 |
+
scales: 尺度参数,形状为(K, D)
|
28 |
+
weights: 混合权重,形状为(K,)
|
29 |
+
seed: 随机种子
|
30 |
+
"""
|
31 |
+
self.D = D
|
32 |
+
self.K = K
|
33 |
+
self.p = p
|
34 |
+
self.seed = seed
|
35 |
+
np.random.seed(seed)
|
36 |
+
|
37 |
+
# 初始化分布参数
|
38 |
+
if centers is None:
|
39 |
+
self.centers = np.random.randn(K, D) * 2
|
40 |
+
else:
|
41 |
+
self.centers = centers
|
42 |
+
|
43 |
+
if scales is None:
|
44 |
+
self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
|
45 |
+
else:
|
46 |
+
self.scales = scales
|
47 |
+
|
48 |
+
if weights is None:
|
49 |
+
self.weights = np.random.dirichlet(np.ones(K))
|
50 |
+
else:
|
51 |
+
self.weights = weights / weights.sum() # 确保权重和为1
|
52 |
+
|
53 |
+
def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
|
54 |
+
"""计算第k个分量的概率密度
|
55 |
+
Args:
|
56 |
+
x: 输入数据点,形状为(N, D)
|
57 |
+
k: 分量索引
|
58 |
+
Returns:
|
59 |
+
概率密度值,形状为(N,)
|
60 |
+
"""
|
61 |
+
# 计算归一化常数
|
62 |
+
norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
|
63 |
+
|
64 |
+
# 计算|x_i - c_k|^p / α_k^p
|
65 |
+
z = np.abs(x - self.centers[k]) / self.scales[k]
|
66 |
+
exp_term = np.exp(-np.sum(z**self.p, axis=1))
|
67 |
+
|
68 |
+
return np.prod(norm_const) * exp_term
|
69 |
+
|
70 |
+
def pdf(self, x: np.ndarray) -> np.ndarray:
|
71 |
+
"""计算混合分布的概率密度
|
72 |
+
Args:
|
73 |
+
x: 输入数据点,形状为(N, D)
|
74 |
+
Returns:
|
75 |
+
概率密度值,形状为(N,)
|
76 |
+
"""
|
77 |
+
density = np.zeros(len(x))
|
78 |
+
for k in range(self.K):
|
79 |
+
density += self.weights[k] * self.component_pdf(x, k)
|
80 |
+
return density
|
81 |
+
|
82 |
+
def generate_component_samples(self, n: int, k: int) -> np.ndarray:
|
83 |
+
"""从第k个分量生成样本
|
84 |
+
Args:
|
85 |
+
n: 样本数量
|
86 |
+
k: 分量索引
|
87 |
+
Returns:
|
88 |
+
样本点,形状为(n, D)
|
89 |
+
"""
|
90 |
+
# 使用幂指数分布的反变换采样
|
91 |
+
u = np.random.uniform(-1, 1, size=(n, self.D))
|
92 |
+
r = np.abs(u) ** (1/self.p)
|
93 |
+
samples = self.centers[k] + self.scales[k] * np.sign(u) * r
|
94 |
+
return samples
|
95 |
+
|
96 |
+
def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
|
97 |
+
"""生成混合分布的样本
|
98 |
+
Args:
|
99 |
+
N: 总样本数量
|
100 |
+
Returns:
|
101 |
+
X: 生成的数据点,形状为(N, D)
|
102 |
+
y: 对应的概率密度值,形状为(N,)
|
103 |
+
"""
|
104 |
+
# 根据混合权重确定每个分量的样本数量
|
105 |
+
n_samples = np.random.multinomial(N, self.weights)
|
106 |
+
|
107 |
+
# 从每个分量生成样本
|
108 |
+
samples = []
|
109 |
+
for k in range(self.K):
|
110 |
+
x = self.generate_component_samples(n_samples[k], k)
|
111 |
+
samples.append(x)
|
112 |
+
|
113 |
+
# 合并并打乱样本
|
114 |
+
X = np.vstack(samples)
|
115 |
+
idx = np.random.permutation(N)
|
116 |
+
X = X[idx]
|
117 |
+
|
118 |
+
# 计算概率密度
|
119 |
+
y = self.pdf(X)
|
120 |
+
|
121 |
+
return X, y
|
122 |
+
|
123 |
+
def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
|
124 |
+
"""保存数据集到文件
|
125 |
+
Args:
|
126 |
+
save_dir: 保存目录
|
127 |
+
name: 数据集名称
|
128 |
+
"""
|
129 |
+
save_path = Path(save_dir)
|
130 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
131 |
+
|
132 |
+
# 生成并保存数据
|
133 |
+
X, y = self.generate_samples(N=1000)
|
134 |
+
np.savez(str(save_path / f'{name}.npz'),
|
135 |
+
X=X, y=y,
|
136 |
+
centers=self.centers,
|
137 |
+
scales=self.scales,
|
138 |
+
weights=self.weights,
|
139 |
+
D=self.D,
|
140 |
+
K=self.K,
|
141 |
+
p=self.p)
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
|
145 |
+
"""从文件加载数据集
|
146 |
+
Args:
|
147 |
+
file_path: 数据文件路径
|
148 |
+
Returns:
|
149 |
+
加载的GMM对象
|
150 |
+
"""
|
151 |
+
data = np.load(str(file_path))
|
152 |
+
return cls(
|
153 |
+
D=int(data['D']),
|
154 |
+
K=int(data['K']),
|
155 |
+
p=float(data['p']),
|
156 |
+
centers=data['centers'],
|
157 |
+
scales=data['scales'],
|
158 |
+
weights=data['weights']
|
159 |
+
)
|
160 |
+
|
161 |
+
def test_gmm_dataset():
|
162 |
+
"""测试GMM数据集生成器"""
|
163 |
+
# 创建2D的GMM数据集
|
164 |
+
gmm = GeneralizedGaussianMixture(
|
165 |
+
D=2,
|
166 |
+
K=3,
|
167 |
+
p=2.0,
|
168 |
+
centers=np.array([[-2, -2], [0, 0], [2, 2]]),
|
169 |
+
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
|
170 |
+
weights=np.array([0.3, 0.4, 0.3])
|
171 |
+
)
|
172 |
+
|
173 |
+
# 生成样本
|
174 |
+
X, y = gmm.generate_samples(1000)
|
175 |
+
|
176 |
+
# 保存数据集
|
177 |
+
gmm.save_dataset('test_data')
|
178 |
+
|
179 |
+
# 加载数据集
|
180 |
+
loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
|
181 |
+
|
182 |
+
# 验证保存和加载的参数是否一致
|
183 |
+
assert np.allclose(gmm.centers, loaded_gmm.centers)
|
184 |
+
assert np.allclose(gmm.scales, loaded_gmm.scales)
|
185 |
+
assert np.allclose(gmm.weights, loaded_gmm.weights)
|
186 |
+
|
187 |
+
print("GMM数据集测试通过!")
|
188 |
+
|
189 |
+
if __name__ == '__main__':
|
190 |
+
test_gmm_dataset()
|
experiments/gmm_fitting.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from sklearn.neural_network import MLPRegressor
|
4 |
+
from pathlib import Path
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
from typing import Any, Optional
|
10 |
+
|
11 |
+
# 添加pykan到Python路径
|
12 |
+
repo_root = Path(__file__).parent.parent.parent
|
13 |
+
sys.path.append(str(repo_root / 'pykan'))
|
14 |
+
|
15 |
+
from kan import *
|
16 |
+
# 针对gmm_dataset的导入,尝试不同的导入路径
|
17 |
+
try:
|
18 |
+
from .gmm_dataset import GeneralizedGaussianMixture
|
19 |
+
except ImportError:
|
20 |
+
from gmm_dataset import GeneralizedGaussianMixture
|
21 |
+
|
22 |
+
def train_and_evaluate(dataset: GeneralizedGaussianMixture,
|
23 |
+
save_dir: Path,
|
24 |
+
kan_config: Optional[dict[str, Any]] = None,
|
25 |
+
random_state: int = 42) -> dict[str, Any]:
|
26 |
+
"""训练和评估不同模型"""
|
27 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
# 生成训练和测试数据
|
30 |
+
X_train, y_train = dataset.generate_samples(N=1000)
|
31 |
+
X_test, y_test = dataset.generate_samples(N=200)
|
32 |
+
|
33 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
34 |
+
torch.set_default_dtype(torch.float64) # 设置为双精度
|
35 |
+
|
36 |
+
# 转换数据为PyTorch格式
|
37 |
+
train_data = {
|
38 |
+
'train_input': torch.FloatTensor(X_train).to(device),
|
39 |
+
'train_label': torch.FloatTensor(y_train).reshape(-1, 1).to(device),
|
40 |
+
'test_input': torch.FloatTensor(X_test).to(device),
|
41 |
+
'test_label': torch.FloatTensor(y_test).reshape(-1, 1).to(device)
|
42 |
+
}
|
43 |
+
|
44 |
+
# 保存训练数据
|
45 |
+
np.savez(save_dir / f'data_{random_state}.npz',
|
46 |
+
X_train=X_train, y_train=y_train,
|
47 |
+
X_test=X_test, y_test=y_test)
|
48 |
+
|
49 |
+
# 训练KAN
|
50 |
+
if kan_config is None:
|
51 |
+
kan_config = {
|
52 |
+
'width': [dataset.D, 5, 1],
|
53 |
+
'grid': 5,
|
54 |
+
'k': 3
|
55 |
+
}
|
56 |
+
|
57 |
+
# 确保device参数是字符串
|
58 |
+
kan_model = KAN(**kan_config, seed=random_state, device=str(device))
|
59 |
+
kan_model = kan_model.to(device) # 确保模型在正确的设备上
|
60 |
+
results = kan_model.fit(train_data, opt="LBFGS", steps=50, lamb=0.001)
|
61 |
+
|
62 |
+
# 训练MLP
|
63 |
+
mlp = MLPRegressor(
|
64 |
+
hidden_layer_sizes=(10, 5),
|
65 |
+
max_iter=1000,
|
66 |
+
random_state=random_state
|
67 |
+
)
|
68 |
+
mlp.fit(X_train, y_train)
|
69 |
+
|
70 |
+
# 计算和保存预测结果
|
71 |
+
grid_x = np.linspace(X_train.min(), X_train.max(), 100)
|
72 |
+
grid_y = np.linspace(X_train.min(), X_train.max(), 100)
|
73 |
+
XX, YY = np.meshgrid(grid_x, grid_y)
|
74 |
+
grid_points = np.column_stack((XX.ravel(), YY.ravel()))
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
kan_pred = kan_model(torch.FloatTensor(grid_points).to(device)).cpu().numpy()
|
78 |
+
mlp_pred = mlp.predict(grid_points)
|
79 |
+
true_density = dataset.pdf(grid_points)
|
80 |
+
|
81 |
+
# 计算测试集RMSE
|
82 |
+
kan_test_rmse = np.sqrt(np.mean((kan_model(train_data['test_input']).cpu().numpy() - y_test.reshape(-1, 1))**2))
|
83 |
+
mlp_test_rmse = np.sqrt(np.mean((mlp.predict(X_test).reshape(-1, 1) - y_test.reshape(-1, 1))**2))
|
84 |
+
|
85 |
+
evaluation = {
|
86 |
+
'random_state': random_state,
|
87 |
+
'kan_test_rmse': float(kan_test_rmse),
|
88 |
+
'mlp_test_rmse': float(mlp_test_rmse),
|
89 |
+
'training_history': results
|
90 |
+
}
|
91 |
+
|
92 |
+
# 保存预测结果
|
93 |
+
np.savez(save_dir / f'predictions_{random_state}.npz',
|
94 |
+
grid_points=grid_points,
|
95 |
+
kan_pred=kan_pred,
|
96 |
+
mlp_pred=mlp_pred,
|
97 |
+
true_density=true_density)
|
98 |
+
|
99 |
+
# 保存评估结果
|
100 |
+
with open(save_dir / f'evaluation_{random_state}.json', 'w') as f:
|
101 |
+
json.dump(evaluation, f)
|
102 |
+
|
103 |
+
return evaluation
|
104 |
+
|
105 |
+
def run_experiments(save_dir: Path, n_experiments: int = 5) -> dict[str, float]:
|
106 |
+
"""进行多次随机实验"""
|
107 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
108 |
+
|
109 |
+
all_results = []
|
110 |
+
base_seed = 42
|
111 |
+
|
112 |
+
for i in range(n_experiments):
|
113 |
+
print(f"Running experiment {i+1}/{n_experiments}")
|
114 |
+
random_state = base_seed + i
|
115 |
+
|
116 |
+
# 创建数据集
|
117 |
+
dataset = GeneralizedGaussianMixture(
|
118 |
+
D=2,
|
119 |
+
K=3,
|
120 |
+
p=2.0,
|
121 |
+
centers=np.array([[-2, -2], [0, 0], [2, 2]]),
|
122 |
+
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
|
123 |
+
weights=np.array([0.3, 0.4, 0.3]),
|
124 |
+
seed=random_state
|
125 |
+
)
|
126 |
+
|
127 |
+
# 训练和评估
|
128 |
+
result = train_and_evaluate(dataset, save_dir / str(random_state), random_state=random_state)
|
129 |
+
all_results.append(result)
|
130 |
+
|
131 |
+
# 保存所有结果
|
132 |
+
with open(save_dir / 'all_results.json', 'w') as f:
|
133 |
+
json.dump(all_results, f)
|
134 |
+
|
135 |
+
# 计算统计量
|
136 |
+
kan_rmses = [r['kan_test_rmse'] for r in all_results]
|
137 |
+
mlp_rmses = [r['mlp_test_rmse'] for r in all_results]
|
138 |
+
|
139 |
+
statistics = {
|
140 |
+
'kan_mean_rmse': float(np.mean(kan_rmses)),
|
141 |
+
'kan_std_rmse': float(np.std(kan_rmses)),
|
142 |
+
'mlp_mean_rmse': float(np.mean(mlp_rmses)),
|
143 |
+
'mlp_std_rmse': float(np.std(mlp_rmses)),
|
144 |
+
}
|
145 |
+
|
146 |
+
with open(save_dir / 'statistics.json', 'w') as f:
|
147 |
+
json.dump(statistics, f)
|
148 |
+
|
149 |
+
return statistics
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
# 使用相对路径,保存在experiments/results目录下
|
153 |
+
results_dir = Path(__file__).parent / 'results'
|
154 |
+
stats = run_experiments(results_dir)
|
155 |
+
print("\nExperiment Statistics:")
|
156 |
+
print(f"KAN Test RMSE: {stats['kan_mean_rmse']:.4f} ± {stats['kan_std_rmse']:.4f}")
|
157 |
+
print(f"MLP Test RMSE: {stats['mlp_mean_rmse']:.4f} ± {stats['mlp_std_rmse']:.4f}")
|
experiments/test.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
|
3 |
+
# 示例数据(请替换为实际数据)
|
4 |
+
methods = ['RSRM', 'PSRN', 'NGGP', 'PySR', 'BMS', 'uDSR', 'AIF',
|
5 |
+
'DGSR', 'E2E', 'SymINDy', 'PhySO', 'TPSR', 'SPL',
|
6 |
+
'DEAP', 'SINDy', 'NSRS', 'gplearn', 'SNIP', 'KAN', 'EQL']
|
7 |
+
recovery_rates = [85, 78, 92, 88, 76, 83, 95, 81, 89, 77, 84, 86, 80,
|
8 |
+
79, 82, 87, 75, 88, 90, 84] # 恢复率百分比
|
9 |
+
errors = [3, 4, 2, 3, 5, 2, 1, 3, 2, 4, 3, 2, 3, 4, 2, 3, 5, 2, 3, 2] # 误差范围
|
10 |
+
|
11 |
+
# 创建图形对象
|
12 |
+
fig = go.Figure()
|
13 |
+
|
14 |
+
# 添加带误差线的数据点
|
15 |
+
fig.add_trace(go.Scatter(
|
16 |
+
x=recovery_rates,
|
17 |
+
y=methods,
|
18 |
+
mode='markers',
|
19 |
+
error_x=dict(
|
20 |
+
type='data',
|
21 |
+
array=errors,
|
22 |
+
visible=True,
|
23 |
+
color='#FF5733',
|
24 |
+
thickness=2,
|
25 |
+
width=10
|
26 |
+
),
|
27 |
+
marker=dict(
|
28 |
+
size=12,
|
29 |
+
color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
|
30 |
+
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
|
31 |
+
'#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
|
32 |
+
'#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5'],
|
33 |
+
opacity=0.8
|
34 |
+
)
|
35 |
+
))
|
36 |
+
|
37 |
+
# 设置布局
|
38 |
+
fig.update_layout(
|
39 |
+
title='不同方法的恢复率比较',
|
40 |
+
xaxis=dict(
|
41 |
+
title='恢复率 (%)',
|
42 |
+
range=[0, 100],
|
43 |
+
dtick=20,
|
44 |
+
title_standoff=25
|
45 |
+
),
|
46 |
+
yaxis=dict(
|
47 |
+
title='Methods',
|
48 |
+
title_font=dict(size=14),
|
49 |
+
tickfont=dict(size=12),
|
50 |
+
autorange="reversed" # 使第一个方法显示在最上方
|
51 |
+
),
|
52 |
+
hovermode='closest',
|
53 |
+
width=1000,
|
54 |
+
height=600,
|
55 |
+
showlegend=False
|
56 |
+
)
|
57 |
+
|
58 |
+
# 添加注释(可选)
|
59 |
+
fig.add_annotation(
|
60 |
+
x=0,
|
61 |
+
y=0.95,
|
62 |
+
xref='paper',
|
63 |
+
yref='paper',
|
64 |
+
text='知乎 @x66ccff',
|
65 |
+
showarrow=False,
|
66 |
+
font=dict(size=10)
|
67 |
+
)
|
68 |
+
|
69 |
+
fig.show()
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit>=1.32.0
|
2 |
+
numpy>=1.21.0
|
3 |
+
pandas>=1.3.0
|
4 |
+
plotly>=5.18.0
|
5 |
+
scipy>=1.7.0
|
6 |
+
torch>=1.9.0
|
7 |
+
scikit-learn>=1.0.0
|
8 |
+
ipython>=8.0.0
|
9 |
+
ipywidgets>=7.0.0
|
10 |
+
nbformat>=5.0.0
|
11 |
+
sympy>=1.8
|