File size: 5,019 Bytes
2d736ae
 
 
 
 
 
 
 
 
 
 
d28b192
2d736ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d28b192
2d736ae
 
 
 
 
 
 
 
 
 
 
 
 
 
d28b192
 
2d736ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47944b0
 
 
 
 
 
2d736ae
 
 
 
 
47944b0
2d736ae
 
47944b0
2d736ae
 
 
 
 
 
 
 
 
 
 
 
 
dc3330f
 
 
a9575c1
 
2d736ae
 
 
 
 
 
d9b9935
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import chess
import chess.svg
from collections import deque
from transformers import pipeline
import torch

DEBUG = False
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
    "text-generation", 
    model="jrahn/RookWorld-LM-124M",
    torch_dtype=torch.bfloat16,
    device=device
)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
pipe.tokenizer.padding_side = "left"
sampling_args = {
    "do_sample": True,
    "temperature": 0.7,
    "top_k": 15,
    "truncation": True,
    "return_full_text": False,
    "pad_token_id": pipe.tokenizer.eos_token_id,
    "max_length": 186
}
START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"

def generate_action(state):
    prompt = f"P: {state} "
    if DEBUG: print(prompt)
    generation = pipe(prompt, **sampling_args)
    if DEBUG: print(generation)
    try:
        action = generation[0]['generated_text'].split("B: ")[-1].strip()
        gr.Info(f"Policy generated move: {action}", duration=3)
        # TODO: display generated CoT
    except:
        gr.Info(f"Policy generation invalid: {generation}", duration=None)
        action = "0000"
    if DEBUG: print(action)
    return action

def generate_state(state, action, history):
    if DEBUG: print(state, action, history)
    history_text = " ".join(history[-10:])
    prompt = f"A: {state}+{action}+{history_text}+"
    if DEBUG: print(prompt)
    generation = pipe(prompt, **sampling_args)
    if DEBUG: print(generation)
    try:
        parts = generation[0]['generated_text'].split("+")
        new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1]
        #gr.Info(f"Environment generated state: {new_state}", duration=3)
    except:
        new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
        gr.Info(f"Environment generation invalid: {generation}", duration=None)
    if DEBUG: print(new_state, reward, terminated, truncated)
    return new_state, reward, terminated, truncated

def step_episode(inputs):
    state, history = inputs
    action = generate_action(state)
    if action == "0000":
        svg_string = create_chess_board_svg()
        return svg_string, START_POSITION, "", [START_POSITION, []]
    history.append(action)
    
    new_state, reward, terminated, truncated = generate_state(state, action, history)

    if int(terminated):
        player = "White" if state.split()[1] == 'w' else "Black"
        result_message = ""
        if reward == "-1":
            result_message = f"Environment ended game: {player} lost!"
        elif reward == "1":
            result_message = f"Environment ended game: {player} won!"
        elif reward == "0.5":
            result_message = "Environment ended game: It's a draw!"
        else:
            result_message = "Environment ended game: Unexpected outcome"
        
        gr.Info(result_message, duration=None)
        svg_string = create_chess_board_svg()
        return svg_string, START_POSITION, "", [START_POSITION, []]
    
    if int(truncated):
        gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None)
        svg_string = create_chess_board_svg()
        return svg_string, START_POSITION, "", [START_POSITION, []]

    try:
        mv = chess.Move.from_uci(action)
        svg_string = create_chess_board_svg(new_state, lastmove=mv)
    except:
        svg_string = create_chess_board_svg(new_state)
    if not svg_string:
        svg_string = create_chess_board_svg()
        return svg_string, [START_POSITION, []], START_POSITION, ""
    return svg_string, new_state, ", ".join(history), [new_state, history]

def create_chess_board_svg(fen=None, lastmove=None):
    try:
        board = chess.Board(fen) if fen else chess.Board()
        return chess.svg.board(board, lastmove=lastmove, size=400)
    except:
        gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None)
        return ""
    
board = gr.HTML("""
    <div style='height: 400px; width: 400px; background-color: white;'></div>
    """, label="Chess Board")
state_fen = gr.Textbox(label="FEN")
move_history = gr.Textbox(label="Move history")
demo = gr.Interface(
    fn=step_episode,
    inputs=gr.State(value=[START_POSITION, []]),
    outputs=[board, state_fen, move_history, gr.State()],
    title="♜ RookWorld-LM-124M Self-Play Demo",
    description="""♜ RookWorld-LM (GPT2-124M) Unified Policy & World Model  
    Both the *policy actions* (with generated CoT) and the *environment response* (World Model) are fully generated by a single language model.
    Click the **Generate**-button to generate a new move and environment response. On CPU this can take ~30 seconds per step.
    [Project Details](https://huggingface.co/collections/jrahn/rookworld-and-rook-reasoning-over-organized-knowledge-679b511567f95e05d9c4a7e7)""",
    allow_flagging="never",
    analytics_enabled=False,
    submit_btn="Generate",
    clear_btn="Reset",
)

demo.launch()