Commit 
							
							·
						
						1dc3470
	
1
								Parent(s):
							
							31d2530
								
pushing model
Browse files- DQN_baseline.cleanrl_model +0 -0
- README.md +16 -15
- dqn.py +81 -29
- events.out.tfevents.1676244619.wycliffeduncan-Victus-by-HP-Gaming-Laptop-15-fa0xxx.28826.0 → events.out.tfevents.1678647860.portal.3094185.0 +2 -2
- replay.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-0.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-1.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-8.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-0.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-1.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-8.mp4 +0 -0
    	
        DQN_baseline.cleanrl_model
    CHANGED
    
    | Binary files a/DQN_baseline.cleanrl_model and b/DQN_baseline.cleanrl_model differ | 
|  | 
    	
        README.md
    CHANGED
    
    | @@ -16,7 +16,7 @@ model-index: | |
| 16 | 
             
                  type: CartPole-v1
         | 
| 17 | 
             
                metrics:
         | 
| 18 | 
             
                - type: mean_reward
         | 
| 19 | 
            -
                  value:  | 
| 20 | 
             
                  name: mean_reward
         | 
| 21 | 
             
                  verified: false
         | 
| 22 | 
             
            ---
         | 
| @@ -46,32 +46,33 @@ curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline-seed1/raw/main/dq | |
| 46 | 
             
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline-seed1/raw/main/pyproject.toml
         | 
| 47 | 
             
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline-seed1/raw/main/poetry.lock
         | 
| 48 | 
             
            poetry install --all-extras
         | 
| 49 | 
            -
            python dqn.py --exp-name DQN_baseline --track --wandb-entity pfunk --wandb-project-name dqpn --save-model true --upload-model true --hf-entity pfunk | 
| 50 | 
             
            ```
         | 
| 51 |  | 
| 52 | 
             
            # Hyperparameters
         | 
| 53 | 
             
            ```python
         | 
| 54 | 
            -
            {' | 
| 55 | 
            -
             ' | 
| 56 | 
            -
             ' | 
|  | |
| 57 | 
             
             'cuda': True,
         | 
| 58 | 
            -
             'end_e': 0. | 
| 59 | 
             
             'env_id': 'CartPole-v1',
         | 
| 60 | 
             
             'exp_name': 'DQN_baseline',
         | 
| 61 | 
            -
             'exploration_fraction': 0. | 
| 62 | 
            -
             'gamma': 0 | 
| 63 | 
             
             'hf_entity': 'pfunk',
         | 
| 64 | 
            -
             'learning_rate': 0. | 
| 65 | 
            -
             'learning_starts':  | 
| 66 | 
             
             'save_model': True,
         | 
| 67 | 
             
             'seed': 1,
         | 
| 68 | 
            -
             'start_e': 1,
         | 
| 69 | 
            -
             'target_network_frequency':  | 
| 70 | 
            -
             ' | 
| 71 | 
             
             'torch_deterministic': True,
         | 
| 72 | 
            -
             'total_timesteps':  | 
| 73 | 
             
             'track': True,
         | 
| 74 | 
            -
             'train_frequency':  | 
| 75 | 
             
             'upload_model': True,
         | 
| 76 | 
             
             'wandb_entity': 'pfunk',
         | 
| 77 | 
             
             'wandb_project_name': 'dqpn'}
         | 
|  | |
| 16 | 
             
                  type: CartPole-v1
         | 
| 17 | 
             
                metrics:
         | 
| 18 | 
             
                - type: mean_reward
         | 
| 19 | 
            +
                  value: 500.00 +/- 0.00
         | 
| 20 | 
             
                  name: mean_reward
         | 
| 21 | 
             
                  verified: false
         | 
| 22 | 
             
            ---
         | 
|  | |
| 46 | 
             
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline-seed1/raw/main/pyproject.toml
         | 
| 47 | 
             
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline-seed1/raw/main/poetry.lock
         | 
| 48 | 
             
            poetry install --all-extras
         | 
