Files changed (1) hide show
  1. app.py +191 -412
app.py CHANGED
@@ -1,412 +1,191 @@
1
- import os
2
- import json
3
- import requests
4
-
5
- import gradio as gr
6
- import pandas as pd
7
- from huggingface_hub import HfApi, hf_hub_download, snapshot_download
8
- from huggingface_hub.repocard import metadata_load
9
- from apscheduler.schedulers.background import BackgroundScheduler
10
-
11
- from tqdm.contrib.concurrent import thread_map
12
-
13
- from utils import *
14
-
15
- DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data"
16
- DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data"
17
- HF_TOKEN = os.environ.get("HF_TOKEN")
18
-
19
- block = gr.Blocks()
20
- api = HfApi(token=HF_TOKEN)
21
-
22
- # Containing the data
23
- rl_envs = [
24
- {
25
- "rl_env_beautiful": "LunarLander-v2 πŸš€",
26
- "rl_env": "LunarLander-v2",
27
- "video_link": "",
28
- "global": None
29
- },
30
- {
31
- "rl_env_beautiful": "CartPole-v1",
32
- "rl_env": "CartPole-v1",
33
- "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4",
34
- "global": None
35
- },
36
- {
37
- "rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery ❄️",
38
- "rl_env": "FrozenLake-v1-4x4-no_slippery",
39
- "video_link": "",
40
- "global": None
41
- },
42
- {
43
- "rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery ❄️",
44
- "rl_env": "FrozenLake-v1-8x8-no_slippery",
45
- "video_link": "",
46
- "global": None
47
- },
48
- {
49
- "rl_env_beautiful": "FrozenLake-v1-4x4 ❄️",
50
- "rl_env": "FrozenLake-v1-4x4",
51
- "video_link": "",
52
- "global": None
53
- },
54
- {
55
- "rl_env_beautiful": "FrozenLake-v1-8x8 ❄️",
56
- "rl_env": "FrozenLake-v1-8x8",
57
- "video_link": "",
58
- "global": None
59
- },
60
- {
61
- "rl_env_beautiful": "Taxi-v3 πŸš–",
62
- "rl_env": "Taxi-v3",
63
- "video_link": "",
64
- "global": None
65
- },
66
- {
67
- "rl_env_beautiful": "CarRacing-v0 🏎️",
68
- "rl_env": "CarRacing-v0",
69
- "video_link": "",
70
- "global": None
71
- },
72
- {
73
- "rl_env_beautiful": "CarRacing-v2 🏎️",
74
- "rl_env": "CarRacing-v2",
75
- "video_link": "",
76
- "global": None
77
- },
78
- {
79
- "rl_env_beautiful": "MountainCar-v0 ⛰️",
80
- "rl_env": "MountainCar-v0",
81
- "video_link": "",
82
- "global": None
83
- },
84
- {
85
- "rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πŸ‘Ύ",
86
- "rl_env": "SpaceInvadersNoFrameskip-v4",
87
- "video_link": "",
88
- "global": None
89
- },
90
- {
91
- "rl_env_beautiful": "PongNoFrameskip-v4 🎾",
92
- "rl_env": "PongNoFrameskip-v4",
93
- "video_link": "",
94
- "global": None
95
- },
96
- {
97
- "rl_env_beautiful": "BreakoutNoFrameskip-v4 🧱",
98
- "rl_env": "BreakoutNoFrameskip-v4",
99
- "video_link": "",
100
- "global": None
101
- },
102
- {
103
- "rl_env_beautiful": "QbertNoFrameskip-v4 🐦",
104
- "rl_env": "QbertNoFrameskip-v4",
105
- "video_link": "",
106
- "global": None
107
- },
108
- {
109
- "rl_env_beautiful": "BipedalWalker-v3",
110
- "rl_env": "BipedalWalker-v3",
111
- "video_link": "",
112
- "global": None
113
- },
114
- {
115
- "rl_env_beautiful": "Walker2DBulletEnv-v0",
116
- "rl_env": "Walker2DBulletEnv-v0",
117
- "video_link": "",
118
- "global": None
119
- },
120
- {
121
- "rl_env_beautiful": "AntBulletEnv-v0",
122
- "rl_env": "AntBulletEnv-v0",
123
- "video_link": "",
124
- "global": None
125
- },
126
- {
127
- "rl_env_beautiful": "HalfCheetahBulletEnv-v0",
128
- "rl_env": "HalfCheetahBulletEnv-v0",
129
- "video_link": "",
130
- "global": None
131
- },
132
- {
133
- "rl_env_beautiful": "PandaReachDense-v2",
134
- "rl_env": "PandaReachDense-v2",
135
- "video_link": "",
136
- "global": None
137
- },
138
- {
139
- "rl_env_beautiful": "PandaReachDense-v3",
140
- "rl_env": "PandaReachDense-v3",
141
- "video_link": "",
142
- "global": None
143
- },
144
- {
145
- "rl_env_beautiful": "Pixelcopter-PLE-v0",
146
- "rl_env": "Pixelcopter-PLE-v0",
147
- "video_link": "",
148
- "global": None
149
- }
150
- ]
151
-
152
- def restart():
153
- print("RESTART")
154
- api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")
155
-
156
- def get_metadata(model_id):
157
- try:
158
- readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
159
- return metadata_load(readme_path)
160
- except requests.exceptions.HTTPError:
161
- # 404 README.md not found
162
- return None
163
-
164
- def parse_metrics_accuracy(meta):
165
- if "model-index" not in meta:
166
- return None
167
- result = meta["model-index"][0]["results"]
168
- metrics = result[0]["metrics"]
169
- accuracy = metrics[0]["value"]
170
- return accuracy
171
-
172
- # We keep the worst case episode
173
- def parse_rewards(accuracy):
174
- default_std = -1000
175
- default_reward=-1000
176
- if accuracy != None:
177
- accuracy = str(accuracy)
178
- parsed = accuracy.split('+/-')
179
- if len(parsed)>1:
180
- mean_reward = float(parsed[0].strip())
181
- std_reward = float(parsed[1].strip())
182
- elif len(parsed)==1: #only mean reward
183
- mean_reward = float(parsed[0].strip())
184
- std_reward = float(0)
185
- else:
186
- mean_reward = float(default_std)
187
- std_reward = float(default_reward)
188
-
189
- else:
190
- mean_reward = float(default_std)
191
- std_reward = float(default_reward)
192
- return mean_reward, std_reward
193
-
194
-
195
- def get_model_ids(rl_env):
196
- api = HfApi()
197
- models = api.list_models(filter=rl_env)
198
- model_ids = [x.modelId for x in models]
199
- return model_ids
200
-
201
- # Parralelized version
202
- def update_leaderboard_dataset_parallel(rl_env, path):
203
- # Get model ids associated with rl_env
204
- model_ids = get_model_ids(rl_env)
205
-
206
- def process_model(model_id):
207
- meta = get_metadata(model_id)
208
- #LOADED_MODEL_METADATA[model_id] = meta if meta is not None else ''
209
- if meta is None:
210
- return None
211
- user_id = model_id.split('/')[0]
212
- row = {}
213
- row["User"] = user_id
214
- row["Model"] = model_id
215
- accuracy = parse_metrics_accuracy(meta)
216
- mean_reward, std_reward = parse_rewards(accuracy)
217
- mean_reward = mean_reward if not pd.isna(mean_reward) else 0
218
- std_reward = std_reward if not pd.isna(std_reward) else 0
219
- row["Results"] = mean_reward - std_reward
220
- row["Mean Reward"] = mean_reward
221
- row["Std Reward"] = std_reward
222
- return row
223
-
224
- data = list(thread_map(process_model, model_ids, desc="Processing models"))
225
-
226
- # Filter out None results (models with no metadata)
227
- data = [row for row in data if row is not None]
228
-
229
- ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
230
- new_history = ranked_dataframe
231
- file_path = path + "/" + rl_env + ".csv"
232
- new_history.to_csv(file_path, index=False)
233
-
234
- return ranked_dataframe
235
-
236
-
237
- def update_leaderboard_dataset(rl_env, path):
238
- # Get model ids associated with rl_env
239
- model_ids = get_model_ids(rl_env)
240
- data = []
241
- for model_id in model_ids:
242
- """
243
- readme_path = hf_hub_download(model_id, filename="README.md")
244
- meta = metadata_load(readme_path)
245
- """
246
- meta = get_metadata(model_id)
247
- #LOADED_MODEL_METADATA[model_id] = meta if meta is not None else ''
248
- if meta is None:
249
- continue
250
- user_id = model_id.split('/')[0]
251
- row = {}
252
- row["User"] = user_id
253
- row["Model"] = model_id
254
- accuracy = parse_metrics_accuracy(meta)
255
- mean_reward, std_reward = parse_rewards(accuracy)
256
- mean_reward = mean_reward if not pd.isna(mean_reward) else 0
257
- std_reward = std_reward if not pd.isna(std_reward) else 0
258
- row["Results"] = mean_reward - std_reward
259
- row["Mean Reward"] = mean_reward
260
- row["Std Reward"] = std_reward
261
- data.append(row)
262
-
263
- ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
264
- new_history = ranked_dataframe
265
- file_path = path + "/" + rl_env + ".csv"
266
- new_history.to_csv(file_path, index=False)
267
-
268
- return ranked_dataframe
269
-
270
- def download_leaderboard_dataset():
271
- path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
272
- return path
273
-
274
- def get_data(rl_env, path) -> pd.DataFrame:
275
- """
276
- Get data from rl_env
277
- :return: data as a pandas DataFrame
278
- """
279
- csv_path = path + "/" + rl_env + ".csv"
280
- data = pd.read_csv(csv_path)
281
-
282
- for index, row in data.iterrows():
283
- user_id = row["User"]
284
- data.loc[index, "User"] = make_clickable_user(user_id)
285
- model_id = row["Model"]
286
- data.loc[index, "Model"] = make_clickable_model(model_id)
287
-
288
- return data
289
-
290
- def get_data_no_html(rl_env, path) -> pd.DataFrame:
291
- """
292
- Get data from rl_env
293
- :return: data as a pandas DataFrame
294
- """
295
- csv_path = path + "/" + rl_env + ".csv"
296
- data = pd.read_csv(csv_path)
297
-
298
- return data
299
-
300
- def rank_dataframe(dataframe):
301
- dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
302
- if not 'Ranking' in dataframe.columns:
303
- dataframe.insert(0, 'Ranking', [i for i in range(1,len(dataframe)+1)])
304
- else:
305
- dataframe['Ranking'] = [i for i in range(1,len(dataframe)+1)]
306
- return dataframe
307
-
308
-
309
- def run_update_dataset():
310
- path_ = download_leaderboard_dataset()
311
- for i in range(0, len(rl_envs)):
312
- rl_env = rl_envs[i]
313
- update_leaderboard_dataset_parallel(rl_env["rl_env"], path_)
314
-
315
- api.upload_folder(
316
- folder_path=path_,
317
- repo_id="huggingface-projects/drlc-leaderboard-data",
318
- repo_type="dataset",
319
- commit_message="Update dataset")
320
-
321
- def filter_data(rl_env, path, user_id):
322
- data_df = get_data_no_html(rl_env, path)
323
- models = []
324
- models = data_df[data_df["User"] == user_id]
325
-
326
- for index, row in models.iterrows():
327
- user_id = row["User"]
328
- models.loc[index, "User"] = make_clickable_user(user_id)
329
- model_id = row["Model"]
330
- models.loc[index, "Model"] = make_clickable_model(model_id)
331
-
332
-
333
- return models
334
-
335
- run_update_dataset()
336
-
337
- with block:
338
- gr.Markdown(f"""
339
- # πŸ† The Deep Reinforcement Learning Course Leaderboard πŸ†
340
-
341
- This is the leaderboard of trained agents during the <a href="https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt">Deep Reinforcement Learning Course</a>. A free course from beginner to expert.
342
-
343
- ### We only display the best 100 models
344
- If you want to **find yours, type your user id and click on Search my models.**
345
- You **can click on the model's name** to be redirected to its model card, including documentation.
346
-
347
- ### How are the results calculated?
348
- We use **lower bound result to sort the models: mean_reward - std_reward.**
349
-
350
- ### I can't find my model 😭
351
- The leaderboard is **updated every two hours** if you can't find your models, just wait for the next update.
352
-
353
- ### The Deep RL Course
354
- πŸ€– You want to try to train your agents? <a href="https://huggingface.co/deep-rl-course/unit0/introduction?fw=pt" target="_blank"> Check the Hugging Face free Deep Reinforcement Learning Course πŸ€— </a>.
355
-
356
- πŸ”§ There is an **environment missing?** Please open an issue.
357
- """)
358
- path_ = download_leaderboard_dataset()
359
-
360
- for i in range(0, len(rl_envs)):
361
- rl_env = rl_envs[i]
362
- with gr.TabItem(rl_env["rl_env_beautiful"]) as rl_tab:
363
- with gr.Row():
364
- markdown = """
365
- # {name_leaderboard}
366
-
367
- """.format(name_leaderboard = rl_env["rl_env_beautiful"], video_link = rl_env["video_link"])
368
- gr.Markdown(markdown)
369
-
370
-
371
- with gr.Row():
372
- gr.Markdown("""
373
- ## Search your models
374
- Simply type your user id to find your models
375
- """)
376
-
377
- with gr.Row():
378
- user_id = gr.Textbox(label= "Your user id")
379
- search_btn = gr.Button("Search my models πŸ”Ž")
380
- reset_btn = gr.Button("Clear my search")
381
- env = gr.State(rl_env["rl_env"])
382
- grpath = gr.State(path_)
383
- with gr.Row():
384
- gr_dataframe = gr.components.Dataframe(value=get_data(rl_env["rl_env"], path_), headers=["Ranking πŸ†", "User πŸ€—", "Model id πŸ€–", "Results", "Mean Reward", "Std Reward"], datatype=["number", "markdown", "markdown", "number", "number", "number"], row_count=(100, 'fixed'))
385
-
386
- with gr.Row():
387
- #gr_search_dataframe = gr.components.Dataframe(headers=["Ranking πŸ†", "User πŸ€—", "Model id πŸ€–", "Results", "Mean Reward", "Std Reward"], datatype=["number", "markdown", "markdown", "number", "number", "number"], visible=False)
388
- search_btn.click(fn=filter_data, inputs=[env, grpath, user_id], outputs=gr_dataframe, api_name="filter_data")
389
-
390
- with gr.Row():
391
- search_btn.click(fn=filter_data, inputs=[env, grpath, user_id], outputs=gr_dataframe, api_name="filter_data")
392
- reset_btn.click(fn=get_data, inputs=[env, grpath], outputs=gr_dataframe, api_name="get_data")
393
- """
394
- block.load(
395
- download_leaderboard_dataset,
396
- inputs=[],
397
- outputs=[
398
- grpath
399
- ],
400
- )
401
- """
402
-
403
-
404
- scheduler = BackgroundScheduler()
405
- # Refresh every hour
406
- #scheduler.add_job(func=run_update_dataset, trigger="interval", seconds=3600)
407
- #scheduler.add_job(download_leaderboard_dataset, 'interval', seconds=3600)
408
- #scheduler.add_job(run_update_dataset, 'interval', seconds=3600)
409
- scheduler.add_job(restart, 'interval', seconds=10800)
410
- scheduler.start()
411
-
412
- block.launch()
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import gradio as gr
5
+ import pandas as pd
6
+ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
7
+ from huggingface_hub.repocard import metadata_load
8
+ from apscheduler.schedulers.background import BackgroundScheduler
9
+ from tqdm.contrib.concurrent import thread_map
10
+ from utils import *
11
+
12
+ DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data"
13
+ DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data"
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+ block = gr.Blocks()
17
+ api = HfApi(token=HF_TOKEN)
18
+
19
+ # Define RL environments
20
+ rl_envs = [
21
+ {"rl_env_beautiful": "LunarLander-v2 πŸš€", "rl_env": "LunarLander-v2", "video_link": "", "global": None},
22
+ {"rl_env_beautiful": "CartPole-v1", "rl_env": "CartPole-v1", "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4", "global": None},
23
+ {"rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery ❄️", "rl_env": "FrozenLake-v1-4x4-no_slippery", "video_link": "", "global": None},
24
+ {"rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery ❄️", "rl_env": "FrozenLake-v1-8x8-no_slippery", "video_link": "", "global": None},
25
+ {"rl_env_beautiful": "FrozenLake-v1-4x4 ❄️", "rl_env": "FrozenLake-v1-4x4", "video_link": "", "global": None},
26
+ {"rl_env_beautiful": "FrozenLake-v1-8x8 ❄️", "rl_env": "FrozenLake-v1-8x8", "video_link": "", "global": None},
27
+ {"rl_env_beautiful": "Taxi-v3 πŸš–", "rl_env": "Taxi-v3", "video_link": "", "global": None},
28
+ {"rl_env_beautiful": "CarRacing-v0 🏎️", "rl_env": "CarRacing-v0", "video_link": "", "global": None},
29
+ {"rl_env_beautiful": "CarRacing-v2 🏎️", "rl_env": "CarRacing-v2", "video_link": "", "global": None},
30
+ {"rl_env_beautiful": "MountainCar-v0 ⛰️", "rl_env": "MountainCar-v0", "video_link": "", "global": None},
31
+ {"rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πŸ‘Ύ", "rl_env": "SpaceInvadersNoFrameskip-v4", "video_link": "", "global": None},
32
+ {"rl_env_beautiful": "PongNoFrameskip-v4 🎾", "rl_env": "PongNoFrameskip-v4", "video_link": "", "global": None},
33
+ {"rl_env_beautiful": "BreakoutNoFrameskip-v4 🧱", "rl_env": "BreakoutNoFrameskip-v4", "video_link": "", "global": None},
34
+ {"rl_env_beautiful": "QbertNoFrameskip-v4 🐦", "rl_env": "QbertNoFrameskip-v4", "video_link": "", "global": None},
35
+ {"rl_env_beautiful": "BipedalWalker-v3", "rl_env": "BipedalWalker-v3", "video_link": "", "global": None},
36
+ {"rl_env_beautiful": "Walker2DBulletEnv-v0", "rl_env": "Walker2DBulletEnv-v0", "video_link": "", "global": None},
37
+ {"rl_env_beautiful": "AntBulletEnv-v0", "rl_env": "AntBulletEnv-v0", "video_link": "", "global": None},
38
+ {"rl_env_beautiful": "HalfCheetahBulletEnv-v0", "rl_env": "HalfCheetahBulletEnv-v0", "video_link": "", "global": None},
39
+ {"rl_env_beautiful": "PandaReachDense-v2", "rl_env": "PandaReachDense-v2", "video_link": "", "global": None},
40
+ {"rl_env_beautiful": "PandaReachDense-v3", "rl_env": "PandaReachDense-v3", "video_link": "", "global": None},
41
+ {"rl_env_beautiful": "Pixelcopter-PLE-v0", "rl_env": "Pixelcopter-PLE-v0", "video_link": "", "global": None}
42
+ ]
43
+
44
+ # -------------------- Utility Functions --------------------
45
+
46
+ def restart():
47
+ """Restart the Hugging Face Space."""
48
+ print("RESTARTING SPACE...")
49
+ api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")
50
+
51
+ def download_leaderboard_dataset():
52
+ """Download leaderboard dataset once at startup."""
53
+ print("Downloading leaderboard dataset...")
54
+ return snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
55
+
56
+ def get_metadata(model_id):
57
+ """Fetch metadata for a given model from Hugging Face."""
58
+ try:
59
+ readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
60
+ return metadata_load(readme_path)
61
+ except requests.exceptions.HTTPError:
62
+ return None # 404 README.md not found
63
+
64
+ def parse_metrics_accuracy(meta):
65
+ """Extract accuracy metrics from metadata."""
66
+ if "model-index" not in meta:
67
+ return None
68
+ result = meta["model-index"][0]["results"]
69
+ metrics = result[0]["metrics"]
70
+ return metrics[0]["value"]
71
+
72
+ def parse_rewards(accuracy):
73
+ """Extract mean and std rewards from accuracy metrics."""
74
+ default_std = -1000
75
+ default_reward = -1000
76
+ if accuracy is not None:
77
+ parsed = str(accuracy).split('+/-')
78
+ mean_reward = float(parsed[0].strip()) if parsed[0] else default_reward
79
+ std_reward = float(parsed[1].strip()) if len(parsed) > 1 else 0
80
+ else:
81
+ mean_reward, std_reward = default_reward, default_std
82
+ return mean_reward, std_reward
83
+
84
+ def get_model_ids(rl_env):
85
+ """Retrieve models matching the given RL environment."""
86
+ return [x.modelId for x in api.list_models(filter=rl_env)]
87
+
88
+ def update_leaderboard_dataset_parallel(rl_env, path):
89
+ """Parallelized update of leaderboard dataset for a given RL environment."""
90
+ model_ids = get_model_ids(rl_env)
91
+
92
+ def process_model(model_id):
93
+ meta = get_metadata(model_id)
94
+ if not meta:
95
+ return None
96
+ user_id = model_id.split('/')[0]
97
+ row = {
98
+ "User": user_id,
99
+ "Model": model_id,
100
+ "Results": None,
101
+ "Mean Reward": None,
102
+ "Std Reward": None
103
+ }
104
+ accuracy = parse_metrics_accuracy(meta)
105
+ mean_reward, std_reward = parse_rewards(accuracy)
106
+ row["Results"] = mean_reward - std_reward
107
+ row["Mean Reward"] = mean_reward
108
+ row["Std Reward"] = std_reward
109
+ return row
110
+
111
+ data = list(thread_map(process_model, model_ids, desc="Processing models"))
112
+ data = [row for row in data if row is not None]
113
+
114
+ ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
115
+ ranked_dataframe.to_csv(os.path.join(path, f"{rl_env}.csv"), index=False)
116
+
117
+ return ranked_dataframe
118
+
119
+ def rank_dataframe(dataframe):
120
+ """Sort models by results and assign ranking."""
121
+ dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
122
+ dataframe.insert(0, 'Ranking', range(1, len(dataframe) + 1))
123
+ return dataframe
124
+
125
+ def run_update_dataset():
126
+ """Update dataset periodically using the scheduler."""
127
+ path_ = download_leaderboard_dataset()
128
+ for env in rl_envs:
129
+ update_leaderboard_dataset_parallel(env["rl_env"], path_)
130
+
131
+ print("Uploading updated dataset...")
132
+ api.upload_folder(
133
+ folder_path=path_,
134
+ repo_id=DATASET_REPO_ID,
135
+ repo_type="dataset",
136
+ commit_message="Update dataset"
137
+ )
138
+
139
+ def filter_data(rl_env, path, user_id):
140
+ """Filter dataset for a specific user ID."""
141
+ data_df = pd.read_csv(os.path.join(path, f"{rl_env}.csv"))
142
+ return data_df[data_df["User"] == user_id]
143
+
144
+ # -------------------- Gradio UI --------------------
145
+
146
+ print("Initializing dataset...")
147
+ path_ = download_leaderboard_dataset()
148
+
149
+ with block:
150
+ gr.Markdown("""
151
+ # πŸ† Deep Reinforcement Learning Course Leaderboard πŸ†
152
+
153
+ This leaderboard displays trained agents from the [Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt).
154
+
155
+ **Models are ranked using `mean_reward - std_reward`.**
156
+
157
+ If you can't find your model, please wait for the next update (every 2 hours).
158
+ """)
159
+
160
+ grpath = gr.State(path_) # Store dataset path as a state variable
161
+
162
+ for env in rl_envs:
163
+ with gr.TabItem(env["rl_env_beautiful"]):
164
+ gr.Markdown(f"## {env['rl_env_beautiful']}")
165
+ user_id = gr.Textbox(label="Your user ID")
166
+ search_btn = gr.Button("Search πŸ”Ž")
167
+ reset_btn = gr.Button("Clear Search")
168
+ env_state = gr.State(env["rl_env"]) # Store environment name as a state variable
169
+
170
+ gr_dataframe = gr.Dataframe(
171
+ value=pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")),
172
+ headers=["Ranking πŸ†", "User πŸ€—", "Model πŸ€–", "Results", "Mean Reward", "Std Reward"],
173
+ datatype=["number", "markdown", "markdown", "number", "number", "number"],
174
+ # row_count=(100, 'fixed')
175
+ row_count=(100,"dynamic") # Allows displaying all rows dynamically
176
+
177
+ )
178
+
179
+ # βœ… Corrected: Use `gr.State()` for env["rl_env"] and `grpath`
180
+ search_btn.click(fn=filter_data, inputs=[env_state, grpath, user_id], outputs=gr_dataframe)
181
+ reset_btn.click(fn=lambda: pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), inputs=[], outputs=gr_dataframe)
182
+
183
+
184
+ # -------------------- Scheduler --------------------
185
+
186
+ scheduler = BackgroundScheduler()
187
+ scheduler.add_job(run_update_dataset, 'interval', hours=2) # Update dataset every 2 hours
188
+ scheduler.add_job(restart, 'interval', hours=3) # Restart space every 3 hours
189
+ scheduler.start()
190
+
191
+ block.launch()