Spaces:
Runtime error
Runtime error
import random | |
import json | |
import ast | |
class QlearningAgent(): | |
def __init__(self, epsilon,alpha ,discount_factor, train): | |
self.q_table = {} | |
self.epsilon = epsilon | |
self.alpha = alpha | |
self.discount_factor = discount_factor | |
self.train = train | |
def save_agent_dict(self, file_name): | |
# Convert tuple keys to strings | |
q_table_str_keys = {str(key): value for key, value in self.q_table.items()} | |
with open(file_name, 'w') as file_json: | |
json.dump(q_table_str_keys, file_json) | |
def load_agent_dict(self, file_name): | |
try: | |
with open(file_name, 'r') as file_json: | |
json_data = json.load(file_json) | |
# Convert string keys back to tuples | |
q_table = {ast.literal_eval(key): value for key, value in json_data.items()} | |
self.q_table = q_table | |
print("Q-table loaded successfully.") | |
except FileNotFoundError: | |
print(f"File '{file_name}' not found. Q-table not loaded.") | |
def get_q_value(self, state, action, piece): | |
state_tuple = tuple(state.flatten()) | |
if (state_tuple,action,piece) not in self.q_table: | |
self.q_table[(state_tuple, action, piece)] = 0.0 | |
return self.q_table[(state_tuple, action, piece)] | |
def choose_move(self, state, available_moves, piece): | |
q_values = [] | |
for action in available_moves: | |
q_values.append(self.get_q_value(state, action, piece)) | |
if random.uniform(0, 1) < self.epsilon and self.train: | |
return random.choice(available_moves) | |
else: | |
max_q_value = max(q_values) | |
if q_values.count(max_q_value) > 1: | |
best_moves = [i for i in range(len(available_moves)) if q_values[i] == max_q_value] | |
i = random.choice(best_moves) | |
else: | |
i = q_values.index(max_q_value) | |
return available_moves[i] | |
def update_q_value(self, states, rewards): | |
#estado atual + alpha[retorno estado atual + ymax(proximo_estado) - estado atual] | |
for i,state in enumerate(states): | |
rt = 0 | |
if state not in self.q_table.keys(): | |
self.q_table[state] = 0.0 | |
for ii in range(0,len(rewards)): | |
rt+= rewards[ii] * (self.discount_factor ** (ii-i)) | |
if i == len(states)-1: | |
next_reward = 0 | |
else: | |
next_reward = rewards[i+1] | |
q_formula = self.q_table[state] + (self.alpha*(rewards[i] + self.discount_factor*(next_reward) - self.q_table[state])) | |
self.q_table[state] = q_formula | |