pushing model
Browse files- README.md +4 -3
- cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model +2 -2
- cleanba_ppo_envpool_impala_atari_wrapper.py +56 -9
- events.out.tfevents.1676611831.ip-26-0-130-181.1151330.0 → events.out.tfevents.1678210197.ip-26-0-141-70 +2 -2
- poetry.lock +0 -0
- pyproject.toml +18 -162
- replay.mp4 +0 -0
- videos/Tutankham-v5__cleanba_ppo_envpool_impala_atari_wrapper__1__8e9eb61e-29b5-4771-b9d3-bd644ea96b7a-eval/0.mp4 +0 -0
- videos/Tutankham-v5__cleanba_ppo_envpool_impala_atari_wrapper__1__f396779d-340e-4192-8178-86eb90315d3f-eval/0.mp4 +0 -0
    	
        README.md
    CHANGED
    
    | @@ -16,7 +16,7 @@ model-index: | |
| 16 | 
             
                  type: Tutankham-v5
         | 
| 17 | 
             
                metrics:
         | 
| 18 | 
             
                - type: mean_reward
         | 
| 19 | 
            -
                  value:  | 
| 20 | 
             
                  name: mean_reward
         | 
| 21 | 
             
                  verified: false
         | 
| 22 | 
             
            ---
         | 
| @@ -46,7 +46,7 @@ curl -OL https://huggingface.co/cleanrl/Tutankham-v5-cleanba_ppo_envpool_impala_ | |
| 46 | 
             
            curl -OL https://huggingface.co/cleanrl/Tutankham-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
         | 
| 47 | 
             
            curl -OL https://huggingface.co/cleanrl/Tutankham-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
         | 
| 48 | 
             
            poetry install --all-extras
         | 
| 49 | 
            -
            python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-device-ids 1 2 3 --track --save-model --upload-model --hf-entity cleanrl --env-id Tutankham-v5 --seed 1
         | 
| 50 | 
             
            ```
         | 
| 51 |  | 
| 52 | 
             
            # Hyperparameters
         | 
| @@ -59,6 +59,7 @@ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-devic | |
| 59 | 
             
             'batch_size': 15360,
         | 
| 60 | 
             
             'capture_video': False,
         | 
| 61 | 
             
             'clip_coef': 0.1,
         | 
|  | |
| 62 | 
             
             'cuda': True,
         | 
| 63 | 
             
             'distributed': True,
         | 
| 64 | 
             
             'ent_coef': 0.01,
         | 
| @@ -99,7 +100,7 @@ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-devic | |
| 99 | 
             
             'upload_model': True,
         | 
| 100 | 
             
             'vf_coef': 0.5,
         | 
| 101 | 
             
             'wandb_entity': None,
         | 
| 102 | 
            -
             'wandb_project_name': ' | 
| 103 | 
             
             'world_size': 2}
         | 
| 104 | 
             
            ```
         | 
| 105 |  | 
|  | |
| 16 | 
             
                  type: Tutankham-v5
         | 
| 17 | 
             
                metrics:
         | 
| 18 | 
             
                - type: mean_reward
         | 
| 19 | 
            +
                  value: 279.00 +/- 39.06
         | 
| 20 | 
             
                  name: mean_reward
         | 
| 21 | 
             
                  verified: false
         | 
| 22 | 
             
            ---
         | 
|  | |
| 46 | 
             
            curl -OL https://huggingface.co/cleanrl/Tutankham-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
         | 
| 47 | 
             
            curl -OL https://huggingface.co/cleanrl/Tutankham-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
         | 
| 48 | 
             
            poetry install --all-extras
         | 
| 49 | 
            +
            python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-device-ids 1 2 3 --track --wandb-project-name cleanba --save-model --upload-model --hf-entity cleanrl --env-id Tutankham-v5 --seed 1
         | 
| 50 | 
             
            ```
         | 
| 51 |  | 
| 52 | 
             
            # Hyperparameters
         | 
|  | |
| 59 | 
             
             'batch_size': 15360,
         | 
| 60 | 
             
             'capture_video': False,
         | 
| 61 | 
             
             'clip_coef': 0.1,
         | 
| 62 | 
            +
             'concurrency': True,
         | 
| 63 | 
             
             'cuda': True,
         | 
| 64 | 
             
             'distributed': True,
         | 
| 65 | 
             
             'ent_coef': 0.01,
         | 
|  | |
| 100 | 
             
             'upload_model': True,
         | 
| 101 | 
             
             'vf_coef': 0.5,
         | 
| 102 | 
             
             'wandb_entity': None,
         | 
| 103 | 
            +
             'wandb_project_name': 'cleanba',
         | 
| 104 | 
             
             'world_size': 2}
         | 
| 105 | 
             
            ```
         | 
