File size: 5,019 Bytes
2d736ae d28b192 2d736ae d28b192 2d736ae d28b192 2d736ae 47944b0 2d736ae 47944b0 2d736ae 47944b0 2d736ae dc3330f a9575c1 2d736ae d9b9935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import chess
import chess.svg
from collections import deque
from transformers import pipeline
import torch
DEBUG = False
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
"text-generation",
model="jrahn/RookWorld-LM-124M",
torch_dtype=torch.bfloat16,
device=device
)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
pipe.tokenizer.padding_side = "left"
sampling_args = {
"do_sample": True,
"temperature": 0.7,
"top_k": 15,
"truncation": True,
"return_full_text": False,
"pad_token_id": pipe.tokenizer.eos_token_id,
"max_length": 186
}
START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
def generate_action(state):
prompt = f"P: {state} "
if DEBUG: print(prompt)
generation = pipe(prompt, **sampling_args)
if DEBUG: print(generation)
try:
action = generation[0]['generated_text'].split("B: ")[-1].strip()
gr.Info(f"Policy generated move: {action}", duration=3)
# TODO: display generated CoT
except:
gr.Info(f"Policy generation invalid: {generation}", duration=None)
action = "0000"
if DEBUG: print(action)
return action
def generate_state(state, action, history):
if DEBUG: print(state, action, history)
history_text = " ".join(history[-10:])
prompt = f"A: {state}+{action}+{history_text}+"
if DEBUG: print(prompt)
generation = pipe(prompt, **sampling_args)
if DEBUG: print(generation)
try:
parts = generation[0]['generated_text'].split("+")
new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1]
#gr.Info(f"Environment generated state: {new_state}", duration=3)
except:
new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
gr.Info(f"Environment generation invalid: {generation}", duration=None)
if DEBUG: print(new_state, reward, terminated, truncated)
return new_state, reward, terminated, truncated
def step_episode(inputs):
state, history = inputs
action = generate_action(state)
if action == "0000":
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
history.append(action)
new_state, reward, terminated, truncated = generate_state(state, action, history)
if int(terminated):
player = "White" if state.split()[1] == 'w' else "Black"
result_message = ""
if reward == "-1":
result_message = f"Environment ended game: {player} lost!"
elif reward == "1":
result_message = f"Environment ended game: {player} won!"
elif reward == "0.5":
result_message = "Environment ended game: It's a draw!"
else:
result_message = "Environment ended game: Unexpected outcome"
gr.Info(result_message, duration=None)
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
if int(truncated):
gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None)
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
try:
mv = chess.Move.from_uci(action)
svg_string = create_chess_board_svg(new_state, lastmove=mv)
except:
svg_string = create_chess_board_svg(new_state)
if not svg_string:
svg_string = create_chess_board_svg()
return svg_string, [START_POSITION, []], START_POSITION, ""
return svg_string, new_state, ", ".join(history), [new_state, history]
def create_chess_board_svg(fen=None, lastmove=None):
try:
board = chess.Board(fen) if fen else chess.Board()
return chess.svg.board(board, lastmove=lastmove, size=400)
except:
gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None)
return ""
board = gr.HTML("""
<div style='height: 400px; width: 400px; background-color: white;'></div>
""", label="Chess Board")
state_fen = gr.Textbox(label="FEN")
move_history = gr.Textbox(label="Move history")
demo = gr.Interface(
fn=step_episode,
inputs=gr.State(value=[START_POSITION, []]),
outputs=[board, state_fen, move_history, gr.State()],
title="♜ RookWorld-LM-124M Self-Play Demo",
description="""♜ RookWorld-LM (GPT2-124M) Unified Policy & World Model
Both the *policy actions* (with generated CoT) and the *environment response* (World Model) are fully generated by a single language model.
Click the **Generate**-button to generate a new move and environment response. On CPU this can take ~30 seconds per step.
[Project Details](https://huggingface.co/collections/jrahn/rookworld-and-rook-reasoning-over-organized-knowledge-679b511567f95e05d9c4a7e7)""",
allow_flagging="never",
analytics_enabled=False,
submit_btn="Generate",
clear_btn="Reset",
)
demo.launch() |