| 49 | 
            +
            python dqn.py --exp-name DQN_baseline --seed 1 --track --wandb-entity pfunk --wandb-project-name dqpn --capture-video true --save-model true --upload-model true --hf-entity pfunk
         | 
| 50 | 
             
            ```
         | 
| 51 |  | 
| 52 | 
             
            # Hyperparameters
         | 
| 53 | 
             
            ```python
         | 
| 54 | 
            +
            {'alg_type': 'dqn.py',
         | 
| 55 | 
            +
             'batch_size': 256,
         | 
| 56 | 
            +
             'buffer_size': 300000,
         | 
| 57 | 
            +
             'capture_video': True,
         | 
| 58 | 
             
             'cuda': True,
         | 
| 59 | 
            +
             'end_e': 0.1,
         | 
| 60 | 
             
             'env_id': 'CartPole-v1',
         | 
| 61 | 
             
             'exp_name': 'DQN_baseline',
         | 
| 62 | 
            +
             'exploration_fraction': 0.2,
         | 
| 63 | 
            +
             'gamma': 1.0,
         | 
| 64 | 
             
             'hf_entity': 'pfunk',
         | 
| 65 | 
            +
             'learning_rate': 0.0001,
         | 
| 66 | 
            +
             'learning_starts': 1000,
         | 
| 67 | 
             
             'save_model': True,
         | 
| 68 | 
             
             'seed': 1,
         | 
| 69 | 
            +
             'start_e': 1.0,
         | 
| 70 | 
            +
             'target_network_frequency': 100,
         | 
| 71 | 
            +
             'target_tau': 1.0,
         | 
| 72 | 
             
             'torch_deterministic': True,
         | 
| 73 | 
            +
             'total_timesteps': 500000,
         | 
| 74 | 
             
             'track': True,
         | 
| 75 | 
            +
             'train_frequency': 1,
         | 
| 76 | 
             
             'upload_model': True,
         | 
| 77 | 
             
             'wandb_entity': 'pfunk',
         | 
| 78 | 
             
             'wandb_project_name': 'dqpn'}
         | 
    	
        dqn.py
    CHANGED
    
    | @@ -46,27 +46,27 @@ def parse_args(): | |
| 46 | 
             
                    help="the id of the environment")
         | 
| 47 | 
             
                parser.add_argument("--total-timesteps", type=int, default=500000,
         | 
| 48 | 
             
                    help="total timesteps of the experiments")
         | 
| 49 | 
            -
                parser.add_argument("--learning-rate", type=float, default= | 
| 50 | 
             
                    help="the learning rate of the optimizer")
         | 
| 51 | 
            -
                parser.add_argument("--buffer-size", type=int, default= | 
| 52 | 
             
                    help="the replay memory buffer size")
         | 
| 53 | 
            -
                parser.add_argument("--gamma", type=float, default=0 | 
| 54 | 
             
                    help="the discount factor gamma")
         | 
| 55 | 
            -
                parser.add_argument("--tau", type=float, default=1 | 
| 56 | 
             
                    help="the target network update rate")
         | 
| 57 | 
            -
                parser.add_argument("--target-network-frequency", type=int, default= | 
| 58 | 
             
                    help="the timesteps it takes to update the target network")
         | 
| 59 | 
            -
                parser.add_argument("--batch-size", type=int, default= | 
| 60 | 
             
                    help="the batch size of sample from the reply memory")
         | 
| 61 | 
            -
                parser.add_argument("--start-e", type=float, default=1,
         | 
| 62 | 
             
                    help="the starting epsilon for exploration")
         | 
| 63 | 
            -
                parser.add_argument("--end-e", type=float, default=0. | 
| 64 | 
             
                    help="the ending epsilon for exploration")
         | 
| 65 | 
            -
                parser.add_argument("--exploration-fraction", type=float, default=0. | 
| 66 | 
             
                    help="the fraction of `total-timesteps` it takes from start-e to go end-e")
         | 
| 67 | 
            -
                parser.add_argument("--learning-starts", type=int, default= | 
| 68 | 
             
                    help="timestep to start learning")
         | 
