RookWorld / app.py
jrahn's picture
Update app.py
a9575c1 verified
raw
history blame contribute delete
5.02 kB
import gradio as gr
import chess
import chess.svg
from collections import deque
from transformers import pipeline
import torch
DEBUG = False
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
"text-generation",
model="jrahn/RookWorld-LM-124M",
torch_dtype=torch.bfloat16,
device=device
)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
pipe.tokenizer.padding_side = "left"
sampling_args = {
"do_sample": True,
"temperature": 0.7,
"top_k": 15,
"truncation": True,
"return_full_text": False,
"pad_token_id": pipe.tokenizer.eos_token_id,
"max_length": 186
}
START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
def generate_action(state):
prompt = f"P: {state} "
if DEBUG: print(prompt)
generation = pipe(prompt, **sampling_args)
if DEBUG: print(generation)
try:
action = generation[0]['generated_text'].split("B: ")[-1].strip()
gr.Info(f"Policy generated move: {action}", duration=3)
# TODO: display generated CoT
except:
gr.Info(f"Policy generation invalid: {generation}", duration=None)
action = "0000"
if DEBUG: print(action)
return action
def generate_state(state, action, history):
if DEBUG: print(state, action, history)
history_text = " ".join(history[-10:])
prompt = f"A: {state}+{action}+{history_text}+"
if DEBUG: print(prompt)
generation = pipe(prompt, **sampling_args)
if DEBUG: print(generation)
try:
parts = generation[0]['generated_text'].split("+")
new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1]
#gr.Info(f"Environment generated state: {new_state}", duration=3)
except:
new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
gr.Info(f"Environment generation invalid: {generation}", duration=None)
if DEBUG: print(new_state, reward, terminated, truncated)
return new_state, reward, terminated, truncated
def step_episode(inputs):
state, history = inputs
action = generate_action(state)
if action == "0000":
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
history.append(action)
new_state, reward, terminated, truncated = generate_state(state, action, history)
if int(terminated):
player = "White" if state.split()[1] == 'w' else "Black"
result_message = ""
if reward == "-1":
result_message = f"Environment ended game: {player} lost!"
elif reward == "1":
result_message = f"Environment ended game: {player} won!"
elif reward == "0.5":
result_message = "Environment ended game: It's a draw!"
else:
result_message = "Environment ended game: Unexpected outcome"
gr.Info(result_message, duration=None)
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
if int(truncated):
gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None)
svg_string = create_chess_board_svg()
return svg_string, START_POSITION, "", [START_POSITION, []]
try:
mv = chess.Move.from_uci(action)
svg_string = create_chess_board_svg(new_state, lastmove=mv)
except:
svg_string = create_chess_board_svg(new_state)
if not svg_string:
svg_string = create_chess_board_svg()
return svg_string, [START_POSITION, []], START_POSITION, ""
return svg_string, new_state, ", ".join(history), [new_state, history]
def create_chess_board_svg(fen=None, lastmove=None):
try:
board = chess.Board(fen) if fen else chess.Board()
return chess.svg.board(board, lastmove=lastmove, size=400)
except:
gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None)
return ""
board = gr.HTML("""
<div style='height: 400px; width: 400px; background-color: white;'></div>
""", label="Chess Board")
state_fen = gr.Textbox(label="FEN")
move_history = gr.Textbox(label="Move history")
demo = gr.Interface(
fn=step_episode,
inputs=gr.State(value=[START_POSITION, []]),
outputs=[board, state_fen, move_history, gr.State()],
title="♜ RookWorld-LM-124M Self-Play Demo",
description="""♜ RookWorld-LM (GPT2-124M) Unified Policy & World Model
Both the *policy actions* (with generated CoT) and the *environment response* (World Model) are fully generated by a single language model.
Click the **Generate**-button to generate a new move and environment response. On CPU this can take ~30 seconds per step.
[Project Details](https://huggingface.co/collections/jrahn/rookworld-and-rook-reasoning-over-organized-knowledge-679b511567f95e05d9c4a7e7)""",
allow_flagging="never",
analytics_enabled=False,
submit_btn="Generate",
clear_btn="Reset",
)
demo.launch()