2catycm commited on
Commit
06aa092
·
1 Parent(s): 86e568b

feat: 修复多个bug

Browse files
Files changed (1) hide show
  1. app.py +80 -42
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import numpy as np
3
  from pathlib import Path
@@ -24,7 +25,7 @@ from kan.utils import create_dataset, ex_round # type: ignore
24
  # Set torch dtype
25
  torch.set_default_dtype(torch.float64)
26
 
27
- def show_kan_prediction(model, device, samples, placeholder):
28
  """显示KAN的预测结果"""
29
  # 生成网格数据
30
  x = np.linspace(-5, 5, 100)
@@ -84,7 +85,7 @@ def show_kan_prediction(model, device, samples, placeholder):
84
 
85
  # 更新布局
86
  fig_kan.update_layout(
87
- title='KAN预测结果',
88
  showlegend=True,
89
  width=1200,
90
  height=600,
@@ -100,7 +101,10 @@ def show_kan_prediction(model, device, samples, placeholder):
100
  fig_kan.update_yaxes(title_text='Y', row=1, col=2)
101
 
102
  # 使用占位符显示图形
103
- placeholder.plotly_chart(fig_kan, use_container_width=True)
 
 
 
104
 
105
  def create_gmm_plot(dataset, centers, K, samples=None):
106
  """创建GMM分布的可视化图形"""
@@ -198,6 +202,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
198
  device = torch.device('cuda')
199
  else:
200
  device = torch.device('cpu')
 
201
 
202
  # 转换采样点为tensor
203
  samples = torch.from_numpy(samples).to(device)
@@ -216,13 +221,17 @@ def train_kan(samples, gmm_dataset, device='cuda'):
216
  }
217
 
218
  # 创建训练进度显示组件
219
- st.write("网络结构:")
220
- kan_fig_placeholder = st.empty()
221
- st.write("预测结果:")
222
- kan_plot_placeholder = st.empty()
 
 
 
223
  progress_container = st.container()
224
 
225
- total_steps = 100
 
226
  steps_per_update = 10
227
 
228
  def calculate_error(model, x, y):
@@ -260,44 +269,69 @@ def train_kan(samples, gmm_dataset, device='cuda'):
260
  | 测试误差 | {test_error:.8f} |
261
  """)
262
 
263
- # 更新进度和预测结果
264
- show_kan_prediction(model, device, samples, kan_plot_placeholder)
265
-
266
  # 更新可视化(每5步更新一次)
267
- if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
268
- # 更新预测结果
269
- show_kan_prediction(model, device, samples, kan_plot_placeholder)
270
 
271
  # 更新网络结构图(可选)
272
- if show_plot:
273
- try:
274
- kan_fig = model.plot()
275
- if isinstance(kan_fig, tuple):
276
- kan_fig = kan_fig[0] # 如果是元组,取第一个元素
277
- if kan_fig is not None:
278
- kan_fig_placeholder.pyplot(kan_fig)
279
- plt.close('all') # 确保关闭所有图形
280
- except Exception as e:
281
- if step == 0: # 只在第一次出错时显示警告
282
- st.warning(f"注意:网络结构图显示失败 ({str(e)})")
 
 
283
 
 
 
 
284
  with progress_container:
285
  st.markdown("#### 训练过程")
286
  error_text = st.empty()
287
 
288
  # 第一阶段训练
289
  # 第一阶段:初始训练
290
- with st.spinner("初始训练阶段..."):
291
- train_phase("第一阶段", total_steps, lamb=0.001, show_plot=False) # 第一阶段不显示网络图
292
 
293
  # 剪枝阶段
294
  with st.spinner("正在进行网络剪枝优化..."):
295
  model = model.prune()
296
  progress_container.info("网络剪枝完成")
297
 
298
- # 第二阶段:精调
299
- with st.spinner("最终调优阶段..."):
300
- train_phase("第二阶段", total_steps, show_plot=True) # 第二阶段显示网络图
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  # 显示最终误差
303
  train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
@@ -307,7 +341,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
307
  - 训练集误差: {train_error:.6f}
308
  - 测试集误差: {test_error:.6f}
309
  """)
