Commit 
							
							·
						
						29624a8
	
1
								Parent(s):
							
							4c14fe1
								
pushing model
Browse files- DQN_baseline_gdrl_4.cleanrl_model +0 -0
- README.md +79 -0
- dqn.py +255 -0
- events.out.tfevents.1677090125.redi.243042.0 +3 -0
- poetry.lock +0 -0
- pyproject.toml +178 -0
- replay.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-0.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-1.mp4 +0 -0
- videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-8.mp4 +0 -0
    	
        DQN_baseline_gdrl_4.cleanrl_model
    ADDED
    
    | Binary file (276 kB). View file | 
|  | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,79 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            tags:
         | 
| 3 | 
            +
            - CartPole-v1
         | 
| 4 | 
            +
            - deep-reinforcement-learning
         | 
| 5 | 
            +
            - reinforcement-learning
         | 
| 6 | 
            +
            - custom-implementation
         | 
| 7 | 
            +
            library_name: cleanrl
         | 
| 8 | 
            +
            model-index:
         | 
| 9 | 
            +
            - name: DQN
         | 
| 10 | 
            +
              results:
         | 
| 11 | 
            +
              - task:
         | 
| 12 | 
            +
                  type: reinforcement-learning
         | 
| 13 | 
            +
                  name: reinforcement-learning
         | 
| 14 | 
            +
                dataset:
         | 
| 15 | 
            +
                  name: CartPole-v1
         | 
| 16 | 
            +
                  type: CartPole-v1
         | 
| 17 | 
            +
                metrics:
         | 
| 18 | 
            +
                - type: mean_reward
         | 
| 19 | 
            +
                  value: 500.00 +/- 0.00
         | 
| 20 | 
            +
                  name: mean_reward
         | 
| 21 | 
            +
                  verified: false
         | 
| 22 | 
            +
            ---
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # (CleanRL) **DQN** Agent Playing **CartPole-v1**
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            This is a trained model of a DQN agent playing CartPole-v1.
         | 
| 27 | 
            +
            The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
         | 
| 28 | 
            +
            found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/DQN_baseline_gdrl_4.py).
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ## Get Started
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            To use this model, please install the `cleanrl` package with the following command:
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ```
         | 
| 35 | 
            +
            pip install "cleanrl[DQN_baseline_gdrl_4]"
         | 
| 36 | 
            +
            python -m cleanrl_utils.enjoy --exp-name DQN_baseline_gdrl_4 --env-id CartPole-v1
         | 
| 37 | 
            +
            ```
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            ## Command to reproduce the training
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            ```bash
         | 
| 45 | 
            +
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline_gdrl_4-seed1/raw/main/dqn.py
         | 
| 46 | 
            +
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline_gdrl_4-seed1/raw/main/pyproject.toml
         | 
| 47 | 
            +
            curl -OL https://huggingface.co/pfunk/CartPole-v1-DQN_baseline_gdrl_4-seed1/raw/main/poetry.lock
         | 
| 48 | 
            +
            poetry install --all-extras
         | 
| 49 | 
            +
            python dqn.py --exp-name DQN_baseline_gdrl_4 --track --wandb-entity pfunk --wandb-project-name dqpn --save-model true --upload-model true --hf-entity pfunk --total-timesteps 500000
         | 
| 50 | 
            +
            ```
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Hyperparameters
         | 
