import json from typing import Any import gradio as gr import numpy as np import pandas as pd from datasets import Dataset from loguru import logger from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME from components import commons from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState from components.typed_dicts import PipelineStateDict from display.formatting import styled_error from shared.workflows import factory from shared.workflows.metrics import evaluate_prediction from shared.workflows.qb_agents import QuizBowlBonusAgent from shared.workflows.runners import run_and_eval_bonus_dataset, run_and_evaluate_bonus from submission import submit from . import populate, validation from .plotting import create_bonus_confidence_plot, create_bonus_html from .utils import create_error_message from .validation import UserInputWorkflowValidator def process_bonus_results(results: list[dict]) -> pd.DataFrame: """Process results from bonus mode and prepare visualization data.""" return pd.DataFrame( [ { "Part": f"Part {r['number']}", "Correct?": "✅" if r["correct"] == 1 else "❌", "Confidence": r["confidence"], "Prediction": r["guess"], "Explanation": r["explanation"], } for r in results ] ) def initialize_eval_interface(example: dict, part_outputs: list[dict], input_vars: list[str]): """Initialize the interface with example text.""" try: html_content = create_bonus_html(example["leadin"], example["parts"]) # Create confidence plot data plot_data = create_bonus_confidence_plot(example["parts"], part_outputs) # Store state state = {"parts": example["parts"], "outputs": part_outputs} # Preparing step outputs for the model step_outputs = {} for i, output in enumerate(part_outputs): key = f"part {i + 1}" step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in input_vars} if output["logprob"] is not None: step_outputs[key]["output_probability"] = float(np.exp(output["logprob"])) return html_content, plot_data, state, step_outputs except Exception as e: error_msg = f"Error initializing interface: {str(e)}" logger.exception(error_msg) return styled_error(error_msg), pd.DataFrame(), {}, {} class BonusInterface: """Gradio interface for the Bonus mode.""" def __init__(self, app: gr.Blocks, browser_state: dict, dataset: Dataset, model_options: dict, defaults: dict): """Initialize the Bonus interface.""" logger.info(f"Initializing Bonus interface with dataset size: {len(dataset)}") self.browser_state = browser_state self.ds = dataset self.model_options = model_options self.app = app self.defaults = defaults self.output_state = gr.State(value={}) self.render() # ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------- def load_default_workflow(self): workflow = self.defaults["init_workflow"] pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump() return pipeline_state_dict, {} def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool): try: state_dict = browser_state["bonus"].get("pipeline_state", {}) if state_dict: pipeline_state = PipelineState.model_validate(state_dict) pipeline_state_dict = pipeline_state.model_dump() output_state = browser_state["bonus"].get("output_state", {}) else: pipeline_state_dict, output_state = self.load_default_workflow() except Exception as e: logger.warning(f"Error loading presaved pipeline state: {e}") pipeline_state_dict, output_state = self.load_default_workflow() return browser_state, not pipeline_change, pipeline_state_dict, output_state # ------------------------------------------ INTERFACE RENDER FUNCTIONS ------------------------------------------- def _render_pipeline_interface(self, pipeline_state: PipelineState): """Render the model interface.""" with gr.Row(elem_classes="bonus-header-row form-inline"): self.pipeline_selector = commons.get_pipeline_selector([]) self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary") self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False) logger.info(f"Rendering {self.__class__.__name__} with pipeline state: {pipeline_state}") self.pipeline_interface = PipelineInterface( self.app, pipeline_state.workflow, ui_state=pipeline_state.ui_state, model_options=list(self.model_options.keys()), config=self.defaults, validator=UserInputWorkflowValidator("bonus"), ) def _render_qb_interface(self): """Render the quizbowl interface.""" with gr.Row(elem_classes="bonus-header-row form-inline"): self.qid_selector = commons.get_qid_selector(len(self.ds)) self.run_btn = gr.Button("Run on Bonus Question", variant="secondary") self.question_display = gr.HTML(label="Question", elem_id="bonus-question-display") self.error_display = gr.HTML(label="Error", elem_id="bonus-error-display", visible=False) self.results_table = gr.DataFrame( label="Model Outputs", value=pd.DataFrame(columns=["Part", "Correct?", "Confidence", "Prediction", "Explanation"]), visible=False, ) self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False) with gr.Row(): self.eval_btn = gr.Button("Evaluate", variant="primary") self.model_name_input, self.description_input, self.submit_btn, self.submit_status = ( commons.get_model_submission_accordion(self.app) ) def render(self): """Create the Gradio interface.""" self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index") workflow = factory.create_empty_bonus_workflow() pipeline_state = PipelineState.from_workflow(workflow) with gr.Row(): # Model Panel with gr.Column(scale=1): self._render_pipeline_interface(pipeline_state) with gr.Column(scale=1): self._render_qb_interface() self._setup_event_listeners() def validate_workflow(self, state_dict: PipelineStateDict): """Validate the workflow.""" try: pipeline_state = PipelineState(**state_dict) validation.validate_workflow( pipeline_state.workflow, required_input_vars=CONFIGS["bonus"]["required_input_vars"], required_output_vars=CONFIGS["bonus"]["required_output_vars"], ) except Exception as e: raise gr.Error(f"Error validating workflow: {str(e)}") def get_new_question_html(self, question_id: int): """Get the HTML for a new question.""" if question_id is None: logger.error("Question ID is None. Setting to 1") question_id = 1 try: question_id = int(question_id) - 1 if not self.ds or question_id < 0 or question_id >= len(self.ds): return "Invalid question ID or dataset not loaded" example = self.ds[question_id] leadin = example["leadin"] parts = example["parts"] return create_bonus_html(leadin, parts) except Exception as e: return f"Error loading question: {str(e)}" def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]: names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile) return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME) def load_pipeline( self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None ) -> tuple[str, bool, PipelineStateDict, dict]: try: workflow = populate.load_workflow("bonus", model_name, profile) if workflow is None: logger.warning(f"Could not load workflow for {model_name}") return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False) pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump() return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True) except Exception as e: error_msg = styled_error(f"Error loading pipeline: {str(e)}") return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg) # ------------------------------------- Agent Functions ----------------------------------------------------------- def single_run( self, question_id: int, state_dict: PipelineStateDict, ) -> tuple[str, Any, Any]: """Run the agent in bonus mode and updates the interface. Returns: tuple: Contains the following components: - question_display: HTML display content of the question - output_state: Updated state with question parts and outputs - results_table: DataFrame with model predictions and scores - model_outputs_display: Detailed step outputs from the model - error_display: Any error messages (if applicable) """ try: pipeline_state = validation.validate_bonus_workflow(state_dict) question_id = int(question_id - 1) if not self.ds or question_id < 0 or question_id >= len(self.ds): raise gr.Error("Invalid question ID or dataset not loaded") example = self.ds[question_id] agent = QuizBowlBonusAgent(pipeline_state.workflow) model_output = run_and_evaluate_bonus(agent, example, return_extras=True) part_outputs = model_output["part_outputs"] # Process results and prepare visualization data html_content, plot_data, output_state, step_outputs = initialize_eval_interface( example, part_outputs, pipeline_state.workflow.inputs ) df = process_bonus_results(part_outputs) return ( html_content, gr.update(value=output_state), gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True), gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True), gr.update(visible=False), ) except Exception as e: error_msg = styled_error(create_error_message(e)) logger.exception(f"Error running bonus: {e}") return ( gr.skip(), gr.skip(), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=error_msg), ) def evaluate(self, state_dict: PipelineStateDict, progress: gr.Progress = gr.Progress()): """Evaluate the bonus questions.""" try: pipeline_state = validation.validate_bonus_workflow(state_dict) # Validate inputs if not self.ds or not self.ds.num_rows: return "No dataset loaded", None, None agent = QuizBowlBonusAgent(pipeline_state.workflow) model_outputs = run_and_eval_bonus_dataset( agent, self.ds, num_workers=2, return_extras=True, tqdm_provider=progress.tqdm ) n_parts_correct = 0 total_parts = 0 n_questions_correct = 0 for model_output in model_outputs: part_outputs = model_output["part_outputs"] n_parts_correct += sum(output["correct"] for output in part_outputs) total_parts += len(part_outputs) n_questions_correct += int(n_parts_correct == len(part_outputs)) p_accuracy = n_parts_correct / total_parts q_accuracy = n_questions_correct / len(self.ds) df = pd.DataFrame( [ { "Question Accuracy": f"{q_accuracy:.2%}", "Part Accuracy": f"{p_accuracy:.2%}", "Questions Evaluated": len(self.ds), } ] ) # plot_data = create_scatter_pyplot(part_numbers, part_scores) return ( gr.update(value=df, label="Scores on Sample Set"), gr.update(visible=False), gr.update(visible=False), ) except Exception as e: error_msg = styled_error(create_error_message(e)) logger.exception(f"Error evaluating bonus: {e}") return gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg) def submit_model( self, model_name: str, description: str, state_dict: PipelineStateDict, profile: gr.OAuthProfile = None, ): """Submit the model output.""" pipeline_state = PipelineState(**state_dict) return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile) @property def pipeline_state(self): return self.pipeline_interface.pipeline_state # ------------------------------------- Event Listeners ----------------------------------------------------------- def _setup_event_listeners(self): # Initialize with the default question (ID 0) gr.on( triggers=[self.app.load, self.qid_selector.change], fn=self.get_new_question_html, inputs=[self.qid_selector], outputs=[self.question_display], ) gr.on( triggers=[self.app.load], fn=self.get_pipeline_names, outputs=[self.pipeline_selector], ) pipeline_change = self.pipeline_interface.pipeline_change gr.on( triggers=[self.app.load], fn=self.load_presaved_pipeline_state, inputs=[self.browser_state, pipeline_change], outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state], ) self.load_btn.click( fn=self.load_pipeline, inputs=[self.pipeline_selector, pipeline_change], outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display], ) self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state) self.run_btn.click( self.single_run, inputs=[ self.qid_selector, self.pipeline_state, ], outputs=[ self.question_display, self.output_state, self.results_table, self.model_outputs_display, self.error_display, ], ) self.eval_btn.click( fn=self.evaluate, inputs=[self.pipeline_state], outputs=[self.results_table, self.model_outputs_display, self.error_display], ) self.submit_btn.click( fn=self.submit_model, inputs=[ self.model_name_input, self.description_input, self.pipeline_state, ], outputs=[self.submit_status], )