| 106 |  | 
    	
        cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:5aecff440c81fc86e2089d9539106bf0e9a7afe02d9551bb9fad02d731f4170e
         | 
| 3 | 
            +
            size 4368279
         | 
    	
        cleanba_ppo_envpool_impala_atari_wrapper.py
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
         | 
| 2 | 
             
            import argparse
         | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            import random
         | 
| @@ -26,7 +25,7 @@ import numpy as np | |
| 26 | 
             
            import optax
         | 
| 27 | 
             
            from flax.linen.initializers import constant, orthogonal
         | 
| 28 | 
             
            from flax.training.train_state import TrainState
         | 
| 29 | 
            -
            from  | 
| 30 |  | 
| 31 |  | 
| 32 | 
             
            def parse_args():
         | 
| @@ -47,7 +46,7 @@ def parse_args(): | |
| 47 | 
             
                parser.add_argument("--wandb-entity", type=str, default=None,
         | 
| 48 | 
             
                    help="the entity (team) of wandb's project")
         | 
| 49 | 
             
                parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 50 | 
            -
                    help=" | 
| 51 | 
             
                parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 52 | 
             
                    help="whether to save model into the `runs/{run_name}` folder")
         | 
| 53 | 
             
                parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| @@ -97,6 +96,8 @@ def parse_args(): | |
| 97 | 
             
                    help="the device ids that learner workers will use")
         | 
| 98 | 
             
                parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 99 | 
             
                    help="whether to use `jax.distirbuted`")
         | 
|  | |
|  | |
| 100 | 
             
                parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 101 | 
             
                    help="whether to call block_until_ready() for profiling")
         | 
| 102 | 
             
                parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| @@ -213,7 +214,7 @@ class AgentParams: | |
| 213 |  | 
| 214 | 
             
            @partial(jax.jit, static_argnums=(3))
         | 
