Xmaster6y commited on
Commit
3333fb8
Β·
1 Parent(s): 5e4365f

new working demo

Browse files
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pipenv
85
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
87
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
88
+ # install all needed dependencies.
89
+ #Pipfile.lock
90
+
91
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
92
+ __pypackages__/
93
+
94
+ # Celery stuff
95
+ celerybeat-schedule
96
+ celerybeat.pid
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # Environments
102
+ .env
103
+ .venv
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # Pickle files
129
+ *.pkl
130
+
131
+ # Various files
132
+ ignored
133
+ debug
134
+ *.zip
135
+ lc0
136
+ !bin/lc0
137
+ wandb
138
+ **/.DS_Store
139
+
140
+ *secret*
.pre-commit-config.yaml DELETED
@@ -1,18 +0,0 @@
1
- repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.5.0
4
- hooks:
5
- - id: check-added-large-files
6
- args: ['--maxkb=600']
7
- - id: check-yaml
8
- - id: check-json
9
- - id: check-toml
10
- - id: end-of-file-fixer
11
- - id: trailing-whitespace
12
- - id: check-docstring-first
13
- - repo: https://github.com/astral-sh/ruff-pre-commit
14
- rev: v0.4.2
15
- hooks:
16
- - id: ruff
17
- args: [ --fix ]
18
- - id: ruff-format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,11 +4,11 @@ emoji: πŸ”¬
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.11.0
8
- app_file: app/main.py
9
- pinned: false
10
  license: mit
11
  short_description: Demo lczerolens features
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.28.0
8
+ app_file: main.py
9
+ pinned: true
10
  license: mit
11
  short_description: Demo lczerolens features
12
  ---
13
 
