File size: 9,600 Bytes
aed7669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52ff713
aed7669
 
 
52ff713
aed7669
 
 
 
 
 
 
 
 
 
 
 
920bbd1
 
bd5bff8
 
52ff713
bd5bff8
 
 
 
 
 
 
 
 
 
52ff713
d2fa748
920bbd1
 
d2fa748
bd5bff8
52ff713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aed7669
 
 
 
52ff713
aed7669
 
 
52ff713
d53afd7
 
 
 
52ff713
9449bbc
 
8ed5109
9449bbc
 
52ff713
 
 
9f6c7b8
d2fa748
 
8ed5109
d2fa748
 
52ff713
9f6c7b8
 
 
52ff713
08fd631
 
52ff713
08fd631
 
52ff713
c043acf
d2fa748
ccd246f
d53afd7
9449bbc
d2fa748
 
ccd246f
d53afd7
9449bbc
d2fa748
9449bbc
 
4bb5afa
52ff713
d53afd7
9f6c7b8
08fd631
9f6c7b8
52ff713
4bb5afa
 
52ff713
4bb5afa
52ff713
4bb5afa
9f6c7b8
52ff713
4bb5afa
 
52ff713
 
4bb5afa
261cee6
b95ef88
 
261cee6
920bbd1
261cee6
b95ef88
261cee6
b95ef88
261cee6
b95ef88
52ff713
261cee6
b95ef88
52ff713
261cee6
4bb5afa
 
261cee6
1ad7c66
261cee6
b95ef88
261cee6
1ad7c66
52ff713
 
 
 
 
 
 
 
 
 
39e7b94
613e785
 
52ff713
 
 
 
 
 
 
39e7b94
613e785
 
52ff713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39e7b94
613e785
 
52ff713
 
 
 
 
 
 
 
261cee6
b95ef88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
import json
import sqlite3
import glob
import pandas as pd
import gradio as gr
from datetime import datetime
from typing import Dict, List

# Directory to store SQLite results
db_dir = "results/"

def find_or_download_db():
    """Check if SQLite .db files exist; if not, attempt to download from cloud storage."""
    if not os.path.exists(db_dir):
        os.makedirs(db_dir)
    db_files = glob.glob(os.path.join(db_dir, "*.db"))

    # Ensure the random bot database exists
    if "results/random_None.db" not in db_files:
        raise FileNotFoundError("Please upload results for the random agent in a file named 'random_None.db'.")

    return db_files

def extract_agent_info(filename: str):
    """Extract agent type and model name from the filename."""
    base_name = os.path.basename(filename).replace(".db", "")
    parts = base_name.split("_", 1)
    if len(parts) == 2:
        agent_type, model_name = parts
    else:
        agent_type, model_name = parts[0], "Unknown"
    return agent_type, model_name

def get_available_games(include_aggregated=True) -> List[str]:
    """Extracts all unique game names from all SQLite databases. Includes 'Aggregated Performance' only when required."""
    db_files = find_or_download_db()
    game_names = set()

    for db_file in db_files:
        conn = sqlite3.connect(db_file)
        try:
            query = "SELECT DISTINCT game_name FROM moves"
            df = pd.read_sql_query(query, conn)
            game_names.update(df["game_name"].tolist())
        except Exception:
            pass  # Ignore errors if table doesn't exist
        finally:
            conn.close()

    game_list = sorted(game_names) if game_names else ["No Games Found"]
    if include_aggregated:
        game_list.insert(0, "Aggregated Performance")  # Ensure 'Aggregated Performance' is always first
    return game_list

def extract_illegal_moves_summary()-> pd.DataFrame:
    """Extracts the number of illegal moves made by each LLM agent.

    Returns:
        pd.DataFrame: DataFrame with columns [agent_name, illegal_moves].
    """
    db_files = find_or_download_db()
    summary = []
    for db_file in db_files:
        agent_type, model_name = extract_agent_info(db_file)
        if agent_type == "random":
            continue # Skip the random agent from this analysis
        conn = sqlite3.connect(db_file)
        try:
            # Count number of illegal moves from the illegal_moves table
            df = pd.read_sql_query("SELECT COUNT(*) AS illegal_moves FROM illegal_moves", conn)
            count = int(df["illegal_moves"].iloc[0]) if not df.empty else 0
        except Exception:
            count = 0 # If the table does not exist or error occurs
        summary.append({"agent_name": model_name, "illegal_moves": count})
        conn.close()
    return pd.DataFrame(summary)

