import sys if not sys.warnoptions: import warnings warnings.simplefilter("ignore") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.simplefilter("ignore") import tensorflow as tf tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) tf.get_logger().setLevel('INFO') tf.autograph.set_verbosity(0) import warnings warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=Warning) # import tensorflow as tf from baselines.ppo2 import sppo2 from baselines.common.models import build_skill_impala_cnn from baselines.common.mpi_util import setup_mpi_gpus from procgen import ProcgenEnv from baselines.common.vec_env import ( VecExtractDictObs, VecMonitor, VecNormalize ) from baselines import logger from mpi4py import MPI import argparse from train_procgen.utils import * from train_procgen.constants import EASY_GAME_RANGES parser = argparse.ArgumentParser(description='Process procgen training arguments.') parser.add_argument('--env_name', type=str, default='coinrun') parser.add_argument('--num_envs', type=int, default=64) parser.add_argument('--distribution_mode', type=str, default='easy', choices=["easy", "hard", "exploration", "memory", "extreme"]) parser.add_argument('--num_levels', type=int, default=0) parser.add_argument('--start_level', type=int, default=0) parser.add_argument('--test_worker_interval', type=int, default=0) parser.add_argument('--timesteps_per_proc', type=int, default=25_000_000) parser.add_argument('--rand_seed', type=int, default=2022) parser.add_argument('--num_embeddings', type=int, default=8) parser.add_argument('--beta', type=float, default=0.0) parser.add_argument('--alpha1', type=float, default=20) parser.add_argument('--alpha2', type=float, default=20) args = parser.parse_args() def train_fn(env_name, num_envs, distribution_mode, num_levels, start_level, timesteps_per_proc, num_embeddings, beta, alpha1, alpha2, is_test_worker=False, log_dir='./train_procgen/checkpoints', comm=None, rand_seed=None): learning_rate = 5e-4 ent_coef = .01 gamma = .999 lam = .95 nsteps = 256 nminibatches = 8 ppo_epochs = 3 clip_range = .2 use_vf_clipping = True mpi_rank_weight = 0 if is_test_worker else 1 num_levels = 0 if is_test_worker else num_levels if log_dir is not None: log_comm = comm.Split(1 if is_test_worker else 0, 0) format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else [] logger.configure(comm=log_comm, dir=log_dir, format_strs=format_strs) logger.info("creating environment") venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level, distribution_mode=distribution_mode, rand_seed=rand_seed) venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor( venv=venv, filename=None, keep_buf=100, ) venv = VecNormalize(venv=venv, ob=False) logger.info("creating tf session") setup_mpi_gpus() config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=E1101 sess = tf.compat.v1.Session(config=config) sess.__enter__() conv_fn = lambda x: build_skill_impala_cnn(x, depths=[16, 32, 32], emb_dim=256, num_embeddings=num_embeddings, beta=beta, seed=rand_seed) logger.info("training") sppo2.learn( env=venv, network=conv_fn, total_timesteps=timesteps_per_proc, save_interval=0, nsteps=nsteps, nminibatches=nminibatches, lam=lam, gamma=gamma, noptepochs=ppo_epochs, log_interval=1, ent_coef=ent_coef, mpi_rank_weight=mpi_rank_weight, clip_vf=use_vf_clipping, comm=comm, lr=learning_rate, cliprange=clip_range, update_fn=None, init_fn=None, vf_coef=0.5, max_grad_norm=0.5, seed=rand_seed, num_embeddings=num_embeddings, highest_score=EASY_GAME_RANGES[env_name][1], alpha1=alpha1, alpha2=alpha2, ) def main(): comm = MPI.COMM_WORLD rank = comm.Get_rank() is_test_worker = False test_worker_interval = args.test_worker_interval if test_worker_interval > 0: is_test_worker = rank % test_worker_interval == (test_worker_interval - 1) saved_dir = './train_procgen/checkpoints/sppo-' + args.env_name + '_' + args.distribution_mode + '_' + str( args.num_levels) + '_' + str(args.start_level) + '_' + str( args.rand_seed) train_fn(args.env_name, args.num_envs, args.distribution_mode, args.num_levels, args.start_level, args.timesteps_per_proc, is_test_worker=is_test_worker, comm=comm, log_dir=saved_dir, rand_seed=args.rand_seed, num_embeddings=args.num_embeddings, beta=args.beta, alpha1=args.alpha1, alpha2=args.alpha2, ) if __name__ == '__main__': main()