SAC + HER Agent for FetchPickAndPlace-v4
Model Overview
This repository contains a Soft Actor-Critic (SAC) agent trained with Hindsight Experience Replay (HER) on the FetchPickAndPlace-v4
environment from gymnasium-robotics
. The agent learns to pick and place objects using sparse or dense rewards, and is suitable for robotic manipulation research.
- Algorithm: Soft Actor-Critic (SAC)
- Replay Buffer: Hindsight Experience Replay (HER)
- Environment: FetchPickAndPlace-v4 (
gymnasium-robotics
) - Framework: Stable Baselines3
Training Details
- Total Timesteps: 500,000
- Evaluation Frequency: Every 2,000 steps (15 episodes per eval)
- Checkpoint Frequency: Every 50,000 steps (model + replay buffer)
- Seed: 42
- Dense Shaping:
False
(can be enabled with wrapper) - Device: CUDA if available, otherwise auto
Hyperparameters
Parameter | Value |
---|---|
Algorithm | SAC |
Policy | MultiInputPolicy |
Replay Buffer | HER |
n_sampled_goal | 4 |
goal_selection_strategy | future |
Batch Size | 512 |
Buffer Size | 1,000,000 |
Learning Rate | 1e-3 |
Gamma | 0.95 |
Tau | 0.05 |
Entropy Coefficient | auto |
Train Frequency | 1 step |
Gradient Steps | 1 |
Tensorboard Log | logs_pnp_sac_her/tb |
Seed | 42 |
Device | CUDA/Auto |
Dense Shaping | False (default) |
Files
sac_her_pnp.zip
: Final trained SAC modelckpt_sac_her_250000_steps.zip
: Latest checkpointreplay_buffer.pkl
: Replay buffer for continued trainingreplay.mp4
: Replay video of agent performance (manual generation recommended)README.md
: This model card
Usage
To load and use the model for inference:
from stable_baselines3 import SAC
import gymnasium as gym
import gymnasium_robotics
env = gym.make("FetchPickAndPlace-v4", render_mode="rgb_array")
model = SAC.load("path/to/sac_her_pnp.zip", env=env)
obs, info = env.reset()
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
Evaluation
To evaluate the agent over multiple episodes:
from stable_baselines3 import SAC
import gymnasium as gym
import gymnasium_robotics
env = gym.make("FetchPickAndPlace-v4", render_mode="human")
model = SAC.load("path/to/sac_her_pnp.zip", env=env)
num_episodes = 10
for ep in range(num_episodes):
obs, info = env.reset()
done = False
truncated = False
episode_reward = 0
while not (done or truncated):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
episode_reward += reward
print(f"Episode {ep+1} reward: {episode_reward}")
env.close()
Replay Video
If replay.mp4
is not present, you can manually generate it:
import gymnasium as gym
import gymnasium_robotics
from stable_baselines3 import SAC
import moviepy.editor as mpy
env = gym.make("FetchPickAndPlace-v4", render_mode="rgb_array")
model = SAC.load("path/to/sac_her_pnp.zip", env=env)
frames = []
obs, info = env.reset()
done = False
truncated = False
step = 0
max_steps = 1000
while not (done or truncated) and step < max_steps:
frame = env.render()
frames.append(frame)
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
step += 1
env.close()
clip = mpy.ImageSequenceClip(frames, fps=30)
clip.write_videofile("replay.mp4", codec="libx264")
Continued Training
To continue training from a checkpoint:
from stable_baselines3 import SAC
import gymnasium as gym
import gymnasium_robotics
env = gym.make("FetchPickAndPlace-v4", render_mode=None)
model = SAC.load("logs_pnp_sac_her/ckpt_sac_her_250000_steps.zip", env=env)
model.learn(total_timesteps=500_000, reset_num_timesteps=False)
Citation
If you use this model, please cite:
@misc{IntelliGrow_FetchPickAndPlace_SAC_HER,
title={SAC + HER Agent for FetchPickAndPlace-v4},
author={IntelliGrow},
year={2025},
howpublished={Hugging Face Hub},
url={https://huggingface.co/IntelliGrow/FetchPickAndPlace-v4}
}
License
MIT License
Contact: For questions or issues, open an issue on the Hugging Face repository.
- Downloads last month
- 32
Evaluation results
- mean_reward on FetchPickAndPlace-v4self-reported-9.70 +/- 4.17