| 53 | 
            +
            ```python
         | 
| 54 | 
            +
            {'batch_size': 128,
         | 
| 55 | 
            +
             'buffer_size': 100000,
         | 
| 56 | 
            +
             'capture_video': False,
         | 
| 57 | 
            +
             'cuda': True,
         | 
| 58 | 
            +
             'end_e': 0.1,
         | 
| 59 | 
            +
             'env_id': 'CartPole-v1',
         | 
| 60 | 
            +
             'exp_name': 'DQN_baseline_gdrl_4',
         | 
| 61 | 
            +
             'exploration_fraction': 0.2,
         | 
| 62 | 
            +
             'gamma': 1.0,
         | 
| 63 | 
            +
             'hf_entity': 'pfunk',
         | 
| 64 | 
            +
             'learning_rate': 0.0001,
         | 
| 65 | 
            +
             'learning_starts': 1000,
         | 
| 66 | 
            +
             'save_model': True,
         | 
| 67 | 
            +
             'seed': 1,
         | 
| 68 | 
            +
             'start_e': 1,
         | 
| 69 | 
            +
             'target_network_frequency': 50,
         | 
| 70 | 
            +
             'tau': 1.0,
         | 
| 71 | 
            +
             'torch_deterministic': True,
         | 
| 72 | 
            +
             'total_timesteps': 500000,
         | 
| 73 | 
            +
             'track': True,
         | 
| 74 | 
            +
             'train_frequency': 1,
         | 
| 75 | 
            +
             'upload_model': True,
         | 
| 76 | 
            +
             'wandb_entity': 'pfunk',
         | 
| 77 | 
            +
             'wandb_project_name': 'dqpn'}
         | 
| 78 | 
            +
            ```
         | 
| 79 | 
            +
                
         | 
    	
        dqn.py
    ADDED
    
    | @@ -0,0 +1,255 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import time
         | 
| 6 | 
            +
            from distutils.util import strtobool
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import gym
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn as nn
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            import torch.optim as optim
         | 
| 14 | 
            +
            from stable_baselines3.common.buffers import ReplayBuffer
         | 
| 15 | 
            +
            from torch.utils.tensorboard import SummaryWriter
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def parse_args():
         | 
| 19 | 
            +
                # fmt: off
         | 
| 20 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 21 | 
            +
                parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
         | 
| 22 | 
            +
                    help="the name of this experiment")
         | 
| 23 | 
            +
                parser.add_argument("--seed", type=int, default=1,
         | 
| 24 | 
            +
                    help="seed of the experiment")
         | 
| 25 | 
            +
                parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 26 | 
            +
                    help="if toggled, `torch.backends.cudnn.deterministic=False`")
         | 
| 27 | 
            +
                parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 28 | 
            +
                    help="if toggled, cuda will be enabled by default")
         | 
| 29 | 
            +
                parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 30 | 
            +
                    help="if toggled, this experiment will be tracked with Weights and Biases")
         | 
| 31 | 
            +
                parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
         | 
| 32 | 
            +
                    help="the wandb's project name")
         | 
| 33 | 
            +
                parser.add_argument("--wandb-entity", type=str, default=None,
         | 
| 34 | 
            +
                    help="the entity (team) of wandb's project")
         | 
| 35 | 
            +
                parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 36 | 
            +
                    help="whether to capture videos of the agent performances (check out `videos` folder)")
         | 
| 37 | 
            +
                parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 38 | 
            +
                    help="whether to save model into the `runs/{run_name}` folder")
         | 
| 39 | 
            +
                parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 40 | 
            +
                    help="whether to upload the saved model to huggingface")
         | 
| 41 | 
            +
                parser.add_argument("--hf-entity", type=str, default="",
         | 
| 42 | 
            +
                    help="the user or org name of the model repository from the Hugging Face Hub")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Algorithm specific arguments
         | 
| 45 | 
            +
                parser.add_argument("--env-id", type=str, default="CartPole-v1",
         | 
| 46 | 
            +
                    help="the id of the environment")
         | 
| 47 | 
            +
                parser.add_argument("--total-timesteps", type=int, default=500000,
         | 
| 48 | 
            +
                    help="total timesteps of the experiments")
         | 
| 49 | 
            +
                parser.add_argument("--learning-rate", type=float, default=0.0001,
         | 
| 50 | 
            +
                    help="the learning rate of the optimizer")
         | 
