File size: 3,385 Bytes
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Gradio interface for plotting policy.
"""

import chess
import gradio as gr
import uuid

from lczerolens.encodings import encode_move

from src import constants, global_variables, visualisation


def compute_features_fn(
    features,
    model_output, 
    file_id,
    root_fen, 
    traj_fen, 
    feature_index
):
    model_output, _, sae_output = global_variables.generator.generate(
        root_fen=root_fen,
        traj_fen=traj_fen
    )
    features = sae_output["f"]
    first_output = render_feature_index(
        features,
        model_output,
        file_id,
        feature_index,
        traj_fen,
    )
    game_info = f"WDL: {model_output.get('wdl')}"
    return *first_output, game_info


def render_feature_index(
    features,
    model_output,
    file_id,
    feature_index,
    traj_fen,
):
    if file_id is None:
        file_id = str(uuid.uuid4())
    board = chess.Board(traj_fen)
    pixel_features = features[:,feature_index]
    if board.turn:
        heatmap = pixel_features.view(64)
    else:
        heatmap = pixel_features.view(8,8).flip(0).view(64)

    best_legal_logit = None
    best_legal_move = None
    for move in board.legal_moves:
        move_index = encode_move(move, (board.turn, not board.turn))
        logit = model_output["policy"][1,move_index].item()
        if best_legal_logit is None:
            best_legal_logit = logit
        else:
            best_legal_move = move

    svg_board, fig = visualisation.render_heatmap(
        board,
        heatmap,
        arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
    )
    with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
        f.write(svg_board)
    return (
        features,
        model_output,
        file_id,
        f"{constants.FIGURES_FOLER}/{file_id}.svg",
        fig
    ) 

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            root_fen = gr.Textbox(
                label="Root FEN",
                lines=1,
                max_lines=1,
                value=chess.STARTING_FEN,
            )
            traj_fen = gr.Textbox(
                label="Trajectory FEN",
                lines=1,
                max_lines=1,
                value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
            )
            compute_features = gr.Button("Compute features")

            with gr.Group():
                with gr.Row():
                    feature_index = gr.Slider(
                        label="Feature index",
                        minimum=0,
                        maximum=constants.N_FEATURES,
                        step=1,
                        value=0,
                    )
                    
            with gr.Group():
                with gr.Row():    
                    game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
                with gr.Row():
                    colorbar = gr.Plot(label="Colorbar")
        with gr.Column():
            board_image = gr.Image(label="Board")

    features = gr.State(None)
    model_output = gr.State(None)
    file_id = gr.State(None)
    compute_features.click(
        compute_features_fn,
        inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
        outputs=[features, model_output, file_id, board_image, colorbar, game_info],
    )