2catycm commited on
Commit
a299d07
·
1 Parent(s): e48942a
Files changed (2) hide show
  1. app.py +4 -1
  2. requirements.txt +1 -1
app.py CHANGED
@@ -69,6 +69,7 @@ def show_kan_prediction(model, device, samples, placeholder, phase_name):
69
 
70
  # 添加采样点
71
  if samples is not None:
 
72
  fig_kan.add_trace(
73
  go.Scatter(
74
  x=samples[:, 0], y=samples[:, 1],
@@ -320,7 +321,9 @@ def train_kan(samples, gmm_dataset, device='cuda'):
320
  with st.spinner("符号简化中..."):
321
  # model = model.prune()
322
  # progress_container.info("网络剪枝完成")
323
- model.auto_symbolic()
 
 
324
  progress_container.info("符号简化完成")
325
 
326
  with st.spinner("参数调整中..."):
 
69
 
70
  # 添加采样点
71
  if samples is not None:
72
+ samples = samples.cpu().numpy() if torch.is_tensor(samples) else samples
73
  fig_kan.add_trace(
74
  go.Scatter(
75
  x=samples[:, 0], y=samples[:, 1],
 
321
  with st.spinner("符号简化中..."):
322
  # model = model.prune()
323
  # progress_container.info("网络剪枝完成")
324
+ lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
325
+ model.auto_symbolic(lib=lib)
326
+ # model.auto_symbolic()
327
  progress_container.info("符号简化完成")
328
 
329
  with st.spinner("参数调整中..."):
requirements.txt CHANGED
@@ -8,5 +8,5 @@ ipython>=8.0.0
8
  ipywidgets>=7.0.0
9
  nbformat>=5.0.0
10
  sympy>=1.8
11
- pykan
12
  matplotlib
 
8
  ipywidgets>=7.0.0
9
  nbformat>=5.0.0
10
  sympy>=1.8
11
+ git+https://github.com/2catycm/pykan.git
12
  matplotlib