|
import gymnasium as gym |
|
import numpy as np |
|
import imageio |
|
|
|
|
|
NUMBER_OF_EPISODES = 30000 |
|
LEARNING_RATE = 0.1 |
|
DISCOUNT_FACTOR = 0.98 |
|
EPSILON = 0.3 |
|
NUMBER_OF_BINS = 14 |
|
|
|
bins = { |
|
'cart_pos_bins': np.linspace(-4.8, 4.8, num=NUMBER_OF_BINS), |
|
'cart_vel_bins': np.linspace(-3.0, 3.0, num=NUMBER_OF_BINS), |
|
'pole_angle_bins': np.linspace(-0.418, 0.418, num=NUMBER_OF_BINS), |
|
'pole_angular_vel_bins': np.linspace(-2.0, 2.0, num=NUMBER_OF_BINS) |
|
} |
|
|
|
|
|
def discretize_state(state, bins): |
|
cart_pos, cart_vel, pole_angle, pole_angular_vel = state |
|
discrete_cart_pos = np.digitize(cart_pos, bins['cart_pos_bins']) - 1 |
|
discrete_cart_vel = np.digitize(cart_vel, bins['cart_vel_bins']) - 1 |
|
discrete_pole_angle = np.digitize(pole_angle, bins['pole_angle_bins']) - 1 |
|
discrete_pole_angular_vel = np.digitize(pole_angular_vel, bins['pole_angular_vel_bins']) - 1 |
|
return (discrete_cart_pos, discrete_cart_vel, discrete_pole_angle, discrete_pole_angular_vel) |
|
|
|
|
|
def initialize_environment(): |
|
env = gym.make('CartPole-v1') |
|
state_shape = env.observation_space.shape |
|
action_size = env.action_space.n |
|
print(f"State shape: {state_shape}, Action size: {action_size}") |
|
return env, action_size |
|
|
|
|
|
def initialize_q_table(state_size, action_size): |
|
return np.zeros((NUMBER_OF_BINS, NUMBER_OF_BINS, NUMBER_OF_BINS, NUMBER_OF_BINS, action_size)) |
|
|
|
|
|
def epsilon_greedy_action_selection(state, qtable, env, epsilon): |
|
if np.random.uniform(0, 1) < epsilon: |
|
return env.action_space.sample() |
|
else: |
|
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(state, bins) |
|
return np.argmax(qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, :]) |
|
|
|
|
|
def update_q_value(current_state, action, reward, next_state, qtable, learning_rate, discount_factor): |
|
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(next_state, bins) |
|
future_q_value = np.max(qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, :]) |
|
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(current_state, bins) |
|
current_q_value = qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, action] |
|
new_q_value = current_q_value + learning_rate * (reward + discount_factor * future_q_value - current_q_value) |
|
qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, action] = new_q_value |
|
|
|
|
|
def train_agent(env, qtable, num_episodes, learning_rate, discount_factor, epsilon): |
|
for episode_nr in range(num_episodes): |
|
current_state, _ = env.reset() |
|
done = False |
|
|
|
while not done: |
|
action = epsilon_greedy_action_selection(current_state, qtable, env, epsilon) |
|
next_state, reward, done, _, _ = env.step(action) |
|
|
|
|
|
if done and reward == 1: |
|
reward = -100 |
|
|
|
update_q_value(current_state, action, reward, next_state, qtable, learning_rate, discount_factor) |
|
current_state = next_state |
|
|
|
if episode_nr % 1000 == 0: |
|
print(f"Episode {episode_nr} completed") |
|
|
|
return qtable |
|
|
|
|
|
def save_qtable(filename, qtable): |
|
np.save(filename, qtable) |
|
print(f"Q-table saved as {filename}") |
|
|
|
|
|
def create_replay_video(env, qtable, filename="replay.mp4"): |
|
frames = [] |
|
current_state, _ = env.reset() |
|
done = False |
|
|
|
while not done: |
|
frames.append(env.render()) |
|
art_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(current_state, bins) |
|
action = np.argmax(qtable[art_pos, cart_vel, pole_angle, pole_angular_vel, :]) |
|
next_state, _, done, _, _ = env.step(action) |
|
current_state = next_state |
|
|
|
env.close() |
|
|
|
with imageio.get_writer(filename, fps=30) as video: |
|
for frame in frames: |
|
video.append_data(frame) |
|
|
|
print(f"Video saved as {filename}") |
|
|
|
|
|
def main(): |
|
env, action_size = initialize_environment() |
|
qtable = initialize_q_table(NUMBER_OF_BINS, action_size) |
|
qtable = train_agent(env, qtable, NUMBER_OF_EPISODES, LEARNING_RATE, |
|
DISCOUNT_FACTOR, EPSILON) |
|
save_qtable("cartpole_qtable.npy", qtable) |
|
|
|
env = gym.make('CartPole-v1', render_mode="rgb_array") |
|
create_replay_video(env, qtable) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|