Dixing (Dex) Xu commited on
Commit
f3092ac
·
unverified ·
1 Parent(s): 5c7fa16

:sparkles: Add validation plot and score for webui (#28)

Browse files

* :sparkles: Add validation plot and score for webui

* Add validation plot
* Add validation score
* Update style.css

* :art: Put the best validation score under the tab

* :rotating_light: update lint and example text

Files changed (2) hide show
  1. aide/webui/app.py +112 -15
  2. aide/webui/style.css +1 -1
aide/webui/app.py CHANGED
@@ -158,24 +158,35 @@ class WebUI:
158
  Returns:
159
  list: List of uploaded or example files.
160
  """
161
- if st.button(
162
- "Load Example Experiment", type="primary", use_container_width=True
163
- ):
164
- st.session_state.example_files = self.load_example_files()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if st.session_state.get("example_files"):
167
  st.info("Example files loaded! Click 'Run AIDE' to proceed.")
168
  with st.expander("View Loaded Files", expanded=False):
169
  for file in st.session_state.example_files:
170
  st.text(f"📄 {file['name']}")
171
- uploaded_files = st.session_state.example_files
172
- else:
173
- uploaded_files = st.file_uploader(
174
- "Upload Data Files",
175
- accept_multiple_files=True,
176
- type=["csv", "txt", "json", "md"],
177
- )
178
- return uploaded_files
179
 
180
  def handle_user_inputs(self):
181
  """
@@ -187,12 +198,12 @@ class WebUI:
187
  goal_text = st.text_area(
188
  "Goal",
189
  value=st.session_state.get("goal", ""),
190
- placeholder="Example: Predict house prices",
191
  )
192
  eval_text = st.text_area(
193
  "Evaluation Criteria",
194
  value=st.session_state.get("eval", ""),
195
- placeholder="Example: Use RMSE metric",
196
  )
197
  num_steps = st.slider(
198
  "Number of Steps",
@@ -450,7 +461,16 @@ class WebUI:
450
  st.header("Results")
451
  if st.session_state.get("results"):
452
  results = st.session_state.results
453
- tabs = st.tabs(["Tree Visualization", "Best Solution", "Config", "Journal"])
 
 
 
 
 
 
 
 
 
454
 
455
  with tabs[0]:
456
  self.render_tree_visualization(results)
@@ -460,6 +480,12 @@ class WebUI:
460
  self.render_config(results)
461
  with tabs[3]:
462
  self.render_journal(results)
 
 
 
 
 
 
463
  else:
464
  st.info("No results to display. Please run an experiment.")
465
 
@@ -529,6 +555,77 @@ class WebUI:
529
  else:
530
  st.info("No journal available.")
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  if __name__ == "__main__":
534
  app = WebUI()
 
158
  Returns:
159
  list: List of uploaded or example files.
160
  """
161
+ # Only show file uploader if no example files are loaded
162
+ if not st.session_state.get("example_files"):
163
+ uploaded_files = st.file_uploader(
164
+ "Upload Data Files",
165
+ accept_multiple_files=True,
166
+ type=["csv", "txt", "json", "md"],
167
+ label_visibility="collapsed",
168
+ )
169
+
170
+ if uploaded_files:
171
+ st.session_state.pop(
172
+ "example_files", None
173
+ ) # Remove example files if any
174
+ return uploaded_files
175
+
176
+ # Only show example button if no files are uploaded
177
+ if st.button(
178
+ "Load Example Experiment", type="primary", use_container_width=True
179
+ ):
180
+ st.session_state.example_files = self.load_example_files()
181
 
182
  if st.session_state.get("example_files"):
183
  st.info("Example files loaded! Click 'Run AIDE' to proceed.")
184
  with st.expander("View Loaded Files", expanded=False):
185
  for file in st.session_state.example_files:
186
  st.text(f"📄 {file['name']}")
187
+ return st.session_state.example_files
188
+
189
+ return [] # Return empty list if no files are uploaded or loaded
 
 
 
 
 
190
 
191
  def handle_user_inputs(self):
192
  """
 
198
  goal_text = st.text_area(
199
  "Goal",
200
  value=st.session_state.get("goal", ""),
201
+ placeholder="Example: Predict the sales price for each house",
202
  )
203
  eval_text = st.text_area(
204
  "Evaluation Criteria",
205
  value=st.session_state.get("eval", ""),
206
+ placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.",
207
  )
208
  num_steps = st.slider(
209
  "Number of Steps",
 
461
  st.header("Results")
462
  if st.session_state.get("results"):
463
  results = st.session_state.results
464
+
465
+ tabs = st.tabs(
466
+ [
467
+ "Tree Visualization",
468
+ "Best Solution",
469
+ "Config",
470
+ "Journal",
471
+ "Validation Plot",
472
+ ]
473
+ )
474
 
475
  with tabs[0]:
476
  self.render_tree_visualization(results)
 
480
  self.render_config(results)
481
  with tabs[3]:
482
  self.render_journal(results)
483
+ with tabs[4]:
484
+ # Display best score before the plot
485
+ best_metric = self.get_best_metric(results)
486
+ if best_metric is not None:
487
+ st.metric("Best Validation Score", f"{best_metric:.4f}")
488
+ self.render_validation_plot(results)
489
  else:
490
  st.info("No results to display. Please run an experiment.")
491
 
 
555
  else:
556
  st.info("No journal available.")
557
 
558
+ @staticmethod
559
+ def get_best_metric(results):
560
+ """
561
+ Extract the best validation metric from results.
562
+ """
563
+ try:
564
+ journal_data = json.loads(results["journal"])
565
+ metrics = []
566
+ for node in journal_data:
567
+ if node["metric"] is not None:
568
+ try:
569
+ # Convert string metric to float
570
+ metric_value = float(node["metric"])
571
+ metrics.append(metric_value)
572
+ except (ValueError, TypeError):
573
+ continue
574
+ return max(metrics) if metrics else None
575
+ except (json.JSONDecodeError, KeyError):
576
+ return None
577
+
578
+ @staticmethod
579
+ def render_validation_plot(results):
580
+ """
581
+ Render the validation score plot.
582
+ """
583
+ try:
584
+ journal_data = json.loads(results["journal"])
585
+ steps = []
586
+ metrics = []
587
+
588
+ for node in journal_data:
589
+ if node["metric"] is not None and node["metric"].lower() != "none":
590
+ try:
591
+ metric_value = float(node["metric"])
592
+ steps.append(node["step"])
593
+ metrics.append(metric_value)
594
+ except (ValueError, TypeError):
595
+ continue
596
+
597
+ if metrics:
598
+ import plotly.graph_objects as go
599
+
600
+ fig = go.Figure()
601
+ fig.add_trace(
602
+ go.Scatter(
603
+ x=steps,
604
+ y=metrics,
605
+ mode="lines+markers",
606
+ name="Validation Score",
607
+ line=dict(color="#F04370"),
608
+ marker=dict(color="#F04370"),
609
+ )
610
+ )
611
+
612
+ fig.update_layout(
613
+ title="Validation Score Progress",
614
+ xaxis_title="Step",
615
+ yaxis_title="Validation Score",
616
+ template="plotly_white",
617
+ hovermode="x unified",
618
+ plot_bgcolor="rgba(0,0,0,0)",
619
+ paper_bgcolor="rgba(0,0,0,0)",
620
+ )
621
+
622
+ st.plotly_chart(fig, use_container_width=True)
623
+ else:
624
+ st.info("No validation metrics available to plot.")
625
+
626
+ except (json.JSONDecodeError, KeyError):
627
+ st.error("Could not parse validation metrics data.")
628
+
629
 
630
  if __name__ == "__main__":
631
  app = WebUI()
aide/webui/style.css CHANGED
@@ -1,7 +1,7 @@
1
  /* Main colors */
2
  :root {
3
  --background: #F2F0E7;
4
- --background-shaded: #EBE8DD;
5
  --card: #FFFFFF;
6
  --primary: #0D0F18;
7
  --accent: #F04370;
 
1
  /* Main colors */
2
  :root {
3
  --background: #F2F0E7;
4
+ --background-shaded: #FFFFFF;
5
  --card: #FFFFFF;
6
  --primary: #0D0F18;
7
  --accent: #F04370;