Maharshi Gor
Major update:
f10a835
raw
history blame
15.5 kB
import json
from typing import Any
import gradio as gr
import pandas as pd
from datasets import Dataset
from loguru import logger
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
from components import commons
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
from components.typed_dicts import PipelineStateDict
from display.formatting import styled_error
from submission import submit
from workflows import factory
from workflows.qb_agents import QuizBowlBonusAgent
from . import populate, validation
from .plotting import create_bonus_confidence_plot, create_bonus_html
from .utils import evaluate_prediction
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
"""Process results from bonus mode and prepare visualization data."""
return pd.DataFrame(
[
{
"Part": f"Part {r['part_number']}",
"Correct?": "✅" if r["score"] == 1 else "❌",
"Confidence": r["confidence"],
"Prediction": r["answer"],
"Explanation": r["explanation"],
}
for r in results
]
)
def initialize_eval_interface(example: dict, model_outputs: list[dict]):
"""Initialize the interface with example text."""
try:
html_content = create_bonus_html(example["leadin"], example["parts"])
# Create confidence plot data
plot_data = create_bonus_confidence_plot(example["parts"], model_outputs)
# Store state
state = json.dumps({"parts": example["parts"], "outputs": model_outputs})
return html_content, plot_data, state
except Exception as e:
logger.exception(f"Error initializing interface: {e.args}")
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
class BonusInterface:
"""Gradio interface for the Bonus mode."""
def __init__(self, app: gr.Blocks, browser_state: dict, dataset: Dataset, model_options: dict, defaults: dict):
"""Initialize the Bonus interface."""
logger.info(f"Initializing Bonus interface with dataset size: {len(dataset)}")
self.browser_state = browser_state
self.ds = dataset
self.model_options = model_options
self.app = app
self.defaults = defaults
self.output_state = gr.State(value="{}")
self.render()
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE -------------------------------------
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
try:
state_dict = browser_state["bonus"].get("pipeline_state", {})
pipeline_state = PipelineState.model_validate(state_dict)
pipeline_state_dict = pipeline_state.model_dump()
output_state = browser_state["bonus"].get("output_state", "{}")
except Exception as e:
logger.warning(f"Error loading presaved pipeline state: {e}")
output_state = "{}"
workflow = self.defaults["init_workflow"]
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
return browser_state, not pipeline_change, pipeline_state_dict, output_state
# ------------------------------------------ INTERFACE RENDER FUNCTIONS -------------------------------------------
def _render_pipeline_interface(self, pipeline_state: PipelineState):
"""Render the model interface."""
with gr.Row(elem_classes="bonus-header-row form-inline"):
self.pipeline_selector = commons.get_pipeline_selector([])
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
self.pipeline_interface = PipelineInterface(
self.app,
pipeline_state.workflow,
ui_state=pipeline_state.ui_state,
model_options=list(self.model_options.keys()),
config=self.defaults,
)
def _render_qb_interface(self):
"""Render the quizbowl interface."""
with gr.Row(elem_classes="bonus-header-row form-inline"):
self.qid_selector = commons.get_qid_selector(len(self.ds))
self.run_btn = gr.Button("Run on Bonus Question", variant="secondary")
self.question_display = gr.HTML(label="Question", elem_id="bonus-question-display")
self.error_display = gr.HTML(label="Error", elem_id="bonus-error-display", visible=False)
self.results_table = gr.DataFrame(
label="Model Outputs",
value=pd.DataFrame(columns=["Part", "Correct?", "Confidence", "Prediction", "Explanation"]),
visible=False,
)
self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False)
with gr.Row():
self.eval_btn = gr.Button("Evaluate", variant="primary")
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = (
commons.get_model_submission_accordion(self.app)
)
def render(self):
"""Create the Gradio interface."""
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
workflow = factory.create_empty_tossup_workflow()
pipeline_state = PipelineState.from_workflow(workflow)
with gr.Row():
# Model Panel
with gr.Column(scale=1):
self._render_pipeline_interface(pipeline_state)
with gr.Column(scale=1):
self._render_qb_interface()
self._setup_event_listeners()
def validate_workflow(self, state_dict: PipelineStateDict):
"""Validate the workflow."""
try:
pipeline_state = PipelineState(**state_dict)
validation.validate_workflow(
pipeline_state.workflow,
required_input_vars=CONFIGS["bonus"]["required_input_vars"],
required_output_vars=CONFIGS["bonus"]["required_output_vars"],
)
except Exception as e:
raise gr.Error(f"Error validating workflow: {str(e)}")
def get_new_question_html(self, question_id: int):
"""Get the HTML for a new question."""
if question_id is None:
logger.error("Question ID is None. Setting to 1")
question_id = 1
try:
question_id = int(question_id) - 1
if not self.ds or question_id < 0 or question_id >= len(self.ds):
return "Invalid question ID or dataset not loaded"
example = self.ds[question_id]
leadin = example["leadin"]
parts = example["parts"]
return create_bonus_html(leadin, parts)
except Exception as e:
return f"Error loading question: {str(e)}"
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile)
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
def load_pipeline(
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
) -> tuple[str, bool, PipelineStateDict, dict]:
try:
workflow = populate.load_workflow("bonus", model_name, profile)
if workflow is None:
logger.warning(f"Could not load workflow for {model_name}")
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True)
except Exception as e:
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
# ------------------------------------- Agent Functions -----------------------------------------------------------
def get_agent_outputs(self, example: dict, pipeline_state: PipelineState):
"""Get the model outputs for a given question ID."""
outputs = []
leadin = example["leadin"]
agent = QuizBowlBonusAgent(pipeline_state.workflow)
for i, part in enumerate(example["parts"]):
# Run model for each part
part_output = agent.run(leadin, part["part"])
# Add part number and evaluate score
part_output["part_number"] = i + 1
part_output["score"] = evaluate_prediction(part_output["answer"], part["clean_answers"])
outputs.append(part_output)
return outputs
def single_run(
self,
question_id: int,
state_dict: PipelineStateDict,
) -> tuple[str, Any, Any]:
"""Run the agent in bonus mode and updates the interface.
Returns:
tuple: Contains the following components:
- question_display: HTML display content of the question
- output_state: Updated state with question parts and outputs
- results_table: DataFrame with model predictions and scores
- model_outputs_display: Detailed step outputs from the model
- error_display: Any error messages (if applicable)
"""
try:
pipeline_state = validation.validate_bonus_workflow(state_dict)
question_id = int(question_id - 1)
if not self.ds or question_id < 0 or question_id >= len(self.ds):
raise gr.Error("Invalid question ID or dataset not loaded")
example = self.ds[question_id]
outputs = self.get_agent_outputs(example, pipeline_state)
# Process results and prepare visualization data
html_content, plot_data, output_state = initialize_eval_interface(example, outputs)
df = process_bonus_results(outputs)
step_outputs = [output["step_outputs"] for output in outputs]
return (
html_content,
gr.update(value=output_state),
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True),
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True),
gr.update(visible=False),
)
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
return (
gr.skip(),
gr.skip(),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, value=error_msg),
)
def evaluate(self, state_dict: PipelineStateDict, progress: gr.Progress = gr.Progress()):
"""Evaluate the bonus questions."""
try:
pipeline_state = validation.validate_bonus_workflow(state_dict)
# Validate inputs
if not self.ds or not self.ds.num_rows:
return "No dataset loaded", None, None
total_correct = 0
total_parts = 0
part_scores = []
part_numbers = []
for example in progress.tqdm(self.ds, desc="Evaluating bonus questions"):
model_outputs = self.get_agent_outputs(example, pipeline_state)
for output in model_outputs:
total_parts += 1
if output["score"] == 1:
total_correct += 1
part_scores.append(output["score"])
part_numbers.append(output["part_number"])
accuracy = total_correct / total_parts
df = pd.DataFrame(
[
{
"Part Accuracy": f"{accuracy:.2%}",
"Total Score": f"{total_correct}/{total_parts}",
"Questions Evaluated": len(self.ds),
}
]
)
# plot_data = create_scatter_pyplot(part_numbers, part_scores)
return (
gr.update(value=df, label="Scores on Sample Set"),
gr.update(visible=False),
gr.update(visible=False),
)
except Exception as e:
error_msg = styled_error(f"Error evaluating bonus: {e.args}")
logger.exception(f"Error evaluating bonus: {e.args}")
return gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
def submit_model(
self,
model_name: str,
description: str,
state_dict: PipelineStateDict,
profile: gr.OAuthProfile = None,
):
"""Submit the model output."""
pipeline_state = PipelineState(**state_dict)
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
@property
def pipeline_state(self):
return self.pipeline_interface.pipeline_state
# ------------------------------------- Event Listeners -----------------------------------------------------------
def _setup_event_listeners(self):
# Initialize with the default question (ID 0)
gr.on(
triggers=[self.app.load, self.qid_selector.change],
fn=self.get_new_question_html,
inputs=[self.qid_selector],
outputs=[self.question_display],
)
gr.on(
triggers=[self.app.load],
fn=self.get_pipeline_names,
outputs=[self.pipeline_selector],
)
pipeline_change = self.pipeline_interface.pipeline_change
gr.on(
triggers=[self.app.load],
fn=self.load_presaved_pipeline_state,
inputs=[self.browser_state, pipeline_change],
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state],
)
self.load_btn.click(
fn=self.load_pipeline,
inputs=[self.pipeline_selector, pipeline_change],
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display],
)
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state)
self.run_btn.click(
self.single_run,
inputs=[
self.qid_selector,
self.pipeline_state,
],
outputs=[
self.question_display,
self.output_state,
self.results_table,
self.model_outputs_display,
self.error_display,
],
)
self.eval_btn.click(
fn=self.evaluate,
inputs=[self.pipeline_state],
outputs=[self.results_table, self.model_outputs_display, self.error_display],
)
self.submit_btn.click(
fn=self.submit_model,
inputs=[
self.model_name_input,
self.description_input,
self.pipeline_state,
],
outputs=[self.submit_status],
)