dot_bimanual_insert / README.md
IliaLarchenko's picture
Update README.md
3c4dfdb verified
metadata
library_name: lerobot
tags:
  - model_hub_mixin
  - pytorch_model_hub_mixin
  - robotics
  - dot
license: apache-2.0
datasets:
  - lerobot/aloha_sim_insertion_human
pipeline_tag: robotics

Model Card for "Decoder Only Transformer (DOT) Policy" for ALOHA bimanual insert problem

Read more about the model and implementation details in the DOT Policy repository.

This model is trained using the LeRobot library and achieves state-of-the-art results on behavior cloning on ALOHA bimanual insert dataset. It achieves 29.6% success rate vs. 21% for the previous state-of-the-art model (ACT).

This result is achieved without the checkpoint selection and is easy to reproduce.

You can use this model by installing LeRobot from this branch

To train the model:

python lerobot/scripts/train.py \
    --policy.type=dot \
    --dataset.repo_id=lerobot/aloha_sim_insertion_human \
    --env.type=aloha \
    --env.task=AlohaInsertion-v0 \
    --env.episode_length=500 \
    --output_dir=outputs/train/pusht_aloha_insert \
    --batch_size=24  \
    --log_freq=1000 \
    --eval_freq=10000 \
    --save_freq=10000 \
    --offline.steps=100000 \
    --seed=100000 \
    --wandb.enable=true \
    --num_workers=24 \
    --use_amp=true \
    --device=cuda \
    --policy.optimizer_lr=0.00003 \
    --policy.optimizer_min_lr=0.00001 \
    --policy.optimizer_lr_cycle_steps=100000 \
    --policy.train_horizon=150 \
    --policy.inference_horizon=100 \
    --policy.lookback_obs_steps=30 \
    --policy.lookback_aug=5 \
    --policy.rescale_shape="[480,640]" \
    --policy.alpha=0.98 \
    --policy.train_alpha=0.99

To evaluate the model:

python lerobot/scripts/eval.py \
    --policy.path=IliaLarchenko/dot_bimanual_insert \
    --env.type=aloha \
    --env.task=AlohaInsertion-v0 \
    --env.episode_length=500 \
    --eval.n_episodes=1000 \
    --eval.batch_size=100 \
    --seed=1000000

Model size:

  • Total parameters: 14.1m
  • Trainable parameters: 2.9m