| 69 | 
            -
                parser.add_argument("--train-frequency", type=int, default= | 
| 70 | 
             
                    help="the frequency of training")
         | 
| 71 | 
             
                args = parser.parse_args()
         | 
| 72 | 
             
                # fmt: on
         | 
| @@ -93,11 +93,11 @@ class QNetwork(nn.Module): | |
| 93 | 
             
                def __init__(self, env):
         | 
| 94 | 
             
                    super().__init__()
         | 
| 95 | 
             
                    self.network = nn.Sequential(
         | 
| 96 | 
            -
                        nn.Linear(np.array(env.single_observation_space.shape).prod(),  | 
| 97 | 
             
                        nn.ReLU(),
         | 
| 98 | 
            -
                        nn.Linear( | 
| 99 | 
             
                        nn.ReLU(),
         | 
| 100 | 
            -
                        nn.Linear( | 
| 101 | 
             
                    )
         | 
| 102 |  | 
| 103 | 
             
                def forward(self, x):
         | 
| @@ -115,14 +115,16 @@ if __name__ == "__main__": | |
| 115 | 
             
                if args.track:
         | 
| 116 | 
             
                    import wandb
         | 
| 117 |  | 
| 118 | 
            -
                     | 
|  | |
| 119 | 
             
                        project=args.wandb_project_name,
         | 
| 120 | 
             
                        entity=args.wandb_entity,
         | 
| 121 | 
            -
                        sync_tensorboard=True,
         | 
| 122 | 
             
                        config=vars(args),
         | 
|  | |
|  | |
| 123 | 
             
                        name=run_name,
         | 
|  | |
| 124 | 
             
                        monitor_gym=True,
         | 
| 125 | 
            -
                        save_code=True,
         | 
| 126 | 
             
                    )
         | 
| 127 | 
             
                writer = SummaryWriter(f"runs/{run_name}")
         | 
| 128 | 
             
                writer.add_text(
         | 
| @@ -130,6 +132,10 @@ if __name__ == "__main__": | |
| 130 | 
             
                    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
         | 
| 131 | 
             
                )
         | 
| 132 |  | 
|  | |
|  | |
|  | |
|  | |
| 133 | 
             
                # TRY NOT TO MODIFY: seeding
         | 
| 134 | 
             
                random.seed(args.seed)
         | 
| 135 | 
             
                np.random.seed(args.seed)
         | 
| @@ -141,9 +147,10 @@ if __name__ == "__main__": | |
| 141 | 
             
                # env setup
         | 
| 142 | 
             
                envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
         | 
| 143 | 
             
                assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
         | 
|  | |
| 144 |  | 
| 145 | 
             
                q_network = QNetwork(envs).to(device)
         | 
| 146 | 
            -
                optimizer = optim. | 
| 147 | 
             
                target_network = QNetwork(envs).to(device)
         | 
| 148 | 
             
                target_network.load_state_dict(q_network.state_dict())
         | 
| 149 |  | 
| @@ -152,15 +159,19 @@ if __name__ == "__main__": | |
| 152 | 
             
                    envs.single_observation_space,
         | 
| 153 | 
             
                    envs.single_action_space,
         | 
| 154 | 
             
                    device,
         | 
|  | |
| 155 | 
             
                    handle_timeout_termination=True,
         | 
| 156 | 
             
                )
         | 
| 157 | 
             
                start_time = time.time()
         | 
|  | |
|  | |
| 158 |  | 
| 159 | 
             
                # TRY NOT TO MODIFY: start the game
         | 
| 160 | 
             
                obs = envs.reset()
         | 
| 161 | 
             
                for global_step in range(args.total_timesteps):
         | 
| 162 | 
             
                    # ALGO LOGIC: put action logic here
         | 
| 163 | 
             
                    epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
         | 
|  | |
| 164 | 
             
                    if random.random() < epsilon:
         | 
| 165 | 
             
                        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
         | 
| 166 | 
             
                    else:
         | 
| @@ -173,10 +184,14 @@ if __name__ == "__main__": | |
| 173 | 
             
                    # TRY NOT TO MODIFY: record rewards for plotting purposes
         | 
| 174 | 
             
                    for info in infos:
         | 
| 175 | 
             
                        if "episode" in info.keys():
         | 
| 176 | 
            -
                             | 
| 177 | 
            -
                             | 
| 178 | 
            -
                             | 
| 179 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
| 180 | 
             
                            break
         | 
| 181 |  | 
| 182 | 
             
                    # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
         | 
| @@ -200,10 +215,43 @@ if __name__ == "__main__": | |
| 200 | 
             
                            loss = F.mse_loss(td_target, old_val)
         | 
| 201 |  | 
| 202 | 
             
                            if global_step % 100 == 0:
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                                 | 
| 205 | 
            -
                                 | 
| 206 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 207 |  | 
| 208 | 
             
                            # optimize the model
         | 
| 209 | 
             
                            optimizer.zero_grad()
         | 
| @@ -214,8 +262,11 @@ if __name__ == "__main__": | |
| 214 | 
             
                        if global_step % args.target_network_frequency == 0:
         | 
| 215 | 
             
                            for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
         | 
| 216 | 
             
                                target_network_param.data.copy_(
         | 
| 217 | 
            -
                                    args. | 
| 218 | 
             
                                )
         | 
|  | |
|  | |
|  | |
| 219 |  | 
| 220 | 
             
                if args.save_model:
         | 
| 221 | 
             
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
| @@ -234,14 +285,15 @@ if __name__ == "__main__": | |
| 234 | 
             
                        epsilon=0.05,
         | 
| 235 | 
             
                    )
         | 
