feat: 修复多个bug
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
from pathlib import Path
|
@@ -24,7 +25,7 @@ from kan.utils import create_dataset, ex_round # type: ignore
|
|
24 |
# Set torch dtype
|
25 |
torch.set_default_dtype(torch.float64)
|
26 |
|
27 |
-
def show_kan_prediction(model, device, samples, placeholder):
|
28 |
"""显示KAN的预测结果"""
|
29 |
# 生成网格数据
|
30 |
x = np.linspace(-5, 5, 100)
|
@@ -84,7 +85,7 @@ def show_kan_prediction(model, device, samples, placeholder):
|
|
84 |
|
85 |
# 更新布局
|
86 |
fig_kan.update_layout(
|
87 |
-
title='KAN
|
88 |
showlegend=True,
|
89 |
width=1200,
|
90 |
height=600,
|
@@ -100,7 +101,10 @@ def show_kan_prediction(model, device, samples, placeholder):
|
|
100 |
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
|
101 |
|
102 |
# 使用占位符显示图形
|
103 |
-
|
|
|
|
|
|
|
104 |
|
105 |
def create_gmm_plot(dataset, centers, K, samples=None):
|
106 |
"""创建GMM分布的可视化图形"""
|
@@ -198,6 +202,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
198 |
device = torch.device('cuda')
|
199 |
else:
|
200 |
device = torch.device('cpu')
|
|
|
201 |
|
202 |
# 转换采样点为tensor
|
203 |
samples = torch.from_numpy(samples).to(device)
|
@@ -216,13 +221,17 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
216 |
}
|
217 |
|
218 |
# 创建训练进度显示组件
|
219 |
-
st.write("
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
progress_container = st.container()
|
224 |
|
225 |
-
total_steps = 100
|
|
|
226 |
steps_per_update = 10
|
227 |
|
228 |
def calculate_error(model, x, y):
|
@@ -260,44 +269,69 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
260 |
| 测试误差 | {test_error:.8f} |
|
261 |
""")
|
262 |
|
263 |
-
|
264 |
-
show_kan_prediction(model, device, samples, kan_plot_placeholder)
|
265 |
-
|
266 |
# 更新可视化(每5步更新一次)
|
267 |
-
if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
|
268 |
-
|
269 |
-
|
270 |
|
271 |
# 更新网络结构图(可选)
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
|
|
|
|
|
|
|
284 |
with progress_container:
|
285 |
st.markdown("#### 训练过程")
|
286 |
error_text = st.empty()
|
287 |
|
288 |
# 第一阶段训练
|
289 |
# 第一阶段:初始训练
|
290 |
-
with st.spinner("
|
291 |
-
train_phase("
|
292 |
|
293 |
# 剪枝阶段
|
294 |
with st.spinner("正在进行网络剪枝优化..."):
|
295 |
model = model.prune()
|
296 |
progress_container.info("网络剪枝完成")
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
# 显示最终误差
|
303 |
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
@@ -307,7 +341,7 @@ def train_kan(samples, gmm_dataset, device='cuda'):
|
|
307 |
- 训练集误差: {train_error:.6f}
|
308 |
- 测试集误差: {test_error:.6f}
|
309 |
""")
|
310 |
-
|
311 |
progress_container.success("🎉 训练完成!")
|
312 |
return model
|
313 |
|
@@ -421,7 +455,7 @@ with st.sidebar:
|
|
421 |
|
422 |
# 采样设置
|
423 |
st.subheader("采样设置")
|
424 |
-
n_samples = st.slider("采样点数", 5,
|
425 |
if st.button("重新采样"):
|
426 |
# 创建GMM数据集进行采样
|
427 |
gmm = GeneralizedGaussianMixture(
|
@@ -548,34 +582,37 @@ fig.update_xaxes(title_text='X', row=1, col=2)
|
|
548 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
549 |
|
550 |
# 显示GMM主图
|
551 |
-
st.plotly_chart(fig, use_container_width=
|
|
|
552 |
|
553 |
# KAN网络训练和预测部分
|
554 |
if st.session_state.sample_points is not None:
|
555 |
st.markdown("---")
|
556 |
st.subheader("KAN网络训练与预测")
|
|
|
|
|
557 |
|
558 |
# 训练控制按钮
|
559 |
col1, col2, col3 = st.columns([1, 2, 1])
|
560 |
with col1:
|
561 |
-
if st.button("拟合KAN", use_container_width=
|
562 |
with st.spinner('训练KAN网络中...'):
|
563 |
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
|
564 |
st.balloons()
|
565 |
|
566 |
with col3:
|
567 |
if st.session_state.kan_model is not None:
|
568 |
-
if st.button("清除KAN结果", use_container_width=
|
569 |
st.session_state.kan_model = None
|
570 |
st.rerun()
|
571 |
|
572 |
# 显示KAN预测结果
|
573 |
-
if st.session_state.kan_model is not None:
|
574 |
-
st.subheader("KAN预测结果")
|
575 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
576 |
-
kan_plot_placeholder = st.empty()
|
577 |
-
show_kan_prediction(st.session_state.kan_model, device,
|
578 |
-
|
579 |
|
580 |
st.markdown("---")
|
581 |
|
@@ -625,4 +662,5 @@ with st.expander("分布参数说明"):
|
|
625 |
""")
|
626 |
|
627 |
# 显示当前参数的数学公式
|
628 |
-
st.
|
|
|
|
1 |
+
import time
|
2 |
import streamlit as st
|
3 |
import numpy as np
|
4 |
from pathlib import Path
|
|
|
25 |
# Set torch dtype
|
26 |
torch.set_default_dtype(torch.float64)
|
27 |
|
28 |
+
def show_kan_prediction(model, device, samples, placeholder, phase_name):
|
29 |
"""显示KAN的预测结果"""
|
30 |
# 生成网格数据
|
31 |
x = np.linspace(-5, 5, 100)
|
|
|
85 |
|
86 |
# 更新布局
|
87 |
fig_kan.update_layout(
|
88 |
+
title='KAN预测分布',
|
89 |
showlegend=True,
|
90 |
width=1200,
|
91 |
height=600,
|
|
|
101 |
fig_kan.update_yaxes(title_text='Y', row=1, col=2)
|
102 |
|
103 |
# 使用占位符显示图形
|
104 |
+
|
105 |
+
placeholder.plotly_chart(fig_kan,
|
106 |
+
use_container_width=False,
|
107 |
+
key=f"kan_plot_{phase_name}_{time.time()}")
|
108 |
|
109 |
def create_gmm_plot(dataset, centers, K, samples=None):
|
110 |
"""创建GMM分布的可视化图形"""
|
|
|
202 |
device = torch.device('cuda')
|
203 |
else:
|
204 |
device = torch.device('cpu')
|
205 |
+
st.info(f"使用设备: {device} 训练网络")
|
206 |
|
207 |
# 转换采样点为tensor
|
208 |
samples = torch.from_numpy(samples).to(device)
|
|
|
221 |
}
|
222 |
|
223 |
# 创建训练进度显示组件
|
224 |
+
# st.write("网络预测分布:")
|
225 |
+
|
226 |
+
|
227 |
+
st.write("网络图形结构:")
|
228 |
+
kan_network_arch_placeholder = st.empty()
|
229 |
+
|
230 |
+
|
231 |
progress_container = st.container()
|
232 |
|
233 |
+
# total_steps = 100
|
234 |
+
total_steps = 50
|
235 |
steps_per_update = 10
|
236 |
|
237 |
def calculate_error(model, x, y):
|
|
|
269 |
| 测试误差 | {test_error:.8f} |
|
270 |
""")
|
271 |
|
272 |
+
|
|
|
|
|
273 |
# 更新可视化(每5步更新一次)
|
274 |
+
# if step % (steps_per_update * 5) == 0 or step + steps_per_update >= steps:
|
275 |
+
# # 更新预测结果
|
276 |
+
# show_kan_prediction(model, device, samples, kan_plot_placeholder, phase_name)
|
277 |
|
278 |
# 更新网络结构图(可选)
|
279 |
+
if show_plot:
|
280 |
+
try:
|
281 |
+
model.plot()
|
282 |
+
kan_fig = plt.gcf()
|
283 |
+
# if isinstance(kan_fig, tuple):
|
284 |
+
# kan_fig = kan_fig[0] # 如果是元组,取第一个元素
|
285 |
+
# if kan_fig is not None:
|
286 |
+
kan_network_arch_placeholder.pyplot(kan_fig, use_container_width=False)
|
287 |
+
# plt.close('all') # 确保关闭所有图形
|
288 |
+
except Exception as e:
|
289 |
+
if step == 0: # 只在第一次出错时显示警告
|
290 |
+
st.warning(f"注意:网络结构图显示失败 ({str(e)})")
|
291 |
+
|
292 |
|
293 |
+
# 更新进度和预测结果
|
294 |
+
show_kan_prediction(model, device, samples, kan_distribution_plot_placeholder, phase_name)
|
295 |
+
|
296 |
with progress_container:
|
297 |
st.markdown("#### 训练过程")
|
298 |
error_text = st.empty()
|
299 |
|
300 |
# 第一阶段训练
|
301 |
# 第一阶段:初始训练
|
302 |
+
with st.spinner("参数调整中..."):
|
303 |
+
train_phase("第一阶段: 正则化训练", total_steps, lamb=0.001, show_plot=True)
|
304 |
|
305 |
# 剪枝阶段
|
306 |
with st.spinner("正在进行网络剪枝优化..."):
|
307 |
model = model.prune()
|
308 |
progress_container.info("网络剪枝完成")
|
309 |
|
310 |
+
with st.spinner("参数调整中..."):
|
311 |
+
train_phase("第二阶段: 剪枝适应性训练", total_steps, show_plot=True)
|
312 |
+
|
313 |
+
with st.spinner("正在进行网格精细化..."):
|
314 |
+
model = model.refine(10)
|
315 |
+
progress_container.info("网格精细化完成")
|
316 |
+
|
317 |
+
with st.spinner("参数调整中..."):
|
318 |
+
train_phase("第三阶段: 网格适应性训练", total_steps, show_plot=True)
|
319 |
+
|
320 |
+
with st.spinner("符号简化中..."):
|
321 |
+
# model = model.prune()
|
322 |
+
# progress_container.info("网络剪枝完成")
|
323 |
+
model.auto_symbolic()
|
324 |
+
progress_container.info("符号简化完成")
|
325 |
+
|
326 |
+
with st.spinner("参数调整中..."):
|
327 |
+
train_phase("第四阶段:符号适应性训练", total_steps, show_plot=True)
|
328 |
+
|
329 |
+
from kan.utils import ex_round
|
330 |
+
from sympy import latex
|
331 |
+
s= ex_round(model.symbolic_formula()[0][0],4)
|
332 |
+
|
333 |
+
st.write("网络公式:")
|
334 |
+
st.latex(latex(s))
|
335 |
|
336 |
# 显示最终误差
|
337 |
train_error = calculate_error(model, train_dataset['train_input'], train_dataset['train_label'])
|
|
|
341 |
- 训练集误差: {train_error:.6f}
|
342 |
- 测试集误差: {test_error:.6f}
|
343 |
""")
|
344 |
+
|
345 |
progress_container.success("🎉 训练完成!")
|
346 |
return model
|
347 |
|
|
|
455 |
|
456 |
# 采样设置
|
457 |
st.subheader("采样设置")
|
458 |
+
n_samples = st.slider("采样点数", 5, 1000, 100)
|
459 |
if st.button("重新采样"):
|
460 |
# 创建GMM数据集进行采样
|
461 |
gmm = GeneralizedGaussianMixture(
|
|
|
582 |
fig.update_yaxes(title_text='Y', row=1, col=2)
|
583 |
|
584 |
# 显示GMM主图
|
585 |
+
st.plotly_chart(fig, use_container_width=False)
|
586 |
+
|
587 |
|
588 |
# KAN网络训练和预测部分
|
589 |
if st.session_state.sample_points is not None:
|
590 |
st.markdown("---")
|
591 |
st.subheader("KAN网络训练与预测")
|
592 |
+
|
593 |
+
kan_distribution_plot_placeholder = st.empty()
|
594 |
|
595 |
# 训练控制按钮
|
596 |
col1, col2, col3 = st.columns([1, 2, 1])
|
597 |
with col1:
|
598 |
+
if st.button("拟合KAN", use_container_width=False):
|
599 |
with st.spinner('训练KAN网络中...'):
|
600 |
st.session_state.kan_model = train_kan(st.session_state.sample_points, dataset)
|
601 |
st.balloons()
|
602 |
|
603 |
with col3:
|
604 |
if st.session_state.kan_model is not None:
|
605 |
+
if st.button("清除KAN结果", use_container_width=False):
|
606 |
st.session_state.kan_model = None
|
607 |
st.rerun()
|
608 |
|
609 |
# 显示KAN预测结果
|
610 |
+
# if st.session_state.kan_model is not None:
|
611 |
+
# st.subheader("KAN预测结果")
|
612 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
613 |
+
# kan_plot_placeholder = st.empty()
|
614 |
+
# show_kan_prediction(st.session_state.kan_model, device,
|
615 |
+
# st.session_state.sample_points, kan_plot_placeholder, "显示结果")
|
616 |
|
617 |
st.markdown("---")
|
618 |
|
|
|
662 |
""")
|
663 |
|
664 |
# 显示当前参数的数学公式
|
665 |
+
with st.expander("分布概率密度函数公式"):
|
666 |
+
st.latex(generate_latex_formula(st.session_state.p, K, centers[:K], scales[:K], weights[:K]))
|