| 215 | 
             
            def get_action_and_value(
         | 
| 216 | 
            -
                params:  | 
| 217 | 
             
                next_obs: np.ndarray,
         | 
| 218 | 
             
                key: jax.random.PRNGKey,
         | 
| 219 | 
             
                action_dim: int,
         | 
| @@ -281,6 +282,20 @@ def prepare_data( | |
| 281 | 
             
                return b_obs, b_actions, b_logprobs, b_advantages, b_returns
         | 
| 282 |  | 
| 283 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 284 | 
             
            def rollout(
         | 
| 285 | 
             
                key: jax.random.PRNGKey,
         | 
| 286 | 
             
                args,
         | 
| @@ -289,7 +304,7 @@ def rollout( | |
| 289 | 
             
                writer,
         | 
| 290 | 
             
                learner_devices,
         | 
| 291 | 
             
            ):
         | 
| 292 | 
            -
                envs = make_env(args.env_id, args.seed, args.local_num_envs, args.async_batch_size)()
         | 
| 293 | 
             
                len_actor_device_ids = len(args.actor_device_ids)
         | 
| 294 | 
             
                global_step = 0
         | 
| 295 | 
             
                # TRY NOT TO MODIFY: start the game
         | 
| @@ -332,9 +347,13 @@ def rollout( | |
| 332 | 
             
                    # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
         | 
| 333 | 
             
                    # behind the learner's policy version
         | 
| 334 | 
             
                    params_queue_get_time_start = time.time()
         | 
| 335 | 
            -
                    if  | 
| 336 | 
             
                        params = params_queue.get()
         | 
| 337 | 
             
                        actor_policy_version += 1
         | 
|  | |
|  | |
|  | |
|  | |
| 338 | 
             
                    params_queue_get_time.append(time.time() - params_queue_get_time_start)
         | 
| 339 | 
             
                    writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
         | 
| 340 | 
             
                    rollout_time_start = time.time()
         | 
| @@ -397,18 +416,29 @@ def rollout( | |
| 397 | 
             
                    writer.add_scalar("stats/inference_time", inference_time, global_step)
         | 
| 398 | 
             
                    writer.add_scalar("stats/storage_time", storage_time, global_step)
         | 
| 399 | 
             
                    writer.add_scalar("stats/env_send_time", env_send_time, global_step)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 400 |  | 
| 401 | 
             
                    payload = (
         | 
| 402 | 
             
                        global_step,
         | 
| 403 | 
             
                        actor_policy_version,
         | 
| 404 | 
             
                        update,
         | 
| 405 | 
             
                        obs,
         | 
| 406 | 
            -
                        dones,
         | 
| 407 | 
             
                        values,
         | 
| 408 | 
             
                        actions,
         | 
| 409 | 
             
                        logprobs,
         | 
|  | |
| 410 | 
             
                        env_ids,
         | 
| 411 | 
             
                        rewards,
         | 
|  | |
| 412 | 
             
                    )
         | 
| 413 | 
             
                    if update == 1 or not args.test_actor_learner_throughput:
         | 
| 414 | 
             
                        rollout_queue_put_time_start = time.time()
         | 
| @@ -717,15 +747,21 @@ if __name__ == "__main__": | |
| 717 | 
             
                            actor_policy_version,
         | 
| 718 | 
             
                            update,
         | 
| 719 | 
             
                            obs,
         | 
| 720 | 
            -
                            dones,
         | 
| 721 | 
             
                            values,
         | 
| 722 | 
             
                            actions,
         | 
| 723 | 
             
                            logprobs,
         | 
|  | |
| 724 | 
             
                            env_ids,
         | 
| 725 | 
             
                            rewards,
         | 
|  | |
| 726 | 
             
                        ) = rollout_queue.get()
         | 
| 727 | 
             
                        rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
         | 
| 728 | 
             
                        writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 729 |  | 
| 730 | 
             
                    data_transfer_time_start = time.time()
         | 
| 731 | 
             
                    b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
         | 
| @@ -780,11 +816,22 @@ if __name__ == "__main__": | |
| 780 | 
             
                        break
         | 
| 781 |  | 
| 782 | 
             
                if args.save_model and args.local_rank == 0:
         | 
|  | |
|  | |
| 783 | 
             
                    agent_state = flax.jax_utils.unreplicate(agent_state)
         | 
| 784 | 
             
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
| 785 | 
             
                    with open(model_path, "wb") as f:
         | 
| 786 | 
             
                        f.write(
         | 
| 787 | 
            -
                            flax.serialization.to_bytes( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 788 | 
             
                        )
         | 
| 789 | 
             
                    print(f"model saved to {model_path}")
         | 
| 790 | 
             
                    from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
         | 
|  | |
|  | |
| 1 | 
             
            import argparse
         | 
| 2 | 
             
            import os
         | 
| 3 | 
             
            import random
         | 
|  | |
| 25 | 
             
            import optax
         | 
| 26 | 
             
            from flax.linen.initializers import constant, orthogonal
         | 
| 27 | 
             
            from flax.training.train_state import TrainState
         | 
| 28 | 
            +
            from tensorboardX import SummaryWriter
         | 
| 29 |  | 
| 30 |  | 
| 31 | 
             
            def parse_args():
         | 
|  | |
| 46 | 
             
                parser.add_argument("--wandb-entity", type=str, default=None,
         | 
| 47 | 
             
                    help="the entity (team) of wandb's project")
         | 
| 48 | 
             
                parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 49 | 
            +
                    help="whether to capture videos of the agent performances (check out `videos` folder)")
         | 
| 50 | 
             
                parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 51 | 
             
                    help="whether to save model into the `runs/{run_name}` folder")
         | 
| 52 | 
             
                parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
|  | |
| 96 | 
             
                    help="the device ids that learner workers will use")
         | 
| 97 | 
             
                parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 98 | 
             
                    help="whether to use `jax.distirbuted`")
         | 
| 99 | 
            +
                parser.add_argument("--concurrency", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 100 | 
            +
                    help="whether to run the actor and learner concurrently")
         | 
| 101 | 
             
                parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 102 | 
             
                    help="whether to call block_until_ready() for profiling")
         | 
| 103 | 
             
                parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
|  | |
| 214 |  | 
| 215 | 
             
            @partial(jax.jit, static_argnums=(3))
         | 