| 236 | 
             
                    for idx, episodic_return in enumerate(episodic_returns):
         | 
| 237 | 
            -
                         | 
| 238 |  | 
| 239 | 
             
                    if args.upload_model:
         | 
| 240 | 
             
                        from cleanrl_utils.huggingface import push_to_hub
         | 
| 241 |  | 
| 242 | 
             
                        repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
         | 
| 243 | 
             
                        repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
         | 
| 244 | 
            -
                        push_to_hub(args,  | 
| 245 |  | 
|  | |
| 246 | 
             
                envs.close()
         | 
| 247 | 
             
                writer.close()
         | 
|  | |
| 46 | 
             
                    help="the id of the environment")
         | 
| 47 | 
             
                parser.add_argument("--total-timesteps", type=int, default=500000,
         | 
| 48 | 
             
                    help="total timesteps of the experiments")
         | 
| 49 | 
            +
                parser.add_argument("--learning-rate", type=float, default=0.0001,
         | 
| 50 | 
             
                    help="the learning rate of the optimizer")
         | 
| 51 | 
            +
                parser.add_argument("--buffer-size", type=int, default=300000,
         | 
| 52 | 
             
                    help="the replay memory buffer size")
         | 
| 53 | 
            +
                parser.add_argument("--gamma", type=float, default=1.0,
         | 
| 54 | 
             
                    help="the discount factor gamma")
         | 
| 55 | 
            +
                parser.add_argument("--target-tau", type=float, default=1.0,
         | 
| 56 | 
             
                    help="the target network update rate")
         | 
| 57 | 
            +
                parser.add_argument("--target-network-frequency", type=int, default=100,
         | 
| 58 | 
             
                    help="the timesteps it takes to update the target network")
         | 
| 59 | 
            +
                parser.add_argument("--batch-size", type=int, default=256,
         | 
| 60 | 
             
                    help="the batch size of sample from the reply memory")
         | 
| 61 | 
            +
                parser.add_argument("--start-e", type=float, default=1.0,
         | 
| 62 | 
             
                    help="the starting epsilon for exploration")
         | 
| 63 | 
            +
                parser.add_argument("--end-e", type=float, default=0.1,
         | 
| 64 | 
             
                    help="the ending epsilon for exploration")
         | 
| 65 | 
            +
                parser.add_argument("--exploration-fraction", type=float, default=0.2,
         | 
| 66 | 
             
                    help="the fraction of `total-timesteps` it takes from start-e to go end-e")
         | 