310
-
311
  progress_container.success("🎉 训练完成!")
312
  return model
313
 
@@ -421,7 +455,7 @@ with st.sidebar:
421
 
422
  # 采样设置
423
  st.subheader("采样设置")
424
- n_samples = st.slider("采样点数", 5, 20, 10)
425
  if st.button("重新采样"):
426
  # 创建GMM数据集进行采样
427
  gmm = GeneralizedGaussianMixture(
@@ -548,34 +582,37 @@ fig.update_xaxes(title_text='X', row=1, col=2)
548
  fig.update_yaxes(title_text='Y', row=1, col=2)
549
 
550
  # 显示GMM主图
551
- st.plotly_chart(fig, use_container_width=True)
 
552
 
553
  # KAN网络训练和预测部分
554
  if st.session_state.sample_points is not None:
555
  st.markdown("---")
556
  st.subheader("KAN网络训练与预测")
 
 
557
 
558
  # 训练控制按钮
559
  col1, col2, col3 = st.columns([1, 2, 1])
560
  with col1:
561
- if st.button("拟合KAN", use_container_width=True):
562
  with st.spinner('训练KAN网络中...'):
563
  st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
564
  st.balloons()
565
 
566
  with col3:
567
  if st.session_state.kan_model is not None:
568
- if st.button("清除KAN结果", use_container_width=True):
569
  st.session_state.kan_model = None
570
  st.rerun()
571
 
572
  # 显示KAN预测结果
573
- if st.session_state.kan_model is not None:
574
- st.subheader("KAN预测结果")
575
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
576
- kan_plot_placeholder = st.empty()
577
- show_kan_prediction(st.session_state.kan_model, device,
578
- st.session_state.sample_points, kan_plot_placeholder)
579
 
580
  st.markdown("---")
581
 
@@ -625,4 +662,5 @@ with st.expander("分布参数说明"):
625
  """)
626
 
627
  # 显示当前参数的数学公式
628
- st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))
 
 
1
+ import time
2
  import streamlit as st
3
  import numpy as np
4
  from pathlib import Path
 
25
  # Set torch dtype
26
  torch.set_default_dtype(torch.float64)
27
 
28
+ def show_kan_prediction(model, device, samples, placeholder, phase_name):
29
  """显示KAN的预测结果"""
30
  # 生成网格数据
31
  x = np.linspace(-5, 5, 100)
 
85
 
86
  # 更新布局
87
  fig_kan.update_layout(
88
+ title='KAN预测分布',
89
  showlegend=True,
90
  width=1200,
91
  height=600,
 
101
  fig_kan.update_yaxes(title_text='Y', row=1, col=2)
102
 
103
  # 使用占位符显示图形
104
+
105
+ placeholder.plotly_chart(fig_kan,
106
+ use_container_width=False,
107
+ key=f"kan_plot_{phase_name}_{time.time()}")
108
 
109
  def create_gmm_plot(dataset, centers, K, samples=None):
110
  """创建GMM分布的可视化图形"""
 
202
  device = torch.device('cuda')
203
  else:
204
  device = torch.device('cpu')
205
+ st.info(f"使用设备: {device} 训练网络")
206
 
207
  # 转换采样点为tensor
208
  samples = torch.from_numpy(samples).to(device)
 
221
  }
222
 
223
  # 创建训练进度显示组件
224
+ # st.write("网络预测分布:")
225
+
226
+
227
+ st.write("网络图形结构:")
228
+ kan_network_arch_placeholder = st.empty()
229
+
230
+
231
  progress_container = st.container()
232
 
233
+ # total_steps = 100
234
+ total_steps = 50
235
  steps_per_update = 10
236
 
237
  def calculate_error(model, x, y):
 
269
  | 测试误差 | {test_error:.8f} |
270
  """)
271
 
272
+
 
 
273
  # 更新可视化(每5步更新一次)
274
+ # if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
275
+ # # 更新预测结果
276
+ # show_kan_prediction(model, device, samples, kan_plot_placeholder, phase_name)
277
 