def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
    """Extract and aggregate leaderboard stats from all SQLite databases."""
    db_files = find_or_download_db()
    all_stats = []

    for db_file in db_files:
        conn = sqlite3.connect(db_file)
        agent_type, model_name = extract_agent_info(db_file)

        # Skip random agent rows
        if agent_type == "random":
            conn.close()
            continue

        if game_name == "Aggregated Performance":
            query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
                    "SUM(reward) AS total_rewards " \
                    "FROM game_results"
            df = pd.read_sql_query(query, conn)

            # Use avg_generation_time from a specific game (e.g., Kuhn Poker)
            game_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = 'kuhn_poker'"
            avg_gen_time = conn.execute(game_query).fetchone()[0] or 0
        else:
            query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
                    "SUM(reward) AS total_rewards " \
                    "FROM game_results WHERE game_name = ?"
            df = pd.read_sql_query(query, conn, params=(game_name,))

            # Fetch average generation time from moves table
            gen_time_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = ?"
            avg_gen_time = conn.execute(gen_time_query, (game_name,)).fetchone()[0] or 0

        # Keep division by 2 for total rewards
        df["total_rewards"] = df["total_rewards"].fillna(0).astype(float) / 2

        # Ensure avg_gen_time has decimals
        avg_gen_time = round(avg_gen_time, 3)

        # Calculate win rate against random bot using moves table
        vs_random_query = """
            SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
            JOIN moves m ON gr.game_name = m.game_name AND gr.episode = m.episode
            WHERE m.opponent = 'random_None' AND gr.reward > 0
        """
        total_vs_random_query = """
            SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
            JOIN moves m ON gr.game_name = m.game_name AND gr.episode = m.episode
            WHERE m.opponent = 'random_None'
        """
        wins_vs_random = conn.execute(vs_random_query).fetchone()[0] or 0
        total_vs_random = conn.execute(total_vs_random_query).fetchone()[0] or 0
        vs_random_rate = (wins_vs_random / total_vs_random * 100) if total_vs_random > 0 else 0

        df.insert(0, "agent_name", model_name)  # Ensure agent_name is the first column
        df.insert(1, "agent_type", agent_type)  # Ensure agent_type is second column
        df["avg_generation_time (sec)"] = avg_gen_time
        df["win vs_random (%)"] = round(vs_random_rate, 2)

        all_stats.append(df)
        conn.close()

    leaderboard_df = pd.concat(all_stats, ignore_index=True) if all_stats else pd.DataFrame()

    if leaderboard_df.empty:
        leaderboard_df = pd.DataFrame(columns=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"])

    return leaderboard_df


##########################################################
with gr.Blocks() as interface:
    # Tab for playing games against LLMs
    with gr.Tab("Game Arena"):
        gr.Markdown("# Play Against LLMs\nChoose a game and an opponent to play!")
        # Dropdown to select a game, excluding 'Aggregated Performance'
        game_dropdown = gr.Dropdown(get_available_games(include_aggregated=False), label="Select a Game")
        # Dropdown to choose an opponent (Random Bot or LLM)
        opponent_dropdown = gr.Dropdown(["Random Bot", "LLM"], label="Choose Opponent")
        # Button to start the game
        play_button = gr.Button("Start Game")
        # Textbox to display the game log
        game_output = gr.Textbox(label="Game Log")

        # Event to start the game when the button is clicked
        play_button.click(lambda game, opponent: f"Game {game} started against {opponent}", inputs=[game_dropdown, opponent_dropdown], outputs=[game_output])

    # Tab for leaderboard and performance tracking
    with gr.Tab("Leaderboard"):
        gr.Markdown("# LLM Model Leaderboard\nTrack performance across different games!")
        # Dropdown to select a game, including 'Aggregated Performance'
        leaderboard_game_dropdown = gr.Dropdown(choices=get_available_games(), label="Select Game", value="Aggregated Performance")
        # Table to display leaderboard statistics
        leaderboard_table = gr.Dataframe(value=extract_leaderboard_stats("Aggregated Performance"), headers=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"], every=5)
        # Update the leaderboard when a new game is selected
        leaderboard_game_dropdown.change(fn=extract_leaderboard_stats, inputs=[leaderboard_game_dropdown], outputs=[leaderboard_table])

    # Tab for visual insights and performance metrics
    with gr.Tab("Metrics Dashboard"):
        gr.Markdown("# πŸ“Š Metrics Dashboard\nVisual summaries of LLM performance across games.")

        # Extract data for visualizations
        metrics_df = extract_leaderboard_stats("Aggregated Performance")

        with gr.Row():
            gr.BarPlot(
                value=metrics_df,
                x="win vs_random (%)",
                y="agent_name",
                title="Win Rate vs Random Bot",
                x_label="LLM Model",
                y_label="Win Rate (%)"
            )

        with gr.Row():
            gr.BarPlot(
                value=metrics_df,
                x="avg_generation_time (sec)",
                y="agent_name",
                title="Average Generation Time",
                x_label="LLM Model",
                y_label="Time (sec)"
            )

        with gr.Row():
            gr.Dataframe(value=metrics_df, label="Performance Summary")

    # Tab for LLM reasoning and illegal move analysis
    with gr.Tab("Analysis of LLM Reasoning"):
        gr.Markdown("# 🧠 Analysis of LLM Reasoning\nInsights into move legality and decision behavior.")

        # Load illegal move stats using global function
        illegal_df = extract_illegal_moves_summary()

        with gr.Row():
            gr.BarPlot(
                value=illegal_df,
                x="illegal_moves",
                y="agent_name",
                title="Illegal Moves by Model",
                x_label="LLM Model",
                y_label="# of Illegal Moves"
            )

        with gr.Row():
            gr.Dataframe(value=illegal_df, label="Illegal Move Summary")

    # Launch the Gradio interface
    interface.launch()