| 67 | 
            +
                parser.add_argument("--learning-starts", type=int, default=1000,
         | 
| 68 | 
             
                    help="timestep to start learning")
         | 
| 69 | 
            +
                parser.add_argument("--train-frequency", type=int, default=1,
         | 
| 70 | 
             
                    help="the frequency of training")
         | 
| 71 | 
             
                args = parser.parse_args()
         | 
| 72 | 
             
                # fmt: on
         | 
|  | |
| 93 | 
             
                def __init__(self, env):
         | 
| 94 | 
             
                    super().__init__()
         | 
| 95 | 
             
                    self.network = nn.Sequential(
         | 
| 96 | 
            +
                        nn.Linear(np.array(env.single_observation_space.shape).prod(), 512),
         | 
| 97 | 
             
                        nn.ReLU(),
         | 
| 98 | 
            +
                        nn.Linear(512, 128),
         | 
| 99 | 
             
                        nn.ReLU(),
         | 
| 100 | 
            +
                        nn.Linear(128, env.single_action_space.n),
         | 
| 101 | 
             
                    )
         | 
| 102 |  | 
| 103 | 
             
                def forward(self, x):
         | 
|  | |
| 115 | 
             
                if args.track:
         | 
| 116 | 
             
                    import wandb
         | 
| 117 |  | 
| 118 | 
            +
                    args.alg_type = os.path.basename(__file__)
         | 
| 119 | 
            +
                    wandb_sess = wandb.init(
         | 
| 120 | 
             
                        project=args.wandb_project_name,
         | 
| 121 | 
             
                        entity=args.wandb_entity,
         | 
|  | |
| 122 | 
             
                        config=vars(args),
         | 
| 123 | 
            +
                        save_code=True,
         | 
| 124 | 
            +
                        # group='string',
         | 
| 125 | 
             
                        name=run_name,
         | 
| 126 | 
            +
                        sync_tensorboard=False,
         | 
| 127 | 
             
                        monitor_gym=True,
         | 
|  | |
| 128 | 
             
                    )
         | 
| 129 | 
             
                writer = SummaryWriter(f"runs/{run_name}")
         | 
| 130 | 
             
                writer.add_text(
         | 
|  | |
| 132 | 
             
                    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
         | 
| 133 | 
             
                )
         | 
| 134 |  | 
| 135 | 
            +
                def log_value(name: str, x: float, y: int):
         | 
| 136 | 
            +
                    # writer.add_scalar(name, x, y)
         | 
| 137 | 
            +
                    wandb.log({name: x, "global_step": y})
         | 
| 138 | 
            +
             | 
| 139 | 
             
                # TRY NOT TO MODIFY: seeding
         | 
| 140 | 
             
                random.seed(args.seed)
         | 
| 141 | 
             
                np.random.seed(args.seed)
         | 
|  | |
| 147 | 
             
                # env setup
         | 
| 148 | 
             
                envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
         | 
| 149 | 
             
                assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
         | 
| 150 | 
            +
                envs.seed(args.seed)
         | 
| 151 |  | 
| 152 | 
             
                q_network = QNetwork(envs).to(device)
         | 
| 153 | 
            +
                optimizer = optim.RMSprop(q_network.parameters(), lr=args.learning_rate)
         | 
| 154 | 
             
                target_network = QNetwork(envs).to(device)
         | 
| 155 | 
             
                target_network.load_state_dict(q_network.state_dict())
         | 
| 156 |  | 
|  | |
| 159 | 
             
                    envs.single_observation_space,
         | 
| 160 | 
             
                    envs.single_action_space,
         | 
| 161 | 
             
                    device,
         | 
| 162 | 
            +
                    optimize_memory_usage=True,
         | 
| 163 | 
             
                    handle_timeout_termination=True,
         | 
| 164 | 
             
                )
         | 
| 165 | 
             
                start_time = time.time()
         | 
| 166 | 
            +
                policy_update_counter = 0
         | 
| 167 | 
            +
                episode_returns = []
         | 
| 168 |  | 
| 169 | 
             
                # TRY NOT TO MODIFY: start the game
         | 
