Maharshi Gor
Updated workflow APIs, code clean up and minor functions for hf pipeline support
f064c62
import json | |
from typing import Any | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from datasets import Dataset | |
from loguru import logger | |
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME | |
from components import commons | |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState | |
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict | |
from display.formatting import styled_error | |
from shared.workflows import factory | |
from shared.workflows.metrics import evaluate_prediction | |
from shared.workflows.metrics.qb_metrics import prepare_tossup_results_df | |
from shared.workflows.qb_agents import QuizBowlTossupAgent, TossupResult | |
from shared.workflows.runners import run_and_eval_tossup_dataset, run_and_evaluate_tossup | |
from submission import submit | |
from . import populate, validation | |
from .plotting import ( | |
create_tossup_confidence_pyplot, | |
create_tossup_eval_dashboard, | |
create_tossup_eval_table, | |
create_tossup_html, | |
) | |
from .utils import create_error_message | |
from .validation import UserInputWorkflowValidator | |
class ScoredTossupResult(TossupResult): | |
"""Result of a tossup question with evaluation score and position.""" | |
score: int # Correctness score of the answer | |
token_position: int # 0-indexed position in the question where prediction was made | |
def add_model_scores( | |
run_outputs: list[TossupResult], clean_answers: list[str], run_indices: list[int] | |
) -> list[ScoredTossupResult]: | |
"""Add model scores to the model outputs.""" | |
for output in run_outputs: | |
output["correct"] = evaluate_prediction(output["guess"], clean_answers) | |
output["token_position"] = run_indices[output["run_idx"] - 1] | |
return run_outputs | |
def prepare_buzz_evals( | |
run_indices: list[int], model_outputs: list[dict] | |
) -> tuple[list[str], list[tuple[int, float, bool]]]: | |
"""Process text into tokens and assign random values for demonstration.""" | |
if not run_indices: | |
logger.warning("No run indices provided, returning empty results") | |
return [], [] | |
eval_points = [] | |
for o in model_outputs: | |
token_position = run_indices[o["run_idx"] - 1] | |
eval_points.append((token_position, o)) | |
return eval_points | |
def initialize_eval_interface( | |
example: dict, | |
run_outputs: list[dict], | |
input_vars: list, | |
confidence_threshold: float, | |
prob_threshold: float | None = None, | |
): | |
"""Initialize the interface with example text.""" | |
try: | |
tokens = example["question"].split() | |
run_indices = example["run_indices"] | |
answer = example["answer_primary"] | |
clean_answers = example["clean_answers"] | |
eval_points = [(o["token_position"], o) for o in run_outputs] | |
if not tokens: | |
error_msg = "No tokens found in the provided text." | |
logger.exception(error_msg) | |
return styled_error(error_msg), pd.DataFrame(), {}, {} | |
html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points) | |
plot_data = create_tossup_confidence_pyplot(tokens, run_outputs, confidence_threshold, prob_threshold) | |
# Store tokens, values, and buzzes as JSON for later use | |
state = {"tokens": tokens, "values": eval_points} | |
# Preparing step outputs for the model | |
step_outputs = {} | |
for output in run_outputs: | |
tok_pos = output["token_position"] | |
key = "{pos}:{token}".format(pos=tok_pos, token=tokens[tok_pos - 1]) | |
step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in input_vars} | |
if output["logprob"] is not None: | |
step_outputs[key]["output_probability"] = float(np.exp(output["logprob"])) | |
return html_content, plot_data, state, step_outputs | |
except Exception as e: | |
error_msg = f"Error initializing interface: {str(e)}" | |
logger.exception(error_msg) | |
return styled_error(error_msg), pd.DataFrame(), {}, {} | |
def process_tossup_results(results: list[dict]) -> pd.DataFrame: | |
"""Process results from tossup mode and prepare visualization data.""" | |
data = [] | |
for r in results: | |
entry = { | |
"Token Position": r["token_position"], | |
"Correct?": "✅" if r["correct"] == 1 else "❌", | |
"Confidence": r["confidence"], | |
} | |
if r["logprob"] is not None: | |
entry["Probability"] = f"{np.exp(r['logprob']):.3f}" | |
entry["Prediction"] = r["guess"] | |
data.append(entry) | |
return pd.DataFrame(data) | |
class TossupInterface: | |
"""Gradio interface for the Tossup mode.""" | |
def __init__( | |
self, | |
app: gr.Blocks, | |
browser_state: gr.BrowserState, | |
dataset: Dataset, | |
model_options: dict, | |
defaults: TossupInterfaceDefaults, | |
): | |
"""Initialize the Tossup interface.""" | |
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}") | |
self.browser_state = browser_state | |
self.ds = dataset | |
self.model_options = model_options | |
self.app = app | |
self.defaults = defaults | |
self.output_state = gr.State(value={}) | |
self.render() | |
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------ | |
def load_default_workflow(self): | |
workflow = self.defaults["init_workflow"] | |
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() | |
return pipeline_state_dict, {} | |
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool): | |
try: | |
state_dict = browser_state["tossup"].get("pipeline_state", {}) | |
if state_dict: | |
pipeline_state = TossupPipelineState.model_validate(state_dict) | |
pipeline_state_dict = pipeline_state.model_dump() | |
output_state = browser_state["tossup"].get("output_state", {}) | |
else: | |
pipeline_state_dict, output_state = self.load_default_workflow() | |
except Exception as e: | |
logger.warning(f"Error loading presaved pipeline state: {e}") | |
pipeline_state_dict, output_state = self.load_default_workflow() | |
return browser_state, not pipeline_change, pipeline_state_dict, output_state | |
# ------------------------------------------ INTERFACE RENDER FUNCTIONS ------------------------------------------- | |
def _render_pipeline_interface(self, pipeline_state: TossupPipelineState): | |
"""Render the model interface.""" | |
with gr.Row(elem_classes="bonus-header-row form-inline"): | |
self.pipeline_selector = commons.get_pipeline_selector([]) | |
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary") | |
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False) | |
self.pipeline_interface = TossupPipelineInterface( | |
self.app, | |
pipeline_state.workflow, | |
ui_state=pipeline_state.ui_state, | |
model_options=list(self.model_options.keys()), | |
config=self.defaults, | |
validator=UserInputWorkflowValidator("tossup"), | |
) | |
def _render_qb_interface(self): | |
"""Render the quizbowl interface.""" | |
with gr.Row(elem_classes="bonus-header-row form-inline"): | |
self.qid_selector = commons.get_qid_selector(len(self.ds)) | |
self.early_stop_checkbox = gr.Checkbox( | |
value=self.defaults["early_stop"], | |
label="Early Stop", | |
info="Stop if already buzzed", | |
scale=0, | |
) | |
self.run_btn = gr.Button("Run on Tossup Question", variant="secondary") | |
self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display") | |
self.error_display = gr.HTML(label="Error", elem_id="tossup-error-display", visible=False) | |
with gr.Row(): | |
self.confidence_plot = gr.Plot( | |
label="Buzz Confidence", | |
format="webp", | |
) | |
self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False) | |
self.results_table = gr.DataFrame( | |
label="Model Outputs", | |
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]), | |
visible=False, | |
) | |
with gr.Row(): | |
self.eval_btn = gr.Button("Evaluate", variant="primary") | |
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = ( | |
commons.get_model_submission_accordion(self.app) | |
) | |
def render(self): | |
"""Create the Gradio interface.""" | |
workflow = factory.create_empty_tossup_workflow() | |
pipeline_state = TossupPipelineState.from_workflow(workflow) | |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index") | |
with gr.Row(): | |
# Model Panel | |
with gr.Column(scale=1): | |
self._render_pipeline_interface(pipeline_state) | |
with gr.Column(scale=1): | |
self._render_qb_interface() | |
self._setup_event_listeners() | |
# ------------------------------------- Component Updates Functions --------------------------------------------- | |
def get_new_question_html(self, question_id: int) -> str: | |
"""Get the HTML for a new question.""" | |
if question_id is None: | |
logger.error("Question ID is None. Setting to 1") | |
question_id = 1 | |
try: | |
example = self.ds[question_id - 1] | |
question_tokens = example["question"].split() | |
return create_tossup_html( | |
question_tokens, example["answer_primary"], example["clean_answers"], example["run_indices"] | |
) | |
except Exception as e: | |
return f"Error loading question: {str(e)}" | |
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]: | |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile) | |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME) | |
def load_pipeline( | |
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None | |
) -> tuple[str, bool, TossupPipelineStateDict, dict]: | |
try: | |
workflow = populate.load_workflow("tossup", model_name, profile) | |
if workflow is None: | |
logger.warning(f"Could not load workflow for {model_name}") | |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False) | |
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() | |
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True) | |
except Exception as e: | |
logger.exception(e) | |
error_msg = styled_error(f"Error loading pipeline: {str(e)}") | |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg) | |
# ------------------------------------- Agent Functions ----------------------------------------------------------- | |
def single_run( | |
self, | |
question_id: int, | |
state_dict: TossupPipelineStateDict, | |
early_stop: bool = True, | |
) -> tuple[str, Any, Any]: | |
"""Run the agent in tossup mode with a system prompt. | |
Returns: | |
tuple: A tuple containing: | |
- tokens_html (str): HTML representation of the tossup question with buzz indicators | |
- output_state (gr.update): Update for the output state component | |
- plot_data (gr.update): Update for the confidence plot with label and visibility | |
- df (gr.update): Update for the dataframe component showing model outputs | |
- step_outputs (gr.update): Update for the step outputs component | |
- error_msg (gr.update): Update for the error message component (hidden if no errors) | |
""" | |
try: | |
pipeline_state = validation.validate_tossup_workflow(state_dict) | |
workflow = pipeline_state.workflow | |
# Validate inputs | |
question_id = int(question_id - 1) | |
if not self.ds or question_id < 0 or question_id >= len(self.ds): | |
raise gr.Error("Invalid question ID or dataset not loaded") | |
example = self.ds[question_id] | |
outputs = run_and_evaluate_tossup( | |
QuizBowlTossupAgent(pipeline_state.workflow), | |
example, | |
return_extras=True, | |
early_stop=early_stop, | |
) | |
run_outputs = outputs["run_outputs"] | |
# Process results and prepare visualization data | |
confidence_threshold = workflow.buzzer.confidence_threshold | |
prob_threshold = workflow.buzzer.prob_threshold | |
tokens_html, plot_data, output_state, step_outputs = initialize_eval_interface( | |
example, run_outputs, workflow.inputs, confidence_threshold, prob_threshold | |
) | |
df = process_tossup_results(run_outputs) | |
return ( | |
tokens_html, | |
gr.update(value=output_state), | |
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}", show_label=True), | |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True), | |
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True), | |
gr.update(visible=False), | |
) | |
except Exception as e: | |
error_msg = styled_error(create_error_message(e)) | |
logger.exception(f"Error running tossup: {e}") | |
return ( | |
gr.skip(), | |
gr.skip(), | |
gr.skip(), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True, value=error_msg), | |
) | |
def evaluate(self, state_dict: TossupPipelineStateDict, progress: gr.Progress = gr.Progress()): | |
"""Evaluate the tossup questions.""" | |
try: | |
# Validate inputs | |
if not self.ds or not self.ds.num_rows: | |
return "No dataset loaded", None, None | |
pipeline_state = validation.validate_tossup_workflow(state_dict) | |
agent = QuizBowlTossupAgent(pipeline_state.workflow) | |
model_outputs = run_and_eval_tossup_dataset( | |
agent, self.ds, return_extras=True, tqdm_provider=progress.tqdm, num_workers=2 | |
) | |
eval_df = prepare_tossup_results_df(model_outputs, self.ds["run_indices"]) | |
plot_data = create_tossup_eval_dashboard(self.ds["run_indices"], eval_df) | |
output_df = create_tossup_eval_table(eval_df) | |
return ( | |
gr.update(value=plot_data, label="Buzz Positions on Sample Set", show_label=False), | |
gr.update(value=output_df, label="(Mean) Metrics on Sample Set", visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
except Exception as e: | |
error_msg = styled_error(create_error_message(e)) | |
logger.exception(f"Error evaluating tossups: {e}") | |
return ( | |
gr.skip(), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True, value=error_msg), | |
) | |
def submit_model( | |
self, | |
model_name: str, | |
description: str, | |
state_dict: TossupPipelineStateDict, | |
profile: gr.OAuthProfile = None, | |
) -> str: | |
"""Submit the model output.""" | |
try: | |
pipeline_state = validation.validate_tossup_workflow(state_dict) | |
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile) | |
except Exception as e: | |
logger.exception(f"Error submitting model: {e.args}") | |
return styled_error(f"Error: {str(e)}") | |
def pipeline_state(self): | |
return self.pipeline_interface.pipeline_state | |
# ------------------------------------- Event Listeners ----------------------------------------------------------- | |
def _setup_event_listeners(self): | |
gr.on( | |
triggers=[self.app.load, self.qid_selector.change], | |
fn=self.get_new_question_html, | |
inputs=[self.qid_selector], | |
outputs=[self.question_display], | |
) | |
gr.on( | |
triggers=[self.app.load], | |
fn=self.get_pipeline_names, | |
outputs=[self.pipeline_selector], | |
) | |
pipeline_change = self.pipeline_interface.pipeline_change | |
gr.on( | |
triggers=[self.app.load], | |
fn=self.load_presaved_pipeline_state, | |
inputs=[self.browser_state, pipeline_change], | |
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state], | |
) | |
self.load_btn.click( | |
fn=self.load_pipeline, | |
inputs=[self.pipeline_selector, pipeline_change], | |
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display], | |
) | |
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state) | |
self.run_btn.click( | |
self.single_run, | |
inputs=[ | |
self.qid_selector, | |
self.pipeline_state, | |
self.early_stop_checkbox, | |
], | |
outputs=[ | |
self.question_display, | |
self.output_state, | |
self.confidence_plot, | |
self.results_table, | |
self.model_outputs_display, | |
self.error_display, | |
], | |
) | |
self.eval_btn.click( | |
fn=self.evaluate, | |
inputs=[self.pipeline_state], | |
outputs=[self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display], | |
) | |
self.submit_btn.click( | |
fn=self.submit_model, | |
inputs=[ | |
self.model_name_input, | |
self.description_input, | |
self.pipeline_state, | |
], | |
outputs=[self.submit_status], | |
) | |