PPO playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
1d51343
| # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html) | |
| import os | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| from multiprocessing import Pool | |
| from rl_algo_impls.runner.running_utils import base_parser | |
| from rl_algo_impls.runner.train import TrainArgs | |
| from rl_algo_impls.runner.train import train as runner_train | |
| def train() -> None: | |
| parser = base_parser() | |
| parser.add_argument( | |
| "--wandb-project-name", | |
| type=str, | |
| default="rl-algo-impls", | |
| help="WandB project name to upload training data to. If none, won't upload.", | |
| ) | |
| parser.add_argument( | |
| "--wandb-entity", | |
| type=str, | |
| default=None, | |
| help="WandB team of project. None uses default entity", | |
| ) | |
| parser.add_argument( | |
| "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run" | |
| ) | |
| parser.add_argument( | |
| "--pool-size", type=int, default=1, help="Simultaneous training jobs to run" | |
| ) | |
| parser.add_argument( | |
| "--virtual-display", action="store_true", help="Use headless virtual display" | |
| ) | |
| # parser.set_defaults( | |
| # algo=["ppo"], | |
| # env=["CartPole-v1"], | |
| # seed=[10], | |
| # pool_size=3, | |
| # ) | |
| args = parser.parse_args() | |
| print(args) | |
| if args.virtual_display: | |
| from pyvirtualdisplay.display import Display | |
| virtual_display = Display(visible=False, size=(1400, 900)) | |
| virtual_display.start() | |
| # virtual_display isn't a TrainArg so must be removed | |
| delattr(args, "virtual_display") | |
| pool_size = min(args.pool_size, len(args.seed)) | |
| # pool_size isn't a TrainArg so must be removed from args | |
| delattr(args, "pool_size") | |
| train_args = TrainArgs.expand_from_dict(vars(args)) | |
| if len(train_args) == 1: | |
| runner_train(train_args[0]) | |
| else: | |
| # Force a new process for each job to get around wandb not allowing more than one | |
| # wandb.tensorboard.patch call per process. | |
| with Pool(pool_size, maxtasksperchild=1) as p: | |
| p.map(runner_train, train_args) | |
| if __name__ == "__main__": | |
| train() | |