| """ | |
| @author: bvk1ng (Adityam Ghosh) | |
| Date: 12/28/2023 | |
| """ | |
| from typing import Callable, List, Tuple, Any, Dict, Union | |
| import numpy as np | |
| import gymnasium as gym | |
| def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray: | |
| """Function to append the recent state into the state variable and remove the oldest using FIFO.""" | |
| return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2) | |
| def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable): | |
| """Function to play the atari game.""" | |
| obs, info = env.reset() | |
| obs_small = img_transform.transform(obs) | |
| state = np.stack([obs_small] * 4, axis=2) | |
| done, truncated = False, False | |
| episode_reward = 0 | |
| while not (done or truncated): | |
| action = model.predict(np.expand_dims(state, axis=0)).numpy() | |
| action = np.argmax(action, axis=1)[0] | |
| obs, reward, done, truncated, info = env.step(action) | |
| obs_small = img_transform.transform(obs) | |
| episode_reward += reward | |
| next_state = update_state(state=state, obs_small=obs_small) | |
| state = next_state | |
| print(f"Total reward earned: {episode_reward}") | |