14
+ See the documentation [here](https://lczerolens.readthedocs.io/).
app/attention_interface.py DELETED
@@ -1,293 +0,0 @@
1
- """
2
- Gradio interface for plotting attention.
3
- """
4
-
5
- import copy
6
-
7
- import chess
8
- import gradio as gr
9
-
10
- from lczerolens.board import LczeroBoard
11
- from demo import constants, utils, visualisation
12
-
13
-
14
- def list_models():
15
- """
16
- List the models in the model directory.
17
- """
18
- models_info = utils.get_models_info(leela=False)
19
- return sorted([[model_info[0]] for model_info in models_info])
20
-
21
-
22
- def on_select_model_df(
23
- evt: gr.SelectData,
24
- ):
25
- """
26
- When a model is selected, update the statement.
27
- """
28
- return evt.value
29
-
30
-
31
- def compute_cache(
32
- board_fen,
33
- action_seq,
34
- model_name,
35
- attention_layer,
36
- attention_head,
37
- square,
38
- state_board_index,
39
- state_boards,
40
- state_cache,
41
- ):
42
- if model_name == "":
43
- gr.Warning("No model selected.")
44
- return None, None, None, state_boards, state_cache
45
-
46
- try:
47
- board = LczeroBoard(board_fen)
48
- except ValueError:
49
- board = LczeroBoard()
50
- gr.Warning("Invalid FEN, using starting position.")
51
- state_boards = [board.copy()]
52
- if action_seq:
53
- try:
54
- if action_seq.startswith("1."):
55
- for action in action_seq.split():
56
- if action.endswith("."):
57
- continue
58
- board.push_san(action)
59
- state_boards.append(board.copy())
60
- else:
61
- for action in action_seq.split():
62
- board.push_uci(action)
63
- state_boards.append(board.copy())
64
- except ValueError:
65
- gr.Warning(f"Invalid action {action} stopping before it.")
66
- try:
67
- wrapper, lens = utils.get_wrapper_lens_from_state(
68
- model_name,
69
- "activation",
70
- lens_name="attention",
71
- module_exp=r"encoder\d+/mha/QK/softmax",
72
- )
73
- except ValueError:
74
- gr.Warning("Could not load model.")
75
- return None, None, None, state_boards, state_cache
76
- state_cache = []
77
- for board in state_boards:
78
- attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper))
79
- state_cache.append(attention_cache)
80
- return (
81
- *make_plot(
82
- attention_layer,
83
- attention_head,
84
- square,
85
- state_board_index,
86
- state_boards,
87
- state_cache,
88
- ),
89
- state_boards,
90
- state_cache,
91
- )
92
-
93
-
94
- def make_plot(
95
- attention_layer,
96
- attention_head,
97
- square,
98
- state_board_index,
99
- state_boards,
100
- state_cache,
101
- ):
102
- if state_cache == []:
103
- gr.Warning("No cache available.")
104
- return None, None, None
105
-
106
- board = state_boards[state_board_index]
107
- num_attention_layers = len(state_cache[state_board_index])
108
- if attention_layer > num_attention_layers:
109
- gr.Warning(
110
- f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead."
111
- )
112
- attention_layer = num_attention_layers
113
-
114
- key = f"encoder{attention_layer-1}/mha/QK/softmax"
115
- try:
116
- attention_tensor = state_cache[state_board_index][key]
117
- except KeyError:
118
- gr.Warning(f"Combination {key} does not exist.")
119
- return None, None, None
120
- if attention_head > attention_tensor.shape[1]:
121
- gr.Warning(
122
- f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead."
123
- )
124
- attention_head = attention_tensor.shape[1]
125
- try:
126
- square_index = chess.SQUARE_NAMES.index(square)
127
- except ValueError:
128
- gr.Warning(f"Invalid square {square}, using a1 instead.")
129
- square_index = 0
130
- square = "a1"
131
- if board.turn == chess.BLACK:
132
- square_index = chess.square_mirror(square_index)
133
-
134
- heatmap = attention_tensor[0, attention_head - 1, square_index]
135
- if board.turn == chess.BLACK:
136
- heatmap = heatmap.view(8, 8).flip(0).view(64)
137
- svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square)
138
- with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f:
139
- f.write(svg_board)
140
- return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig
141
-
142
-
143
- def previous_board(
144
- attention_layer,
145
- attention_head,
146
- square,
147
- state_board_index,
148
- state_boards,
149
- state_cache,
150
- ):
151
- state_board_index -= 1
152
- if state_board_index < 0:
153
- gr.Warning("Already at first board.")
154
- state_board_index = 0
155
- return (
156
- *make_plot(
157
- attention_layer,
158
- attention_head,
159
- square,
160
- state_board_index,
161
- state_boards,
162
- state_cache,
163
- ),
164
- state_board_index,
165
- )
166
-
167
-
168
- def next_board(
169
- attention_layer,
170
- attention_head,
171
- square,
172
- state_board_index,
173
- state_boards,
174
- state_cache,
175
- ):
176
- state_board_index += 1
177
- if state_board_index >= len(state_boards):
178
- gr.Warning("Already at last board.")
179
- state_board_index = len(state_boards) - 1
180
- return (
181
- *make_plot(
182
- attention_layer,
183
- attention_head,
184
- square,
185
- state_board_index,
186
- state_boards,
187
- state_cache,
188
- ),
189
- state_board_index,
190
- )
191
-
192
-
193
- with gr.Blocks() as interface:
194
- with gr.Row():
195
- with gr.Column(scale=2):
196
- model_df = gr.Dataframe(
197
- headers=["Available models"],
198
- datatype=["str"],
199
- interactive=False,
200
- type="array",
201
- value=list_models,
202
- )
203
- with gr.Column(scale=1):
204
- with gr.Row():
205
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
206
-
207
- model_df.select(
208
- on_select_model_df,
209
- None,
210
- model_name,
211
- )
212
-
213
- with gr.Row():
214
- with gr.Column():
215
- board_fen = gr.Textbox(
216
- label="Board starting FEN",
217
- lines=1,
218
- max_lines=1,
219
- value=chess.STARTING_FEN,
220
- )
221
- action_seq = gr.Textbox(
222
- label="Action sequence",
223
- lines=1,
224
- max_lines=1,
225
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
226
- )
227
- compute_cache_button = gr.Button("Compute cache")
228
-
229
- with gr.Group():
230
- with gr.Row():
231
- attention_layer = gr.Slider(
232
- label="Attention layer",
233
- minimum=1,
234
- maximum=24,
235
- step=1,
236
- value=1,
237
- )
238
- attention_head = gr.Slider(
239
- label="Attention head",
240
- minimum=1,
241
- maximum=24,
242
- step=1,
243
- value=1,
244
- )
245
- with gr.Row():
246
- square = gr.Textbox(
247
- label="Square",
248
- lines=1,
249
- max_lines=1,
250
- value="a1",
251
- scale=1,
252
- )
253
- with gr.Row():
254
- previous_board_button = gr.Button("Previous board")
255
- next_board_button = gr.Button("Next board")
256
- current_board_fen = gr.Textbox(
257
- label="Board FEN",
258
- lines=1,
259
- max_lines=1,
260
- )
261
- colorbar = gr.Plot(label="Colorbar")
262
- with gr.Column():
263
- image = gr.Image(label="Board")
264
-
265
- state_board_index = gr.State(0)
266
- state_boards = gr.State([])
267
- state_cache = gr.State([])
268
- base_inputs = [
269
- attention_layer,
270
- attention_head,
271
- square,
272
- state_board_index,
273
- state_boards,
274
- state_cache,
275
- ]
276
- outputs = [image, current_board_fen, colorbar]
277
-
278
- compute_cache_button.click(
279
- compute_cache,
280
- inputs=[board_fen, action_seq, model_name] + base_inputs,
281
- outputs=outputs + [state_boards, state_cache],
282
- )
283
-
284
- previous_board_button.click(
285
- previous_board,
286
- inputs=base_inputs,
287
- outputs=outputs + [state_board_index],
288
- )
289
- next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index])
290
-
291
- attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
292
- attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
293
- square.submit(make_plot, inputs=base_inputs, outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/backend_interface.py DELETED
@@ -1,208 +0,0 @@
1
- """
2
- Gradio interface for visualizing the policy of a model.
3
- """
4
-
5
- import chess
6
- import chess.svg
7
- import gradio as gr
8
- import torch
9
- from lczero.backends import Backend, GameState, Weights
10
-
11
- from demo import constants, utils, visualisation
12
- from lczerolens import move_encodings
13
- from lczerolens.model import lczero as lczero_utils
14
- from lczerolens.xai import PolicyLens
15
- from lczerolens.board import LczeroBoard
16
-
17
-
18
- def list_models():
19
- """
20
- List the models in the model directory.
21
- """
22
- models_info = utils.get_models_info(onnx=False)
23
- return sorted([[model_info[0]] for model_info in models_info])
24
-
25
-
26
- def on_select_model_df(
27
- evt: gr.SelectData,
28
- ):
29
- """
30
- When a model is selected, update the statement.
31
- """
32
- return evt.value
33
-
34
-
35
- def make_policy_plot(
36
- board_fen,
37
- action_seq,
38
- view,
39
- model_name,
40
- depth,
41
- use_softmax,
42
- aggregate_topk,
43
- render_bestk,
44
- only_legal,
45
- ):
46
- if model_name == "":
47
- gr.Warning(
48
- "Please select a model.",
49
- )
50
- return (
51
- None,
52
- None,
53
- "",
54
- )
55
- try:
56
- board = LczeroBoard(board_fen)
57
- except ValueError:
58
- board = LczeroBoard()
59
- gr.Warning("Invalid FEN, using starting position.")
60
- if action_seq:
61
- try:
62
- for action in action_seq.split():
63
- board.push_uci(action)
64
- except ValueError:
65
- gr.Warning("Invalid action sequence, using starting position.")
66
- board = LczeroBoard()
67
- lczero_weights = Weights(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}")
68
- lczero_backend = Backend(lczero_weights)
69
- uci_moves = [move.uci() for move in board.move_stack]
70
- lczero_game = GameState(moves=uci_moves)
71
- policy, value = lczero_utils.prediction_from_backend(
72
- lczero_backend,
73
- lczero_game,
74
- softmax=use_softmax,
75
- only_legal=only_legal,
76
- illegal_value=0,
77
- )
78
- pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk))
79
-
80
- if view == "from":
81
- if board.turn == chess.WHITE:
82
- heatmap = pickup_agg
83
- else:
84
- heatmap = pickup_agg.view(8, 8).flip(0).view(64)
85
- else:
86
- if board.turn == chess.WHITE:
87
- heatmap = dropoff_agg
88
- else:
89
- heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
90
- us_them = (board.turn, not board.turn)
91
- if only_legal:
92
- legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves]
93
- filtered_policy = torch.zeros(1858)
94
- filtered_policy[legal_moves] = policy[legal_moves]
95
- if (filtered_policy < 0).any():
96
- gr.Warning("Some legal moves have negative policy.")
97
- topk_moves = torch.topk(filtered_policy, render_bestk)
98
- else:
99
- topk_moves = torch.topk(policy, render_bestk)
100
- arrows = []
101
- for move_index in topk_moves.indices:
102
- move = move_encodings.decode_move(move_index, us_them)
103
- arrows.append((move.from_square, move.to_square))
104
- svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows)
105
- with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
106
- f.write(svg_board)
107
- raw_policy, _ = lczero_utils.prediction_from_backend(
108
- lczero_backend,
109
- lczero_game,
110
- softmax=False,
111
- only_legal=False,
112
- illegal_value=0,
113
- )
114
- fig_dist = visualisation.render_policy_distribution(
115
- raw_policy,
116
- [move_encodings.encode_move(move, us_them) for move in board.legal_moves],
117
- )
118
- return (
119
- f"{constants.FIGURE_DIRECTORY}/policy.svg",
120
- fig,
121
- (f"Value: {value:.2f}"),
122
- fig_dist,
123
- )
124
-
125
-
126
- with gr.Blocks() as interface:
127
- with gr.Row():
128
- with gr.Column(scale=2):
129
- model_df = gr.Dataframe(
130
- headers=["Available models"],
131
- datatype=["str"],
132
- interactive=False,
133
- type="array",
134
- value=list_models,
135
- )
136
- with gr.Column(scale=1):
137
- with gr.Row():
138
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
139
-
140
- model_df.select(
141
- on_select_model_df,
142
- None,
143
- model_name,
144
- )
145
- with gr.Row():
146
- with gr.Column():
147
- board_fen = gr.Textbox(
148
- label="Board FEN",
149
- lines=1,
150
- max_lines=1,
151
- value=chess.STARTING_FEN,
152
- )
153
- action_seq = gr.Textbox(
154
- label="Action sequence",
155
- lines=1,
156
- max_lines=1,
157
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
158
- )
159
- with gr.Group():
160
- with gr.Row():
161
- depth = gr.Radio(label="Depth", choices=[0], value=0)
162
- use_softmax = gr.Checkbox(label="Use softmax", value=True)
163
- with gr.Row():
164
- aggregate_topk = gr.Slider(
165
- label="Aggregate top k",
166
- minimum=1,
167
- maximum=1858,
168
- step=1,
169
- value=1858,
170
- scale=3,
171
- )
172
- view = gr.Radio(
173
- label="View",
174
- choices=["from", "to"],
175
- value="from",
176
- scale=1,
177
- )
178
- with gr.Row():
179
- render_bestk = gr.Slider(
180
- label="Render best k",
181
- minimum=1,
182
- maximum=5,
183
- step=1,
184
- value=5,
185
- scale=3,
186
- )
187
- only_legal = gr.Checkbox(label="Only legal", value=True, scale=1)
188
-
189
- policy_button = gr.Button("Plot policy")
190
- colorbar = gr.Plot(label="Colorbar")
191
- game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
192
- with gr.Column():
193
- image = gr.Image(label="Board")
194
- density_plot = gr.Plot(label="Density")
195
-
196
- policy_inputs = [
197
- board_fen,
198
- action_seq,
199
- view,
200
- model_name,
201
- depth,
202
- use_softmax,
203
- aggregate_topk,
204
- render_bestk,
205
- only_legal,
206
- ]
207
- policy_outputs = [image, colorbar, game_info, density_plot]
208
- policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/constants.py DELETED
@@ -1,7 +0,0 @@
1
- """
2
- Constants for the demo.
3
- """
4
-
5
- MODEL_DIRECTORY = "demo/onnx_models"
6
- LEELA_MODEL_DIRECTORY = "demo/leela_models"
7
- FIGURE_DIRECTORY = "demo/figures"
 
 
 
 
 
 
 
 
app/convert_interface.py DELETED
@@ -1,201 +0,0 @@
1
- """
2
- Gradio interface for converting models.
3
- """
4
-
5
- import os
6
- import uuid
7
-
8
- import gradio as gr
9
-
10
- from demo import constants, utils
11
- from lczerolens.model import lczero as lczero_utils
12
-
13
-
14
- def list_models():
15
- """
16
- List the models in the model directory.
17
- """
18
- models_info = utils.get_models_info()
19
- return sorted([[model_info[0]] for model_info in models_info])
20
-
21
-
22
- def on_select_model_df(
23
- evt: gr.SelectData,
24
- ):
25
- """
26
- When a model is selected, update the statement.
27
- """
28
- return evt.value
29
-
30
-
31
- def convert_model(
32
- model_name: str,
33
- ):
34
- """
35
- Convert the model.
36
- """
37
- if model_name == "":
38
- gr.Warning(
39
- "Please select a model.",
40
- )
41
- return list_models(), ""
42
- if model_name.endswith(".onnx"):
43
- gr.Warning(
44
- "ONNX conversion not implemented.",
45
- )
46
- return list_models(), ""
47
- try:
48
- lczero_utils.convert_to_onnx(
49
- f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
50
- f"{constants.MODEL_DIRECTORY}/{model_name[:-6]}.onnx",
51
- )
52
- except RuntimeError:
53
- gr.Warning(
54
- f"Could not convert net at `{model_name}`.",
55
- )
56
- return list_models(), "Conversion failed"
57
- return list_models(), "Conversion successful"
58
-
59
-
60
- def upload_model(
61
- model_file: gr.File,
62
- ):
63
- """
64
- Convert the model.
65
- """
66
- if model_file is None:
67
- gr.Warning(
68
- "File not uploaded.",
69
- )
70
- return list_models()
71
- try:
72
- id = uuid.uuid4()
73
- tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}"
74
- with open(
75
- tmp_file_path,
76
- "wb",
77
- ) as f:
78
- f.write(model_file)
79
- utils.save_model(tmp_file_path)
80
- except RuntimeError:
81
- gr.Warning(
82
- "Invalid file type.",
83
- )
84
- finally:
85
- if os.path.exists(tmp_file_path):
86
- os.remove(tmp_file_path)
87
- return list_models()
88
-
89
-
90
- def get_model_description(
91
- model_name: str,
92
- ):
93
- """
94
- Get the model description.
95
- """
96
- if model_name == "":
97
- gr.Warning(
98
- "Please select a model.",
99
- )
100
- return ""
101
- if model_name.endswith(".onnx"):
102
- gr.Warning(
103
- "ONNX description not implemented.",
104
- )
105
- return ""
106
- try:
107
- description = lczero_utils.describenet(
108
- f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
109
- )
110
- except RuntimeError:
111
- raise gr.Error(
112
- f"Could not describe net at `{model_name}`.",
113
- )
114
- return description
115
-
116
-
117
- def get_model_path(
118
- model_name: str,
119
- ):
120
- """
121
- Get the model path.
122
- """
123
- if model_name == "":
124
- gr.Warning(
125
- "Please select a model.",
126
- )
127
- return None
128
- if model_name.endswith(".onnx"):
129
- return f"{constants.MODEL_DIRECTORY}/{model_name}"
130
- else:
131
- return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}"
132
-
133
-
134
- with gr.Blocks() as interface:
135
- model_file = gr.File(type="binary")
136
- upload_button = gr.Button(
137
- value="Upload",
138
- )
139
- with gr.Row():
140
- with gr.Column(scale=2):
141
- model_df = gr.Dataframe(
142
- headers=["Available models"],
143
- datatype=["str"],
144
- interactive=False,
145
- type="array",
146
- value=list_models,
147
- )
148
- with gr.Column(scale=1):
149
- with gr.Row():
150
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
151
- conversion_status = gr.Textbox(
152
- label="Conversion status",
153
- lines=1,
154
- interactive=False,
155
- )
156
-
157
- convert_button = gr.Button(
158
- value="Convert",
159
- )
160
- describe_button = gr.Button(
161
- value="Describe model",
162
- )
163
- model_description = gr.Textbox(
164
- label="Model description",
165
- lines=1,
166
- interactive=False,
167
- )
168
- download_button = gr.Button(
169
- value="Get download link",
170
- )
171
- download_file = gr.File(
172
- type="filepath",
173
- label="Download link",
174
- interactive=False,
175
- )
176
-
177
- model_df.select(
178
- on_select_model_df,
179
- None,
180
- model_name,
181
- )
182
- upload_button.click(
183
- upload_model,
184
- model_file,
185
- model_df,
186
- )
187
- convert_button.click(
188
- convert_model,
189
- model_name,
190
- [model_df, conversion_status],
191
- )
192
- describe_button.click(
193
- get_model_description,
194
- model_name,
195
- model_description,
196
- )
197
- download_button.click(
198
- get_model_path,
199
- model_name,
200
- download_file,
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/crp_interface.py DELETED
@@ -1,281 +0,0 @@
1
- """
2
- Gradio interface for plotting policy.
3
- """
4
-
5
- import copy
6
-
7
- import chess
8
- import gradio as gr
9
-
10
- from demo import constants, utils, visualisation
11
- from lczerolens.board import LczeroBoard
12
-
13
-
14
- cache = None
15
- boards = None
16
- board_index = 0
17
-
18
-
19
- def list_models():
20
- """
21
- List the models in the model directory.
22
- """
23
- models_info = utils.get_models_info(leela=False)
24
- return sorted([[model_info[0]] for model_info in models_info])
25
-
26
-
27
- def on_select_model_df(
28
- evt: gr.SelectData,
29
- ):
30
- """
31
- When a model is selected, update the statement.
32
- """
33
- return evt.value
34
-
35
-
36
- def compute_cache(
37
- board_fen,
38
- action_seq,
39
- model_name,
40
- plane_index,
41
- history_index,
42
- ):
43
- global cache
44
- global boards
45
- if model_name == "":
46
- gr.Warning("No model selected.")
47
- return None, None, None, None, None
48
- try:
49
- board = LczeroBoard(board_fen)
50
- except ValueError:
51
- board = LczeroBoard()
52
- gr.Warning("Invalid FEN, using starting position.")
53
- boards = [board.copy()]
54
- if action_seq:
55
- try:
56
- if action_seq.startswith("1."):
57
- for action in action_seq.split():
58
- if action.endswith("."):
59
- continue
60
- board.push_san(action)
61
- boards.append(board.copy())
62
- else:
63
- for action in action_seq.split():
64
- board.push_uci(action)
65
- boards.append(board.copy())
66
- except ValueError:
67
- gr.Warning(f"Invalid action {action} stopping before it.")
68
- wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "crp")
69
- cache = []
70
- for board in boards:
71
- relevance = lens.compute_heatmap(board, wrapper)
72
- cache.append(copy.deepcopy(relevance))
73
- return (
74
- *make_plot(
75
- plane_index,
76
- ),
77
- *make_history_plot(
78
- history_index,
79
- ),
80
- )
81
-
82
-
83
- def make_plot(
84
- plane_index,
85
- ):
86
- global cache
87
- global boards
88
- global board_index
89
-
90
- if cache is None:
91
- gr.Warning("Cache not computed!")
92
- return None, None, None
93
-
94
- board = boards[board_index]
95
- relevance_tensor = cache[board_index]
96
- a_max = relevance_tensor.abs().max()
97
- if a_max != 0:
98
- relevance_tensor = relevance_tensor / a_max
99
- vmin = -1
100
- vmax = 1
101
- heatmap = relevance_tensor[plane_index - 1].view(64)
102
- if board.turn == chess.BLACK:
103
- heatmap = heatmap.view(8, 8).flip(0).view(64)
104
- svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
105
- with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
106
- f.write(svg_board)
107
- return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
108
-
109
-
110
- def make_history_plot(
111
- history_index,
112
- ):
113
- global cache
114
- global boards
115
- global board_index
116
-
117
- if cache is None:
118
- gr.Warning("Cache not computed!")
119
- return None, None
120
-
121
- board = boards[board_index]
122
- relevance_tensor = cache[board_index]
123
- a_max = relevance_tensor.abs().max()
124
- if a_max != 0:
125
- relevance_tensor = relevance_tensor / a_max
126
- vmin = -1
127
- vmax = 1
128
- heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
129
- if board.turn == chess.BLACK:
130
- heatmap = heatmap.view(8, 8).flip(0).view(64)
131
- if board_index - history_index + 1 < 0:
132
- history_board = LczeroBoard(fen=None)
133
- else:
134
- history_board = boards[board_index - history_index + 1]
135
- svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
136
- with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
137
- f.write(svg_board)
138
- return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
139
-
140
-
141
- def previous_board(
142
- plane_index,
143
- history_index,
144
- ):
145
- global board_index
146
- board_index -= 1
147
- if board_index < 0:
148
- gr.Warning("Already at first board.")
149
- board_index = 0
150
- return (
151
- *make_plot(
152
- plane_index,
153
- ),
154
- *make_history_plot(
155
- history_index,
156
- ),
157
- )
158
-
159
-
160
- def next_board(
161
- plane_index,
162
- history_index,
163
- ):
164
- global board_index
165
- board_index += 1
166
- if board_index >= len(boards):
167
- gr.Warning("Already at last board.")
168
- board_index = len(boards) - 1
169
- return (
170
- *make_plot(
171
- plane_index,
172
- ),
173
- *make_history_plot(
174
- history_index,
175
- ),
176
- )
177
-
178
-
179
- with gr.Blocks() as interface:
180
- with gr.Row():
181
- with gr.Column(scale=2):
182
- model_df = gr.Dataframe(
183
- headers=["Available models"],
184
- datatype=["str"],
185
- interactive=False,
186
- type="array",
187
- value=list_models,
188
- )
189
- with gr.Column(scale=1):
190
- with gr.Row():
191
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
192
-
193
- model_df.select(
194
- on_select_model_df,
195
- None,
196
- model_name,
197
- )
198
-
199
- with gr.Row():
200
- with gr.Column():
201
- board_fen = gr.Textbox(
202
- label="Board starting FEN",
203
- lines=1,
204
- max_lines=1,
205
- value=chess.STARTING_FEN,
206
- )
207
- action_seq = gr.Textbox(
208
- label="Action sequence",
209
- lines=1,
210
- max_lines=1,
211
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
212
- )
213
- compute_cache_button = gr.Button("Compute heatmaps")
214
-
215
- with gr.Group():
216
- with gr.Row():
217
- plane_index = gr.Slider(
218
- label="Plane index",
219
- minimum=1,
220
- maximum=112,
221
- step=1,
222
- value=1,
223
- )
224
- with gr.Row():
225
- previous_board_button = gr.Button("Previous board")
226
- next_board_button = gr.Button("Next board")
227
- current_board_fen = gr.Textbox(
228
- label="Board FEN",
229
- lines=1,
230
- max_lines=1,
231
- )
232
- colorbar = gr.Plot(label="Colorbar")
233
- with gr.Column():
234
- image = gr.Image(label="Board")
235
-
236
- with gr.Row():
237
- with gr.Column():
238
- with gr.Group():
239
- with gr.Row():
240
- histroy_index = gr.Slider(
241
- label="History index",
242
- minimum=1,
243
- maximum=8,
244
- step=1,
245
- value=1,
246
- )
247
- history_colorbar = gr.Plot(label="Colorbar")
248
- with gr.Column():
249
- history_image = gr.Image(label="Board")
250
-
251
- base_inputs = [
252
- plane_index,
253
- histroy_index,
254
- ]
255
- outputs = [
256
- image,
257
- current_board_fen,
258
- colorbar,
259
- history_image,
260
- history_colorbar,
261
- ]
262
-
263
- compute_cache_button.click(
264
- compute_cache,
265
- inputs=[board_fen, action_seq, model_name] + base_inputs,
266
- outputs=outputs,
267
- )
268
-
269
- previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
270
- next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
271
-
272
- plane_index.change(
273
- make_plot,
274
- inputs=plane_index,
275
- outputs=[image, current_board_fen, colorbar],
276
- )
277
- histroy_index.change(
278
- make_history_plot,
279
- inputs=histroy_index,
280
- outputs=[history_image, history_colorbar],
281
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/encoding_interface.py DELETED
@@ -1,83 +0,0 @@
1
- """
2
- Gradio interface for plotting encodings.
3
- """
4
-
5
- import chess
6
- import gradio as gr
7
-
8
- from demo import constants, visualisation
9
- from lczerolens import board_encodings
10
- from lczerolens.board import LczeroBoard
11
-
12
-
13
- def make_encoding_plot(
14
- board_fen,
15
- action_seq,
16
- plane_index,
17
- color_flip,
18
- ):
19
- try:
20
- board = LczeroBoard(board_fen)
21
- except ValueError:
22
- board = LczeroBoard()
23
- gr.Warning("Invalid FEN, using starting position.")
24
- if action_seq:
25
- try:
26
- for action in action_seq.split():
27
- board.push_uci(action)
28
- except ValueError:
29
- gr.Warning("Invalid action sequence, using starting position.")
30
- board = LczeroBoard()
31
- board_tensor = board_encodings.board_to_input_tensor(board)
32
- heatmap = board_tensor[plane_index]
33
- if color_flip and board.turn == chess.BLACK:
34
- heatmap = heatmap.flip(0)
35
- svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0)
36
- with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f:
37
- f.write(svg_board)
38
- return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig
39
-
40
-
41
- with gr.Blocks() as interface:
42
- with gr.Row():
43
- with gr.Column():
44
- board_fen = gr.Textbox(
45
- label="Board starting FEN",
46
- lines=1,
47
- max_lines=1,
48
- value=chess.STARTING_FEN,
49
- )
50
- action_seq = gr.Textbox(
51
- label="Action sequence",
52
- lines=1,
53
- max_lines=1,
54
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
55
- )
56
- with gr.Group():
57
- with gr.Row():
58
- plane_index = gr.Slider(
59
- label="Plane index",
60
- minimum=0,
61
- maximum=111,
62
- step=1,
63
- value=0,
64
- scale=3,
65
- )
66
- color_flip = gr.Checkbox(label="Color flip", value=True, scale=1)
67
-
68
- colorbar = gr.Plot(label="Colorbar")
69
- with gr.Column():
70
- image = gr.Image(label="Board")
71
-
72
- policy_inputs = [
73
- board_fen,
74
- action_seq,
75
- plane_index,
76
- color_flip,
77
- ]
78
- policy_outputs = [image, colorbar]
79
- board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
80
- action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
81
- plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
82
- color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
83
- interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/lrp_interface.py DELETED
@@ -1,280 +0,0 @@
1
- """
2
- Gradio interface for plotting policy.
3
- """
4
-
5
- import copy
6
-
7
- import chess
8
- import gradio as gr
9
-
10
- from demo import constants, utils, visualisation
11
- from lczerolens.board import LczeroBoard
12
-
13
- cache = None
14
- boards = None
15
- board_index = 0
16
-
17
-
18
- def list_models():
19
- """
20
- List the models in the model directory.
21
- """
22
- models_info = utils.get_models_info(leela=False)
23
- return sorted([[model_info[0]] for model_info in models_info])
24
-
25
-
26
- def on_select_model_df(
27
- evt: gr.SelectData,
28
- ):
29
- """
30
- When a model is selected, update the statement.
31
- """
32
- return evt.value
33
-
34
-
35
- def compute_cache(
36
- board_fen,
37
- action_seq,
38
- model_name,
39
- plane_index,
40
- history_index,
41
- ):
42
- global cache
43
- global boards
44
- if model_name == "":
45
- gr.Warning("No model selected.")
46
- return None, None, None, None, None
47
- try:
48
- board = LczeroBoard(board_fen)
49
- except ValueError:
50
- board = LczeroBoard()
51
- gr.Warning("Invalid FEN, using starting position.")
52
- boards = [board.copy()]
53
- if action_seq:
54
- try:
55
- if action_seq.startswith("1."):
56
- for action in action_seq.split():
57
- if action.endswith("."):
58
- continue
59
- board.push_san(action)
60
- boards.append(board.copy())
61
- else:
62
- for action in action_seq.split():
63
- board.push_uci(action)
64
- boards.append(board.copy())
65
- except ValueError:
66
- gr.Warning(f"Invalid action {action} stopping before it.")
67
- wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
68
- cache = []
69
- for board in boards:
70
- relevance = lens.compute_heatmap(board, wrapper)
71
- cache.append(copy.deepcopy(relevance))
72
- return (
73
- *make_plot(
74
- plane_index,
75
- ),
76
- *make_history_plot(
77
- history_index,
78
- ),
79
- )
80
-
81
-
82
- def make_plot(
83
- plane_index,
84
- ):
85
- global cache
86
- global boards
87
- global board_index
88
-
89
- if cache is None:
90
- gr.Warning("Cache not computed!")
91
- return None, None, None
92
-
93
- board = boards[board_index]
94
- relevance_tensor = cache[board_index]
95
- a_max = relevance_tensor.abs().max()
96
- if a_max != 0:
97
- relevance_tensor = relevance_tensor / a_max
98
- vmin = -1
99
- vmax = 1
100
- heatmap = relevance_tensor[plane_index - 1].view(64)
101
- if board.turn == chess.BLACK:
102
- heatmap = heatmap.view(8, 8).flip(0).view(64)
103
- svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
104
- with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
105
- f.write(svg_board)
106
- return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
107
-
108
-
109
- def make_history_plot(
110
- history_index,
111
- ):
112
- global cache
113
- global boards
114
- global board_index
115
-
116
- if cache is None:
117
- gr.Warning("Cache not computed!")
118
- return None, None
119
-
120
- board = boards[board_index]
121
- relevance_tensor = cache[board_index]
122
- a_max = relevance_tensor.abs().max()
123
- if a_max != 0:
124
- relevance_tensor = relevance_tensor / a_max
125
- vmin = -1
126
- vmax = 1
127
- heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
128
- if board.turn == chess.BLACK:
129
- heatmap = heatmap.view(8, 8).flip(0).view(64)
130
- if board_index - history_index + 1 < 0:
131
- history_board = LczeroBoard(fen=None)
132
- else:
133
- history_board = boards[board_index - history_index + 1]
134
- svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
135
- with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
136
- f.write(svg_board)
137
- return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
138
-
139
-
140
- def previous_board(
141
- plane_index,
142
- history_index,
143
- ):
144
- global board_index
145
- board_index -= 1
146
- if board_index < 0:
147
- gr.Warning("Already at first board.")
148
- board_index = 0
149
- return (
150
- *make_plot(
151
- plane_index,
152
- ),
153
- *make_history_plot(
154
- history_index,
155
- ),
156
- )
157
-
158
-
159
- def next_board(
160
- plane_index,
161
- history_index,
162
- ):
163
- global board_index
164
- board_index += 1
165
- if board_index >= len(boards):
166
- gr.Warning("Already at last board.")
167
- board_index = len(boards) - 1
168
- return (
169
- *make_plot(
170
- plane_index,
171
- ),
172
- *make_history_plot(
173
- history_index,
174
- ),
175
- )
176
-
177
-
178
- with gr.Blocks() as interface:
179
- with gr.Row():
180
- with gr.Column(scale=2):
181
- model_df = gr.Dataframe(
182
- headers=["Available models"],
183
- datatype=["str"],
184
- interactive=False,
185
- type="array",
186
- value=list_models,
187
- )
188
- with gr.Column(scale=1):
189
- with gr.Row():
190
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
191
-
192
- model_df.select(
193
- on_select_model_df,
194
- None,
195
- model_name,
196
- )
197
-
198
- with gr.Row():
199
- with gr.Column():
200
- board_fen = gr.Textbox(
201
- label="Board starting FEN",
202
- lines=1,
203
- max_lines=1,
204
- value=chess.STARTING_FEN,
205
- )
206
- action_seq = gr.Textbox(
207
- label="Action sequence",
208
- lines=1,
209
- max_lines=1,
210
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
211
- )
212
- compute_cache_button = gr.Button("Compute heatmaps")
213
-
214
- with gr.Group():
215
- with gr.Row():
216
- plane_index = gr.Slider(
217
- label="Plane index",
218
- minimum=1,
219
- maximum=112,
220
- step=1,
221
- value=1,
222
- )
223
- with gr.Row():
224
- previous_board_button = gr.Button("Previous board")
225
- next_board_button = gr.Button("Next board")
226
- current_board_fen = gr.Textbox(
227
- label="Board FEN",
228
- lines=1,
229
- max_lines=1,
230
- )
231
- colorbar = gr.Plot(label="Colorbar")
232
- with gr.Column():
233
- image = gr.Image(label="Board")
234
-
235
- with gr.Row():
236
- with gr.Column():
237
- with gr.Group():
238
- with gr.Row():
239
- histroy_index = gr.Slider(
240
- label="History index",
241
- minimum=1,
242
- maximum=8,
243
- step=1,
244
- value=1,
245
- )
246
- history_colorbar = gr.Plot(label="Colorbar")
247
- with gr.Column():
248
- history_image = gr.Image(label="Board")
249
-
250
- base_inputs = [
251
- plane_index,
252
- histroy_index,
253
- ]
254
- outputs = [
255
- image,
256
- current_board_fen,
257
- colorbar,
258
- history_image,
259
- history_colorbar,
260
- ]
261
-
262
- compute_cache_button.click(
263
- compute_cache,
264
- inputs=[board_fen, action_seq, model_name] + base_inputs,
265
- outputs=outputs,
266
- )
267
-
268
- previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
269
- next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
270
-
271
- plane_index.change(
272
- make_plot,
273
- inputs=plane_index,
274
- outputs=[image, current_board_fen, colorbar],
275
- )
276
- histroy_index.change(
277
- make_history_plot,
278
- inputs=histroy_index,
279
- outputs=[history_image, history_colorbar],
280
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/main.py DELETED
@@ -1,50 +0,0 @@
1
- """
2
- Gradio demo for lczero-easy.
3
- """
4
-
5
- import gradio as gr
6
-
7
- from . import (
8
- attention_interface,
9
- backend_interface,
10
- board_interface,
11
- convert_interface,
12
- crp_interface,
13
- encoding_interface,
14
- lrp_interface,
15
- policy_interface,
16
- statistics_interface,
17
- )
18
-
19
- demo = gr.TabbedInterface(
20
- [
21
- crp_interface.interface,
22
- statistics_interface.interface,
23
- lrp_interface.interface,
24
- attention_interface.interface,
25
- policy_interface.interface,
26
- backend_interface.interface,
27
- encoding_interface.interface,
28
- board_interface.interface,
29
- convert_interface.interface,
30
- ],
31
- [
32
- "CRP",
33
- "Statistics",
34
- "LRP",
35
- "Attention",
36
- "Policy",
37
- "Backend",
38
- "Encoding",
39
- "Board",
40
- "Convert",
41
- ],
42
- title="LczeroLens Demo",
43
- analytics_enabled=False,
44
- )
45
-
46
- if __name__ == "__main__":
47
- demo.launch(
48
- server_port=8000,
49
- server_name="0.0.0.0",
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/policy_interface.py DELETED
@@ -1,278 +0,0 @@
1
- """
2
- Gradio interface for visualizing the policy of a model.
3
- """
4
-
5
- import chess
6
- import chess.svg
7
- import gradio as gr
8
- import torch
9
-
10
- from demo import constants, utils, visualisation
11
- from lczerolens import move_encodings
12
- from lczerolens.board import LczeroBoard
13
- from lczerolens.xai import PolicyLens
14
-
15
-
16
- current_board = None
17
- current_raw_policy = None
18
- current_policy = None
19
- current_value = None
20
- current_outcome = None
21
-
22
-
23
- def list_models():
24
- """
25
- List the models in the model directory.
26
- """
27
- models_info = utils.get_models_info(leela=False)
28
- return sorted([[model_info[0]] for model_info in models_info])
29
-
30
-
31
- def on_select_model_df(
32
- evt: gr.SelectData,
33
- ):
34
- """
35
- When a model is selected, update the statement.
36
- """
37
- return evt.value
38
-
39
-
40
- def compute_policy(
41
- board_fen,
42
- action_seq,
43
- model_name,
44
- ):
45
- global current_board
46
- global current_policy
47
- global current_raw_policy
48
- global current_value
49
- global current_outcome
50
- if model_name == "":
51
- gr.Warning(
52
- "Please select a model.",
53
- )
54
- return (
55
- None,
56
- None,
57
- "",
58
- )
59
- try:
60
- board = LczeroBoard(board_fen)
61
- except ValueError:
62
- gr.Warning("Invalid FEN.")
63
- return (None, None, "", None)
64
- if action_seq:
65
- try:
66
- for action in action_seq.split():
67
- board.push_uci(action)
68
- except ValueError:
69
- gr.Warning("Invalid action sequence.")
70
- return (None, None, "", None)
71
- wrapper = utils.get_wrapper_from_state(model_name)
72
- (output,) = wrapper.predict(board)
73
- current_raw_policy = output["policy"][0]
74
- policy = torch.softmax(output["policy"][0], dim=-1)
75
-
76
- filtered_policy = torch.full((1858,), 0.0)
77
- legal_moves = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves]
78
- filtered_policy[legal_moves] = policy[legal_moves]
79
- policy = filtered_policy
80
-
81
- current_board = board
82
- current_policy = policy
83
- current_value = output.get("value", None)
84
- current_outcome = output.get("wdl", None)
85
-
86
-
87
- def make_plot(
88
- view,
89
- aggregate_topk,
90
- move_to_play,
91
- ):
92
- global current_board
93
- global current_policy
94
- global current_raw_policy
95
- global current_value
96
- global current_outcome
97
-
98
- if current_board is None or current_policy is None:
99
- gr.Warning("Please compute a policy first.")
100
- return (None, None, "", None)
101
-
102
- pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(current_policy, int(aggregate_topk))
103
-
104
- if view == "from":
105
- if current_board.turn == chess.WHITE:
106
- heatmap = pickup_agg
107
- else:
108
- heatmap = pickup_agg.view(8, 8).flip(0).view(64)
109
- else:
110
- if current_board.turn == chess.WHITE:
111
- heatmap = dropoff_agg
112
- else:
113
- heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
114
- us_them = (current_board.turn, not current_board.turn)
115
- topk_moves = torch.topk(current_policy, 50)
116
- move = move_encodings.decode_move(topk_moves.indices[move_to_play - 1], us_them)
117
- arrows = [(move.from_square, move.to_square)]
118
- svg_board, fig = visualisation.render_heatmap(current_board, heatmap, arrows=arrows)
119
- with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
120
- f.write(svg_board)
121
- fig_dist = visualisation.render_policy_distribution(
122
- current_raw_policy,
123
- [move_encodings.encode_move(move, us_them) for move in current_board.legal_moves],
124
- )
125
- return (
126
- f"{constants.FIGURE_DIRECTORY}/policy.svg",
127
- fig,
128
- (f"Value: {current_value} - WDL: {current_outcome}"),
129
- fig_dist,
130
- )
131
-
132
-
133
- def make_policy_plot(
134
- board_fen,
135
- action_seq,
136
- view,
137
- model_name,
138
- aggregate_topk,
139
- move_to_play,
140
- ):
141
- compute_policy(
142
- board_fen,
143
- action_seq,
144
- model_name,
145
- )
146
- return make_plot(
147
- view,
148
- aggregate_topk,
149
- move_to_play,
150
- )
151
-
152
-
153
- def play_move(
154
- board_fen,
155
- action_seq,
156
- view,
157
- model_name,
158
- aggregate_topk,
159
- move_to_play,
160
- ):
161
- global current_board
162
- global current_policy
163
-
164
- move = move_encodings.decode_move(
165
- current_policy.topk(50).indices[move_to_play - 1],
166
- (current_board.turn, not current_board.turn),
167
- )
168
- current_board.push(move)
169
- action_seq = f"{action_seq} {move.uci()}"
170
- compute_policy(
171
- board_fen,
172
- action_seq,
173
- model_name,
174
- )
175
- return [
176
- *make_plot(
177
- view,
178
- aggregate_topk,
179
- 1,
180
- ),
181
- action_seq,
182
- 1,
183
- ]
184
-
185
-
186
- with gr.Blocks() as interface:
187
- with gr.Row():
188
- with gr.Column(scale=2):
189
- model_df = gr.Dataframe(
190
- headers=["Available models"],
191
- datatype=["str"],
192
- interactive=False,
193
- type="array",
194
- value=list_models,
195
- )
196
- with gr.Column(scale=1):
197
- with gr.Row():
198
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
199
- model_df.select(
200
- on_select_model_df,
201
- None,
202
- model_name,
203
- )
204
-
205
- with gr.Row():
206
- with gr.Column():
207
- board_fen = gr.Textbox(
208
- label="Board FEN",
209
- lines=1,
210
- max_lines=1,
211
- value=chess.STARTING_FEN,
212
- )
213
- action_seq = gr.Textbox(
214
- label="Action sequence",
215
- lines=1,
216
- value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
217
- )
218
- with gr.Group():
219
- with gr.Row():
220
- aggregate_topk = gr.Slider(
221
- label="Aggregate top k",
222
- minimum=1,
223
- maximum=1858,
224
- step=1,
225
- value=1858,
226
- scale=3,
227
- )
228
- view = gr.Radio(
229
- label="View",
230
- choices=["from", "to"],
231
- value="from",
232
- scale=1,
233
- )
234
- with gr.Row():
235
- move_to_play = gr.Slider(
236
- label="Move to play",
237
- minimum=1,
238
- maximum=50,
239
- step=1,
240
- value=1,
241
- scale=3,
242
- )
243
- play_button = gr.Button("Play")
244
-
245
- policy_button = gr.Button("Compute policy")
246
- colorbar = gr.Plot(label="Colorbar")
247
- game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
248
- with gr.Column():
249
- image = gr.Image(label="Board")
250
- density_plot = gr.Plot(label="Density")
251
-
252
- policy_inputs = [
253
- board_fen,
254
- action_seq,
255
- view,
256
- model_name,
257
- aggregate_topk,
258
- move_to_play,
259
- ]
260
- policy_outputs = [image, colorbar, game_info, density_plot]
261
- policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
262
- board_fen.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
263
- action_seq.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
264
-
265
- fast_inputs = [
266
- view,
267
- aggregate_topk,
268
- move_to_play,
269
- ]
270
- aggregate_topk.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
271
- view.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
272
- move_to_play.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
273
-
274
- play_button.click(
275
- play_move,
276
- inputs=policy_inputs,
277
- outputs=policy_outputs + [action_seq, move_to_play],
278
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/state.py DELETED
@@ -1,18 +0,0 @@
1
- """
2
- Global state for the demo application.
3
- """
4
-
5
- from typing import Dict
6
-
7
- from lczerolens import Lens, ModelWrapper
8
-
9
- wrappers: Dict[str, ModelWrapper] = {}
10
-
11
- lenses: Dict[str, Dict[str, Lens]] = {
12
- "activation": {},
13
- "lrp": {},
14
- "crp": {},
15
- "policy": {},
16
- "probing": {},
17
- "patching": {},
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/statistics_interface.py DELETED
@@ -1,189 +0,0 @@
1
- """
2
- Gradio interface for visualizing the policy of a model.
3
- """
4
-
5
- import gradio as gr
6
-
7
- from demo import utils, visualisation
8
- from lczerolens import GameDataset
9
- from lczerolens.xai import ConceptDataset, HasThreatConcept
10
-
11
- current_policy_statistics = None
12
- current_lrp_statistics = None
13
- current_probing_statistics = None
14
- dataset = GameDataset("assets/test_stockfish_10.jsonl")
15
- check_concept = HasThreatConcept("K", relative=True)
16
- unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
17
- unique_check_dataset.set_concept(check_concept)
18
-
19
-
20
- def list_models():
21
- """
22
- List the models in the model directory.
23
- """
24
- models_info = utils.get_models_info(leela=False)
25
- return sorted([[model_info[0]] for model_info in models_info])
26
-
27
-
28
- def on_select_model_df(
29
- evt: gr.SelectData,
30
- ):
31
- """
32
- When a model is selected, update the statement.
33
- """
34
- return evt.value
35
-
36
-
37
- def compute_policy_statistics(
38
- model_name,
39
- ):
40
- global current_policy_statistics
41
- global dataset
42
-
43
- if model_name == "":
44
- gr.Warning(
45
- "Please select a model.",
46
- )
47
- return None
48
- wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
49
- current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
50
- return make_policy_plot()
51
-
52
-
53
- def make_policy_plot():
54
- global current_policy_statistics
55
-
56
- if current_policy_statistics is None:
57
- gr.Warning(
58
- "Please compute policy statistics first.",
59
- )
60
- return None
61
- else:
62
- return visualisation.render_policy_statistics(current_policy_statistics)
63
-
64
-
65
- def compute_lrp_statistics(
66
- model_name,
67
- ):
68
- global current_lrp_statistics
69
- global dataset
70
-
71
- if model_name == "":
72
- gr.Warning(
73
- "Please select a model.",
74
- )
75
- return None, None, None
76
- wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
77
- current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
78
- return make_lrp_plot()
79
-
80
-
81
- def make_lrp_plot():
82
- global current_lrp_statistics
83
-
84
- if current_lrp_statistics is None:
85
- gr.Warning(
86
- "Please compute LRP statistics first.",
87
- )
88
- return None, None, None
89
- else:
90
- return visualisation.render_relevance_proportion(current_lrp_statistics)
91
-
92
-
93
- def compute_probing_statistics(
94
- model_name,
95
- ):
96
- global current_probing_statistics
97
- global check_concept
98
- global unique_check_dataset
99
-
100
- if model_name == "":
101
- gr.Warning(
102
- "Please select a model.",
103
- )
104
- return None
105
- wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
106
- current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
107
- return make_probing_plot()
108
-
109
-
110
- def make_probing_plot():
111
- global current_probing_statistics
112
-
113
- if current_probing_statistics is None:
114
- gr.Warning(
115
- "Please compute probing statistics first.",
116
- )
117
- return None
118
- else:
119
- return visualisation.render_probing_statistics(current_probing_statistics)
120
-
121
-
122
- with gr.Blocks() as interface:
123
- with gr.Row():
124
- with gr.Column(scale=2):
125
- model_df = gr.Dataframe(
126
- headers=["Available models"],
127
- datatype=["str"],
128
- interactive=False,
129
- type="array",
130
- value=list_models,
131
- )
132
- with gr.Column(scale=1):
133
- with gr.Row():
134
- model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
135
- model_df.select(
136
- on_select_model_df,
137
- None,
138
- model_name,
139
- )
140
-
141
- with gr.Row():
142
- with gr.Column():
143
- policy_plot = gr.Plot(label="Policy statistics")
144
- policy_compute_button = gr.Button(value="Compute policy statistics")
145
- policy_plot_button = gr.Button(value="Plot policy statistics")
146
-
147
- policy_compute_button.click(
148
- compute_policy_statistics,
149
- inputs=[model_name],
150
- outputs=[policy_plot],
151
- )
152
- policy_plot_button.click(make_policy_plot, outputs=[policy_plot])
153
-
154
- with gr.Column():
155
- lrp_plot_hist = gr.Plot(label="LRP history statistics")
156
-
157
- with gr.Row():
158
- with gr.Column():
159
- lrp_plot_planes = gr.Plot(label="LRP planes statistics")
160
-
161
- with gr.Column():
162
- lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")
163
-
164
- with gr.Row():
165
- lrp_compute_button = gr.Button(value="Compute LRP statistics")
166
- with gr.Row():
167
- lrp_plot_button = gr.Button(value="Plot LRP statistics")
168
-
169
- lrp_compute_button.click(
170
- compute_lrp_statistics,
171
- inputs=[model_name],
172
- outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
173
- )
174
- lrp_plot_button.click(
175
- make_lrp_plot,
176
- outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
177
- )
178
-
179
- with gr.Column():
180
- probing_plot = gr.Plot(label="Probing statistics")
181
- probing_compute_button = gr.Button(value="Compute probing statistics")
182
- probing_plot_button = gr.Button(value="Plot probing statistics")
183
-
184
- probing_compute_button.click(
185
- compute_probing_statistics,
186
- inputs=[model_name],
187
- outputs=[probing_plot],
188
- )
189
- probing_plot_button.click(make_probing_plot, outputs=[probing_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/utils.py DELETED
@@ -1,121 +0,0 @@
1
- """
2
- Utils for the demo app.
3
- """
4
-
5
- import os
6
- import re
7
- import subprocess
8
-
9
- from demo import constants, state
10
- from lczerolens import Lens, LczeroModel
11
- from lczerolens.model import lczero as lczero_utils
12
-
13
-
14
- def get_models_info(onnx=True, leela=True):
15
- """
16
- Get the names of the models in the model directory.
17
- """
18
- model_df = []
19
- exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)"
20
- if onnx:
21
- for filename in os.listdir(constants.MODEL_DIRECTORY):
22
- if filename.endswith(".onnx"):
23
- match = re.search(exp, filename)
24
- if match is None:
25
- n_filters = -1
26
- n_blocks = -1
27
- else:
28
- n_filters = int(match.group("n_filters"))
29
- n_blocks = int(match.group("n_blocks"))
30
- model_df.append(
31
- [
32
- filename,
33
- "ONNX",
34
- n_blocks,
35
- n_filters,
36
- ]
37
- )
38
- if leela:
39
- for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY):
40
- if filename.endswith(".pb.gz"):
41
- match = re.search(exp, filename)
42
- if match is None:
43
- n_filters = -1
44
- n_blocks = -1
45
- else:
46
- n_filters = int(match.group("n_filters"))
47
- n_blocks = int(match.group("n_blocks"))
48
- model_df.append(
49
- [
50
- filename,
51
- "LEELA",
52
- n_blocks,
53
- n_filters,
54
- ]
55
- )
56
- return model_df
57
-
58
-
59
- def save_model(tmp_file_path):
60
- """
61
- Save the model to the model directory.
62
- """
63
- popen = subprocess.Popen(
64
- ["file", tmp_file_path],
65
- stdout=subprocess.PIPE,
66
- stderr=subprocess.PIPE,
67
- )
68
- popen.wait()
69
- if popen.returncode != 0:
70
- raise RuntimeError
71
- file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip()
72
- rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc)
73
- type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc)
74
- if rename_match is None or type_match is None:
75
- raise RuntimeError
76
- model_name = rename_match.group("name")
77
- model_type = type_match.group("type")
78
- if model_type != "gzip":
79
- raise RuntimeError
80
- os.rename(
81
- tmp_file_path,
82
- f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
83
- )
84
- try:
85
- lczero_utils.describenet(
86
- f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
87
- )
88
- except RuntimeError:
89
- os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz")
90
- raise RuntimeError
91
-
92
-
93
- def get_wrapper_from_state(model_name):
94
- """
95
- Get the model wrapper from the state.
96
- """
97
- if model_name in state.wrappers:
98
- return state.wrappers[model_name]
99
- else:
100
- wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
101
- state.wrappers[model_name] = wrapper
102
- return wrapper
103
-
104
-
105
- def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs):
106
- """
107
- Get the model wrapper and lens from the state.
108
- """
109
- if model_name in state.wrappers:
110
- wrapper = state.wrappers[model_name]
111
- else:
112
- wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
113
- state.wrappers[model_name] = wrapper
114
- if lens_name in state.lenses[lens_type]:
115
- lens = state.lenses[lens_type][lens_name]
116
- else:
117
- lens = Lens.from_name(lens_type, **kwargs)
118
- if not lens.is_compatible(wrapper):
119
- raise ValueError(f"Lens of type {lens_type} not compatible with model.")
120
- state.lenses[lens_type][lens_name] = lens
121
- return wrapper, lens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/visualisation.py DELETED
@@ -1,303 +0,0 @@
1
- """
2
- Visualisation utils.
3
- """
4
-
5
- import chess
6
- import chess.svg
7
- import matplotlib
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- import torch
11
- import torchviz
12
-
13
- from . import constants
14
-
15
- COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
16
- ALPHA = 1.0
17
-
18
-
19
- def render_heatmap(
20
- board,
21
- heatmap,
22
- square=None,
23
- vmin=None,
24
- vmax=None,
25
- arrows=None,
26
- normalise="none",
27
- ):
28
- """
29
- Render a heatmap on the board.
30
- """
31
- if normalise == "abs":
32
- a_max = heatmap.abs().max()
33
- if a_max != 0:
34
- heatmap = heatmap / a_max
35
- vmin = -1
36
- vmax = 1
37
- if vmin is None:
38
- vmin = heatmap.min()
39
- if vmax is None:
40
- vmax = heatmap.max()
41
- norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
42
-
43
- color_dict = {}
44
- for square_index in range(64):
45
- color = COLOR_MAP(norm(heatmap[square_index]))
46
- color = (*color[:3], ALPHA)
47
- color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
48
- fig = plt.figure(figsize=(6, 0.6))
49
- ax = plt.gca()
50
- ax.axis("off")
51
- fig.colorbar(
52
- matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
53
- ax=ax,
54
- orientation="horizontal",
55
- fraction=1.0,
56
- )
57
- if square is not None:
58
- try:
59
- check = chess.parse_square(square)
60
- except ValueError:
61
- check = None
62
- else:
63
- check = None
64
- if arrows is None:
65
- arrows = []
66
- plt.close()
67
- return (
68
- chess.svg.board(
69
- board,
70
- check=check,
71
- fill=color_dict,
72
- size=350,
73
- arrows=arrows,
74
- ),
75
- fig,
76
- )
77
-
78
-
79
- def render_architecture(model, name: str = "model", directory: str = ""):
80
- """
81
- Render the architecture of the model.
82
- """
83
- out = model(torch.zeros(1, 112, 8, 8))
84
- if len(out) == 2:
85
- policy, outcome_probs = out
86
- value = torch.zeros(outcome_probs.shape[0], 1)
87
- else:
88
- policy, outcome_probs, value = out
89
- torchviz.make_dot(policy, params=dict(list(model.named_parameters()))).render(
90
- f"{directory}/{name}_policy", format="svg"
91
- )
92
- torchviz.make_dot(outcome_probs, params=dict(list(model.named_parameters()))).render(
93
- f"{directory}/{name}_outcome_probs", format="svg"
94
- )
95
- torchviz.make_dot(value, params=dict(list(model.named_parameters()))).render(
96
- f"{directory}/{name}_value", format="svg"
97
- )
98
-
99
-
100
- def render_policy_distribution(
101
- policy,
102
- legal_moves,
103
- n_bins=20,
104
- ):
105
- """
106
- Render the policy distribution histogram.
107
- """
108
- legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
109
- fig = plt.figure(figsize=(6, 6))
110
- ax = plt.gca()
111
- _, bins = np.histogram(policy, bins=n_bins)
112
- ax.hist(
113
- policy[~legal_mask],
114
- bins=bins,
115
- alpha=0.5,
116
- density=True,
117
- label="Illegal moves",
118
- )
119
- ax.hist(
120
- policy[legal_mask],
121
- bins=bins,
122
- alpha=0.5,
123
- density=True,
124
- label="Legal moves",
125
- )
126
- plt.xlabel("Policy")
127
- plt.ylabel("Density")
128
- plt.legend()
129
- plt.yscale("log")
130
- return fig
131
-
132
-
133
- def render_policy_statistics(
134
- statistics,
135
- ):
136
- """
137
- Render the policy statistics.
138
- """
139
- fig = plt.figure(figsize=(6, 6))
140
- ax = plt.gca()
141
- move_indices = list(statistics["mean_legal_logits"].keys())
142
- legal_means_avg = [np.mean(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
143
- illegal_means_avg = [np.mean(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
144
- legal_means_std = [np.std(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
145
- illegal_means_std = [np.std(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
146
- ax.errorbar(
147
- move_indices,
148
- legal_means_avg,
149
- yerr=legal_means_std,
150
- label="Legal moves",
151
- )
152
- ax.errorbar(
153
- move_indices,
154
- illegal_means_avg,
155
- yerr=illegal_means_std,
156
- label="Illegal moves",
157
- )
158
- plt.xlabel("Move index")
159
- plt.ylabel("Mean policy logits")
160
- plt.legend()
161
- return fig
162
-
163
-
164
- def render_relevance_proportion(statistics, scaled=True):
165
- """
166
- Render the relevance proportion statistics.
167
- """
168
- norm = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
169
- fig_hist = plt.figure(figsize=(6, 6))
170
- ax = plt.gca()
171
- move_indices = list(statistics["planes_relevance_proportion"].keys())
172
- for h in range(8):
173
- relevance_proportion_avg = [
174
- np.mean([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
175
- for move_idx in move_indices
176
- ]
177
- relevance_proportion_std = [
178
- np.std([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
179
- for move_idx in move_indices
180
- ]
181
- ax.errorbar(
182
- move_indices[h + 1 :],
183
- relevance_proportion_avg[h + 1 :],
184
- yerr=relevance_proportion_std[h + 1 :],
185
- label=f"History {h}",
186
- c=COLOR_MAP(norm(h / 9)),
187
- )
188
-
189
- relevance_proportion_avg = [
190
- np.mean([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
191
- for move_idx in move_indices
192
- ]
193
- relevance_proportion_std = [
194
- np.std([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
195
- for move_idx in move_indices
196
- ]
197
- ax.errorbar(
198
- move_indices,
199
- relevance_proportion_avg,
200
- yerr=relevance_proportion_std,
201
- label="Castling rights",
202
- c=COLOR_MAP(norm(8 / 9)),
203
- )
204
- relevance_proportion_avg = [
205
- np.mean([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
206
- for move_idx in move_indices
207
- ]
208
- relevance_proportion_std = [
209
- np.std([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
210
- for move_idx in move_indices
211
- ]
212
- ax.errorbar(
213
- move_indices,
214
- relevance_proportion_avg,
215
- yerr=relevance_proportion_std,
216
- label="Remaining planes",
217
- c=COLOR_MAP(norm(9 / 9)),
218
- )
219
- plt.xlabel("Move index")
220
- plt.ylabel("Absolute relevance proportion")
221
- plt.yscale("log")
222
- plt.legend()
223
-
224
- if scaled:
225
- stat_key = "planes_relevance_proportion_scaled"
226
- else:
227
- stat_key = "planes_relevance_proportion"
228
- fig_planes = plt.figure(figsize=(6, 6))
229
- ax = plt.gca()
230
- move_indices = list(statistics[stat_key].keys())
231
- for p in range(13):
232
- relevance_proportion_avg = [
233
- np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
234
- ]
235
- relevance_proportion_std = [
236
- np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
237
- ]
238
- ax.errorbar(
239
- move_indices,
240
- relevance_proportion_avg,
241
- yerr=relevance_proportion_std,
242
- label=constants.PLANE_NAMES[p],
243
- c=COLOR_MAP(norm(p / 12)),
244
- )
245
-
246
- plt.xlabel("Move index")
247
- plt.ylabel("Absolute relevance proportion")
248
- plt.yscale("log")
249
- plt.legend()
250
-
251
- fig_pieces = plt.figure(figsize=(6, 6))
252
- ax = plt.gca()
253
- for p in range(1, 13):
254
- stat_key = f"configuration_relevance_proportion_threatened_piece{p}"
255
- n_attackers = list(statistics[stat_key].keys())
256
- relevance_proportion_avg = [
257
- np.mean(statistics[f"configuration_relevance_proportion_threatened_piece{p}"][n]) for n in n_attackers
258
- ]
259
- relevance_proportion_std = [np.std(statistics[stat_key][n]) for n in n_attackers]
260
- ax.errorbar(
261
- n_attackers,
262
- relevance_proportion_avg,
263
- yerr=relevance_proportion_std,
264
- label="PNBRQKpnbrqk"[p - 1],
265
- c=COLOR_MAP(norm(p / 12)),
266
- )
267
-
268
- plt.xlabel("Number of attackers")
269
- plt.ylabel("Absolute configuration relevance proportion")
270
- plt.yscale("log")
271
- plt.legend()
272
-
273
- return fig_hist, fig_planes, fig_pieces
274
-
275
-
276
- def render_probing_statistics(
277
- statistics,
278
- ):
279
- """
280
- Render the probing statistics.
281
- """
282
- fig = plt.figure(figsize=(6, 6))
283
- ax = plt.gca()
284
- n_blocks = len(statistics["metrics"])
285
- for metric in statistics["metrics"]["block0"]:
286
- avg = []
287
- std = []
288
- for block_idx in range(n_blocks):
289
- metrics = statistics["metrics"]
290
- block_data = metrics[f"block{block_idx}"]
291
- avg.append(np.mean(block_data[metric]))
292
- std.append(np.std(block_data[metric]))
293
- ax.errorbar(
294
- range(n_blocks),
295
- avg,
296
- yerr=std,
297
- label=metric,
298
- )
299
- plt.xlabel("Block index")
300
- plt.ylabel("Metric")
301
- plt.yscale("log")
302
- plt.legend()
303
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{app β†’ demo}/__init__.py RENAMED
File without changes
demo/constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for the demo.
3
+ """
4
+
5
+ import os
6
+
7
+ ONNX_MODEL_DIRECTORY = "demo/onnx-models"
8
+ LEELA_MODEL_DIRECTORY = "demo/leela-models"
9
+ FIGURE_DIRECTORY = "demo/figures"
10
+
11
+ ONNX_MODEL_NAMES = [
12
+ f for f in os.listdir(ONNX_MODEL_DIRECTORY)
13
+ if f.endswith(".onnx")
14
+ ]
15
+ LEELA_MODEL_NAMES = [
16
+ f for f in os.listdir(LEELA_MODEL_DIRECTORY)
17
+ if f.endswith(".pb.gz")
18
+ ]
{app β†’ demo}/figures/.gitignore RENAMED
File without changes
demo/interfaces/__init__.py ADDED
File without changes
demo/interfaces/activations.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+ import os
10
+ import torch
11
+
12
+ from lczerolens import LczeroBoard, LczeroModel, Lens
13
+
14
+ from .. import constants
15
+
16
+ def get_model(model_name: str):
17
+ return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
18
+
19
+ def get_activations(model: LczeroModel, board: LczeroBoard):
20
+ lens = Lens.from_name("activation", "block\d/conv2/relu")
21
+ with torch.no_grad():
22
+ results = lens.analyse(model, board)
23
+ return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))]
24
+
25
+ def get_board(game_pgn:str, board_fen:str):
26
+ if game_pgn:
27
+ try:
28
+ board = LczeroBoard()
29
+ pgn = io.StringIO(game_pgn)
30
+ game = chess.pgn.read_game(pgn)
31
+ for move in game.mainline_moves():
32
+ board.push(move)
33
+ except Exception as e:
34
+ print(e)
35
+ gr.Warning("Error parsing PGN, using starting position.")
36
+ board = LczeroBoard()
37
+ else:
38
+ try:
39
+ board = LczeroBoard(board_fen)
40
+ except Exception as e:
41
+ print(e)
42
+ gr.Warning("Invalid FEN, using starting position.")
43
+ board = LczeroBoard()
44
+ return board
45
+
46
+ def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int):
47
+ if layer_index >= len(activations):
48
+ safe_layer_index = len(activations) - 1
49
+ gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).")
50
+ else:
51
+ safe_layer_index = layer_index
52
+ if channel_index >= activations[safe_layer_index].shape[0]:
53
+ safe_channel_index = activations[safe_layer_index].shape[0] - 1
54
+ gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).")
55
+ else:
56
+ safe_channel_index = channel_index
57
+ heatmap = activations[safe_layer_index][safe_channel_index].view(64)
58
+ board.render_heatmap(
59
+ heatmap,
60
+ save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg",
61
+ )
62
+ return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg"
63
+
64
+ def initial_load(model_name: str, board_fen: str, game_pgn: str, layer_index: int, channel_index: int):
65
+ model = get_model(model_name)
66
+ board = get_board(game_pgn, board_fen)
67
+ activations = get_activations(model, board)
68
+ plots = render_activations(board, activations, layer_index, channel_index)
69
+ return model, board, activations, *plots
70
+
71
+ def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, layer_index: int, channel_index: int):
72
+ board = get_board(game_pgn, board_fen)
73
+ activations = get_activations(model, board)
74
+ plots = render_activations(board, activations, layer_index, channel_index)
75
+ return board, activations, *plots
76
+
77
+ def on_model_change(model_name: str, board: LczeroBoard, layer_index: int, channel_index: int):
78
+ model = get_model(model_name)
79
+ activations = get_activations(model, board)
80
+ plots = render_activations(board, activations, layer_index, channel_index)
81
+ return model, activations, *plots
82
+
83
+ with gr.Blocks() as interface:
84
+ with gr.Row():
85
+ with gr.Column():
86
+ with gr.Group():
87
+ gr.Markdown(
88
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
89
+ )
90
+ game_pgn = gr.Textbox(
91
+ label="Game PGN",
92
+ lines=1,
93
+ value="",
94
+ )
95
+ board_fen = gr.Textbox(
96
+ label="Board FEN",
97
+ lines=1,
98
+ max_lines=1,
99
+ value=chess.STARTING_FEN,
100
+ )
101
+ with gr.Group():
102
+ model_name = gr.Dropdown(
103
+ label="Model",
104
+ choices=constants.ONNX_MODEL_NAMES,
105
+ )
106
+ layer_index = gr.Slider(
107
+ label="Layer index",
108
+ minimum=0,
109
+ maximum=19,
110
+ step=1,
111
+ value=0,
112
+ )
113
+ channel_index = gr.Slider(
114
+ label="Channel index",
115
+ minimum=0,
116
+ maximum=200,
117
+ step=1,
118
+ value=0,
119
+ )
120
+ with gr.Column():
121
+ image_board = gr.Image(label="Board", interactive=False)
122
+ colorbar = gr.Image(label="Colorbar", interactive=False)
123
+
124
+ model = gr.State(value=None)
125
+ board = gr.State(value=None)
126
+ activations = gr.State(value=None)
127
+
128
+ interface.load(
129
+ initial_load,
130
+ inputs=[model_name, game_pgn, board_fen, layer_index, channel_index],
131
+ outputs=[model, board, activations, image_board, colorbar],
132
+ )
133
+ game_pgn.submit(
134
+ on_board_change,
135
+ inputs=[model, game_pgn, board_fen, layer_index, channel_index],
136
+ outputs=[board, activations, image_board, colorbar],
137
+ )
138
+ board_fen.submit(
139
+ on_board_change,
140
+ inputs=[model, game_pgn, board_fen, layer_index, channel_index],
141
+ outputs=[board, activations, image_board, colorbar],
142
+ )
143
+ model_name.change(
144
+ on_model_change,
145
+ inputs=[model_name, board, layer_index, channel_index],
146
+ outputs=[model, activations, image_board, colorbar],
147
+ )
148
+ layer_index.change(
149
+ render_activations,
150
+ inputs=[board, activations, layer_index, channel_index],
151
+ outputs=[image_board, colorbar],
152
+ )
153
+ channel_index.change(
154
+ render_activations,
155
+ inputs=[board, activations, layer_index, channel_index],
156
+ outputs=[image_board, colorbar],
157
+ )
app/board_interface.py β†’ demo/interfaces/board.py RENAMED
@@ -3,11 +3,12 @@ Gradio interface for plotting a board.
3
  """
4
 
5
  import chess
 
6
  import gradio as gr
7
 
8
- from demo import constants
9
  from lczerolens.board import LczeroBoard
10
 
 
11
 
12
  def make_board_plot(board_fen, arrows, square):
13
  try:
@@ -15,34 +16,8 @@ def make_board_plot(board_fen, arrows, square):
15
  except ValueError:
16
  board = LczeroBoard()
17
  gr.Warning("Invalid FEN, using starting position.")
18
- try:
19
- if arrows:
20
- arrows_list = arrows.split(" ")
21
- chess_arrows = []
22
- for arrow in arrows_list:
23
- from_square, to_square = arrow[:2], arrow[2:]
24
- chess_arrows.append(
25
- (
26
- chess.parse_square(from_square),
27
- chess.parse_square(to_square),
28
- )
29
- )
30
- else:
31
- chess_arrows = []
32
- except ValueError:
33
- chess_arrows = []
34
- gr.Warning("Invalid arrows, using none.")
35
-
36
- color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
37
- svg_board = chess.svg.board(
38
- board,
39
- size=350,
40
- arrows=chess_arrows,
41
- fill=color_dict,
42
- )
43
- with open(f"{constants.FIGURE_DIRECTORY}/board.svg", "w") as f:
44
- f.write(svg_board)
45
- return f"{constants.FIGURE_DIRECTORY}/board.svg"
46
 
47
 
48
  with gr.Blocks() as interface:
@@ -76,6 +51,7 @@ with gr.Blocks() as interface:
76
  arrows,
77
  square,
78
  ]
 
79
  board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
80
  arrows.submit(make_board_plot, inputs=inputs, outputs=image)
81
- interface.load(make_board_plot, inputs=inputs, outputs=image)
 
3
  """
4
 
5
  import chess
6
+ import chess.svg
7
  import gradio as gr
8
 
 
9
  from lczerolens.board import LczeroBoard
10
 
11
+ from ..utils import create_board_figure
12
 
13
  def make_board_plot(board_fen, arrows, square):
14
  try:
 
16
  except ValueError:
17
  board = LczeroBoard()
18
  gr.Warning("Invalid FEN, using starting position.")
19
+ filepath = create_board_figure(board, arrows=arrows, square=square, name="board")
20
+ return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  with gr.Blocks() as interface:
 
51
  arrows,
52
  square,
53
  ]
54
+ interface.load(make_board_plot, inputs=inputs, outputs=image)
55
  board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
56
  arrows.submit(make_board_plot, inputs=inputs, outputs=image)
57
+ square.submit(make_board_plot, inputs=inputs, outputs=image)
demo/interfaces/encodings.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+
10
+ from lczerolens.board import LczeroBoard
11
+
12
+ from ..constants import FIGURE_DIRECTORY
13
+
14
+ def make_render(game_pgn:str, board_fen:str, plane_index:int):
15
+ if game_pgn:
16
+ try:
17
+ board = LczeroBoard()
18
+ pgn = io.StringIO(game_pgn)
19
+ game = chess.pgn.read_game(pgn)
20
+ for move in game.mainline_moves():
21
+ board.push(move)
22
+ except Exception as e:
23
+ print(e)
24
+ gr.Warning("Error parsing PGN, using starting position.")
25
+ board = LczeroBoard()
26
+ else:
27
+ try:
28
+ board = LczeroBoard(board_fen)
29
+ except Exception as e:
30
+ print(e)
31
+ gr.Warning("Invalid FEN, using starting position.")
32
+ board = LczeroBoard()
33
+ return board, *make_board_plot(board, plane_index)
34
+
35
+ def make_board_plot(board:LczeroBoard, plane_index:int):
36
+ input_tensor = board.to_input_tensor()
37
+ board.render_heatmap(
38
+ input_tensor[plane_index].view(64),
39
+ save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
40
+ vmin=0,
41
+ vmax=1,
42
+ )
43
+ return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
44
+
45
+ with gr.Blocks() as interface:
46
+ with gr.Row():
47
+ with gr.Column():
48
+ with gr.Group():
49
+ gr.Markdown(
50
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
51
+ )
52
+ game_pgn = gr.Textbox(
53
+ label="Game PGN",
54
+ lines=1,
55
+ value="",
56
+ )
57
+ board_fen = gr.Textbox(
58
+ label="Board FEN",
59
+ lines=1,
60
+ max_lines=1,
61
+ value=chess.STARTING_FEN,
62
+ )
63
+ with gr.Group():
64
+ with gr.Row():
65
+ plane_index = gr.Slider(
66
+ label="Plane index",
67
+ minimum=0,
68
+ maximum=111,
69
+ step=1,
70
+ value=0,
71
+ )
72
+ with gr.Column():
73
+ image_board = gr.Image(label="Board", interactive=False)
74
+ colorbar = gr.Image(label="Colorbar", interactive=False)
75
+
76
+ state_board = gr.State(value=LczeroBoard())
77
+
78
+ render_inputs = [game_pgn, board_fen, plane_index]
79
+ render_outputs = [state_board, image_board, colorbar]
80
+ interface.load(
81
+ make_render,
82
+ inputs=render_inputs,
83
+ outputs=render_outputs,
84
+ )
85
+ game_pgn.submit(
86
+ make_render,
87
+ inputs=render_inputs,
88
+ outputs=render_outputs,
89
+ )
90
+ board_fen.submit(
91
+ make_render,
92
+ inputs=render_inputs,
93
+ outputs=render_outputs,
94
+ )
95
+ plane_index.change(
96
+ make_board_plot,
97
+ inputs=[state_board, plane_index],
98
+ outputs=[image_board, colorbar],
99
+ )
demo/interfaces/gradients.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+ import os
10
+
11
+ from lczerolens import LczeroBoard, LczeroModel, Lens
12
+
13
+ from .. import constants
14
+
15
+ def get_model(model_name: str):
16
+ return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
17
+
18
+ def get_gradients(model: LczeroModel, board: LczeroBoard, wdl_target: str):
19
+ lens = Lens.from_name("gradient")
20
+ wdl_index = {"win": 0, "draw": 1, "loss": 2}[wdl_target]
21
+
22
+ def init_target(model):
23
+ return getattr(model, "output/wdl").output[:, wdl_index]
24
+ results = lens.analyse(model, board, init_target=init_target)
25
+
26
+ return results["input_grad"]
27
+
28
+ def get_board(game_pgn:str, board_fen:str):
29
+ if game_pgn:
30
+ try:
31
+ board = LczeroBoard()
32
+ pgn = io.StringIO(game_pgn)
33
+ game = chess.pgn.read_game(pgn)
34
+ for move in game.mainline_moves():
35
+ board.push(move)
36
+ except Exception as e:
37
+ print(e)
38
+ gr.Warning("Error parsing PGN, using starting position.")
39
+ board = LczeroBoard()
40
+ else:
41
+ try:
42
+ board = LczeroBoard(board_fen)
43
+ except Exception as e:
44
+ print(e)
45
+ gr.Warning("Invalid FEN, using starting position.")
46
+ board = LczeroBoard()
47
+ return board
48
+
49
+ def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int):
50
+ if average_over_planes:
51
+ heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64)
52
+ else:
53
+ heatmap = gradients[0, plane_index].view(64)
54
+ board.render_heatmap(
55
+ heatmap,
56
+ save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg",
57
+ )
58
+ return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg"
59
+
60
+ def initial_load(model_name: str, board_fen: str, game_pgn: str, wdl_target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
61
+ model = get_model(model_name)
62
+ board = get_board(game_pgn, board_fen)
63
+ gradients = get_gradients(model, board, wdl_target)
64
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
65
+ return model, board, gradients, *plots
66
+
67
+ def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, wdl_target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
68
+ board = get_board(game_pgn, board_fen)
69
+ gradients = get_gradients(model, board, wdl_target)
70
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
71
+ return board, gradients, *plots
72
+
73
+ def on_model_change(model_name: str, board: LczeroBoard, wdl_target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
74
+ model = get_model(model_name)
75
+ gradients = get_gradients(model, board, wdl_target)
76
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
77
+ return model, gradients, *plots
78
+
79
+ def on_wdl_target_change(model: LczeroModel, board: LczeroBoard, wdl_target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
80
+ gradients = get_gradients(model, board, wdl_target)
81
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
82
+ return gradients, *plots
83
+
84
+ with gr.Blocks() as interface:
85
+ with gr.Row():
86
+ with gr.Column():
87
+ with gr.Group():
88
+ gr.Markdown(
89
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
90
+ )
91
+ game_pgn = gr.Textbox(
92
+ label="Game PGN",
93
+ lines=1,
94
+ value="",
95
+ )
96
+ board_fen = gr.Textbox(
97
+ label="Board FEN",
98
+ lines=1,
99
+ max_lines=1,
100
+ value=chess.STARTING_FEN,
101
+ )
102
+ with gr.Group():
103
+ model_name = gr.Dropdown(
104
+ label="Model",
105
+ choices=constants.ONNX_MODEL_NAMES,
106
+ )
107
+ wdl_target = gr.Radio(
108
+ ["win", "draw", "loss"], label="WDL target",
109
+ value="win",
110
+ )
111
+ with gr.Group():
112
+ average_over_planes = gr.Checkbox(label="Average over Planes", value=False)
113
+ with gr.Accordion("Average over planes", open=False):
114
+ begin_average_index = gr.Slider(
115
+ label="Begin average index",
116
+ minimum=0,
117
+ maximum=111,
118
+ step=1,
119
+ value=0,
120
+ )
121
+ end_average_index = gr.Slider(
122
+ label="End average index",
123
+ minimum=0,
124
+ maximum=111,
125
+ step=1,
126
+ value=111,
127
+ )
128
+ plane_index = gr.Slider(
129
+ label="Plane index",
130
+ minimum=0,
131
+ maximum=111,
132
+ step=1,
133
+ value=0,
134
+ )
135
+
136
+ with gr.Column():
137
+ image_board = gr.Image(label="Board", interactive=False)
138
+ colorbar = gr.Image(label="Colorbar", interactive=False)
139
+
140
+ model = gr.State(value=None)
141
+ board = gr.State(value=None)
142
+ gradients = gr.State(value=None)
143
+
144
+ interface.load(
145
+ initial_load,
146
+ inputs=[model_name, game_pgn, board_fen, wdl_target, average_over_planes, begin_average_index, end_average_index, plane_index],
147
+ outputs=[model, board, gradients, image_board, colorbar],
148
+ )
149
+ game_pgn.submit(
150
+ on_board_change,
151
+ inputs=[model, game_pgn, board_fen, wdl_target, average_over_planes, begin_average_index, end_average_index, plane_index],
152
+ outputs=[board, gradients, image_board, colorbar],
153
+ )
154
+ board_fen.submit(
155
+ on_board_change,
156
+ inputs=[model, game_pgn, board_fen, wdl_target, average_over_planes, begin_average_index, end_average_index, plane_index],
157
+ outputs=[board, gradients, image_board, colorbar],
158
+ )
159
+ model_name.change(
160
+ on_model_change,
161
+ inputs=[model_name, board, wdl_target, average_over_planes, begin_average_index, end_average_index, plane_index],
162
+ outputs=[model, gradients, image_board, colorbar],
163
+ )
164
+ wdl_target.change(
165
+ on_wdl_target_change,
166
+ inputs=[model, board, wdl_target, average_over_planes, begin_average_index, end_average_index, plane_index],
167
+ outputs=[gradients, image_board, colorbar],
168
+ )
169
+ for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]:
170
+ render_arg.change(
171
+ render_gradients,
172
+ inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index],
173
+ outputs=[image_board, colorbar],
174
+ )
demo/interfaces/play.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interface to play against the model.
2
+ """
3
+
4
+ import os
5
+
6
+ import chess
7
+ import chess.pgn
8
+ import random
9
+ import gradio as gr
10
+
11
+ from lczerolens import LczeroBoard, LczeroModel
12
+ from lczerolens.play import PolicySampler
13
+
14
+ from .. import constants
15
+ from ..utils import create_board_figure
16
+
17
+
18
+ def get_sampler(model_name: str):
19
+ model = LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
20
+ return PolicySampler(model)
21
+
22
+ def get_pgn(board: LczeroBoard):
23
+ game = chess.pgn.Game()
24
+ for move in board.move_stack:
25
+ game.add_variation(move)
26
+ return str(game).split("\n")[-1]
27
+
28
+ def render_board(
29
+ board: LczeroBoard,
30
+ ):
31
+ player = board.turn
32
+ if len(board.move_stack) > 0:
33
+ last_move_uci = board.peek().uci()
34
+ else:
35
+ last_move_uci = None
36
+
37
+ if board.is_check():
38
+ check = board.king(board.turn)
39
+ else:
40
+ check = None
41
+ filepath = create_board_figure(
42
+ board,
43
+ orientation=player,
44
+ arrows=last_move_uci,
45
+ square=check,
46
+ name="play_board",
47
+ )
48
+ return filepath
49
+
50
+ def gather_outputs(board: LczeroBoard, sampler: PolicySampler):
51
+ return sampler, board, board.fen(), get_pgn(board), render_board(board), ""
52
+
53
+ def get_init(model_name: str):
54
+ sampler = get_sampler(model_name)
55
+ is_ai_white = random.choice([True, False])
56
+ init_board = LczeroBoard()
57
+ if is_ai_white:
58
+ play_ai_move(init_board, sampler)
59
+ return gather_outputs(init_board, sampler)
60
+
61
+ def play_user_move_then_ai_move(
62
+ uci_move: str,
63
+ board: LczeroBoard,
64
+ sampler: PolicySampler,
65
+ ):
66
+ board.push_uci(uci_move)
67
+ play_ai_move(board, sampler)
68
+ return gather_outputs(board, sampler)
69
+
70
+
71
+ def play_ai_move(
72
+ board: LczeroBoard,
73
+ sampler: PolicySampler,
74
+ ):
75
+ move, _ = next(iter(sampler.get_next_moves([board])))
76
+ board.push(move)
77
+
78
+ with gr.Blocks() as interface:
79
+ with gr.Row():
80
+ with gr.Column():
81
+ current_fen = gr.Textbox(
82
+ label="Board FEN",
83
+ lines=1,
84
+ max_lines=1,
85
+ value=chess.STARTING_FEN,
86
+ )
87
+ current_pgn = gr.Textbox(
88
+ label="Action sequence",
89
+ lines=1,
90
+ value="",
91
+ )
92
+ with gr.Row():
93
+ move_to_play = gr.Textbox(
94
+ label="Move to play (UCI)",
95
+ lines=1,
96
+ max_lines=1,
97
+ value="",
98
+ )
99
+ with gr.Column():
100
+ model_name = gr.Dropdown(
101
+ label="Model",
102
+ choices=constants.ONNX_MODEL_NAMES,
103
+ )
104
+ play_button = gr.Button("Play")
105
+ reset_button = gr.Button("Reset")
106
+ with gr.Column():
107
+ image_board = gr.Image(label="Board", interactive=False)
108
+
109
+ sampler = gr.State(value=None)
110
+ board = gr.State(value=None)
111
+
112
+ outputs = [sampler, board, current_fen, current_pgn, image_board, move_to_play]
113
+
114
+ play_button.click(
115
+ play_user_move_then_ai_move,
116
+ inputs=[move_to_play, board, sampler],
117
+ outputs=outputs,
118
+ )
119
+ move_to_play.submit(
120
+ play_user_move_then_ai_move,
121
+ inputs=[move_to_play, board, sampler],
122
+ outputs=outputs,
123
+ )
124
+
125
+ model_name.change(
126
+ get_sampler,
127
+ inputs=[model_name],
128
+ outputs=[sampler],
129
+ )
130
+
131
+ reset_button.click(
132
+ get_init,
133
+ inputs=[model_name],
134
+ outputs=outputs,
135
+ )
136
+ interface.load(
137
+ get_init,
138
+ inputs=[model_name],
139
+ outputs=outputs,
140
+ )
{app/leela_models β†’ demo/leela-models}/.gitignore RENAMED
File without changes
{app/onnx_models β†’ demo/onnx-models}/.gitignore RENAMED
File without changes
demo/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import chess.svg
4
+ from typing import Optional
5
+
6
+ from lczerolens.board import LczeroBoard
7
+
8
+ from . import constants
9
+
10
+
11
+ def create_board_figure(
12
+ board: LczeroBoard,
13
+ *,
14
+ orientation: bool = chess.WHITE,
15
+ arrows: str = "",
16
+ square: str = "",
17
+ name: str = "board",
18
+ ):
19
+ try:
20
+ if arrows:
21
+ arrows_list = arrows.split(" ")
22
+ chess_arrows = []
23
+ for arrow in arrows_list:
24
+ from_square, to_square = arrow[:2], arrow[2:]
25
+ chess_arrows.append(
26
+ (
27
+ chess.parse_square(from_square),
28
+ chess.parse_square(to_square),
29
+ )
30
+ )
31
+ else:
32
+ chess_arrows = []
33
+ except ValueError:
34
+ chess_arrows = []
35
+ gr.Warning("Invalid arrows, using none.")
36
+
37
+ try:
38
+ color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
39
+ except ValueError:
40
+ color_dict = {}
41
+ gr.Warning("Invalid square, using none.")
42
+
43
+ svg_board = chess.svg.board(
44
+ board,
45
+ size=350,
46
+ orientation=orientation,
47
+ arrows=chess_arrows,
48
+ fill=color_dict,
49
+ )
50
+ with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f:
51
+ f.write(svg_board)
52
+ return f"{constants.FIGURE_DIRECTORY}/{name}.svg"
main.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for chess project.
3
+ """
4
+
5
+ import gradio as gr
6
+ import subprocess
7
+
8
+ from demo.interfaces import (
9
+ board,
10
+ encodings,
11
+ gradients,
12
+ play,
13
+ activations,
14
+ )
15
+
16
+ demo = gr.TabbedInterface(
17
+ [
18
+ board.interface,
19
+ play.interface,
20
+ encodings.interface,
21
+ activations.interface,
22
+ gradients.interface,
23
+ ],
24
+ [
25
+ "Board",
26
+ "Play",
27
+ "Encodings",
28
+ "Activations",
29
+ "Gradients",
30
+ ],
31
+ title="Chess Project Demo",
32
+ analytics_enabled=False,
33
+ )
34
+
35
+ if __name__ == "__main__":
36
+ subprocess.run(["bash", "resolve-assets.sh"])
37
+ demo.launch(
38
+ server_port=8000,
39
+ server_name="0.0.0.0",
40
+ )
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "lczerolens-demo"
3
+ version = "0.1.0"
4
+ description = "Demo lczerolens features."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "gdown>=5.2.0",
9
+ "gradio>=5.20.1",
10
+ "lczerolens[viz]>=0.3.1",
11
+ ]
resolve-assets.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ uv run gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O demo/onnx-models/lc0-10-4238.onnx
2
+ uv run gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O demo/onnx-models/lc0-19-1876.onnx
3
+ uv run gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O demo/onnx-models/lc0-19-4508.onnx
4
+ uv run gdown 1TI429e9mr2de7LjHp2IIl7ouMoUaDjjZ -O demo/onnx-models/maia-1100.onnx
5
+ uv run gdown 1-8IJ5WYMPpcxOsHfIKY8xKskwk2z_yrY -O demo/onnx-models/maia-1900.onnx
uv.lock ADDED
The diff for this file is too large to render. See raw diff