diff --git a/Python/PPOVizDoom.py b/Python/PPOVizDoom.py index 29087d6..6a61571 100644 --- a/Python/PPOVizDoom.py +++ b/Python/PPOVizDoom.py @@ -40,62 +40,62 @@ def parse_vizdoom_cfg(argv=None, evaluation=False): return final_cfg -## Start the training, this should take around 15 minutes -register_vizdoom_components() - -# The scenario we train on today is health gathering -# other scenarios include "doom_basic", "doom_two_colors_easy", "doom_dm", "doom_dwango5", "doom_my_way_home", "doom_deadly_corridor", "doom_defend_the_center", "doom_defend_the_line" -env = "doom_health_gathering_supreme" -cfg = parse_vizdoom_cfg( - argv=[f"--env={env}", "--num_workers=8", "--num_envs_per_worker=4", "--train_for_env_steps=4000000"] -) - -status = run_rl(cfg) - - -from sample_factory.enjoy import enjoy - -cfg = parse_vizdoom_cfg( - argv=[f"--env={env}", "--num_workers=1", "--save_video", "--no_render", "--max_num_episodes=10"], evaluation=True -) -status = enjoy(cfg) - - -# from base64 import b64encode -# from IPython.display import HTML - -# mp4 = open("/content/train_dir/default_experiment/replay.mp4", "rb").read() -# data_url = "data:video/mp4;base64," + b64encode(mp4).decode() -# HTML( -# """ -# -# """ -# % data_url -# ) - -from huggingface_hub import notebook_login -notebook_login() - -# !git config --global credential.helper store - -from sample_factory.enjoy import enjoy - -hf_username = "ThomasSimonini" # insert your HuggingFace username here - -cfg = parse_vizdoom_cfg( - argv=[ - f"--env={env}", - "--num_workers=1", - "--save_video", - "--no_render", - "--max_num_episodes=10", - "--max_num_frames=100000", - "--push_to_hub", - f"--hf_repository={hf_username}/rl_course_vizdoom_health_gathering_supreme", - ], - evaluation=True, -) -status = enjoy(cfg) +if __name__ == '__main__': + ## Start the training, this should take around 15 minutes + register_vizdoom_components() + + # The scenario we train on today is health gathering + # other scenarios include "doom_basic", "doom_two_colors_easy", "doom_dm", "doom_dwango5", "doom_my_way_home", "doom_deadly_corridor", "doom_defend_the_center", "doom_defend_the_line" + env = "doom_health_gathering_supreme" + cfg = parse_vizdoom_cfg( + argv=[f"--env={env}", "--num_workers=8", "--num_envs_per_worker=4", "--train_for_env_steps=4000000"] + ) + status = run_rl(cfg) + + + from sample_factory.enjoy import enjoy + + cfg = parse_vizdoom_cfg( + argv=[f"--env={env}", "--num_workers=1", "--save_video", "--no_render", "--max_num_episodes=10"], evaluation=True + ) + status = enjoy(cfg) + + + # from base64 import b64encode + # from IPython.display import HTML + + # mp4 = open("/content/train_dir/default_experiment/replay.mp4", "rb").read() + # data_url = "data:video/mp4;base64," + b64encode(mp4).decode() + # HTML( + # """ + # + # """ + # % data_url + # ) + + from huggingface_hub import notebook_login + notebook_login() + + # !git config --global credential.helper store + + from sample_factory.enjoy import enjoy + + hf_username = "togu6669" # insert your HuggingFace username here + + cfg = parse_vizdoom_cfg( + argv=[ + f"--env={env}", + "--num_workers=1", + "--save_video", + "--no_render", + "--max_num_episodes=10", + "--max_num_frames=100000", + "--push_to_hub", + f"--hf_repository={hf_username}/rl_course_vizdoom_health_gathering_supreme", + ], + evaluation=True, + ) + status = enjoy(cfg)