|
import gradio as gr |
|
import chess |
|
import chess.svg |
|
from collections import deque |
|
from transformers import pipeline |
|
import torch |
|
|
|
DEBUG = False |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipe = pipeline( |
|
"text-generation", |
|
model="jrahn/RookWorld-LM-124M", |
|
torch_dtype=torch.bfloat16, |
|
device=device |
|
) |
|
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id |
|
pipe.tokenizer.padding_side = "left" |
|
sampling_args = { |
|
"do_sample": True, |
|
"temperature": 0.7, |
|
"top_k": 15, |
|
"truncation": True, |
|
"return_full_text": False, |
|
"pad_token_id": pipe.tokenizer.eos_token_id, |
|
"max_length": 186 |
|
} |
|
START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" |
|
|
|
def generate_action(state): |
|
prompt = f"P: {state} " |
|
if DEBUG: print(prompt) |
|
generation = pipe(prompt, **sampling_args) |
|
if DEBUG: print(generation) |
|
try: |
|
action = generation[0]['generated_text'].split("B: ")[-1].strip() |
|
gr.Info(f"Policy generated move: {action}", duration=3) |
|
|
|
except: |
|
gr.Info(f"Policy generation invalid: {generation}", duration=None) |
|
action = "0000" |
|
if DEBUG: print(action) |
|
return action |
|
|
|
def generate_state(state, action, history): |
|
if DEBUG: print(state, action, history) |
|
history_text = " ".join(history[-10:]) |
|
prompt = f"A: {state}+{action}+{history_text}+" |
|
if DEBUG: print(prompt) |
|
generation = pipe(prompt, **sampling_args) |
|
if DEBUG: print(generation) |
|
try: |
|
parts = generation[0]['generated_text'].split("+") |
|
new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1] |
|
|
|
except: |
|
new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1" |
|
gr.Info(f"Environment generation invalid: {generation}", duration=None) |
|
if DEBUG: print(new_state, reward, terminated, truncated) |
|
return new_state, reward, terminated, truncated |
|
|
|
def step_episode(inputs): |
|
state, history = inputs |
|
action = generate_action(state) |
|
if action == "0000": |
|
svg_string = create_chess_board_svg() |
|
return svg_string, START_POSITION, "", [START_POSITION, []] |
|
history.append(action) |
|
|
|
new_state, reward, terminated, truncated = generate_state(state, action, history) |
|
|
|
if int(terminated): |
|
player = "White" if state.split()[1] == 'w' else "Black" |
|
result_message = "" |
|
if reward == "-1": |
|
result_message = f"Environment ended game: {player} lost!" |
|
elif reward == "1": |
|
result_message = f"Environment ended game: {player} won!" |
|
elif reward == "0.5": |
|
result_message = "Environment ended game: It's a draw!" |
|
else: |
|
result_message = "Environment ended game: Unexpected outcome" |
|
|
|
gr.Info(result_message, duration=None) |
|
svg_string = create_chess_board_svg() |
|
return svg_string, START_POSITION, "", [START_POSITION, []] |
|
|
|
if int(truncated): |
|
gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None) |
|
svg_string = create_chess_board_svg() |
|
return svg_string, START_POSITION, "", [START_POSITION, []] |
|
|
|
try: |
|
mv = chess.Move.from_uci(action) |
|
svg_string = create_chess_board_svg(new_state, lastmove=mv) |
|
except: |
|
svg_string = create_chess_board_svg(new_state) |
|
if not svg_string: |
|
svg_string = create_chess_board_svg() |
|
return svg_string, [START_POSITION, []], START_POSITION, "" |
|
return svg_string, new_state, ", ".join(history), [new_state, history] |
|
|
|
def create_chess_board_svg(fen=None, lastmove=None): |
|
try: |
|
board = chess.Board(fen) if fen else chess.Board() |
|
return chess.svg.board(board, lastmove=lastmove, size=400) |
|
except: |
|
gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None) |
|
return "" |
|
|
|
board = gr.HTML(""" |
|
<div style='height: 400px; width: 400px; background-color: white;'></div> |
|
""", label="Chess Board") |
|
state_fen = gr.Textbox(label="FEN") |
|
move_history = gr.Textbox(label="Move history") |
|
demo = gr.Interface( |
|
fn=step_episode, |
|
inputs=gr.State(value=[START_POSITION, []]), |
|
outputs=[board, state_fen, move_history, gr.State()], |
|
title="♜ RookWorld-LM-124M Self-Play Demo", |
|
description="""♜ RookWorld-LM (GPT2-124M) Unified Policy & World Model |
|
Both the *policy actions* (with generated CoT) and the *environment response* (World Model) are fully generated by a single language model. |
|
Click the **Generate**-button to generate a new move and environment response. On CPU this can take ~30 seconds per step. |
|
[Project Details](https://huggingface.co/collections/jrahn/rookworld-and-rook-reasoning-over-organized-knowledge-679b511567f95e05d9c4a7e7)""", |
|
allow_flagging="never", |
|
analytics_enabled=False, |
|
submit_btn="Generate", |
|
clear_btn="Reset", |
|
) |
|
|
|
demo.launch() |