|
from typing import Any, Dict, List |
|
|
|
import pandas as pd |
|
|
|
|
|
def _create_confidence_plot_data(results: List[Dict], top_k_mode: bool = False) -> pd.DataFrame: |
|
"""Create a DataFrame for the confidence plot.""" |
|
if not top_k_mode: |
|
return pd.DataFrame( |
|
{ |
|
"position": [r["position"] for r in results], |
|
"confidence": [r["confidence"] for r in results], |
|
"answer": [r["answer"] for r in results], |
|
} |
|
) |
|
|
|
|
|
return _create_top_k_plot_data(results) |
|
|
|
|
|
def _create_top_k_plot_data(results: List[Dict]) -> pd.DataFrame: |
|
"""Create plot data for top-k mode.""" |
|
|
|
top_answers = set() |
|
for r in results: |
|
for g in r.get("guesses", [])[:3]: |
|
if g.get("answer"): |
|
top_answers.add(g.get("answer")) |
|
|
|
top_answers = list(top_answers)[:5] |
|
|
|
|
|
all_data = [] |
|
for position_idx, result in enumerate(results): |
|
position = result["position"] |
|
for answer in top_answers: |
|
confidence = 0 |
|
for guess in result.get("guesses", []): |
|
if guess.get("answer") == answer: |
|
confidence = guess.get("confidence", 0) |
|
break |
|
all_data.append({"position": position, "confidence": confidence, "answer": answer}) |
|
|
|
return pd.DataFrame(all_data) |
|
|
|
|
|
def _create_top_k_dataframe(results: List[Dict]) -> pd.DataFrame: |
|
"""Create a DataFrame for top-k results.""" |
|
df_rows = [] |
|
for result in results: |
|
position = result["position"] |
|
for i, guess in enumerate(result.get("guesses", [])): |
|
df_rows.append( |
|
{ |
|
"position": position, |
|
"answer": guess.get("answer", ""), |
|
"confidence": guess.get("confidence", 0), |
|
"rank": i + 1, |
|
} |
|
) |
|
return pd.DataFrame(df_rows) |
|
|
|
|
|
def _format_buzz_result(buzzed: bool, results: List[Dict], gold_label: str, top_k_mode: bool) -> tuple[str, str, bool]: |
|
"""Format the result text based on whether the agent buzzed.""" |
|
if not buzzed: |
|
return f"Did not buzz. Correct answer was: {gold_label}", "No buzz", False |
|
|
|
buzz_position = next(i for i, r in enumerate(results) if r.get("buzz", False)) |
|
buzz_result = results[buzz_position] |
|
|
|
if top_k_mode: |
|
|
|
top_answers = [g.get("answer", "").lower() for g in buzz_result.get("guesses", [])] |
|
correct = gold_label.lower() in [a.lower() for a in top_answers] |
|
final_answer = top_answers[0] if top_answers else "No answer" |
|
else: |
|
|
|
final_answer = buzz_result["answer"] |
|
correct = final_answer.lower() == gold_label.lower() |
|
|
|
result_text = f"BUZZED at position {buzz_position + 1} with answer: {final_answer}\n" |
|
result_text += f"Correct answer: {gold_label}\n" |
|
result_text += f"Result: {'CORRECT' if correct else 'INCORRECT'}" |
|
|
|
return result_text, final_answer, correct |
|
|