|
""" |
|
Utility classes and functions for the GuardBench Leaderboard display. |
|
""" |
|
|
|
from dataclasses import dataclass, field, fields |
|
from enum import Enum, auto |
|
from typing import List, Optional |
|
|
|
|
|
class Mode(Enum): |
|
"""Inference mode for the guard model.""" |
|
CoT = auto() |
|
Strict = auto() |
|
|
|
def __str__(self): |
|
"""String representation of the mode.""" |
|
return self.name |
|
|
|
|
|
class ModelType(Enum): |
|
"""Model types for the leaderboard.""" |
|
Unknown = auto() |
|
OpenSource = auto() |
|
ClosedSource = auto() |
|
API = auto() |
|
|
|
def to_str(self, separator: str = "-") -> str: |
|
"""Convert enum to string with separator.""" |
|
if self == ModelType.Unknown: |
|
return "Unknown" |
|
elif self == ModelType.OpenSource: |
|
return f"Open{separator}Source" |
|
elif self == ModelType.ClosedSource: |
|
return f"Closed{separator}Source" |
|
elif self == ModelType.API: |
|
return "API" |
|
return "Unknown" |
|
|
|
class GuardModelType(str, Enum): |
|
"""Guard model types for the leaderboard.""" |
|
LLAMA_GUARD = "llama_guard" |
|
CLASSIFIER = "classifier" |
|
ATLA_SELENE = "atla_selene" |
|
OPENAI_MODERATION = "openai_moderation" |
|
LLM_REGEXP = "llm_regexp" |
|
LLM_SO = "llm_so" |
|
WHITECIRCLE_GUARD = "whitecircle_guard" |
|
|
|
def __str__(self): |
|
"""String representation of the guard model type.""" |
|
return self.name |
|
|
|
|
|
|
|
class Precision(Enum): |
|
"""Model precision types.""" |
|
Unknown = auto() |
|
float16 = auto() |
|
bfloat16 = auto() |
|
float32 = auto() |
|
int8 = auto() |
|
int4 = auto() |
|
NA = auto() |
|
|
|
def __str__(self): |
|
"""String representation of the precision type.""" |
|
return self.name |
|
|
|
|
|
class WeightType(Enum): |
|
"""Model weight types.""" |
|
Original = auto() |
|
Delta = auto() |
|
Adapter = auto() |
|
def __str__(self): |
|
"""String representation of the weight type.""" |
|
return self.name |
|
|
|
|
|
@dataclass |
|
class ColumnInfo: |
|
"""Information about a column in the leaderboard.""" |
|
name: str |
|
display_name: str |
|
type: str = "text" |
|
hidden: bool = False |
|
never_hidden: bool = False |
|
displayed_by_default: bool = True |
|
|
|
|
|
@dataclass |
|
class GuardBenchColumn: |
|
"""Columns for the GuardBench leaderboard.""" |
|
|
|
model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="model_name", |
|
display_name="Model", |
|
never_hidden=True, |
|
displayed_by_default=True |
|
)) |
|
mode: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="mode", |
|
display_name="Mode", |
|
displayed_by_default=True |
|
)) |
|
model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="model_type", |
|
display_name="Access_Type", |
|
displayed_by_default=True |
|
)) |
|
submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="submission_date", |
|
display_name="Submission_Date", |
|
displayed_by_default=False |
|
)) |
|
version: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="version", |
|
display_name="Version", |
|
displayed_by_default=False |
|
)) |
|
guard_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="guard_model_type", |
|
display_name="Type", |
|
displayed_by_default=False |
|
)) |
|
base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="base_model", |
|
display_name="Base Model", |
|
displayed_by_default=False |
|
)) |
|
revision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="revision", |
|
display_name="Revision", |
|
displayed_by_default=False |
|
)) |
|
precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="precision", |
|
display_name="Precision", |
|
displayed_by_default=False |
|
)) |
|
weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="weight_type", |
|
display_name="Weight Type", |
|
displayed_by_default=False |
|
)) |
|
|
|
|
|
default_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_f1_binary", |
|
display_name="Default_Prompts_F1_Binary", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_f1", |
|
display_name="Default_Prompts_F1", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_recall_binary", |
|
display_name="Default_Prompts_Recall", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_precision_binary", |
|
display_name="Default_Prompts_Precision", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_error_ratio", |
|
display_name="Default_Prompts_Error_Ratio", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_prompts_avg_runtime_ms", |
|
display_name="Default_Prompts_Avg_Runtime_ms", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
|
|
|
|
jailbreaked_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_f1_binary", |
|
display_name="Jailbreaked_Prompts_F1_Binary", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_f1", |
|
display_name="Jailbreaked_Prompts_F1", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_recall_binary", |
|
display_name="Jailbreaked_Prompts_Recall", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_precision_binary", |
|
display_name="Jailbreaked_Prompts_Precision", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_error_ratio", |
|
display_name="Jailbreaked_Prompts_Error_Ratio", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_prompts_avg_runtime_ms", |
|
display_name="Jailbreaked_Prompts_Avg_Runtime_ms", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
|
|
|
|
default_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_f1_binary", |
|
display_name="Default_Answers_F1_Binary", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_f1", |
|
display_name="Default_Answers_F1", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_recall_binary", |
|
display_name="Default_Answers_Recall", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_precision_binary", |
|
display_name="Default_Answers_Precision", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_error_ratio", |
|
display_name="Default_Answers_Error_Ratio", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
default_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="default_answers_avg_runtime_ms", |
|
display_name="Default_Answers_Avg_Runtime_ms", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
|
|
|
|
jailbreaked_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_f1_binary", |
|
display_name="Jailbreaked_Answers_F1_Binary", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_f1", |
|
display_name="Jailbreaked_Answers_F1", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_recall_binary", |
|
display_name="Jailbreaked_Answers_Recall", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_precision_binary", |
|
display_name="Jailbreaked_Answers_Precision", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_error_ratio", |
|
display_name="Jailbreaked_Answers_Error_Ratio", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
jailbreaked_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="jailbreaked_answers_avg_runtime_ms", |
|
display_name="Jailbreaked_Answers_Avg_Runtime_ms", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
integral_score: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="integral_score", |
|
display_name="Integral_Score", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
|
|
|
|
macro_accuracy: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="macro_accuracy", |
|
display_name="Macro_Accuracy", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
macro_recall: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="macro_recall", |
|
display_name="Macro_Recall", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
macro_precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="macro_precision", |
|
display_name="Macro Precision", |
|
type="number", |
|
displayed_by_default=False |
|
)) |
|
|
|
|
|
micro_avg_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="micro_avg_error_ratio", |
|
display_name="Micro_Error", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
micro_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="micro_avg_runtime_ms", |
|
display_name="Micro_Avg_time_ms", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
total_evals_count: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
|
name="total_evals_count", |
|
display_name="Total_Count", |
|
type="number", |
|
displayed_by_default=True |
|
)) |
|
|
|
|
|
|
|
GUARDBENCH_COLUMN = GuardBenchColumn() |
|
|
|
|
|
COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] |
|
DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
|
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
|
|
|
|
|
def reorder_display_cols(): |
|
cols = DISPLAY_COLS |
|
if 'model_name' in cols and 'mode' in cols: |
|
cols.remove('mode') |
|
model_name_index = cols.index('model_name') |
|
cols.insert(model_name_index + 1, 'mode') |
|
return cols |
|
DISPLAY_COLS = reorder_display_cols() |
|
|
|
METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
|
if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] |
|
HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
|
if getattr(GUARDBENCH_COLUMN, f.name).hidden] |
|
NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
|
if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] |
|
|
|
|
|
CATEGORIES = [ |
|
'Political Corruption and Legal Evasion', |
|
'Financial Fraud and Unethical Business', |
|
'AI Manipulation and Jailbreaking', |
|
'Child Exploitation and Abuse', |
|
'Hate Speech, Extremism, and Discrimination', |
|
'Labor Exploitation and Human Trafficking', |
|
'Manipulation, Deception, and Misinformation', |
|
'Environmental and Industrial Harm', |
|
'Academic Dishonesty and Cheating', |
|
'Self–Harm and Suicidal Ideation', |
|
'Animal Cruelty and Exploitation', |
|
'Criminal, Violent, and Terrorist Activity', |
|
'Drug– and Substance–Related Activities', |
|
'Sexual Content and Violence', |
|
'Weapon, Explosives, and Hazardous Materials', |
|
'Cybercrime, Hacking, and Digital Exploits', |
|
'Creative Content Involving Illicit Themes', |
|
'Safe Prompts' |
|
] |
|
|
|
|
|
TEST_TYPES = [ |
|
"default_prompts", |
|
"jailbreaked_prompts", |
|
"default_answers", |
|
"jailbreaked_answers" |
|
] |
|
|
|
|
|
METRICS = [ |
|
"f1_binary", |
|
"recall_binary", |
|
"precision_binary", |
|
"error_ratio", |
|
"avg_runtime_ms", |
|
"accuracy" |
|
] |
|
|
|
def get_all_column_choices(): |
|
""" |
|
Get all available column choices for the multiselect dropdown. |
|
|
|
Returns: |
|
List of tuples with (column_name, display_name) for all columns. |
|
""" |
|
column_choices = [] |
|
|
|
default_visible_columns = get_default_visible_columns() |
|
|
|
for f in fields(GUARDBENCH_COLUMN): |
|
column_info = getattr(GUARDBENCH_COLUMN, f.name) |
|
|
|
if column_info.name not in default_visible_columns: |
|
column_choices.append((column_info.name, column_info.display_name)) |
|
|
|
return column_choices |
|
|
|
def get_default_visible_columns(): |
|
""" |
|
Get the list of column names that should be visible by default. |
|
|
|
Returns: |
|
List of column names that are displayed by default. |
|
""" |
|
return [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
|
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
|
|