Dixing Xu
commited on
:recycle: Refactor webui to live render results
Browse files- aide/webui/app.py +76 -109
aide/webui/app.py
CHANGED
@@ -100,8 +100,6 @@ class WebUI:
|
|
100 |
input_col, results_col = st.columns([1, 3])
|
101 |
with input_col:
|
102 |
self.render_input_section(results_col)
|
103 |
-
with results_col:
|
104 |
-
self.render_results_section()
|
105 |
|
106 |
def render_sidebar(self):
|
107 |
"""
|
@@ -273,17 +271,46 @@ class WebUI:
|
|
273 |
return None
|
274 |
|
275 |
experiment = self.initialize_experiment(input_dir, goal_text, eval_text)
|
276 |
-
|
|
|
|
|
|
|
|
|
277 |
|
278 |
for step in range(num_steps):
|
279 |
st.session_state.current_step = step + 1
|
280 |
progress = (step + 1) / num_steps
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
experiment.run(steps=1)
|
283 |
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
except Exception as e:
|
289 |
st.session_state.is_running = False
|
@@ -355,70 +382,6 @@ class WebUI:
|
|
355 |
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text)
|
356 |
return experiment
|
357 |
|
358 |
-
@staticmethod
|
359 |
-
def create_results_placeholders(results_col, experiment):
|
360 |
-
"""
|
361 |
-
Create placeholders in the results column for dynamic content.
|
362 |
-
|
363 |
-
Args:
|
364 |
-
results_col (st.delta_generator.DeltaGenerator): The results column.
|
365 |
-
experiment (Experiment): The Experiment object.
|
366 |
-
|
367 |
-
Returns:
|
368 |
-
dict: Dictionary of placeholders.
|
369 |
-
"""
|
370 |
-
with results_col:
|
371 |
-
status_placeholder = st.empty()
|
372 |
-
step_placeholder = st.empty()
|
373 |
-
config_title_placeholder = st.empty()
|
374 |
-
config_placeholder = st.empty()
|
375 |
-
progress_placeholder = st.empty()
|
376 |
-
|
377 |
-
step_placeholder.markdown(
|
378 |
-
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
|
379 |
-
)
|
380 |
-
config_title_placeholder.markdown("### 📋 Configuration")
|
381 |
-
config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")
|
382 |
-
progress_placeholder.progress(0)
|
383 |
-
|
384 |
-
placeholders = {
|
385 |
-
"status": status_placeholder,
|
386 |
-
"step": step_placeholder,
|
387 |
-
"config_title": config_title_placeholder,
|
388 |
-
"config": config_placeholder,
|
389 |
-
"progress": progress_placeholder,
|
390 |
-
}
|
391 |
-
return placeholders
|
392 |
-
|
393 |
-
@staticmethod
|
394 |
-
def update_results_placeholders(placeholders, progress):
|
395 |
-
"""
|
396 |
-
Update the placeholders with the current progress.
|
397 |
-
|
398 |
-
Args:
|
399 |
-
placeholders (dict): Dictionary of placeholders.
|
400 |
-
progress (float): Current progress value.
|
401 |
-
"""
|
402 |
-
placeholders["step"].markdown(
|
403 |
-
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
|
404 |
-
)
|
405 |
-
placeholders["progress"].progress(progress)
|
406 |
-
|
407 |
-
@staticmethod
|
408 |
-
def clear_run_state(placeholders):
|
409 |
-
"""
|
410 |
-
Clear the running state and placeholders after the experiment.
|
411 |
-
|
412 |
-
Args:
|
413 |
-
placeholders (dict): Dictionary of placeholders.
|
414 |
-
"""
|
415 |
-
st.session_state.is_running = False
|
416 |
-
placeholders["status"].empty()
|
417 |
-
placeholders["step"].empty()
|
418 |
-
placeholders["config_title"].empty()
|
419 |
-
placeholders["config"].empty()
|
420 |
-
placeholders["progress"].empty()
|
421 |
-
|
422 |
@staticmethod
|
423 |
def collect_results(experiment):
|
424 |
"""
|
@@ -454,41 +417,6 @@ class WebUI:
|
|
454 |
}
|
455 |
return results
|
456 |
|
457 |
-
def render_results_section(self):
|
458 |
-
"""
|
459 |
-
Render the results section with tabs for different outputs.
|
460 |
-
"""
|
461 |
-
st.header("Results")
|
462 |
-
if st.session_state.get("results"):
|
463 |
-
results = st.session_state.results
|
464 |
-
|
465 |
-
tabs = st.tabs(
|
466 |
-
[
|
467 |
-
"Tree Visualization",
|
468 |
-
"Best Solution",
|
469 |
-
"Config",
|
470 |
-
"Journal",
|
471 |
-
"Validation Plot",
|
472 |
-
]
|
473 |
-
)
|
474 |
-
|
475 |
-
with tabs[0]:
|
476 |
-
self.render_tree_visualization(results)
|
477 |
-
with tabs[1]:
|
478 |
-
self.render_best_solution(results)
|
479 |
-
with tabs[2]:
|
480 |
-
self.render_config(results)
|
481 |
-
with tabs[3]:
|
482 |
-
self.render_journal(results)
|
483 |
-
with tabs[4]:
|
484 |
-
# Display best score before the plot
|
485 |
-
best_metric = self.get_best_metric(results)
|
486 |
-
if best_metric is not None:
|
487 |
-
st.metric("Best Validation Score", f"{best_metric:.4f}")
|
488 |
-
self.render_validation_plot(results)
|
489 |
-
else:
|
490 |
-
st.info("No results to display. Please run an experiment.")
|
491 |
-
|
492 |
@staticmethod
|
493 |
def render_tree_visualization(results):
|
494 |
"""
|
@@ -576,9 +504,13 @@ class WebUI:
|
|
576 |
return None
|
577 |
|
578 |
@staticmethod
|
579 |
-
def render_validation_plot(results):
|
580 |
"""
|
581 |
Render the validation score plot.
|
|
|
|
|
|
|
|
|
582 |
"""
|
583 |
try:
|
584 |
journal_data = json.loads(results["journal"])
|
@@ -619,12 +551,47 @@ class WebUI:
|
|
619 |
paper_bgcolor="rgba(0,0,0,0)",
|
620 |
)
|
621 |
|
622 |
-
|
|
|
623 |
else:
|
624 |
-
st.info("No validation metrics available to plot
|
625 |
|
626 |
except (json.JSONDecodeError, KeyError):
|
627 |
-
st.error("Could not parse validation metrics data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
|
630 |
if __name__ == "__main__":
|
|
|
100 |
input_col, results_col = st.columns([1, 3])
|
101 |
with input_col:
|
102 |
self.render_input_section(results_col)
|
|
|
|
|
103 |
|
104 |
def render_sidebar(self):
|
105 |
"""
|
|
|
271 |
return None
|
272 |
|
273 |
experiment = self.initialize_experiment(input_dir, goal_text, eval_text)
|
274 |
+
|
275 |
+
# Create separate placeholders for progress and config
|
276 |
+
progress_placeholder = results_col.empty()
|
277 |
+
config_placeholder = results_col.empty()
|
278 |
+
results_placeholder = results_col.empty()
|
279 |
|
280 |
for step in range(num_steps):
|
281 |
st.session_state.current_step = step + 1
|
282 |
progress = (step + 1) / num_steps
|
283 |
+
|
284 |
+
# Update progress
|
285 |
+
with progress_placeholder.container():
|
286 |
+
st.markdown(
|
287 |
+
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
|
288 |
+
)
|
289 |
+
st.progress(progress)
|
290 |
+
|
291 |
+
# Show config only for first step
|
292 |
+
if step == 0:
|
293 |
+
with config_placeholder.container():
|
294 |
+
st.markdown("### 📋 Configuration")
|
295 |
+
st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")
|
296 |
+
|
297 |
experiment.run(steps=1)
|
298 |
|
299 |
+
# Show results
|
300 |
+
with results_placeholder.container():
|
301 |
+
self.render_live_results(experiment)
|
302 |
+
|
303 |
+
# Clear config after first step
|
304 |
+
if step == 0:
|
305 |
+
config_placeholder.empty()
|
306 |
|
307 |
+
# Clear progress after all steps
|
308 |
+
progress_placeholder.empty()
|
309 |
+
|
310 |
+
# Update session state
|
311 |
+
st.session_state.is_running = False
|
312 |
+
st.session_state.results = self.collect_results(experiment)
|
313 |
+
return st.session_state.results
|
314 |
|
315 |
except Exception as e:
|
316 |
st.session_state.is_running = False
|
|
|
382 |
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text)
|
383 |
return experiment
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
@staticmethod
|
386 |
def collect_results(experiment):
|
387 |
"""
|
|
|
417 |
}
|
418 |
return results
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
@staticmethod
|
421 |
def render_tree_visualization(results):
|
422 |
"""
|
|
|
504 |
return None
|
505 |
|
506 |
@staticmethod
|
507 |
+
def render_validation_plot(results, step):
|
508 |
"""
|
509 |
Render the validation score plot.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
results (dict): The results dictionary
|
513 |
+
step (int): Current step number for unique key generation
|
514 |
"""
|
515 |
try:
|
516 |
journal_data = json.loads(results["journal"])
|
|
|
551 |
paper_bgcolor="rgba(0,0,0,0)",
|
552 |
)
|
553 |
|
554 |
+
# Only keep the key for plotly_chart
|
555 |
+
st.plotly_chart(fig, use_container_width=True, key=f"plot_{step}")
|
556 |
else:
|
557 |
+
st.info("No validation metrics available to plot")
|
558 |
|
559 |
except (json.JSONDecodeError, KeyError):
|
560 |
+
st.error("Could not parse validation metrics data")
|
561 |
+
|
562 |
+
def render_live_results(self, experiment):
|
563 |
+
"""
|
564 |
+
Render live results.
|
565 |
+
|
566 |
+
Args:
|
567 |
+
experiment (Experiment): The Experiment object
|
568 |
+
"""
|
569 |
+
results = self.collect_results(experiment)
|
570 |
+
|
571 |
+
# Create tabs for different result views
|
572 |
+
tabs = st.tabs(
|
573 |
+
[
|
574 |
+
"Tree Visualization",
|
575 |
+
"Best Solution",
|
576 |
+
"Config",
|
577 |
+
"Journal",
|
578 |
+
"Validation Plot",
|
579 |
+
]
|
580 |
+
)
|
581 |
+
|
582 |
+
with tabs[0]:
|
583 |
+
self.render_tree_visualization(results)
|
584 |
+
with tabs[1]:
|
585 |
+
self.render_best_solution(results)
|
586 |
+
with tabs[2]:
|
587 |
+
self.render_config(results)
|
588 |
+
with tabs[3]:
|
589 |
+
self.render_journal(results)
|
590 |
+
with tabs[4]:
|
591 |
+
best_metric = self.get_best_metric(results)
|
592 |
+
if best_metric is not None:
|
593 |
+
st.metric("Best Validation Score", f"{best_metric:.4f}")
|
594 |
+
self.render_validation_plot(results, step=st.session_state.current_step)
|
595 |
|
596 |
|
597 |
if __name__ == "__main__":
|