278
  # 更新网络结构图(可选)
279
+ if show_plot:
280
+ try:
281
+ model.plot()
282
+ kan_fig = plt.gcf()
283
+ # if isinstance(kan_fig, tuple):
284
+ # kan_fig = kan_fig[0] # 如果是元组,取第一个元素
285
+ # if kan_fig is not None:
286
+ kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False)
287
+ # plt.close('all') # 确保关闭所有图形
288
+ except Exception as e:
289
+ if step == 0: # 只在第一次出错时显示警告
290
+ st.warning(f"注意:网络结构图显示失败 ({str(e)})")
291
+
292
 
293
+ # 更新进度和预测结果
294
+ show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name)
295
+
296
  with progress_container:
297
  st.markdown("#### 训练过程")
298
  error_text = st.empty()
299
 
300
  # 第一阶段训练
301
  # 第一阶段:初始训练
302
+ with st.spinner("参数调整中..."):
303
+ train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True)
304
 
305
  # 剪枝阶段
306
  with st.spinner("正在进行网络剪枝优化..."):
307
  model = model.prune()
308
  progress_container.info("网络剪枝完成")
309
 
310
+ with st.spinner("参数调整中..."):
311
+ train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True)
312
+
313
+ with st.spinner("正在进行网格精细化..."):
314
+ model = model.refine(10)
315
+ progress_container.info("网格精细化完成")
316
+
317
+ with st.spinner("参数调整中..."):
318
+ train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True)
319
+
320
+ with st.spinner("符号简化中..."):
321
+ # model = model.prune()
322
+ # progress_container.info("网络剪枝完成")
323
+ model.auto_symbolic()
324
+ progress_container.info("符号简化完成")
325
+
326
+ with st.spinner("参数调整中..."):
327
+ train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True)
328
+
329
+ from kan.utils import ex_round
330
+ from sympy import latex
331
+ s= ex_round(model.symbolic_formula()[0][0],4)
332
+
333
+ st.write("网络公式:")
334
+ st.latex(latex(s))
335
 
336
  # 显示最终误差
337
  train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
 
341
  - 训练集误差: {train_error:.6f}
342
  - 测试集误差: {test_error:.6f}
343
  """)
344
+
345
  progress_container.success("🎉 训练完成!")
346
  return model
347
 
 
455
 
456
  # 采样设置
457
  st.subheader("采样设置")
458
+ n_samples = st.slider("采样点数", 5, 1000, 100)
459
  if st.button("重新采样"):
460
  # 创建GMM数据集进行采样
461
  gmm = GeneralizedGaussianMixture(
 
582
  fig.update_yaxes(title_text='Y', row=1, col=2)
583
 
584
  # 显示GMM主图
585
+ st.plotly_chart(fig, use_container_width=False)
586
+
587
 
588
  # KAN网络训练和预测部分
589
  if st.session_state.sample_points is not None:
590
  st.markdown("---")
591
  st.subheader("KAN网络训练与预测")
592
+
593
+ kan_distribution_plot_placeholder = st.empty()
594
 
595
  # 训练控制按钮
596
  col1, col2, col3 = st.columns([1, 2, 1])
597
  with col1:
598
+ if st.button("拟合KAN", use_container_width=False):
599
  with st.spinner('训练KAN网络中...'):
600
  st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
601
  st.balloons()
602
 
603
  with col3:
604
  if st.session_state.kan_model is not None:
605
+ if st.button("清除KAN结果", use_container_width=False):
606
  st.session_state.kan_model = None
607
  st.rerun()
608
 
609
  # 显示KAN预测结果
610
+ # if st.session_state.kan_model is not None:
611
+ # st.subheader("KAN预测结果")
612
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
613
+ # kan_plot_placeholder = st.empty()
614
+ # show_kan_prediction(st.session_state.kan_model, device,
615
+ # st.session_state.sample_points, kan_plot_placeholder, "显示结果")
616
 
617
  st.markdown("---")
618
 
 
662
  """)
663
 
664
  # 显示当前参数的数学公式
665
+ with st.expander("分布概率密度函数公式"):
666
+ st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))