| 216 | 
             
            def get_action_and_value(
         | 
| 217 | 
            +
                params: flax.core.FrozenDict,
         | 
| 218 | 
             
                next_obs: np.ndarray,
         | 
| 219 | 
             
                key: jax.random.PRNGKey,
         | 
| 220 | 
             
                action_dim: int,
         | 
|  | |
| 282 | 
             
                return b_obs, b_actions, b_logprobs, b_advantages, b_returns
         | 
| 283 |  | 
| 284 |  | 
| 285 | 
            +
            @jax.jit
         | 
| 286 | 
            +
            def make_bulk_array(
         | 
| 287 | 
            +
                obs: list,
         | 
| 288 | 
            +
                values: list,
         | 
| 289 | 
            +
                actions: list,
         | 
| 290 | 
            +
                logprobs: list,
         | 
| 291 | 
            +
            ):
         | 
| 292 | 
            +
                obs = jnp.asarray(obs)
         | 
| 293 | 
            +
                values = jnp.asarray(values)
         | 
| 294 | 
            +
                actions = jnp.asarray(actions)
         | 
| 295 | 
            +
                logprobs = jnp.asarray(logprobs)
         | 
| 296 | 
            +
                return obs, values, actions, logprobs
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
             
            def rollout(
         | 
| 300 | 
             
                key: jax.random.PRNGKey,
         | 
| 301 | 
             
                args,
         | 
|  | |
| 304 | 
             
                writer,
         | 
| 305 | 
             
                learner_devices,
         | 
| 306 | 
             
            ):
         | 
| 307 | 
            +
                envs = make_env(args.env_id, args.seed + jax.process_index(), args.local_num_envs, args.async_batch_size)()
         | 
| 308 | 
             
                len_actor_device_ids = len(args.actor_device_ids)
         | 
| 309 | 
             
                global_step = 0
         | 
| 310 | 
             
                # TRY NOT TO MODIFY: start the game
         | 
|  | |
| 347 | 
             
                    # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
         | 
| 348 | 
             
                    # behind the learner's policy version
         | 
| 349 | 
             
                    params_queue_get_time_start = time.time()
         | 
| 350 | 
            +
                    if not args.concurrency:
         | 
| 351 | 
             
                        params = params_queue.get()
         | 
| 352 | 
             
                        actor_policy_version += 1
         | 
| 353 | 
            +
                    else:
         | 
| 354 | 
            +
                        if update != 2:
         | 
| 355 | 
            +
                            params = params_queue.get()
         | 
| 356 | 
            +
                            actor_policy_version += 1
         | 
| 357 | 
             
                    params_queue_get_time.append(time.time() - params_queue_get_time_start)
         | 
| 358 | 
             
                    writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
         | 
| 359 | 
             
                    rollout_time_start = time.time()
         | 
|  | |
| 416 | 
             
                    writer.add_scalar("stats/inference_time", inference_time, global_step)
         | 
| 417 | 
             
                    writer.add_scalar("stats/storage_time", storage_time, global_step)
         | 
| 418 | 
             
                    writer.add_scalar("stats/env_send_time", env_send_time, global_step)
         | 
| 419 | 
            +
                    # `make_bulk_array` is actually important. It accumulates the data from the lists
         | 
| 420 | 
            +
                    # into single bulk arrays, which later makes transferring the data to the learner's
         | 
| 421 | 
            +
                    # device slightly faster. See https://wandb.ai/costa-huang/cleanRL/reports/data-transfer-optimization--VmlldzozNjU5MTg1
         | 
| 422 | 
            +
                    if args.learner_device_ids[0] != args.actor_device_ids[0]:
         | 
| 423 | 
            +
                        obs, values, actions, logprobs = make_bulk_array(
         | 
| 424 | 
            +
                            obs,
         | 
| 425 | 
            +
                            values,
         | 
| 426 | 
            +
                            actions,
         | 
| 427 | 
            +
                            logprobs,
         | 
| 428 | 
            +
                        )
         | 
| 429 |  | 
| 430 | 
             
                    payload = (
         | 
| 431 | 
             
                        global_step,
         | 
| 432 | 
             
                        actor_policy_version,
         | 
| 433 | 
             
                        update,
         | 
| 434 | 
             
                        obs,
         | 
|  | |
| 435 | 
             
                        values,
         | 
| 436 | 
             
                        actions,
         | 
| 437 | 
             
                        logprobs,
         | 
| 438 | 
            +
                        dones,
         | 
| 439 | 
             
                        env_ids,
         | 
| 440 | 
             
                        rewards,
         | 
| 441 | 
            +
                        np.mean(params_queue_get_time),
         | 
| 442 | 
             
                    )
         | 
| 443 | 
             
                    if update == 1 or not args.test_actor_learner_throughput:
         | 
| 444 | 
             
                        rollout_queue_put_time_start = time.time()
         | 
|  | |
| 747 | 
             
                            actor_policy_version,
         | 
| 748 | 
             
                            update,
         | 
| 749 | 
             
                            obs,
         | 
|  | |
| 750 | 
             
                            values,
         | 
| 751 | 
             
                            actions,
         | 
| 752 | 
             
                            logprobs,
         | 
| 753 | 
            +
                            dones,
         | 
| 754 | 
             
                            env_ids,
         | 
| 755 | 
             
                            rewards,
         | 
| 756 | 
            +
                            avg_params_queue_get_time,
         | 
| 757 | 
             
                        ) = rollout_queue.get()
         | 
