File size: 7,078 Bytes
aada8de
 
0aedfd6
aada8de
 
 
 
 
 
 
 
0aedfd6
aada8de
 
f5303bc
0aedfd6
aada8de
0aedfd6
 
 
aada8de
 
 
 
 
409ae36
90f7215
 
0aedfd6
 
8fab3a5
 
0aedfd6
 
409ae36
0aedfd6
 
 
 
 
 
f5303bc
8d950ab
 
 
 
 
aada8de
409ae36
5f6be9e
f5303bc
aada8de
f5303bc
a47646c
f5303bc
a47646c
f5303bc
 
 
 
 
5f6be9e
 
 
 
 
d56df1b
a47646c
 
d56df1b
a47646c
 
aada8de
f5303bc
5f6be9e
aada8de
f5303bc
 
 
 
 
 
 
89d6f25
aada8de
f5303bc
409ae36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5303bc
 
 
4b38e69
d56df1b
aada8de
 
 
 
 
 
 
 
 
 
8fab3a5
d56df1b
8fab3a5
26602a7
f5303bc
26602a7
409ae36
0aedfd6
 
 
409ae36
0aedfd6
 
 
409ae36
0aedfd6
 
 
409ae36
aada8de
 
 
 
 
 
 
 
 
 
 
 
 
35d42fa
aada8de
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from apscheduler.schedulers.background import BackgroundScheduler
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns

from src.about import (
    CITATION_BUTTON_LABEL,
    CITATION_BUTTON_TEXT,
    INTRODUCTION_TEXT,
    LLM_BENCHMARKS_TEXT,
    TITLE,
)
from src.constants import ProblemTypes
from src.display.css_html_js import custom_css
from src.display.utils import (
    ModelInfoColumn,
    fields
)
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, REPO_ID
from src.populate import get_model_info_df, get_merged_df
from src.utils import get_grouped_dfs, pivot_existed_df, rename_metrics, format_df


def restart_space():
    API.restart_space(repo_id=REPO_ID)



grouped_dfs = get_grouped_dfs()
#domain_df, freq_df, term_length_df, variate_type_df, overall_df = grouped_dfs['domain'], grouped_dfs['frequency'], grouped_dfs['term_length'], grouped_dfs['univariate'], grouped_dfs['overall']
domain_df, overall_df = grouped_dfs[ProblemTypes.col_name], grouped_dfs['overall']
overall_df = rename_metrics(overall_df)
overall_df = format_df(overall_df)
overall_df = overall_df.sort_values(by=['rank'])
domain_df = pivot_existed_df(domain_df, tab_name=ProblemTypes.col_name)
print(f'Domain dataframe is {domain_df}')
# freq_df = pivot_existed_df(freq_df, tab_name='frequency')
# print(f'Freq dataframe is {freq_df}')
# term_length_df = pivot_existed_df(term_length_df, tab_name='term_length')
# print(f'Term length dataframe is {term_length_df}')
# variate_type_df = pivot_existed_df(variate_type_df, tab_name='univariate')
# print(f'Variate type dataframe is {variate_type_df}')
model_info_df = get_model_info_df(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH)
# (
#     finished_eval_queue_df,
#     running_eval_queue_df,
#     pending_eval_queue_df,
# ) = get_evaluation_queue_df(EVAL_REQUESTS_PATH, EVAL_COLS)


