apsys commited on
Commit
6938bdc
·
1 Parent(s): a7eca29

mode viz fix

Browse files
Files changed (1) hide show
  1. app.py +82 -9
app.py CHANGED
@@ -805,9 +805,17 @@ with demo:
805
  interactive=True,
806
  visible=False
807
  )
808
- model_selector = gr.Dropdown(
809
- choices=update_model_choices(CURRENT_VERSION),
810
- label="Select Models to Compare",
 
 
 
 
 
 
 
 
811
  multiselect=True,
812
  interactive=True
813
  )
@@ -830,18 +838,83 @@ with demo:
830
  plot_output = gr.Plot()
831
 
832
  # Update visualization when any selector changes
833
- for control in [viz_version_selector, model_selector, category_selector, metric_selector]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
  control.change(
835
- fn=lambda sm, sc, s_metric, v: update_visualization(sm, CATEGORY_REVERSE_MAP.get(sc, sc), s_metric, v),
836
- inputs=[model_selector, category_selector, metric_selector, viz_version_selector],
837
  outputs=plot_output
838
  )
839
 
840
- # Update model choices when version changes
841
  viz_version_selector.change(
842
- fn=update_model_choices,
843
  inputs=[viz_version_selector],
844
- outputs=[model_selector]
845
  )
846
 
847
  # with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):
 
805
  interactive=True,
806
  visible=False
807
  )
808
+ # New: Mode selector
809
+ def get_model_mode_choices(version):
810
+ df = get_leaderboard_df(version=version)
811
+ if df.empty:
812
+ return []
813
+ # Return list of tuples (model_name, mode)
814
+ return sorted([f"{row['model_name']} [{row['mode']}]" for _, row in df.drop_duplicates(subset=["model_name", "mode"]).iterrows()])
815
+
816
+ model_mode_selector = gr.Dropdown(
817
+ choices=get_model_mode_choices(CURRENT_VERSION),
818
+ label="Select Model(s) [Mode] to Compare",
819
  multiselect=True,
820
  interactive=True
821
  )
 
838
  plot_output = gr.Plot()
839
 
840
  # Update visualization when any selector changes
841
+ def update_visualization_with_mode(selected_model_modes, selected_category, selected_metric, version):
842
+ if not selected_model_modes:
843
+ return go.Figure()
844
+ df = get_leaderboard_df(version=version) if selected_category == "All Results" else get_category_leaderboard_df(selected_category, version=version)
845
+ if df.empty:
846
+ return go.Figure()
847
+ # Parse selected_model_modes into model_name and mode
848
+ selected_pairs = [s.rsplit(" [", 1) for s in selected_model_modes]
849
+ selected_pairs = [(name.strip(), mode.strip("] ")) for name, mode in selected_pairs]
850
+ mask = df.apply(lambda row: (row['model_name'], str(row['mode'])) in selected_pairs, axis=1)
851
+ filtered_df = df[mask]
852
+ metric_cols = [col for col in filtered_df.columns if selected_metric in col]
853
+ fig = go.Figure()
854
+ colors = ['#8FCCCC', '#C2A4B6', '#98B4A6', '#B68F7C']
855
+ for idx, (model_name, mode) in enumerate(selected_pairs):
856
+ model_data = filtered_df[(filtered_df['model_name'] == model_name) & (filtered_df['mode'] == mode)]
857
+ if not model_data.empty:
858
+ values = model_data[metric_cols].values[0].tolist()
859
+ values = values + [values[0]]
860
+ categories = [col.replace(f'_{selected_metric}', '') for col in metric_cols]
861
+ categories = categories + [categories[0]]
862
+ fig.add_trace(go.Scatterpolar(
863
+ r=values,
864
+ theta=categories,
865
+ name=f"{model_name} [{mode}]",
866
+ line_color=colors[idx % len(colors)],
867
+ fill='toself'
868
+ ))
869
+ fig.update_layout(
870
+ paper_bgcolor='#000000',
871
+ plot_bgcolor='#000000',
872
+ font={'color': '#ffffff'},
873
+ title={
874
+ 'text': f'{selected_category} - {selected_metric.upper()} Score Comparison',
875
+ 'font': {'color': '#ffffff', 'size': 24}
876
+ },
877
+ polar=dict(
878
+ bgcolor='#000000',
879
+ radialaxis=dict(
880
+ visible=True,
881
+ range=[0, 1],
882
+ gridcolor='#333333',
883
+ linecolor='#333333',
884
+ tickfont={'color': '#ffffff'},
885
+ ),
886
+ angularaxis=dict(
887
+ gridcolor='#333333',
888
+ linecolor='#333333',
889
+ tickfont={'color': '#ffffff'},
890
+ )
891
+ ),
892
+ height=600,
893
+ showlegend=True,
894
+ legend=dict(
895
+ yanchor="top",
896
+ y=0.99,
897
+ xanchor="right",
898
+ x=0.99,
899
+ bgcolor='rgba(0,0,0,0.5)',
900
+ font={'color': '#ffffff'}
901
+ )
902
+ )
903
+ return fig
904
+
905
+ # Connect selectors to update function
906
+ for control in [viz_version_selector, model_mode_selector, category_selector, metric_selector]:
907
  control.change(
908
+ fn=lambda smm, sc, s_metric, v: update_visualization_with_mode(smm, CATEGORY_REVERSE_MAP.get(sc, sc), s_metric, v),
909
+ inputs=[model_mode_selector, category_selector, metric_selector, viz_version_selector],
910
  outputs=plot_output
911
  )
912
 
913
+ # Update model_mode_selector choices when version changes
914
  viz_version_selector.change(
915
+ fn=get_model_mode_choices,
916
  inputs=[viz_version_selector],
917
+ outputs=[model_mode_selector]
918
  )
919
 
920
  # with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):