import gradio as gr import chess import chess.svg from collections import deque from transformers import pipeline import torch DEBUG = False device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipeline( "text-generation", model="jrahn/RookWorld-LM-124M", torch_dtype=torch.bfloat16, device=device ) pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id pipe.tokenizer.padding_side = "left" sampling_args = { "do_sample": True, "temperature": 0.7, "top_k": 15, "truncation": True, "return_full_text": False, "pad_token_id": pipe.tokenizer.eos_token_id, "max_length": 186 } START_POSITION = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" def generate_action(state): prompt = f"P: {state} " if DEBUG: print(prompt) generation = pipe(prompt, **sampling_args) if DEBUG: print(generation) try: action = generation[0]['generated_text'].split("B: ")[-1].strip() gr.Info(f"Policy generated move: {action}", duration=3) # TODO: display generated CoT except: gr.Info(f"Policy generation invalid: {generation}", duration=None) action = "0000" if DEBUG: print(action) return action def generate_state(state, action, history): if DEBUG: print(state, action, history) history_text = " ".join(history[-10:]) prompt = f"A: {state}+{action}+{history_text}+" if DEBUG: print(prompt) generation = pipe(prompt, **sampling_args) if DEBUG: print(generation) try: parts = generation[0]['generated_text'].split("+") new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1] #gr.Info(f"Environment generated state: {new_state}", duration=3) except: new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1" gr.Info(f"Environment generation invalid: {generation}", duration=None) if DEBUG: print(new_state, reward, terminated, truncated) return new_state, reward, terminated, truncated def step_episode(inputs): state, history = inputs action = generate_action(state) if action == "0000": svg_string = create_chess_board_svg() return svg_string, START_POSITION, "", [START_POSITION, []] history.append(action) new_state, reward, terminated, truncated = generate_state(state, action, history) if int(terminated): player = "White" if state.split()[1] == 'w' else "Black" result_message = "" if reward == "-1": result_message = f"Environment ended game: {player} lost!" elif reward == "1": result_message = f"Environment ended game: {player} won!" elif reward == "0.5": result_message = "Environment ended game: It's a draw!" else: result_message = "Environment ended game: Unexpected outcome" gr.Info(result_message, duration=None) svg_string = create_chess_board_svg() return svg_string, START_POSITION, "", [START_POSITION, []] if int(truncated): gr.Info(f"Environment ended game: ILLEGAL_MOVE", duration=None) svg_string = create_chess_board_svg() return svg_string, START_POSITION, "", [START_POSITION, []] try: mv = chess.Move.from_uci(action) svg_string = create_chess_board_svg(new_state, lastmove=mv) except: svg_string = create_chess_board_svg(new_state) if not svg_string: svg_string = create_chess_board_svg() return svg_string, [START_POSITION, []], START_POSITION, "" return svg_string, new_state, ", ".join(history), [new_state, history] def create_chess_board_svg(fen=None, lastmove=None): try: board = chess.Board(fen) if fen else chess.Board() return chess.svg.board(board, lastmove=lastmove, size=400) except: gr.Info(f"Python-Chess board visualization cannot be rendered from FEN: {fen}", duration=None) return "" board = gr.HTML("""
""", label="Chess Board") state_fen = gr.Textbox(label="FEN") move_history = gr.Textbox(label="Move history") demo = gr.Interface( fn=step_episode, inputs=gr.State(value=[START_POSITION, []]), outputs=[board, state_fen, move_history, gr.State()], title="♜ RookWorld-LM-124M Self-Play Demo", description="""♜ RookWorld-LM (GPT2-124M) Unified Policy & World Model Both the *policy actions* (with generated CoT) and the *environment response* (World Model) are fully generated by a single language model. Click the **Generate**-button to generate a new move and environment response. On CPU this can take ~30 seconds per step. [Project Details](https://huggingface.co/collections/jrahn/rookworld-and-rook-reasoning-over-organized-knowledge-679b511567f95e05d9c4a7e7)""", allow_flagging="never", analytics_enabled=False, submit_btn="Generate", clear_btn="Reset", ) demo.launch()