def init_leaderboard(ori_dataframe, model_info_df, sort_val: str|None = None):
    if ori_dataframe is None or ori_dataframe.empty:
        raise ValueError("Leaderboard DataFrame is empty or None.")
    model_info_col_list = [c.name for c in fields(ModelInfoColumn) if c.displayed_by_default if c.name not in ['#Params (B)', 'available_on_hub', 'hub', 'Model sha','Hub License']]
    col2type_dict = {c.name: c.type for c in fields(ModelInfoColumn)}
    default_selection_list = list(ori_dataframe.columns) + model_info_col_list
    # print('default_selection_list: ', default_selection_list)
    # ipdb.set_trace()
    # default_selection_list = [col for col in default_selection_list if col not in ['#Params (B)', 'available_on_hub', 'hub', 'Model sha','Hub License']]
    merged_df = get_merged_df(ori_dataframe, model_info_df)
    new_cols = ['T'] + [col for col in merged_df.columns if col != 'T']
    merged_df = merged_df[new_cols]
    if sort_val:
        if sort_val in merged_df.columns:
            merged_df = merged_df.sort_values(by=[sort_val])
        else:
            print(f'Warning: cannot sort by {sort_val}')
    print('Merged df: ', merged_df)
    # get the data type
    datatype_list = [col2type_dict[col] if col in col2type_dict else 'number' for col in merged_df.columns]
    # print('datatype_list: ', datatype_list)
    # print('merged_df.column: ', merged_df.columns)
    # ipdb.set_trace()
    return Leaderboard(
        value=merged_df,
        datatype=datatype_list,
        select_columns=SelectColumns(
            default_selection=default_selection_list,
            # default_selection=[c.name for c in fields(ModelInfoColumn) if
            #                    c.displayed_by_default and c.name not in ['params', 'available_on_hub', 'hub',
            #                                                              'Model sha', 'Hub License']],
            # default_selection=list(dataframe.columns),
            cant_deselect=[c.name for c in fields(ModelInfoColumn) if c.never_hidden],
            label="Select Columns to Display:",
            # How to uncheck??
        ),
        hide_columns=[c.name for c in fields(ModelInfoColumn) if c.hidden],
        search_columns=['model'],
        # hide_columns=[c.name for c in fields(AutoEvalColumn) if c.hidden],
        # filter_columns=[
        #     ColumnFilter(AutoEvalColumn.model_type.name, type="checkboxgroup", label="Model types"),
        #     ColumnFilter(AutoEvalColumn.precision.name, type="checkboxgroup", label="Precision"),
        #     ColumnFilter(
        #         AutoEvalColumn.params.name,
        #         type="slider",
        #         min=0.01,
        #         max=500,
        #         label="Select the number of parameters (B)",
        #     ),
        #     ColumnFilter(
        #         AutoEvalColumn.still_on_hub.name, type="boolean", label="Deleted/incomplete", default=False
        #     ),
        # ],
        filter_columns=[
            ColumnFilter(ModelInfoColumn.model_type.name, type="checkboxgroup", label="Model types"),
        ],
        # bool_checkboxgroup_label="",
        column_widths=[40, 150] + [180 for _ in range(len(merged_df.columns)-2)],
        interactive=False,
    )


demo = gr.Blocks(css=custom_css)
with demo:
    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem('πŸ… Overall', elem_id="llm-benchmark-tab-table", id=5):
            leaderboard = init_leaderboard(overall_df, model_info_df, sort_val='Rank')
            print(f'FINAL Overall LEADERBOARD {overall_df}')
        with gr.TabItem("πŸ… By Domain", elem_id="llm-benchmark-tab-table", id=0):
            leaderboard = init_leaderboard(domain_df, model_info_df)
            print(f"FINAL Domain LEADERBOARD 1 {domain_df}")

        # with gr.TabItem("πŸ… By Frequency", elem_id="llm-benchmark-tab-table", id=1):
        #     leaderboard = init_leaderboard(freq_df, model_info_df)
        #     print(f"FINAL Frequency LEADERBOARD 1 {freq_df}")

        # with gr.TabItem("πŸ… By Term Length", elem_id="llm-benchmark-tab-table", id=2):
        #     leaderboard = init_leaderboard(term_length_df, model_info_df)
        #     print(f"FINAL term length LEADERBOARD 1 {term_length_df}")

        # with gr.TabItem("πŸ… By Variate Type", elem_id="llm-benchmark-tab-table", id=3):
        #     leaderboard = init_leaderboard(variate_type_df, model_info_df)
        #     print(f"FINAL LEADERBOARD 1 {variate_type_df}")
        with gr.TabItem("πŸ“ About", elem_id="llm-benchmark-tab-table", id=4):
            gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")

    with gr.Row():
        with gr.Accordion("πŸ“™ Citation", open=False):
            citation_button = gr.Textbox(
                value=CITATION_BUTTON_TEXT,
                label=CITATION_BUTTON_LABEL,
                lines=20,
                elem_id="citation-button",
                show_copy_button=True,
            )

scheduler = BackgroundScheduler()
# scheduler.add_job(restart_space, "interval", seconds=1800)
scheduler.start()
demo.queue(default_concurrency_limit=40).launch()