gradio_tictactoe / classes /Qlearningagent.py
Gustking's picture
Upload 8 files
b3accf9
import random
import json
import ast
class QlearningAgent():
def __init__(self, epsilon,alpha ,discount_factor, train):
self.q_table = {}
self.epsilon = epsilon
self.alpha = alpha
self.discount_factor = discount_factor
self.train = train
def save_agent_dict(self, file_name):
# Convert tuple keys to strings
q_table_str_keys = {str(key): value for key, value in self.q_table.items()}
with open(file_name, 'w') as file_json:
json.dump(q_table_str_keys, file_json)
def load_agent_dict(self, file_name):
try:
with open(file_name, 'r') as file_json:
json_data = json.load(file_json)
# Convert string keys back to tuples
q_table = {ast.literal_eval(key): value for key, value in json_data.items()}
self.q_table = q_table
print("Q-table loaded successfully.")
except FileNotFoundError:
print(f"File '{file_name}' not found. Q-table not loaded.")
def get_q_value(self, state, action, piece):
state_tuple = tuple(state.flatten())
if (state_tuple,action,piece) not in self.q_table:
self.q_table[(state_tuple, action, piece)] = 0.0
return self.q_table[(state_tuple, action, piece)]
def choose_move(self, state, available_moves, piece):
q_values = []
for action in available_moves:
q_values.append(self.get_q_value(state, action, piece))
if random.uniform(0, 1) < self.epsilon and self.train:
return random.choice(available_moves)
else:
max_q_value = max(q_values)
if q_values.count(max_q_value) > 1:
best_moves = [i for i in range(len(available_moves)) if q_values[i] == max_q_value]
i = random.choice(best_moves)
else:
i = q_values.index(max_q_value)
return available_moves[i]
def update_q_value(self, states, rewards):
#estado atual + alpha[retorno estado atual + ymax(proximo_estado) - estado atual]
for i,state in enumerate(states):
rt = 0
if state not in self.q_table.keys():
self.q_table[state] = 0.0
for ii in range(0,len(rewards)):
rt+= rewards[ii] * (self.discount_factor ** (ii-i))
if i == len(states)-1:
next_reward = 0
else:
next_reward = rewards[i+1]
q_formula = self.q_table[state] + (self.alpha*(rewards[i] + self.discount_factor*(next_reward) - self.q_table[state]))
self.q_table[state] = q_formula