improve interface
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
|
|
3 |
import random
|
4 |
import chess
|
5 |
import chess.svg
|
6 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer,
|
7 |
|
8 |
token = os.environ['auth_token']
|
9 |
|
@@ -44,38 +44,58 @@ def predict_move(fen, top_k=3):
|
|
44 |
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
|
45 |
return p['label']
|
46 |
|
47 |
-
def
|
48 |
global board
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
move = predict_move(board.fen())
|
53 |
-
board.push_uci(move)
|
54 |
else:
|
55 |
-
|
56 |
-
|
57 |
-
board.push_uci(inp_move)
|
58 |
-
if inp_notation == 'SAN':
|
59 |
-
board.push_san(inp_move)
|
60 |
with open('board.svg', 'w') as f:
|
61 |
f.write(str(chess.svg.board(board)))
|
62 |
-
print(state)
|
63 |
return 'board.svg', board.fen()
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import random
|
4 |
import chess
|
5 |
import chess.svg
|
6 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
7 |
|
8 |
token = os.environ['auth_token']
|
9 |
|
|
|
44 |
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
|
45 |
return p['label']
|
46 |
|
47 |
+
def btn_load(inp_fen):
|
48 |
global board
|
49 |
+
|
50 |
+
if inp_fen:
|
51 |
+
board = chess.Board(inp_fen)
|
|
|
|
|
52 |
else:
|
53 |
+
board.reset()
|
54 |
+
|
|
|
|
|
|
|
55 |
with open('board.svg', 'w') as f:
|
56 |
f.write(str(chess.svg.board(board)))
|
|
|
57 |
return 'board.svg', board.fen()
|
58 |
|
59 |
+
def btn_play(inp_move, inp_notation):
|
60 |
+
global board
|
61 |
+
|
62 |
+
if inp_move:
|
63 |
+
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move)
|
64 |
+
elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move)
|
65 |
+
else:
|
66 |
+
mv = chess.Move.from_uci(predict_move(board.fen()))
|
67 |
+
|
68 |
+
board.push(mv)
|
69 |
+
|
70 |
+
with open('board.svg', 'w') as f:
|
71 |
+
f.write(str(chess.svg.board(board, lastmove=mv)))
|
72 |
+
|
73 |
+
return 'board.svg', board.fen(), ''
|
74 |
+
|
75 |
+
with gr.Blocks() as block:
|
76 |
+
gr.Markdown('# Play YoloChess - Policy Network v0.3')
|
77 |
+
with gr.Row() as row:
|
78 |
+
with gr.Column():
|
79 |
+
move = gr.Textbox(label='human player move')
|
80 |
+
notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
|
81 |
+
fen = gr.Textbox(placeholder=board.fen(), label='FEN')
|
82 |
+
with gr.Row():
|
83 |
+
load_btn = gr.Button("Load")
|
84 |
+
play_btn = gr.Button("Play")
|
85 |
+
with gr.Column():
|
86 |
+
position_output = gr.Image(label='board')
|
87 |
+
|
88 |
+
load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen])
|
89 |
+
play_btn.click(fn=btn_play, inputs=[move, notation], outputs=[position_output, fen, move])
|
90 |
+
|
91 |
+
gr.Markdown(
|
92 |
+
'''
|
93 |
+
- Click "Load" button to start and reset board.
|
94 |
+
- Click "Play" button to get Engine move.
|
95 |
+
- Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.
|
96 |
+
- Output "ERROR" generally occurs on illegal moves (Human or Engine).
|
97 |
+
- Enter "FEN" to start from a custom position.
|
98 |
+
'''
|
99 |
+
)
|
100 |
+
|
101 |
+
block.launch()
|