Maharshi Gor
Enhance model provider detection and add repository management script. Added support for multi step agent.
973519b
import json | |
import logging | |
from typing import Any | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from datasets import Dataset | |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState | |
from submission import submit | |
from workflows.qb.multi_step_agent import MultiStepTossupAgent | |
from workflows.qb.simple_agent import SimpleTossupAgent | |
from workflows.structs import ModelStep, Workflow | |
from .plotting import ( | |
create_answer_html, | |
create_pyplot, | |
create_scatter_pyplot, | |
create_tokens_html, | |
evaluate_buzz, | |
update_plot, | |
) | |
# TODO: Error handling on run tossup and evaluate tossup and show correct messages | |
# TODO: ^^ Same for Bonus | |
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]: | |
"""Add model scores to the model outputs.""" | |
for output, run_idx in zip(model_outputs, run_indices): | |
output["score"] = evaluate_buzz(output["answer"], clean_answers) | |
output["token_position"] = run_idx + 1 | |
return model_outputs | |
def prepare_buzz_evals( | |
run_indices: list[int], model_outputs: list[dict] | |
) -> tuple[list[str], list[tuple[int, float, bool]]]: | |
"""Process text into tokens and assign random values for demonstration.""" | |
if not run_indices: | |
logging.warning("No run indices provided, returning empty results") | |
return [], [] | |
eval_points = [] | |
for i, v in zip(run_indices, model_outputs): | |
eval_point = v["confidence"], v["buzz"], v["score"] | |
eval_points.append((int(i), eval_point)) | |
return eval_points | |
def initialize_eval_interface(example, model_outputs: list[dict]): | |
"""Initialize the interface with example text.""" | |
tokens = example["question"].split() | |
run_indices = example["run_indices"] | |
answer = example["answer_primary"] | |
try: | |
eval_points = prepare_buzz_evals(run_indices, model_outputs) | |
if not tokens: | |
return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}" | |
highlighted_index = next((int(i) for i, (_, b, _) in eval_points if b == 1), -1) | |
html_content = create_tokens_html(tokens, eval_points, answer) | |
plot_data = create_pyplot(tokens, eval_points, highlighted_index) | |
# Store tokens, values, and buzzes as JSON for later use | |
state = json.dumps({"tokens": tokens, "values": eval_points}) | |
return html_content, plot_data, state | |
except Exception as e: | |
logging.error(f"Error initializing interface: {e}", exc_info=True) | |
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}" | |
def process_tossup_results(results: list[dict], top_k_mode: bool = False) -> pd.DataFrame: | |
"""Process results from tossup mode and prepare visualization data.""" | |
# Create DataFrame for detailed results | |
if top_k_mode: | |
raise ValueError("Top-k mode not supported for tossup mode") | |
return pd.DataFrame( | |
[ | |
{ | |
"Token Position": r["token_position"], | |
"Correct?": "✅" if r["score"] == 1 else "❌", | |
"Confidence": r["confidence"], | |
"Prediction": r["answer"], | |
} | |
for r in results | |
] | |
) | |
def validate_workflow(workflow: Workflow): | |
""" | |
Validate that a workflow is properly configured for the tossup task. | |
Args: | |
workflow (Workflow): The workflow to validate | |
Raises: | |
ValueError: If the workflow is not properly configured | |
""" | |
if not workflow.steps: | |
raise ValueError("Workflow must have at least one step") | |
# Ensure all steps are properly configured | |
for step_id, step in workflow.steps.items(): | |
validate_model_step(step) | |
# Check that the workflow has the correct structure | |
input_vars = set(workflow.inputs) | |
if "question" not in input_vars: | |
raise ValueError("Workflow must have 'question' as an input") | |
output_vars = set(workflow.outputs) | |
if not any("answer" in out_var for out_var in output_vars): | |
raise ValueError("Workflow must produce an 'answer' as output") | |
if not any("confidence" in out_var for out_var in output_vars): | |
raise ValueError("Workflow must produce a 'confidence' score as output") | |
def validate_model_step(model_step: ModelStep): | |
""" | |
Validate that a model step is properly configured for the tossup task. | |
Args: | |
model_step (ModelStep): The model step to validate | |
Raises: | |
ValueError: If the model step is not properly configured | |
""" | |
# Check required fields | |
if not model_step.model or not model_step.provider: | |
raise ValueError("Model step must have both model and provider specified") | |
if model_step.call_type != "llm": | |
raise ValueError("Model step must have call_type 'llm'") | |
# Validate temperature for LLM steps | |
if model_step.temperature is None: | |
raise ValueError("Temperature must be specified for LLM model steps") | |
if not (0.0 <= model_step.temperature <= 1.0): | |
raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}") | |
# Validate input fields | |
input_field_names = {field.name for field in model_step.input_fields} | |
if "question" not in input_field_names: | |
raise ValueError("Model step must have a 'question' input field") | |
# Validate output fields | |
output_field_names = {field.name for field in model_step.output_fields} | |
if "answer" not in output_field_names: | |
raise ValueError("Model step must have an 'answer' output field") | |
if "confidence" not in output_field_names: | |
raise ValueError("Model step must have a 'confidence' output field") | |
# Validate confidence output field is of type float | |
for field in model_step.output_fields: | |
if field.name == "confidence" and field.type != "float": | |
raise ValueError("The 'confidence' output field must be of type 'float'") | |
class TossupInterface: | |
"""Gradio interface for the Tossup mode.""" | |
def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict): | |
"""Initialize the Tossup interface.""" | |
logging.info(f"Initializing Tossup interface with dataset size: {len(dataset)}") | |
self.ds = dataset | |
self.model_options = model_options | |
self.app = app | |
self.defaults = defaults | |
self.output_state = gr.State(value="{}") | |
self.render() | |
def _render_model_interface(self, workflow: Workflow, simple: bool = True): | |
"""Render the model interface.""" | |
self.pipeline_interface = PipelineInterface( | |
workflow, | |
simple=simple, | |
model_options=list(self.model_options.keys()), | |
) | |
with gr.Row(): | |
self.buzz_t_slider = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
value=self.defaults["buzz_threshold"], | |
step=0.01, | |
label="Buzz Threshold", | |
) | |
self.early_stop_checkbox = gr.Checkbox( | |
value=self.defaults["early_stop"], | |
label="Early Stop", | |
info="Stop early if already buzzed", | |
) | |
self.run_btn = gr.Button("Run Tossup", variant="primary") | |
def _render_qb_interface(self): | |
"""Render the quizbowl interface.""" | |
with gr.Row(): | |
self.qid_selector = gr.Number( | |
label="Question ID", value=1, precision=0, minimum=1, maximum=len(self.ds), show_label=True, scale=0 | |
) | |
self.answer_display = gr.Textbox( | |
label="PrimaryAnswer", elem_id="answer-display", elem_classes="answer-box", interactive=False, scale=1 | |
) | |
self.clean_answer_display = gr.Textbox( | |
label="Acceptable Answers", | |
elem_id="answer-display-2", | |
elem_classes="answer-box", | |
interactive=False, | |
scale=2, | |
) | |
# self.answer_display = gr.HTML(label="Answer", elem_id="answer-display") | |
self.question_display = gr.HTML(label="Question", elem_id="question-display") | |
with gr.Row(): | |
self.confidence_plot = gr.Plot( | |
label="Buzz Confidence", | |
format="webp", | |
) | |
self.results_table = gr.DataFrame( | |
label="Model Outputs", | |
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]), | |
) | |
with gr.Row(): | |
self.eval_btn = gr.Button("Evaluate") | |
with gr.Accordion("Model Submission", elem_classes="model-submission-accordion", open=True): | |
with gr.Row(): | |
self.model_name_input = gr.Textbox(label="Model Name") | |
self.description_input = gr.Textbox(label="Description") | |
with gr.Row(): | |
gr.LoginButton() | |
self.submit_btn = gr.Button("Submit") | |
self.submit_status = gr.HTML(label="Submission Status") | |
def render(self): | |
"""Create the Gradio interface.""" | |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index") | |
workflow = self.defaults["init_workflow"] | |
with gr.Row(): | |
# Model Panel | |
with gr.Column(scale=1): | |
self._render_model_interface(workflow, simple=self.defaults["simple_workflow"]) | |
with gr.Column(scale=1): | |
self._render_qb_interface() | |
self._setup_event_listeners() | |
def get_full_question(self, question_id: int) -> str: | |
"""Get the full question text for a given question ID.""" | |
try: | |
question_id = int(question_id - 1) | |
if not self.ds or question_id < 0 or question_id >= len(self.ds): | |
return "Invalid question ID or dataset not loaded" | |
question_data = self.ds[question_id] | |
# Get the full question text (the last element in question_runs) | |
full_question = question_data["question"] | |
gold_label = question_data["answer_primary"] | |
return f"Question: {full_question}\n\nCorrect Answer: {gold_label}" | |
except Exception as e: | |
return f"Error loading question: {str(e)}" | |
def validate_workflow(self, pipeline_state: PipelineState): | |
"""Validate the workflow.""" | |
try: | |
validate_workflow(pipeline_state.workflow) | |
except Exception as e: | |
raise gr.Error(f"Error validating workflow: {str(e)}") | |
def get_new_question_html(self, question_id: int): | |
"""Get the HTML for a new question.""" | |
example = self.ds[question_id - 1] | |
question = example["question"] | |
gold_label = example["answer_primary"] | |
marker_indices = example["run_indices"] | |
tokens = question.split() | |
question_html = create_tokens_html(tokens, [], gold_label, marker_indices) | |
clean_answers = [a for a in example["clean_answers"] if len(a.split()) <= 6] | |
clean_answers = ", ".join(clean_answers) | |
return question_html, gold_label, clean_answers | |
def get_model_outputs(self, example: dict, pipeline_state: PipelineState, buzz_threshold: float, early_stop: bool): | |
"""Get the model outputs for a given question ID.""" | |
question_runs = [] | |
tokens = example["question"].split() | |
for run_idx in example["run_indices"]: | |
question_runs.append(" ".join(tokens[: run_idx + 1])) | |
workflow = pipeline_state.workflow | |
if len(workflow.steps) > 1: | |
agent = MultiStepTossupAgent(workflow, buzz_threshold) | |
else: | |
agent = SimpleTossupAgent(workflow, buzz_threshold) | |
outputs = list(agent.run(question_runs, early_stop=early_stop)) | |
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"]) | |
return outputs | |
def run_tossup( | |
self, | |
question_id: int, | |
pipeline_state: PipelineState, | |
buzz_threshold: float, | |
early_stop: bool = True, | |
) -> tuple[str, Any, Any]: | |
"""Run the agent in tossup mode with a system prompt.""" | |
try: | |
# Validate inputs | |
question_id = int(question_id - 1) | |
if not self.ds or question_id < 0 or question_id >= len(self.ds): | |
return "Invalid question ID or dataset not loaded", None, None | |
example = self.ds[question_id] | |
outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop) | |
# Process results and prepare visualization data | |
tokens_html, plot_data, output_state = initialize_eval_interface(example, outputs) | |
df = process_tossup_results(outputs) | |
return ( | |
tokens_html, | |
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"), | |
gr.update(value=output_state), | |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"), | |
) | |
except Exception as e: | |
import traceback | |
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
return error_msg, None, None | |
def evaluate_tossups( | |
self, pipeline_state: PipelineState, buzz_threshold: float, progress: gr.Progress = gr.Progress() | |
): | |
"""Evaluate the tossup.""" | |
try: | |
# Validate inputs | |
if not self.ds or not self.ds.num_rows: | |
return "No dataset loaded", None, None | |
buzz_counts = 0 | |
correct_buzzes = 0 | |
token_positions = [] | |
correctness = [] | |
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"): | |
model_outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop=True) | |
if model_outputs[-1]["buzz"]: | |
buzz_counts += 1 | |
if model_outputs[-1]["score"] == 1: | |
correct_buzzes += 1 | |
token_positions.append(model_outputs[-1]["token_position"]) | |
correctness.append(model_outputs[-1]["score"]) | |
buzz_accuracy = correct_buzzes / buzz_counts | |
df = pd.DataFrame( | |
[ | |
{ | |
"Avg Buzz Position": f"{np.mean(token_positions):.2f}", | |
"Buzz Accuracy": f"{buzz_accuracy:.2%}", | |
"Total Score": f"{correct_buzzes}/{len(self.ds)}", | |
} | |
] | |
) | |
plot_data = create_scatter_pyplot(token_positions, correctness) | |
return ( | |
gr.update(value=df, label="Scores on Sample Set"), | |
gr.update(value=plot_data, label="Buzz Positions on Sample Set"), | |
) | |
except Exception: | |
import traceback | |
logging.error(f"Error evaluating tossups: {traceback.format_exc()}") | |
return "Error evaluating tossups", None, None | |
def submit_model( | |
self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None | |
): | |
"""Submit the model output.""" | |
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile) | |
def _setup_event_listeners(self): | |
gr.on( | |
triggers=[self.app.load, self.qid_selector.change], | |
fn=self.get_new_question_html, | |
inputs=[self.qid_selector], | |
outputs=[self.question_display, self.answer_display, self.clean_answer_display], | |
) | |
self.run_btn.click( | |
self.pipeline_interface.validate_workflow, | |
inputs=[self.pipeline_interface.pipeline_state], | |
outputs=[self.pipeline_interface.pipeline_state], | |
).success( | |
self.run_tossup, | |
inputs=[ | |
self.qid_selector, | |
self.pipeline_interface.pipeline_state, | |
self.buzz_t_slider, | |
self.early_stop_checkbox, | |
], | |
outputs=[ | |
self.question_display, | |
self.confidence_plot, | |
self.output_state, | |
self.results_table, | |
], | |
) | |
self.eval_btn.click( | |
fn=self.evaluate_tossups, | |
inputs=[self.pipeline_interface.pipeline_state, self.buzz_t_slider], | |
outputs=[self.results_table, self.confidence_plot], | |
) | |
self.submit_btn.click( | |
fn=self.submit_model, | |
inputs=[ | |
self.model_name_input, | |
self.description_input, | |
self.pipeline_interface.pipeline_state, | |
], | |
outputs=[self.submit_status], | |
) | |
self.hidden_input.change( | |
fn=update_plot, | |
inputs=[self.hidden_input, self.output_state], | |
outputs=[self.confidence_plot], | |
) | |