| 170 | 
             
                obs = envs.reset()
         | 
| 171 | 
             
                for global_step in range(args.total_timesteps):
         | 
| 172 | 
             
                    # ALGO LOGIC: put action logic here
         | 
| 173 | 
             
                    epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
         | 
| 174 | 
            +
             | 
| 175 | 
             
                    if random.random() < epsilon:
         | 
| 176 | 
             
                        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
         | 
| 177 | 
             
                    else:
         | 
|  | |
| 184 | 
             
                    # TRY NOT TO MODIFY: record rewards for plotting purposes
         | 
| 185 | 
             
                    for info in infos:
         | 
| 186 | 
             
                        if "episode" in info.keys():
         | 
| 187 | 
            +
                            episode_returns.append(info['episode']['r'])
         | 
| 188 | 
            +
                            episode_returns = episode_returns[-100:]
         | 
| 189 | 
            +
                            print(f"step={global_step}, return={info['episode']['r']}, sps={int(global_step / (time.time() - start_time))}")
         | 
| 190 | 
            +
                            log_value("perf/episodic_return", info["episode"]["r"], global_step)
         | 
| 191 | 
            +
                            log_value("perf/episodic_return_mean_100", np.mean(episode_returns), global_step)
         | 
| 192 | 
            +
                            log_value("perf/episodic_return_std_100", np.std(episode_returns), global_step)
         | 
| 193 | 
            +
                            log_value("debug/episodic_length", info["episode"]["l"], global_step)
         | 
| 194 | 
            +
                            log_value("ex2/epsilon", epsilon, global_step)
         | 
| 195 | 
             
                            break
         | 
| 196 |  | 
| 197 | 
             
                    # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
         | 
|  | |
| 215 | 
             
                            loss = F.mse_loss(td_target, old_val)
         | 
| 216 |  | 
| 217 | 
             
                            if global_step % 100 == 0:
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                                prev = old_val.detach().cpu().numpy()
         | 
| 220 | 
            +
                                new = td_target.detach().cpu().numpy()
         | 
| 221 | 
            +
                                diff, a_diff = new-prev, np.abs(new-prev)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                                mean, a_mean = np.mean(diff), np.mean(a_diff)
         | 
| 224 | 
            +
                                median, a_median = np.median(diff), np.median(a_diff)
         | 
| 225 | 
            +
                                maximum, a_maximum = np.max(diff), np.max(a_diff)
         | 
| 226 | 
            +
                                minimum, a_minimum = np.min(diff), np.min(a_diff)
         | 
| 227 | 
            +
                                std, a_std = np.std(diff), np.std(a_diff)
         | 
| 228 | 
            +
                                below, a_below = mean - std, a_mean - a_std
         | 
| 229 | 
            +
                                above, a_above = mean + std, a_mean + a_std
         | 
| 230 | 
            +
                                pu_scalar, a_pu_scalar = 2 * mean / maximum, 2 * a_mean / a_maximum
         | 
| 231 | 
            +
                                policy_frequency_scalar_ratio = 1.0 * pu_scalar
         | 
| 232 | 
            +
                                a_policy_frequency_scalar_ratio = 1.0 * a_pu_scalar
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                                log_value("losses/td_loss", loss, global_step)
         | 
| 235 | 
            +
                                log_value("losses/q_values", old_val.mean().item(), global_step)
         | 
| 236 | 
            +
                                log_value("td/mean", mean, global_step)
         | 
| 237 | 
            +
                                log_value("td/a_mean", a_mean, global_step)
         | 
| 238 | 
            +
                                log_value("td/median", median, global_step)
         | 
| 239 | 
            +
                                log_value("td/a_median", a_median, global_step)
         | 
| 240 | 
            +
                                log_value("td/max", maximum, global_step)
         | 
| 241 | 
            +
                                log_value("td/a_max", a_maximum, global_step)
         | 
| 242 | 
            +
                                log_value("td/min", minimum, global_step)
         | 
| 243 | 
            +
                                log_value("td/a_min", a_minimum, global_step)
         | 
