Dixing (Dex) Xu
commited on
:sparkles: Add validation plot and score for webui (#28)
Browse files* :sparkles: Add validation plot and score for webui
* Add validation plot
* Add validation score
* Update style.css
* :art: Put the best validation score under the tab
* :rotating_light: update lint and example text
- aide/webui/app.py +112 -15
- aide/webui/style.css +1 -1
aide/webui/app.py
CHANGED
@@ -158,24 +158,35 @@ class WebUI:
|
|
158 |
Returns:
|
159 |
list: List of uploaded or example files.
|
160 |
"""
|
161 |
-
if
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
if st.session_state.get("example_files"):
|
167 |
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
|
168 |
with st.expander("View Loaded Files", expanded=False):
|
169 |
for file in st.session_state.example_files:
|
170 |
st.text(f"📄 {file['name']}")
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
"Upload Data Files",
|
175 |
-
accept_multiple_files=True,
|
176 |
-
type=["csv", "txt", "json", "md"],
|
177 |
-
)
|
178 |
-
return uploaded_files
|
179 |
|
180 |
def handle_user_inputs(self):
|
181 |
"""
|
@@ -187,12 +198,12 @@ class WebUI:
|
|
187 |
goal_text = st.text_area(
|
188 |
"Goal",
|
189 |
value=st.session_state.get("goal", ""),
|
190 |
-
placeholder="Example: Predict house
|
191 |
)
|
192 |
eval_text = st.text_area(
|
193 |
"Evaluation Criteria",
|
194 |
value=st.session_state.get("eval", ""),
|
195 |
-
placeholder="Example: Use RMSE metric",
|
196 |
)
|
197 |
num_steps = st.slider(
|
198 |
"Number of Steps",
|
@@ -450,7 +461,16 @@ class WebUI:
|
|
450 |
st.header("Results")
|
451 |
if st.session_state.get("results"):
|
452 |
results = st.session_state.results
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
with tabs[0]:
|
456 |
self.render_tree_visualization(results)
|
@@ -460,6 +480,12 @@ class WebUI:
|
|
460 |
self.render_config(results)
|
461 |
with tabs[3]:
|
462 |
self.render_journal(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
else:
|
464 |
st.info("No results to display. Please run an experiment.")
|
465 |
|
@@ -529,6 +555,77 @@ class WebUI:
|
|
529 |
else:
|
530 |
st.info("No journal available.")
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
|
533 |
if __name__ == "__main__":
|
534 |
app = WebUI()
|
|
|
158 |
Returns:
|
159 |
list: List of uploaded or example files.
|
160 |
"""
|
161 |
+
# Only show file uploader if no example files are loaded
|
162 |
+
if not st.session_state.get("example_files"):
|
163 |
+
uploaded_files = st.file_uploader(
|
164 |
+
"Upload Data Files",
|
165 |
+
accept_multiple_files=True,
|
166 |
+
type=["csv", "txt", "json", "md"],
|
167 |
+
label_visibility="collapsed",
|
168 |
+
)
|
169 |
+
|
170 |
+
if uploaded_files:
|
171 |
+
st.session_state.pop(
|
172 |
+
"example_files", None
|
173 |
+
) # Remove example files if any
|
174 |
+
return uploaded_files
|
175 |
+
|
176 |
+
# Only show example button if no files are uploaded
|
177 |
+
if st.button(
|
178 |
+
"Load Example Experiment", type="primary", use_container_width=True
|
179 |
+
):
|
180 |
+
st.session_state.example_files = self.load_example_files()
|
181 |
|
182 |
if st.session_state.get("example_files"):
|
183 |
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
|
184 |
with st.expander("View Loaded Files", expanded=False):
|
185 |
for file in st.session_state.example_files:
|
186 |
st.text(f"📄 {file['name']}")
|
187 |
+
return st.session_state.example_files
|
188 |
+
|
189 |
+
return [] # Return empty list if no files are uploaded or loaded
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def handle_user_inputs(self):
|
192 |
"""
|
|
|
198 |
goal_text = st.text_area(
|
199 |
"Goal",
|
200 |
value=st.session_state.get("goal", ""),
|
201 |
+
placeholder="Example: Predict the sales price for each house",
|
202 |
)
|
203 |
eval_text = st.text_area(
|
204 |
"Evaluation Criteria",
|
205 |
value=st.session_state.get("eval", ""),
|
206 |
+
placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.",
|
207 |
)
|
208 |
num_steps = st.slider(
|
209 |
"Number of Steps",
|
|
|
461 |
st.header("Results")
|
462 |
if st.session_state.get("results"):
|
463 |
results = st.session_state.results
|
464 |
+
|
465 |
+
tabs = st.tabs(
|
466 |
+
[
|
467 |
+
"Tree Visualization",
|
468 |
+
"Best Solution",
|
469 |
+
"Config",
|
470 |
+
"Journal",
|
471 |
+
"Validation Plot",
|
472 |
+
]
|
473 |
+
)
|
474 |
|
475 |
with tabs[0]:
|
476 |
self.render_tree_visualization(results)
|
|
|
480 |
self.render_config(results)
|
481 |
with tabs[3]:
|
482 |
self.render_journal(results)
|
483 |
+
with tabs[4]:
|
484 |
+
# Display best score before the plot
|
485 |
+
best_metric = self.get_best_metric(results)
|
486 |
+
if best_metric is not None:
|
487 |
+
st.metric("Best Validation Score", f"{best_metric:.4f}")
|
488 |
+
self.render_validation_plot(results)
|
489 |
else:
|
490 |
st.info("No results to display. Please run an experiment.")
|
491 |
|
|
|
555 |
else:
|
556 |
st.info("No journal available.")
|
557 |
|
558 |
+
@staticmethod
|
559 |
+
def get_best_metric(results):
|
560 |
+
"""
|
561 |
+
Extract the best validation metric from results.
|
562 |
+
"""
|
563 |
+
try:
|
564 |
+
journal_data = json.loads(results["journal"])
|
565 |
+
metrics = []
|
566 |
+
for node in journal_data:
|
567 |
+
if node["metric"] is not None:
|
568 |
+
try:
|
569 |
+
# Convert string metric to float
|
570 |
+
metric_value = float(node["metric"])
|
571 |
+
metrics.append(metric_value)
|
572 |
+
except (ValueError, TypeError):
|
573 |
+
continue
|
574 |
+
return max(metrics) if metrics else None
|
575 |
+
except (json.JSONDecodeError, KeyError):
|
576 |
+
return None
|
577 |
+
|
578 |
+
@staticmethod
|
579 |
+
def render_validation_plot(results):
|
580 |
+
"""
|
581 |
+
Render the validation score plot.
|
582 |
+
"""
|
583 |
+
try:
|
584 |
+
journal_data = json.loads(results["journal"])
|
585 |
+
steps = []
|
586 |
+
metrics = []
|
587 |
+
|
588 |
+
for node in journal_data:
|
589 |
+
if node["metric"] is not None and node["metric"].lower() != "none":
|
590 |
+
try:
|
591 |
+
metric_value = float(node["metric"])
|
592 |
+
steps.append(node["step"])
|
593 |
+
metrics.append(metric_value)
|
594 |
+
except (ValueError, TypeError):
|
595 |
+
continue
|
596 |
+
|
597 |
+
if metrics:
|
598 |
+
import plotly.graph_objects as go
|
599 |
+
|
600 |
+
fig = go.Figure()
|
601 |
+
fig.add_trace(
|
602 |
+
go.Scatter(
|
603 |
+
x=steps,
|
604 |
+
y=metrics,
|
605 |
+
mode="lines+markers",
|
606 |
+
name="Validation Score",
|
607 |
+
line=dict(color="#F04370"),
|
608 |
+
marker=dict(color="#F04370"),
|
609 |
+
)
|
610 |
+
)
|
611 |
+
|
612 |
+
fig.update_layout(
|
613 |
+
title="Validation Score Progress",
|
614 |
+
xaxis_title="Step",
|
615 |
+
yaxis_title="Validation Score",
|
616 |
+
template="plotly_white",
|
617 |
+
hovermode="x unified",
|
618 |
+
plot_bgcolor="rgba(0,0,0,0)",
|
619 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
620 |
+
)
|
621 |
+
|
622 |
+
st.plotly_chart(fig, use_container_width=True)
|
623 |
+
else:
|
624 |
+
st.info("No validation metrics available to plot.")
|
625 |
+
|
626 |
+
except (json.JSONDecodeError, KeyError):
|
627 |
+
st.error("Could not parse validation metrics data.")
|
628 |
+
|
629 |
|
630 |
if __name__ == "__main__":
|
631 |
app = WebUI()
|
aide/webui/style.css
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
/* Main colors */
|
2 |
:root {
|
3 |
--background: #F2F0E7;
|
4 |
-
--background-shaded: #
|
5 |
--card: #FFFFFF;
|
6 |
--primary: #0D0F18;
|
7 |
--accent: #F04370;
|
|
|
1 |
/* Main colors */
|
2 |
:root {
|
3 |
--background: #F2F0E7;
|
4 |
+
--background-shaded: #FFFFFF;
|
5 |
--card: #FFFFFF;
|
6 |
--primary: #0D0F18;
|
7 |
--accent: #F04370;
|