cart-pole-ql / train.py
pkalkman's picture
finished exercise
78aa1e7
import gymnasium as gym
import numpy as np
import imageio
NUMBER_OF_EPISODES = 30000
LEARNING_RATE = 0.1
DISCOUNT_FACTOR = 0.98
EPSILON = 0.3
NUMBER_OF_BINS = 14
bins = {
'cart_pos_bins': np.linspace(-4.8, 4.8, num=NUMBER_OF_BINS),
'cart_vel_bins': np.linspace(-3.0, 3.0, num=NUMBER_OF_BINS),
'pole_angle_bins': np.linspace(-0.418, 0.418, num=NUMBER_OF_BINS),
'pole_angular_vel_bins': np.linspace(-2.0, 2.0, num=NUMBER_OF_BINS)
}
def discretize_state(state, bins):
cart_pos, cart_vel, pole_angle, pole_angular_vel = state
discrete_cart_pos = np.digitize(cart_pos, bins['cart_pos_bins']) - 1
discrete_cart_vel = np.digitize(cart_vel, bins['cart_vel_bins']) - 1
discrete_pole_angle = np.digitize(pole_angle, bins['pole_angle_bins']) - 1
discrete_pole_angular_vel = np.digitize(pole_angular_vel, bins['pole_angular_vel_bins']) - 1
return (discrete_cart_pos, discrete_cart_vel, discrete_pole_angle, discrete_pole_angular_vel)
def initialize_environment():
env = gym.make('CartPole-v1')
state_shape = env.observation_space.shape
action_size = env.action_space.n
print(f"State shape: {state_shape}, Action size: {action_size}")
return env, action_size
def initialize_q_table(state_size, action_size):
return np.zeros((NUMBER_OF_BINS, NUMBER_OF_BINS, NUMBER_OF_BINS, NUMBER_OF_BINS, action_size))
def epsilon_greedy_action_selection(state, qtable, env, epsilon):
if np.random.uniform(0, 1) < epsilon:
return env.action_space.sample()
else:
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(state, bins)
return np.argmax(qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, :])
def update_q_value(current_state, action, reward, next_state, qtable, learning_rate, discount_factor):
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(next_state, bins)
future_q_value = np.max(qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, :])
cart_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(current_state, bins)
current_q_value = qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, action]
new_q_value = current_q_value + learning_rate * (reward + discount_factor * future_q_value - current_q_value)
qtable[cart_pos, cart_vel, pole_angle, pole_angular_vel, action] = new_q_value
def train_agent(env, qtable, num_episodes, learning_rate, discount_factor, epsilon):
for episode_nr in range(num_episodes):
current_state, _ = env.reset()
done = False
while not done:
action = epsilon_greedy_action_selection(current_state, qtable, env, epsilon)
next_state, reward, done, _, _ = env.step(action)
# Apply negative reward if done, otherwise continue learning
if done and reward == 1:
reward = -100
update_q_value(current_state, action, reward, next_state, qtable, learning_rate, discount_factor)
current_state = next_state
if episode_nr % 1000 == 0:
print(f"Episode {episode_nr} completed")
return qtable
def save_qtable(filename, qtable):
np.save(filename, qtable)
print(f"Q-table saved as {filename}")
def create_replay_video(env, qtable, filename="replay.mp4"):
frames = []
current_state, _ = env.reset()
done = False
while not done:
frames.append(env.render())
art_pos, cart_vel, pole_angle, pole_angular_vel = discretize_state(current_state, bins)
action = np.argmax(qtable[art_pos, cart_vel, pole_angle, pole_angular_vel, :])
next_state, _, done, _, _ = env.step(action)
current_state = next_state
env.close()
with imageio.get_writer(filename, fps=30) as video:
for frame in frames:
video.append_data(frame)
print(f"Video saved as {filename}")
def main():
env, action_size = initialize_environment()
qtable = initialize_q_table(NUMBER_OF_BINS, action_size)
qtable = train_agent(env, qtable, NUMBER_OF_EPISODES, LEARNING_RATE,
DISCOUNT_FACTOR, EPSILON)
save_qtable("cartpole_qtable.npy", qtable)
env = gym.make('CartPole-v1', render_mode="rgb_array")
create_replay_video(env, qtable)
if __name__ == "__main__":
main()