| 244 | 
            +
                                log_value("td/std", std, global_step)
         | 
| 245 | 
            +
                                log_value("td/a_std", a_std, global_step)
         | 
| 246 | 
            +
                                log_value("td/below", below, global_step)
         | 
| 247 | 
            +
                                log_value("td/a_below", a_below, global_step)
         | 
| 248 | 
            +
                                log_value("td/above", above, global_step)
         | 
| 249 | 
            +
                                log_value("td/a_above", a_above, global_step)
         | 
| 250 | 
            +
                                log_value("pu/pu_scalar", pu_scalar, global_step)
         | 
| 251 | 
            +
                                log_value("pu/a_pu_scalar", a_pu_scalar, global_step)
         | 
| 252 | 
            +
                                log_value("pu/policy_frequency_scalar_ratio", policy_frequency_scalar_ratio, global_step)
         | 
| 253 | 
            +
                                log_value("pu/a_policy_frequency_scalar_ratio", a_policy_frequency_scalar_ratio, global_step)
         | 
| 254 | 
            +
                                log_value("debug/steps_per_second", int(global_step / (time.time() - start_time)), global_step)
         | 
| 255 |  | 
| 256 | 
             
                            # optimize the model
         | 
| 257 | 
             
                            optimizer.zero_grad()
         | 
|  | |
| 262 | 
             
                        if global_step % args.target_network_frequency == 0:
         | 
| 263 | 
             
                            for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
         | 
| 264 | 
             
                                target_network_param.data.copy_(
         | 
| 265 | 
            +
                                    args.target_tau * q_network_param.data + (1.0 - args.target_tau) * target_network_param.data
         | 
| 266 | 
             
                                )
         | 
| 267 | 
            +
                        policy_update_counter += 1
         | 
| 268 | 
            +
                        if global_step % 100 == 0:
         | 
| 269 | 
            +
                            log_value("pu/n_policy_update", policy_update_counter, global_step)
         | 
| 270 |  | 
| 271 | 
             
                if args.save_model:
         | 
| 272 | 
             
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
|  | |
| 285 | 
             
                        epsilon=0.05,
         | 
| 286 | 
             
                    )
         | 
| 287 | 
             
                    for idx, episodic_return in enumerate(episodic_returns):
         | 
| 288 | 
            +
                        log_value("eval/episodic_return", episodic_return, idx)
         | 
| 289 |  | 
| 290 | 
             
                    if args.upload_model:
         | 
| 291 | 
             
                        from cleanrl_utils.huggingface import push_to_hub
         | 
| 292 |  | 
| 293 | 
             
                        repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
         | 
| 294 | 
             
                        repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
         | 
| 295 | 
            +
                        push_to_hub(args, np.mean(episode_returns), repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")
         | 
| 296 |  | 
| 297 | 
            +
                wandb_sess.finish()
         | 
| 298 | 
             
                envs.close()
         | 
| 299 | 
             
                writer.close()
         | 
    	
        events.out.tfevents.1676244619.wycliffeduncan-Victus-by-HP-Gaming-Laptop-15-fa0xxx.28826.0 → events.out.tfevents.1678647860.portal.3094185.0
    RENAMED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:9423c8373dbc5b7a9a951d2520c14b332177a31609844d931aa8865fe2761480
         | 
| 3 | 
            +
            size 634
         | 
    	
        replay.mp4
    CHANGED
    
    | Binary files a/replay.mp4 and b/replay.mp4 differ | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-0.mp4
    DELETED
    
    | Binary file (23.7 kB) | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-1.mp4
    DELETED
    
    | Binary file (20.1 kB) | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1676244613-eval/rl-video-episode-8.mp4
    DELETED
    
    | Binary file (22.1 kB) | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-0.mp4
    ADDED
    
    | Binary file (43.5 kB). View file | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-1.mp4
    ADDED
    
    | Binary file (41.1 kB). View file | 
|  | 
    	
        videos/CartPole-v1__DQN_baseline__1__1678647857-eval/rl-video-episode-8.mp4
    ADDED
    
    | Binary file (42.7 kB). View file | 
|  |