2catycm's picture
pykan
a299d07
import time
import streamlit as st
import numpy as np
from pathlib import Path
from experiments.gmm_dataset import GeneralizedGaussianMixture
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import List, Tuple
import torch
import os
import sys
import matplotlib.pyplot as plt
# Set torch path
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__ or "")]
# Add pykan to path
pykan_path = Path(__file__).parent.parent / 'third_party' / 'pykan'
sys.path.append(str(pykan_path))
# Import KAN related modules
from kan import KAN # type: ignore
from kan.utils import create_dataset, ex_round # type: ignore
# Set torch dtype
torch.set_default_dtype(torch.float64)
def show_kan_prediction(model, device, samples, placeholder, phase_name):
"""显示KAN的预测结果"""
# 生成网格数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# 使用KAN预测
grid_points = torch.from_numpy(xy).to(device)
with torch.no_grad():
Z_kan = model(grid_points).cpu().numpy().reshape(X.shape)
# 创建预测的概率密度图
fig_kan = make_subplots(
rows=1, cols=2,
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
subplot_titles=('KAN预测的3D概率密度曲面', 'KAN预测的等高线图')
)
# 3D Surface
surface_kan = go.Surface(
x=X, y=Y, z=Z_kan,
colorscale='viridis',
showscale=True,
colorbar=dict(x=0.45)
)
fig_kan.add_trace(surface_kan, row=1, col=1)
# Contour Plot
contour_kan = go.Contour(
x=x, y=y, z=Z_kan,
colorscale='viridis',
showscale=True,
colorbar=dict(x=1.0),
contours=dict(
showlabels=True,
labelfont=dict(size=12)
)
)
fig_kan.add_trace(contour_kan, row=1, col=2)
# 添加采样点
if samples is not None:
samples = samples.cpu().numpy() if torch.is_tensor(samples) else samples
fig_kan.add_trace(
go.Scatter(
x=samples[:, 0], y=samples[:, 1],
mode='markers',
marker=dict(
size=8,
color='yellow',
line=dict(color='black', width=1)
),
name='训练点'
),
row=1, col=2
)
# 更新布局
fig_kan.update_layout(
title='KAN预测分布',
showlegend=True,
width=1200,
height=600,
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='密度'
)
)
# 更新2D图的坐标轴
fig_kan.update_xaxes(title_text='X', row=1, col=2)
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
# 使用占位符显示图形
placeholder.plotly_chart(fig_kan,
use_container_width=False,
key=f"kan_plot_{phase_name}_{time.time()}")
def create_gmm_plot(dataset, centers, K, samples=None):
"""创建GMM分布的可视化图形"""
# 生成网格数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# 计算概率密度
Z = dataset.pdf(xy).reshape(X.shape)
# 创建2D和3D可视化
fig = make_subplots(
rows=1, cols=2,
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
)
# 3D Surface
surface = go.Surface(
x=X, y=Y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=0.45)
)
fig.add_trace(surface, row=1, col=1)
# Contour Plot
contour = go.Contour(
x=x, y=y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=1.0),
contours=dict(
showlabels=True,
labelfont=dict(size=12)
)
)
fig.add_trace(contour, row=1, col=2)
# 添加分量中心点
fig.add_trace(
go.Scatter(
x=centers[:K, 0], y=centers[:K, 1],
mode='markers+text',
marker=dict(size=10, color='red'),
text=[f'C{i+1}' for i in range(K)],
textposition="top center",
name='分量中心'
),
row=1, col=2
)
# 添加采样点(如果有)
if samples is not None:
fig.add_trace(
go.Scatter(
x=samples[:, 0], y=samples[:, 1],
mode='markers+text',
marker=dict(
size=8,
color='yellow',
line=dict(color='black', width=1)
),
text=[f'S{i+1}' for i in range(len(samples))],
textposition="bottom center",
name='采样点'
),
row=1, col=2
)
# 更新布局
fig.update_layout(
title='广义高斯混合分布',
showlegend=True,
width=1200,
height=600,
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='密度'
)
)
# 更新2D图的坐标轴
fig.update_xaxes(title_text='X', row=1, col=2)
fig.update_yaxes(title_text='Y', row=1, col=2)
return fig
def train_kan(samples, gmm_dataset, device='cuda'):
"""训练KAN网络"""
if torch.cuda.is_available() and device == 'cuda':
device = torch.device('cuda')
else:
device = torch.device('cpu')
st.info(f"使用设备: {device} 训练网络")
# 转换采样点为tensor
samples = torch.from_numpy(samples).to(device)
# 计算标签(概率密度值)
labels = torch.from_numpy(gmm_dataset.pdf(samples.cpu().numpy())).reshape(-1, 1).to(device)
# 创建KAN模型
model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device)
# 创建训练和测试数据集
train_size = int(0.8 * samples.shape[0])
train_dataset = {
'train_input': samples[:train_size],
'train_label': labels[:train_size],
'test_input': samples[train_size:],
'test_label': labels[train_size:]
}
# 创建训练进度显示组件
# st.write("网络预测分布:")
st.write("网络图形结构:")
kan_network_arch_placeholder = st.empty()
progress_container = st.container()
# total_steps = 100
total_steps = 50
steps_per_update = 10
def calculate_error(model, x, y):
"""计算预测误差"""
with torch.no_grad():
pred = model(x)
return torch.mean((pred - y) ** 2).item()
def train_phase(phase_name, steps, lamb=None, show_plot=True):
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
for step in range(0, steps, steps_per_update):
# 训练几步
if lamb is not None:
model.fit(train_dataset, opt="LBFGS", steps=steps_per_update, lamb=lamb)
else:
model.fit(train_dataset, opt="LBFGS", steps=steps_per_update)
# 更新进度和误差
progress = (step + steps_per_update) / steps
progress_bar.progress(progress)
# 计算当前误差
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
# 使用表格格式显示进度和误差
status_text.markdown(f"""
### {phase_name}
| 项目 | 值 |
|:---|:---|
| 进度 | {progress:.0%} |
| 训练误差 | {train_error:.8f} |
| 测试误差 | {test_error:.8f} |
""")
# 更新可视化(每5步更新一次)
# if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
# # 更新预测结果
# show_kan_prediction(model, device, samples, kan_plot_placeholder, phase_name)
# 更新网络结构图(可选)
if show_plot:
try:
model.plot()
kan_fig = plt.gcf()
# if isinstance(kan_fig, tuple):
# kan_fig = kan_fig[0] # 如果是元组,取第一个元素
# if kan_fig is not None:
kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False)
# plt.close('all') # 确保关闭所有图形
except Exception as e:
if step == 0: # 只在第一次出错时显示警告
st.warning(f"注意:网络结构图显示失败 ({str(e)})")
# 更新进度和预测结果
show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name)
with progress_container:
st.markdown("#### 训练过程")
error_text = st.empty()
# 第一阶段训练
# 第一阶段:初始训练
with st.spinner("参数调整中..."):
train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True)
# 剪枝阶段
with st.spinner("正在进行网络剪枝优化..."):
model = model.prune()
progress_container.info("网络剪枝完成")
with st.spinner("参数调整中..."):
train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True)
with st.spinner("正在进行网格精细化..."):
model = model.refine(10)
progress_container.info("网格精细化完成")
with st.spinner("参数调整中..."):
train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True)
with st.spinner("符号简化中..."):
# model = model.prune()
# progress_container.info("网络剪枝完成")
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
# model.auto_symbolic()
progress_container.info("符号简化完成")
with st.spinner("参数调整中..."):
train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True)
from kan.utils import ex_round
from sympy import latex
s= ex_round(model.symbolic_formula()[0][0],4)
st.write("网络公式:")
st.latex(latex(s))
# 显示最终误差
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label'])
error_text.markdown(f"""
#### 训练结果
- 训练集误差: {train_error:.6f}
- 测试集误差: {test_error:.6f}
""")
progress_container.success("🎉 训练完成!")
return model
def init_session_state():
"""初始化session state"""
if 'prev_K' not in st.session_state:
st.session_state.prev_K = 3
if 'p' not in st.session_state:
st.session_state.p = 2.0
if 'centers' not in st.session_state:
st.session_state.centers = np.array([[-2, -2], [0, 0], [2, 2]], dtype=np.float64)
if 'scales' not in st.session_state:
st.session_state.scales = np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]], dtype=np.float64)
if 'weights' not in st.session_state:
st.session_state.weights = np.ones(3, dtype=np.float64) / 3
if 'sample_points' not in st.session_state:
st.session_state.sample_points = None
if 'kan_model' not in st.session_state:
st.session_state.kan_model = None
def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""创建默认参数"""
# 在[-3, 3]范围内均匀生成K个中心点
x = np.linspace(-3, 3, K)
y = np.linspace(-3, 3, K)
centers = np.column_stack((x, y))
# 默认尺度和权重
scales = np.ones((K, 2), dtype=np.float64) * 3
weights = np.random.random(size=K).astype(np.float64)
weights /= weights.sum() # 归一化权重
return centers, scales, weights
def generate_latex_formula(p: float, K: int, centers: np.ndarray,
scales: np.ndarray, weights: np.ndarray) -> str:
"""生成LaTeX公式"""
formula = r"P(x) = \sum_{k=1}^{" + str(K) + r"} \pi_k P_{\theta_k}(x) \\"
formula += r"P_{\theta_k}(x) = \eta_k \exp(-s_k d_k(x)) = \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-\frac{|x-c_k|^p}{\alpha_k^p})= \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-|\frac{x-c_k}{\alpha_k}|^p) \\"
formula += r"\text{where: }"
for k in range(K):
c = centers[k]
s = scales[k]
w = weights[k]
component = f"P_{{\\theta_{k+1}}}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\"
formula += component
formula += f"\\pi_{k+1} = {w:.2f} \\\\"
return formula
st.set_page_config(page_title="GMM Distribution Visualization", layout="wide")
st.title("广义高斯混合分布可视化")
# 初始化session state
init_session_state()
# 侧边栏参数设置
with st.sidebar:
st.header("分布参数")
# 分布基本参数
st.session_state.p = st.slider("形状参数 (p)", 0.1, 5.0, st.session_state.p, 0.1,
help="p=1: 拉普拉斯分布, p=2: 高斯分布, p→∞: 均匀分布")
K = st.slider("分量数 (K)", 1, 5, st.session_state.prev_K)
# 如果K发生变化,重新初始化参数
if K != st.session_state.prev_K:
centers, scales, weights = create_default_parameters(K)
st.session_state.centers = centers
st.session_state.scales = scales
st.session_state.weights = weights
st.session_state.prev_K = K
# 高级参数设置
st.subheader("高级设置")
show_advanced = st.checkbox("显示分量参数", value=False)
if show_advanced:
# 为每个分量设置参数
centers_list: List[List[float]] = []
scales_list: List[List[float]] = []
weights_list: List[float] = []
for k in range(K):
st.write(f"分量 {k+1}")
col1, col2 = st.columns(2)
with col1:
cx = st.number_input(f"中心X_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][0]), 0.1)
cy = st.number_input(f"中心Y_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][1]), 0.1)
with col2:
sx = st.number_input(f"尺度X_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][0]), 0.1)
sy = st.number_input(f"尺度Y_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][1]), 0.1)
w = st.slider(f"权重_{k+1}", 0.0, 1.0, float(st.session_state.weights[k]), 0.1)
centers_list.append([cx, cy])
scales_list.append([sx, sy])
weights_list.append(w)
centers = np.array(centers_list, dtype=np.float64)
scales = np.array(scales_list, dtype=np.float64)
weights = np.array(weights_list, dtype=np.float64)
weights = weights / weights.sum()
st.session_state.centers = centers
st.session_state.scales = scales
st.session_state.weights = weights
else:
centers = st.session_state.centers
scales = st.session_state.scales
weights = st.session_state.weights
# 采样设置
st.subheader("采样设置")
n_samples = st.slider("采样点数", 5, 1000, 100)
if st.button("重新采样"):
# 创建GMM数据集进行采样
gmm = GeneralizedGaussianMixture(
D=2,
K=K,
p=st.session_state.p,
centers=centers[:K],
scales=scales[:K],
weights=weights[:K]
)
# 使用GMM生成采样点
samples, _ = gmm.generate_samples(n_samples)
st.session_state.sample_points = samples
st.session_state.kan_model = None # 重置KAN模型
# 创建GMM数据集
dataset = GeneralizedGaussianMixture(
D=2,
K=K,
p=st.session_state.p,
centers=centers[:K],
scales=scales[:K],
weights=weights[:K]
)
# 生成网格数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# 计算概率密度
Z = dataset.pdf(xy).reshape(X.shape)
# 创建2D和3D可视化
fig = make_subplots(
rows=1, cols=2,
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
)
# 3D Surface
surface = go.Surface(
x=X, y=Y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=0.45)
)
fig.add_trace(surface, row=1, col=1)
# Contour Plot with component centers
contour = go.Contour(
x=x, y=y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=1.0),
contours=dict(
showlabels=True,
labelfont=dict(size=12)
)
)
fig.add_trace(contour, row=1, col=2)
# 添加分量中心点
fig.add_trace(
go.Scatter(
x=centers[:K, 0], y=centers[:K, 1],
mode='markers+text',
marker=dict(size=10, color='red'),
text=[f'C{i+1}' for i in range(K)],
textposition="top center",
name='分量中心'
),
row=1, col=2
)
# 添加采样点(如果有)
if st.session_state.sample_points is not None:
samples = st.session_state.sample_points
# 计算每个样本点的概率密度
probs = dataset.pdf(samples)
# 计算每个样本点属于每个分量的后验概率
posteriors = []
for sample in samples:
component_probs = [
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
for k in range(K)
]
total = sum(component_probs)
posteriors.append([p/total for p in component_probs])
# 添加样本点到图表
fig.add_trace(
go.Scatter(
x=samples[:, 0], y=samples[:, 1],
mode='markers+text',
marker=dict(
size=8,
color='yellow',
line=dict(color='black', width=1)
),
text=[f'S{i+1}' for i in range(len(samples))],
textposition="bottom center",
name='采样点'
),
row=1, col=2
)
# 更新布局
fig.update_layout(
title='广义高斯混合分布',
showlegend=True,
width=1200,
height=600,
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='密度'
)
)
# 更新2D图的坐标轴
fig.update_xaxes(title_text='X', row=1, col=2)
fig.update_yaxes(title_text='Y', row=1, col=2)
# 显示GMM主图
st.plotly_chart(fig, use_container_width=False)
# KAN网络训练和预测部分
if st.session_state.sample_points is not None:
st.markdown("---")
st.subheader("KAN网络训练与预测")
kan_distribution_plot_placeholder = st.empty()
# 训练控制按钮
col1, col2, col3 = st.columns([1, 2, 1])
with col1:
if st.button("拟合KAN", use_container_width=False):
with st.spinner('训练KAN网络中...'):
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
st.balloons()
with col3:
if st.session_state.kan_model is not None:
if st.button("清除KAN结果", use_container_width=False):
st.session_state.kan_model = None
st.rerun()
# 显示KAN预测结果
# if st.session_state.kan_model is not None:
# st.subheader("KAN预测结果")
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# kan_plot_placeholder = st.empty()
# show_kan_prediction(st.session_state.kan_model, device,
# st.session_state.sample_points, kan_plot_placeholder, "显示结果")
st.markdown("---")
# 显示采样点信息
if st.session_state.sample_points is not None:
# 重新计算采样点的概率密度和后验概率
samples = st.session_state.sample_points
probs = dataset.pdf(samples)
posteriors = []
for sample in samples:
component_probs = [
weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p))
for k in range(K)
]
total = sum(component_probs)
posteriors.append([p/total for p in component_probs])
with st.expander("采样点信息"):
# 创建数据列表
point_data = []
for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)):
row = {
'采样点': f'S{i+1}',
'X坐标': f'{sample[0]:.2f}',
'Y坐标': f'{sample[1]:.2f}',
'概率密度': f'{prob:.4f}'
}
# 添加每个分量的后验概率
for k in range(K):
row[f'分量{k+1}后验概率'] = f'{post[k]:.4f}'
point_data.append(row)
# 显示dataframe
st.dataframe(point_data)
# 添加参数说明
with st.expander("分布参数说明"):
st.markdown("""
- **形状参数 (p)**:控制分布的形状
- p = 1: 拉普拉斯分布
- p = 2: 高斯分布
- p → ∞: 均匀分布
- **分量参数**:每个分量由以下参数确定
- 中心 (μ): 峰值位置,通过X和Y坐标确定
- 尺度 (α): 分布的展宽程度,X和Y方向可不同
- 权重 (π): 混合系数,所有分量权重和为1
""")
# 显示当前参数的数学公式
with st.expander("分布概率密度函数公式"):
st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))