| 758 | 
             
                        rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
         | 
| 759 | 
             
                        writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
         | 
| 760 | 
            +
                        writer.add_scalar(
         | 
| 761 | 
            +
                            "stats/rollout_params_queue_get_time_diff",
         | 
| 762 | 
            +
                            np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
         | 
| 763 | 
            +
                            global_step,
         | 
| 764 | 
            +
                        )
         | 
| 765 |  | 
| 766 | 
             
                    data_transfer_time_start = time.time()
         | 
| 767 | 
             
                    b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
         | 
|  | |
| 816 | 
             
                        break
         | 
| 817 |  | 
| 818 | 
             
                if args.save_model and args.local_rank == 0:
         | 
| 819 | 
            +
                    if args.distributed:
         | 
| 820 | 
            +
                        jax.distributed.shutdown()
         | 
| 821 | 
             
                    agent_state = flax.jax_utils.unreplicate(agent_state)
         | 
| 822 | 
             
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
| 823 | 
             
                    with open(model_path, "wb") as f:
         | 
| 824 | 
             
                        f.write(
         | 
| 825 | 
            +
                            flax.serialization.to_bytes(
         | 
| 826 | 
            +
                                [
         | 
| 827 | 
            +
                                    vars(args),
         | 
| 828 | 
            +
                                    [
         | 
| 829 | 
            +
                                        agent_state.params.network_params,
         | 
| 830 | 
            +
                                        agent_state.params.actor_params,
         | 
| 831 | 
            +
                                        agent_state.params.critic_params,
         | 
| 832 | 
            +
                                    ],
         | 
| 833 | 
            +
                                ]
         | 
| 834 | 
            +
                            )
         | 
| 835 | 
             
                        )
         | 
| 836 | 
             
                    print(f"model saved to {model_path}")
         | 
| 837 | 
             
                    from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
         | 
    	
        events.out.tfevents.1676611831.ip-26-0-130-181.1151330.0 → events.out.tfevents.1678210197.ip-26-0-141-70
    RENAMED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            -
            oid sha256: | 
| 3 | 
            -
            size  | 
|  | |
| 1 | 
             
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:2fb9ad8c93aacb46dbf269fe778fb9847e7547f373457f24781dfb5729f52b5d
         | 
| 3 | 
            +
            size 5017750
         | 
    	
        poetry.lock
    CHANGED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        pyproject.toml
    CHANGED
    
    | @@ -1,178 +1,34 @@ | |
| 1 | 
             
            [tool.poetry]
         | 
| 2 | 
            -
            name = " | 
| 3 | 
            -
            version = " | 
| 4 | 
            -
            description = " | 
| 5 | 
             
            authors = ["Costa Huang <[email protected]>"]
         | 
