|
import chess |
|
import chess.engine |
|
import numpy as np |
|
import tensorflow as tf |
|
import time |
|
import os |
|
import datetime |
|
import numpy as np |
|
|
|
|
|
class PolicyValueNetwork(tf.keras.Model): |
|
def __init__(self, num_moves): |
|
super(PolicyValueNetwork, self).__init__() |
|
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same') |
|
self.flatten = tf.keras.layers.Flatten() |
|
self.dense_policy = tf.keras.layers.Dense(num_moves, activation='softmax', name='policy_head') |
|
self.dense_value = tf.keras.layers.Dense(1, activation='tanh', name='value_head') |
|
|
|
def call(self, inputs): |
|
x = self.conv1(inputs) |
|
x = self.flatten(x) |
|
policy = self.dense_policy(x) |
|
value = self.dense_value(x) |
|
return policy, value |
|
|
|
|
|
def board_to_input(board): |
|
piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] |
|
input_planes = np.zeros((8, 8, 12), dtype=np.float32) |
|
|
|
for piece_type_index, piece_type in enumerate(piece_types): |
|
for square in chess.SQUARES: |
|
piece = board.piece_at(square) |
|
if piece is not None: |
|
if piece.piece_type == piece_type: |
|
plane_index = piece_type_index if piece.color == chess.WHITE else piece_type_index + 6 |
|
row, col = chess.square_rank(square), chess.square_file(square) |
|
input_planes[row, col, plane_index] = 1.0 |
|
return input_planes |
|
|
|
def get_legal_moves_mask(board): |
|
legal_moves = list(board.legal_moves) |
|
move_indices = [move_to_index(move) for move in legal_moves] |
|
|
|
|
|
valid_move_indices = [] |
|
out_of_bounds_indices = [] |
|
for index in move_indices: |
|
if 0 <= index < NUM_POSSIBLE_MOVES: |
|
valid_move_indices.append(index) |
|
else: |
|
out_of_bounds_indices.append(index) |
|
|
|
|
|
mask = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32) |
|
mask[valid_move_indices] = 1.0 |
|
return mask |
|
|
|
|
|
NUM_POSSIBLE_MOVES = 4672 |
|
|
|
def move_to_index(move): |
|
"""Standard, deterministic move to index conversion (UCI-like encoding).""" |
|
index = 0 |
|
|
|
|
|
if move.promotion is None: |
|
index = move.from_square * 64 + move.to_square |
|
|
|
|
|
elif move.promotion == chess.KNIGHT: |
|
index = 4096 + move.to_square |
|
elif move.promotion == chess.BISHOP: |
|
index = 4096 + 64 + move.to_square |
|
elif move.promotion == chess.ROOK: |
|
index = 4096 + 64*2 + move.to_square |
|
elif move.promotion == chess.QUEEN: |
|
index = 4096 + 64*3 + move.to_square |
|
else: |
|
raise ValueError(f"Unknown promotion piece type: {move.promotion}") |
|
|
|
return index |
|
|
|
def index_to_move(index, board): |
|
"""Standard, deterministic index to move conversion (index to chess.Move).""" |
|
|
|
if 0 <= index < 4096: |
|
from_square = index // 64 |
|
to_square = index % 64 |
|
promotion = None |
|
|
|
elif 4096 <= index < 4096 + 64: |
|
from_square_rank = chess.square_rank(chess.A8) - 1 |
|
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) |
|
to_square = index - 4096 |
|
promotion = chess.KNIGHT |
|
|
|
elif 4096 + 64 <= index < 4096 + 64*2: |
|
from_square_rank = chess.square_rank(chess.A8) - 1 |
|
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) |
|
to_square = index - (4096 + 64) |
|
promotion = chess.BISHOP |
|
|
|
elif 4096 + 64*2 <= index < 4096 + 64*3: |
|
from_square_rank = chess.square_rank(chess.A8) - 1 |
|
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) |
|
to_square = index - (4096 + 64*2) |
|
promotion = chess.ROOK |
|
|
|
elif 4096 + 64*3 <= index < NUM_POSSIBLE_MOVES: |
|
from_square_rank = chess.square_rank(chess.A8) - 1 |
|
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) |
|
to_square = index - (4096 + 64*3) |
|
promotion = chess.QUEEN |
|
|
|
else: |
|
return None |
|
|
|
move = chess.Move(from_square, to_square, promotion=promotion) |
|
if move in board.legal_moves: |
|
return move |
|
return None |
|
|
|
|
|
def get_game_result_value(board): |
|
if board.is_checkmate(): |
|
return 1 if board.turn == chess.BLACK else -1 |
|
elif board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves() or board.is_fivefold_repetition() or board.is_variant_draw(): |
|
return 0 |
|
else: |
|
return 0 |
|
|
|
|
|
class MCTSNode: |
|
def __init__(self, board, parent=None, prior_prob=0): |
|
self.board = board.copy() |
|
self.parent = parent |
|
self.children = {} |
|
self.visits = 0 |
|
self.value_sum = 0 |
|
self.prior_prob = prior_prob |
|
self.policy_prob = 0 |
|
self.value = 0 |
|
|
|
def select_child(self, exploration_constant=1.4): |
|
best_child = None |
|
best_ucb = -float('inf') |
|
for move, child in self.children.items(): |
|
ucb = child.value + exploration_constant * child.prior_prob * np.sqrt(self.visits) / (1 + child.visits) |
|
if ucb > best_ucb: |
|
best_ucb = ucb |
|
best_child = child |
|
return best_child |
|
|
|
def expand(self, policy_probs): |
|
legal_moves = list(self.board.legal_moves) |
|
for move in legal_moves: |
|
move_index = move_to_index(move) |
|
prior_prob = policy_probs[move_index] |
|
self.children[move] = MCTSNode(chess.Board(fen=self.board.fen()), parent=self, prior_prob=prior_prob) |
|
|
|
def evaluate(self, policy_value_net): |
|
input_board = board_to_input(self.board) |
|
policy_output, value_output = policy_value_net(np.expand_dims(input_board, axis=0)) |
|
policy_probs = policy_output.numpy()[0] |
|
value = value_output.numpy()[0][0] |
|
|
|
legal_moves_mask = get_legal_moves_mask(self.board) |
|
masked_policy_probs = policy_probs * legal_moves_mask |
|
if np.sum(masked_policy_probs) > 0: |
|
masked_policy_probs /= np.sum(masked_policy_probs) |
|
else: |
|
masked_policy_probs = legal_moves_mask / np.sum(legal_moves_mask) |
|
|
|
self.policy_prob = masked_policy_probs |
|
self.value = value |
|
return value, masked_policy_probs |
|
|
|
def backup(self, value): |
|
self.visits += 1 |
|
self.value_sum += value |
|
self.value = self.value_sum / self.visits |
|
if self.parent: |
|
self.parent.backup(-value) |
|
|
|
def run_mcts(root_node, policy_value_net, num_simulations): |
|
for _ in range(num_simulations): |
|
node = root_node |
|
search_path = [node] |
|
|
|
while node.children and not node.board.is_game_over(): |
|
node = node.select_child() |
|
search_path.append(node) |
|
|
|
leaf_node = search_path[-1] |
|
|
|
if not leaf_node.board.is_game_over(): |
|
value, policy_probs = leaf_node.evaluate(policy_value_net) |
|
leaf_node.expand(policy_probs) |
|
else: |
|
value = get_game_result_value(leaf_node.board) |
|
|
|
leaf_node.backup(value) |
|
|
|
return choose_best_move_from_mcts(root_node) |
|
|
|
def choose_best_move_from_mcts(root_node, temperature=0.0): |
|
if temperature == 0: |
|
best_move = max(root_node.children, key=lambda move: root_node.children[move].visits) |
|
else: |
|
visits = [root_node.children[move].visits for move in root_node.children] |
|
move_probs = np.array(visits) ** (1/temperature) |
|
move_probs = move_probs / np.sum(move_probs) |
|
moves = list(root_node.children.keys()) |
|
best_move = np.random.choice(moves, p=move_probs) |
|
return best_move |
|
|
|
|
|
class RLEngine: |
|
def __init__(self, policy_value_net, num_simulations_per_move=100): |
|
self.policy_value_net = policy_value_net |
|
self.num_simulations_per_move = num_simulations_per_move |
|
|
|
def choose_move(self, board): |
|
root_node = MCTSNode(board) |
|
best_move = run_mcts(root_node, self.policy_value_net, self.num_simulations_per_move) |
|
return best_move |
|
|
|
|
|
def self_play_game(engine, model, num_simulations): |
|
game_history = [] |
|
board = chess.Board() |
|
while not board.is_game_over(): |
|
root_node = MCTSNode(board) |
|
run_mcts(root_node, model, num_simulations) |
|
|
|
policy_targets = create_policy_targets_from_mcts_visits(root_node) |
|
game_history.append((board.fen(), policy_targets)) |
|
|
|
best_move = choose_best_move_from_mcts(root_node, temperature=0.8) |
|
board.push(best_move) |
|
|
|
game_result = get_game_result_value(board) |
|
|
|
for i in range(len(game_history)): |
|
fen, policy_target = game_history[i] |
|
game_history[i] = (fen, policy_target, game_result if board.turn == chess.WHITE else -game_result) |
|
return game_history |
|
|
|
def create_policy_targets_from_mcts_visits(root_node): |
|
policy_targets = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32) |
|
for move, child_node in root_node.children.items(): |
|
move_index = move_to_index(move) |
|
policy_targets[move_index] = child_node.visits |
|
policy_targets /= np.sum(policy_targets) |
|
return policy_targets |
|
|
|
def train_step(model, board_inputs, policy_targets, value_targets, optimizer): |
|
with tf.GradientTape() as tape: |
|
policy_outputs, value_outputs = model(board_inputs) |
|
policy_loss = tf.keras.losses.CategoricalCrossentropy()(policy_targets, policy_outputs) |
|
value_loss = tf.keras.losses.MeanSquaredError()(value_targets, value_outputs) |
|
total_loss = policy_loss + value_loss |
|
gradients = tape.gradient(total_loss, model.trainable_variables) |
|
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) |
|
return total_loss, policy_loss, value_loss |
|
|
|
def train_network(model, game_histories, optimizer, epochs=10, batch_size=32): |
|
all_board_inputs = [] |
|
all_policy_targets = [] |
|
all_value_targets = [] |
|
|
|
for game_history in game_histories: |
|
for fen, policy_target, game_result in game_history: |
|
board = chess.Board(fen) |
|
all_board_inputs.append(board_to_input(board)) |
|
all_policy_targets.append(policy_target) |
|
all_value_targets.append(np.array([game_result])) |
|
|
|
all_board_inputs = np.array(all_board_inputs) |
|
all_policy_targets = np.array(all_policy_targets) |
|
all_value_targets = np.array(all_value_targets) |
|
|
|
dataset = tf.data.Dataset.from_tensor_slices((all_board_inputs, all_policy_targets, all_value_targets)) |
|
dataset = dataset.shuffle(buffer_size=len(all_board_inputs)).batch(batch_size).prefetch(tf.data.AUTOTUNE) |
|
|
|
for epoch in range(epochs): |
|
print(f"Epoch {epoch+1}/{epochs}") |
|
for batch_inputs, batch_policy_targets, batch_value_targets in dataset: |
|
loss, p_loss, v_loss = train_step(model, batch_inputs, batch_policy_targets, batch_value_targets, optimizer) |
|
print(f" Loss: {loss:.4f}, Policy Loss: {p_loss:.4f}, Value Loss: {v_loss:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if tf.config.list_physical_devices('GPU'): |
|
print("\n\nGPU is available and will be used for training.\n\n") |
|
gpu_device = '/GPU:0' |
|
else: |
|
print("\n\nGPU is not available. Training will use CPU (may be slow).\n\n") |
|
gpu_device = '/CPU:0' |
|
|
|
with tf.device(gpu_device): |
|
|
|
policy_value_net = PolicyValueNetwork(NUM_POSSIBLE_MOVES) |
|
engine = RLEngine(policy_value_net, num_simulations_per_move=100) |
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) |
|
|
|
|
|
num_self_play_games = 50 |
|
epochs = 5 |
|
|
|
|
|
game_histories = [] |
|
start_time = time.time() |
|
|
|
|
|
MODEL_SAVE_DIR = "models_colab" |
|
os.makedirs(MODEL_SAVE_DIR, exist_ok=True) |
|
|
|
for i in range(num_self_play_games): |
|
print(f"Self-play game {i+1}/{num_self_play_games} \n") |
|
game_history = self_play_game(engine, policy_value_net, num_simulations=50) |
|
game_histories.append(game_history) |
|
|
|
train_network(policy_value_net, game_histories, optimizer, epochs=epochs) |
|
|
|
end_time = time.time() |
|
training_time = end_time - start_time |
|
print(f"\n\n ---- Training completed in {training_time:.2f} seconds. ---- \n") |
|
|
|
|
|
current_datetime = datetime.datetime.now() |
|
model_version_str = current_datetime.strftime("%Y-%m-%d-%H%M") |
|
model_save_path = os.path.join(MODEL_SAVE_DIR, f"StockZero-{model_version_str}.weights.h5") |
|
policy_value_net.save_weights(model_save_path) |
|
print(f"Trained model weights saved to '{model_save_path}' in '{MODEL_SAVE_DIR}' directory in Colab.") |
|
|
|
|
|
|
|
import shutil |
|
zip_file_path = f"StockZero-{model_version_str}" |
|
shutil.make_archive(zip_file_path, 'zip', MODEL_SAVE_DIR) |
|
print(f"Model directory zipped to '{zip_file_path}'. Download this file.") |
|
from google.colab import files |
|
files.download(f"{zip_file_path}.zip") |
|
|
|
print("\n\n ----- Training finished. ------- \n\n") |