Thomas Simonini
		
	commited on
		
		
					Commit 
							
							·
						
						2900a24
	
1
								Parent(s):
							
							1c62517
								
Update README.md
Browse files
    	
        README.md
    CHANGED
    
    | @@ -4,4 +4,129 @@ tags: | |
| 4 | 
             
            - reinforcement-learning
         | 
| 5 | 
             
            - stable-baselines3
         | 
| 6 | 
             
            ---
         | 
| 7 | 
            -
            #  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 4 | 
             
            - reinforcement-learning
         | 
| 5 | 
             
            - stable-baselines3
         | 
| 6 | 
             
            ---
         | 
| 7 | 
            +
            # PPO Agent playing QbertNoFrameskip-v4
         | 
| 8 | 
            +
            This is a trained model of a **PPO agent playing QbertNoFrameskip-v4 using the [stable-baselines3 library](https://stable-baselines3.readthedocs.io/en/master/index.html)**.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            <video src="https://huggingface.co/ThomasSimonini/ppo-QbertNoFrameskip-v4/resolve/main/output.mp4" controls autoplay loop></video>
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            ## Evaluation Results
         | 
| 13 | 
            +
            Mean_reward: `15685.00 +/- 115.217`
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Usage (with Stable-baselines3)
         | 
| 16 | 
            +
            - You need to use `gym==0.19` since it **includes Atari Roms**.
         | 
| 17 | 
            +
            - The Action Space is 6 since we use only **possible actions in this game**.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            Watch your agent interacts :
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ```python
         | 
| 23 | 
            +
            # Import the libraries
         | 
| 24 | 
            +
            import os 
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            import gym
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from stable_baselines3 import PPO
         | 
| 29 | 
            +
            from stable_baselines3.common.vec_env import VecNormalize
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            from stable_baselines3.common.env_util import make_atari_env
         | 
| 32 | 
            +
            from stable_baselines3.common.vec_env import VecFrameStack
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from huggingface_sb3 import load_from_hub, push_to_hub
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Load the model
         | 
| 37 | 
            +
            checkpoint = load_from_hub("ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
         | 
| 40 | 
            +
            custom_objects = {
         | 
| 41 | 
            +
                        "learning_rate": 0.0,
         | 
| 42 | 
            +
                        "lr_schedule": lambda _: 0.0,
         | 
| 43 | 
            +
                        "clip_range": lambda _: 0.0,
         | 
| 44 | 
            +
                    }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            model= PPO.load(checkpoint, custom_objects=custom_objects)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            env = make_atari_env('QbertNoFrameskip-v4', n_envs=1)
         | 
| 49 | 
            +
            env = VecFrameStack(env, n_stack=4)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            obs = env.reset()
         | 
| 52 | 
            +
            while True:
         | 
| 53 | 
            +
                action, _states = model.predict(obs)
         | 
| 54 | 
            +
                obs, rewards, dones, info = env.step(action)
         | 
| 55 | 
            +
                env.render()
         | 
| 56 | 
            +
            ```
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            ## Training Code
         | 
| 60 | 
            +
            ```python
         | 
| 61 | 
            +
            import wandb
         | 
| 62 | 
            +
            import gym
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            from stable_baselines3 import PPO
         | 
| 65 | 
            +
            from stable_baselines3.common.env_util import make_atari_env
         | 
| 66 | 
            +
            from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
         | 
| 67 | 
            +
            from stable_baselines3.common.callbacks import CheckpointCallback
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            from wandb.integration.sb3 import WandbCallback
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            from huggingface_sb3 import load_from_hub, push_to_hub
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            config = {
         | 
| 74 | 
            +
                "env_name": "QbertNoFrameskip-v4",
         | 
| 75 | 
            +
                "num_envs": 8,
         | 
| 76 | 
            +
                "total_timesteps": int(10e6),
         | 
| 77 | 
            +
                "seed": 1194709219,    
         | 
| 78 | 
            +
            }
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            run = wandb.init(
         | 
| 81 | 
            +
                project="HFxSB3",
         | 
| 82 | 
            +
                config = config,
         | 
| 83 | 
            +
                sync_tensorboard = True,  # Auto-upload sb3's tensorboard metrics
         | 
| 84 | 
            +
                monitor_gym = True, # Auto-upload the videos of agents playing the game
         | 
| 85 | 
            +
                save_code = True, # Save the code to W&B
         | 
| 86 | 
            +
                )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # There already exists an environment generator
         | 
| 89 | 
            +
            # that will make and wrap atari environments correctly.
         | 
| 90 | 
            +
            # Here we are also multi-worker training (n_envs=8 => 8 environments)
         | 
| 91 | 
            +
            env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #QbertNoFrameskip-v4
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            print("ENV ACTION SPACE: ", env.action_space.n)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            # Frame-stacking with 4 frames
         | 
| 96 | 
            +
            env = VecFrameStack(env, n_stack=4)
         | 
| 97 | 
            +
            # Video recorder
         | 
| 98 | 
            +
            env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            model = PPO(policy = "CnnPolicy",
         | 
| 101 | 
            +
                        env = env,
         | 
| 102 | 
            +
                        batch_size = 256,
         | 
| 103 | 
            +
                        clip_range = 0.1,
         | 
| 104 | 
            +
                        ent_coef = 0.01,
         | 
| 105 | 
            +
                        gae_lambda = 0.9,
         | 
| 106 | 
            +
                        gamma = 0.99,
         | 
| 107 | 
            +
                        learning_rate = 2.5e-4,
         | 
| 108 | 
            +
                        max_grad_norm = 0.5,
         | 
| 109 | 
            +
                        n_epochs = 4,
         | 
| 110 | 
            +
                        n_steps = 128,
         | 
| 111 | 
            +
                        vf_coef = 0.5,
         | 
| 112 | 
            +
                        tensorboard_log = f"runs",
         | 
| 113 | 
            +
                        verbose=1,
         | 
| 114 | 
            +
                        )
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
            model.learn(
         | 
| 117 | 
            +
                total_timesteps = config["total_timesteps"],
         | 
| 118 | 
            +
                callback = [
         | 
| 119 | 
            +
                    WandbCallback(
         | 
| 120 | 
            +
                    gradient_save_freq = 1000,
         | 
| 121 | 
            +
                    model_save_path = f"models/{run.id}",
         | 
| 122 | 
            +
                    ), 
         | 
| 123 | 
            +
                    CheckpointCallback(save_freq=10000, save_path='./qbert',
         | 
| 124 | 
            +
                                                     name_prefix=config["env_name"]),
         | 
| 125 | 
            +
                    ]
         | 
| 126 | 
            +
            )
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            model.save("ppo-QbertNoFrameskip-v4.zip")
         | 
| 129 | 
            +
            push_to_hub(repo_id="ThomasSimonini/ppo-QbertNoFrameskip-v4", 
         | 
| 130 | 
            +
                filename="ppo-QbertNoFrameskip-v4.zip",
         | 
| 131 | 
            +
                commit_message="Added Qbert trained agent")
         | 
| 132 | 
            +
            ```
         | 
