mode viz fix
Browse files
app.py
CHANGED
@@ -805,9 +805,17 @@ with demo:
|
|
805 |
interactive=True,
|
806 |
visible=False
|
807 |
)
|
808 |
-
|
809 |
-
|
810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
multiselect=True,
|
812 |
interactive=True
|
813 |
)
|
@@ -830,18 +838,83 @@ with demo:
|
|
830 |
plot_output = gr.Plot()
|
831 |
|
832 |
# Update visualization when any selector changes
|
833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
control.change(
|
835 |
-
fn=lambda
|
836 |
-
inputs=[
|
837 |
outputs=plot_output
|
838 |
)
|
839 |
|
840 |
-
# Update
|
841 |
viz_version_selector.change(
|
842 |
-
fn=
|
843 |
inputs=[viz_version_selector],
|
844 |
-
outputs=[
|
845 |
)
|
846 |
|
847 |
# with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):
|
|
|
805 |
interactive=True,
|
806 |
visible=False
|
807 |
)
|
808 |
+
# New: Mode selector
|
809 |
+
def get_model_mode_choices(version):
|
810 |
+
df = get_leaderboard_df(version=version)
|
811 |
+
if df.empty:
|
812 |
+
return []
|
813 |
+
# Return list of tuples (model_name, mode)
|
814 |
+
return sorted([f"{row['model_name']} [{row['mode']}]" for _, row in df.drop_duplicates(subset=["model_name", "mode"]).iterrows()])
|
815 |
+
|
816 |
+
model_mode_selector = gr.Dropdown(
|
817 |
+
choices=get_model_mode_choices(CURRENT_VERSION),
|
818 |
+
label="Select Model(s) [Mode] to Compare",
|
819 |
multiselect=True,
|
820 |
interactive=True
|
821 |
)
|
|
|
838 |
plot_output = gr.Plot()
|
839 |
|
840 |
# Update visualization when any selector changes
|
841 |
+
def update_visualization_with_mode(selected_model_modes, selected_category, selected_metric, version):
|
842 |
+
if not selected_model_modes:
|
843 |
+
return go.Figure()
|
844 |
+
df = get_leaderboard_df(version=version) if selected_category == "All Results" else get_category_leaderboard_df(selected_category, version=version)
|
845 |
+
if df.empty:
|
846 |
+
return go.Figure()
|
847 |
+
# Parse selected_model_modes into model_name and mode
|
848 |
+
selected_pairs = [s.rsplit(" [", 1) for s in selected_model_modes]
|
849 |
+
selected_pairs = [(name.strip(), mode.strip("] ")) for name, mode in selected_pairs]
|
850 |
+
mask = df.apply(lambda row: (row['model_name'], str(row['mode'])) in selected_pairs, axis=1)
|
851 |
+
filtered_df = df[mask]
|
852 |
+
metric_cols = [col for col in filtered_df.columns if selected_metric in col]
|
853 |
+
fig = go.Figure()
|
854 |
+
colors = ['#8FCCCC', '#C2A4B6', '#98B4A6', '#B68F7C']
|
855 |
+
for idx, (model_name, mode) in enumerate(selected_pairs):
|
856 |
+
model_data = filtered_df[(filtered_df['model_name'] == model_name) & (filtered_df['mode'] == mode)]
|
857 |
+
if not model_data.empty:
|
858 |
+
values = model_data[metric_cols].values[0].tolist()
|
859 |
+
values = values + [values[0]]
|
860 |
+
categories = [col.replace(f'_{selected_metric}', '') for col in metric_cols]
|
861 |
+
categories = categories + [categories[0]]
|
862 |
+
fig.add_trace(go.Scatterpolar(
|
863 |
+
r=values,
|
864 |
+
theta=categories,
|
865 |
+
name=f"{model_name} [{mode}]",
|
866 |
+
line_color=colors[idx % len(colors)],
|
867 |
+
fill='toself'
|
868 |
+
))
|
869 |
+
fig.update_layout(
|
870 |
+
paper_bgcolor='#000000',
|
871 |
+
plot_bgcolor='#000000',
|
872 |
+
font={'color': '#ffffff'},
|
873 |
+
title={
|
874 |
+
'text': f'{selected_category} - {selected_metric.upper()} Score Comparison',
|
875 |
+
'font': {'color': '#ffffff', 'size': 24}
|
876 |
+
},
|
877 |
+
polar=dict(
|
878 |
+
bgcolor='#000000',
|
879 |
+
radialaxis=dict(
|
880 |
+
visible=True,
|
881 |
+
range=[0, 1],
|
882 |
+
gridcolor='#333333',
|
883 |
+
linecolor='#333333',
|
884 |
+
tickfont={'color': '#ffffff'},
|
885 |
+
),
|
886 |
+
angularaxis=dict(
|
887 |
+
gridcolor='#333333',
|
888 |
+
linecolor='#333333',
|
889 |
+
tickfont={'color': '#ffffff'},
|
890 |
+
)
|
891 |
+
),
|
892 |
+
height=600,
|
893 |
+
showlegend=True,
|
894 |
+
legend=dict(
|
895 |
+
yanchor="top",
|
896 |
+
y=0.99,
|
897 |
+
xanchor="right",
|
898 |
+
x=0.99,
|
899 |
+
bgcolor='rgba(0,0,0,0.5)',
|
900 |
+
font={'color': '#ffffff'}
|
901 |
+
)
|
902 |
+
)
|
903 |
+
return fig
|
904 |
+
|
905 |
+
# Connect selectors to update function
|
906 |
+
for control in [viz_version_selector, model_mode_selector, category_selector, metric_selector]:
|
907 |
control.change(
|
908 |
+
fn=lambda smm, sc, s_metric, v: update_visualization_with_mode(smm, CATEGORY_REVERSE_MAP.get(sc, sc), s_metric, v),
|
909 |
+
inputs=[model_mode_selector, category_selector, metric_selector, viz_version_selector],
|
910 |
outputs=plot_output
|
911 |
)
|
912 |
|
913 |
+
# Update model_mode_selector choices when version changes
|
914 |
viz_version_selector.change(
|
915 |
+
fn=get_model_mode_choices,
|
916 |
inputs=[viz_version_selector],
|
917 |
+
outputs=[model_mode_selector]
|
918 |
)
|
919 |
|
920 |
# with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):
|