| 51 | 
            +
                parser.add_argument("--buffer-size", type=int, default=100000,
         | 
| 52 | 
            +
                    help="the replay memory buffer size")
         | 
| 53 | 
            +
                parser.add_argument("--gamma", type=float, default=1.,
         | 
| 54 | 
            +
                    help="the discount factor gamma")
         | 
| 55 | 
            +
                parser.add_argument("--tau", type=float, default=1.,
         | 
| 56 | 
            +
                    help="the target network update rate")
         | 
| 57 | 
            +
                parser.add_argument("--target-network-frequency", type=int, default=50,
         | 
| 58 | 
            +
                    help="the timesteps it takes to update the target network")
         | 
| 59 | 
            +
                parser.add_argument("--batch-size", type=int, default=128,
         | 
| 60 | 
            +
                    help="the batch size of sample from the reply memory")
         | 
| 61 | 
            +
                parser.add_argument("--start-e", type=float, default=1,
         | 
| 62 | 
            +
                    help="the starting epsilon for exploration")
         | 
| 63 | 
            +
                parser.add_argument("--end-e", type=float, default=0.1,
         | 
| 64 | 
            +
                    help="the ending epsilon for exploration")
         | 
| 65 | 
            +
                parser.add_argument("--exploration-fraction", type=float, default=0.2,
         | 
| 66 | 
            +
                    help="the fraction of `total-timesteps` it takes from start-e to go end-e")
         | 
| 67 | 
            +
                parser.add_argument("--learning-starts", type=int, default=1000,
         | 
| 68 | 
            +
                    help="timestep to start learning")
         | 
| 69 | 
            +
                parser.add_argument("--train-frequency", type=int, default=1,
         | 
| 70 | 
            +
                    help="the frequency of training")
         | 
| 71 | 
            +
                args = parser.parse_args()
         | 
| 72 | 
            +
                # fmt: on
         | 
| 73 | 
            +
                return args
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def make_env(env_id, seed, idx, capture_video, run_name):
         | 
| 77 | 
            +
                def thunk():
         | 
| 78 | 
            +
                    env = gym.make(env_id)
         | 
| 79 | 
            +
                    env = gym.wrappers.RecordEpisodeStatistics(env)
         | 
| 80 | 
            +
                    if capture_video:
         | 
| 81 | 
            +
                        if idx == 0:
         | 
| 82 | 
            +
                            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
         | 
| 83 | 
            +
                    env.seed(seed)
         | 
| 84 | 
            +
                    env.action_space.seed(seed)
         | 
| 85 | 
            +
                    env.observation_space.seed(seed)
         | 
| 86 | 
            +
                    return env
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return thunk
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            # ALGO LOGIC: initialize agent here:
         | 
| 92 | 
            +
            class QNetwork(nn.Module):
         | 
| 93 | 
            +
                def __init__(self, env):
         | 
| 94 | 
            +
                    super().__init__()
         | 