|  | |
| 6 | 
             
            packages = [
         | 
| 7 | 
            -
                { include = " | 
| 8 | 
             
                { include = "cleanrl_utils" },
         | 
| 9 | 
             
            ]
         | 
| 10 | 
            -
            keywords = ["reinforcement", "machine", "learning", "research"]
         | 
| 11 | 
            -
            license="MIT"
         | 
| 12 | 
            -
            readme = "README.md"
         | 
| 13 |  | 
| 14 | 
             
            [tool.poetry.dependencies]
         | 
| 15 | 
            -
            python = " | 
| 16 | 
            -
            tensorboard = "^2. | 
| 17 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 | 
             
            gym = "0.23.1"
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            stable-baselines3 = "1.2.0"
         | 
| 21 | 
            -
            gymnasium = "^0.26.3"
         | 
| 22 | 
             
            moviepy = "^1.0.3"
         | 
| 23 | 
            -
            pygame = "2.1.0"
         | 
| 24 | 
            -
            huggingface-hub = "^0.11.1"
         | 
| 25 |  | 
| 26 | 
            -
            ale-py = {version = "0.7.4", optional = true}
         | 
| 27 | 
            -
            AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
         | 
| 28 | 
            -
            opencv-python = {version = "^4.6.0.66", optional = true}
         | 
| 29 | 
            -
            pybullet = {version = "3.1.8", optional = true}
         | 
| 30 | 
            -
            procgen = {version = "^0.10.7", optional = true}
         | 
| 31 | 
            -
            pytest = {version = "^7.1.3", optional = true}
         | 
| 32 | 
            -
            mujoco = {version = "^2.2", optional = true}
         | 
| 33 | 
            -
            imageio = {version = "^2.14.1", optional = true}
         | 
| 34 | 
            -
            free-mujoco-py = {version = "^2.1.6", optional = true}
         | 
| 35 | 
            -
            mkdocs-material = {version = "^8.4.3", optional = true}
         | 
| 36 | 
            -
            markdown-include = {version = "^0.7.0", optional = true}
         | 
| 37 | 
            -
            jax = {version = "^0.3.17", optional = true}
         | 
| 38 | 
            -
            jaxlib = {version = "^0.3.15", optional = true}
         | 
| 39 | 
            -
            flax = {version = "^0.6.0", optional = true}
         | 
| 40 | 
            -
            optuna = {version = "^3.0.1", optional = true}
         | 
| 41 | 
            -
            optuna-dashboard = {version = "^0.7.2", optional = true}
         | 
| 42 | 
            -
            rich = {version = "<12.0", optional = true}
         | 
| 43 | 
            -
            envpool = {version = "^0.8.1", optional = true}
         | 
| 44 | 
            -
            PettingZoo = {version = "1.18.1", optional = true}
         | 
| 45 | 
            -
            SuperSuit = {version = "3.4.0", optional = true}
         | 
| 46 | 
            -
            multi-agent-ale-py = {version = "0.1.11", optional = true}
         | 
| 47 | 
            -
            boto3 = {version = "^1.24.70", optional = true}
         | 
| 48 | 
            -
            awscli = {version = "^1.25.71", optional = true}
         | 
| 49 | 
            -
            shimmy = {version = "^0.1.0", optional = true}
         | 
| 50 | 
            -
            dm-control = {version = "^1.0.8", optional = true}
         | 
| 51 |  | 
| 52 | 
             
            [tool.poetry.group.dev.dependencies]
         | 
| 53 | 
            -
            pre-commit = "^ | 
| 54 | 
            -
             | 
| 55 | 
            -
            [tool.poetry.group.atari]
         | 
| 56 | 
            -
            optional = true
         | 
| 57 | 
            -
            [tool.poetry.group.atari.dependencies]
         | 
| 58 | 
            -
            ale-py = "0.7.4"
         | 
| 59 | 
            -
            AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
         | 
| 60 | 
            -
            opencv-python = "^4.6.0.66"
         | 
| 61 | 
            -
             | 
| 62 | 
            -
            [tool.poetry.group.pybullet]
         | 
| 63 | 
            -
            optional = true
         | 
| 64 | 
            -
            [tool.poetry.group.pybullet.dependencies]
         | 
| 65 | 
            -
            pybullet = "3.1.8"
         | 
| 66 | 
            -
             | 
| 67 | 
            -
            [tool.poetry.group.procgen]
         | 
| 68 | 
            -
            optional = true
         | 
| 69 | 
            -
            [tool.poetry.group.procgen.dependencies]
         | 
| 70 | 
            -
            procgen = "^0.10.7"
         | 
| 71 | 
            -
             | 
| 72 | 
            -
            [tool.poetry.group.pytest]
         | 
| 73 | 
            -
            optional = true
         | 
| 74 | 
            -
            [tool.poetry.group.pytest.dependencies]
         | 
| 75 | 
            -
            pytest = "^7.1.3"
         | 
| 76 | 
            -
             | 
| 77 | 
            -
            [tool.poetry.group.mujoco]
         | 
| 78 | 
            -
            optional = true
         | 
| 79 | 
            -
            [tool.poetry.group.mujoco.dependencies]
         | 
| 80 | 
            -
            mujoco = "^2.2"
         | 
| 81 | 
            -
            imageio = "^2.14.1"
         | 
| 82 | 
            -
             | 
| 83 | 
            -
            [tool.poetry.group.mujoco_py]
         | 
| 84 | 
            -
            optional = true
         | 
| 85 | 
            -
            [tool.poetry.group.mujoco_py.dependencies]
         | 
| 86 | 
            -
            free-mujoco-py = "^2.1.6"
         | 
| 87 | 
            -
             | 
| 88 | 
            -
            [tool.poetry.group.docs]
         | 
| 89 | 
            -
            optional = true
         | 
| 90 | 
            -
            [tool.poetry.group.docs.dependencies]
         | 
| 91 | 
            -
            mkdocs-material = "^8.4.3"
         | 
| 92 | 
            -
            markdown-include = "^0.7.0"
         | 
| 93 | 
            -
             | 
| 94 | 
            -
            [tool.poetry.group.jax]
         | 
| 95 | 
            -
            optional = true
         | 
| 96 | 
            -
            [tool.poetry.group.jax.dependencies]
         | 
| 97 | 
            -
            jax = "^0.3.17"
         | 
| 98 | 
            -
            jaxlib = "^0.3.15"
         | 
| 99 | 
            -
            flax = "^0.6.0"
         | 
| 100 | 
            -
             | 
| 101 | 
            -
            [tool.poetry.group.optuna]
         | 
| 102 | 
            -
            optional = true
         | 
| 103 | 
            -
            [tool.poetry.group.optuna.dependencies]
         | 
| 104 | 
            -
            optuna = "^3.0.1"
         | 
| 105 | 
            -
            optuna-dashboard = "^0.7.2"
         | 
| 106 | 
            -
            rich = "<12.0"
         | 
| 107 | 
            -
             | 
| 108 | 
            -
            [tool.poetry.group.envpool]
         | 
| 109 | 
            -
            optional = true
         | 
| 110 | 
            -
            [tool.poetry.group.envpool.dependencies]
         | 
| 111 | 
            -
            envpool = "^0.8.1"
         | 
| 112 | 
            -
             | 
| 113 | 
            -
            [tool.poetry.group.pettingzoo]
         | 
| 114 | 
            -
            optional = true
         | 
| 115 | 
            -
            [tool.poetry.group.pettingzoo.dependencies]
         | 
| 116 | 
            -
            PettingZoo = "1.18.1"
         | 
| 117 | 
            -
            SuperSuit = "3.4.0"
         | 
| 118 | 
            -
            multi-agent-ale-py = "0.1.11"
         | 
| 119 | 
            -
             | 
| 120 | 
            -
            [tool.poetry.group.cloud]
         | 
| 121 | 
            -
            optional = true
         | 
| 122 | 
            -
            [tool.poetry.group.cloud.dependencies]
         | 
| 123 | 
            -
            boto3 = "^1.24.70"
         | 
| 124 | 
            -
            awscli = "^1.25.71"
         | 
| 125 | 
            -
             | 
| 126 | 
            -
            [tool.poetry.group.isaacgym]
         | 
| 127 | 
            -
            optional = true
         | 
| 128 | 
            -
            [tool.poetry.group.isaacgym.dependencies]
         | 
| 129 | 
            -
            isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry"}
         | 
| 130 | 
            -
            isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
         | 
| 131 | 
            -
             | 
| 132 | 
            -
            [tool.poetry.group.dm_control]
         | 
| 133 | 
            -
            optional = true
         | 
| 134 | 
            -
            [tool.poetry.group.dm_control.dependencies]
         | 
| 135 | 
            -
            shimmy = "^0.1.0"
         | 
| 136 | 
            -
            dm-control = "^1.0.8"
         | 
| 137 | 
            -
            mujoco = "^2.2"
         | 
| 138 |  | 
| 139 | 
             
            [build-system]
         | 
| 140 | 
             
            requires = ["poetry-core"]
         | 
| 141 | 
             
            build-backend = "poetry.core.masonry.api"
         | 
| 142 | 
            -
             | 
| 143 | 
            -
            [tool.poetry.extras]
         | 
| 144 | 
            -
            atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 145 | 
            -
            pybullet = ["pybullet"]
         | 
| 146 | 
            -
            procgen = ["procgen"]
         | 
| 147 | 
            -
            plot = ["pandas", "seaborn"]
         | 
| 148 | 
            -
            pytest = ["pytest"]
         | 
| 149 | 
            -
            mujoco = ["mujoco", "imageio"]
         | 
| 150 | 
            -
            mujoco_py = ["free-mujoco-py"]
         | 
| 151 | 
            -
            jax = ["jax", "jaxlib", "flax"]
         | 
| 152 | 
            -
            docs = ["mkdocs-material", "markdown-include"]
         | 
| 153 | 
            -
            envpool = ["envpool"]
         | 
| 154 | 
            -
            optuna = ["optuna", "optuna-dashboard", "rich"]
         | 
| 155 | 
            -
            pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
         | 
| 156 | 
            -
            cloud = ["boto3", "awscli"]
         | 
| 157 | 
            -
            dm_control = ["shimmy", "dm-control", "mujoco"]
         | 
| 158 | 
            -
             | 
| 159 | 
            -
            # dependencies for algorithm variant (useful when you want to run a specific algorithm)
         | 
| 160 | 
            -
            dqn = []
         | 
| 161 | 
            -
            dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 162 | 
            -
            dqn_jax = ["jax", "jaxlib", "flax"]
         | 
| 163 | 
            -
            dqn_atari_jax = [
         | 
| 164 | 
            -
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 165 | 
            -
                "jax", "jaxlib", "flax" # jax
         | 
| 166 | 
            -
            ]
         | 
| 167 | 
            -
            c51 = []
         | 
| 168 | 
            -
            c51_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 169 | 
            -
            c51_jax = ["jax", "jaxlib", "flax"]
         | 
| 170 | 
            -
            c51_atari_jax = [
         | 
| 171 | 
            -
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 172 | 
            -
                "jax", "jaxlib", "flax" # jax
         | 
| 173 | 
            -
            ]
         | 
| 174 | 
            -
            ppo_atari_envpool_xla_jax_scan = [
         | 
| 175 | 
            -
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 176 | 
            -
                "jax", "jaxlib", "flax", # jax
         | 
| 177 | 
            -
                "envpool", # envpool
         | 
| 178 | 
            -
            ]
         | 
|  | |
| 1 | 
             
            [tool.poetry]
         | 
| 2 | 
            +
            name = "cleanba"
         | 
| 3 | 
            +
            version = "0.1.0"
         | 
| 4 | 
            +
            description = ""
         | 
| 5 | 
             
            authors = ["Costa Huang <[email protected]>"]
         | 
| 6 | 
            +
            readme = "README.md"
         | 
| 7 | 
             
            packages = [
         | 
| 8 | 
            +
                { include = "cleanba" },
         | 
| 9 | 
             
                { include = "cleanrl_utils" },
         | 
| 10 | 
             
            ]
         | 
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            [tool.poetry.dependencies]
         | 
| 13 | 
            +
            python = "^3.8"
         | 
| 14 | 
            +
            tensorboard = "^2.12.0"
         | 
| 15 | 
            +
            envpool = "^0.8.1"
         | 
| 16 | 
            +
            jax = "0.3.25"
         | 
| 17 | 
            +
            flax = "0.6.0"
         | 
| 18 | 
            +
            optax = "0.1.3"
         | 
| 19 | 
            +
            huggingface-hub = "^0.12.0"
         | 
| 20 | 
            +
            jaxlib = "0.3.25"
         | 
| 21 | 
            +
            wandb = "^0.13.10"
         | 
| 22 | 
            +
            tensorboardx = "^2.5.1"
         | 
| 23 | 
            +
            chex = "0.1.5"
         | 
| 24 | 
             
            gym = "0.23.1"
         | 
| 25 | 
            +
            opencv-python = "^4.7.0.68"
         | 
|  | |
|  | |
| 26 | 
             
            moviepy = "^1.0.3"
         | 
|  | |
|  | |
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 |  | 
| 29 | 
             
            [tool.poetry.group.dev.dependencies]
         | 
| 30 | 
            +
            pre-commit = "^3.0.4"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 |  | 
| 32 | 
             
            [build-system]
         | 
| 33 | 
             
            requires = ["poetry-core"]
         | 
| 34 | 
             
            build-backend = "poetry.core.masonry.api"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        replay.mp4
    CHANGED
    
    | Binary files a/replay.mp4 and b/replay.mp4 differ | 
|  | 
    	
        videos/Tutankham-v5__cleanba_ppo_envpool_impala_atari_wrapper__1__8e9eb61e-29b5-4771-b9d3-bd644ea96b7a-eval/0.mp4
    ADDED
    
    | Binary file (215 kB). View file | 
|  | 
    	
        videos/Tutankham-v5__cleanba_ppo_envpool_impala_atari_wrapper__1__f396779d-340e-4192-8178-86eb90315d3f-eval/0.mp4
    DELETED
    
    | Binary file (146 kB) | 
|  | 
