import time import streamlit as st import numpy as np from pathlib import Path from experiments.gmm_dataset import GeneralizedGaussianMixture import plotly.graph_objects as go from plotly.subplots import make_subplots from typing import List, Tuple import torch import os import sys import matplotlib.pyplot as plt # Set torch path torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__ or "")] # Add pykan to path pykan_path = Path(__file__).parent.parent / 'third_party' / 'pykan' sys.path.append(str(pykan_path)) # Import KAN related modules from kan import KAN # type: ignore from kan.utils import create_dataset, ex_round # type: ignore # Set torch dtype torch.set_default_dtype(torch.float64) def show_kan_prediction(model, device, samples, placeholder, phase_name): """显示KAN的预测结果""" # 生成网格数据 x = np.linspace(-5, 5, 100) y = np.linspace(-5, 5, 100) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) # 使用KAN预测 grid_points = torch.from_numpy(xy).to(device) with torch.no_grad(): Z_kan = model(grid_points).cpu().numpy().reshape(X.shape) # 创建预测的概率密度图 fig_kan = make_subplots( rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'contour'}]], subplot_titles=('KAN预测的3D概率密度曲面', 'KAN预测的等高线图') ) # 3D Surface surface_kan = go.Surface( x=X, y=Y, z=Z_kan, colorscale='viridis', showscale=True, colorbar=dict(x=0.45) ) fig_kan.add_trace(surface_kan, row=1, col=1) # Contour Plot contour_kan = go.Contour( x=x, y=y, z=Z_kan, colorscale='viridis', showscale=True, colorbar=dict(x=1.0), contours=dict( showlabels=True, labelfont=dict(size=12) ) ) fig_kan.add_trace(contour_kan, row=1, col=2) # 添加采样点 if samples is not None: samples = samples.cpu().numpy() if torch.is_tensor(samples) else samples fig_kan.add_trace( go.Scatter( x=samples[:, 0], y=samples[:, 1], mode='markers', marker=dict( size=8, color='yellow', line=dict(color='black', width=1) ), name='训练点' ), row=1, col=2 ) # 更新布局 fig_kan.update_layout( title='KAN预测分布', showlegend=True, width=1200, height=600, scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='密度' ) ) # 更新2D图的坐标轴 fig_kan.update_xaxes(title_text='X', row=1, col=2) fig_kan.update_yaxes(title_text='Y', row=1, col=2) # 使用占位符显示图形 placeholder.plotly_chart(fig_kan, use_container_width=False, key=f"kan_plot_{phase_name}_{time.time()}") def create_gmm_plot(dataset, centers, K, samples=None): """创建GMM分布的可视化图形""" # 生成网格数据 x = np.linspace(-5, 5, 100) y = np.linspace(-5, 5, 100) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) # 计算概率密度 Z = dataset.pdf(xy).reshape(X.shape) # 创建2D和3D可视化 fig = make_subplots( rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'contour'}]], subplot_titles=('3D概率密度曲面', '等高线图与分量中心') ) # 3D Surface surface = go.Surface( x=X, y=Y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=0.45) ) fig.add_trace(surface, row=1, col=1) # Contour Plot contour = go.Contour( x=x, y=y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=1.0), contours=dict( showlabels=True, labelfont=dict(size=12) ) ) fig.add_trace(contour, row=1, col=2) # 添加分量中心点 fig.add_trace( go.Scatter( x=centers[:K, 0], y=centers[:K, 1], mode='markers+text', marker=dict(size=10, color='red'), text=[f'C{i+1}' for i in range(K)], textposition="top center", name='分量中心' ), row=1, col=2 ) # 添加采样点(如果有) if samples is not None: fig.add_trace( go.Scatter( x=samples[:, 0], y=samples[:, 1], mode='markers+text', marker=dict( size=8, color='yellow', line=dict(color='black', width=1) ), text=[f'S{i+1}' for i in range(len(samples))], textposition="bottom center", name='采样点' ), row=1, col=2 ) # 更新布局 fig.update_layout( title='广义高斯混合分布', showlegend=True, width=1200, height=600, scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='密度' ) ) # 更新2D图的坐标轴 fig.update_xaxes(title_text='X', row=1, col=2) fig.update_yaxes(title_text='Y', row=1, col=2) return fig def train_kan(samples, gmm_dataset, device='cuda'): """训练KAN网络""" if torch.cuda.is_available() and device == 'cuda': device = torch.device('cuda') else: device = torch.device('cpu') st.info(f"使用设备: {device} 训练网络") # 转换采样点为tensor samples = torch.from_numpy(samples).to(device) # 计算标签(概率密度值) labels = torch.from_numpy(gmm_dataset.pdf(samples.cpu().numpy())).reshape(-1, 1).to(device) # 创建KAN模型 model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device) # 创建训练和测试数据集 train_size = int(0.8 * samples.shape[0]) train_dataset = { 'train_input': samples[:train_size], 'train_label': labels[:train_size], 'test_input': samples[train_size:], 'test_label': labels[train_size:] } # 创建训练进度显示组件 # st.write("网络预测分布:") st.write("网络图形结构:") kan_network_arch_placeholder = st.empty() progress_container = st.container() # total_steps = 100 total_steps = 50 steps_per_update = 10 def calculate_error(model, x, y): """计算预测误差""" with torch.no_grad(): pred = model(x) return torch.mean((pred - y) ** 2).item() def train_phase(phase_name, steps, lamb=None, show_plot=True): with progress_container: progress_bar = st.progress(0) status_text = st.empty() for step in range(0, steps, steps_per_update): # 训练几步 if lamb is not None: model.fit(train_dataset, opt="LBFGS", steps=steps_per_update, lamb=lamb) else: model.fit(train_dataset, opt="LBFGS", steps=steps_per_update) # 更新进度和误差 progress = (step + steps_per_update) / steps progress_bar.progress(progress) # 计算当前误差 train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label']) test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label']) # 使用表格格式显示进度和误差 status_text.markdown(f""" ### {phase_name} | 项目 | 值 | |:---|:---| | 进度 | {progress:.0%} | | 训练误差 | {train_error:.8f} | | 测试误差 | {test_error:.8f} | """) # 更新可视化(每5步更新一次) # if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps: # # 更新预测结果 # show_kan_prediction(model, device, samples, kan_plot_placeholder, phase_name) # 更新网络结构图(可选) if show_plot: try: model.plot() kan_fig = plt.gcf() # if isinstance(kan_fig, tuple): # kan_fig = kan_fig[0] # 如果是元组,取第一个元素 # if kan_fig is not None: kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False) # plt.close('all') # 确保关闭所有图形 except Exception as e: if step == 0: # 只在第一次出错时显示警告 st.warning(f"注意:网络结构图显示失败 ({str(e)})") # 更新进度和预测结果 show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name) with progress_container: st.markdown("#### 训练过程") error_text = st.empty() # 第一阶段训练 # 第一阶段:初始训练 with st.spinner("参数调整中..."): train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True) # 剪枝阶段 with st.spinner("正在进行网络剪枝优化..."): model = model.prune() progress_container.info("网络剪枝完成") with st.spinner("参数调整中..."): train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True) with st.spinner("正在进行网格精细化..."): model = model.refine(10) progress_container.info("网格精细化完成") with st.spinner("参数调整中..."): train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True) with st.spinner("符号简化中..."): # model = model.prune() # progress_container.info("网络剪枝完成") lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs'] model.auto_symbolic(lib=lib) # model.auto_symbolic() progress_container.info("符号简化完成") with st.spinner("参数调整中..."): train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True) from kan.utils import ex_round from sympy import latex s= ex_round(model.symbolic_formula()[0][0],4) st.write("网络公式:") st.latex(latex(s)) # 显示最终误差 train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label']) test_error = calculate_error(model, train_dataset['test_input'], train_dataset['test_label']) error_text.markdown(f""" #### 训练结果 - 训练集误差: {train_error:.6f} - 测试集误差: {test_error:.6f} """) progress_container.success("🎉 训练完成!") return model def init_session_state(): """初始化session state""" if 'prev_K' not in st.session_state: st.session_state.prev_K = 3 if 'p' not in st.session_state: st.session_state.p = 2.0 if 'centers' not in st.session_state: st.session_state.centers = np.array([[-2, -2], [0, 0], [2, 2]], dtype=np.float64) if 'scales' not in st.session_state: st.session_state.scales = np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]], dtype=np.float64) if 'weights' not in st.session_state: st.session_state.weights = np.ones(3, dtype=np.float64) / 3 if 'sample_points' not in st.session_state: st.session_state.sample_points = None if 'kan_model' not in st.session_state: st.session_state.kan_model = None def create_default_parameters(K: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """创建默认参数""" # 在[-3, 3]范围内均匀生成K个中心点 x = np.linspace(-3, 3, K) y = np.linspace(-3, 3, K) centers = np.column_stack((x, y)) # 默认尺度和权重 scales = np.ones((K, 2), dtype=np.float64) * 3 weights = np.random.random(size=K).astype(np.float64) weights /= weights.sum() # 归一化权重 return centers, scales, weights def generate_latex_formula(p: float, K: int, centers: np.ndarray, scales: np.ndarray, weights: np.ndarray) -> str: """生成LaTeX公式""" formula = r"P(x) = \sum_{k=1}^{" + str(K) + r"} \pi_k P_{\theta_k}(x) \\" formula += r"P_{\theta_k}(x) = \eta_k \exp(-s_k d_k(x)) = \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-\frac{|x-c_k|^p}{\alpha_k^p})= \frac{p}{2\alpha_k \Gamma(1/p) }\exp(-|\frac{x-c_k}{\alpha_k}|^p) \\" formula += r"\text{where: }" for k in range(K): c = centers[k] s = scales[k] w = weights[k] component = f"P_{{\\theta_{k+1}}}(x) = \\frac{{{p:.1f}}}{{2\\alpha_{k+1} \\Gamma(1/{p:.1f})}}\\exp(-|\\frac{{x-({c[0]:.1f}, {c[1]:.1f})}}{{{s[0]:.1f}, {s[1]:.1f}}}|^{{{p:.1f}}}) \\\\" formula += component formula += f"\\pi_{k+1} = {w:.2f} \\\\" return formula st.set_page_config(page_title="GMM Distribution Visualization", layout="wide") st.title("广义高斯混合分布可视化") # 初始化session state init_session_state() # 侧边栏参数设置 with st.sidebar: st.header("分布参数") # 分布基本参数 st.session_state.p = st.slider("形状参数 (p)", 0.1, 5.0, st.session_state.p, 0.1, help="p=1: 拉普拉斯分布, p=2: 高斯分布, p→∞: 均匀分布") K = st.slider("分量数 (K)", 1, 5, st.session_state.prev_K) # 如果K发生变化,重新初始化参数 if K != st.session_state.prev_K: centers, scales, weights = create_default_parameters(K) st.session_state.centers = centers st.session_state.scales = scales st.session_state.weights = weights st.session_state.prev_K = K # 高级参数设置 st.subheader("高级设置") show_advanced = st.checkbox("显示分量参数", value=False) if show_advanced: # 为每个分量设置参数 centers_list: List[List[float]] = [] scales_list: List[List[float]] = [] weights_list: List[float] = [] for k in range(K): st.write(f"分量 {k+1}") col1, col2 = st.columns(2) with col1: cx = st.number_input(f"中心X_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][0]), 0.1) cy = st.number_input(f"中心Y_{k+1}", -5.0, 5.0, float(st.session_state.centers[k][1]), 0.1) with col2: sx = st.number_input(f"尺度X_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][0]), 0.1) sy = st.number_input(f"尺度Y_{k+1}", 0.1, 3.0, float(st.session_state.scales[k][1]), 0.1) w = st.slider(f"权重_{k+1}", 0.0, 1.0, float(st.session_state.weights[k]), 0.1) centers_list.append([cx, cy]) scales_list.append([sx, sy]) weights_list.append(w) centers = np.array(centers_list, dtype=np.float64) scales = np.array(scales_list, dtype=np.float64) weights = np.array(weights_list, dtype=np.float64) weights = weights / weights.sum() st.session_state.centers = centers st.session_state.scales = scales st.session_state.weights = weights else: centers = st.session_state.centers scales = st.session_state.scales weights = st.session_state.weights # 采样设置 st.subheader("采样设置") n_samples = st.slider("采样点数", 5, 1000, 100) if st.button("重新采样"): # 创建GMM数据集进行采样 gmm = GeneralizedGaussianMixture( D=2, K=K, p=st.session_state.p, centers=centers[:K], scales=scales[:K], weights=weights[:K] ) # 使用GMM生成采样点 samples, _ = gmm.generate_samples(n_samples) st.session_state.sample_points = samples st.session_state.kan_model = None # 重置KAN模型 # 创建GMM数据集 dataset = GeneralizedGaussianMixture( D=2, K=K, p=st.session_state.p, centers=centers[:K], scales=scales[:K], weights=weights[:K] ) # 生成网格数据 x = np.linspace(-5, 5, 100) y = np.linspace(-5, 5, 100) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) # 计算概率密度 Z = dataset.pdf(xy).reshape(X.shape) # 创建2D和3D可视化 fig = make_subplots( rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'contour'}]], subplot_titles=('3D概率密度曲面', '等高线图与分量中心') ) # 3D Surface surface = go.Surface( x=X, y=Y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=0.45) ) fig.add_trace(surface, row=1, col=1) # Contour Plot with component centers contour = go.Contour( x=x, y=y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=1.0), contours=dict( showlabels=True, labelfont=dict(size=12) ) ) fig.add_trace(contour, row=1, col=2) # 添加分量中心点 fig.add_trace( go.Scatter( x=centers[:K, 0], y=centers[:K, 1], mode='markers+text', marker=dict(size=10, color='red'), text=[f'C{i+1}' for i in range(K)], textposition="top center", name='分量中心' ), row=1, col=2 ) # 添加采样点(如果有) if st.session_state.sample_points is not None: samples = st.session_state.sample_points # 计算每个样本点的概率密度 probs = dataset.pdf(samples) # 计算每个样本点属于每个分量的后验概率 posteriors = [] for sample in samples: component_probs = [ weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p)) for k in range(K) ] total = sum(component_probs) posteriors.append([p/total for p in component_probs]) # 添加样本点到图表 fig.add_trace( go.Scatter( x=samples[:, 0], y=samples[:, 1], mode='markers+text', marker=dict( size=8, color='yellow', line=dict(color='black', width=1) ), text=[f'S{i+1}' for i in range(len(samples))], textposition="bottom center", name='采样点' ), row=1, col=2 ) # 更新布局 fig.update_layout( title='广义高斯混合分布', showlegend=True, width=1200, height=600, scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='密度' ) ) # 更新2D图的坐标轴 fig.update_xaxes(title_text='X', row=1, col=2) fig.update_yaxes(title_text='Y', row=1, col=2) # 显示GMM主图 st.plotly_chart(fig, use_container_width=False) # KAN网络训练和预测部分 if st.session_state.sample_points is not None: st.markdown("---") st.subheader("KAN网络训练与预测") kan_distribution_plot_placeholder = st.empty() # 训练控制按钮 col1, col2, col3 = st.columns([1, 2, 1]) with col1: if st.button("拟合KAN", use_container_width=False): with st.spinner('训练KAN网络中...'): st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset) st.balloons() with col3: if st.session_state.kan_model is not None: if st.button("清除KAN结果", use_container_width=False): st.session_state.kan_model = None st.rerun() # 显示KAN预测结果 # if st.session_state.kan_model is not None: # st.subheader("KAN预测结果") # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # kan_plot_placeholder = st.empty() # show_kan_prediction(st.session_state.kan_model, device, # st.session_state.sample_points, kan_plot_placeholder, "显示结果") st.markdown("---") # 显示采样点信息 if st.session_state.sample_points is not None: # 重新计算采样点的概率密度和后验概率 samples = st.session_state.sample_points probs = dataset.pdf(samples) posteriors = [] for sample in samples: component_probs = [ weights[k] * np.exp(-np.sum(((sample - centers[k]) / scales[k])**st.session_state.p)) for k in range(K) ] total = sum(component_probs) posteriors.append([p/total for p in component_probs]) with st.expander("采样点信息"): # 创建数据列表 point_data = [] for i, (sample, prob, post) in enumerate(zip(samples, probs, posteriors)): row = { '采样点': f'S{i+1}', 'X坐标': f'{sample[0]:.2f}', 'Y坐标': f'{sample[1]:.2f}', '概率密度': f'{prob:.4f}' } # 添加每个分量的后验概率 for k in range(K): row[f'分量{k+1}后验概率'] = f'{post[k]:.4f}' point_data.append(row) # 显示dataframe st.dataframe(point_data) # 添加参数说明 with st.expander("分布参数说明"): st.markdown(""" - **形状参数 (p)**:控制分布的形状 - p = 1: 拉普拉斯分布 - p = 2: 高斯分布 - p → ∞: 均匀分布 - **分量参数**:每个分量由以下参数确定 - 中心 (μ): 峰值位置,通过X和Y坐标确定 - 尺度 (α): 分布的展宽程度,X和Y方向可不同 - 权重 (π): 混合系数,所有分量权重和为1 """) # 显示当前参数的数学公式 with st.expander("分布概率密度函数公式"): st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))