| 95 | 
            +
                    self.network = nn.Sequential(
         | 
| 96 | 
            +
                        nn.Linear(np.array(env.single_observation_space.shape).prod(), 512),
         | 
| 97 | 
            +
                        nn.ReLU(),
         | 
| 98 | 
            +
                        nn.Linear(512, 128),
         | 
| 99 | 
            +
                        nn.ReLU(),
         | 
| 100 | 
            +
                        nn.Linear(128, env.single_action_space.n),
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward(self, x):
         | 
| 104 | 
            +
                    return self.network(x)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
         | 
| 108 | 
            +
                slope = (end_e - start_e) / duration
         | 
| 109 | 
            +
                return max(slope * t + start_e, end_e)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            if __name__ == "__main__":
         | 
| 113 | 
            +
                args = parse_args()
         | 
| 114 | 
            +
                run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
         | 
| 115 | 
            +
                if args.track:
         | 
| 116 | 
            +
                    import wandb
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    wandb.init(
         | 
| 119 | 
            +
                        project=args.wandb_project_name,
         | 
| 120 | 
            +
                        entity=args.wandb_entity,
         | 
| 121 | 
            +
                        sync_tensorboard=True,
         | 
| 122 | 
            +
                        config=vars(args),
         | 
| 123 | 
            +
                        name=run_name,
         | 
| 124 | 
            +
                        monitor_gym=True,
         | 
| 125 | 
            +
                        save_code=True,
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
                writer = SummaryWriter(f"runs/{run_name}")
         | 
| 128 | 
            +
                writer.add_text(
         | 
| 129 | 
            +
                    "hyperparameters",
         | 
| 130 | 
            +
                    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
         | 
| 131 | 
            +
                )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # TRY NOT TO MODIFY: seeding
         | 
| 134 | 
            +
                random.seed(args.seed)
         | 
| 135 | 
            +
                np.random.seed(args.seed)
         | 
| 136 | 
            +
                torch.manual_seed(args.seed)
         | 
| 137 | 
            +
                torch.backends.cudnn.deterministic = args.torch_deterministic
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                # env setup
         | 
| 142 | 
            +
                envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
         | 
| 143 | 
            +
                assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
         | 
| 144 | 
            +
                envs.seed(args.seed)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                q_network = QNetwork(envs).to(device)
         | 
| 147 | 
            +
                optimizer = optim.RMSprop(q_network.parameters(), lr=args.learning_rate)
         | 
| 148 | 
            +
                target_network = QNetwork(envs).to(device)
         | 
| 149 | 
            +
                target_network.load_state_dict(q_network.state_dict())
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                rb = ReplayBuffer(
         | 
| 152 | 
            +
                    args.buffer_size,
         | 
| 153 | 
            +
                    envs.single_observation_space,
         | 
| 154 | 
            +
                    envs.single_action_space,
         | 
| 155 | 
            +
                    device,
         | 
| 156 | 
            +
                    handle_timeout_termination=True,
         | 
| 157 | 
            +
                )
         | 
| 158 | 
            +
                start_time = time.time()
         | 
| 159 | 
            +
                episode_returns = []
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # TRY NOT TO MODIFY: start the game
         | 
| 162 | 
            +
                obs = envs.reset()
         | 
| 163 | 
            +
                for global_step in range(args.total_timesteps):
         | 
| 164 | 
            +
                    # ALGO LOGIC: put action logic here
         | 
| 165 | 
            +
                    epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
         | 
| 166 | 
            +
                    if random.random() < epsilon:
         | 
| 167 | 
            +
                        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
         | 
| 168 | 
            +
                    else:
         | 
| 169 | 
            +
                        q_values = q_network(torch.Tensor(obs).to(device))
         | 
| 170 | 
            +
                        actions = torch.argmax(q_values, dim=1).cpu().numpy()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # TRY NOT TO MODIFY: execute the game and log data.
         | 
| 173 | 
            +
                    next_obs, rewards, dones, infos = envs.step(actions)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # TRY NOT TO MODIFY: record rewards for plotting purposes
         | 
| 176 | 
            +
                    for info in infos:
         | 
| 177 | 
            +
                        if "episode" in info.keys():
         | 
| 178 | 
            +
                            episode_returns.append(info['episode']['r'])
         | 
| 179 | 
            +
                            episode_returns = episode_returns[-100:]
         | 
| 180 | 
            +
                            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
         | 
| 181 | 
            +
                            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
         | 
| 182 | 
            +
                            writer.add_scalar("charts/episodic_return_mean_10", np.mean(episode_returns[-10:]), global_step)
         | 
| 183 | 
            +
                            writer.add_scalar("charts/episodic_return_std_10", np.std(episode_returns[-10:]), global_step)
         | 
| 184 | 
            +
                            writer.add_scalar("charts/episodic_return_mean_100", np.mean(episode_returns), global_step)
         | 
| 185 | 
            +
                            writer.add_scalar("charts/episodic_return_std_100", np.std(episode_returns), global_step)
         | 
| 186 | 
            +
                            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
         | 
| 187 | 
            +
                            writer.add_scalar("charts/epsilon", epsilon, global_step)
         | 
| 188 | 
            +
                            break
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
         | 
| 191 | 
            +
                    real_next_obs = next_obs.copy()
         | 
| 192 | 
            +
                    for idx, d in enumerate(dones):
         | 
| 193 | 
            +
                        if d:
         | 
| 194 | 
            +
                            real_next_obs[idx] = infos[idx]["terminal_observation"]
         | 
| 195 | 
            +
                    rb.add(obs, real_next_obs, actions, rewards, dones, infos)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
         | 
| 198 | 
            +
                    obs = next_obs
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # ALGO LOGIC: training.
         | 
| 201 | 
            +
                    if global_step > args.learning_starts:
         | 
| 202 | 
            +
                        if global_step % args.train_frequency == 0:
         | 
| 203 | 
            +
                            data = rb.sample(args.batch_size)
         | 
| 204 | 
            +
                            with torch.no_grad():
         | 
| 205 | 
            +
                                target_max, _ = target_network(data.next_observations).max(dim=1)
         | 
| 206 | 
            +
                                td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
         | 
| 207 | 
            +
                            old_val = q_network(data.observations).gather(1, data.actions).squeeze()
         | 
| 208 | 
            +
                            loss = F.mse_loss(td_target, old_val)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                            if global_step % 100 == 0:
         | 
| 211 | 
            +
                                writer.add_scalar("losses/td_loss", loss, global_step)
         | 
| 212 | 
            +
                                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
         | 
| 213 | 
            +
                                print("SPS:", int(global_step / (time.time() - start_time)))
         | 
| 214 | 
            +
                                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                            # optimize the model
         | 
| 217 | 
            +
                            optimizer.zero_grad()
         | 
| 218 | 
            +
                            loss.backward()
         | 
| 219 | 
            +
                            optimizer.step()
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                        # update target network
         | 
| 222 | 
            +
                        if global_step % args.target_network_frequency == 0:
         | 
| 223 | 
            +
                            for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
         | 
| 224 | 
            +
                                target_network_param.data.copy_(
         | 
| 225 | 
            +
                                    args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
         | 
| 226 | 
            +
                                )
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                if args.save_model:
         | 
| 229 | 
            +
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
| 230 | 
            +
                    torch.save(q_network.state_dict(), model_path)
         | 
| 231 | 
            +
                    print(f"model saved to {model_path}")
         | 
| 232 | 
            +
                    from cleanrl_utils.evals.dqn_eval import evaluate
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    episodic_returns = evaluate(
         | 
| 235 | 
            +
                        model_path,
         | 
| 236 | 
            +
                        make_env,
         | 
| 237 | 
            +
                        args.env_id,
         | 
| 238 | 
            +
                        eval_episodes=10,
         | 
| 239 | 
            +
                        run_name=f"{run_name}-eval",
         | 
| 240 | 
            +
                        Model=QNetwork,
         | 
| 241 | 
            +
                        device=device,
         | 
| 242 | 
            +
                        epsilon=0.05,
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
                    for idx, episodic_return in enumerate(episodic_returns):
         | 
| 245 | 
            +
                        writer.add_scalar("eval/episodic_return", episodic_return, idx)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    if args.upload_model:
         | 
| 248 | 
            +
                        from cleanrl_utils.huggingface import push_to_hub
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                        repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
         | 
| 251 | 
            +
                        repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
         | 
| 252 | 
            +
                        push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                envs.close()
         | 
| 255 | 
            +
                writer.close()
         | 
    	
        events.out.tfevents.1677090125.redi.243042.0
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:e08d353a4ea9acc64e7be912e653cde500c91373059356f9ee6ffa4300cf7f79
         | 
| 3 | 
            +
            size 1870842
         | 
    	
        poetry.lock
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        pyproject.toml
    ADDED
    
    | @@ -0,0 +1,178 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [tool.poetry]
         | 
| 2 | 
            +
            name = "cleanrl"
         | 
| 3 | 
            +
            version = "1.1.0"
         | 
| 4 | 
            +
            description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
         | 
| 5 | 
            +
            authors = ["Costa Huang <[email protected]>"]
         | 
| 6 | 
            +
            packages = [
         | 
| 7 | 
            +
                { include = "cleanrl" },
         | 
| 8 | 
            +
                { include = "cleanrl_utils" },
         | 
| 9 | 
            +
            ]
         | 
| 10 | 
            +
            keywords = ["reinforcement", "machine", "learning", "research"]
         | 
| 11 | 
            +
            license="MIT"
         | 
| 12 | 
            +
            readme = "README.md"
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            [tool.poetry.dependencies]
         | 
| 15 | 
            +
            python = ">=3.7.1,<3.10"
         | 
| 16 | 
            +
            tensorboard = "^2.10.0"
         | 
| 17 | 
            +
            wandb = "^0.13.6"
         | 
| 18 | 
            +
            gym = "0.23.1"
         | 
| 19 | 
            +
            torch = ">=1.12.1"
         | 
| 20 | 
            +
            stable-baselines3 = "1.2.0"
         | 
| 21 | 
            +
            gymnasium = "^0.26.3"
         | 
| 22 | 
            +
            moviepy = "^1.0.3"
         | 
| 23 | 
            +
            pygame = "2.1.0"
         | 
| 24 | 
            +
            huggingface-hub = "^0.11.1"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ale-py = {version = "0.7.4", optional = true}
         | 
| 27 | 
            +
            AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
         | 
| 28 | 
            +
            opencv-python = {version = "^4.6.0.66", optional = true}
         | 
| 29 | 
            +
            pybullet = {version = "3.1.8", optional = true}
         | 
| 30 | 
            +
            procgen = {version = "^0.10.7", optional = true}
         | 
| 31 | 
            +
            pytest = {version = "^7.1.3", optional = true}
         | 
| 32 | 
            +
            mujoco = {version = "^2.2", optional = true}
         | 
| 33 | 
            +
            imageio = {version = "^2.14.1", optional = true}
         | 
| 34 | 
            +
            free-mujoco-py = {version = "^2.1.6", optional = true}
         | 
| 35 | 
            +
            mkdocs-material = {version = "^8.4.3", optional = true}
         | 
| 36 | 
            +
            markdown-include = {version = "^0.7.0", optional = true}
         | 
| 37 | 
            +
            jax = {version = "^0.3.17", optional = true}
         | 
| 38 | 
            +
            jaxlib = {version = "^0.3.15", optional = true}
         | 
| 39 | 
            +
            flax = {version = "^0.6.0", optional = true}
         | 
| 40 | 
            +
            optuna = {version = "^3.0.1", optional = true}
         | 
| 41 | 
            +
            optuna-dashboard = {version = "^0.7.2", optional = true}
         | 
| 42 | 
            +
            rich = {version = "<12.0", optional = true}
         | 
| 43 | 
            +
            envpool = {version = "^0.6.4", optional = true}
         | 
| 44 | 
            +
            PettingZoo = {version = "1.18.1", optional = true}
         | 
| 45 | 
            +
            SuperSuit = {version = "3.4.0", optional = true}
         | 
| 46 | 
            +
            multi-agent-ale-py = {version = "0.1.11", optional = true}
         | 
| 47 | 
            +
            boto3 = {version = "^1.24.70", optional = true}
         | 
| 48 | 
            +
            awscli = {version = "^1.25.71", optional = true}
         | 
| 49 | 
            +
            shimmy = {version = "^0.1.0", optional = true}
         | 
| 50 | 
            +
            dm-control = {version = "^1.0.8", optional = true}
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            [tool.poetry.group.dev.dependencies]
         | 
| 53 | 
            +
            pre-commit = "^2.20.0"
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            [tool.poetry.group.atari]
         | 
| 56 | 
            +
            optional = true
         | 
| 57 | 
            +
            [tool.poetry.group.atari.dependencies]
         | 
| 58 | 
            +
            ale-py = "0.7.4"
         | 
| 59 | 
            +
            AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
         | 
| 60 | 
            +
            opencv-python = "^4.6.0.66"
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            [tool.poetry.group.pybullet]
         | 
| 63 | 
            +
            optional = true
         | 
| 64 | 
            +
            [tool.poetry.group.pybullet.dependencies]
         | 
| 65 | 
            +
            pybullet = "3.1.8"
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            [tool.poetry.group.procgen]
         | 
| 68 | 
            +
            optional = true
         | 
| 69 | 
            +
            [tool.poetry.group.procgen.dependencies]
         | 
| 70 | 
            +
            procgen = "^0.10.7"
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            [tool.poetry.group.pytest]
         | 
| 73 | 
            +
            optional = true
         | 
| 74 | 
            +
            [tool.poetry.group.pytest.dependencies]
         | 
| 75 | 
            +
            pytest = "^7.1.3"
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            [tool.poetry.group.mujoco]
         | 
| 78 | 
            +
            optional = true
         | 
| 79 | 
            +
            [tool.poetry.group.mujoco.dependencies]
         | 
| 80 | 
            +
            mujoco = "^2.2"
         | 
| 81 | 
            +
            imageio = "^2.14.1"
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            [tool.poetry.group.mujoco_py]
         | 
| 84 | 
            +
            optional = true
         | 
| 85 | 
            +
            [tool.poetry.group.mujoco_py.dependencies]
         | 
| 86 | 
            +
            free-mujoco-py = "^2.1.6"
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            [tool.poetry.group.docs]
         | 
| 89 | 
            +
            optional = true
         | 
| 90 | 
            +
            [tool.poetry.group.docs.dependencies]
         | 
| 91 | 
            +
            mkdocs-material = "^8.4.3"
         | 
| 92 | 
            +
            markdown-include = "^0.7.0"
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            [tool.poetry.group.jax]
         | 
| 95 | 
            +
            optional = true
         | 
| 96 | 
            +
            [tool.poetry.group.jax.dependencies]
         | 
| 97 | 
            +
            jax = "^0.3.17"
         | 
| 98 | 
            +
            jaxlib = "^0.3.15"
         | 
| 99 | 
            +
            flax = "^0.6.0"
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            [tool.poetry.group.optuna]
         | 
| 102 | 
            +
            optional = true
         | 
| 103 | 
            +
            [tool.poetry.group.optuna.dependencies]
         | 
| 104 | 
            +
            optuna = "^3.0.1"
         | 
| 105 | 
            +
            optuna-dashboard = "^0.7.2"
         | 
| 106 | 
            +
            rich = "<12.0"
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            [tool.poetry.group.envpool]
         | 
| 109 | 
            +
            optional = true
         | 
| 110 | 
            +
            [tool.poetry.group.envpool.dependencies]
         | 
| 111 | 
            +
            envpool = "^0.6.4"
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            [tool.poetry.group.pettingzoo]
         | 
| 114 | 
            +
            optional = true
         | 
| 115 | 
            +
            [tool.poetry.group.pettingzoo.dependencies]
         | 
| 116 | 
            +
            PettingZoo = "1.18.1"
         | 
| 117 | 
            +
            SuperSuit = "3.4.0"
         | 
| 118 | 
            +
            multi-agent-ale-py = "0.1.11"
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            [tool.poetry.group.cloud]
         | 
| 121 | 
            +
            optional = true
         | 
| 122 | 
            +
            [tool.poetry.group.cloud.dependencies]
         | 
| 123 | 
            +
            boto3 = "^1.24.70"
         | 
| 124 | 
            +
            awscli = "^1.25.71"
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            [tool.poetry.group.isaacgym]
         | 
| 127 | 
            +
            optional = true
         | 
| 128 | 
            +
            [tool.poetry.group.isaacgym.dependencies]
         | 
| 129 | 
            +
            isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry"}
         | 
| 130 | 
            +
            isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            [tool.poetry.group.dm_control]
         | 
| 133 | 
            +
            optional = true
         | 
| 134 | 
            +
            [tool.poetry.group.dm_control.dependencies]
         | 
| 135 | 
            +
            shimmy = "^0.1.0"
         | 
| 136 | 
            +
            dm-control = "^1.0.8"
         | 
| 137 | 
            +
            mujoco = "^2.2"
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            [build-system]
         | 
| 140 | 
            +
            requires = ["poetry-core"]
         | 
| 141 | 
            +
            build-backend = "poetry.core.masonry.api"
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            [tool.poetry.extras]
         | 
| 144 | 
            +
            atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 145 | 
            +
            pybullet = ["pybullet"]
         | 
| 146 | 
            +
            procgen = ["procgen"]
         | 
| 147 | 
            +
            plot = ["pandas", "seaborn"]
         | 
| 148 | 
            +
            pytest = ["pytest"]
         | 
| 149 | 
            +
            mujoco = ["mujoco", "imageio"]
         | 
| 150 | 
            +
            mujoco_py = ["free-mujoco-py"]
         | 
| 151 | 
            +
            jax = ["jax", "jaxlib", "flax"]
         | 
| 152 | 
            +
            docs = ["mkdocs-material", "markdown-include"]
         | 
| 153 | 
            +
            envpool = ["envpool"]
         | 
| 154 | 
            +
            optuna = ["optuna", "optuna-dashboard", "rich"]
         | 
| 155 | 
            +
            pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
         | 
| 156 | 
            +
            cloud = ["boto3", "awscli"]
         | 
| 157 | 
            +
            dm_control = ["shimmy", "dm-control", "mujoco"]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            # dependencies for algorithm variant (useful when you want to run a specific algorithm)
         | 
| 160 | 
            +
            dqn = []
         | 
| 161 | 
            +
            dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 162 | 
            +
            dqn_jax = ["jax", "jaxlib", "flax"]
         | 
| 163 | 
            +
            dqn_atari_jax = [
         | 
| 164 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 165 | 
            +
                "jax", "jaxlib", "flax" # jax
         | 
| 166 | 
            +
            ]
         | 
| 167 | 
            +
            c51 = []
         | 
| 168 | 
            +
            c51_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 169 | 
            +
            c51_jax = ["jax", "jaxlib", "flax"]
         | 
| 170 | 
            +
            c51_atari_jax = [
         | 
| 171 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 172 | 
            +
                "jax", "jaxlib", "flax" # jax
         | 
| 173 | 
            +
            ]
         | 
| 174 | 
            +
            ppo_atari_envpool_xla_jax_scan = [
         | 
| 175 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 176 | 
            +
                "jax", "jaxlib", "flax", # jax
         | 
| 177 | 
            +
                "envpool", # envpool
         | 
| 178 | 
            +
            ]
         | 
    	
        replay.mp4
    ADDED
    
    | 
            File without changes
         | 
    	
        videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-0.mp4
    ADDED
    
    | 
            File without changes
         | 
    	
        videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-1.mp4
    ADDED
    
    | 
            File without changes
         | 
    	
        videos/CartPole-v1__DQN_baseline_gdrl_4__1__1677090121-eval/rl-video-episode-8.mp4
    ADDED
    
    | 
            File without changes
         | 
