jrahn commited on
Commit
2d736ae
·
verified ·
1 Parent(s): 1f794a3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import chess
3
+ import chess.svg
4
+ from collections import deque
5
+ from transformers import pipeline
6
+ import torch
7
+
8
+ DEBUG = False
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ pipe = pipeline(
11
+ "text-generation",
12
+ model="jrahn/rookworld_7m_3e_gpt2_124M_hf",
13
+ torch_dtype=torch.bfloat16,
14
+ device=device
15
+ )
16
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
17
+ pipe.tokenizer.padding_side = "left"
18
+ sampling_args = {
19
+ "do_sample": True,
20
+ "temperature": 0.7,
21
+ "top_k": 15,
22
+ "truncation": True,
23
+ "return_full_text": False,
24
+ "pad_token_id": pipe.tokenizer.eos_token_id,
25
+ "max_length": 186
26
+ }
27
+ START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
28
+
29
+ def generate_action(state):
30
+ prompt = f"P: {state} "
31
+ if DEBUG: print(prompt)
32
+ generation = pipe(prompt, **sampling_args)
33
+ if DEBUG: print(generation)
34
+ try:
35
+ action = generation[0]['generated_text'].split("B: ")[-1].strip()
36
+ gr.Info(f"Policy generated move: {action}", duration=3)
37
+ except:
38
+ gr.Info(f"Policy generation invalid: {generation}", duration=None)
39
+ action = "0000"
40
+ if DEBUG: print(action)
41
+ return action
42
+
43
+ def generate_state(state, action, history):
44
+ if DEBUG: print(state, action, history)
45
+ history_text = " ".join(history[-10:])
46
+ prompt = f"A: {state}+{action}+{history_text}+"
47
+ if DEBUG: print(prompt)
48
+ generation = pipe(prompt, **sampling_args)
49
+ if DEBUG: print(generation)
50
+ try:
51
+ new_state, reward, terminated, truncated = generation[0]['generated_text'].split("+")
52
+ #gr.Info(f"Environment generated state: {new_state}", duration=3)
53
+ except:
54
+ new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
55
+ gr.Info(f"Environment generation invalid: {generation}", duration=None)
56
+ if DEBUG: print(new_state, reward, terminated, truncated)
57
+ return new_state, reward, terminated, truncated
58
+
59
+ def step_episode(inputs):
60
+ state, history = inputs
61
+ action = generate_action(state)
62
+ if action == "0000":
63
+ svg_string = create_chess_board_svg()
64
+ return svg_string, START_POSITION, "", [START_POSITION, []]
65
+ history.append(action)
66
+
67
+ new_state, reward, terminated, truncated = generate_state(state, action, history)
68
+
69
+ if int(terminated):
70
+ player = "White" if state.split()[1] == 'w' else "Black"
71
+ result_message = ""
72
+ if reward == "-1":
73
+ result_message = f"Environment ended game: {player} lost!"
74
+ elif reward == "1":
75
+ result_message = f"Environment ended game: {player} won!"
76
+ elif reward == "0.5":
77
+ result_message = "Environment ended game: It's a draw!"
78
+ else:
79
+ result_message = "Environment ended game: Unexpected outcome"
80
+
81
+ gr.Info(result_message, duration=None)
82
+ svg_string = create_chess_board_svg()
83
+ return svg_string, START_POSITION, "", [START_POSITION, []]
84
+
85
+ if int(truncated):
86
+ gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None)
87
+ svg_string = create_chess_board_svg()
88
+ return svg_string, START_POSITION, "", [START_POSITION, []]
89
+
90
+ svg_string = create_chess_board_svg(new_state)
91
+ if not svg_string:
92
+ svg_string = create_chess_board_svg()
93
+ return svg_string, [START_POSITION, []], START_POSITION, ""
94
+ return svg_string, new_state, ", ".join(history), [new_state, history]
95
+
96
+ def create_chess_board_svg(fen=None):
97
+ try:
98
+ board = chess.Board(fen) if fen else chess.Board()
99
+ return chess.svg.board(board, size=400)
100
+ except:
101
+ gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None)
102
+ return ""
103
+
104
+ board = gr.HTML("""
105
+ <div style='height: 400px; width: 400px; background-color: white;'></div>
106
+ """, label="Chess Board")
107
+ state_fen = gr.Textbox(label="FEN")
108
+ move_history = gr.Textbox(label="Move history")
109
+ demo = gr.Interface(
110
+ fn=step_episode,
111
+ inputs=gr.State(value=[START_POSITION, []]),
112
+ outputs=[board, state_fen, move_history, gr.State()],
113
+ title="RookWorld Self-Play Demo",
114
+ description="Click the Generate-Button to generate a new move and environment response. On CPU this can take ",
115
+ allow_flagging="never",
116
+ analytics_enabled=False,
117
+ submit_btn="Generate",
118
+ clear_btn="Reset",
119
+ )
120
+
121
+ if __name__ == "__main__":
122
+ demo.launch()