diff --git a/lerobot/__init__.py b/lerobot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d61e4853e671ff02a88237960c9d7cafc3716d75 --- /dev/null +++ b/lerobot/__init__.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. +We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. + +Example: + ```python + import lerobot + print(lerobot.available_envs) + print(lerobot.available_tasks_per_env) + print(lerobot.available_datasets) + print(lerobot.available_datasets_per_env) + print(lerobot.available_real_world_datasets) + print(lerobot.available_policies) + print(lerobot.available_policies_per_env) + print(lerobot.available_robots) + print(lerobot.available_cameras) + print(lerobot.available_motors) + ``` + +When implementing a new dataset loadable with LeRobotDataset follow these steps: +- Update `available_datasets_per_env` in `lerobot/__init__.py` + +When implementing a new environment (e.g. `gym_aloha`), follow these steps: +- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` + +When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: +- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` +- Set the required `name` class attribute. +- Update variables in `tests/test_available.py` by importing your new Policy class +""" + +import itertools + +from lerobot.__version__ import __version__ # noqa: F401 + +# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies` +# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to +# a yaml file AND a environment name. The difference should be more obvious. +available_tasks_per_env = { + "aloha": [ + "AlohaInsertion-v0", + "AlohaTransferCube-v0", + ], + "pusht": ["PushT-v0"], + "xarm": ["XarmLift-v0"], +} +available_envs = list(available_tasks_per_env.keys()) + +available_datasets_per_env = { + "aloha": [ + "lerobot/aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_scripted", + "lerobot/aloha_sim_transfer_cube_human", + "lerobot/aloha_sim_transfer_cube_scripted", + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_transfer_cube_scripted_image", + ], + # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly + # coupled with tests. + "pusht": ["lerobot/pusht", "lerobot/pusht_image"], + "xarm": [ + "lerobot/xarm_lift_medium", + "lerobot/xarm_lift_medium_replay", + "lerobot/xarm_push_medium", + "lerobot/xarm_push_medium_replay", + "lerobot/xarm_lift_medium_image", + "lerobot/xarm_lift_medium_replay_image", + "lerobot/xarm_push_medium_image", + "lerobot/xarm_push_medium_replay_image", + ], +} + +available_real_world_datasets = [ + "lerobot/aloha_mobile_cabinet", + "lerobot/aloha_mobile_chair", + "lerobot/aloha_mobile_elevator", + "lerobot/aloha_mobile_shrimp", + "lerobot/aloha_mobile_wash_pan", + "lerobot/aloha_mobile_wipe_wine", + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + "lerobot/umi_cup_in_the_wild", + "lerobot/unitreeh1_fold_clothes", + "lerobot/unitreeh1_rearrange_objects", + "lerobot/unitreeh1_two_robot_greeting", + "lerobot/unitreeh1_warehouse", + "lerobot/nyu_rot_dataset", + "lerobot/utokyo_saytap", + "lerobot/imperialcollege_sawyer_wrist_cam", + "lerobot/utokyo_xarm_bimanual", + "lerobot/tokyo_u_lsmo", + "lerobot/utokyo_pr2_opening_fridge", + "lerobot/cmu_franka_exploration_dataset", + "lerobot/cmu_stretch", + "lerobot/asu_table_top", + "lerobot/utokyo_pr2_tabletop_manipulation", + "lerobot/utokyo_xarm_pick_and_place", + "lerobot/ucsd_kitchen_dataset", + "lerobot/austin_buds_dataset", + "lerobot/dlr_sara_grid_clamp", + "lerobot/conq_hose_manipulation", + "lerobot/columbia_cairlab_pusht_real", + "lerobot/dlr_sara_pour", + "lerobot/dlr_edan_shared_control", + "lerobot/ucsd_pick_and_place_dataset", + "lerobot/berkeley_cable_routing", + "lerobot/nyu_franka_play_dataset", + "lerobot/austin_sirius_dataset", + "lerobot/cmu_play_fusion", + "lerobot/berkeley_gnm_sac_son", + "lerobot/nyu_door_opening_surprising_effectiveness", + "lerobot/berkeley_fanuc_manipulation", + "lerobot/jaco_play", + "lerobot/viola", + "lerobot/kaist_nonprehensile", + "lerobot/berkeley_mvp", + "lerobot/uiuc_d3field", + "lerobot/berkeley_gnm_recon", + "lerobot/austin_sailor_dataset", + "lerobot/utaustin_mutex", + "lerobot/roboturk", + "lerobot/stanford_hydra_dataset", + "lerobot/berkeley_autolab_ur5", + "lerobot/stanford_robocook", + "lerobot/toto", + "lerobot/fmb", + "lerobot/droid_100", + "lerobot/berkeley_rpt", + "lerobot/stanford_kuka_multimodal_dataset", + "lerobot/iamlab_cmu_pickup_insert", + "lerobot/taco_play", + "lerobot/berkeley_gnm_cory_hall", + "lerobot/usc_cloth_sim", +] + +available_datasets = sorted( + set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) +) + +# lists all available policies from `lerobot/common/policies` +available_policies = [ + "act", + "diffusion", + "tdmpc", + "vqbet", +] + +# lists all available robots from `lerobot/common/robot_devices/robots` +available_robots = [ + "koch", + "koch_bimanual", + "aloha", + "so100", + "moss", +] + +# lists all available cameras from `lerobot/common/robot_devices/cameras` +available_cameras = [ + "opencv", + "intelrealsense", +] + +# lists all available motors from `lerobot/common/robot_devices/motors` +available_motors = [ + "dynamixel", + "feetech", +] + +# keys and values refer to yaml files +available_policies_per_env = { + "aloha": ["act"], + "pusht": ["diffusion", "vqbet"], + "xarm": ["tdmpc"], + "koch_real": ["act_koch_real"], + "aloha_real": ["act_aloha_real"], +} + +env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] +env_dataset_pairs = [ + (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets +] +env_dataset_policy_triplets = [ + (env, dataset, policy) + for env, datasets in available_datasets_per_env.items() + for dataset in datasets + for policy in available_policies_per_env[env] +] diff --git a/lerobot/__version__.py b/lerobot/__version__.py new file mode 100644 index 0000000000000000000000000000000000000000..d12aafaa9e573f408ce0d06654b519ba97832738 --- /dev/null +++ b/lerobot/__version__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""To enable `lerobot.__version__`""" + +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("lerobot") +except PackageNotFoundError: + __version__ = "unknown" diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..973595cdfc0e06f38c5a43b287a0e773c4520dbb --- /dev/null +++ b/lerobot/common/constants.py @@ -0,0 +1,45 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# keys +import os +from pathlib import Path + +from huggingface_hub.constants import HF_HOME + +OBS_ENV = "observation.environment_state" +OBS_ROBOT = "observation.state" +OBS_IMAGE = "observation.image" +OBS_IMAGES = "observation.images" +ACTION = "action" + +# files & directories +CHECKPOINTS_DIR = "checkpoints" +LAST_CHECKPOINT_LINK = "last" +PRETRAINED_MODEL_DIR = "pretrained_model" +TRAINING_STATE_DIR = "training_state" +RNG_STATE = "rng_state.safetensors" +TRAINING_STEP = "training_step.json" +OPTIMIZER_STATE = "optimizer_state.safetensors" +OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" +SCHEDULER_STATE = "scheduler_state.json" + +# cache dir +default_cache_path = Path(HF_HOME) / "lerobot" +HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() + +if "LEROBOT_HOME" in os.environ: + raise ValueError( + f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" + "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead." + ) diff --git a/lerobot/common/datasets/backward_compatibility.py b/lerobot/common/datasets/backward_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8e31c4fb704c5bd9291be5fb155e91f67b463b --- /dev/null +++ b/lerobot/common/datasets/backward_compatibility.py @@ -0,0 +1,68 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import packaging.version + +V2_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. + +We introduced a new format since v2.0 which is not backward compatible with v1.x. +Please, use our conversion script. Modify the following command with your own task description: +``` +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ + --repo-id {repo_id} \\ + --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ +``` + +A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the +peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top +cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped +target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the +sweatshirt.", ... + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +V21_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. +While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global +stats instead of per-episode stats. Update your dataset stats to the new format using this command: +``` +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id} +``` + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +FUTURE_MESSAGE = """ +The dataset you requested ({repo_id}) is only available in {version} format. +As we cannot ensure forward compatibility with it, please update your current version of lerobot. +""" + + +class CompatibilityError(Exception): ... + + +class BackwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = V2_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) + + +class ForwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) diff --git a/lerobot/common/datasets/card_template.md b/lerobot/common/datasets/card_template.md new file mode 100644 index 0000000000000000000000000000000000000000..7ee27df95dfdbe6ecb5f65054937fc09e4113523 --- /dev/null +++ b/lerobot/common/datasets/card_template.md @@ -0,0 +1,27 @@ +--- +# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/datasets-cards +{{ card_data }} +--- + +This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). + +## Dataset Description + +{{ dataset_description | default("", true) }} + +- **Homepage:** {{ url | default("[More Information Needed]", true)}} +- **Paper:** {{ paper | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} + +## Dataset Structure + +{{ dataset_structure | default("[More Information Needed]", true)}} + +## Citation + +**BibTeX:** + +```bibtex +{{ citation_bibtex | default("[More Information Needed]", true)}} +``` diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..1149ec83ed1b564860b45090aad9996d7c94bcb9 --- /dev/null +++ b/lerobot/common/datasets/compute_stats.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from lerobot.common.datasets.utils import load_image_as_numpy + + +def estimate_num_samples( + dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 +) -> int: + """Heuristic to estimate the number of samples based on dataset size. + The power controls the sample growth relative to dataset size. + Lower the power for less number of samples. + + For default arguments, we have: + - from 1 to ~500, num_samples=100 + - at 1000, num_samples=177 + - at 2000, num_samples=299 + - at 5000, num_samples=594 + - at 10000, num_samples=1000 + - at 20000, num_samples=1681 + """ + if dataset_len < min_num_samples: + min_num_samples = dataset_len + return max(min_num_samples, min(int(dataset_len**power), max_num_samples)) + + +def sample_indices(data_len: int) -> list[int]: + num_samples = estimate_num_samples(data_len) + return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() + + +def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): + _, height, width = img.shape + + if max(width, height) < max_size_threshold: + # no downsampling needed + return img + + downsample_factor = int(width / target_size) if width > height else int(height / target_size) + return img[:, ::downsample_factor, ::downsample_factor] + + +def sample_images(image_paths: list[str]) -> np.ndarray: + sampled_indices = sample_indices(len(image_paths)) + + images = None + for i, idx in enumerate(sampled_indices): + path = image_paths[idx] + # we load as uint8 to reduce memory usage + img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) + img = auto_downsample_height_width(img) + + if images is None: + images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) + + images[i] = img + + return images + + +def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: + return { + "min": np.min(array, axis=axis, keepdims=keepdims), + "max": np.max(array, axis=axis, keepdims=keepdims), + "mean": np.mean(array, axis=axis, keepdims=keepdims), + "std": np.std(array, axis=axis, keepdims=keepdims), + "count": np.array([len(array)]), + } + + +def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: + ep_stats = {} + for key, data in episode_data.items(): + if features[key]["dtype"] == "string": + continue # HACK: we should receive np.arrays of strings + elif features[key]["dtype"] in ["image", "video"]: + ep_ft_array = sample_images(data) # data is a list of image paths + axes_to_reduce = (0, 2, 3) # keep channel dim + keepdims = True + else: + ep_ft_array = data # data is already a np.ndarray + axes_to_reduce = 0 # compute stats over the first axis + keepdims = data.ndim == 1 # keep as np.array + + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + + # finally, we normalize and remove batch dim for images + if features[key]["dtype"] in ["image", "video"]: + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + } + + return ep_stats + + +def _assert_type_and_shape(stats_list: list[dict[str, dict]]): + for i in range(len(stats_list)): + for fkey in stats_list[i]: + for k, v in stats_list[i][fkey].items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." + ) + if v.ndim == 0: + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + if k == "count" and v.shape != (1,): + raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") + if "image" in fkey and k != "count" and v.shape != (3, 1, 1): + raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + + +def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregates stats for a single feature.""" + means = np.stack([s["mean"] for s in stats_ft_list]) + variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) + counts = np.stack([s["count"] for s in stats_ft_list]) + total_count = counts.sum(axis=0) + + # Prepare weighted mean by matching number of dimensions + while counts.ndim < means.ndim: + counts = np.expand_dims(counts, axis=-1) + + # Compute the weighted mean + weighted_means = means * counts + total_mean = weighted_means.sum(axis=0) / total_count + + # Compute the variance using the parallel algorithm + delta_means = means - total_mean + weighted_variances = (variances + delta_means**2) * counts + total_variance = weighted_variances.sum(axis=0) / total_count + + return { + "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), + "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), + "mean": total_mean, + "std": np.sqrt(total_variance), + "count": total_count, + } + + +def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregate stats from multiple compute_stats outputs into a single set of stats. + + The final stats will have the union of all data keys from each of the stats dicts. + + For instance: + - new_min = min(min_dataset_0, min_dataset_1, ...) + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_mean = (mean of all data, weighted by counts) + - new_std = (std of all data) + """ + + _assert_type_and_shape(stats_list) + + data_keys = {key for stats in stats_list for key in stats} + aggregated_stats = {key: {} for key in data_keys} + + for key in data_keys: + stats_with_key = [stats[key] for stats in stats_list if key in stats] + aggregated_stats[key] = aggregate_feature_stats(stats_with_key) + + return aggregated_stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..38c01b42f848a99feaa96c90a30c4bb59df45b74 --- /dev/null +++ b/lerobot/common/datasets/factory.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pprint import pformat + +import torch + +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, + MultiLeRobotDataset, +) +from lerobot.common.datasets.transforms import ImageTransforms +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.train import TrainPipelineConfig + +IMAGENET_STATS = { + "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) + "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1) +} + + +def resolve_delta_timestamps( + cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata +) -> dict[str, list] | None: + """Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig. + + Args: + cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from. + ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build + delta_timestamps against. + + Returns: + dict[str, list] | None: A dictionary of delta_timestamps, e.g.: + { + "observation.state": [-0.04, -0.02, 0] + "observation.action": [-0.02, 0, 0.02] + } + returns `None` if the the resulting dict is empty. + """ + delta_timestamps = {} + for key in ds_meta.features: + if key == "next.reward" and cfg.reward_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] + if key == "action" and cfg.action_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] + if key.startswith("observation.") and cfg.observation_delta_indices is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] + + if len(delta_timestamps) == 0: + delta_timestamps = None + + return delta_timestamps + + +def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: + """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. + + Args: + cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig. + + Raises: + NotImplementedError: The MultiLeRobotDataset is currently deactivated. + + Returns: + LeRobotDataset | MultiLeRobotDataset + """ + image_transforms = ( + ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ) + + if isinstance(cfg.dataset.repo_id, str): + ds_meta = LeRobotDatasetMetadata( + cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision + ) + delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=cfg.dataset.revision, + video_backend=cfg.dataset.video_backend, + ) + else: + raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") + dataset = MultiLeRobotDataset( + cfg.dataset.repo_id, + # TODO(aliberts): add proper support for multi dataset + # delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + video_backend=cfg.dataset.video_backend, + ) + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(dataset.repo_id_to_index, indent=2)}" + ) + + if cfg.dataset.use_imagenet_stats: + for key in dataset.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): + dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + + return dataset diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc0ee2f8adac77e61a1f64377626ae5b751a86b --- /dev/null +++ b/lerobot/common/datasets/image_writer.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import queue +import threading +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + + +def safe_stop_image_writer(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + dataset = kwargs.get("dataset") + image_writer = getattr(dataset, "image_writer", None) if dataset else None + if image_writer is not None: + print("Waiting for image writer to terminate...") + image_writer.stop() + raise e + + return wrapper + + +def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: + # TODO(aliberts): handle 1 channel and 4 for depth images + if image_array.ndim != 3: + raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") + + if image_array.shape[0] == 3: + # Transpose from pytorch convention (C, H, W) to (H, W, C) + image_array = image_array.transpose(1, 2, 0) + + elif image_array.shape[-1] != 3: + raise NotImplementedError( + f"The image has {image_array.shape[-1]} channels, but 3 is required for now." + ) + + if image_array.dtype != np.uint8: + if range_check: + max_ = image_array.max().item() + min_ = image_array.min().item() + if max_ > 1.0 or min_ < 0.0: + raise ValueError( + "The image data type is float, which requires values in the range [0.0, 1.0]. " + f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " + "provide a uint8 image with values in the range [0, 255]." + ) + + image_array = (image_array * 255).astype(np.uint8) + + return PIL.Image.fromarray(image_array) + + +def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): + try: + if isinstance(image, np.ndarray): + img = image_array_to_pil_image(image) + elif isinstance(image, PIL.Image.Image): + img = image + else: + raise TypeError(f"Unsupported image type: {type(image)}") + img.save(fpath) + except Exception as e: + print(f"Error writing image {fpath}: {e}") + + +def worker_thread_loop(queue: queue.Queue): + while True: + item = queue.get() + if item is None: + queue.task_done() + break + image_array, fpath = item + write_image(image_array, fpath) + queue.task_done() + + +def worker_process(queue: queue.Queue, num_threads: int): + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=worker_thread_loop, args=(queue,)) + t.daemon = True + t.start() + threads.append(t) + for t in threads: + t.join() + + +class AsyncImageWriter: + """ + This class abstract away the initialisation of processes or/and threads to + save images on disk asynchrounously, which is critical to control a robot and record data + at a high frame rate. + + When `num_processes=0`, it creates a threads pool of size `num_threads`. + When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts + their own threads pool of size `num_threads`. + + The optimal number of processes and threads depends on your computer capabilities. + We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower + the number of threads. If it is still not stable, try to use 1 subprocess, or more. + """ + + def __init__(self, num_processes: int = 0, num_threads: int = 1): + self.num_processes = num_processes + self.num_threads = num_threads + self.queue = None + self.threads = [] + self.processes = [] + self._stopped = False + + if num_threads <= 0 and num_processes <= 0: + raise ValueError("Number of threads and processes must be greater than zero.") + + if self.num_processes == 0: + # Use threading + self.queue = queue.Queue() + for _ in range(self.num_threads): + t = threading.Thread(target=worker_thread_loop, args=(self.queue,)) + t.daemon = True + t.start() + self.threads.append(t) + else: + # Use multiprocessing + self.queue = multiprocessing.JoinableQueue() + for _ in range(self.num_processes): + p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) + p.daemon = True + p.start() + self.processes.append(p) + + def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): + if isinstance(image, torch.Tensor): + # Convert tensor to numpy array to minimize main process time + image = image.cpu().numpy() + self.queue.put((image, fpath)) + + def wait_until_done(self): + self.queue.join() + + def stop(self): + if self._stopped: + return + + if self.num_processes == 0: + for _ in self.threads: + self.queue.put(None) + for t in self.threads: + t.join() + else: + num_nones = self.num_processes * self.num_threads + for _ in range(num_nones): + self.queue.put(None) + for p in self.processes: + p.join() + if p.is_alive(): + p.terminate() + self.queue.close() + self.queue.join_thread() + + self._stopped = True diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8da85d60d529840bcd31d542b44a14cd795cdd9 --- /dev/null +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -0,0 +1,1217 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import shutil +from pathlib import Path +from typing import Callable + +import datasets +import numpy as np +import packaging.version +import PIL.Image +import torch +import torch.utils +from datasets import concatenate_datasets, load_dataset +from huggingface_hub import HfApi, snapshot_download +from huggingface_hub.constants import REPOCARD_NAME +from huggingface_hub.errors import RevisionNotFoundError + +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats +from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.common.datasets.utils import ( + DEFAULT_FEATURES, + DEFAULT_IMAGE_PATH, + INFO_PATH, + TASKS_PATH, + append_jsonlines, + backward_compatible_episodes_stats, + check_delta_timestamps, + check_timestamps_sync, + check_version_compatibility, + create_empty_dataset_info, + create_lerobot_dataset_card, + embed_images, + get_delta_indices, + get_episode_data_index, + get_features_from_robot, + get_hf_features_from_features, + get_safe_version, + hf_transform_to_torch, + is_valid_version, + load_episodes, + load_episodes_stats, + load_info, + load_stats, + load_tasks, + validate_episode_buffer, + validate_frame, + write_episode, + write_episode_stats, + write_info, + write_json, +) +from lerobot.common.datasets.video_utils import ( + VideoFrame, + decode_video_frames, + encode_video_frames, + get_safe_default_codec, + get_video_info, +) +from lerobot.common.robot_devices.robots.utils import Robot + +CODEBASE_VERSION = "v2.1" + + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + ): + self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks, self.task_to_task_index = load_tasks(self.root) + self.episodes = load_episodes(self.root) + if self._version < packaging.version.parse("v2.1"): + self.stats = load_stats(self.root) + self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) + else: + self.episodes_stats = load_episodes_stats(self.root) + self.stats = aggregate_stats(list(self.episodes_stats.values())) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + return Path(fpath) + + def get_episode_chunk(self, ep_index: int) -> int: + return ep_index // self.chunks_size + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def total_chunks(self) -> int: + """Total number of chunks (groups of episodes).""" + return self.info["total_chunks"] + + @property + def chunks_size(self) -> int: + """Max number of episodes per chunk.""" + return self.info["chunks_size"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + return self.task_to_task_index.get(task, None) + + def add_task(self, task: str): + """ + Given a task in natural language, add it to the dictionary of tasks. + """ + if task in self.task_to_task_index: + raise ValueError(f"The task '{task}' already exists and can't be added twice.") + + task_index = self.info["total_tasks"] + self.task_to_task_index[task] = task_index + self.tasks[task_index] = task + self.info["total_tasks"] += 1 + + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, self.root / TASKS_PATH) + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + ) -> None: + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + + chunk = self.get_episode_chunk(episode_index) + if chunk >= self.total_chunks: + self.info["total_chunks"] += 1 + + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info["total_videos"] += len(self.video_keys) + if len(self.video_keys) > 0: + self.update_video_info() + + write_info(self.info, self.root) + + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + self.episodes[episode_index] = episode_dict + write_episode(episode_dict, self.root) + + self.episodes_stats[episode_index] = episode_stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + write_episode_stats(episode_index, episode_stats, self.root) + + def update_video_info(self) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + self.info["features"][key]["info"] = get_video_info(video_path) + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + root: str | Path | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + features: dict | None = None, + use_videos: bool = True, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + if robot is not None: + features = get_features_from_robot(robot, use_videos) + robot_type = robot.robot_type + if not all(cam.fps == fps for cam in robot.cameras.values()): + logging.warning( + f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." + "In this case, frames from lower fps cameras will be repeated to fill in the blanks." + ) + elif features is None: + raise ValueError( + "Dataset features must either come from a Robot or explicitly passed upon creation." + ) + else: + # TODO(aliberts, rcadene): implement sanity check for features + features = {**features, **DEFAULT_FEATURES} + + # check if none of the features contains a "/" in their names, + # as this would break the dict flattening in the stats computation, which uses '/' as separator + for key in features: + if "/" in key: + raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") + + features = {**features, **DEFAULT_FEATURES} + + obj.tasks, obj.task_to_task_index = {}, {} + obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError() + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + return obj + + +class LeRobotDataset(torch.utils.data.Dataset): + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + download_videos: bool = True, + video_backend: str | None = None, + ): + """ + 2 modes are available for instantiating this class, depending on 2 different use cases: + + 1. Your dataset already exists: + - On your local disk in the 'root' folder. This is typically the case when you recorded your + dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class + with 'root' will load your dataset directly from disk. This can happen while you're offline (no + internet connection). + + - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on + your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download + the dataset from that address and load it, pending your dataset is compliant with + codebase_version v2.0. If your dataset has been created before this new format, you will be + prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at + lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py. + + + 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty + LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an + existing dataset to the LeRobotDataset format. + + + In terms of files, LeRobotDataset encapsulates 3 main things: + - metadata: + - info contains various information about the dataset like shapes, keys, fps etc. + - stats stores the dataset statistics of the different modalities for normalization + - tasks contains the prompts for each task of the dataset, which can be used for + task-conditioned training. + - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + + A typical LeRobotDataset looks like this from its root path: + . + ├── data + │ ├── chunk-000 + │ │ ├── episode_000000.parquet + │ │ ├── episode_000001.parquet + │ │ ├── episode_000002.parquet + │ │ └── ... + │ ├── chunk-001 + │ │ ├── episode_001000.parquet + │ │ ├── episode_001001.parquet + │ │ ├── episode_001002.parquet + │ │ └── ... + │ └── ... + ├── meta + │ ├── episodes.jsonl + │ ├── info.json + │ ├── stats.json + │ └── tasks.jsonl + └── videos + ├── chunk-000 + │ ├── observation.images.laptop + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + │ ├── observation.images.phone + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + ├── chunk-001 + └── ... + + Note that this file-based structure is designed to be as versatile as possible. The files are split by + episodes which allows a more granular control over which episodes one wants to use and download. The + structure of the dataset is entirely described in the info.json file, which can be easily downloaded + or viewed directly on the hub before downloading any actual data. The type of files used are very + simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md + for the README). + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset + will be stored under root/repo_id. + root (Path | None, optional): Local directory to use for downloading/writing files. You can also + set the LEROBOT_HOME environment variable to point to a different location. Defaults to + '~/.cache/huggingface/lerobot'. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. Defaults to None. + image_transforms (Callable | None, optional): You can pass standard v2 image transforms from + torchvision.transforms.v2 here which will be applied to visual modalities (whether they come + from videos or images). Defaults to None. + delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None. + tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in + sync with the fps value. It is used at the init of the dataset to make sure that each + timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames + decoded from video files. It is also used to check that `delta_timestamps` (when provided) are + multiples of 1/fps. Defaults to 1e-4. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. Defaults to current codebase version tag. + sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + are already present in the local cache, this will be faster. However, files loaded might not + be in sync with the version on the hub, especially if you specified 'revision'. Defaults to + False. + download_videos (bool, optional): Flag to download the videos. Note that when set to True but the + video files are already present on local disk, they won't be downloaded again. Defaults to + True. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. + You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.delta_indices = None + + # Unused attributes + self.image_writer = None + self.episode_buffer = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): + episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] + self.stats = aggregate_stats(episodes_stats) + + # Load actual data + try: + if force_cache_sync: + raise FileNotFoundError + assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) + self.hf_dataset = self.load_hf_dataset() + except (AssertionError, FileNotFoundError, NotADirectoryError): + self.revision = get_safe_version(self.repo_id, self.revision) + self.download_episodes(download_videos) + self.hf_dataset = self.load_hf_dataset() + + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + + # Check timestamps + timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() + ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} + check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) + + # Setup delta_indices + if self.delta_timestamps is not None: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + def push_to_hub( + self, + branch: str | None = None, + tags: list | None = None, + license: str | None = "apache-2.0", + tag_version: bool = True, + push_videos: bool = True, + private: bool = False, + allow_patterns: list[str] | str | None = None, + upload_large_folder: bool = False, + **card_kwargs, + ) -> None: + ignore_patterns = ["images/"] + if not push_videos: + ignore_patterns.append("videos/") + + hub_api = HfApi() + hub_api.create_repo( + repo_id=self.repo_id, + private=private, + repo_type="dataset", + exist_ok=True, + ) + if branch: + hub_api.create_branch( + repo_id=self.repo_id, + branch=branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) + + upload_kwargs = { + "repo_id": self.repo_id, + "folder_path": self.root, + "repo_type": "dataset", + "revision": branch, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + } + if upload_large_folder: + hub_api.upload_large_folder(**upload_kwargs) + else: + hub_api.upload_folder(**upload_kwargs) + + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + + if tag_version: + with contextlib.suppress(RevisionNotFoundError): + hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + def download_episodes(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this + will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole + dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present + in 'local_dir', they won't be downloaded again. + """ + # TODO(rcadene, aliberts): implement faster transfer + # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads + files = None + ignore_patterns = None if download_videos else "videos/" + if self.episodes is not None: + files = self.get_episodes_file_paths() + + self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + + def get_episodes_file_paths(self) -> list[Path]: + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self.meta.video_keys) > 0: + video_files = [ + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + + return fpaths + + def load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if self.episodes is None: + path = str(self.root / "data") + hf_dataset = load_dataset("parquet", data_dir=path, split="train") + else: + files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + hf_dataset = load_dataset("parquet", data_files=files, split="train") + + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def create_hf_dataset(self) -> datasets.Dataset: + features = get_hf_features_from_features(self.features) + ft_dict = {col: [] for col in features} + hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") + + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self.meta.total_episodes + + @property + def features(self) -> dict[str, dict]: + return self.meta.features + + @property + def hf_features(self) -> datasets.Features: + """Features of the hf_dataset.""" + if self.hf_dataset is not None: + return self.hf_dataset.features + else: + return get_hf_features_from_features(self.features) + + def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + ep_start = self.episode_data_index["from"][ep_idx] + ep_end = self.episode_data_index["to"][ep_idx] + query_indices = { + key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { # Pad values outside of current episode range + f"{key}_is_pad": torch.BoolTensor( + [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + return { + key: torch.stack(self.hf_dataset.select(q_idx)[key]) + for key, q_idx in query_indices.items() + if key not in self.meta.video_keys + } + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + item = {} + for vid_key, query_ts in query_timestamps.items(): + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: + for key, val in padding.items(): + item[key] = torch.BoolTensor(val) + return item + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx) -> dict: + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self.meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self.image_transforms is not None: + image_keys = self.meta.camera_keys + for cam in image_keys: + item[cam] = self.image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self.meta.tasks[task_idx] + + return item + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + def create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + # size and task are special cases that are not in self.features + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self.root / fpath + + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath) + else: + self.image_writer.save_image(image=image, fpath=fpath) + + def add_frame(self, frame: dict) -> None: + """ + This function only adds the frame to the episode_buffer. Apart from images — which are written in a + temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method + then needs to be called. + """ + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self.features) + + if self.episode_buffer is None: + self.episode_buffer = self.create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + + # Add frame features to episode_buffer + for key in frame: + if key == "task": + # Note: we associate the task in natural language to its task index during `save_episode` + self.episode_buffer["task"].append(frame["task"]) + continue + + if key not in self.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." + ) + + if self.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + self._save_image(frame[key], img_path) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode(self, episode_data: dict | None = None) -> None: + """ + This will save to disk the current episode in self.episode_buffer. + + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. + """ + if not episode_data: + episode_buffer = self.episode_buffer + + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Add new tasks to the tasks dictionary + for task in episode_tasks: + task_index = self.meta.get_task_index(task) + if task_index is None: + self.meta.add_task(task) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) + + for key, ft in self.features.items(): + # index, episode_index, task_index are already processed above, and image and video + # are processed separately by storing image path and frame info as meta data + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + self._wait_image_writer() + self._save_episode_table(episode_buffer, episode_index) + ep_stats = compute_episode_stats(episode_buffer, self.features) + + if len(self.meta.video_keys) > 0: + video_paths = self.encode_episode_videos(episode_index) + for key in self.meta.video_keys: + episode_buffer[key] = video_paths[key] + + # `meta.save_episode` be executed after encoding the videos + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + + ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) + ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} + check_timestamps_sync( + episode_buffer["timestamp"], + episode_buffer["episode_index"], + ep_data_index_np, + self.fps, + self.tolerance_s, + ) + + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) + + parquet_files = list(self.root.rglob("*.parquet")) + assert len(parquet_files) == self.num_episodes + + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + + if not episode_data: # Reset the buffer + self.episode_buffer = self.create_episode_buffer() + + def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: + episode_dict = {key: episode_buffer[key] for key in self.hf_features} + ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) + self.hf_dataset.set_transform(hf_transform_to_torch) + ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) + ep_data_path.parent.mkdir(parents=True, exist_ok=True) + ep_dataset.to_parquet(ep_data_path) + + def clear_episode_buffer(self) -> None: + episode_index = self.episode_buffer["episode_index"] + if self.image_writer is not None: + for cam_key in self.meta.camera_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=cam_key, frame_index=0 + ).parent + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + # Reset the buffer + self.episode_buffer = self.create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + if isinstance(self.image_writer, AsyncImageWriter): + logging.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """ + Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to + remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. + """ + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def encode_videos(self) -> None: + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + for ep_idx in range(self.meta.total_episodes): + self.encode_episode_videos(ep_idx) + + def encode_episode_videos(self, episode_index: int) -> dict: + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + video_paths = {} + for key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(episode_index, key) + video_paths[key] = str(video_path) + if video_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + + return video_paths + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + root: str | Path | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + features: dict | None = None, + use_videos: bool = True, + tolerance_s: float = 1e-4, + image_writer_processes: int = 0, + image_writer_threads: int = 0, + video_backend: str | None = None, + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from scratch in order to record data.""" + obj = cls.__new__(cls) + obj.meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=fps, + root=root, + robot=robot, + robot_type=robot_type, + features=features, + use_videos=use_videos, + ) + obj.repo_id = obj.meta.repo_id + obj.root = obj.meta.root + obj.revision = None + obj.tolerance_s = tolerance_s + obj.image_writer = None + + if image_writer_processes or image_writer_threads: + obj.start_image_writer(image_writer_processes, image_writer_threads) + + # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer + obj.episode_buffer = obj.create_episode_buffer() + + obj.episodes = None + obj.hf_dataset = obj.create_hf_dataset() + obj.image_transforms = None + obj.delta_timestamps = None + obj.delta_indices = None + obj.episode_data_index = None + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + return obj + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + video_backend: str | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_frames: + start_idx += dataset.num_frames + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_features: + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..d907e46874f702b9d94313a0c7c80bd8fb661f72 --- /dev/null +++ b/lerobot/common/datasets/online_buffer.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An online buffer for the online training loop in train.py + +Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should +consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much +faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it +supports in-place slicing and mutation which is very handy for a dynamic buffer. +""" + +import os +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def _make_memmap_safe(**kwargs) -> np.memmap: + """Make a numpy memmap with checks on available disk space first. + + Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" + + For information on dtypes: + https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing + """ + if kwargs["mode"].startswith("w"): + required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes + stats = os.statvfs(Path(kwargs["filename"]).parent) + available_space = stats.f_bavail * stats.f_frsize # bytes + if required_space >= available_space * 0.8: + raise RuntimeError( + f"You're about to take up {required_space} of {available_space} bytes available." + ) + return np.memmap(**kwargs) + + +class OnlineBuffer(torch.utils.data.Dataset): + """FIFO data buffer for the online training loop in train.py. + + Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training + loop in the same way that a LeRobotDataset would be used. + + The underlying data structure will have data inserted in a circular fashion. Always insert after the + last index, and when you reach the end, wrap around to the start. + + The data is stored in a numpy memmap. + """ + + NEXT_INDEX_KEY = "_next_index" + OCCUPANCY_MASK_KEY = "_occupancy_mask" + INDEX_KEY = "index" + FRAME_INDEX_KEY = "frame_index" + EPISODE_INDEX_KEY = "episode_index" + TIMESTAMP_KEY = "timestamp" + IS_PAD_POSTFIX = "_is_pad" + + def __init__( + self, + write_dir: str | Path, + data_spec: dict[str, Any] | None, + buffer_capacity: int | None, + fps: float | None = None, + delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, + ): + """ + The online buffer can be provided from scratch or you can load an existing online buffer by passing + a `write_dir` associated with an existing buffer. + + Args: + write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. + Note that if the files already exist, they are opened in read-write mode (used for training + resumption.) + data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], + "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, + but note that "index", "frame_index" and "episode_index" are already accounted for by this + class, so you don't need to include them. + buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your + system's available disk space when choosing this. + fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the + delta_timestamps logic. You can pass None if you are not using delta_timestamps. + delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally + converted to dict[str, np.ndarray] for optimization purposes. + + """ + self.set_delta_timestamps(delta_timestamps) + self._fps = fps + # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from + # the requested frames. It is only used when `delta_timestamps` is provided. + # minus 1e-4 to account for possible numerical error + self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None + self._buffer_capacity = buffer_capacity + data_spec = self._make_data_spec(data_spec, buffer_capacity) + Path(write_dir).mkdir(parents=True, exist_ok=True) + self._data = {} + for k, v in data_spec.items(): + self._data[k] = _make_memmap_safe( + filename=Path(write_dir) / k, + dtype=v["dtype"] if v is not None else None, + mode="r+" if (Path(write_dir) / k).exists() else "w+", + shape=tuple(v["shape"]) if v is not None else None, + ) + + @property + def delta_timestamps(self) -> dict[str, np.ndarray] | None: + return self._delta_timestamps + + def set_delta_timestamps(self, value: dict[str, list[float]] | None): + """Set delta_timestamps converting the values to numpy arrays. + + The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays + need to be converted into numpy arrays. + """ + if value is not None: + self._delta_timestamps = {k: np.array(v) for k, v in value.items()} + else: + self._delta_timestamps = None + + def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: + """Makes the data spec for np.memmap.""" + if any(k.startswith("_") for k in data_spec): + raise ValueError( + "data_spec keys should not start with '_'. This prefix is reserved for internal logic." + ) + preset_keys = { + OnlineBuffer.INDEX_KEY, + OnlineBuffer.FRAME_INDEX_KEY, + OnlineBuffer.EPISODE_INDEX_KEY, + OnlineBuffer.TIMESTAMP_KEY, + } + if len(intersection := set(data_spec).intersection(preset_keys)) > 0: + raise ValueError( + f"data_spec should not contain any of {preset_keys} as these are handled internally. " + f"The provided data_spec has {intersection}." + ) + complete_data_spec = { + # _next_index will be a pointer to the next index that we should start filling from when we add + # more data. + OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, + # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied + # with real data rather than the dummy initialization. + OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, + OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, + } + for k, v in data_spec.items(): + complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} + return complete_data_spec + + def add_data(self, data: dict[str, np.ndarray]): + """Add new data to the buffer, which could potentially mean shifting old data out. + + The new data should contain all the frames (in order) of any number of episodes. The indices should + start from 0 (note to the developer: this can easily be generalized). See the `rollout` and + `eval_policy` functions in `eval.py` for more information on how the data is constructed. + + Shift the incoming data index and episode_index to continue on from the last frame. Note that this + will be done in place! + """ + if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: + raise ValueError(f"Missing data keys: {missing_keys}") + new_data_length = len(data[self.data_keys[0]]) + if not all(len(data[k]) == new_data_length for k in self.data_keys): + raise ValueError("All data items should have the same length") + + next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] + + # Sanity check to make sure that the new data indices start from 0. + assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 + assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 + + # Shift the incoming indices if necessary. + if self.num_frames > 0: + last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] + last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] + data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 + data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 + + # Insert the new data starting from next_index. It may be necessary to wrap around to the start. + n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) + for k in self.data_keys: + if n_surplus == 0: + slc = slice(next_index, next_index + new_data_length) + self._data[k][slc] = data[k] + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True + else: + self._data[k][next_index:] = data[k][:-n_surplus] + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True + self._data[k][:n_surplus] = data[k][-n_surplus:] + if n_surplus == 0: + self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length + else: + self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus + + @property + def data_keys(self) -> list[str]: + keys = set(self._data) + keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) + keys.remove(OnlineBuffer.NEXT_INDEX_KEY) + return sorted(keys) + + @property + def fps(self) -> float | None: + return self._fps + + @property + def num_episodes(self) -> int: + return len( + np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + ) + + @property + def num_frames(self) -> int: + return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) + + def __len__(self): + return self.num_frames + + def _item_to_tensors(self, item: dict) -> dict: + item_ = {} + for k, v in item.items(): + if isinstance(v, torch.Tensor): + item_[k] = v + elif isinstance(v, np.ndarray): + item_[k] = torch.from_numpy(v) + else: + item_[k] = torch.tensor(v) + return item_ + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self) or idx < -len(self): + raise IndexError + + item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} + + if self.delta_timestamps is None: + return self._item_to_tensors(item) + + episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] + current_ts = item[OnlineBuffer.TIMESTAMP_KEY] + episode_data_indices = np.where( + np.bitwise_and( + self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], + ) + )[0] + episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] + + for data_key in self.delta_timestamps: + # Note: The logic in this loop is copied from `load_previous_and_future_frames`. + # Get timestamps used as query to retrieve data of previous/future frames. + query_ts = current_ts + self.delta_timestamps[data_key] + + # Compute distances between each query timestamp and all timestamps of all the frames belonging to + # the episode. + dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) + argmin_ = np.argmin(dist, axis=1) + min_ = dist[np.arange(dist.shape[0]), argmin_] + + is_pad = min_ > self.tolerance_s + + # Check violated query timestamps are all outside the episode range. + assert ( + (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) + ).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" + ") inside the episode range." + ) + + # Load frames for this data key. + item[data_key] = self._data[data_key][episode_data_indices[argmin_]] + + item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad + + return self._item_to_tensors(item) + + def get_data_by_key(self, key: str) -> torch.Tensor: + """Returns all data for a given data key as a Tensor.""" + return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + + +def compute_sampler_weights( + offline_dataset: LeRobotDataset, + offline_drop_n_last_frames: int = 0, + online_dataset: OnlineBuffer | None = None, + online_sampling_ratio: float | None = None, + online_drop_n_last_frames: int = 0, +) -> torch.Tensor: + """Compute the sampling weights for the online training dataloader in train.py. + + Args: + offline_dataset: The LeRobotDataset used for offline pre-training. + online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. + online_dataset: The OnlineBuffer used in online training. + online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an + online dataset is provided, this value must also be provided. + online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online + dataset. + Returns: + Tensor of weights for [offline_dataset; online_dataset], normalized to 1. + + Notes to maintainers: + - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. + - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace + `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature + is the ability to turn shuffling off. + - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not + included here to avoid adding complexity. + """ + if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): + raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") + if (online_dataset is None) ^ (online_sampling_ratio is None): + raise ValueError( + "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." + ) + offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + + weights = [] + + if len(offline_dataset) > 0: + offline_data_mask_indices = [] + for start_index, end_index in zip( + offline_dataset.episode_data_index["from"], + offline_dataset.episode_data_index["to"], + strict=True, + ): + offline_data_mask_indices.extend( + range(start_index.item(), end_index.item() - offline_drop_n_last_frames) + ) + offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) + offline_data_mask[torch.tensor(offline_data_mask_indices)] = True + weights.append( + torch.full( + size=(len(offline_dataset),), + fill_value=offline_sampling_ratio / offline_data_mask.sum(), + ) + * offline_data_mask + ) + + if online_dataset is not None and len(online_dataset) > 0: + online_data_mask_indices = [] + episode_indices = online_dataset.get_data_by_key("episode_index") + for episode_idx in torch.unique(episode_indices): + where_episode = torch.where(episode_indices == episode_idx) + start_index = where_episode[0][0] + end_index = where_episode[0][-1] + 1 + online_data_mask_indices.extend( + range(start_index.item(), end_index.item() - online_drop_n_last_frames) + ) + online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) + online_data_mask[torch.tensor(online_data_mask_indices)] = True + weights.append( + torch.full( + size=(len(online_dataset),), + fill_value=online_sampling_ratio / online_data_mask.sum(), + ) + * online_data_mask + ) + + weights = torch.cat(weights) + + if weights.sum() == 0: + weights += 1 / len(weights) + else: + weights /= weights.sum() + + return weights diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcf87f77f1e2892d4c536b574ba2a15acaf82d3 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict + +import datasets +import numpy +import PIL +import torch + +from lerobot.common.datasets.video_utils import encode_video_frames + + +def concatenate_episodes(ep_dicts): + data_dict = {} + + keys = ep_dicts[0].keys() + for key in keys: + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict + + +def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + def save_image(img_array, i, out_dir): + img = PIL.Image.fromarray(img_array) + img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100) + + num_images = len(imgs_array) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] + + +def get_default_encoding() -> dict: + """Returns the default ffmpeg encoding parameters used by `encode_video_frames`.""" + signature = inspect.signature(encode_video_frames) + return { + k: v.default + for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] + } + + +def check_repo_id(repo_id: str) -> None: + if len(repo_id.split("/")) != 2: + raise ValueError( + f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset + (e.g. 'lerobot/pusht'), but contains '{repo_id}'.""" + ) + + +# TODO(aliberts): remove +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: + """ + Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. + + Parameters: + - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. + + Returns: + - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: + - "from": A tensor containing the starting index of each episode. + - "to": A tensor containing the ending index of each episode. + """ + episode_data_index = {"from": [], "to": []} + + current_episode = None + """ + The episode_index is a list of integers, each representing the episode index of the corresponding example. + For instance, the following is a valid episode_index: + [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] + + Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and + ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: + { + "from": [0, 3, 7], + "to": [3, 7, 12] + } + """ + if len(hf_dataset) == 0: + episode_data_index = { + "from": torch.tensor([]), + "to": torch.tensor([]), + } + return episode_data_index + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + # We encountered a new episode, so we append its starting location to the "from" list + episode_data_index["from"].append(idx) + # If this is not the first episode, we append the ending location of the previous episode to the "to" list + if current_episode is not None: + episode_data_index["to"].append(idx) + # Let's keep track of the current episode index + current_episode = episode_idx + else: + # We are still in the same episode, so there is nothing for us to do here + pass + # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list + episode_data_index["to"].append(idx + 1) + + for k in ["from", "to"]: + episode_data_index[k] = torch.tensor(episode_data_index[k]) + + return episode_data_index diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6c15c1500055d197fe081401840e5ba7847479 --- /dev/null +++ b/lerobot/common/datasets/sampler.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterator, Union + +import torch + + +class EpisodeAwareSampler: + def __init__( + self, + episode_data_index: dict, + episode_indices_to_use: Union[list, None] = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + shuffle: bool = False, + ): + """Sampler that optionally incorporates episode boundary information. + + Args: + episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + episode_indices_to_use: List of episode indices to use. If None, all episodes are used. + Assumes that episodes are indexed from 0 to N-1. + drop_n_first_frames: Number of frames to drop from the start of each episode. + drop_n_last_frames: Number of frames to drop from the end of each episode. + shuffle: Whether to shuffle the indices. + """ + indices = [] + for episode_idx, (start_index, end_index) in enumerate( + zip(episode_data_index["from"], episode_data_index["to"], strict=True) + ): + if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + indices.extend( + range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + ) + + self.indices = indices + self.shuffle = shuffle + + def __iter__(self) -> Iterator[int]: + if self.shuffle: + for i in torch.randperm(len(self.indices)): + yield self.indices[i] + else: + for i in self.indices: + yield i + + def __len__(self) -> int: + return len(self.indices) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..720c939b8f15829b1aeaf4631ed1565f636c4782 --- /dev/null +++ b/lerobot/common/datasets/transforms.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from dataclasses import dataclass, field +from typing import Any, Callable, Sequence + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import functional as F # noqa: N812 + + +class RandomSubsetApply(Transform): + """Apply a random subset of N transformations from a list of transformations. + + Args: + transforms: list of transformations. + p: represents the multinomial probabilities (with no replacement) used for sampling the transform. + If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms + have the same probability. + n_subset: number of transformations to apply. If ``None``, all transforms are applied. + Must be in [1, len(transforms)]. + random_order: apply transformations in a random order. + """ + + def __init__( + self, + transforms: Sequence[Callable], + p: list[float] | None = None, + n_subset: int | None = None, + random_order: bool = False, + ) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError( + f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" + ) + + if n_subset is None: + n_subset = len(transforms) + elif not isinstance(n_subset, int): + raise TypeError("n_subset should be an int or None") + elif not (1 <= n_subset <= len(transforms)): + raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + + self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] + self.n_subset = n_subset + self.random_order = random_order + + self.selected_transforms = None + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) + if not self.random_order: + selected_indices = selected_indices.sort().values + + self.selected_transforms = [self.transforms[i] for i in selected_indices] + + for transform in self.selected_transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + + return outputs + + def extra_repr(self) -> str: + return ( + f"transforms={self.transforms}, " + f"p={self.p}, " + f"n_subset={self.n_subset}, " + f"random_order={self.random_order}" + ) + + +class SharpnessJitter(Transform): + """Randomly change the sharpness of an image or video. + + Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. + While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image, + SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of + augmentations as a result. + + A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness + by a factor of 2. + + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from + [max(0, 1 - sharpness), 1 + sharpness] or the given + [min, max]. Should be non negative numbers. + """ + + def __init__(self, sharpness: float | Sequence[float]) -> None: + super().__init__() + self.sharpness = self._check_input(sharpness) + + def _check_input(self, sharpness): + if isinstance(sharpness, (int, float)): + if sharpness < 0: + raise ValueError("If sharpness is a single number, it must be non negative.") + sharpness = [1.0 - sharpness, 1.0 + sharpness] + sharpness[0] = max(sharpness[0], 0.0) + elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: + sharpness = [float(v) for v in sharpness] + else: + raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + + if not 0.0 <= sharpness[0] <= sharpness[1]: + raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + + return float(sharpness[0]), float(sharpness[1]) + + def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: + sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() + return {"sharpness_factor": sharpness_factor} + + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + sharpness_factor = params["sharpness_factor"] + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + + +@dataclass +class ImageTransformConfig: + """ + For each transform, the following parameters are available: + weight: This represents the multinomial probability (with no replacement) + used for sampling the transform. If the sum of the weights is not 1, + they will be normalized. + type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a + custom transform defined here. + kwargs: Lower & upper bound respectively used for sampling the transform's parameter + (following uniform distribution) when it's applied. + """ + + weight: float = 1.0 + type: str = "Identity" + kwargs: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ImageTransformsConfig: + """ + These transforms are all using standard torchvision.transforms.v2 + You can find out how these transformations affect images here: + https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html + We use a custom RandomSubsetApply container to sample them. + """ + + # Set this flag to `true` to enable transforms during training + enable: bool = False + # This is the maximum number of transforms (sampled from these below) that will be applied to each frame. + # It's an integer in the interval [1, number_of_available_transforms]. + max_num_transforms: int = 3 + # By default, transforms are applied in Torchvision's suggested order (shown below). + # Set this to True to apply them in a random order. + random_order: bool = False + tfs: dict[str, ImageTransformConfig] = field( + default_factory=lambda: { + "brightness": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"brightness": (0.8, 1.2)}, + ), + "contrast": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"contrast": (0.8, 1.2)}, + ), + "saturation": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"saturation": (0.5, 1.5)}, + ), + "hue": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"hue": (-0.05, 0.05)}, + ), + "sharpness": ImageTransformConfig( + weight=1.0, + type="SharpnessJitter", + kwargs={"sharpness": (0.5, 1.5)}, + ), + } + ) + + +def make_transform_from_config(cfg: ImageTransformConfig): + if cfg.type == "Identity": + return v2.Identity(**cfg.kwargs) + elif cfg.type == "ColorJitter": + return v2.ColorJitter(**cfg.kwargs) + elif cfg.type == "SharpnessJitter": + return SharpnessJitter(**cfg.kwargs) + else: + raise ValueError(f"Transform '{cfg.type}' is not valid.") + + +class ImageTransforms(Transform): + """A class to compose image transforms based on configuration.""" + + def __init__(self, cfg: ImageTransformsConfig) -> None: + super().__init__() + self._cfg = cfg + + self.weights = [] + self.transforms = {} + for tf_name, tf_cfg in cfg.tfs.items(): + if tf_cfg.weight <= 0.0: + continue + + self.transforms[tf_name] = make_transform_from_config(tf_cfg) + self.weights.append(tf_cfg.weight) + + n_subset = min(len(self.transforms), cfg.max_num_transforms) + if n_subset == 0 or not cfg.enable: + self.tf = v2.Identity() + else: + self.tf = RandomSubsetApply( + transforms=list(self.transforms.values()), + p=self.weights, + n_subset=n_subset, + random_order=cfg.random_order, + ) + + def forward(self, *inputs: Any) -> Any: + return self.tf(*inputs) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8a54db14d60e82ddfd1726d96b515734ba37a0 --- /dev/null +++ b/lerobot/common/datasets/utils.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import importlib.resources +import json +import logging +from collections.abc import Iterator +from itertools import accumulate +from pathlib import Path +from pprint import pformat +from types import SimpleNamespace +from typing import Any + +import datasets +import jsonlines +import numpy as np +import packaging.version +import torch +from datasets.table import embed_table_storage +from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from huggingface_hub.errors import RevisionNotFoundError +from PIL import Image as PILImage +from torchvision import transforms + +from lerobot.common.datasets.backward_compatibility import ( + V21_MESSAGE, + BackwardCompatibilityError, + ForwardCompatibilityError, +) +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.utils.utils import is_valid_numpy_dtype_string +from lerobot.configs.types import DictLike, FeatureType, PolicyFeature + +DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk + +INFO_PATH = "meta/info.json" +EPISODES_PATH = "meta/episodes.jsonl" +STATS_PATH = "meta/stats.json" +EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" +TASKS_PATH = "meta/tasks.jsonl" + +DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" + +DATASET_CARD_TEMPLATE = """ +--- +# Metadata will go there +--- +This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). + +## {} + +""" + +DEFAULT_FEATURES = { + "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, + "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, + "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, + "index": {"dtype": "int64", "shape": (1,), "names": None}, + "task_index": {"dtype": "int64", "shape": (1,), "names": None}, +} + + +def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: + """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + + For example: + ``` + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` + >>> print(flatten_dict(dct)) + {"a/b": 1, "a/c/d": 2, "e": 3} + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dict(d: dict, sep: str = "/") -> dict: + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d = outdict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return outdict + + +def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: + split_keys = flattened_key.split(sep) + getter = obj[split_keys[0]] + if len(split_keys) == 1: + return getter + + for key in split_keys[1:]: + getter = getter[key] + + return getter + + +def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: + serialized_dict = {} + for key, value in flatten_dict(stats).items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + serialized_dict[key] = value.tolist() + elif isinstance(value, np.generic): + serialized_dict[key] = value.item() + elif isinstance(value, (int, float)): + serialized_dict[key] = value + else: + raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") + return unflatten_dict(serialized_dict) + + +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + return dataset + + +def load_json(fpath: Path) -> Any: + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def load_jsonlines(fpath: Path) -> list[Any]: + with jsonlines.open(fpath, "r") as reader: + return list(reader) + + +def write_jsonlines(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(data) + + +def append_jsonlines(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with jsonlines.open(fpath, "a") as writer: + writer.write(data) + + +def write_info(info: dict, local_dir: Path): + write_json(info, local_dir / INFO_PATH) + + +def load_info(local_dir: Path) -> dict: + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info + + +def write_stats(stats: dict, local_dir: Path): + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_task(task_index: int, task: dict, local_dir: Path): + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, local_dir / TASKS_PATH) + + +def load_tasks(local_dir: Path) -> tuple[dict, dict]: + tasks = load_jsonlines(local_dir / TASKS_PATH) + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + task_to_task_index = {task: task_index for task_index, task in tasks.items()} + return tasks, task_to_task_index + + +def write_episode(episode: dict, local_dir: Path): + append_jsonlines(episode, local_dir / EPISODES_PATH) + + +def load_episodes(local_dir: Path) -> dict: + episodes = load_jsonlines(local_dir / EPISODES_PATH) + return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} + + +def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): + # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` + # is a dictionary of stats and not an integer. + episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} + append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) + + +def load_episodes_stats(local_dir: Path) -> dict: + episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) + return { + item["episode_index"]: cast_stats_to_numpy(item["stats"]) + for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) + } + + +def backward_compatible_episodes_stats( + stats: dict[str, dict[str, np.ndarray]], episodes: list[int] +) -> dict[str, dict[str, np.ndarray]]: + return dict.fromkeys(episodes, stats) + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: + img = PILImage.open(fpath).convert("RGB") + img_array = np.array(img, dtype=dtype) + if channel_first: # (H, W, C) -> (C, H, W) + img_array = np.transpose(img_array, (2, 0, 1)) + if np.issubdtype(dtype, np.floating): + img_array /= 255.0 + return img_array + + +def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): + """Get a transform function that convert items from Hugging Face dataset (pyarrow) + to torch tensors. Importantly, images are converted from PIL, which corresponds to + a channel last representation (h w c) of uint8 type, to a torch image representation + with channel first (c h w) of float32 type in range [0,1]. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif first_item is None: + pass + else: + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def is_valid_version(version: str) -> bool: + try: + packaging.version.parse(version) + return True + except packaging.version.InvalidVersion: + return False + + +def check_version_compatibility( + repo_id: str, + version_to_check: str | packaging.version.Version, + current_version: str | packaging.version.Version, + enforce_breaking_major: bool = True, +) -> None: + v_check = ( + packaging.version.parse(version_to_check) + if not isinstance(version_to_check, packaging.version.Version) + else version_to_check + ) + v_current = ( + packaging.version.parse(current_version) + if not isinstance(current_version, packaging.version.Version) + else current_version + ) + if v_check.major < v_current.major and enforce_breaking_major: + raise BackwardCompatibilityError(repo_id, v_check) + elif v_check.minor < v_current.minor: + logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) + + +def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: + """Returns available valid versions (branches and tags) on given repo.""" + api = HfApi() + repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") + repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] + repo_versions = [] + for ref in repo_refs: + with contextlib.suppress(packaging.version.InvalidVersion): + repo_versions.append(packaging.version.parse(ref)) + + return repo_versions + + +def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: + """ + Returns the version if available on repo or the latest compatible one. + Otherwise, will throw a `CompatibilityError`. + """ + target_version = ( + packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version + ) + hub_versions = get_repo_versions(repo_id) + + if not hub_versions: + raise RevisionNotFoundError( + f"""Your dataset must be tagged with a codebase version. + Assuming _version_ is the codebase_version value in the info.json, you can run this: + ```python + from huggingface_hub import HfApi + + hub_api = HfApi() + hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") + ``` + """ + ) + + if target_version in hub_versions: + return f"v{target_version}" + + compatibles = [ + v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor + ] + if compatibles: + return_version = max(compatibles) + if return_version < target_version: + logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") + return f"v{return_version}" + + lower_major = [v for v in hub_versions if v.major < target_version.major] + if lower_major: + raise BackwardCompatibilityError(repo_id, max(lower_major)) + + upper_versions = [v for v in hub_versions if v > target_version] + assert len(upper_versions) > 0 + raise ForwardCompatibilityError(repo_id, min(upper_versions)) + + +def get_hf_features_from_features(features: dict) -> datasets.Features: + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + elif len(ft["shape"]) == 1: + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") + + return datasets.Features(hf_features) + + +def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: + camera_ft = {} + if robot.cameras: + camera_ft = { + key: {"dtype": "video" if use_videos else "image", **ft} + for key, ft in robot.camera_features.items() + } + return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == "observation.environment_state": + type = FeatureType.ENV + elif key.startswith("observation"): + type = FeatureType.STATE + elif key == "action": + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + robot_type: str, + features: dict, + use_videos: bool, +) -> dict: + return { + "codebase_version": codebase_version, + "robot_type": robot_type, + "total_episodes": 0, + "total_frames": 0, + "total_tasks": 0, + "total_videos": 0, + "total_chunks": 0, + "chunks_size": DEFAULT_CHUNK_SIZE, + "fps": fps, + "splits": {}, + "data_path": DEFAULT_PARQUET_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, + } + + +def get_episode_data_index( + episode_dicts: dict[dict], episodes: list[int] | None = None +) -> dict[str, torch.Tensor]: + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} + if episodes is not None: + episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} + + cumulative_lengths = list(accumulate(episode_lengths.values())) + return { + "from": torch.LongTensor([0] + cumulative_lengths[:-1]), + "to": torch.LongTensor(cumulative_lengths), + } + + +def check_timestamps_sync( + timestamps: np.ndarray, + episode_indices: np.ndarray, + episode_data_index: dict[str, np.ndarray], + fps: int, + tolerance_s: float, + raise_value_error: bool = True, +) -> bool: + """ + This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance + to account for possible numerical error. + + Args: + timestamps (np.ndarray): Array of timestamps in seconds. + episode_indices (np.ndarray): Array indicating the episode index for each timestamp. + episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', + which identifies indices for the end of each episode. + fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. + tolerance_s (float): Allowed deviation from the expected (1/fps) difference. + raise_value_error (bool): Whether to raise a ValueError if the check fails. + + Returns: + bool: True if all checked timestamp differences lie within tolerance, False otherwise. + + Raises: + ValueError: If the check fails and `raise_value_error` is True. + """ + if timestamps.shape != episode_indices.shape: + raise ValueError( + "timestamps and episode_indices should have the same shape. " + f"Found {timestamps.shape=} and {episode_indices.shape=}." + ) + + # Consecutive differences + diffs = np.diff(timestamps) + within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s + + # Mask to ignore differences at the boundaries between episodes + mask = np.ones(len(diffs), dtype=bool) + ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode + mask[ignored_diffs] = False + filtered_within_tolerance = within_tolerance[mask] + + # Check if all remaining diffs are within tolerance + if not np.all(filtered_within_tolerance): + # Track original indices before masking + original_indices = np.arange(len(diffs)) + filtered_indices = original_indices[mask] + outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] + outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] + + outside_tolerances = [] + for idx in outside_tolerance_indices: + entry = { + "timestamps": [timestamps[idx], timestamps[idx + 1]], + "diff": diffs[idx], + "episode_index": episode_indices[idx].item() + if hasattr(episode_indices[idx], "item") + else episode_indices[idx], + } + outside_tolerances.append(entry) + + if raise_value_error: + raise ValueError( + f"""One or several timestamps unexpectedly violate the tolerance inside episode range. + This might be due to synchronization issues during data collection. + \n{pformat(outside_tolerances)}""" + ) + return False + + return True + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. + This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be + actual timestamps from the dataset. + """ + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + if not all(within_tolerance): + outside_tolerance[key] = [ + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = [round(d * fps) for d in delta_ts] + + return delta_indices + + +def cycle(iterable): + """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + + See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + """ + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) + + +def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: + """Create a branch on a existing Hugging Face repo. Delete the branch if it already + exists before creating it. + """ + api = HfApi() + + branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches + refs = [branch.ref for branch in branches] + ref = f"refs/heads/{branch}" + if ref in refs: + api.delete_branch(repo_id, repo_type=repo_type, branch=branch) + + api.create_branch(repo_id, repo_type=repo_type, branch=branch) + + +def create_lerobot_dataset_card( + tags: list | None = None, + dataset_info: dict | None = None, + **kwargs, +) -> DatasetCard: + """ + Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md. + Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. + """ + card_tags = ["LeRobot"] + + if tags: + card_tags += tags + if dataset_info: + dataset_structure = "[meta/info.json](meta/info.json):\n" + dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n" + kwargs = {**kwargs, "dataset_structure": dataset_structure} + card_data = DatasetCardData( + license=kwargs.get("license"), + tags=card_tags, + task_categories=["robotics"], + configs=[ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ], + ) + + card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() + + return DatasetCard.from_template( + card_data=card_data, + template_str=card_template, + **kwargs, + ) + + +class IterableNamespace(SimpleNamespace): + """ + A namespace object that supports both dictionary-like iteration and dot notation access. + Automatically converts nested dictionaries into IterableNamespaces. + + This class extends SimpleNamespace to provide: + - Dictionary-style iteration over keys + - Access to items via both dot notation (obj.key) and brackets (obj["key"]) + - Dictionary-like methods: items(), keys(), values() + - Recursive conversion of nested dictionaries + + Args: + dictionary: Optional dictionary to initialize the namespace + **kwargs: Additional keyword arguments passed to SimpleNamespace + + Examples: + >>> data = {"name": "Alice", "details": {"age": 25}} + >>> ns = IterableNamespace(data) + >>> ns.name + 'Alice' + >>> ns.details.age + 25 + >>> list(ns.keys()) + ['name', 'details'] + >>> for key, value in ns.items(): + ... print(f"{key}: {value}") + name: Alice + details: IterableNamespace(age=25) + """ + + def __init__(self, dictionary: dict[str, Any] = None, **kwargs): + super().__init__(**kwargs) + if dictionary is not None: + for key, value in dictionary.items(): + if isinstance(value, dict): + setattr(self, key, IterableNamespace(value)) + else: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(vars(self)) + + def __getitem__(self, key: str) -> Any: + return vars(self)[key] + + def items(self): + return vars(self).items() + + def values(self): + return vars(self).values() + + def keys(self): + return vars(self).keys() + + +def validate_frame(frame: dict, features: dict): + optional_features = {"timestamp"} + expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} + actual_features = set(frame.keys()) + + error_message = validate_features_presence(actual_features, expected_features, optional_features) + + if "task" in frame: + error_message += validate_feature_string("task", frame["task"]) + + common_features = actual_features & (expected_features | optional_features) + for name in common_features - {"task"}: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence( + actual_features: set[str], expected_features: set[str], optional_features: set[str] +): + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - (expected_features | optional_features) + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +): + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str): + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..99ab2cbf6d5ad8a349064bd7e57d78ca59a35189 --- /dev/null +++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py @@ -0,0 +1,884 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2. + +Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in +lerobot/configs/robot/aloha.yaml before running this script. +""" + +import traceback +from pathlib import Path +from textwrap import dedent + +from lerobot import available_datasets +from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset +from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig + +LOCAL_DIR = Path("data/") + +# spellchecker:off +ALOHA_MOBILE_INFO = { + "robot_config": AlohaRobotConfig(), + "license": "mit", + "url": "https://mobile-aloha.github.io/", + "paper": "https://arxiv.org/abs/2401.02117", + "citation_bibtex": dedent(r""" + @inproceedings{fu2024mobile, + author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea}, + title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation}, + booktitle = {arXiv}, + year = {2024}, + }""").lstrip(), +} +ALOHA_STATIC_INFO = { + "robot_config": AlohaRobotConfig(), + "license": "mit", + "url": "https://tonyzhaozh.github.io/aloha/", + "paper": "https://arxiv.org/abs/2304.13705", + "citation_bibtex": dedent(r""" + @article{Zhao2023LearningFB, + title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, + author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn}, + journal={RSS}, + year={2023}, + volume={abs/2304.13705}, + url={https://arxiv.org/abs/2304.13705} + }""").lstrip(), +} +PUSHT_INFO = { + "license": "mit", + "url": "https://diffusion-policy.cs.columbia.edu/", + "paper": "https://arxiv.org/abs/2303.04137v5", + "citation_bibtex": dedent(r""" + @article{chi2024diffusionpolicy, + author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, + title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, + journal = {The International Journal of Robotics Research}, + year = {2024}, + }""").lstrip(), +} +XARM_INFO = { + "license": "mit", + "url": "https://www.nicklashansen.com/td-mpc/", + "paper": "https://arxiv.org/abs/2203.04955", + "citation_bibtex": dedent(r""" + @inproceedings{Hansen2022tdmpc, + title={Temporal Difference Learning for Model Predictive Control}, + author={Nicklas Hansen and Xiaolong Wang and Hao Su}, + booktitle={ICML}, + year={2022} + } + """), +} +UNITREEH_INFO = { + "license": "apache-2.0", +} + +DATASETS = { + "aloha_mobile_cabinet": { + "single_task": "Open the top cabinet, store the pot inside it then close the cabinet.", + **ALOHA_MOBILE_INFO, + }, + "aloha_mobile_chair": { + "single_task": "Push the chairs in front of the desk to place them against it.", + **ALOHA_MOBILE_INFO, + }, + "aloha_mobile_elevator": { + "single_task": "Take the elevator to the 1st floor.", + **ALOHA_MOBILE_INFO, + }, + "aloha_mobile_shrimp": { + "single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.", + **ALOHA_MOBILE_INFO, + }, + "aloha_mobile_wash_pan": { + "single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.", + **ALOHA_MOBILE_INFO, + }, + "aloha_mobile_wipe_wine": { + "single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.", + **ALOHA_MOBILE_INFO, + }, + "aloha_static_battery": { + "single_task": "Place the battery into the slot of the remote controller.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO}, + "aloha_static_coffee": { + "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_coffee_new": { + "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_cups_open": { + "single_task": "Pick up the plastic cup and open its lid.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_fork_pick_up": { + "single_task": "Pick up the fork and place it on the plate.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_pingpong_test": { + "single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_pro_pencil": { + "single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_screw_driver": { + "single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_tape": { + "single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_thread_velcro": { + "single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_towel": { + "single_task": "Pick up a piece of paper towel and place it on the spilled liquid.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_vinh_cup": { + "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_vinh_cup_left": { + "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, + "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, + "aloha_sim_insertion_scripted_image": { + "single_task": "Insert the peg into the socket.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, + "aloha_sim_insertion_human_image": { + "single_task": "Insert the peg into the socket.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_transfer_cube_scripted": { + "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_transfer_cube_scripted_image": { + "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_transfer_cube_human": { + "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_transfer_cube_human_image": { + "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", + **ALOHA_STATIC_INFO, + }, + "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, + "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, + "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, + "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO}, + "unitreeh1_two_robot_greeting": { + "single_task": "Greet the other robot with a high five.", + **UNITREEH_INFO, + }, + "unitreeh1_warehouse": { + "single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", + **UNITREEH_INFO, + }, + "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "umi_cup_in_the_wild": { + "single_task": "Put the cup on the plate.", + "license": "apache-2.0", + }, + "asu_table_top": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1", + "citation_bibtex": dedent(r""" + @inproceedings{zhou2023modularity, + title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation}, + author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni}, + booktitle={Conference on Robot Learning}, + pages={1684--1695}, + year={2023}, + organization={PMLR} + } + @article{zhou2023learning, + title={Learning modular language-conditioned robot policies through attention}, + author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon}, + journal={Autonomous Robots}, + pages={1--21}, + year={2023}, + publisher={Springer} + }""").lstrip(), + }, + "austin_buds_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://ut-austin-rpl.github.io/BUDS-website/", + "paper": "https://arxiv.org/abs/2109.13841", + "citation_bibtex": dedent(r""" + @article{zhu2022bottom, + title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation}, + author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke}, + journal={IEEE Robotics and Automation Letters}, + volume={7}, + number={2}, + pages={4126--4133}, + year={2022}, + publisher={IEEE} + }""").lstrip(), + }, + "austin_sailor_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://ut-austin-rpl.github.io/sailor/", + "paper": "https://arxiv.org/abs/2210.11435", + "citation_bibtex": dedent(r""" + @inproceedings{nasiriany2022sailor, + title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning}, + author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu}, + booktitle={Conference on Robot Learning (CoRL)}, + year={2022} + }""").lstrip(), + }, + "austin_sirius_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://ut-austin-rpl.github.io/sirius/", + "paper": "https://arxiv.org/abs/2211.08416", + "citation_bibtex": dedent(r""" + @inproceedings{liu2022robot, + title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment}, + author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu}, + booktitle = {Robotics: Science and Systems (RSS)}, + year = {2023} + }""").lstrip(), + }, + "berkeley_autolab_ur5": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://sites.google.com/view/berkeley-ur5/home", + "citation_bibtex": dedent(r""" + @misc{BerkeleyUR5Website, + title = {Berkeley {UR5} Demonstration Dataset}, + author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg}, + howpublished = {https://sites.google.com/view/berkeley-ur5/home}, + }""").lstrip(), + }, + "berkeley_cable_routing": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://sites.google.com/view/cablerouting/home", + "paper": "https://arxiv.org/abs/2307.08927", + "citation_bibtex": dedent(r""" + @article{luo2023multistage, + author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine}, + title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning}, + journal = {arXiv pre-print}, + year = {2023}, + url = {https://arxiv.org/abs/2307.08927}, + }""").lstrip(), + }, + "berkeley_fanuc_manipulation": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/berkeley.edu/fanuc-manipulation", + "citation_bibtex": dedent(r""" + @article{fanuc_manipulation2023, + title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot}, + author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi}, + year={2023}, + }""").lstrip(), + }, + "berkeley_gnm_cory_hall": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://arxiv.org/abs/1709.10489", + "citation_bibtex": dedent(r""" + @inproceedings{kahn2018self, + title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation}, + author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey}, + booktitle={2018 IEEE international conference on robotics and automation (ICRA)}, + pages={5129--5136}, + year={2018}, + organization={IEEE} + }""").lstrip(), + }, + "berkeley_gnm_recon": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/view/recon-robot", + "paper": "https://arxiv.org/abs/2104.05859", + "citation_bibtex": dedent(r""" + @inproceedings{shah2021rapid, + title={Rapid Exploration for Open-World Navigation with Latent Goal Models}, + author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine}, + booktitle={5th Annual Conference on Robot Learning }, + year={2021}, + url={https://openreview.net/forum?id=d_SWJhyKfVw} + }""").lstrip(), + }, + "berkeley_gnm_sac_son": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/view/SACSoN-review", + "paper": "https://arxiv.org/abs/2306.01874", + "citation_bibtex": dedent(r""" + @article{hirose2023sacson, + title={SACSoN: Scalable Autonomous Data Collection for Social Navigation}, + author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey}, + journal={arXiv preprint arXiv:2306.01874}, + year={2023} + }""").lstrip(), + }, + "berkeley_mvp": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://arxiv.org/abs/2203.06173", + "citation_bibtex": dedent(r""" + @InProceedings{Radosavovic2022, + title = {Real-World Robot Learning with Masked Visual Pre-training}, + author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell}, + booktitle = {CoRL}, + year = {2022} + }""").lstrip(), + }, + "berkeley_rpt": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://arxiv.org/abs/2306.10007", + "citation_bibtex": dedent(r""" + @article{Radosavovic2023, + title={Robot Learning with Sensorimotor Pre-training}, + author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik}, + year={2023}, + journal={arXiv:2306.10007} + }""").lstrip(), + }, + "cmu_franka_exploration_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://human-world-model.github.io/", + "paper": "https://arxiv.org/abs/2308.10901", + "citation_bibtex": dedent(r""" + @inproceedings{mendonca2023structured, + title={Structured World Models from Human Videos}, + author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak}, + journal={RSS}, + year={2023} + }""").lstrip(), + }, + "cmu_play_fusion": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://play-fusion.github.io/", + "paper": "https://arxiv.org/abs/2312.04549", + "citation_bibtex": dedent(r""" + @inproceedings{chen2023playfusion, + title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play}, + author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak}, + booktitle={CoRL}, + year={2023} + }""").lstrip(), + }, + "cmu_stretch": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://robo-affordances.github.io/", + "paper": "https://arxiv.org/abs/2304.08488", + "citation_bibtex": dedent(r""" + @inproceedings{bahl2023affordances, + title={Affordances from Human Videos as a Versatile Representation for Robotics}, + author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak}, + booktitle={CVPR}, + year={2023} + } + @article{mendonca2023structured, + title={Structured World Models from Human Videos}, + author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak}, + journal={CoRL}, + year={2023} + }""").lstrip(), + }, + "columbia_cairlab_pusht_real": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://diffusion-policy.cs.columbia.edu/", + "paper": "https://arxiv.org/abs/2303.04137v5", + "citation_bibtex": dedent(r""" + @inproceedings{chi2023diffusionpolicy, + title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, + author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran}, + booktitle={Proceedings of Robotics: Science and Systems (RSS)}, + year={2023} + }""").lstrip(), + }, + "conq_hose_manipulation": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home", + "citation_bibtex": dedent(r""" + @misc{ConqHoseManipData, + author={Peter Mitrano and Dmitry Berenson}, + title={Conq Hose Manipulation Dataset, v1.15.0}, + year={2024}, + howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset} + }""").lstrip(), + }, + "dlr_edan_shared_control": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://ieeexplore.ieee.org/document/9341156", + "citation_bibtex": dedent(r""" + @inproceedings{vogel_edan_2020, + title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities}, + language = {en}, + booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})}, + author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin}, + year = {2020} + } + @inproceedings{quere_shared_2020, + address = {Paris, France}, + title = {Shared {Control} {Templates} for {Assistive} {Robotics}}, + language = {en}, + booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})}, + author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern}, + year = {2020}, + pages = {7}, + }""").lstrip(), + }, + "dlr_sara_grid_clamp": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://www.researchsquare.com/article/rs-3289569/v1", + "citation_bibtex": dedent(r""" + @article{padalkar2023guided, + title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world}, + author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek}, + journal={Research square preprint rs-3289569/v1}, + year={2023} + }""").lstrip(), + }, + "dlr_sara_pour": { + "tasks_col": "language_instruction", + "license": "mit", + "paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf", + "citation_bibtex": dedent(r""" + @inproceedings{padalkar2023guiding, + title={Guiding Reinforcement Learning with Shared Control Templates}, + author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek}, + booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023}, + year={2023}, + organization={IEEE} + }""").lstrip(), + }, + "droid_100": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://droid-dataset.github.io/", + "paper": "https://arxiv.org/abs/2403.12945", + "citation_bibtex": dedent(r""" + @article{khazatsky2024droid, + title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset}, + author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn}, + year = {2024}, + }""").lstrip(), + }, + "fmb": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://functional-manipulation-benchmark.github.io/", + "paper": "https://arxiv.org/abs/2401.08553", + "citation_bibtex": dedent(r""" + @article{luo2024fmb, + title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning}, + author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey}, + journal={arXiv preprint arXiv:2401.08553}, + year={2024} + }""").lstrip(), + }, + "iamlab_cmu_pickup_insert": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://openreview.net/forum?id=WuBv9-IGDUA", + "paper": "https://arxiv.org/abs/2401.14502", + "citation_bibtex": dedent(r""" + @inproceedings{saxena2023multiresolution, + title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models}, + author={Saumya Saxena and Mohit Sharma and Oliver Kroemer}, + booktitle={7th Annual Conference on Robot Learning}, + year={2023}, + url={https://openreview.net/forum?id=WuBv9-IGDUA} + }""").lstrip(), + }, + "imperialcollege_sawyer_wrist_cam": { + "tasks_col": "language_instruction", + "license": "mit", + }, + "jaco_play": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://github.com/clvrai/clvr_jaco_play_dataset", + "citation_bibtex": dedent(r""" + @software{dass2023jacoplay, + author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui + and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.}, + title = {CLVR Jaco Play Dataset}, + url = {https://github.com/clvrai/clvr_jaco_play_dataset}, + version = {1.0.0}, + year = {2023} + }""").lstrip(), + }, + "kaist_nonprehensile": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder", + "citation_bibtex": dedent(r""" + @article{kimpre, + title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer}, + author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon}, + booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, + year={2023}, + organization={IEEE} + }""").lstrip(), + }, + "nyu_door_opening_surprising_effectiveness": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://jyopari.github.io/VINN/", + "paper": "https://arxiv.org/abs/2112.01511", + "citation_bibtex": dedent(r""" + @misc{pari2021surprising, + title={The Surprising Effectiveness of Representation Learning for Visual Imitation}, + author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto}, + year={2021}, + eprint={2112.01511}, + archivePrefix={arXiv}, + primaryClass={cs.RO} + }""").lstrip(), + }, + "nyu_franka_play_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://play-to-policy.github.io/", + "paper": "https://arxiv.org/abs/2210.10047", + "citation_bibtex": dedent(r""" + @article{cui2022play, + title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data}, + author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, + journal = {arXiv preprint arXiv:2210.10047}, + year = {2022} + }""").lstrip(), + }, + "nyu_rot_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://rot-robot.github.io/", + "paper": "https://arxiv.org/abs/2206.15469", + "citation_bibtex": dedent(r""" + @inproceedings{haldar2023watch, + title={Watch and match: Supercharging imitation with regularized optimal transport}, + author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel}, + booktitle={Conference on Robot Learning}, + pages={32--43}, + year={2023}, + organization={PMLR} + }""").lstrip(), + }, + "roboturk": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://roboturk.stanford.edu/dataset_real.html", + "paper": "PAPER", + "citation_bibtex": dedent(r""" + @inproceedings{mandlekar2019scaling, + title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity}, + author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li}, + booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, + pages={1048--1055}, + year={2019}, + organization={IEEE} + }""").lstrip(), + }, + "stanford_hydra_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/view/hydra-il-2023", + "paper": "https://arxiv.org/abs/2306.17237", + "citation_bibtex": dedent(r""" + @article{belkhale2023hydra, + title={HYDRA: Hybrid Robot Actions for Imitation Learning}, + author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa}, + journal={arxiv}, + year={2023} + }""").lstrip(), + }, + "stanford_kuka_multimodal_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://sites.google.com/view/visionandtouch", + "paper": "https://arxiv.org/abs/1810.10191", + "citation_bibtex": dedent(r""" + @inproceedings{lee2019icra, + title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks}, + author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette}, + booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)}, + year={2019}, + url={https://arxiv.org/abs/1810.10191} + }""").lstrip(), + }, + "stanford_robocook": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://hshi74.github.io/robocook/", + "paper": "https://arxiv.org/abs/2306.14447", + "citation_bibtex": dedent(r""" + @article{shi2023robocook, + title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools}, + author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun}, + journal={arXiv preprint arXiv:2306.14447}, + year={2023} + }""").lstrip(), + }, + "taco_play": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "url": "https://www.kaggle.com/datasets/oiermees/taco-robot", + "paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911", + "citation_bibtex": dedent(r""" + @inproceedings{rosete2022tacorl, + author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard}, + title = {Latent Plans for Task Agnostic Offline Reinforcement Learning}, + journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)}, + year = {2022} + } + @inproceedings{mees23hulc2, + title={Grounding Language with Visual Affordances over Unstructured Data}, + author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard}, + booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)}, + year={2023}, + address = {London, UK} + }""").lstrip(), + }, + "tokyo_u_lsmo": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "URL", + "paper": "https://arxiv.org/abs/2107.05842", + "citation_bibtex": dedent(r""" + @Article{Osa22, + author = {Takayuki Osa}, + journal = {The International Journal of Robotics Research}, + title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization}, + year = {2022}, + number = {3}, + pages = {291--311}, + volume = {41}, + }""").lstrip(), + }, + "toto": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://toto-benchmark.org/", + "paper": "https://arxiv.org/abs/2306.00942", + "citation_bibtex": dedent(r""" + @inproceedings{zhou2023train, + author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav}, + booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, + title={Train Offline, Test Online: A Real Robot Learning Benchmark}, + year={2023}, + }""").lstrip(), + }, + "ucsd_kitchen_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "citation_bibtex": dedent(r""" + @ARTICLE{ucsd_kitchens, + author = {Ge Yan, Kris Wu, and Xiaolong Wang}, + title = {{ucsd kitchens Dataset}}, + year = {2023}, + month = {August} + }""").lstrip(), + }, + "ucsd_pick_and_place_dataset": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://owmcorl.github.io/#", + "paper": "https://arxiv.org/abs/2310.16029", + "citation_bibtex": dedent(r""" + @preprint{Feng2023Finetuning, + title={Finetuning Offline World Models in the Real World}, + author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang}, + year={2023} + }""").lstrip(), + }, + "uiuc_d3field": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://robopil.github.io/d3fields/", + "paper": "https://arxiv.org/abs/2309.16118", + "citation_bibtex": dedent(r""" + @article{wang2023d3field, + title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation}, + author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu}, + journal={arXiv preprint arXiv:}, + year={2023}, + }""").lstrip(), + }, + "usc_cloth_sim": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://uscresl.github.io/dmfd/", + "paper": "https://arxiv.org/abs/2207.10148", + "citation_bibtex": dedent(r""" + @article{salhotra2022dmfd, + author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.}, + journal={IEEE Robotics and Automation Letters}, + title={Learning Deformable Object Manipulation From Expert Demonstrations}, + year={2022}, + volume={7}, + number={4}, + pages={8775-8782}, + doi={10.1109/LRA.2022.3187843} + }""").lstrip(), + }, + "utaustin_mutex": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://ut-austin-rpl.github.io/MUTEX/", + "paper": "https://arxiv.org/abs/2309.14320", + "citation_bibtex": dedent(r""" + @inproceedings{shah2023mutex, + title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications}, + author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu}, + booktitle={7th Annual Conference on Robot Learning}, + year={2023}, + url={https://openreview.net/forum?id=PwqiqaaEzJ} + }""").lstrip(), + }, + "utokyo_pr2_opening_fridge": { + "tasks_col": "language_instruction", + "license": "mit", + "citation_bibtex": dedent(r""" + @misc{oh2023pr2utokyodatasets, + author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka}, + title={X-Embodiment U-Tokyo PR2 Datasets}, + year={2023}, + url={https://github.com/ojh6404/rlds_dataset_builder}, + }""").lstrip(), + }, + "utokyo_pr2_tabletop_manipulation": { + "tasks_col": "language_instruction", + "license": "mit", + "citation_bibtex": dedent(r""" + @misc{oh2023pr2utokyodatasets, + author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka}, + title={X-Embodiment U-Tokyo PR2 Datasets}, + year={2023}, + url={https://github.com/ojh6404/rlds_dataset_builder}, + }""").lstrip(), + }, + "utokyo_saytap": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://saytap.github.io/", + "paper": "https://arxiv.org/abs/2306.07580", + "citation_bibtex": dedent(r""" + @article{saytap2023, + author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and + Tatsuya Harada}, + title = {SayTap: Language to Quadrupedal Locomotion}, + eprint = {arXiv:2306.07580}, + url = {https://saytap.github.io}, + note = {https://saytap.github.io}, + year = {2023} + }""").lstrip(), + }, + "utokyo_xarm_bimanual": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "citation_bibtex": dedent(r""" + @misc{matsushima2023weblab, + title={Weblab xArm Dataset}, + author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo}, + year={2023}, + }""").lstrip(), + }, + "utokyo_xarm_pick_and_place": { + "tasks_col": "language_instruction", + "license": "cc-by-4.0", + "citation_bibtex": dedent(r""" + @misc{matsushima2023weblab, + title={Weblab xArm Dataset}, + author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo}, + year={2023}, + }""").lstrip(), + }, + "viola": { + "tasks_col": "language_instruction", + "license": "mit", + "url": "https://ut-austin-rpl.github.io/VIOLA/", + "paper": "https://arxiv.org/abs/2210.11339", + "citation_bibtex": dedent(r""" + @article{zhu2022viola, + title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors}, + author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke}, + journal={6th Annual Conference on Robot Learning (CoRL)}, + year={2022} + }""").lstrip(), + }, +} +# spellchecker:on + + +def batch_convert(): + status = {} + logfile = LOCAL_DIR / "conversion_log.txt" + assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets} + for num, (name, kwargs) in enumerate(DATASETS.items()): + repo_id = f"lerobot/{name}" + print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})") + print("---------------------------------------------------------") + try: + convert_dataset(repo_id, LOCAL_DIR, **kwargs) + status = f"{repo_id}: success." + with open(logfile, "a") as file: + file.write(status + "\n") + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + with open(logfile, "a") as file: + file.write(status + "\n") + continue + + +if __name__ == "__main__": + batch_convert() diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..024576d709c7eda0c446a3e84b1fc9c83a12f4b1 --- /dev/null +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -0,0 +1,664 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to +2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English +for each of the task performed in the dataset. This will allow to easily train models with task-conditioning. + +We support 3 different scenarios for these tasks (see instructions below): + 1. Single task dataset: all episodes of your dataset have the same single task. + 2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from + one episode to the next. + 3. Multi task episodes: episodes of your dataset may each contain several different tasks. + + +Can you can also provide a robot config .yaml file (not mandatory) to this script via the option +'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was +recorded with. For now, only Aloha/Koch type robots are supported with this option. + + +# 1. Single task dataset +If your dataset contains a single task, you can simply provide it directly via the CLI with the +'--single-task' option. + +Examples: + +```bash +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \ + --repo-id lerobot/aloha_sim_insertion_human_image \ + --single-task "Insert the peg into the socket." \ + --robot-config lerobot/configs/robot/aloha.yaml \ + --local-dir data +``` + +```bash +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \ + --repo-id aliberts/koch_tutorial \ + --single-task "Pick the Lego block and drop it in the box on the right." \ + --robot-config lerobot/configs/robot/koch.yaml \ + --local-dir data +``` + + +# 2. Single task episodes +If your dataset is a multi-task dataset, you have two options to provide the tasks to this script: + +- If your dataset already contains a language instruction column in its parquet file, you can simply provide + this column's name with the '--tasks-col' arg. + + Example: + + ```bash + python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \ + --repo-id lerobot/stanford_kuka_multimodal_dataset \ + --tasks-col "language_instruction" \ + --local-dir data + ``` + +- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the + '--tasks-path' arg. This file should have the following structure where keys correspond to each + episode_index in the dataset, and values are the language instruction for that episode. + + Example: + + ```json + { + "0": "Do something", + "1": "Do something else", + "2": "Do something", + "3": "Go there", + ... + } + ``` + +# 3. Multi task episodes +If you have multiple tasks per episodes, your dataset should contain a language instruction column in its +parquet file, and you must provide this column's name with the '--tasks-col' arg. + +Example: + +```bash +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \ + --repo-id lerobot/stanford_kuka_multimodal_dataset \ + --tasks-col "language_instruction" \ + --local-dir data +``` +""" + +import argparse +import contextlib +import filecmp +import json +import logging +import math +import shutil +import subprocess +import tempfile +from pathlib import Path + +import datasets +import pyarrow.compute as pc +import pyarrow.parquet as pq +import torch +from datasets import Dataset +from huggingface_hub import HfApi +from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError +from safetensors.torch import load_file + +from lerobot.common.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_PARQUET_PATH, + DEFAULT_VIDEO_PATH, + EPISODES_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, + create_branch, + create_lerobot_dataset_card, + flatten_dict, + get_safe_version, + load_json, + unflatten_dict, + write_json, + write_jsonlines, +) +from lerobot.common.datasets.video_utils import ( + VideoFrame, # noqa: F401 + get_image_pixel_channels, + get_video_info, +) +from lerobot.common.robot_devices.robots.configs import RobotConfig +from lerobot.common.robot_devices.robots.utils import make_robot_config + +V16 = "v1.6" +V20 = "v2.0" + +GITATTRIBUTES_REF = "aliberts/gitattributes_reference" +V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" +V1_INFO_PATH = "meta_data/info.json" +V1_STATS_PATH = "meta_data/stats.safetensors" + + +def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]: + if robot_cfg.type in ["aloha", "koch"]: + state_names = [ + f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor + for arm in robot_cfg.follower_arms + for motor in robot_cfg.follower_arms[arm].motors + ] + action_names = [ + # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"] + f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor + for arm in robot_cfg.leader_arms + for motor in robot_cfg.leader_arms[arm].motors + ] + # elif robot_cfg["robot_type"] == "stretch3": TODO + else: + raise NotImplementedError( + "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()." + ) + + return { + "robot_type": robot_cfg.type, + "names": { + "observation.state": state_names, + "observation.effort": state_names, + "action": action_names, + }, + } + + +def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: + safetensor_path = v1_dir / V1_STATS_PATH + stats = load_file(safetensor_path) + serialized_stats = {key: value.tolist() for key, value in stats.items()} + serialized_stats = unflatten_dict(serialized_stats) + + json_path = v2_dir / STATS_PATH + json_path.parent.mkdir(exist_ok=True, parents=True) + with open(json_path, "w") as f: + json.dump(serialized_stats, f, indent=4) + + # Sanity check + with open(json_path) as f: + stats_json = json.load(f) + + stats_json = flatten_dict(stats_json) + stats_json = {key: torch.tensor(value) for key, value in stats_json.items()} + for key in stats: + torch.testing.assert_close(stats_json[key], stats[key]) + + +def get_features_from_hf_dataset( + dataset: Dataset, robot_config: RobotConfig | None = None +) -> dict[str, list]: + robot_config = parse_robot_config(robot_config) + features = {} + for key, ft in dataset.features.items(): + if isinstance(ft, datasets.Value): + dtype = ft.dtype + shape = (1,) + names = None + if isinstance(ft, datasets.Sequence): + assert isinstance(ft.feature, datasets.Value) + dtype = ft.feature.dtype + shape = (ft.length,) + motor_names = ( + robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] + ) + assert len(motor_names) == shape[0] + names = {"motors": motor_names} + elif isinstance(ft, datasets.Image): + dtype = "image" + image = dataset[0][key] # Assuming first row + channels = get_image_pixel_channels(image) + shape = (image.height, image.width, channels) + names = ["height", "width", "channels"] + elif ft._type == "VideoFrame": + dtype = "video" + shape = None # Add shape later + names = ["height", "width", "channels"] + + features[key] = { + "dtype": dtype, + "shape": shape, + "names": names, + } + + return features + + +def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: + df = dataset.to_pandas() + tasks = list(set(tasks_by_episodes.values())) + tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} + episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} + df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) + + features = dataset.features + features["task_index"] = datasets.Value(dtype="int64") + dataset = Dataset.from_pandas(df, features=features, split="train") + return dataset, tasks + + +def add_task_index_from_tasks_col( + dataset: Dataset, tasks_col: str +) -> tuple[Dataset, dict[str, list[str]], list[str]]: + df = dataset.to_pandas() + + # HACK: This is to clean some of the instructions in our version of Open X datasets + prefix_to_clean = "tf.Tensor(b'" + suffix_to_clean = "', shape=(), dtype=string)" + df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) + + # Create task_index col + tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() + tasks = df[tasks_col].unique().tolist() + tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} + df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) + + # Build the dataset back from df + features = dataset.features + features["task_index"] = datasets.Value(dtype="int64") + dataset = Dataset.from_pandas(df, features=features, split="train") + dataset = dataset.remove_columns(tasks_col) + + return dataset, tasks, tasks_by_episode + + +def split_parquet_by_episodes( + dataset: Dataset, + total_episodes: int, + total_chunks: int, + output_dir: Path, +) -> list: + table = dataset.data.table + episode_lengths = [] + for ep_chunk in range(total_chunks): + ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk + ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) + chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) + (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) + for ep_idx in range(ep_chunk_start, ep_chunk_end): + ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) + episode_lengths.insert(ep_idx, len(ep_table)) + output_file = output_dir / DEFAULT_PARQUET_PATH.format( + episode_chunk=ep_chunk, episode_index=ep_idx + ) + pq.write_table(ep_table, output_file) + + return episode_lengths + + +def move_videos( + repo_id: str, + video_keys: list[str], + total_episodes: int, + total_chunks: int, + work_dir: Path, + clean_gittatributes: Path, + branch: str = "main", +) -> None: + """ + HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git + commands to fetch git lfs video files references to move them into subdirectories without having to + actually download them. + """ + _lfs_clone(repo_id, work_dir, branch) + + videos_moved = False + video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] + if len(video_files) == 0: + video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] + videos_moved = True # Videos have already been moved + + assert len(video_files) == total_episodes * len(video_keys) + + lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files) + + current_gittatributes = work_dir / ".gitattributes" + if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False): + fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes) + + if lfs_untracked_videos: + fix_lfs_video_files_tracking(work_dir, video_files) + + if videos_moved: + return + + video_dirs = sorted(work_dir.glob("videos*/")) + for ep_chunk in range(total_chunks): + ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk + ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) + for vid_key in video_keys: + chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format( + episode_chunk=ep_chunk, video_key=vid_key + ) + (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True) + + for ep_idx in range(ep_chunk_start, ep_chunk_end): + target_path = DEFAULT_VIDEO_PATH.format( + episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx + ) + video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) + if len(video_dirs) == 1: + video_path = video_dirs[0] / video_file + else: + for dir in video_dirs: + if (dir / video_file).is_file(): + video_path = dir / video_file + break + + video_path.rename(work_dir / target_path) + + commit_message = "Move video files into chunk subdirectories" + subprocess.run(["git", "add", "."], cwd=work_dir, check=True) + subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True) + subprocess.run(["git", "push"], cwd=work_dir, check=True) + + +def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: + """ + HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, + there's no other option than to download the actual files and reupload them with lfs tracking. + """ + for i in range(0, len(lfs_untracked_videos), 100): + files = lfs_untracked_videos[i : i + 100] + try: + subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print("git rm --cached ERROR:") + print(e.stderr) + subprocess.run(["git", "add", *files], cwd=work_dir, check=True) + + commit_message = "Track video files with git lfs" + subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True) + subprocess.run(["git", "push"], cwd=work_dir, check=True) + + +def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: + shutil.copyfile(clean_gittatributes, current_gittatributes) + subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) + subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) + subprocess.run(["git", "push"], cwd=work_dir, check=True) + + +def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: + subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True) + repo_url = f"https://huggingface.co/datasets/{repo_id}" + env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files + subprocess.run( + ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)], + check=True, + env=env, + ) + + +def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: + lfs_tracked_files = subprocess.run( + ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True + ) + lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) + return [f for f in video_files if f not in lfs_tracked_files] + + +def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: + # Assumes first episode + video_files = [ + DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) + for vid_key in video_keys + ] + hub_api = HfApi() + hub_api.snapshot_download( + repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files + ) + videos_info_dict = {} + for vid_key, vid_path in zip(video_keys, video_files, strict=True): + videos_info_dict[vid_key] = get_video_info(local_dir / vid_path) + + return videos_info_dict + + +def convert_dataset( + repo_id: str, + local_dir: Path, + single_task: str | None = None, + tasks_path: Path | None = None, + tasks_col: Path | None = None, + robot_config: RobotConfig | None = None, + test_branch: str | None = None, + **card_kwargs, +): + v1 = get_safe_version(repo_id, V16) + v1x_dir = local_dir / V16 / repo_id + v20_dir = local_dir / V20 / repo_id + v1x_dir.mkdir(parents=True, exist_ok=True) + v20_dir.mkdir(parents=True, exist_ok=True) + + hub_api = HfApi() + hub_api.snapshot_download( + repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/" + ) + branch = "main" + if test_branch: + branch = test_branch + create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset") + + metadata_v1 = load_json(v1x_dir / V1_INFO_PATH) + dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") + features = get_features_from_hf_dataset(dataset, robot_config) + video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] + + if single_task and "language_instruction" in dataset.column_names: + logging.warning( + "'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.", + ) + single_task = None + tasks_col = "language_instruction" + + # Episodes & chunks + episode_indices = sorted(dataset.unique("episode_index")) + total_episodes = len(episode_indices) + assert episode_indices == list(range(total_episodes)) + total_videos = total_episodes * len(video_keys) + total_chunks = total_episodes // DEFAULT_CHUNK_SIZE + if total_episodes % DEFAULT_CHUNK_SIZE != 0: + total_chunks += 1 + + # Tasks + if single_task: + tasks_by_episodes = dict.fromkeys(episode_indices, single_task) + dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) + tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + elif tasks_path: + tasks_by_episodes = load_json(tasks_path) + tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} + dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) + tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + elif tasks_col: + dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) + else: + raise ValueError + + assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} + tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] + write_jsonlines(tasks, v20_dir / TASKS_PATH) + features["task_index"] = { + "dtype": "int64", + "shape": (1,), + "names": None, + } + + # Videos + if video_keys: + assert metadata_v1.get("video", False) + dataset = dataset.remove_columns(video_keys) + clean_gitattr = Path( + hub_api.hf_hub_download( + repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" + ) + ).absolute() + with tempfile.TemporaryDirectory() as tmp_video_dir: + move_videos( + repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch + ) + videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) + for key in video_keys: + features[key]["shape"] = ( + videos_info[key].pop("video.height"), + videos_info[key].pop("video.width"), + videos_info[key].pop("video.channels"), + ) + features[key]["video_info"] = videos_info[key] + assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) + if "encoding" in metadata_v1: + assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] + else: + assert metadata_v1.get("video", 0) == 0 + videos_info = None + + # Split data into 1 parquet file by episode + episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) + + if robot_config is not None: + robot_type = robot_config.type + repo_tags = [robot_type] + else: + robot_type = "unknown" + repo_tags = None + + # Episodes + episodes = [ + {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} + for ep_idx in episode_indices + ] + write_jsonlines(episodes, v20_dir / EPISODES_PATH) + + # Assemble metadata v2.0 + metadata_v2_0 = { + "codebase_version": V20, + "robot_type": robot_type, + "total_episodes": total_episodes, + "total_frames": len(dataset), + "total_tasks": len(tasks), + "total_videos": total_videos, + "total_chunks": total_chunks, + "chunks_size": DEFAULT_CHUNK_SIZE, + "fps": metadata_v1["fps"], + "splits": {"train": f"0:{total_episodes}"}, + "data_path": DEFAULT_PARQUET_PATH, + "video_path": DEFAULT_VIDEO_PATH if video_keys else None, + "features": features, + } + write_json(metadata_v2_0, v20_dir / INFO_PATH) + convert_stats_to_json(v1x_dir, v20_dir) + card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) + + with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) + + with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) + + with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): + hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) + + hub_api.upload_folder( + repo_id=repo_id, + path_in_repo="data", + folder_path=v20_dir / "data", + repo_type="dataset", + revision=branch, + ) + hub_api.upload_folder( + repo_id=repo_id, + path_in_repo="meta", + folder_path=v20_dir / "meta", + repo_type="dataset", + revision=branch, + ) + + card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch) + + if not test_branch: + create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") + + +def main(): + parser = argparse.ArgumentParser() + task_args = parser.add_mutually_exclusive_group(required=True) + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + task_args.add_argument( + "--single-task", + type=str, + help="A short but accurate description of the single task performed in the dataset.", + ) + task_args.add_argument( + "--tasks-col", + type=str, + help="The name of the column containing language instructions", + ) + task_args.add_argument( + "--tasks-path", + type=Path, + help="The path to a .json file containing one language instruction for each episode_index", + ) + parser.add_argument( + "--robot", + type=str, + default=None, + help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)", + ) + parser.add_argument( + "--local-dir", + type=Path, + default=None, + help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2", + ) + parser.add_argument( + "--license", + type=str, + default="apache-2.0", + help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.", + ) + parser.add_argument( + "--test-branch", + type=str, + default=None, + help="Repo branch to test your conversion first (e.g. 'v2.0.test')", + ) + + args = parser.parse_args() + if not args.local_dir: + args.local_dir = Path("/tmp/lerobot_dataset_v2") + + if args.robot is not None: + robot_config = make_robot_config(args.robot) + + del args.robot + + convert_dataset(**vars(args), robot_config=robot_config) + + +if __name__ == "__main__": + main() diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py new file mode 100644 index 0000000000000000000000000000000000000000..643ddd3f20541bbd7dd809067fb37707a375d053 --- /dev/null +++ b/lerobot/common/datasets/v21/_remove_language_instruction.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import traceback +from pathlib import Path + +from datasets import get_dataset_config_info +from huggingface_hub import HfApi + +from lerobot import available_datasets +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.datasets.utils import INFO_PATH, write_info +from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings + +LOCAL_DIR = Path("data/") + +hub_api = HfApi() + + +def fix_dataset(repo_id: str) -> str: + if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"): + return f"{repo_id}: skipped (not in {V20})." + + dataset_info = get_dataset_config_info(repo_id, "default") + with SuppressWarnings(): + lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) + + meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} + parquet_features = set(dataset_info.features) + + diff_parquet_meta = parquet_features - meta_features + diff_meta_parquet = meta_features - parquet_features + + if diff_parquet_meta: + raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") + + if not diff_meta_parquet: + return f"{repo_id}: skipped (no diff)" + + if diff_meta_parquet: + logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") + assert diff_meta_parquet == {"language_instruction"} + lerobot_metadata.features.pop("language_instruction") + write_info(lerobot_metadata.info, lerobot_metadata.root) + commit_info = hub_api.upload_file( + path_or_fileobj=lerobot_metadata.root / INFO_PATH, + path_in_repo=INFO_PATH, + repo_id=repo_id, + repo_type="dataset", + revision=V20, + commit_message="Remove 'language_instruction'", + create_pr=True, + ) + return f"{repo_id}: success - PR: {commit_info.pr_url}" + + +def batch_fix(): + status = {} + LOCAL_DIR.mkdir(parents=True, exist_ok=True) + logfile = LOCAL_DIR / "fix_features_v20.txt" + for num, repo_id in enumerate(available_datasets): + print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") + print("---------------------------------------------------------") + try: + status = fix_dataset(repo_id) + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + + logging.info(status) + with open(logfile, "a") as file: + file.write(status + "\n") + + +if __name__ == "__main__": + batch_fix() diff --git a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9272a83575cb46f1581ea802e4b93c6b787b64 --- /dev/null +++ b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1. +""" + +import traceback +from pathlib import Path + +from huggingface_hub import HfApi + +from lerobot import available_datasets +from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset + +LOCAL_DIR = Path("data/") + + +def batch_convert(): + status = {} + LOCAL_DIR.mkdir(parents=True, exist_ok=True) + logfile = LOCAL_DIR / "conversion_log_v21.txt" + hub_api = HfApi() + for num, repo_id in enumerate(available_datasets): + print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") + print("---------------------------------------------------------") + try: + if hub_api.revision_exists(repo_id, V21, repo_type="dataset"): + status = f"{repo_id}: success (already in {V21})." + else: + convert_dataset(repo_id) + status = f"{repo_id}: success." + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + + with open(logfile, "a") as file: + file.write(status + "\n") + + +if __name__ == "__main__": + batch_convert() diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py new file mode 100644 index 0000000000000000000000000000000000000000..176d16d0f33a29c13d17658e706ce2b0219f80af --- /dev/null +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -0,0 +1,114 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to +2.1. It will: + +- Generate per-episodes stats and writes them in `episodes_stats.jsonl` +- Check consistency between these new stats and the old ones. +- Remove the deprecated `stats.json`. +- Update codebase_version in `info.json`. +- Push this new version to the hub on the 'main' branch and tags it with "v2.1". + +Usage: + +```bash +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ + --repo-id=aliberts/koch_tutorial +``` + +""" + +import argparse +import logging + +from huggingface_hub import HfApi + +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info +from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats + +V20 = "v2.0" +V21 = "v2.1" + + +class SuppressWarnings: + def __enter__(self): + self.previous_level = logging.getLogger().getEffectiveLevel() + logging.getLogger().setLevel(logging.ERROR) + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.getLogger().setLevel(self.previous_level) + + +def convert_dataset( + repo_id: str, + branch: str | None = None, + num_workers: int = 4, +): + with SuppressWarnings(): + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + + if (dataset.root / EPISODES_STATS_PATH).is_file(): + (dataset.root / EPISODES_STATS_PATH).unlink() + + convert_stats(dataset, num_workers=num_workers) + ref_stats = load_stats(dataset.root) + check_aggregate_stats(dataset, ref_stats) + + dataset.meta.info["codebase_version"] = CODEBASE_VERSION + write_info(dataset.meta.info, dataset.root) + + dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") + + # delete old stats.json file + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() + + hub_api = HfApi() + if hub_api.file_exists( + repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + ): + hub_api.delete_file( + path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + ) + + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " + "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--branch", + type=str, + default=None, + help="Repo branch to push your dataset. Defaults to the main branch.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for parallelizing stats compute. Defaults to 4.", + ) + + args = parser.parse_args() + convert_dataset(**vars(args)) diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..4a20b4276d5eb0f30a41e66afbfec859ccac18a0 --- /dev/null +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -0,0 +1,99 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +from tqdm import tqdm + +from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import write_episode_stats + + +def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: + ep_len = dataset.meta.episodes[episode_index]["length"] + sampled_indices = sample_indices(ep_len) + query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) + video_frames = dataset._query_videos(query_timestamps, episode_index) + return video_frames[ft_key].numpy() + + +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): + ep_start_idx = dataset.episode_data_index["from"][ep_idx] + ep_end_idx = dataset.episode_data_index["to"][ep_idx] + ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) + + ep_stats = {} + for key, ft in dataset.features.items(): + if ft["dtype"] == "video": + # We sample only for videos + ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) + else: + ep_ft_data = np.array(ep_data[key]) + + axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 + keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 + ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) + + if ft["dtype"] in ["image", "video"]: # remove batch dim + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } + + dataset.meta.episodes_stats[ep_idx] = ep_stats + + +def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + if num_workers > 0: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx + for ep_idx in range(total_episodes) + } + for future in tqdm(as_completed(futures), total=total_episodes): + future.result() + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + +def check_aggregate_stats( + dataset: LeRobotDataset, + reference_stats: dict[str, dict[str, np.ndarray]], + video_rtol_atol: tuple[float] = (1e-2, 1e-2), + default_rtol_atol: tuple[float] = (5e-6, 6e-5), +): + """Verifies that the aggregated stats from episodes_stats are close to reference stats.""" + agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) + for key, ft in dataset.features.items(): + # These values might need some fine-tuning + if ft["dtype"] == "video": + # to account for image sub-sampling + rtol, atol = video_rtol_atol + else: + rtol, atol = default_rtol_atol + + for stat, val in agg_stats[key].items(): + if key in reference_stats and stat in reference_stats[key]: + err_msg = f"feature='{key}' stats='{stat}'" + np.testing.assert_allclose( + val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg + ) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c38d570ddf8debfa955287495f3c55a20d14004e --- /dev/null +++ b/lerobot/common/datasets/video_utils.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import json +import logging +import subprocess +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar + +import pyarrow as pa +import torch +import torchvision +from datasets.features.features import register_feature +from PIL import Image + + +def get_safe_default_codec(): + if importlib.util.find_spec("torchcodec"): + return "torchcodec" + else: + logging.warning( + "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" + ) + return "pyav" + + +def decode_video_frames( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str | None = None, +) -> torch.Tensor: + """ + Decodes video frames using the specified backend. + + Args: + video_path (Path): Path to the video file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".. + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchcodec on cpu and pyav. + """ + if backend is None: + backend = get_safe_default_codec() + if backend == "torchcodec": + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + elif backend in ["pyav", "video_reader"]: + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + else: + raise ValueError(f"Unsupported video backend: {backend}") + + +def decode_video_frames_torchvision( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str = "pyav", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated to the requested timestamps of a video + + The backend can be either "pyav" (default) or "video_reader". + "video_reader" requires installing torchvision from source, see: + https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst + (note that you need to compile against ffmpeg<4.3) + + While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup. + For more info on video decoding, see `benchmark/video/README.md` + + See torchvision doc for more info on these two backends: + https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend + + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, + the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to + that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, + and all subsequent frames until reaching the requested frame. The number of key frames in a video + can be adjusted during encoding to take into account decoding time and video size in bytes. + """ + video_path = str(video_path) + + # set backend + keyframes_only = False + torchvision.set_video_backend(backend) + if backend == "pyav": + keyframes_only = True # pyav doesnt support accuracte seek + + # set a video stream reader + # TODO(rcadene): also load audio stream at the same time + reader = torchvision.io.VideoReader(video_path, "video") + + # set the first and last requested timestamps + # Note: previous timestamps are usually loaded, since we need to access the previous key frame + first_ts = min(timestamps) + last_ts = max(timestamps) + + # access closest key frame of the first requested frame + # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video) + # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek + reader.seek(first_ts, keyframes_only=keyframes_only) + + # load all frames until last requested frame + loaded_frames = [] + loaded_ts = [] + for frame in reader: + current_ts = frame["pts"] + if log_loaded_timestamps: + logging.info(f"frame loaded at timestamp={current_ts:.4f}") + loaded_frames.append(frame["data"]) + loaded_ts.append(current_ts) + if current_ts >= last_ts: + break + + if backend == "pyav": + reader.container.close() + + reader = None + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and timestamps of all loaded frames + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + f"\nbackend: {backend}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to the pytorch format which is float32 in [0,1] range (and channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + + +def decode_video_frames_torchcodec( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + device: str = "cpu", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated with the requested timestamps of a video using torchcodec. + + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. + + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, + the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to + that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, + and all subsequent frames until reaching the requested frame. The number of key frames in a video + can be adjusted during encoding to take into account decoding time and video size in bytes. + """ + + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + + # initialize video decoder + decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") + loaded_frames = [] + loaded_ts = [] + # get metadata for frame information + metadata = decoder.metadata + average_fps = metadata.average_fps + + # convert timestamps to frame indices + frame_indices = [round(ts * average_fps) for ts in timestamps] + + # retrieve frames based on indices + frames_batch = decoder.get_frames_at(indices=frame_indices) + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): + loaded_frames.append(frame) + loaded_ts.append(pts.item()) + if log_loaded_timestamps: + logging.info(f"Frame loaded at timestamp={pts:.4f}") + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and loaded timestamps + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to float32 in [0,1] range (channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + + +def encode_video_frames( + imgs_dir: Path | str, + video_path: Path | str, + fps: int, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int | None = 2, + crf: int | None = 30, + fast_decode: int = 0, + log_level: str | None = "error", + overwrite: bool = False, +) -> None: + """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" + video_path = Path(video_path) + imgs_dir = Path(imgs_dir) + video_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-f", "image2"), + ("-r", str(fps)), + ("-i", str(imgs_dir / "frame_%06d.png")), + ("-vcodec", vcodec), + ("-pix_fmt", pix_fmt), + ] + ) + + if g is not None: + ffmpeg_args["-g"] = str(g) + + if crf is not None: + ffmpeg_args["-crf"] = str(crf) + + if fast_decode: + key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" + value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" + ffmpeg_args[key] = value + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not video_path.exists(): + raise OSError( + f"Video encoding did not work. File not found: {video_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) + + +@dataclass +class VideoFrame: + # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()}) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "'register_feature' is experimental and might be subject to breaking changes in the future.", + category=UserWarning, + ) + # to make VideoFrame available in HuggingFace `datasets` + register_feature(VideoFrame, "VideoFrame") + + +def get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } + + +def get_video_info(video_path: Path | str) -> dict: + ffprobe_video_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + video_stream_info = info["streams"][0] + + # Calculate fps from r_frame_rate + r_frame_rate = video_stream_info["r_frame_rate"] + num, denom = map(int, r_frame_rate.split("/")) + fps = num / denom + + pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + + video_info = { + "video.fps": fps, + "video.height": video_stream_info["height"], + "video.width": video_stream_info["width"], + "video.channels": pixel_channels, + "video.codec": video_stream_info["codec_name"], + "video.pix_fmt": video_stream_info["pix_fmt"], + "video.is_depth_map": False, + **get_audio_info(video_path), + } + + return video_info + + +def get_video_pixel_channels(pix_fmt: str) -> int: + if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: + return 1 + elif "rgba" in pix_fmt or "yuva" in pix_fmt: + return 4 + elif "rgb" in pix_fmt or "yuv" in pix_fmt: + return 3 + else: + raise ValueError("Unknown format") + + +def get_image_pixel_channels(image: Image): + if image.mode == "L": + return 1 # Grayscale + elif image.mode == "LA": + return 2 # Grayscale + Alpha + elif image.mode == "RGB": + return 3 # RGB + elif image.mode == "RGBA": + return 4 # RGBA + else: + raise ValueError("Unknown format") diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4977d11d9fee1f02e1652bd80f500eff2837f0bc --- /dev/null +++ b/lerobot/common/envs/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401 diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..cf90048a37e78cd283bccd011208c3cfb5a6a20b --- /dev/null +++ b/lerobot/common/envs/configs.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass, field + +import draccus + +from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT +from lerobot.configs.types import FeatureType, PolicyFeature + + +@dataclass +class EnvConfig(draccus.ChoiceRegistry, abc.ABC): + task: str | None = None + fps: int = 30 + features: dict[str, PolicyFeature] = field(default_factory=dict) + features_map: dict[str, str] = field(default_factory=dict) + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @abc.abstractproperty + def gym_kwargs(self) -> dict: + raise NotImplementedError() + + +@EnvConfig.register_subclass("aloha") +@dataclass +class AlohaEnv(EnvConfig): + task: str = "AlohaInsertion-v0" + fps: int = 50 + episode_length: int = 400 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "top": f"{OBS_IMAGE}.top", + "pixels/top": f"{OBS_IMAGES}.top", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + "max_episode_steps": self.episode_length, + } + + +@EnvConfig.register_subclass("pusht") +@dataclass +class PushtEnv(EnvConfig): + task: str = "PushT-v0" + fps: int = 10 + episode_length: int = 300 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + visualization_width: int = 384 + visualization_height: int = 384 + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "environment_state": OBS_ENV, + "pixels": OBS_IMAGE, + } + ) + + def __post_init__(self): + if self.obs_type == "pixels_agent_pos": + self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + elif self.obs_type == "environment_state_agent_pos": + self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + "visualization_width": self.visualization_width, + "visualization_height": self.visualization_height, + "max_episode_steps": self.episode_length, + } + + +@EnvConfig.register_subclass("xarm") +@dataclass +class XarmEnv(EnvConfig): + task: str = "XarmLift-v0" + fps: int = 15 + episode_length: int = 200 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + visualization_width: int = 384 + visualization_height: int = 384 + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_ROBOT, + "pixels": OBS_IMAGE, + } + ) + + def __post_init__(self): + if self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + "visualization_width": self.visualization_width, + "visualization_height": self.visualization_height, + "max_episode_steps": self.episode_length, + } diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8450f84b95f393c67abab9536f8ddf252f8c0efa --- /dev/null +++ b/lerobot/common/envs/factory.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + +import gymnasium as gym + +from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv + + +def make_env_config(env_type: str, **kwargs) -> EnvConfig: + if env_type == "aloha": + return AlohaEnv(**kwargs) + elif env_type == "pusht": + return PushtEnv(**kwargs) + elif env_type == "xarm": + return XarmEnv(**kwargs) + else: + raise ValueError(f"Policy type '{env_type}' is not available.") + + +def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None: + """Makes a gym vector environment according to the config. + + Args: + cfg (EnvConfig): the config of the environment to instantiate. + n_envs (int, optional): The number of parallelized env to return. Defaults to 1. + use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to + False. + + Raises: + ValueError: if n_envs < 1 + ModuleNotFoundError: If the requested env package is not installed + + Returns: + gym.vector.VectorEnv: The parallelized gym.env instance. + """ + if n_envs < 1: + raise ValueError("`n_envs must be at least 1") + + package_name = f"gym_{cfg.type}" + + try: + importlib.import_module(package_name) + except ModuleNotFoundError as e: + print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") + raise e + + gym_handle = f"{package_name}/{cfg.task}" + + # batched version of the env that returns an observation of shape (b, c) + env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv + env = env_cls( + [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] + ) + + return env diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83334f876df43ae40d2ebaebdfab425f3e062134 --- /dev/null +++ b/lerobot/common/envs/utils.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Any + +import einops +import gymnasium as gym +import numpy as np +import torch +from torch import Tensor + +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.utils.utils import get_channel_first_image_shape +from lerobot.configs.types import FeatureType, PolicyFeature + + +def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: + # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding) + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # map to expected inputs for the policy + return_observations = {} + if "pixels" in observations: + if isinstance(observations["pixels"], dict): + imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} + else: + imgs = {"observation.image": observations["pixels"]} + + for imgkey, img in imgs.items(): + # TODO(aliberts, rcadene): use transforms.ToTensor()? + img = torch.from_numpy(img) + + # sanity check that images are channel last + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 + + return_observations[imgkey] = img + + if "environment_state" in observations: + return_observations["observation.environment_state"] = torch.from_numpy( + observations["environment_state"] + ).float() + + # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing + # requirement for "agent_pos" + return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + return return_observations + + +def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: + # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is + # (need to also refactor preprocess_observation and externalize normalization from policies) + policy_features = {} + for key, ft in env_cfg.features.items(): + if ft.type is FeatureType.VISUAL: + if len(ft.shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})") + + shape = get_channel_first_image_shape(ft.shape) + feature = PolicyFeature(type=ft.type, shape=shape) + else: + feature = ft + + policy_key = env_cfg.features_map[key] + policy_features[policy_key] = feature + + return policy_features + + +def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool: + first_type = type(env.envs[0]) # Get type of first env + return all(type(e) is first_type for e in env.envs) # Fast type check + + +def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("once", UserWarning) # Apply filter only in this function + + if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")): + warnings.warn( + "The environment does not have 'task_description' and 'task'. Some policies require these features.", + UserWarning, + stacklevel=2, + ) + if not are_all_envs_same_type(env): + warnings.warn( + "The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.", + UserWarning, + stacklevel=2, + ) + + +def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: + """Adds task feature to the observation dict with respect to the first environment attribute.""" + if hasattr(env.envs[0], "task_description"): + observation["task"] = env.call("task_description") + elif hasattr(env.envs[0], "task"): + observation["task"] = env.call("task") + else: # For envs without language instructions, e.g. aloha transfer cube and etc. + num_envs = observation[list(observation.keys())[0]].shape[0] + observation["task"] = ["" for _ in range(num_envs)] + return observation diff --git a/lerobot/common/mocks/__init__.py b/lerobot/common/mocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5848eccbe2f0ac0fca6b3a61d203cec3d9ccde --- /dev/null +++ b/lerobot/common/mocks/__init__.py @@ -0,0 +1 @@ +# Common mocks for robot devices and testing diff --git a/lerobot/common/mocks/cameras/__init__.py b/lerobot/common/mocks/cameras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lerobot/common/mocks/cameras/mock_cv2.py b/lerobot/common/mocks/cameras/mock_cv2.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaf859cc210addcb1a3e79b6c13cd40e06d212c --- /dev/null +++ b/lerobot/common/mocks/cameras/mock_cv2.py @@ -0,0 +1,101 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cache + +import numpy as np + +CAP_V4L2 = 200 +CAP_DSHOW = 700 +CAP_AVFOUNDATION = 1200 +CAP_ANY = -1 + +CAP_PROP_FPS = 5 +CAP_PROP_FRAME_WIDTH = 3 +CAP_PROP_FRAME_HEIGHT = 4 +COLOR_RGB2BGR = 4 +COLOR_BGR2RGB = 4 + +ROTATE_90_COUNTERCLOCKWISE = 2 +ROTATE_90_CLOCKWISE = 0 +ROTATE_180 = 1 + + +@cache +def _generate_image(width: int, height: int): + return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) + + +def cvtColor(color_image, color_conversion): # noqa: N802 + if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]: + return color_image[:, :, [2, 1, 0]] + else: + raise NotImplementedError(color_conversion) + + +def rotate(color_image, rotation): + if rotation is None: + return color_image + elif rotation == ROTATE_90_CLOCKWISE: + return np.rot90(color_image, k=1) + elif rotation == ROTATE_180: + return np.rot90(color_image, k=2) + elif rotation == ROTATE_90_COUNTERCLOCKWISE: + return np.rot90(color_image, k=3) + else: + raise NotImplementedError(rotation) + + +class VideoCapture: + def __init__(self, *args, **kwargs): + self._mock_dict = { + CAP_PROP_FPS: 30, + CAP_PROP_FRAME_WIDTH: 640, + CAP_PROP_FRAME_HEIGHT: 480, + } + self._is_opened = True + + def isOpened(self): # noqa: N802 + return self._is_opened + + def set(self, propId: int, value: float) -> bool: # noqa: N803 + if not self._is_opened: + raise RuntimeError("Camera is not opened") + self._mock_dict[propId] = value + return True + + def get(self, propId: int) -> float: # noqa: N803 + if not self._is_opened: + raise RuntimeError("Camera is not opened") + value = self._mock_dict[propId] + if value == 0: + if propId == CAP_PROP_FRAME_HEIGHT: + value = 480 + elif propId == CAP_PROP_FRAME_WIDTH: + value = 640 + return value + + def read(self): + if not self._is_opened: + raise RuntimeError("Camera is not opened") + h = self.get(CAP_PROP_FRAME_HEIGHT) + w = self.get(CAP_PROP_FRAME_WIDTH) + ret = True + return ret, _generate_image(width=w, height=h) + + def release(self): + self._is_opened = False + + def __del__(self): + if self._is_opened: + self.release() diff --git a/lerobot/common/mocks/cameras/mock_pyrealsense2.py b/lerobot/common/mocks/cameras/mock_pyrealsense2.py new file mode 100644 index 0000000000000000000000000000000000000000..c477eb0626cefdb7c4ee29f9c98a548b901455a9 --- /dev/null +++ b/lerobot/common/mocks/cameras/mock_pyrealsense2.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + +import numpy as np + + +class stream(enum.Enum): # noqa: N801 + color = 0 + depth = 1 + + +class format(enum.Enum): # noqa: N801 + rgb8 = 0 + z16 = 1 + + +class config: # noqa: N801 + def enable_device(self, device_id: str): + self.device_enabled = device_id + + def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): + self.stream_type = stream_type + # Overwrite default values when possible + self.width = 848 if width is None else width + self.height = 480 if height is None else height + self.color_format = format.rgb8 if color_format is None else color_format + self.fps = 30 if fps is None else fps + + +class RSColorProfile: + def __init__(self, config): + self.config = config + + def fps(self): + return self.config.fps + + def width(self): + return self.config.width + + def height(self): + return self.config.height + + +class RSColorStream: + def __init__(self, config): + self.config = config + + def as_video_stream_profile(self): + return RSColorProfile(self.config) + + +class RSProfile: + def __init__(self, config): + self.config = config + + def get_stream(self, color_format): + del color_format # unused + return RSColorStream(self.config) + + +class pipeline: # noqa: N801 + def __init__(self): + self.started = False + self.config = None + + def start(self, config): + self.started = True + self.config = config + return RSProfile(self.config) + + def stop(self): + if not self.started: + raise RuntimeError("You need to start the camera before stop.") + self.started = False + self.config = None + + def wait_for_frames(self, timeout_ms=50000): + del timeout_ms # unused + return RSFrames(self.config) + + +class RSFrames: + def __init__(self, config): + self.config = config + + def get_color_frame(self): + return RSColorFrame(self.config) + + def get_depth_frame(self): + return RSDepthFrame(self.config) + + +class RSColorFrame: + def __init__(self, config): + self.config = config + + def get_data(self): + data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8) + # Create a difference between rgb and bgr + data[:, :, 0] = 2 + return data + + +class RSDepthFrame: + def __init__(self, config): + self.config = config + + def get_data(self): + return np.ones((self.config.height, self.config.width), dtype=np.uint16) + + +class RSDevice: + def __init__(self): + pass + + def get_info(self, camera_info) -> str: + del camera_info # unused + # return fake serial number + return "123456789" + + +class context: # noqa: N801 + def __init__(self): + pass + + def query_devices(self): + return [RSDevice()] + + +class camera_info: # noqa: N801 + # fake name + name = "Intel RealSense D435I" + + def __init__(self, serial_number): + del serial_number + pass diff --git a/lerobot/common/mocks/motors/__init__.py b/lerobot/common/mocks/motors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2184a5c23b4ae3ddf20adbc433ebbc214a1a7e --- /dev/null +++ b/lerobot/common/mocks/motors/__init__.py @@ -0,0 +1 @@ +# Mocks for motor modules diff --git a/lerobot/common/mocks/motors/mock_dynamixel_sdk.py b/lerobot/common/mocks/motors/mock_dynamixel_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..ee399f96de95f51ca42357955af396b4f1e8ea91 --- /dev/null +++ b/lerobot/common/mocks/motors/mock_dynamixel_sdk.py @@ -0,0 +1,107 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration +and testing code logic that requires hardware and devices (e.g. robot arms, cameras) + +Warning: These mocked versions are minimalist. They do not exactly mock every behaviors +from the original classes and functions (e.g. return types might be None instead of boolean). +""" + +# from dynamixel_sdk import COMM_SUCCESS + +DEFAULT_BAUDRATE = 9_600 +COMM_SUCCESS = 0 # tx or rx packet communication success + + +def convert_to_bytes(value, bytes): + # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform + # `convert_bytes_to_value` + del bytes # unused + return value + + +def get_default_motor_values(motor_index): + return { + # Key (int) are from X_SERIES_CONTROL_TABLE + 7: motor_index, # ID + 8: DEFAULT_BAUDRATE, # Baud_rate + 10: 0, # Drive_Mode + 64: 0, # Torque_Enable + # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 + # For other joints, 2560 will be autocorrected to be in calibration range + 132: 2560, # Present_Position + } + + +class PortHandler: + def __init__(self, port): + self.port = port + # factory default baudrate + self.baudrate = DEFAULT_BAUDRATE + + def openPort(self): # noqa: N802 + return True + + def closePort(self): # noqa: N802 + pass + + def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 + del timeout_ms # unused + + def getBaudRate(self): # noqa: N802 + return self.baudrate + + def setBaudRate(self, baudrate): # noqa: N802 + self.baudrate = baudrate + + +class PacketHandler: + def __init__(self, protocol_version): + del protocol_version # unused + # Use packet_handler.data to communicate across Read and Write + self.data = {} + + +class GroupSyncRead: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + + def addParam(self, motor_index): # noqa: N802 + # Initialize motor default values + if motor_index not in self.packet_handler.data: + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) + + def txRxPacket(self): # noqa: N802 + return COMM_SUCCESS + + def getData(self, index, address, bytes): # noqa: N802 + return self.packet_handler.data[index][address] + + +class GroupSyncWrite: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + self.address = address + + def addParam(self, index, data): # noqa: N802 + # Initialize motor default values + if index not in self.packet_handler.data: + self.packet_handler.data[index] = get_default_motor_values(index) + self.changeParam(index, data) + + def txPacket(self): # noqa: N802 + return COMM_SUCCESS + + def changeParam(self, index, data): # noqa: N802 + self.packet_handler.data[index][self.address] = data diff --git a/lerobot/common/mocks/motors/mock_scservo_sdk.py b/lerobot/common/mocks/motors/mock_scservo_sdk.py new file mode 100644 index 0000000000000000000000000000000000000000..37f6d0d566fb3312f33d7f8bd621fdb1c349e140 --- /dev/null +++ b/lerobot/common/mocks/motors/mock_scservo_sdk.py @@ -0,0 +1,125 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration +and testing code logic that requires hardware and devices (e.g. robot arms, cameras) + +Warning: These mocked versions are minimalist. They do not exactly mock every behaviors +from the original classes and functions (e.g. return types might be None instead of boolean). +""" + +# from dynamixel_sdk import COMM_SUCCESS + +DEFAULT_BAUDRATE = 1_000_000 +COMM_SUCCESS = 0 # tx or rx packet communication success + + +def convert_to_bytes(value, bytes): + # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform + # `convert_bytes_to_value` + del bytes # unused + return value + + +def get_default_motor_values(motor_index): + return { + # Key (int) are from SCS_SERIES_CONTROL_TABLE + 5: motor_index, # ID + 6: DEFAULT_BAUDRATE, # Baud_rate + 10: 0, # Drive_Mode + 21: 32, # P_Coefficient + 22: 32, # D_Coefficient + 23: 0, # I_Coefficient + 40: 0, # Torque_Enable + 41: 254, # Acceleration + 31: -2047, # Offset + 33: 0, # Mode + 55: 1, # Lock + # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 + # For other joints, 2560 will be autocorrected to be in calibration range + 56: 2560, # Present_Position + 58: 0, # Present_Speed + 69: 0, # Present_Current + 85: 150, # Maximum_Acceleration + } + + +class PortHandler: + def __init__(self, port): + self.port = port + # factory default baudrate + self.baudrate = DEFAULT_BAUDRATE + self.ser = SerialMock() + + def openPort(self): # noqa: N802 + return True + + def closePort(self): # noqa: N802 + pass + + def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 + del timeout_ms # unused + + def getBaudRate(self): # noqa: N802 + return self.baudrate + + def setBaudRate(self, baudrate): # noqa: N802 + self.baudrate = baudrate + + +class PacketHandler: + def __init__(self, protocol_version): + del protocol_version # unused + # Use packet_handler.data to communicate across Read and Write + self.data = {} + + +class GroupSyncRead: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + + def addParam(self, motor_index): # noqa: N802 + # Initialize motor default values + if motor_index not in self.packet_handler.data: + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) + + def txRxPacket(self): # noqa: N802 + return COMM_SUCCESS + + def getData(self, index, address, bytes): # noqa: N802 + return self.packet_handler.data[index][address] + + +class GroupSyncWrite: + def __init__(self, port_handler, packet_handler, address, bytes): + self.packet_handler = packet_handler + self.address = address + + def addParam(self, index, data): # noqa: N802 + if index not in self.packet_handler.data: + self.packet_handler.data[index] = get_default_motor_values(index) + self.changeParam(index, data) + + def txPacket(self): # noqa: N802 + return COMM_SUCCESS + + def changeParam(self, index, data): # noqa: N802 + self.packet_handler.data[index][self.address] = data + + +class SerialMock: + def reset_output_buffer(self): + pass + + def reset_input_buffer(self): + pass diff --git a/lerobot/common/optim/__init__.py b/lerobot/common/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de2c4c99651ba9c01137026bd35ccb155670c22c --- /dev/null +++ b/lerobot/common/optim/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizers import OptimizerConfig as OptimizerConfig diff --git a/lerobot/common/optim/factory.py b/lerobot/common/optim/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..10ff3df73c3fa47078f8359be863dff050b265b5 --- /dev/null +++ b/lerobot/common/optim/factory.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.configs.train import TrainPipelineConfig + + +def make_optimizer_and_scheduler( + cfg: TrainPipelineConfig, policy: PreTrainedPolicy +) -> tuple[Optimizer, LRScheduler | None]: + """Generates the optimizer and scheduler based on configs. + + Args: + cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs + policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from. + + Returns: + tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. + """ + params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() + optimizer = cfg.optimizer.build(params) + lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None + return optimizer, lr_scheduler diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf4124ce6ec2d4c2045a21a98d2ca6807ed48d3 --- /dev/null +++ b/lerobot/common/optim/optimizers.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from dataclasses import asdict, dataclass +from pathlib import Path + +import draccus +import torch +from safetensors.torch import load_file, save_file + +from lerobot.common.constants import ( + OPTIMIZER_PARAM_GROUPS, + OPTIMIZER_STATE, +) +from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json +from lerobot.common.utils.io_utils import deserialize_json_into_object + + +@dataclass +class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): + lr: float + weight_decay: float + grad_clip_norm: float + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @classmethod + def default_choice_name(cls) -> str | None: + return "adam" + + @abc.abstractmethod + def build(self) -> torch.optim.Optimizer: + raise NotImplementedError + + +@OptimizerConfig.register_subclass("adam") +@dataclass +class AdamConfig(OptimizerConfig): + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + + def build(self, params: dict) -> torch.optim.Optimizer: + kwargs = asdict(self) + kwargs.pop("grad_clip_norm") + return torch.optim.Adam(params, **kwargs) + + +@OptimizerConfig.register_subclass("adamw") +@dataclass +class AdamWConfig(OptimizerConfig): + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 1e-2 + grad_clip_norm: float = 10.0 + + def build(self, params: dict) -> torch.optim.Optimizer: + kwargs = asdict(self) + kwargs.pop("grad_clip_norm") + return torch.optim.AdamW(params, **kwargs) + + +@OptimizerConfig.register_subclass("sgd") +@dataclass +class SGDConfig(OptimizerConfig): + lr: float = 1e-3 + momentum: float = 0.0 + dampening: float = 0.0 + nesterov: bool = False + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + + def build(self, params: dict) -> torch.optim.Optimizer: + kwargs = asdict(self) + kwargs.pop("grad_clip_norm") + return torch.optim.SGD(params, **kwargs) + + +def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + state = optimizer.state_dict() + param_groups = state.pop("param_groups") + flat_state = flatten_dict(state) + save_file(flat_state, save_dir / OPTIMIZER_STATE) + write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) + + +def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + current_state_dict = optimizer.state_dict() + flat_state = load_file(save_dir / OPTIMIZER_STATE) + state = unflatten_dict(flat_state) + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + if "param_groups" in current_state_dict: + param_groups = deserialize_json_into_object( + save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"] + ) + loaded_state_dict["param_groups"] = param_groups + + optimizer.load_state_dict(loaded_state_dict) + return optimizer diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..7e158394605cdfc42c53d9e472fb4e7a75fcb8af --- /dev/null +++ b/lerobot/common/optim/schedulers.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import math +from dataclasses import asdict, dataclass +from pathlib import Path + +import draccus +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + +from lerobot.common.constants import SCHEDULER_STATE +from lerobot.common.datasets.utils import write_json +from lerobot.common.utils.io_utils import deserialize_json_into_object + + +@dataclass +class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): + num_warmup_steps: int + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @abc.abstractmethod + def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None: + raise NotImplementedError + + +@LRSchedulerConfig.register_subclass("diffuser") +@dataclass +class DiffuserSchedulerConfig(LRSchedulerConfig): + name: str = "cosine" + num_warmup_steps: int | None = None + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + from diffusers.optimization import get_scheduler + + kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer} + return get_scheduler(**kwargs) + + +@LRSchedulerConfig.register_subclass("vqbet") +@dataclass +class VQBeTSchedulerConfig(LRSchedulerConfig): + num_warmup_steps: int + num_vqvae_training_steps: int + num_cycles: float = 0.5 + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + def lr_lambda(current_step): + if current_step < self.num_vqvae_training_steps: + return float(1) + else: + adjusted_step = current_step - self.num_vqvae_training_steps + if adjusted_step < self.num_warmup_steps: + return float(adjusted_step) / float(max(1, self.num_warmup_steps)) + progress = float(adjusted_step - self.num_warmup_steps) / float( + max(1, num_training_steps - self.num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress))) + + return LambdaLR(optimizer, lr_lambda, -1) + + +@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") +@dataclass +class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): + """Used by Physical Intelligence to train Pi0""" + + num_warmup_steps: int + num_decay_steps: int + peak_lr: float + decay_lr: float + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + del num_training_steps + + def lr_lambda(current_step): + def linear_warmup_schedule(current_step): + if current_step <= 0: + return 1 / (self.num_warmup_steps + 1) + frac = 1 - current_step / self.num_warmup_steps + return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + + def cosine_decay_schedule(current_step): + step = min(current_step, self.num_decay_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) + alpha = self.decay_lr / self.peak_lr + decayed = (1 - alpha) * cosine_decay + alpha + return decayed + + if current_step < self.num_warmup_steps: + return linear_warmup_schedule(current_step) + + return cosine_decay_schedule(current_step) + + return LambdaLR(optimizer, lr_lambda, -1) + + +def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: + state_dict = scheduler.state_dict() + write_json(state_dict, save_dir / SCHEDULER_STATE) + + +def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler: + state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict()) + scheduler.load_state_dict(state_dict) + return scheduler diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b73ba5f4eeec900d6f38323171be099a6b736f67 --- /dev/null +++ b/lerobot/common/policies/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .act.configuration_act import ACTConfig as ACTConfig +from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .pi0.configuration_pi0 import PI0Config as PI0Config +from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig +from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5819b7490c32b6ec9bba598c9f25aae837fe06 --- /dev/null +++ b/lerobot/common/policies/act/configuration_act.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("act") +@dataclass +class ACTConfig(PreTrainedConfig): + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and 'output_shapes`. + + Notes on the inputs and outputs: + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. + - May optionally work without an "observation.state" key for the proprioceptive robot state. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. + `None` means no pretrained weights. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + dim_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal + ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be + 1 when using this feature, as inference needs to happen at every step to form an ensemble. For + more information on how ensembling works, please see `ACTTemporalEnsembler`. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. + """ + + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 100 + n_action_steps: int = 100 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Architecture. + # Vision backbone. + vision_backbone: str = "resnet18" + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + replace_final_stride_with_dilation: int = False + # Transformer layers. + pre_norm: bool = False + dim_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: int = 1 + # VAE. + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 + + # Inference. + # Note: the value used in ACT when temporal ensembling is enabled is 0.01. + temporal_ensemble_coeff: float | None = None + + # Training and loss computation. + dropout: float = 0.1 + kl_weight: float = 10.0 + + # Training preset + optimizer_lr: float = 1e-5 + optimizer_weight_decay: float = 1e-4 + optimizer_lr_backbone: float = 1e-5 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1: + raise NotImplementedError( + "`n_action_steps` must be 1 when using temporal ensembling. This is " + "because the policy needs to be queried every step to compute the ensembled action." + ) + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py new file mode 100644 index 0000000000000000000000000000000000000000..72d4df03a2eedbc2e5604ec37e4db21f094c9f90 --- /dev/null +++ b/lerobot/common/policies/act/modeling_act.py @@ -0,0 +1,765 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Action Chunking Transformer Policy + +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d + +from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy + + +class ACTPolicy(PreTrainedPolicy): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + config_class = ACTConfig + name = "act" + + def __init__( + self, + config: ACTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = ACT(config) + + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + + self.reset() + + def get_optim_params(self) -> dict: + # TODO(aliberts, rcadene): As of now, lr_backbone == lr + # Should we remove this and just `return self.parameters()`? + return [ + { + "params": [ + p + for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": self.config.optimizer_lr_backbone, + }, + ] + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() + else: + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] + + # If we are doing temporal ensembling, do online updates where we keep track of the number of actions + # we are ensembling over. + if self.config.temporal_ensemble_coeff is not None: + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + action = self.temporal_ensembler.update(actions) + return action + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + actions = self.model(batch)[0][:, : self.config.n_action_steps] + + # TODO(rcadene): make _forward return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] + + batch = self.normalize_targets(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + + l1_loss = ( + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {"l1_loss": l1_loss.item()} + if self.config.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld.item() + loss = l1_loss + mean_kld * self.config.kl_weight + else: + loss = l1_loss + + return loss, loss_dict + + +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat( + [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + ) + # "Consume" the first action. + action, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, 0], + self.ensembled_actions[:, 1:], + self.ensembled_actions_count[1:], + ) + return action + + +class ACT(nn.Module): + """Action Chunking Transformer: The underlying neural network for ACTPolicy. + + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). + + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ + └───────────────────────┘ + """ + + def __init__(self, config: ACTConfig): + # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + super().__init__() + self.config = config + + if self.config.use_vae: + self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) + self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) + # Projection layer for joint-space configuration to hidden dimension. + if self.config.robot_state_feature: + self.vae_encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear( + self.config.action_feature.shape[0], + config.dim_model, + ) + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch + # dimension. + num_input_token_encoder = 1 + config.chunk_size + if self.config.robot_state_feature: + num_input_token_encoder += 1 + self.register_buffer( + "vae_encoder_pos_enc", + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + ) + + # Backbone for image feature extraction. + if self.config.image_features: + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final + # feature map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = ACTEncoder(config) + self.decoder = ACTDecoder(config) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. + if self.config.robot_state_feature: + self.encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + if self.config.env_state_feature: + self.encoder_env_state_input_proj = nn.Linear( + self.config.env_state_feature.shape[0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) + if self.config.image_features: + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, config.dim_model, kernel_size=1 + ) + # Transformer encoder positional embeddings. + n_1d_tokens = 1 # for the latent + if self.config.robot_state_feature: + n_1d_tokens += 1 + if self.config.env_state_feature: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + if self.config.image_features: + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) + + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) + + self._reset_parameters() + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder). + + `batch` should have the following structure: + { + [robot_state_feature] (optional): (B, state_dim) batch of robot states. + + [image_features]: (B, n_cameras, C, H, W) batch of images. + AND/OR + [env_state_feature]: (B, env_dim) batch of environment states. + + [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. + } + + Returns: + (B, chunk_size, action_dim) batch of action sequences + Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the + latent dimension. + """ + if self.config.use_vae and self.training: + assert "action" in batch, ( + "actions must be provided when using the variational objective in training mode." + ) + + if "observation.images" in batch: + batch_size = batch["observation.images"][0].shape[0] + else: + batch_size = batch["observation.environment_state"].shape[0] + + # Prepare the latent for input to the transformer encoder. + if self.config.use_vae and "action" in batch: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat( + self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size + ) # (B, 1, D) + if self.config.robot_state_feature: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + + if self.config.robot_state_feature: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) + + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + + # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the + # sequence depending whether we use the input states or not (cls and robot state) + # False means not a padding token. + cls_joint_is_pad = torch.full( + (batch_size, 2 if self.config.robot_state_feature else 1), + False, + device=batch["observation.state"].device, + ) + key_padding_mask = torch.cat( + [cls_joint_is_pad, batch["action_is_pad"]], axis=1 + ) # (bs, seq+1 or 2) + + # Forward pass through VAE encoder to get the latent PDF parameters. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), + pos_embed=pos_embed.permute(1, 0, 2), + key_padding_mask=key_padding_mask, + )[0] # select the class token, with shape (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.config.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] + + # Sample the latent with the reparameterization trick. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( + batch["observation.state"].device + ) + + # Prepare transformer encoder inputs. + encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] + encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) + # Robot state token. + if self.config.robot_state_feature: + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + # Environment state token. + if self.config.env_state_feature: + encoder_in_tokens.append( + self.encoder_env_state_input_proj(batch["observation.environment_state"]) + ) + + # Camera observation features and positional embeddings. + if self.config.image_features: + all_cam_features = [] + all_cam_pos_embeds = [] + + # For a list of images, the H and W may vary but H*W is constant. + for img in batch["observation.images"]: + cam_features = self.backbone(img)["feature_map"] + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) + + # Rearrange features to (sequence, batch, dim). + cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") + cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") + + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + + encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0)) + encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0)) + + # Stack all tokens along the sequence dimension. + encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) + encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0) + + # Forward pass through the transformer modules. + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer + decoder_in = torch.zeros( + (self.config.chunk_size, batch_size, self.config.dim_model), + dtype=encoder_in_pos_embed.dtype, + device=encoder_in_pos_embed.device, + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=encoder_in_pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + + # Move back to (B, S, C). + decoder_out = decoder_out.transpose(0, 1) + + actions = self.action_head(decoder_out) + + return actions, (mu, log_sigma_x2) + + +class ACTEncoder(nn.Module): + """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): + super().__init__() + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() + + def forward( + self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: + for layer in self.layers: + x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) + x = self.norm(x) + return x + + +class ACTEncoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) + x = x[0] # note: [0] to select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.pre_norm: + x = self.norm2(x) + return x + + +class ACTDecoder(nn.Module): + def __init__(self, config: ACTConfig): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class ACTDecoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] # select just the output, not the attention weights + x = skip + self.dropout2(x) + if self.pre_norm: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.pre_norm: + x = self.norm3(x) + return x + + +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions required. + Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). + + """ + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.from_numpy(sinusoid_table).float() + + +class ACTSinusoidalPositionEmbedding2d(nn.Module): + """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..e73c65fe9a4c9fddf79816e65bd222cc4845aa55 --- /dev/null +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamConfig +from lerobot.common.optim.schedulers import DiffuserSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("diffusion") +@dataclass +class DiffusionConfig(PreTrainedConfig): + """Configuration class for DiffusionPolicy. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and `output_shapes`. + + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. Right now we only support all images having the same shape. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + See `DiffusionPolicy.select_action` for more details. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. + `None` means no pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view. + down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. + You may provide a variable number of dimensions, therefore also controlling the degree of + downsampling. + kernel_size: The convolutional kernel size of the diffusion modeling Unet. + n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. + diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear + network. This is the output dimension of that network, i.e., the embedding dimension. + use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + Bias modulation is used be default, while this parameter indicates whether to also use scale + modulation. + noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. + num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. + beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. + beta_start: Beta value for the first forward-diffusion step. + beta_end: Beta value for the last forward-diffusion step. + prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon" + or "sample". These have equivalent outcomes from a latent variable modeling perspective, but + "epsilon" has been shown to work better in many deep neural network settings. + clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each + denoising step at inference time. WARNING: you will need to make sure your action-space is + normalized to fit within this range. + clip_sample_range: The magnitude of the clipping range as described above. + num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly + spaced). If not provided, this defaults to be the same as `num_train_timesteps`. + do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See + `LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults + to False as the original Diffusion Policy implementation does the same. + """ + + # Inputs / output structure. + n_obs_steps: int = 2 + horizon: int = 16 + n_action_steps: int = 8 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1 + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + pretrained_backbone_weights: str | None = None + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + use_separate_rgb_encoder_per_camera: bool = False + # Unet. + down_dims: tuple[int, ...] = (512, 1024, 2048) + kernel_size: int = 5 + n_groups: int = 8 + diffusion_step_embed_dim: int = 128 + use_film_scale_modulation: bool = True + # Noise scheduler. + noise_scheduler_type: str = "DDPM" + num_train_timesteps: int = 100 + beta_schedule: str = "squaredcos_cap_v2" + beta_start: float = 0.0001 + beta_end: float = 0.02 + prediction_type: str = "epsilon" + clip_sample: bool = True + clip_sample_range: float = 1.0 + + # Inference + num_inference_steps: int | None = None + + # Loss computation + do_mask_loss_for_padding: bool = False + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-6 + scheduler_name: str = "cosine" + scheduler_warmup_steps: int = 500 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + + supported_prediction_types = ["epsilon", "sample"] + if self.prediction_type not in supported_prediction_types: + raise ValueError( + f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." + ) + supported_noise_schedulers = ["DDPM", "DDIM"] + if self.noise_scheduler_type not in supported_noise_schedulers: + raise ValueError( + f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " + f"Got {self.noise_scheduler_type}." + ) + + # Check that the horizon size and U-Net downsampling is compatible. + # U-Net downsamples by 2 with each stage. + downsampling_factor = 2 ** len(self.down_dims) + if self.horizon % downsampling_factor != 0: + raise ValueError( + "The horizon should be an integer multiple of the downsampling factor (which is determined " + f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" + ) + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> DiffuserSchedulerConfig: + return DiffuserSchedulerConfig( + name=self.scheduler_name, + num_warmup_steps=self.scheduler_warmup_steps, + ) + + def validate_features(self) -> None: + if len(self.image_features) == 0 and self.env_state_feature is None: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + if self.crop_shape is not None: + for key, image_ft in self.image_features.items(): + if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: + raise ValueError( + f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " + f"for `crop_shape` and {image_ft.shape} for " + f"`{key}`." + ) + + # Check that all input images have the same shape. + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecadcb05b1990e73e4e9d3fa6f89abfcedd32f6 --- /dev/null +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -0,0 +1,765 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + +TODO(alexander-soare): + - Remove reliance on diffusers for DDPMScheduler and LR scheduler. +""" + +import math +from collections import deque +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from torch import Tensor, nn + +from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_dtype_from_parameters, + get_output_shape, + populate_queues, +) + + +class DiffusionPolicy(PreTrainedPolicy): + """ + Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + """ + + config_class = DiffusionConfig + name = "diffusion" + + def __init__( + self, + config: DiffusionConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + + self.diffusion = DiffusionModel(config) + + self.reset() + + def get_optim_params(self) -> dict: + return self.diffusion.parameters() + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`""" + self._queues = { + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), + } + if self.config.image_features: + self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + if self.config.env_state_feature: + self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method handles caching a history of observations and an action trajectory generated by the + underlying diffusion model. Here's how it works: + - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is + copied `n_obs_steps` times to fill the cache). + - The diffusion model generates `horizon` steps worth of actions. + - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. + Schematically this looks like: + ---------------------------------------------------------------------------------------------- + (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | + |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | + |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | + |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + ---------------------------------------------------------------------------------------------- + Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that + "horizon" may not the best name to describe what the variable actually means, because this period is + actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. + """ + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) + # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.diffusion.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({"action": actions})["action"] + + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) + batch = self.normalize_targets(batch) + loss = self.diffusion.compute_loss(batch) + # no output_dict so returning None + return loss, None + + +def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: + """ + Factory for noise scheduler instances of the requested type. All kwargs are passed + to the scheduler. + """ + if name == "DDPM": + return DDPMScheduler(**kwargs) + elif name == "DDIM": + return DDIMScheduler(**kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {name}") + + +class DiffusionModel(nn.Module): + def __init__(self, config: DiffusionConfig): + super().__init__() + self.config = config + + # Build observation encoders (depending on which observations are provided). + global_cond_dim = self.config.robot_state_feature.shape[0] + if self.config.image_features: + num_images = len(self.config.image_features) + if self.config.use_separate_rgb_encoder_per_camera: + encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)] + self.rgb_encoder = nn.ModuleList(encoders) + global_cond_dim += encoders[0].feature_dim * num_images + else: + self.rgb_encoder = DiffusionRgbEncoder(config) + global_cond_dim += self.rgb_encoder.feature_dim * num_images + if self.config.env_state_feature: + global_cond_dim += self.config.env_state_feature.shape[0] + + self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + + self.noise_scheduler = _make_noise_scheduler( + config.noise_scheduler_type, + num_train_timesteps=config.num_train_timesteps, + beta_start=config.beta_start, + beta_end=config.beta_end, + beta_schedule=config.beta_schedule, + clip_sample=config.clip_sample, + clip_sample_range=config.clip_sample_range, + prediction_type=config.prediction_type, + ) + + if config.num_inference_steps is None: + self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps + else: + self.num_inference_steps = config.num_inference_steps + + # ========= inference ============ + def conditional_sample( + self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + ) -> Tensor: + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Sample prior. + sample = torch.randn( + size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), + dtype=dtype, + device=device, + generator=generator, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + + for t in self.noise_scheduler.timesteps: + # Predict model output. + model_output = self.unet( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + global_cond=global_cond, + ) + # Compute previous image: x_t -> x_t-1 + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + + return sample + + def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: + """Encode image features and concatenate them all together along with the state vector.""" + batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] + global_cond_feats = [batch[OBS_ROBOT]] + # Extract image features. + if self.config.image_features: + if self.config.use_separate_rgb_encoder_per_camera: + # Combine batch and sequence dims while rearranging to make the camera index dimension first. + images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + img_features_list = torch.cat( + [ + encoder(images) + for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) + ] + ) + # Separate batch and sequence dims back out. The camera index dim gets absorbed into the + # feature dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + else: + # Combine batch, sequence, and "which camera" dims before passing to shared encoder. + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the + # feature dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + global_cond_feats.append(img_features) + + if self.config.env_state_feature: + global_cond_feats.append(batch[OBS_ENV]) + + # Concatenate features then flatten to (B, global_cond_dim). + return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) + + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have: + { + "observation.state": (B, n_obs_steps, state_dim) + + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) + } + """ + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) + + # run sampling + actions = self.conditional_sample(batch_size, global_cond=global_cond) + + # Extract `n_action_steps` steps worth of actions (from the current observation). + start = n_obs_steps - 1 + end = start + self.config.n_action_steps + actions = actions[:, start:end] + + return actions + + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) + + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) + } + """ + # Input validation. + assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) + assert "observation.images" in batch or "observation.environment_state" in batch + n_obs_steps = batch["observation.state"].shape[1] + horizon = batch["action"].shape[1] + assert horizon == self.config.horizon + assert n_obs_steps == self.config.n_obs_steps + + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) + + # Forward diffusion. + trajectory = batch["action"] + # Sample noise to add to the trajectory. + eps = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random noising timestep for each item in the batch. + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(trajectory.shape[0],), + device=trajectory.device, + ).long() + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) + + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + + # Compute the loss. + # The target is either the original trajectory, or the noise. + if self.config.prediction_type == "epsilon": + target = eps + elif self.config.prediction_type == "sample": + target = batch["action"] + else: + raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") + + loss = F.mse_loss(pred, target, reduction="none") + + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). + if self.config.do_mask_loss_for_padding: + if "action_is_pad" not in batch: + raise ValueError( + "You need to provide 'action_is_pad' in the batch when " + f"{self.config.do_mask_loss_for_padding=}." + ) + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound.unsqueeze(-1) + + return loss.mean() + + +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + +class DiffusionRgbEncoder(nn.Module): + """Encodes an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + """ + + def __init__(self, config: DiffusionConfig): + super().__init__() + # Set up optional preprocessing. + if config.crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + if config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, config.vision_backbone)( + weights=config.pretrained_backbone_weights + ) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if config.use_group_norm: + if config.pretrained_backbone_weights: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.image_features` and it should + # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the + # height and width from `config.image_features`. + + # Note: we have a check in the config class to make sure all images have the same shape. + images_shape = next(iter(config.image_features.values())).shape + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape = (1, images_shape[0], *dummy_shape_h_w) + feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] + + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.feature_dim = config.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: maybe crop (if it was set up in the __init__). + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class DiffusionSinusoidalPosEmb(nn.Module): + """1D sinusoidal positional embeddings as in Attention is All You Need.""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x.unsqueeze(-1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class DiffusionConv1dBlock(nn.Module): + """Conv1d --> GroupNorm --> Mish""" + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class DiffusionConditionalUnet1d(nn.Module): + """A 1D convolutional UNet with FiLM modulation for conditioning. + + Note: this removes local conditioning as compared to the original diffusion policy code. + """ + + def __init__(self, config: DiffusionConfig, global_cond_dim: int): + super().__init__() + + self.config = config + + # Encoder for the diffusion timestep. + self.diffusion_step_encoder = nn.Sequential( + DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), + nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), + nn.Mish(), + nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), + ) + + # The FiLM conditioning dimension. + cond_dim = config.diffusion_step_embed_dim + global_cond_dim + + # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we + # just reverse these. + in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list( + zip(config.down_dims[:-1], config.down_dims[1:], strict=True) + ) + + # Unet encoder. + common_res_block_kwargs = { + "cond_dim": cond_dim, + "kernel_size": config.kernel_size, + "n_groups": config.n_groups, + "use_film_scale_modulation": config.use_film_scale_modulation, + } + self.down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + self.down_modules.append( + nn.ModuleList( + [ + DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + # Downsample as long as it is not the last block. + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + # Processing in the middle of the auto-encoder. + self.mid_modules = nn.ModuleList( + [ + DiffusionConditionalResidualBlock1d( + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + ), + ] + ) + + # Unet decoder. + self.up_modules = nn.ModuleList([]) + for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + self.up_modules.append( + nn.ModuleList( + [ + # dim_in * 2, because it takes the encoder's skip connection as well + DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + # Upsample as long as it is not the last block. + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + self.final_conv = nn.Sequential( + DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), + nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1), + ) + + def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: + """ + Args: + x: (B, T, input_dim) tensor for input to the Unet. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) diffusion model prediction. + """ + # For 1D convolutions we'll need feature dimension first. + x = einops.rearrange(x, "b t d -> b d t") + + timesteps_embed = self.diffusion_step_encoder(timestep) + + # If there is a global conditioning feature, concatenate it to the timestep embedding. + if global_cond is not None: + global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) + else: + global_feature = timesteps_embed + + # Run encoder, keeping track of skip features to pass to the decoder. + encoder_skip_features: list[Tensor] = [] + for resnet, resnet2, downsample in self.down_modules: + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + encoder_skip_features.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + # Run decoder, using the skip features from the encoder. + for resnet, resnet2, upsample in self.up_modules: + x = torch.cat((x, encoder_skip_features.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, "b d t -> b t d") + return x + + +class DiffusionConditionalResidualBlock1d(nn.Module): + """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + cond_dim: int, + kernel_size: int = 3, + n_groups: int = 8, + # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning + # FiLM just modulates bias). + use_film_scale_modulation: bool = False, + ): + super().__init__() + + self.use_film_scale_modulation = use_film_scale_modulation + self.out_channels = out_channels + + self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + + # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels + self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) + + self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + + # A final convolution for dimension matching the residual (if needed). + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x: Tensor, cond: Tensor) -> Tensor: + """ + Args: + x: (B, in_channels, T) + cond: (B, cond_dim) + Returns: + (B, out_channels, T) + """ + out = self.conv1(x) + + # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). + cond_embed = self.cond_encoder(cond).unsqueeze(-1) + if self.use_film_scale_modulation: + # Treat the embedding as a list of scales and biases. + scale = cond_embed[:, : self.out_channels] + bias = cond_embed[:, self.out_channels :] + out = scale * out + bias + else: + # Treat the embedding as biases. + out = out + cond_embed + + out = self.conv2(out) + out = out + self.residual_conv(x) + return out diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8def95a35c2d1587c3b9de08949fb9d55cbab3a4 --- /dev/null +++ b/lerobot/common/policies/factory.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from torch import nn + +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.datasets.utils import dataset_to_policy_features +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import env_to_policy_features +from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType + + +def get_policy_class(name: str) -> PreTrainedPolicy: + """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" + if name == "tdmpc": + from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy + + return TDMPCPolicy + elif name == "diffusion": + from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy + + return DiffusionPolicy + elif name == "act": + from lerobot.common.policies.act.modeling_act import ACTPolicy + + return ACTPolicy + elif name == "vqbet": + from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy + + return VQBeTPolicy + elif name == "pi0": + from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + + return PI0Policy + elif name == "pi0fast": + from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy + + return PI0FASTPolicy + else: + raise NotImplementedError(f"Policy with name {name} is not implemented.") + + +def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + if policy_type == "tdmpc": + return TDMPCConfig(**kwargs) + elif policy_type == "diffusion": + return DiffusionConfig(**kwargs) + elif policy_type == "act": + return ACTConfig(**kwargs) + elif policy_type == "vqbet": + return VQBeTConfig(**kwargs) + elif policy_type == "pi0": + return PI0Config(**kwargs) + elif policy_type == "pi0fast": + return PI0FASTConfig(**kwargs) + else: + raise ValueError(f"Policy type '{policy_type}' is not available.") + + +def make_policy( + cfg: PreTrainedConfig, + ds_meta: LeRobotDatasetMetadata | None = None, + env_cfg: EnvConfig | None = None, +) -> PreTrainedPolicy: + """Make an instance of a policy class. + + This function exists because (for now) we need to parse features from either a dataset or an environment + in order to properly dimension and instantiate a policy for that dataset or environment. + + Args: + cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will + be loaded with the weights from that path. + ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and + statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. + env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be + provided if ds_meta is not. Defaults to None. + + Raises: + ValueError: Either ds_meta or env and env_cfg must be provided. + NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) + + Returns: + PreTrainedPolicy: _description_ + """ + if bool(ds_meta) == bool(env_cfg): + raise ValueError("Either one of a dataset metadata or a sim env must be provided.") + + # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error. + # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies? + # NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If + # you want this op to be added in priority during the prototype phase of this feature, please comment on + # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment + # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be + # slower than running natively on MPS. + if cfg.type == "vqbet" and cfg.device == "mps": + raise NotImplementedError( + "Current implementation of VQBeT does not support `mps` backend. " + "Please use `cpu` or `cuda` backend." + ) + + policy_cls = get_policy_class(cfg.type) + + kwargs = {} + if ds_meta is not None: + features = dataset_to_policy_features(ds_meta.features) + kwargs["dataset_stats"] = ds_meta.stats + else: + if not cfg.pretrained_path: + logging.warning( + "You are instantiating a policy from scratch and its features are parsed from an environment " + "rather than a dataset. Normalization modules inside the policy will have infinite values " + "by default without stats from a dataset." + ) + features = env_to_policy_features(env_cfg) + + cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} + kwargs["config"] = cfg + + if cfg.pretrained_path: + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + kwargs["pretrained_name_or_path"] = cfg.pretrained_path + policy = policy_cls.from_pretrained(**kwargs) + else: + # Make a fresh policy. + policy = policy_cls(**kwargs) + + policy.to(cfg.device) + assert isinstance(policy, nn.Module) + + # policy = torch.compile(policy, mode="reduce-overhead") + + return policy diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..b3255ec1069059f57d450e7e67a0f681bee5e66c --- /dev/null +++ b/lerobot/common/policies/normalize.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch +from torch import Tensor, nn + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +def create_stats_buffers( + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> dict[str, dict[str, nn.ParameterDict]]: + """ + Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max + statistics. + + Args: (see Normalize and Unnormalize) + + Returns: + dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing + `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. + """ + stats_buffers = {} + + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + assert isinstance(norm_mode, NormalizationMode) + + shape = tuple(ft.shape) + + if ft.type is FeatureType.VISUAL: + # sanity checks + assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" + c, h, w = shape + assert c < h and c < w, f"{key} is not channel first ({shape=})" + # override image shape to be invariant to height and width + shape = (c, 1, 1) + + # Note: we initialize mean, std, min, max to infinity. They should be overwritten + # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, + # we assert they are not infinity anymore. + + buffer = {} + if norm_mode is NormalizationMode.MEAN_STD: + mean = torch.ones(shape, dtype=torch.float32) * torch.inf + std = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict( + { + "mean": nn.Parameter(mean, requires_grad=False), + "std": nn.Parameter(std, requires_grad=False), + } + ) + elif norm_mode is NormalizationMode.MIN_MAX: + min = torch.ones(shape, dtype=torch.float32) * torch.inf + max = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict( + { + "min": nn.Parameter(min, requires_grad=False), + "max": nn.Parameter(max, requires_grad=False), + } + ) + + # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) + if stats: + if isinstance(stats[key]["mean"], np.ndarray): + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) + buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) + buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) + else: + type_ = type(stats[key]["mean"]) + raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") + + stats_buffers[key] = buffer + return stats_buffers + + +def _no_stats_error_str(name: str) -> str: + return ( + f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " + "pretrained model." + ) + + +class Normalize(nn.Module): + """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") + and values are dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_dict. + """ + super().__init__() + self.features = features + self.norm_map = norm_map + self.stats = stats + stats_buffers = create_stats_buffers(features, norm_map, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) # shallow copy avoids mutating the input batch + for key, ft in self.features.items(): + if key not in batch: + # FIXME(aliberts, rcadene): This might lead to silent fail! + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if norm_mode is NormalizationMode.MEAN_STD: + mean = buffer["mean"] + std = buffer["std"] + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif norm_mode is NormalizationMode.MIN_MAX: + min = buffer["min"] + max = buffer["max"] + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") + # normalize to [0,1] + batch[key] = (batch[key] - min) / (max - min + 1e-8) + # normalize to [-1, 1] + batch[key] = batch[key] * 2 - 1 + else: + raise ValueError(norm_mode) + return batch + + +class Unnormalize(nn.Module): + """ + Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their + original range used by the environment. + """ + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") + and values are dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_dict. + """ + super().__init__() + self.features = features + self.norm_map = norm_map + self.stats = stats + # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` + stats_buffers = create_stats_buffers(features, norm_map, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + # TODO(rcadene): should we remove torch.no_grad? + @torch.no_grad + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) # shallow copy avoids mutating the input batch + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if norm_mode is NormalizationMode.MEAN_STD: + mean = buffer["mean"] + std = buffer["std"] + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + elif norm_mode is NormalizationMode.MIN_MAX: + min = buffer["min"] + max = buffer["max"] + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min + else: + raise ValueError(norm_mode) + return batch diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7cc1305ed1b9bfcf2cf7b0659e576e992f5102 --- /dev/null +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -0,0 +1,149 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@PreTrainedConfig.register_subclass("pi0") +@dataclass +class PI0Config(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (224, 224) + + # Add empty images. Used by pi0_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + + # Projector + proj_width: int = 1024 + + # Decoding + num_steps: int = 10 + + # Attention utils + use_cache: bool = True + attention_implementation: str = "eager" # or fa2, flex + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = False + train_state_proj: bool = True + + # Training presets + optimizer_lr: float = 2.5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + # TODO: Add EMA + + def __post_init__(self): + super().__post_init__() + + # TODO(Steven): Validate device and amp? in all policy configs? + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + if self.use_delta_joint_actions_aloha: + raise NotImplementedError( + "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." + ) + + def validate_features(self) -> None: + # TODO: implement value error + # if not self.image_features and not self.env_state_feature: + # raise ValueError("You must provide at least one image or the environment state among the inputs.") + + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3c0e9baaf009540a53706fb9a3681efb1c7695 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py @@ -0,0 +1,82 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.configs.policies import PreTrainedConfig + +torch.backends.cudnn.benchmark = True + + +def main(): + device = "cuda" + dataset_repo_id = "danaaubakirova/koch_test" + # model_name = "pi0_base" + # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ckpt_torch_dir = "lerobot/pi0" + + dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=1, + ) + + batch = next(iter(dataloader)) + + # To device + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device=device, dtype=torch.float32) + + cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) + cfg.pretrained_path = ckpt_torch_dir + policy = make_policy(cfg, ds_meta=dataset.meta) + + # policy = torch.compile(policy, mode="reduce-overhead") + + warmup_iters = 10 + benchmark_iters = 30 + + # Warmup + for _ in range(warmup_iters): + torch.cuda.synchronize() + policy.select_action(batch) + policy.reset() + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(benchmark_iters): + policy.select_action(batch) + policy.reset() + end_event.record() + + # Synchronize and measure time + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + + avg_time_per_iter = elapsed_time_ms / benchmark_iters + print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms") + + +if __name__ == "__main__": + with torch.inference_mode(): + main() diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd7c91f714979fcf668e493415c2d084df32353 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -0,0 +1,131 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import pickle +from pathlib import Path + +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.policies.factory import make_policy +from lerobot.configs.policies import PreTrainedConfig + + +def display(tensor: torch.Tensor): + if tensor.dtype == torch.bool: + tensor = tensor.float() + print(f"Shape: {tensor.shape}") + print(f"Mean: {tensor.mean().item()}") + print(f"Std: {tensor.std().item()}") + print(f"Min: {tensor.min().item()}") + print(f"Max: {tensor.max().item()}") + + +def main(): + num_motors = 14 + device = "cuda" + # model_name = "pi0_aloha_towel" + model_name = "pi0_aloha_sim" + + if model_name == "pi0_aloha_towel": + dataset_repo_id = "lerobot/aloha_static_towel" + else: + dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" + + ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" + save_dir = Path(f"../openpi/data/{model_name}/save") + + with open(save_dir / "example.pkl", "rb") as f: + example = pickle.load(f) + with open(save_dir / "outputs.pkl", "rb") as f: + outputs = pickle.load(f) + with open(save_dir / "noise.pkl", "rb") as f: + noise = pickle.load(f) + + with open(ckpt_jax_dir / "assets/norm_stats.json") as f: + norm_stats = json.load(f) + + # Override stats + dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) + dataset_meta.stats["observation.state"]["mean"] = torch.tensor( + norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 + ) + dataset_meta.stats["observation.state"]["std"] = torch.tensor( + norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 + ) + + # Create LeRobot batch from Jax + batch = {} + for cam_key, uint_chw_array in example["images"].items(): + batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch["observation.state"] = torch.from_numpy(example["state"]) + batch["action"] = torch.from_numpy(outputs["actions"]) + batch["task"] = example["prompt"] + + if model_name == "pi0_aloha_towel": + del batch["observation.images.cam_low"] + elif model_name == "pi0_aloha_sim": + batch["observation.images.top"] = batch["observation.images.cam_high"] + del batch["observation.images.cam_high"] + + # Batchify + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].unsqueeze(0) + elif isinstance(batch[key], str): + batch[key] = [batch[key]] + else: + raise ValueError(f"{key}, {batch[key]}") + + # To device + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device=device, dtype=torch.float32) + + noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32) + + from lerobot.common import policies # noqa + + cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) + cfg.pretrained_path = ckpt_torch_dir + policy = make_policy(cfg, dataset_meta) + + # loss_dict = policy.forward(batch, noise=noise, time=time_beta) + # loss_dict["loss"].backward() + # print("losses") + # display(loss_dict["losses_after_forward"]) + # print("pi_losses") + # display(pi_losses) + + actions = [] + for _ in range(50): + action = policy.select_action(batch, noise=noise) + actions.append(action) + + actions = torch.stack(actions, dim=1) + pi_actions = batch["action"] + print("actions") + display(actions) + print() + print("pi_actions") + display(pi_actions) + print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2)) + print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2)) + print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8835da31efd5e6b05a179f8237d6ada029434963 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py @@ -0,0 +1,84 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import GemmaConfig, PaliGemmaConfig + + +def get_paligemma_config(precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} + + image_size = 224 # image_sizes[variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + + config["image_token_index"] = 257152 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 2048, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 16384, + "is_encoder_decoder": False, + } + vision_config = { + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + return final_config + + +def get_gemma_config(precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + config["image_token_index"] = 257152 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 1024, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 4096, + "is_encoder_decoder": False, + } + final_config = GemmaConfig() + final_config.update(text_config) + return final_config diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py new file mode 100644 index 0000000000000000000000000000000000000000..73ff506ff86364d48321326de5f66841b8d0af23 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -0,0 +1,437 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert pi0 parameters from Jax to Pytorch + +Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment +and install the required libraries. + +```bash +cd ~/code/openpi +source .venv/bin/activate +``` + +Example downloading parameters: +```bash +python +>>> import openpi.shared.download as download +>>> path='s3://openpi-assets/checkpoints/pi0_base/params' +>>> download.maybe_download(path) +``` + +Converting pi0_base: +```python +python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ + --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ + --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch +``` + +```python +python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ + --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ + --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch +``` +""" + +import argparse +import pathlib + +import jax +import numpy as np +import orbax.checkpoint as ocp +import torch +from jax.sharding import SingleDeviceSharding + +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import ( + get_gemma_config, + get_paligemma_config, +) +from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + +PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} + + +def slice_paligemma_state_dict(state_dict, config): + suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" + + # fmt: off + # patch embeddings + state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( + 3, 2, 0, 1 + ) + state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") + # positional embeddings + state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") + encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") + encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") + encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") + + encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") + encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") + encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") + encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") + + encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") + encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") + encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") + encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") + encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") + encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") + encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") + encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() + state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") + + # multimodal projector + + state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() + state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") + + # text decoder (gemma) + embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") + state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") + llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") + llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") + + llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") + llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") + llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") + state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. + + # fmt: on + expert_dict = {} + final_state_dict = {} + for key, value in state_dict.items(): + if key not in [ + f"llm/final_norm_1/scale{suffix}", + f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", + f"llm/layers/attn/kv_einsum_1/w{suffix}", + f"llm/layers/attn/q_einsum_1/w{suffix}", + f"llm/layers/mlp_1/gating_einsum{suffix}", + f"llm/layers/mlp_1/linear{suffix}", + f"llm/layers/pre_attention_norm_1/scale{suffix}", + f"llm/layers/pre_ffw_norm_1/scale{suffix}", + ]: + final_state_dict[key] = torch.from_numpy(value) + else: + expert_dict[key] = value + + return final_state_dict, expert_dict + + +def slice_gemma_state_dict(state_dict, config, num_expert=1): + # fmt: off + # text decoder (gemma) + # no embedding vector, the expert just has the decoder layers + + embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) + state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" + + llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") + llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") + llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") + + llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") + llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") + llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") + + for i in range(config.num_hidden_layers): + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) + + state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) + + state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") + state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) + + # fmt: on + final_state_dict = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + final_state_dict[key] = torch.from_numpy(value) + else: + final_state_dict[key] = value + return final_state_dict + + +def flatten_for_memory(tree, parent_key=""): + out = {} + for k, v in tree.items(): + new_key = f"{parent_key}/{k}" if parent_key else k + if isinstance(v, dict): + out.update(flatten_for_memory(v, new_key)) + else: + out[new_key] = np.array(v) # Ensure conversion to np.array for consistency + return out + + +def flatten_for_npz(tree, parent_key=""): + out = {} + for k, v in tree.items(): + new_key = f"{parent_key}/{k}" if parent_key else k + if isinstance(v, dict): + out.update(flatten_for_npz(v, new_key)) + else: + # bf16/f32 here? + out[new_key] = np.array(v) + return out + + +def slice_initial_orbax_checkpoint(checkpoint_dir: str): + params_path = pathlib.Path(checkpoint_dir).resolve() + checkpointer = ocp.PyTreeCheckpointer() + + metadata = checkpointer.metadata(params_path) + print("Metadata keys:", list(metadata.keys())) + + params_name = "params" + + item = {params_name: metadata[params_name]} + device = jax.local_devices()[0] # Use the first local device + sharding = SingleDeviceSharding(device) + restored = checkpointer.restore( + params_path, + ocp.args.PyTreeRestore( + item=item, + restore_args=jax.tree_util.tree_map( + lambda _: ocp.ArrayRestoreArgs( + restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it + sharding=sharding, + ), + item, + ), + transforms={}, + ), + ) + params = restored[params_name] + + # get params for PaliGemma + pali_params = params["PaliGemma"] + del params["PaliGemma"] + pali_params_flat = flatten_for_npz(pali_params) + return {"paligemma_params": pali_params_flat, "projection_params": params} + + +def update_keys_with_prefix(d: dict, prefix: str) -> dict: + """Update dictionary keys by adding a prefix.""" + return {f"{prefix}{key}": value for key, value in d.items()} + + +def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): + # Break down orbax ckpts - they are in OCDBT + initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) + # process projection params + keys = [ + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", + ] + + projection_params = {} + for key in keys: + kernel_params = initial_params["projection_params"][key]["kernel"] + bias_params = initial_params["projection_params"][key]["bias"] + if isinstance(kernel_params, dict): + weight = kernel_params["value"] + bias = bias_params["value"] + else: + weight = kernel_params + bias = bias_params + projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T + projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) + + # Process PaliGemma weights + paligemma_config = get_paligemma_config(precision) + paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( + initial_params["paligemma_params"], paligemma_config + ) + + # Process Gemma weights (at this stage they are unused) + gemma_config = get_gemma_config(precision) + gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) + + # Instantiate model from configs + + if "pi0_aloha_sim" in checkpoint_dir: + pi0_config = PI0Config( + empty_cameras=2, + adapt_to_pi_aloha=True, + use_delta_joint_actions_aloha=False, + ) + elif "pi0_aloha_towel" in checkpoint_dir: + pi0_config = PI0Config( + adapt_to_pi_aloha=True, + use_delta_joint_actions_aloha=True, + ) + elif "pi0_base" in checkpoint_dir: + pi0_config = PI0Config( + empty_cameras=0, + adapt_to_pi_aloha=False, + use_delta_joint_actions_aloha=False, + ) + else: + raise ValueError() + + # gemma_config=gemma_config, paligemma_config=paligemma_config) + pi0_model = PI0Policy(pi0_config) + + paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") + gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") + projection_params = update_keys_with_prefix(projection_params, "model.") + + # load state dict + torch_dtype = PRECISIONS[precision] + pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params}) + pi0_model = pi0_model.to(torch_dtype) + # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + + pi0_model.save_pretrained(output_path, safe_serialization=True) + # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) + + # assert that model loads properly + del pi0_model + PI0Policy.from_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_dir", + default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", + type=str, + help="Path to the ocdbt checkpoint", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + default="float32", + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + # tokenizer is identical to paligemma, it appears + + parser.add_argument( + "--tokenizer_hub_id", + default="google/paligemma-3b-pt-224", + type=str, + help="Hub path to the tokenizer to save", + ) + + parser.add_argument( + "--output_path", + required=True, + type=str, + help="Path to save converted weights to", + ) + + args = parser.parse_args() + convert_pi0_checkpoint( + checkpoint_dir=args.checkpoint_dir, + precision=args.precision, + tokenizer_id=args.tokenizer_hub_id, + output_path=args.output_path, + ) diff --git a/lerobot/common/policies/pi0/flex_attention.py b/lerobot/common/policies/pi0/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..35628cddb40c0b8781090312e67bd348cb5930bb --- /dev/null +++ b/lerobot/common/policies/pi0/flex_attention.py @@ -0,0 +1,141 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F # noqa: N812 +from packaging.version import Version + +if Version(torch.__version__) > Version("2.5.0"): + # Ffex attention is only available from torch 2.5 onwards + from torch.nn.attention.flex_attention import ( + _mask_mod_signature, + _round_up_to_multiple, + create_block_mask, + create_mask, + flex_attention, + ) + + +# @torch.compile(dynamic=False) +def flex_attention_forward( + attention_mask: torch.Tensor, + batch_size: int, + head_dim: int, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + scaling=None, +): + """ + This is defined out of classes to make compile happy. + """ + + original_dtype = query_states.dtype + num_att_heads = 8 + num_key_value_heads = 1 + num_key_value_groups = num_att_heads // num_key_value_heads + + key_states = key_states[:, :, :, None, :] + key_states = key_states.expand( + batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :] + value_states = value_states.expand( + batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + query_states = query_states.to(torch.float32) + key_states = key_states.to(torch.float32) + value_states = value_states.to(torch.float32) + + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, None, :, : key_states.shape[2]] + + if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: + causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) + + def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: + def mask_mod(b, h, q_idx, kv_idx): + # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. + return precomputed_mask[b][h][q_idx][kv_idx] + + return mask_mod + + b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask + + block_size = 128 + q_len_rounded = _round_up_to_multiple(q_len, block_size) + kv_len_rounded = _round_up_to_multiple(kv_len, block_size) + + # *CRITICAL* we do need to expand here, else we get a CUDA index error + + pad_q = q_len_rounded - q_len + pad_k = kv_len_rounded - kv_len + + padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) + mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) + + mask_4d = create_mask( + mod_fn=mask_mod_fn_orig, + B=b_mask, + H=h_mask, + Q_LEN=q_len_rounded, + KV_LEN=kv_len_rounded, + device=causal_mask.device, + _compile=False, + ) + + mask_mod_fn_padded = precomputed_mask_factory(mask_4d) + block_mask = create_block_mask( + mask_mod=mask_mod_fn_padded, + B=b_mask, + H=h_mask, + Q_LEN=q_len_rounded, + KV_LEN=kv_len_rounded, + BLOCK_SIZE=block_size, + device=causal_mask.device, + _compile=False, + ) + + # mask is applied inside the kernel, ideally more efficiently than score_mod. + attn_output, attention_weights = flex_attention( + query_states, + key_states, + value_states, + block_mask=block_mask, + enable_gqa=True, # because we shaped query/key states for GQA + scale=head_dim**-0.5 if scaling is None else scaling, + return_lse=True, + ) + + attn_output = attn_output.to(dtype=original_dtype) + attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] + attn_output = attn_output.reshape( + batch_size, + -1, + attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] + ) + return attn_output diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py new file mode 100644 index 0000000000000000000000000000000000000000..7599fa6354337a9443cb39c9a3107e7d9cc43ea1 --- /dev/null +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -0,0 +1,732 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +π0: A Vision-Language-Action Flow Model for General Robot Control + +[Paper](https://www.physicalintelligence.company/download/pi0.pdf) +[Jax code](https://github.com/Physical-Intelligence/openpi) + +Designed by Physical Intelligence. Ported from Jax by Hugging Face. + +Install pi0 extra dependencies: +```bash +pip install -e ".[pi0]" +``` + +Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/pi0 \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of finetuning the pi0 neural network with PaliGemma and expert Gemma +pretrained with VLM default parameters before pi0 finetuning: +```bash +python lerobot/scripts/train.py \ +--policy.type=pi0 \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of using the pi0 pretrained model outside LeRobot training framework: +```python +policy = Pi0Policy.from_pretrained("lerobot/pi0") +``` + +""" + +import math +from collections import deque + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoTokenizer + +from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.pi0.paligemma_with_expert import ( + PaliGemmaWithExpertConfig, + PaliGemmaWithExpertModel, +) +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.utils import get_safe_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb + + +def sample_beta(alpha, beta, bsize, device): + gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) + gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) + return gamma1 / (gamma1 + gamma2) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + return att_2d_masks + + +def resize_with_pad(img, width, height, pad_value=-1): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_vector(vector, new_dim): + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class PI0Policy(PreTrainedPolicy): + """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" + + config_class = PI0Config + name = "pi0" + + def __init__( + self, + config: PI0Config, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + self.model = PI0FlowMatching(config) + + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + + batch = self.normalize_inputs(batch) + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise + ) + + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: + """Do a full training forward pass to compute the loss""" + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + actions = self.prepare_action(batch) + actions_is_pad = batch.get("action_is_pad") + + loss_dict = {} + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + loss_dict["losses_after_forward"] = losses.clone() + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + losses = losses * in_episode_bound.unsqueeze(-1) + loss_dict["losses_after_in_ep_bound"] = losses.clone() + + # Remove padding + losses = losses[:, :, : self.config.max_action_dim] + loss_dict["losses_after_rm_padding"] = losses.clone() + + # For backward pass + loss = losses.mean() + # For logging + loss_dict["l2_loss"] = loss.item() + + return loss, loss_dict + + def prepare_images(self, batch): + """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and + convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + """ + images = [] + img_masks = [] + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_language(self, batch) -> tuple[Tensor, Tensor]: + """Tokenize the text input""" + device = batch[OBS_ROBOT].device + tasks = batch["task"] + + # PaliGemma prompt has to end with a new line + tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + + tokenized_prompt = self.language_tokenizer.__call__( + tasks, + padding="max_length", + padding_side="right", + max_length=self.config.tokenizer_max_length, + return_tensors="pt", + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + def prepare_state(self, batch): + """Pad state""" + state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) + return state + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + +class PI0FlowMatching(nn.Module): + """ + π0: A Vision-Language-Action Flow Model for General Robot Control + + [Paper](https://www.physicalintelligence.company/download/pi0.pdf) + [Jax code](https://github.com/Physical-Intelligence/openpi) + + Designed by Physical Intelligence. Ported from Jax by Hugging Face. + ┌──────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌┴─────┐ │ + │ kv cache │Gemma │ │ + │ ┌──────────►│Expert│ │ + │ │ │ │ │ + │ ┌┴────────┐ │x 10 │ │ + │ │ │ └▲──▲──┘ │ + │ │PaliGemma│ │ │ │ + │ │ │ │ robot state │ + │ │ │ noise │ + │ └▲──▲─────┘ │ + │ │ │ │ + │ │ image(s) │ + │ language tokens │ + └──────────────────────────────┘ + """ + + def __init__(self, config): + super().__init__() + self.config = config + + paligemma_with_export_config = PaliGemmaWithExpertConfig( + freeze_vision_encoder=self.config.freeze_vision_encoder, + train_expert_only=self.config.train_expert_only, + attention_implementation=self.config.attention_implementation, + ) + self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) + + # Projections are float32 + self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) + self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) + + self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) + self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) + + self.set_requires_grad() + + def set_requires_grad(self): + for params in self.state_proj.parameters(): + params.requires_grad = self.config.train_state_proj + + def sample_noise(self, shape, device): + noise = torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + return noise + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for PaliGemma transformer processing. + """ + # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + embs = [] + pad_masks = [] + att_masks = [] + + # TODO: remove for loop + for ( + img, + img_mask, + ) in zip(images, img_masks, strict=False): + img_emb = self.paligemma_with_expert.embed_image(img) + img_emb = img_emb.to(dtype=torch.bfloat16) + + # Normalize image embeddings + img_emb_dim = img_emb.shape[-1] + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed state + state_emb = self.state_proj(state) + state_emb = state_emb.to(dtype=torch.bfloat16) + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + dtype = state_emb.dtype + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device + ) + time_emb = time_emb.type(dtype=dtype) + + # Fuse timestep + action information using an MLP + action_emb = self.action_in_proj(noisy_actions) + + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.n_action_steps - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + suffix_out = suffix_out[:, -self.config.n_action_steps :] + # Original openpi code, upcast attention output + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + + losses = F.mse_loss(u_t, v_t, reduction="none") + return losses + + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Compute image and language key value cache + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=self.config.use_cache, + fill_kv_cache=True, + ) + + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + + # Euler step + x_t += dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=self.config.use_cache, + fill_kv_cache=False, + ) + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.n_action_steps :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + return v_t diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py new file mode 100644 index 0000000000000000000000000000000000000000..76e2ce6005cdcde0e3c2730e2962f051598338bd --- /dev/null +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -0,0 +1,417 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.version +from pytest import Cache +from torch import nn +from transformers import ( + AutoConfig, + GemmaForCausalLM, + PaliGemmaForConditionalGeneration, + PretrainedConfig, + PreTrainedModel, +) +from transformers.models.auto import CONFIG_MAPPING + +from lerobot.common.policies.pi0.flex_attention import flex_attention_forward + + +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +class PaliGemmaWithExpertConfig(PretrainedConfig): + model_type = "PaliGemmaWithExpertModel" + sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} + + def __init__( + self, + paligemma_config: dict | None = None, + gemma_expert_config: dict | None = None, + freeze_vision_encoder: bool = True, + train_expert_only: bool = True, + attention_implementation: str = "eager", + **kwargs, + ): + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_implementation = attention_implementation + + if paligemma_config is None: + # Default config from Pi0 + self.paligemma_config = CONFIG_MAPPING["paligemma"]( + transformers_version="4.48.1", + _vocab_size=257152, + bos_token_id=2, + eos_token_id=1, + hidden_size=2048, + image_token_index=257152, + model_type="paligemma", + pad_token_id=0, + projection_dim=2048, + text_config={ + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2048, + "intermediate_size": 16384, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_image_tokens": 256, + "num_key_value_heads": 1, + "torch_dtype": "float32", + "vocab_size": 257152, + }, + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_image_tokens": 256, + "patch_size": 14, + "projection_dim": 2048, + "projector_hidden_act": "gelu_fast", + "torch_dtype": "float32", + "vision_use_head": False, + }, + ) + elif isinstance(self.paligemma_config, dict): + # Override Pi0 default config for PaliGemma + if "model_type" not in gemma_expert_config: + paligemma_config["model_type"] = "paligemma" + + cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] + self.paligemma_config = cfg_cls(**paligemma_config) + + if gemma_expert_config is None: + # Default config from Pi0 + self.gemma_expert_config = CONFIG_MAPPING["gemma"]( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=2, + eos_token_id=1, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation="gelu_pytorch_tanh", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=4096, + max_position_embeddings=8192, + model_type="gemma", + num_attention_heads=8, + num_hidden_layers=18, + num_key_value_heads=1, + pad_token_id=0, + rms_norm_eps=1e-06, + rope_theta=10000.0, + torch_dtype="float32", + transformers_version="4.48.1", + use_cache=True, + vocab_size=257152, + ) + elif isinstance(self.gemma_expert_config, dict): + # Override Pi0 default config for Gemma Expert + if "model_type" not in gemma_expert_config: + gemma_expert_config["model_type"] = "gemma" + + cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] + self.gemma_expert_config = cfg_cls(**gemma_expert_config) + + super().__init__(**kwargs) + + def __post_init__(self): + super().__post_init__() + if self.train_expert_only and not self.freeze_vision_encoder: + raise ValueError( + "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." + ) + + if self.attention_implementation not in ["eager", "fa2", "flex"]: + raise ValueError( + f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." + ) + + +class PaliGemmaWithExpertModel(PreTrainedModel): + config_class = PaliGemmaWithExpertConfig + + def __init__(self, config: PaliGemmaWithExpertConfig): + super().__init__(config=config) + self.config = config + self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) + self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) + # Remove unused embed_tokens + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_like_physical_intelligence() + self.set_requires_grad() + + def set_requires_grad(self): + if self.config.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for params in self.paligemma.vision_tower.parameters(): + params.requires_grad = False + + if self.config.train_expert_only: + self.paligemma.eval() + for params in self.paligemma.parameters(): + params.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + + if self.config.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + + if self.config.train_expert_only: + self.paligemma.eval() + + def to_bfloat16_like_physical_intelligence(self): + self.paligemma = self.paligemma.to(dtype=torch.bfloat16) + + params_to_change_dtype = [ + "language_model.model.layers", + "gemma_expert.model.layers", + "vision_tower", + "multi_modal", + ] + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_change_dtype): + param.data = param.data.to(dtype=torch.bfloat16) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.model.embed_tokens(tokens) + + # TODO: break down this huge forward into modules or functions + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: List[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + fill_kv_cache: Optional[bool] = None, + ): + models = [self.paligemma.language_model.model, self.gemma_expert.model] + + for hidden_states in inputs_embeds: + # TODO this is very inefficient + # dtype is always the same, batch size too (if > 1 len) + # device could be trickier in multi gpu edge cases but that's it + if hidden_states is None: + continue + batch_size = hidden_states.shape[0] + + # RMSNorm + num_layers = self.paligemma.config.text_config.num_hidden_layers + head_dim = self.paligemma.config.text_config.head_dim + for layer_idx in range(num_layers): + query_states = [] + key_states = [] + value_states = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is None: + continue + layer = models[i].layers[layer_idx] + # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) + # hidden_states = hidden_states * normalizer + hidden_states = layer.input_layernorm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=torch.bfloat16) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + + query_states = apply_rope(query_states, position_ids) + key_states = apply_rope(key_states, position_ids) + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + value_states = torch.cat( + [past_key_values[layer_idx]["value_states"], value_states], dim=1 + ) + + attention_interface = self.get_attention_interface() + att_output = attention_interface( + attention_mask, batch_size, head_dim, query_states, key_states, value_states + ) + att_output = att_output.to(dtype=torch.bfloat16) + + # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + + if hidden_states is not None: + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + + # TODO: first dropout (by default 0.0) + + # first residual + out_emb += hidden_states + after_first_residual = out_emb.clone() + + out_emb = layer.post_attention_layernorm(out_emb) + out_emb = layer.mlp(out_emb) + + # TODO: second dropout (by default 0.0) + + # second residual + out_emb += after_first_residual + + outputs_embeds.append(out_emb) + + start = end + else: + outputs_embeds.append(None) + + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is not None: + out_emb = models[i].norm(hidden_states) + outputs_embeds.append(out_emb) + else: + outputs_embeds.append(None) + + return outputs_embeds, past_key_values + + def get_attention_interface(self): + if self.config.attention_implementation == "fa2": + attention_interface = self.flash_attention_forward + elif self.config.attention_implementation == "flex": + attention_interface = flex_attention_forward + else: + attention_interface = self.eager_attention_forward + return attention_interface + + def flash_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + raise NotImplementedError("FA2 is not implemented (yet)") + + def eager_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + num_att_heads = self.config.paligemma_config.text_config.num_attention_heads + num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + # query_states: batch_size, sequence_length, num_att_head, head_dim + # key_states: batch_size, sequence_length, num_key_value_head, head_dim + # value_states: batch_size, sequence_length, num_key_value_head, head_dim + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + # Attention here is upcasted to float32 to match the original eager implementation. + + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= head_dim**-0.5 + big_neg = -2.3819763e38 # See gemma/modules.py + + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length + # value_states: batch_size, sequence_length, num_att_heads, head_dim + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + + att_output = att_output.permute(0, 2, 1, 3) + # we use -1 because sequence length can change + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + + return att_output diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py new file mode 100644 index 0000000000000000000000000000000000000000..29c856e0645579ac75c067446054af298baec0bc --- /dev/null +++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@PreTrainedConfig.register_subclass("pi0fast") +@dataclass +class PI0FASTConfig(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 10 + n_action_steps: int = 5 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 # 32 + max_action_dim: int = 32 # 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (224, 224) + interpolate_like_pi: bool = False + + # Add empty images. Used by pi0_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + + # Projector + proj_width: int = 1024 + + # Decoding + max_decoding_steps: int = 256 + fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + max_input_seq_len: int = 256 # 512 + + # Utils + use_cache: bool = True + + # Frozen parameters + freeze_vision_encoder: bool = True + freeze_lm_head: bool = True + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-5 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + checkpoint_path: str = None + + padding_side: str = "right" + + precision: str = "bfloat16" + grad_clip_norm: float = 1 + + # Allows padding/truncation of generated action tokens during detokenization to ensure decoding. + # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. + relaxed_action_decoding: bool = True + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + def validate_features(self) -> None: + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py new file mode 100644 index 0000000000000000000000000000000000000000..36aafce94b09bf1fbbd63aa5a358a1cd47ba242b --- /dev/null +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -0,0 +1,973 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models + +[Paper](https://arxiv.org/abs/2501.09747) +[Jax code](https://github.com/Physical-Intelligence/openpi) + +Designed by Physical Intelligence. Ported from Jax by Hugging Face. + +Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/pi0fast_base \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of training the pi0+FAST neural network with from scratch: +```bash +python lerobot/scripts/train.py \ +--policy.type=pi0fast \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of using the pi0 pretrained model outside LeRobot training framework: +```python +policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") +``` + +""" + +from collections import deque +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from PIL import Image +from scipy.fft import idct +from torch import Tensor, nn +from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration +from transformers.cache_utils import HybridCache, StaticCache +from transformers.models.auto import CONFIG_MAPPING + +from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.common.policies.pretrained import PreTrainedPolicy + +PRECISION = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class PI0FASTPolicy(PreTrainedPolicy): + """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" + + config_class = PI0FASTConfig + name = "pi0fast" + + def __init__( + self, + config: PI0FASTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") + self.model = PI0FAST(config) + + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def get_optim_params(self) -> dict: + return self.parameters() + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + + batch = self.normalize_inputs(batch) + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + actions = self.model.generate_actions(batch) + + actions = actions[:, : self.config.n_action_steps] + + original_action_dim = self.config.action_feature.shape[ + 0 + ] # self.config.max_action_dim # self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + loss_dict = self.model.forward(batch) + return loss_dict["loss"], loss_dict + + +def block_causal_update_causal_mask( + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + attn_implementation: str = "eager", + dtype: torch.dtype = "float32", +): + """ + Update the causal mask during training and generation. It can be customized to different attention masks. + """ + if attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(dtype).min + + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + + if using_static_cache or isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + # Handle precomputed attention masks + if attention_mask is not None and attention_mask.dim() == 4: + return attention_mask + + # Causal mask initialization + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + + # Standard causal masking (triu ensures tokens can only attend to past) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + # Apply block causal mask + if token_type_ids is not None: + token_type_ids = token_type_ids.to(causal_mask.device).bool() + cumsum = torch.cumsum(token_type_ids, dim=1) + block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] + + # Combine causal_mask with block-wise attention mask + causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) + causal_mask = causal_mask[:, None, :, :] + else: + # Apply past cache position constraint + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + else: + # Apply past cache position constraint + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits + mask_length = attention_mask.shape[-1] + + # Apply padding mask + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +def prepare_inputs_for_generation( + # self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + num_logits_to_keep=None, + labels=None, + self=None, + **kwargs, +): + # create block causal attention + if cache_position[0] > 0 and input_ids.shape[1] > 0: + input_tensor = input_ids[:, -1:] + new_positions = ( + torch.ones( + (position_ids.shape[0], input_ids.shape[1]), + dtype=position_ids.dtype, + device=position_ids.device, + ).cumsum(-1) + + position_ids[:, -1:] + ) + position_ids = torch.cat([position_ids, new_positions], dim=-1) + else: + input_tensor = inputs_embeds + attention_mask = block_causal_update_causal_mask( + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + input_tensor=input_tensor, + token_type_ids=token_type_ids, + dtype=self.dtype, + attn_implementation=self.config.text_config._attn_implementation, + ) + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + num_logits_to_keep=num_logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # Position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + +class PI0FAST(nn.Module): + def __init__(self, config: PI0FASTConfig): + super().__init__() + self.config = config + + # TODO: move tokenizers in Policy + fast_tokenizer_path = "physical-intelligence/fast" + pi0_paligemma_path = "google/paligemma-3b-pt-224" + self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) + self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) + self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) + self.fast_skip_tokens = self.config.fast_skip_tokens + self.max_input_seq_len = self.config.max_input_seq_len + self.action_horizon = self.config.chunk_size + self.action_dim = self.config.action_feature.shape[ + 0 + ] # self.config.max_action_dim # self.config.action_feature.shape[0] + precision = config.precision + torch_precision = PRECISION.get(precision, torch.float32) + self.pad_token_id = ( + self.paligemma_tokenizer.pad_token_id + if hasattr(self.paligemma_tokenizer, "pad_token_id") + else self.paligemma_tokenizer.eos_token_id + ) + + paligemma_config = CONFIG_MAPPING["paligemma"]( + transformers_version="4.48.1", + _vocab_size=257152, + bos_token_id=2, + eos_token_id=1, + hidden_size=2048, + image_token_index=257152, + model_type="paligemma", + pad_token_id=0, + projection_dim=2048, + text_config={ + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2048, + "intermediate_size": 16384, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_image_tokens": 256, + "num_key_value_heads": 1, + "torch_dtype": precision, + "vocab_size": 257152, + "_attn_implementation": "eager", + }, + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_image_tokens": 256, + "patch_size": 14, + "projection_dim": 2048, + "projector_hidden_act": "gelu_pytorch_tanh", + "torch_dtype": precision, + "vision_use_head": False, + }, + ) + self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) + + self.pi0_paligemma.prepare_inputs_for_generation = partial( + prepare_inputs_for_generation, self=self.pi0_paligemma + ) + # change important stuff in bf16 + params_to_change_dtype = [ + "language_model", + "vision_tower", + "multi_modal", + ] + for name, param in self.pi0_paligemma.named_parameters(): + if any(selector in name for selector in params_to_change_dtype): + param.data = param.data.to(dtype=torch_precision) + self.set_requires_grad() + self.image_keys = self.config.image_features.keys() + self.ignore_index = self.pi0_paligemma.config.ignore_index + self.padding_side = self.config.padding_side + + def set_requires_grad(self): + if self.config.freeze_vision_encoder: + self.pi0_paligemma.vision_tower.eval() + for params in self.pi0_paligemma.vision_tower.parameters(): + params.requires_grad = False + # To avoid unused params issue with distributed training + if self.config.freeze_lm_head: + for name, params in self.pi0_paligemma.named_parameters(): + if "embed_tokens" in name: # lm heads and embedding layer are tied + params.requires_grad = False + + def embed_tokens(self, tokens: torch.Tensor): + return self.pi0_paligemma.language_model.model.embed_tokens(tokens) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) + + def prepare_images(self, batch): + """Preprocess LeRobot batch into Pi0 inputs""" + images = [] + img_masks = [] + present_img_keys = [key for key in self.image_keys if key in batch] + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + # Preprocess image features present in the batch + num_empty_cameras = 0 + for key in self.image_keys: + if key in present_img_keys: + img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad( + img, + *self.config.resize_imgs_with_padding, + pad_value=0, + interpolate_like_pi=self.config.interpolate_like_pi, + ) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + else: + if num_empty_cameras >= self.config.empty_cameras: + continue + img = torch.ones_like(img) * -1 + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + num_empty_cameras += 1 + + images.append(img) + img_masks.append(mask) + return images, img_masks + + def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: + mins = actions.amin(dim=(1, 2), keepdim=True) # [0] + maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] + return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 + + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens + return out + + def fast_tokenizer_wrapper(self, actions_norm): + """ + A wrapper for self.fast_tokenizer that ensures batch processing, + conversion to PyTorch tensors, and returns a dictionary without padding. + """ + batch_tokens = self.fast_tokenizer(actions_norm) + fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") + + return fast_out + + def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: + token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) + # Compute cumulative sum mask + cumsum_mask = (padded_mask != 0).cumsum(dim=1) + # Suffix block (everything after prefix_len) + suffix_mask = cumsum_mask > prefix_len + token_type_ids = suffix_mask + return token_type_ids + + def create_input_tokens(self, state, lang_text, actions=None): + bsize = state.shape[0] + device = state.device + bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] + discretized = torch.bucketize(state, bins) - 1 + discretized = discretized[:, :32] + + prefix_texts = [] + state_text = [] + for txt, disc in zip(lang_text, discretized, strict=False): + cleaned = txt.lower().strip().replace("_", " ") + state_str = " ".join(str(val.item()) for val in disc) + prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") + state_text.append(f"State: {state_str};\n") + + prefix_out = self.paligemma_tokenizer( + prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False + ) + prefix_ids = prefix_out["input_ids"].to(device) + prefix_mask = prefix_out["attention_mask"].to(device) + prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() + + if actions is not None: + actions_norm = self.normalize_actions(actions) + actions_pad = F.pad( + actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 + )[:, :, : self.config.max_action_dim] + fast_out = self.fast_tokenizer_wrapper( + actions_pad.cpu(), + ) + act_ids = fast_out["input_ids"] + act_mask = fast_out["attention_mask"].to(device) + + act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) + # Replace action with 0 to pad tokens + act_ids = torch.where( + act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, + self.pad_token_id, + act_ids, + ) + + eos_token = torch.tensor( + [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device + ).expand(bsize, -1) + eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) + bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") + bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) + bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) + act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) + act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) + act_mask = act_mask.to(device) + else: + act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) + act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) + final_ids = torch.cat([prefix_ids, act_ids], dim=1) + + final_mask = torch.cat([prefix_mask, act_mask], dim=1) + batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} + + # Use tokenizer pad function + padded_output = self.paligemma_tokenizer.pad( + batch_inputs, padding="longest", max_length=180, return_tensors="pt" + ) + padded_mask = padded_output["attention_mask"] + + # define tensor of padding lengths + att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens + + token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) + + padded_output["padded_mask"] = padded_output.pop("attention_mask") + padded_output["attention_mask"] = att_mask + # loss is computed not on prefix, and not on padding + padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] + padded_output["token_type_ids"] = token_type_ids + return padded_output + + def shift_padding_side( + self, + tokens: torch.Tensor, + ar_mask: torch.Tensor, + padding_mask: torch.Tensor, + loss_mask: torch.Tensor, + targets: torch.Tensor, + token_type_ids: torch.Tensor, + padding_side: str = "right", + ) -> tuple[torch.Tensor]: + if padding_side not in ["right", "left"]: + return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids + + new_tokens = torch.empty_like(tokens) + new_ar_masks = torch.empty_like(ar_mask) + new_padding_mask = torch.empty_like(padding_mask) + new_loss_mask = torch.empty_like(loss_mask) + new_targets = torch.empty_like(targets) + new_token_type_ids = torch.empty_like(token_type_ids) + batch_size = tokens.shape[0] + for i in range(batch_size): + padding_indices = torch.where(padding_mask[i] == 0)[0] + non_padding_indices = torch.where(padding_mask[i] == 1)[0] + if padding_side == "left": + new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) + else: + new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) + new_tokens[i] = tokens[i].index_select(0, new_indices) + new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) + new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) + new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) + new_targets[i] = targets[i].index_select(0, new_indices) + new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) + + return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids + + def forward(self, batch: dict[str, Tensor]): + device = batch[OBS_ROBOT].device + # TODO: keep like this or move to the policy .forward + images, img_masks = self.prepare_images(batch) + + padded_outs = self.create_input_tokens( + state=batch[OBS_ROBOT], + lang_text=batch["task"], + actions=batch[ACTION], + ) + + embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( + images, + img_masks, + padded_outs["input_ids"], + padded_outs["padded_mask"], + padded_outs["attention_mask"], + padded_outs["loss_mask"], + padded_outs["token_type_ids"], + padding_side=self.padding_side, + ) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + token_type_ids = token_type_ids.to(dtype=torch.int64) + past_seen_tokens = 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) + pad_masks = block_causal_update_causal_mask( + attention_mask=pad_masks, + past_key_values=None, + cache_position=cache_position, + input_tensor=embs, + token_type_ids=token_type_ids, + dtype=self.pi0_paligemma.dtype, + attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, + ) + outputs = self.pi0_paligemma.forward( + input_ids=None, + token_type_ids=None, + attention_mask=pad_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=embs, + use_cache=False, + labels=None, + ) + + logits = outputs.logits + + loss_fct = nn.CrossEntropyLoss(reduction="none") + + # Shift left for next-step prediction + logits = logits[:, :-1, :] + targets = targets[:, 1:].to(device) # Shift targets + loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape + + # Compute per-token loss + token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) + + # Apply loss mask + token_loss = token_loss * loss_mask.reshape(-1) + + # Compute final loss + loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) + + # Return loss dictionary + loss_dict = {"ce_loss": loss.item(), "loss": loss} + return loss_dict + + def decode_actions_with_fast( + self, + tokens: list[list[int]], + *, + time_horizon: int | None = None, + action_dim: int | None = None, + relaxed_decoding: bool = True, + ) -> np.array: + """ + Adapt original decoding in FAST to always return actions instead of zeros. + """ + self.time_horizon = ( + time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon + ) + self.action_dim = ( + action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim + ) + + # Cache the time horizon and action dimension for the next call + self.called_time_horizon = self.time_horizon + self.called_action_dim = self.action_dim + + assert self.time_horizon is not None and self.action_dim is not None, ( + "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." + ) + + decoded_actions = [] + for token in tokens: + try: + decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token + if relaxed_decoding: + # Expected sequence length + expected_seq_len = self.time_horizon * self.action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + # Apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right + # Apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) + + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) + assert decoded_dct_coeff.shape == ( + self.time_horizon, + self.action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" + ) + except Exception as e: + print(f"Error decoding tokens: {e}") + print(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) + decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) + return np.stack(decoded_actions) + + def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: + """ + Extracts actions from predicted output tokens using the FAST model. + + Args: + tokens (torch.Tensor): The input tensor of tokenized outputs. + action_horizon (int): The number of timesteps for actions. + action_dim (int): The dimensionality of each action. + + Returns: + torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). + """ + # Decode predicted output tokens + decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) + cleaned_tokens = [ + tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() + for tokens_sequence in decoded_tokens + ] + raw_action_tokens = [ + self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) + for sample_tokens in cleaned_tokens + ] # something like this should be robust #looks good + action_tokens = [ + self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens + ] + # returns the tensor of decoded actions per sample in a list + decoded_actions = [ + torch.tensor( + self.decode_actions_with_fast( + tok.tolist(), + time_horizon=action_horizon, + action_dim=action_dim, + relaxed_decoding=self.config.relaxed_action_decoding, + ), + device=tokens.device, + ).squeeze(0) + for tok in action_tokens + ] + + return torch.stack( + decoded_actions, + dim=0, + ) + + def generate_actions(self, batch: dict[str, Tensor]): + # TODO: keep like this or move to the policy .forward + images, img_masks = self.prepare_images(batch) + + padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None) + embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( + images, + img_masks, + padded_outs["input_ids"], + padded_outs["padded_mask"], + padded_outs["attention_mask"], + padded_outs["loss_mask"], + padded_outs["token_type_ids"], + padding_side="left", + ) + token_type_ids = token_type_ids.to(dtype=torch.int64) + prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 + output_tokens = self.pi0_paligemma.generate( + input_ids=None, + attention_mask=pad_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=embs, + use_cache=self.config.use_cache, + max_new_tokens=self.config.max_decoding_steps, + do_sample=False, + num_beams=1, + token_type_ids=token_type_ids, + ) + actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) + return actions + + def embed_image(self, image: torch.Tensor): + return self.pi0_paligemma.get_image_features(image) + + def embed_inputs( + self, + images, + img_masks, + tokens, + pad_mask, + ar_mask, + loss_mask, + token_type_ids, + padding_side: str = "right", + ): + # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + # images are a list of same size + # vectorizing everything! + device = images[0].device + image_embedding_dim = images[0].shape[-1] # TODO should be from self.config + all_images = torch.stack(images, dim=1).to(device) + b, n, c, h, w = all_images.shape + all_images = all_images.view(b * n, c, h, w) + embedded = self.embed_image(all_images).to(device) + b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions + m = b_n // b # Compute the number of images per sample dynamically + + # Reshape dynamically + embedded = embedded.view(b, m, p, image_embedding_dim) + tokens_embs = self.embed_tokens(tokens.to(device)) + + img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) + num_img_emb = embedded.shape[2] + img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) + img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) + + image_target_tokens = ( + torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id + ).reshape(b, -1) + image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) + + embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) + + embs = torch.cat([embedded, tokens_embs], dim=1).to(device) + pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) + att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) + loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) + targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) + token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) + + # Shift pad tokens to the left (.generate()) or right (.train()) + embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( + embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side + ) + + targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) + return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids + + +def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + if interpolate_like_pi: + img = (img * 255.0).to(dtype=torch.uint8) + img = img.permute(0, 2, 3, 1) + original_device = img.device + img = img.to(device="cpu").numpy() + imgs = [] + for sub_img in img: + sub_img = Image.fromarray(sub_img) + resized_img = sub_img.resize((resized_width, resized_height), resample=2) + resized_img = torch.from_numpy(np.array(resized_img)) + imgs.append(resized_img) + img = torch.stack(imgs, dim=0) + img = img.permute(0, 3, 1, 2) + resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 + else: + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..da4ef1572195d83605df7d9e347450ad92c0ed32 --- /dev/null +++ b/lerobot/common/policies/pretrained.py @@ -0,0 +1,199 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import logging +import os +from pathlib import Path +from typing import Type, TypeVar + +import packaging +import safetensors +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_model as load_model_as_safetensor +from safetensors.torch import save_model as save_model_as_safetensor +from torch import Tensor, nn + +from lerobot.common.utils.hub import HubMixin +from lerobot.configs.policies import PreTrainedConfig + +T = TypeVar("T", bound="PreTrainedPolicy") + +DEFAULT_POLICY_CARD = """ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot): +- Docs: {{ docs_url | default("[More Information Needed]", true) }} +""" + + +class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): + """ + Base class for policy models. + """ + + config_class: None + name: None + + def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PreTrainedConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PreTrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.config = config + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not getattr(cls, "config_class", None): + raise TypeError(f"Class {cls.__name__} must define 'config_class'") + if not getattr(cls, "name", None): + raise TypeError(f"Class {cls.__name__} must define 'name'") + + def _save_pretrained(self, save_directory: Path) -> None: + self.config._save_pretrained(save_directory) + model_to_save = self.module if hasattr(self, "module") else self + save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + + @classmethod + def from_pretrained( + cls: Type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = False, + **kwargs, + ) -> T: + """ + The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are + deactivated). To train it, you should first set it back in training mode with `policy.train()`. + """ + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + model_id = str(pretrained_name_or_path) + instance = cls(config, **kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" + ) from e + + policy.to(config.device) + policy.eval() + return policy + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): + load_model_as_safetensor(model, model_file, strict=strict) + if map_location != "cpu": + logging.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) + else: + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + return model + + # def generate_model_card(self, *args, **kwargs) -> ModelCard: + # card = ModelCard.from_template( + # card_data=self._hub_mixin_info.model_card_data, + # template_str=self._hub_mixin_info.model_card_template, + # repo_url=self._hub_mixin_info.repo_url, + # docs_url=self._hub_mixin_info.docs_url, + # **kwargs, + # ) + # return card + + @abc.abstractmethod + def get_optim_params(self) -> dict: + """ + Returns the policy-specific parameters dict to be passed on to the optimizer. + """ + raise NotImplementedError + + @abc.abstractmethod + def reset(self): + """To be called whenever the environment is reset. + + Does things like clearing caches. + """ + raise NotImplementedError + + # TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'? + @abc.abstractmethod + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """_summary_ + + Args: + batch (dict[str, Tensor]): _description_ + + Returns: + tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which + is a Tensor, all other items should be logging-friendly, native Python types. + """ + raise NotImplementedError + + @abc.abstractmethod + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Return one action to run in the environment (potentially in batch mode). + + When the model uses a history of observations, or outputs a sequence of actions, this method deals + with caching. + """ + raise NotImplementedError diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py new file mode 100644 index 0000000000000000000000000000000000000000..3fce01df9db3a55a7730e8b5e54069a1b2882716 --- /dev/null +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("tdmpc") +@dataclass +class TDMPCConfig(PreTrainedConfig): + """Configuration class for TDMPCPolicy. + + Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single + camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. + + Args: + n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google + action repeats in Q-learning or ask your favorite chatbot) + horizon: Horizon for model predictive control. + n_action_steps: Number of action steps to take from the plan given by model predictive control. This + is an alternative to using action repeats. If this is set to more than 1, then we require + `n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this + approach of using multiple steps from the plan is not in the original implementation. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to + match the original implementation. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping + to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" + normalization mode here. + image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. + state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. + latent_dim: Observation's latent embedding dimension. + q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation. + mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy + (π), Q ensemble, and V. + discount: Discount factor (γ) to use for the reinforcement learning formalism. + use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model + (π) for each step. + cem_iterations: Number of iterations for the MPPI/CEM loop in MPC. + max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM. + min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π). + Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM. + n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must + be non-zero. + n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can + be zero. + uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating + trajectory values (this is the λ coefficient in eqn 4 of FOWM). + n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration. + elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the + elites, when updating the gaussian parameters for CEM. + gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian + parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ. + max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the + image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation + is applied. Note that the input images are assumed to be square for this augmentation. + reward_coeff: Loss weighting coefficient for the reward regression loss. + expectile_weight: Weighting (τ) used in expectile regression for the state value function (V). + v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to + be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do + because v_target is obtained by evaluating the learned state-action value functions (Q) with + in-sample actions that may not be always optimal. + value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state + value (V) expectile regression loss. + consistency_coeff: Loss weighting coefficient for the consistency loss. + advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage + weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages + are clamped at 100.0. + pi_coeff: Loss weighting coefficient for the action regression loss. + temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time- + steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the + current time step. + target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated + as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the + model being trained. + """ + + # Input / output structure. + n_obs_steps: int = 1 + n_action_repeats: int = 2 + horizon: int = 5 + n_action_steps: int = 1 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ENV": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: int = 32 + state_encoder_hidden_dim: int = 256 + latent_dim: int = 50 + q_ensemble_size: int = 5 + mlp_dim: int = 512 + # Reinforcement learning. + discount: float = 0.9 + + # Inference. + use_mpc: bool = True + cem_iterations: int = 6 + max_std: float = 2.0 + min_std: float = 0.05 + n_gaussian_samples: int = 512 + n_pi_samples: int = 51 + uncertainty_regularizer_coeff: float = 1.0 + n_elites: int = 50 + elite_weighting_temperature: float = 0.5 + gaussian_mean_momentum: float = 0.1 + + # Training and loss computation. + max_random_shift_ratio: float = 0.0476 + # Loss coefficients. + reward_coeff: float = 0.5 + expectile_weight: float = 0.9 + value_coeff: float = 0.1 + consistency_coeff: float = 20.0 + advantage_scaling: float = 3.0 + pi_coeff: float = 0.5 + temporal_decay_coeff: float = 0.5 + # Target model. + target_model_momentum: float = 0.995 + + # Training presets + optimizer_lr: float = 3e-4 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_gaussian_samples <= 0: + raise ValueError( + f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" + ) + if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX: + raise ValueError( + "TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly " + f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " + "information." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + if self.n_action_steps > 1: + if self.n_action_repeats != 1: + raise ValueError( + "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." + ) + if not self.use_mpc: + raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") + if self.n_action_steps > self.horizon: + raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig(lr=self.optimizer_lr) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + # There should only be one image key. + if len(self.image_features) > 1: + raise ValueError( + f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}." + ) + + if len(self.image_features) > 0: + image_ft = next(iter(self.image_features.values())) + if image_ft.shape[-2] != image_ft.shape[-1]: + # TODO(alexander-soare): This limitation is solely because of code in the random shift + # augmentation. It should be able to be removed. + raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.") + + @property + def observation_delta_indices(self) -> list: + return list(range(self.horizon + 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return list(range(self.horizon)) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py new file mode 100644 index 0000000000000000000000000000000000000000..b46ae9030bac56111c95df9ada2266dc1111fe0d --- /dev/null +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of Finetuning Offline World Models in the Real World. + +The comments in this code may sometimes refer to these references: + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) +""" + +# ruff: noqa: N806 + +from collections import deque +from copy import deepcopy +from functools import partial +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues + + +class TDMPCPolicy(PreTrainedPolicy): + """Implementation of TD-MPC learning + inference. + + Please note several warnings for this policy. + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces the results from FOWM. + - Nevertheless, we have verified that we can train TD-MPC for PushT. See + `lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + + config_class = TDMPCConfig + name = "tdmpc" + + def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = TDMPCTOLD(config) + self.model_target = deepcopy(self.model) + for param in self.model_target.parameters(): + param.requires_grad = False + + self.reset() + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """ + Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be + called on `env.reset()` + """ + self._queues = { + "observation.state": deque(maxlen=1), + "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + } + if self.config.image_features: + self._queues["observation.image"] = deque(maxlen=1) + if self.config.env_state_feature: + self._queues["observation.environment_state"] = deque(maxlen=1) + # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start + # CEM for the next step. + self._prev_mean: torch.Tensor | None = None + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.image"] = batch[next(iter(self.config.image_features))] + + self._queues = populate_queues(self._queues, batch) + + # When the action queue is depleted, populate it again by querying the policy. + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + encode_keys = [] + if self.config.image_features: + encode_keys.append("observation.image") + if self.config.env_state_feature: + encode_keys.append("observation.environment_state") + encode_keys.append("observation.state") + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) + else: + # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) + + actions = torch.clamp(actions, -1, +1) + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.n_action_repeats > 1: + for _ in range(self.config.n_action_repeats): + self._queues["action"].append(actions[0]) + else: + # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. + self._queues["action"].extend(actions[: self.config.n_action_steps]) + + action = self._queues["action"].popleft() + return action + + @torch.no_grad() + def plan(self, z: Tensor) -> Tensor: + """Plan sequence of actions using TD-MPC inference. + + Args: + z: (batch, latent_dim,) tensor for the initial state. + Returns: + (horizon, batch, action_dim,) tensor for the planned trajectory of actions. + """ + device = get_device_from_parameters(self) + + batch_size = z.shape[0] + + # Sample Nπ trajectories from the policy. + pi_actions = torch.empty( + self.config.horizon, + self.config.n_pi_samples, + batch_size, + self.config.action_feature.shape[0], + device=device, + ) + if self.config.n_pi_samples > 0: + _z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples) + for t in range(self.config.horizon): + # Note: Adding a small amount of noise here doesn't hurt during inference and may even be + # helpful for CEM. + pi_actions[t] = self.model.pi(_z, self.config.min_std) + _z = self.model.latent_dynamics(_z, pi_actions[t]) + + # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled + # trajectories. + z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + + # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization + # algorithm. + # The initial mean and standard deviation for the cross-entropy method (CEM). + mean = torch.zeros( + self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device + ) + # Maybe warm start CEM with the mean from the previous step. + if self._prev_mean is not None: + mean[:-1] = self._prev_mean[1:] + std = self.config.max_std * torch.ones_like(mean) + + for _ in range(self.config.cem_iterations): + # Randomly sample action trajectories for the gaussian distribution. + std_normal_noise = torch.randn( + self.config.horizon, + self.config.n_gaussian_samples, + batch_size, + self.config.action_feature.shape[0], + device=std.device, + ) + gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + + # Compute elite actions. + actions = torch.cat([gaussian_actions, pi_actions], dim=1) + value = self.estimate_value(z, actions).nan_to_num_(0) + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) + elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) + # (horizon, n_elites, batch, action_dim) + elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) + + # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. + max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) + # The weighting is a softmax over trajectory values. Note that this is not the same as the usage + # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This + # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). + score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score /= score.sum(axis=0, keepdim=True) + # (horizon, batch, action_dim) + _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) + _std = torch.sqrt( + torch.sum( + einops.rearrange(score, "n b -> n b 1") + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, + dim=1, + ) + ) + # Update mean with an exponential moving average, and std with a direct replacement. + mean = ( + self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + ) + std = _std.clamp_(self.config.min_std, self.config.max_std) + + # Keep track of the mean for warm-starting subsequent steps. + self._prev_mean = mean + + # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax + # scores from the last iteration. + actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] + + return actions + + @torch.no_grad() + def estimate_value(self, z: Tensor, actions: Tensor): + """Estimates the value of a trajectory as per eqn 4 of the FOWM paper. + + Args: + z: (batch, latent_dim) tensor of initial latent states. + actions: (horizon, batch, action_dim) tensor of action trajectories. + Returns: + (batch,) tensor of values. + """ + # Initialize return and running discount factor. + G, running_discount = 0, 1 + # Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics + # model. Keep track of return. + for t in range(actions.shape[0]): + # We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4 + # of the FOWM paper. + if self.config.uncertainty_regularizer_coeff > 0: + regularization = -( + self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + ) + else: + regularization = 0 + # Estimate the next state (latent) and reward. + z, reward = self.model.latent_dynamics_and_reward(z, actions[t]) + # Update the return and running discount. + G += running_discount * (reward + regularization) + running_discount *= self.config.discount + # Add the estimated value of the final state (using the minimum for a conservative estimate). + # Do so by predicting the next action, then taking a minimum over the ensemble of state-action value + # estimators. + # Note: This small amount of added noise seems to help a bit at inference time as observed by success + # metrics over 50 episodes of xarm_lift_medium_replay. + next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim) + terminal_values = self.model.Qs(z, next_action) # (ensemble, batch) + # Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper). + if self.config.q_ensemble_size > 2: + G += ( + running_discount + * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ + 0 + ] + ) + else: + G += running_discount * torch.min(terminal_values, dim=0)[0] + # Finally, also regularize the terminal value. + if self.config.uncertainty_regularizer_coeff > 0: + G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + return G + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss. + + Returns a dictionary with loss as a tensor, and other information as native floats. + """ + device = get_device_from_parameters(self) + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch = self.normalize_targets(batch) + + info = {} + + # (b, t) -> (t, b) + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + action = batch["action"] # (t, b, action_dim) + reward = batch["next.reward"] # (t, b) + observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + + # Apply random image augmentations. + if self.config.image_features and self.config.max_random_shift_ratio > 0: + observations["observation.image"] = flatten_forward_unflatten( + partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + observations["observation.image"], + ) + + # Get the current observation for predicting trajectories, and all future observations for use in + # the latent consistency loss and TD loss. + current_observation, next_observations = {}, {} + for k in observations: + current_observation[k] = observations[k][0] + next_observations[k] = observations[k][1:] + horizon, batch_size = next_observations[ + "observation.image" if self.config.image_features else "observation.environment_state" + ].shape[:2] + + # Run latent rollout using the latent dynamics model and policy model. + # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action + # gives us a next `z`. + batch_size = batch["index"].shape[0] + z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds[0] = self.model.encode(current_observation) + reward_preds = torch.empty_like(reward, device=device) + for t in range(horizon): + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + + # Compute Q and V value predictions based on the latent rollout. + q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + v_preds = self.model.V(z_preds[:-1]) + info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) + + # Compute various targets with stopgrad. + with torch.no_grad(): + # Latent state consistency targets. + z_targets = self.model_target.encode(next_observations) + # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the + # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM + # uses a learned state value function: V(z). This means the TD targets only depend on in-sample + # actions (not actions estimated by π). + # Note: Here we do not use self.model_target, but self.model. This is to follow the original code + # and the FOWM paper. + q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we + # are using them to compute loss for V. + v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + + # Compute losses. + # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the + # future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch). + temporal_loss_coeffs = torch.pow( + self.config.temporal_decay_coeff, torch.arange(horizon, device=device) + ).unsqueeze(-1) + # Compute consistency loss as MSE loss between latents predicted from the rollout and latents + # predicted from the (target model's) observation encoder. + consistency_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) + # `z_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # `z_targets` depends on the next observation. + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset + # rewards. + reward_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(reward_preds, reward, reduction="none") + * ~batch["next.reward_is_pad"] + # `reward_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + q_value_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss( + q_preds_ensemble, + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute state value loss as in eqn 3 of FOWM. + diff = v_targets - v_preds + # Expectile loss penalizes: + # - `v_preds < v_targets` with weighting `expectile_weight` + # - `v_preds >= v_targets` with weighting `1 - expectile_weight` + raw_v_value_loss = torch.where( + diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight) + ) * (diff**2) + v_value_loss = ( + ( + temporal_loss_coeffs + * raw_v_value_loss + # `v_targets` depends on the first observation and the actions, as does `v_preds`. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + + # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. + # We won't need these gradients again so detach. + z_preds = z_preds.detach() + # Use stopgrad for the advantage calculation. + with torch.no_grad(): + advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( + z_preds[:-1] + ) + info["advantage"] = advantage[0] + # (t, b) + exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) + # Calculate the MSE between the actions and the action predictions. + # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation + # gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to + # multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action + # dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop + # the 0.5 as we instead make a configuration parameter for it (see below where we compute the total + # loss). + mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) + # NOTE: The original implementation does not take the sum over the temporal dimension like with the + # other losses. + # TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works + # as well as expected. + pi_loss = ( + exp_advantage + * mse + * temporal_loss_coeffs + # `action_preds` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + + loss = ( + self.config.consistency_coeff * consistency_loss + + self.config.reward_coeff * reward_loss + + self.config.value_coeff * q_value_loss + + self.config.value_coeff * v_value_loss + + self.config.pi_coeff * pi_loss + ) + + info.update( + { + "consistency_loss": consistency_loss.item(), + "reward_loss": reward_loss.item(), + "Q_value_loss": q_value_loss.item(), + "V_value_loss": v_value_loss.item(), + "pi_loss": pi_loss.item(), + "sum_loss": loss.item() * self.config.horizon, + } + ) + + # Undo (b, t) -> (t, b). + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + return loss, info + + def update(self): + """Update the target model's parameters with an EMA step.""" + # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA + # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code + # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) + update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + + +class TDMPCTOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, config: TDMPCConfig): + super().__init__() + self.config = config + self._encoder = TDMPCObservationEncoder(config) + self._dynamics = nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + self._reward = nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, 1), + ) + self._pi = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.action_feature.shape[0]), + ) + self._Qs = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + for _ in range(config.q_ensemble_size) + ] + ) + self._V = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + self._init_weights() + + def _init_weights(self): + """Initialize model weights. + + Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers + of reward network and Q networks which get zero initialization). + Zero initialization for all linear and convolutional layers' biases. + """ + + def _apply_fn(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + self.apply(_apply_fn) + for m in [self._reward, *self._Qs]: + assert isinstance(m[-1], nn.Linear), ( + "Sanity check. The last linear layer needs 0 initialization on weights." + ) + nn.init.zeros_(m[-1].weight) + nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + + def encode(self, obs: dict[str, Tensor]) -> Tensor: + """Encodes an observation into its latent representation.""" + return self._encoder(obs) + + def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]: + """Predict the next state's latent representation and the reward given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + A tuple containing: + - (*, latent_dim) tensor for the next state's latent representation. + - (*,) tensor for the estimated reward. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x).squeeze(-1) + + def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor: + """Predict the next state's latent representation given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + (*, latent_dim) tensor for the next state's latent representation. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x) + + def pi(self, z: Tensor, std: float = 0.0) -> Tensor: + """Samples an action from the learned policy. + + The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when + generating rollouts for online training. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + std: The standard deviation of the injected noise. + Returns: + (*, action_dim) tensor for the sampled action. + """ + action = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(action) * std + action += torch.randn_like(action) * std + return action + + def V(self, z: Tensor) -> Tensor: # noqa: N802 + """Predict state value (V). + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + Returns: + (*,) tensor of estimated state values. + """ + return self._V(z).squeeze(-1) + + def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802 + """Predict state-action value for all of the learned Q functions. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select + 2 of the Qs and return the minimum + Returns: + (q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR + (*,) tensor if return_min=True. + """ + x = torch.cat([z, a], dim=-1) + if not return_min: + return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0) + else: + if len(self._Qs) > 2: # noqa: SIM108 + Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)] + else: + Qs = self._Qs + return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0] + + +class TDMPCObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: TDMPCConfig): + """ + Creates encoders for pixel and/or state modalities. + TODO(alexander-soare): The original work allows for multiple images by concatenating them along the + channel dimension. Re-implement this capability. + """ + super().__init__() + self.config = config + + if config.image_features: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + next(iter(config.image_features.values())).shape[0], + config.image_encoder_hidden_dim, + 7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_shape = (1, *next(iter(config.image_features.values())).shape) + out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + ) + + if config.robot_state_feature: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + if config.env_state_feature: + self.env_state_enc_layers = nn.Sequential( + nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + # NOTE: Order of observations matters here. + if self.config.image_features: + feat.append( + flatten_forward_unflatten( + self.image_enc_layers, obs_dict[next(iter(self.config.image_features))] + ) + ) + if self.config.env_state_feature: + feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV])) + if self.config.robot_state_feature: + feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT])) + return torch.stack(feat, dim=0).mean(0) + + +def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor: + """Randomly shifts images horizontally and vertically. + + Adapted from https://github.com/facebookresearch/drqv2 + """ + b, _, h, w = x.size() + assert h == w, "non-square images not handled yet" + pad = int(round(max_random_shift_ratio * h)) + x = F.pad(x, tuple([pad] * 4), "replicate") + eps = 1.0 / (h + 2 * pad) + arange = torch.linspace( + -1.0 + eps, + 1.0 - eps, + h + 2 * pad, + device=x.device, + dtype=torch.float32, + )[:h] + arange = einops.repeat(arange, "w -> h w 1", h=h) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b) + # A random shift in units of pixels and within the boundaries of the padding. + shift = torch.randint( + 0, + 2 * pad + 1, + size=(b, 1, 1, 2), + device=x.device, + dtype=torch.float32, + ) + shift *= 2.0 / (h + 2 * pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + + +def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): + """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" + for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): + for (n_p_ema, p_ema), (n_p, p) in zip( + ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ): + assert n_p_ema == n_p, "Parameter names don't match for EMA model update" + if isinstance(p, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + p_ema.copy_(p.to(dtype=p_ema.dtype).data) + with torch.no_grad(): + p_ema.mul_(alpha) + p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) + + +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally + different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c06e620ba1cec10ce22d54c92ddf48764bf92738 --- /dev/null +++ b/lerobot/common/policies/utils.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +def populate_queues(queues, batch): + for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue + if len(queues[key]) != queues[key].maxlen: + # initialize by copying the first observation several times until the queue is full + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + # add latest observation to the queue + queues[key].append(batch[key]) + return queues + + +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Note: assumes that all parameters have the same device + """ + return next(iter(module.parameters())).device + + +def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: + """Get a module's parameter dtype by checking one of its parameters. + + Note: assumes that all parameters have the same dtype. + """ + return next(iter(module.parameters())).dtype + + +def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: + """ + Calculates the output shape of a PyTorch module given an input shape. + + Args: + module (nn.Module): a PyTorch module + input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width) + + Returns: + tuple: The output shape of the module. + """ + dummy_input = torch.zeros(size=input_shape) + with torch.inference_mode(): + output = module(dummy_input) + return tuple(output.shape) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py new file mode 100644 index 0000000000000000000000000000000000000000..28e9c433833ad2f314ffa4dc44a95c2286210552 --- /dev/null +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python + +# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru +# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamConfig +from lerobot.common.optim.schedulers import VQBeTSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass("vqbet") +@dataclass +class VQBeTConfig(PreTrainedConfig): + """Configuration class for VQ-BeT. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and `output_shapes`. + + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. Right now we only support all images having the same shape. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts. + action_chunk_size: Action chunk size of each action prediction token. + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.image" refers to an input from + a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. + `None` means no pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + n_vqvae_training_steps: Number of optimization steps for training Residual VQ. + vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer). + vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary. + vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE + gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens) + gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features. + gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers. + gpt_n_layer: Number of layers of GPT + gpt_n_head: Number of headers of GPT + gpt_hidden_dim: Size of hidden dimensions of GPT + dropout: Dropout rate for GPT + mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT + offset_loss_weight: A constant that is multiplied to the offset loss + primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss + secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss + bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT + sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code, + and then select secodnary code), or at the same time. + """ + + # Inputs / output structure. + n_obs_steps: int = 5 + n_action_pred_token: int = 3 + action_chunk_size: int = 5 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + pretrained_backbone_weights: str | None = None + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + # VQ-VAE + n_vqvae_training_steps: int = 20000 + vqvae_n_embed: int = 16 + vqvae_embedding_dim: int = 256 + vqvae_enc_hidden_dim: int = 128 + # VQ-BeT + gpt_block_size: int = 500 + gpt_input_dim: int = 512 + gpt_output_dim: int = 512 + gpt_n_layer: int = 8 + gpt_n_head: int = 8 + gpt_hidden_dim: int = 512 + dropout: float = 0.1 + mlp_hidden_dim: int = 1024 + offset_loss_weight: float = 10000.0 + primary_code_loss_weight: float = 5.0 + secondary_code_loss_weight: float = 0.5 + bet_softmax_temperature: float = 0.1 + sequentially_select: bool = False + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-6 + optimizer_vqvae_lr: float = 1e-3 + optimizer_vqvae_weight_decay: float = 1e-4 + scheduler_warmup_steps: int = 500 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> VQBeTSchedulerConfig: + return VQBeTSchedulerConfig( + num_warmup_steps=self.scheduler_warmup_steps, + num_vqvae_training_steps=self.n_vqvae_training_steps, + ) + + def validate_features(self) -> None: + # Note: this check was previously performed inside VQBeTRgbEncoder in the form of + # assert len(image_keys) == 1 + if not len(self.image_features) == 1: + raise ValueError("You must provide only one image among the inputs.") + + if self.crop_shape is not None: + for key, image_ft in self.image_features.items(): + if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: + raise ValueError( + f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " + f"for `crop_shape` and {image_ft.shape} for " + f"`{key}`." + ) + + # Check that all input images have the same shape. + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py new file mode 100644 index 0000000000000000000000000000000000000000..97a08e2f4fd4358635ec411f6adee5d9e6440b64 --- /dev/null +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -0,0 +1,911 @@ +#!/usr/bin/env python + +# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru +# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import deque +from typing import Callable, List + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from torch import Tensor, nn + +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues +from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ + +# ruff: noqa: N806 + + +class VQBeTPolicy(PreTrainedPolicy): + """ + VQ-BeT Policy as per "Behavior Generation with Latent Actions" + """ + + config_class = VQBeTConfig + name = "vqbet" + + def __init__( + self, + config: VQBeTConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.vqbet = VQBeTModel(config) + + self.reset() + + def get_optim_params(self) -> dict: + vqvae_params = ( + list(self.vqbet.action_head.vqvae_model.encoder.parameters()) + + list(self.vqbet.action_head.vqvae_model.decoder.parameters()) + + list(self.vqbet.action_head.vqvae_model.vq_layer.parameters()) + ) + decay_params, no_decay_params = self.vqbet.policy.configure_parameters() + decay_params = ( + decay_params + + list(self.vqbet.rgb_encoder.parameters()) + + list(self.vqbet.state_projector.parameters()) + + list(self.vqbet.rgb_feature_projector.parameters()) + + [self.vqbet.action_token] + + list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters()) + ) + + if self.config.sequentially_select: + decay_params = ( + decay_params + + list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()) + + list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()) + ) + else: + decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters()) + + return [ + { + "params": decay_params, + }, + { + "params": vqvae_params, + "weight_decay": self.config.optimizer_vqvae_weight_decay, + "lr": self.config.optimizer_vqvae_lr, + }, + { + "params": no_decay_params, + "weight_decay": 0.0, + }, + ] + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + queues are populated during rollout of the policy, they contain the n latest observations and actions + """ + self._queues = { + "observation.images": deque(maxlen=self.config.n_obs_steps), + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.action_chunk_size), + } + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + + batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) + + if not self.vqbet.action_head.vqvae_model.discretized.item(): + warnings.warn( + "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.", + stacklevel=1, + ) + + if len(self._queues["action"]) == 0: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + + # the dimension of returned action is (batch_size, action_chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch = self.normalize_targets(batch) + # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) + if not self.vqbet.action_head.vqvae_model.discretized.item(): + # loss: total loss of training RVQ + # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. + # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). + loss, n_different_codes, n_different_combinations, recon_l1_error = ( + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + ) + return loss, { + "n_different_codes": n_different_codes, + "n_different_combinations": n_different_combinations, + "recon_l1_error": recon_l1_error, + } + # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts. + _, loss_dict = self.vqbet(batch, rollout=False) + loss = loss_dict.pop("loss") + + return loss, loss_dict + + +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + +class VQBeTModel(nn.Module): + """VQ-BeT: The underlying neural network for VQ-BeT + + Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows. + - The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors + - A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`. + - These `features` pass through the action head, which passes through the code prediction, offset prediction head, + and finally generates a prediction for the action chunks. + + -------------------------------** legend **------------------------------- + │ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │ + │ o_{t} : visual observation at timestep {t} │ + │ s_{t} : state observation at timestep {t} │ + │ a_{t} : action at timestep {t} │ + │ A_Q : action_query_token │ + -------------------------------------------------------------------------- + + + Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps) + + + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ + │ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │ + │ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │ + │ │ │ │ │ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + + Training Phase 2. + + timestep {t-n+1} timestep {t-n+2} timestep {t} + ┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐ + + o_{t-n+1} o_{t-n+2} ... o_{t} + │ │ │ + │ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p + │ │ │ │ │ │ ┌───────┴───────┐ + │ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q + │ │ │ │ │ │ │ │ │ │ + ┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐ + │ │ + │ GPT │ => policy + │ │ + └───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘ + │ │ │ │ + ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ + code offset code offset code offset code offset + ▼ │ ▼ │ ▼ │ ▼ │ => action_head + RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ + └── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘ + ▼ ▼ ▼ ▼ + action chunk action chunk action chunk action chunk + a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~ + a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1} + + ▼ + ONLY this chunk is used in rollout! + """ + + def __init__(self, config: VQBeTConfig): + super().__init__() + self.config = config + + self.rgb_encoder = VQBeTRgbEncoder(config) + self.num_images = len(self.config.image_features) + # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above. + # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results. + self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) + + # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. + self.state_projector = MLP( + config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim] + ) + self.rgb_feature_projector = MLP( + self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] + ) + + # GPT part of VQ-BeT + self.policy = GPT(config) + # bin prediction head / offset prediction head part of VQ-BeT + self.action_head = VQBeTHead(config) + + # Action tokens for: each observation step, the current action token, and all future action tokens. + num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 + self.register_buffer( + "select_target_actions_indices", + torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), + ) + + def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: + # Input validation. + assert set(batch).issuperset({"observation.state", "observation.images"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch and sequence dims. + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images + ) + + # Arrange prior and current observation step tokens as shown in the class docstring. + # First project features to token dimension. + rgb_tokens = self.rgb_feature_projector( + img_features + ) # (batch, obs_step, number of different cameras, projection dims) + input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] + input_tokens.append( + self.state_projector(batch["observation.state"]) + ) # (batch, obs_step, projection dims) + input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) + # Interleave tokens by stacking and rearranging. + input_tokens = torch.stack(input_tokens, dim=2) + input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") + + len_additional_action_token = self.config.n_action_pred_token - 1 + future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) + + # add additional action query tokens for predicting future action chunks + input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) + + # get action features (pass through GPT) + features = self.policy(input_tokens) + # len(self.config.input_features) is the number of different observation modes. + # this line gets the index of action prompt tokens. + historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len( + self.config.input_features + ) + + # only extract the output tokens at the position of action query: + # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, + # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). + # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). + if len_additional_action_token > 0: + features = torch.cat( + [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 + ) + else: + features = features[:, historical_act_pred_index] + # pass through action head + action_head_output = self.action_head(features) + # if rollout, VQ-BeT don't calculate loss + if rollout: + return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( + batch_size, self.config.action_chunk_size, -1 + ) + # else, it calculate overall loss (bin prediction loss, and offset loss) + else: + output = batch["action"][:, self.select_target_actions_indices] + loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") + return action_head_output, loss + + +class VQBeTHead(nn.Module): + def __init__(self, config: VQBeTConfig): + """ + VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`) + + self.map_to_cbet_preds_bin: outputs probability of each code (for each layer). + The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT, + and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`. + if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin. + + self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers. + The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT, + and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`. + """ + + super().__init__() + self.config = config + # init vqvae + self.vqvae_model = VqVae(config) + if config.sequentially_select: + self.map_to_cbet_preds_primary_bin = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[self.config.vqvae_n_embed], + ) + self.map_to_cbet_preds_secondary_bin = MLP( + in_channels=config.gpt_output_dim + self.config.vqvae_n_embed, + hidden_channels=[self.config.vqvae_n_embed], + ) + else: + self.map_to_cbet_preds_bin = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], + ) + self.map_to_cbet_preds_offset = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[ + self.vqvae_model.vqvae_num_layers + * self.config.vqvae_n_embed + * config.action_chunk_size + * config.action_feature.shape[0], + ], + ) + # loss + self._focal_loss_fn = FocalLoss(gamma=2.0) + + def discretize(self, n_vqvae_training_steps, actions): + # Resize the action sequence data to fit the action chunk size using a sliding window approach. + actions = torch.cat( + [ + actions[:, j : j + self.config.action_chunk_size, :] + for j in range(actions.shape[1] + 1 - self.config.action_chunk_size) + ], + dim=0, + ) + # `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window. + + loss, metric = self.vqvae_model.vqvae_forward(actions) + n_different_codes = sum( + [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] + ) + n_different_combinations = len(torch.unique(metric[2], dim=0)) + recon_l1_error = metric[0].detach().cpu().item() + self.vqvae_model.optimized_steps += 1 + # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part. + if self.vqvae_model.optimized_steps >= n_vqvae_training_steps: + self.vqvae_model.discretized = torch.tensor(True) + self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True) + print("Finished discretizing action data!") + self.vqvae_model.eval() + for param in self.vqvae_model.vq_layer.parameters(): + param.requires_grad = False + return loss, n_different_codes, n_different_combinations, recon_l1_error + + def forward(self, x, **kwargs) -> dict: + # N is the batch size, and T is number of action query tokens, which are process through same GPT + N, T, _ = x.shape + # we calculate N and T side parallelly. Thus, the dimensions would be + # (batch size * number of action query tokens, action chunk size, action dimension) + x = einops.rearrange(x, "N T WA -> (N T) WA") + + # sample offsets + cbet_offsets = self.map_to_cbet_preds_offset(x) + cbet_offsets = einops.rearrange( + cbet_offsets, + "(NT) (G C WA) -> (NT) G C WA", + G=self.vqvae_model.vqvae_num_layers, + C=self.config.vqvae_n_embed, + ) + # if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code + if self.config.sequentially_select: + cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x) + + # select primary bin first + cbet_primary_probs = torch.softmax( + cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1 + ) + NT, choices = cbet_primary_probs.shape + sampled_primary_centers = einops.rearrange( + torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + + cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( + torch.cat( + (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)), + axis=1, + ) + ) + cbet_secondary_probs = torch.softmax( + cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 + ) + sampled_secondary_centers = einops.rearrange( + torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) + cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) + # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. + else: + cbet_logits = self.map_to_cbet_preds_bin(x) + cbet_logits = einops.rearrange( + cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers + ) + cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) + NT, G, choices = cbet_probs.shape + sampled_centers = einops.rearrange( + torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), + "(NT G) 1 -> NT G", + NT=NT, + ) + + device = get_device_from_parameters(self) + indices = ( + torch.arange(NT, device=device).unsqueeze(1), + torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0), + sampled_centers, + ) + # Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.) + sampled_offsets = cbet_offsets[indices] + # Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction + sampled_offsets = sampled_offsets.sum(dim=1) + with torch.no_grad(): + # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder + return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() + # pass the centroids through decoder to get actions. + decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() + # reshaped extracted offset to match with decoded centroids + sampled_offsets = einops.rearrange( + sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size + ) + # add offset and decoded centroids + predicted_action = decoded_action + sampled_offsets + predicted_action = einops.rearrange( + predicted_action, + "(N T) W A -> N T (W A)", + N=N, + T=T, + W=self.config.action_chunk_size, + ) + + return { + "cbet_logits": cbet_logits, + "predicted_action": predicted_action, + "sampled_centers": sampled_centers, + "decoded_action": decoded_action, + } + + def loss_fn(self, pred, target, **kwargs): + """ + for given ground truth action values (target), and prediction (pred) this function calculates the overall loss. + + predicted_action: predicted action chunk (offset + decoded centroids) + sampled_centers: sampled centroids (code of RVQ) + decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder + NT: batch size * T + T: number of action query tokens, which are process through same GPT + cbet_logits: probability of all codes in each layer + """ + action_seq = target + predicted_action = pred["predicted_action"] + sampled_centers = pred["sampled_centers"] + decoded_action = pred["decoded_action"] + NT = predicted_action.shape[0] * predicted_action.shape[1] + + cbet_logits = pred["cbet_logits"] + + predicted_action = einops.rearrange( + predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size + ) + + action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") + # Figure out the loss for the actions. + # First, we need to find the closest cluster center for each ground truth action. + with torch.no_grad(): + state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G + + # Now we can compute the loss. + + # offset loss is L1 distance between the predicted action and ground truth action + offset_loss = F.l1_loss(action_seq, predicted_action) + + # calculate primary code prediction loss + cbet_loss1 = self._focal_loss_fn( + cbet_logits[:, 0, :], + action_bins[:, 0], + ) + # calculate secondary code prediction loss + cbet_loss2 = self._focal_loss_fn( + cbet_logits[:, 1, :], + action_bins[:, 1], + ) + # add all the prediction loss + cbet_loss = ( + cbet_loss1 * self.config.primary_code_loss_weight + + cbet_loss2 * self.config.secondary_code_loss_weight + ) + + equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) + equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) + + action_mse_error = torch.mean((action_seq - predicted_action) ** 2) + vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) + offset_action_error = torch.mean(torch.abs(action_seq - predicted_action)) + action_error_max = torch.max(torch.abs(action_seq - predicted_action)) + + loss = cbet_loss + self.config.offset_loss_weight * offset_loss + + loss_dict = { + "loss": loss, + "classification_loss": cbet_loss.detach().cpu().item(), + "offset_loss": offset_loss.detach().cpu().item(), + "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), + "vq_action_error": vq_action_error.detach().cpu().item(), + "offset_action_error": offset_action_error.detach().cpu().item(), + "action_error_max": action_error_max.detach().cpu().item(), + "action_mse_error": action_mse_error.detach().cpu().item(), + } + return loss_dict + + +class VQBeTRgbEncoder(nn.Module): + """Encode an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + + Same with DiffusionRgbEncoder from modeling_diffusion.py + """ + + def __init__(self, config: VQBeTConfig): + super().__init__() + # Set up optional preprocessing. + if config.crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + if config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, config.vision_backbone)( + weights=config.pretrained_backbone_weights + ) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if config.use_group_norm: + if config.pretrained_backbone_weights: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.image_features` and it should + # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the + # height and width from `config.image_features`. + + images_shape = next(iter(config.image_features.values())).shape + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape = (1, images_shape[0], *dummy_shape_h_w) + feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] + + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.feature_dim = config.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: maybe crop (if it was set up in the __init__). + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class VqVae(nn.Module): + def __init__( + self, + config: VQBeTConfig, + ): + """ + VQ-VAE is composed of three parts: encoder, vq_layer, and decoder. + Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively. + The vq_layer uses residual VQs. + + This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1), + as well as functions to help BeT training part in training phase 2. + """ + + super().__init__() + self.config = config + # 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True) + self.register_buffer("discretized", torch.tensor(False)) + self.optimized_steps = 0 + # we use the fixed number of layers for Residual VQ across all environments. + self.vqvae_num_layers = 2 + + self.vq_layer = ResidualVQ( + dim=config.vqvae_embedding_dim, + num_quantizers=self.vqvae_num_layers, + codebook_size=config.vqvae_n_embed, + ) + + self.encoder = MLP( + in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size, + hidden_channels=[ + config.vqvae_enc_hidden_dim, + config.vqvae_enc_hidden_dim, + config.vqvae_embedding_dim, + ], + ) + self.decoder = MLP( + in_channels=config.vqvae_embedding_dim, + hidden_channels=[ + config.vqvae_enc_hidden_dim, + config.vqvae_enc_hidden_dim, + self.config.action_feature.shape[0] * self.config.action_chunk_size, + ], + ) + + def get_embeddings_from_code(self, encoding_indices): + # This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices. + with torch.no_grad(): + z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices) + # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination. + z_embed = z_embed.sum(dim=0) + return z_embed + + def get_action_from_latent(self, latent): + # given latent vector, this function outputs the decoded action. + output = self.decoder(latent) + if self.config.action_chunk_size == 1: + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) + else: + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) + + def get_code(self, state): + # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) + # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181) + state = einops.rearrange(state, "N T A -> N (T A)") + with torch.no_grad(): + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + vq_loss_state = torch.sum(vq_loss_state) + return state_vq, vq_code + + def vqvae_forward(self, state): + # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181). + state = einops.rearrange(state, "N T A -> N (T A)") + # We start with passing action (or action chunk) at:t+n through the encoder ϕ. + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + # The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up. + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination. + vq_loss_state = torch.sum(vq_loss_state) + # Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ. + dec_out = self.decoder(state_vq) + # Calculate L1 reconstruction loss + encoder_loss = (state - dec_out).abs().mean() + # add encoder reconstruction loss and commitment loss + rep_loss = encoder_loss + vq_loss_state * 5 + + metric = ( + encoder_loss.clone().detach(), + vq_loss_state.clone().detach(), + vq_code, + rep_loss.item(), + ) + return rep_loss, metric + + +class FocalLoss(nn.Module): + """ + From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py + """ + + def __init__(self, gamma: float = 0, size_average: bool = True): + super().__init__() + self.gamma = gamma + self.size_average = size_average + + def forward(self, input, target): + if len(input.shape) == 3: + N, T, _ = input.shape + logpt = F.log_softmax(input, dim=-1) + logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T) + elif len(input.shape) == 2: + logpt = F.log_softmax(input, dim=-1) + logpt = logpt.gather(-1, target.view(-1, 1)).view(-1) + pt = logpt.exp() + + loss = -1 * (1 - pt) ** self.gamma * logpt + if self.size_average: + return loss.mean() + else: + return loss.sum() + + +class MLP(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + hidden_channels: List[int], + ): + layers = [] + in_dim = in_channels + for hidden_dim in hidden_channels[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim)) + layers.append(torch.nn.ReLU()) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1])) + + super().__init__(*layers) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..139d119edc4f5107a4a0ad4630386be7390504ee --- /dev/null +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -0,0 +1,1462 @@ +#!/usr/bin/env python + +# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru +# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from math import ceil +from random import randrange +from typing import Callable + +import torch +import torch.distributed as distributed +import torch.nn.functional as F # noqa: N812 +from einops import pack, rearrange, reduce, repeat, unpack +from torch import einsum, nn +from torch.cuda.amp import autocast +from torch.optim import Optimizer + +from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig + +# ruff: noqa: N806 + +""" +This file is part of a VQ-BeT that utilizes code from the following repositories: + + - Vector Quantize PyTorch code is licensed under the MIT License: + Original source: https://github.com/lucidrains/vector-quantize-pytorch + + - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. + Original source: https://github.com/karpathy/nanoGPT + +We also made some changes to the original code to adapt it to our needs. The changes are described in the code below. +""" + +""" +This is a part for nanoGPT that utilizes code from the following repository: + + - Andrej Karpathy's nanoGPT implementation in PyTorch. + Original source: https://github.com/karpathy/nanoGPT + + - The nanoGPT code is licensed under the MIT License: + + MIT License + + Copyright (c) 2022 Andrej Karpathy + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + - We've made some changes to the original code to adapt it to our needs. + + Changed variable names: + - n_head -> gpt_n_head + - n_embd -> gpt_hidden_dim + - block_size -> gpt_block_size + - n_layer -> gpt_n_layer + + + class GPT(nn.Module): + - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained` + - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop. + - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads). + +""" + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.gpt_hidden_dim % config.gpt_n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim) + # output projection + self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view( + 1, 1, config.gpt_block_size, config.gpt_block_size + ), + ) + self.gpt_n_head = config.gpt_n_head + self.gpt_hidden_dim = config.gpt_hidden_dim + + def forward(self, x): + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (gpt_hidden_dim) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class Block(nn.Module): + # causual self-attention block for GPT + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim), + nn.GELU(), + nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim), + nn.Dropout(config.dropout), + ) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class GPT(nn.Module): + """ + Original comments: + Full definition of a GPT Language Model, all of it in this single file. + References: + 1) the official GPT-2 TensorFlow implementation released by OpenAI: + https://github.com/openai/gpt-2/blob/master/src/model.py + 2) huggingface/transformers PyTorch implementation: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py + """ + + def __init__(self, config: VQBeTConfig): + """ + GPT model gets hyperparameters from a config object. Please refer configuration_vqbet.py for more details. + """ + super().__init__() + assert config.gpt_output_dim is not None + assert config.gpt_block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + { + "wte": nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim), + "wpe": nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim), + "drop": nn.Dropout(config.dropout), + "h": nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]), + "ln_f": nn.LayerNorm(config.gpt_hidden_dim), + } + ) + self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) + # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) + + # report number of parameters + n_params = sum(p.numel() for p in self.parameters()) + print("number of parameters: {:.2f}M".format(n_params / 1e6)) + + def forward(self, input, targets=None): + device = input.device + b, t, d = input.size() + assert t <= self.config.gpt_block_size, ( + f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + ) + + # positional encodings that are added to the input embeddings + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + return logits + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + def crop_block_size(self, gpt_block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert gpt_block_size <= self.config.gpt_block_size + self.config.gpt_block_size = gpt_block_size + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] + + def configure_parameters(self): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _p in m.named_parameters(): + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = dict(self.named_parameters()) + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) + ) + + decay = [param_dict[pn] for pn in sorted(decay)] + no_decay = [param_dict[pn] for pn in sorted(no_decay)] + # return the parameters that require weight decay, and the parameters that don't separately. + return decay, no_decay + + +""" +This file is a part for Residual Vector Quantization that utilizes code from the following repository: + + - Phil Wang's vector-quantize-pytorch implementation in PyTorch. + Original source: https://github.com/lucidrains/vector-quantize-pytorch + + - The vector-quantize-pytorch code is licensed under the MIT License: + + MIT License + + Copyright (c) 2020 Phil Wang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + - We've made some changes to the original code to adapt it to our needs. + + class ResidualVQ(nn.Module): + - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method: + This enables the user to save an indicator whether the codebook is frozen or not. + - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: + This is to make the function name more descriptive. + + class VectorQuantize(nn.Module): + - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method: + These parameters are not used in the code. + - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: + This is to make the function name more descriptive. + +""" + + +class ResidualVQ(nn.Module): + """ + Residual VQ is composed of multiple VectorQuantize layers. + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + "Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is + passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional + Nq -1 vector quantizers, as described in Algorithm 1." + + + self.project_in: function for projecting input to codebook dimension + self.project_out: function for projecting codebook dimension to output dimension + self.layers: nn.ModuleList of VectorQuantize layers that contains Nq layers of VQ as described in the paper. + self.freeze_codebook: buffer to save an indicator whether the codebook is frozen or not. VQ-BeT will check this to determine whether to update the codebook or not. + """ + + def __init__( + self, + *, + dim, + num_quantizers, + codebook_dim=None, + shared_codebook=False, + heads=1, + quantize_dropout=False, + quantize_dropout_cutoff_index=0, + quantize_dropout_multiple_of=1, + accept_image_fmap=False, + **kwargs, + ): + super().__init__() + assert heads == 1, "residual vq is not compatible with multi-headed codes" + codebook_dim = codebook_dim if (codebook_dim is not None) else dim + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.num_quantizers = num_quantizers + + self.accept_image_fmap = accept_image_fmap + self.layers = nn.ModuleList( + [ + VectorQuantize( + dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs + ) + for _ in range(num_quantizers) + ] + ) + + self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 + + self.register_buffer("freeze_codebook", torch.tensor(False)) + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 + + if not shared_codebook: + return + + first_vq, *rest_vq = self.layers + codebook = first_vq._codebook + + for vq in rest_vq: + vq._codebook = codebook + + @property + def codebooks(self): + codebooks = [layer._codebook.embed for layer in self.layers] + codebooks = torch.stack(codebooks, dim=0) + codebooks = rearrange(codebooks, "q 1 c d -> q c d") + return codebooks + + def get_codebook_vector_from_indices(self, indices): + # this function will return the codes from all codebooks across layers corresponding to the indices + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + + indices, ps = pack([indices], "b * q") + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + assert self.quantize_dropout > 0.0, ( + "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + ) + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) + + # get ready for gathering + + codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch) + gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1]) + + # take care of quantizer dropout + + mask = gather_indices == -1.0 + gather_indices = gather_indices.masked_fill( + mask, 0 + ) # have it fetch a dummy code to be masked out later + + all_codes = codebooks.gather(2, gather_indices) # gather all codes + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(mask, 0.0) + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + (all_codes,) = unpack(all_codes, ps, "q b * d") + + return all_codes + + def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): + """ + For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. + First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. + The residual value of each layer is fed to the next layer. + """ + num_quant, quant_dropout_multiple_of, return_loss, device = ( + self.num_quantizers, + self.quantize_dropout_multiple_of, + (indices is not None), + x.device, + ) + + x = self.project_in(x) + + assert not (self.accept_image_fmap and (indices is not None)) + + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + if return_loss: + assert not torch.any(indices == -1), ( + "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + ) + ce_losses = [] + + should_quantize_dropout = self.training and self.quantize_dropout and not return_loss + + # sample a layer index at which to dropout further residual quantization + # also prepare null indices and loss + + if should_quantize_dropout: + rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) + + if quant_dropout_multiple_of != 1: + rand_quantize_dropout_index = ( + ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of) + * quant_dropout_multiple_of + - 1 + ) + + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) + null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) + null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) + + # go through the layers + + for quantizer_index, layer in enumerate(self.layers): + if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: + all_indices.append(null_indices) + all_losses.append(null_loss) + continue + + layer_indices = None + if return_loss: + layer_indices = indices[..., quantizer_index] + + quantized, *rest = layer( + residual, + indices=layer_indices, + sample_codebook_temp=sample_codebook_temp, + freeze_codebook=self.freeze_codebook, + ) + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + if return_loss: + ce_loss = rest[0] + ce_losses.append(ce_loss) + continue + + embed_indices, loss = rest + + all_indices.append(embed_indices) + all_losses.append(loss) + + # project out, if needed + + quantized_out = self.project_out(quantized_out) + + # whether to early return the cross entropy loss + + if return_loss: + return quantized_out, sum(ce_losses) + + # stack all losses and indices + + all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) + + ret = (quantized_out, all_indices, all_losses) + + if return_all_codes: + # whether to return all codes from all codebooks across layers + all_codes = self.get_codebook_vector_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + ret = (*ret, all_codes) + + return ret + + +class VectorQuantize(nn.Module): + def __init__( + self, + dim, + codebook_size, + codebook_dim=None, + heads=1, + separate_codebook_per_head=False, + decay=0.8, + eps=1e-5, + kmeans_init=False, + kmeans_iters=10, + sync_kmeans=True, + threshold_ema_dead_code=0, + channel_last=True, + accept_image_fmap=False, + commitment_weight=1.0, + commitment_use_cross_entropy_loss=False, + orthogonal_reg_weight=0.0, + orthogonal_reg_active_codes_only=False, + orthogonal_reg_max_codes=None, + stochastic_sample_codes=False, + sample_codebook_temp=1.0, + straight_through=False, + reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all + sync_codebook=None, + sync_affine_param=False, + ema_update=True, + learnable_codebook=False, + in_place_codebook_optimizer: Callable[ + ..., Optimizer + ] = None, # Optimizer used to update the codebook embedding if using learnable_codebook + affine_param=False, + affine_param_batch_decay=0.99, + affine_param_codebook_decay=0.9, + sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + ): + super().__init__() + self.dim = dim + self.heads = heads + self.separate_codebook_per_head = separate_codebook_per_head + + codebook_dim = codebook_dim if (codebook_dim is not None) else dim + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.eps = eps + self.commitment_weight = commitment_weight + self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss + + self.learnable_codebook = learnable_codebook + + has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 + self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" + + assert 0 <= sync_update_v <= 1.0 + assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" + + self.sync_update_v = sync_update_v + + gumbel_sample_fn = partial( + gumbel_sample, + stochastic=stochastic_sample_codes, + reinmax=reinmax, + straight_through=straight_through, + ) + + if sync_codebook is None: + sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 + + codebook_kwargs = { + "dim": codebook_dim, + "num_codebooks": heads if separate_codebook_per_head else 1, + "codebook_size": codebook_size, + "kmeans_init": kmeans_init, + "kmeans_iters": kmeans_iters, + "sync_kmeans": sync_kmeans, + "decay": decay, + "eps": eps, + "threshold_ema_dead_code": threshold_ema_dead_code, + "use_ddp": sync_codebook, + "learnable_codebook": has_codebook_orthogonal_loss or learnable_codebook, + "sample_codebook_temp": sample_codebook_temp, + "gumbel_sample": gumbel_sample_fn, + "ema_update": ema_update, + } + + if affine_param: + codebook_kwargs = dict( + **codebook_kwargs, + affine_param=True, + sync_affine_param=sync_affine_param, + affine_param_batch_decay=affine_param_batch_decay, + affine_param_codebook_decay=affine_param_codebook_decay, + ) + + self._codebook = EuclideanCodebook(**codebook_kwargs) + + self.in_place_codebook_optimizer = ( + in_place_codebook_optimizer(self._codebook.parameters()) + if (in_place_codebook_optimizer is not None) + else None + ) + + self.codebook_size = codebook_size + + self.accept_image_fmap = accept_image_fmap + self.channel_last = channel_last + + @property + def codebook(self): + codebook = self._codebook.embed + + if self.separate_codebook_per_head: + return codebook + + return rearrange(codebook, "1 ... -> ...") + + @codebook.setter + def codebook(self, codes): + if not self.separate_codebook_per_head: + codes = rearrange(codes, "... -> 1 ...") + + self._codebook.embed.copy_(codes) + + def get_codebook_vector_from_indices(self, indices): + codebook = self.codebook + is_multiheaded = codebook.ndim > 2 + + if not is_multiheaded: + codes = codebook[indices] + return rearrange(codes, "... h d -> ... (h d)") + + indices, ps = pack_one(indices, "b * h") + indices = rearrange(indices, "b n h -> b h n") + + indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1]) + codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0]) + + codes = codebook.gather(2, indices) + codes = rearrange(codes, "b h n d -> b n (h d)") + codes = unpack_one(codes, ps, "b * d") + return codes + + def forward( + self, + x, + indices=None, + mask=None, + sample_codebook_temp=None, + freeze_codebook=False, + ): + orig_input = x + + only_one = x.ndim == 2 + + if only_one: + assert mask is None + x = rearrange(x, "b d -> b 1 d") + + shape, device, heads, is_multiheaded, _codebook_size, return_loss = ( + x.shape, + x.device, + self.heads, + self.heads > 1, + self.codebook_size, + (indices is not None), + ) + + need_transpose = not self.channel_last and not self.accept_image_fmap + should_inplace_optimize = self.in_place_codebook_optimizer is not None + + # rearrange inputs + + if self.accept_image_fmap: + height, width = x.shape[-2:] + x = rearrange(x, "b c h w -> b (h w) c") + + if need_transpose: + x = rearrange(x, "b d n -> b n d") + + # project input + + x = self.project_in(x) + + # handle multi-headed separate codebooks + + if is_multiheaded: + ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d" + x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads) + + # l2norm for cosine sim, otherwise identity + + x = self._codebook.transform_input(x) + + # codebook forward kwargs + + codebook_forward_kwargs = { + "sample_codebook_temp": sample_codebook_temp, + "mask": mask, + "freeze_codebook": freeze_codebook, + } + + # quantize + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + # one step in-place update + + if should_inplace_optimize and self.training and not freeze_codebook: + if mask is not None: + loss = F.mse_loss(quantize, x.detach(), reduction="none") + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat( + mask, + "b n -> c (b h) n", + c=loss.shape[0], + h=loss.shape[1] // mask.shape[0], + ) + + loss = loss[loss_mask].mean() + + else: + loss = F.mse_loss(quantize, x.detach()) + + loss.backward() + self.in_place_codebook_optimizer.step() + self.in_place_codebook_optimizer.zero_grad() + + # quantize again + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + if self.training: + # determine code to use for commitment loss + maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity + + commit_quantize = maybe_detach(quantize) + + # straight through + + quantize = x + (quantize - x).detach() + + if self.sync_update_v > 0.0: + # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + + # function for calculating cross entropy loss to distance matrix + # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss + + def calculate_ce_loss(codes): + if not is_multiheaded: + dist_einops_eq = "1 b n l -> b l n" + elif self.separate_codebook_per_head: + dist_einops_eq = "c b n l -> b l n c" + else: + dist_einops_eq = "1 (b h) n l -> b l n h" + + ce_loss = F.cross_entropy( + rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1 + ) + + return ce_loss + + # if returning cross entropy loss on codes that were passed in + + if return_loss: + return quantize, calculate_ce_loss(indices) + + # transform embedding indices + + if is_multiheaded: + if self.separate_codebook_per_head: + embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads) + else: + embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) + + if self.accept_image_fmap: + embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) + + if only_one: + embed_ind = rearrange(embed_ind, "b 1 -> b") + + # aggregate loss + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + if self.commitment_use_cross_entropy_loss: + if mask is not None: + ce_loss_mask = mask + if is_multiheaded: + ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads) + + embed_ind.masked_fill_(~ce_loss_mask, -1) + + commit_loss = calculate_ce_loss(embed_ind) + else: + if mask is not None: + # with variable lengthed sequences + commit_loss = F.mse_loss(commit_quantize, x, reduction="none") + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat( + loss_mask, + "b n -> c (b h) n", + c=commit_loss.shape[0], + h=commit_loss.shape[1] // mask.shape[0], + ) + + commit_loss = commit_loss[loss_mask].mean() + else: + commit_loss = F.mse_loss(commit_quantize, x) + + loss = loss + commit_loss * self.commitment_weight + + if self.has_codebook_orthogonal_loss: + codebook = self._codebook.embed + + # only calculate orthogonal loss for the activated codes for this batch + + if self.orthogonal_reg_active_codes_only: + assert not (is_multiheaded and self.separate_codebook_per_head), ( + "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + ) + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[:, unique_code_ids] + + num_codes = codebook.shape[-2] + + if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] + codebook = codebook[:, rand_ids] + + orthogonal_reg_loss = orthogonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + # handle multi-headed quantized embeddings + + if is_multiheaded: + if self.separate_codebook_per_head: + quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads) + else: + quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads) + + # project out + + quantize = self.project_out(quantize) + + # rearrange quantized embeddings + + if need_transpose: + quantize = rearrange(quantize, "b n d -> b d n") + + if self.accept_image_fmap: + quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width) + + if only_one: + quantize = rearrange(quantize, "b 1 d -> b d") + + # if masking, only return quantized for where mask has True + + if mask is not None: + quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) + + return quantize, embed_ind, loss + + +def noop(*args, **kwargs): + pass + + +def identity(t): + return t + + +def cdist(x, y): + x2 = reduce(x**2, "b n d -> b n", "sum") + y2 = reduce(y**2, "b n d -> b n", "sum") + xy = einsum("b i d, b j d -> b i j", x, y) * -2 + return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt() + + +def log(t, eps=1e-20): + return torch.log(t.clamp(min=eps)) + + +def ema_inplace(old, new, decay): + is_mps = str(old.device).startswith("mps:") + + if not is_mps: + old.lerp_(new, 1 - decay) + else: + old.mul_(decay).add_(new * (1 - decay)) + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def uniform_init(*shape): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + +def gumbel_sample( + logits, + temperature=1.0, + stochastic=False, + straight_through=False, + reinmax=False, + dim=-1, + training=True, +): + dtype, size = logits.dtype, logits.shape[dim] + + if training and stochastic and temperature > 0: + sampling_logits = (logits / temperature) + gumbel_noise(logits) + else: + sampling_logits = logits + + ind = sampling_logits.argmax(dim=dim) + one_hot = F.one_hot(ind, size).type(dtype) + + assert not (reinmax and not straight_through), ( + "reinmax can only be turned on if using straight through gumbel softmax" + ) + + if not straight_through or temperature <= 0.0 or not training: + return ind, one_hot + + # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # algorithm 2 + + if reinmax: + π0 = logits.softmax(dim=dim) + π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2 + π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1) + π2 = 2 * π1 - 0.5 * π0 + one_hot = π2 - π2.detach() + one_hot + else: + π1 = (logits / temperature).softmax(dim=dim) + one_hot = one_hot + π1 - π1.detach() + + return ind, one_hot + + +def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1): + denom = x.sum(dim=dim, keepdim=True) + return (x + eps) / (denom + n_categories * eps) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def batched_sample_vectors(samples, num): + return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) + + +def pad_shape(shape, size, dim=0): + return [size if i == dim else s for i, s in enumerate(shape)] + + +def sample_multinomial(total_count, probs): + device = probs.device + probs = probs.cpu() + + total_count = probs.new_full((), total_count) + remainder = probs.new_ones(()) + sample = torch.empty_like(probs, dtype=torch.long) + + for i, p in enumerate(probs): + s = torch.binomial(total_count, p / remainder) + sample[i] = s + total_count -= s + remainder -= p + + return sample.to(device) + + +def all_gather_sizes(x, dim): + size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device) + all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] + distributed.all_gather(all_sizes, size) + return torch.stack(all_sizes) + + +def all_gather_variably_sized(x, sizes, dim=0): + rank = distributed.get_rank() + all_x = [] + + for i, size in enumerate(sizes): + t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) + distributed.broadcast(t, src=i, async_op=True) + all_x.append(t) + + distributed.barrier() + return all_x + + +def sample_vectors_distributed(local_samples, num): + local_samples = rearrange(local_samples, "1 ... -> ...") + + rank = distributed.get_rank() + all_num_samples = all_gather_sizes(local_samples, dim=0) + + if rank == 0: + samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) + else: + samples_per_rank = torch.empty_like(all_num_samples) + + distributed.broadcast(samples_per_rank, src=0) + samples_per_rank = samples_per_rank.tolist() + + local_samples = sample_vectors(local_samples, samples_per_rank[rank]) + all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0) + out = torch.cat(all_samples, dim=0) + + return rearrange(out, "... -> 1 ...") + + +def batched_bincount(x, *, minlength): + batch, dtype, device = x.shape[0], x.dtype, x.device + target = torch.zeros(batch, minlength, dtype=dtype, device=device) + values = torch.ones_like(x) + target.scatter_add_(-1, x, values) + return target + + +def kmeans( + samples, + num_clusters, + num_iters=10, + sample_fn=batched_sample_vectors, + all_reduce_fn=noop, +): + num_codebooks, dim, dtype, _device = ( + samples.shape[0], + samples.shape[-1], + samples.dtype, + samples.device, + ) + + means = sample_fn(samples, num_clusters) + + for _ in range(num_iters): + dists = -torch.cdist(samples, means, p=2) + + buckets = torch.argmax(dists, dim=-1) + bins = batched_bincount(buckets, minlength=num_clusters) + all_reduce_fn(bins) + + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype) + + new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples) + new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1") + all_reduce_fn(new_means) + + means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means) + + return means, bins + + +def batched_embedding(indices, embeds): + batch, dim = indices.shape[1], embeds.shape[-1] + indices = repeat(indices, "h b n -> h b n d", d=dim) + embeds = repeat(embeds, "h c d -> h b c d", b=batch) + return embeds.gather(2, indices) + + +def orthogonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + h, n = t.shape[:2] + normed_codes = F.normalize(t, p=2, dim=-1) + cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) + return (cosine_sim**2).sum() / (h * n**2) - (1 / n) + + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + num_codebooks=1, + kmeans_init=False, + kmeans_iters=10, + sync_kmeans=True, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + reset_cluster_size=None, + use_ddp=False, + learnable_codebook=False, + gumbel_sample=gumbel_sample, + sample_codebook_temp=1.0, + ema_update=True, + affine_param=False, + sync_affine_param=False, + affine_param_batch_decay=0.99, + affine_param_codebook_decay=0.9, + ): + super().__init__() + self.transform_input = identity + + self.decay = decay + self.ema_update = ema_update + + init_fn = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(num_codebooks, codebook_size, dim) + + self.codebook_size = codebook_size + self.num_codebooks = num_codebooks + + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.reset_cluster_size = ( + reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code + ) + + assert callable(gumbel_sample) + self.gumbel_sample = gumbel_sample + self.sample_codebook_temp = sample_codebook_temp + + assert not (use_ddp and num_codebooks > 1 and kmeans_init), ( + "kmeans init is not compatible with multiple codebooks in distributed environment for now" + ) + + self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors + self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + + self.register_buffer("initted", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) + self.register_buffer("embed_avg", embed.clone()) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer("embed", embed) + + # affine related params + + self.affine_param = affine_param + self.sync_affine_param = sync_affine_param + + if not affine_param: + return + + self.affine_param_batch_decay = affine_param_batch_decay + self.affine_param_codebook_decay = affine_param_codebook_decay + + self.register_buffer("batch_mean", None) + self.register_buffer("batch_variance", None) + + self.register_buffer("codebook_mean_needs_init", torch.Tensor([True])) + self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim)) + self.register_buffer("codebook_variance_needs_init", torch.Tensor([True])) + self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim)) + + @torch.jit.ignore + def init_embed_(self, data, mask=None): + if self.initted: + return + + if mask is not None: + c = data.shape[0] + data = rearrange(data[mask], "(c n) d -> c n d", c=c) + + embed, cluster_size = kmeans( + data, + self.codebook_size, + self.kmeans_iters, + sample_fn=self.sample_fn, + all_reduce_fn=self.kmeans_all_reduce_fn, + ) + + embed_sum = embed * rearrange(cluster_size, "... -> ... 1") + + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed_sum) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + @torch.jit.ignore + def update_with_decay(self, buffer_name, new_value, decay): + old_value = getattr(self, buffer_name) + + needs_init = getattr(self, buffer_name + "_needs_init", False) + + if needs_init: + self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) + + if not (old_value is not None) or needs_init: + self.register_buffer(buffer_name, new_value.detach()) + + return + + value = old_value * decay + new_value.detach() * (1 - decay) + self.register_buffer(buffer_name, value) + + @torch.jit.ignore + def update_affine(self, data, embed, mask=None): + assert self.affine_param + + var_fn = partial(torch.var, unbiased=False) + + # calculate codebook mean and variance + + embed = rearrange(embed, "h ... d -> h (...) d") + + if self.training: + self.update_with_decay( + "codebook_mean", + reduce(embed, "h n d -> h 1 d", "mean"), + self.affine_param_codebook_decay, + ) + self.update_with_decay( + "codebook_variance", + reduce(embed, "h n d -> h 1 d", var_fn), + self.affine_param_codebook_decay, + ) + + # prepare batch data, which depends on whether it has masking + + data = rearrange(data, "h ... d -> h (...) d") + + if mask is not None: + c = data.shape[0] + data = rearrange(data[mask], "(c n) d -> c n d", c=c) + + # calculate batch mean and variance + + if not self.sync_affine_param: + self.update_with_decay( + "batch_mean", + reduce(data, "h n d -> h 1 d", "mean"), + self.affine_param_batch_decay, + ) + self.update_with_decay( + "batch_variance", + reduce(data, "h n d -> h 1 d", var_fn), + self.affine_param_batch_decay, + ) + return + + num_vectors, device, dtype = data.shape[-2], data.device, data.dtype + + # number of vectors, for denominator + + num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype) + distributed.all_reduce(num_vectors) + + # calculate distributed mean + + batch_sum = reduce(data, "h n d -> h 1 d", "sum") + distributed.all_reduce(batch_sum) + batch_mean = batch_sum / num_vectors + + self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay) + + # calculate distributed variance + + variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum") + distributed.all_reduce(variance_number) + batch_variance = variance_number / num_vectors + + self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) + + def replace(self, batch_samples, batch_mask): + for ind, (samples, mask) in enumerate( + zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0), strict=False) + ): + if not torch.any(mask): + continue + + sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) + sampled = rearrange(sampled, "1 ... -> ...") + + self.embed.data[ind][mask] = sampled + + self.cluster_size.data[ind][mask] = self.reset_cluster_size + self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") + self.replace(batch_samples, batch_mask=expired_codes) + + @autocast(enabled=False) + def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): + needs_codebook_dim = x.ndim < 4 + sample_codebook_temp = ( + sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp + ) + + x = x.float() + + if needs_codebook_dim: + x = rearrange(x, "... -> 1 ...") + + flatten, ps = pack_one(x, "h * d") + + if mask is not None: + mask = repeat( + mask, + "b n -> c (b h n)", + c=flatten.shape[0], + h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), + ) + + self.init_embed_(flatten, mask=mask) + + if self.affine_param: + self.update_affine(flatten, self.embed, mask=mask) + + embed = self.embed if self.learnable_codebook else self.embed.detach() + + if self.affine_param: + codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() + batch_std = self.batch_variance.clamp(min=1e-5).sqrt() + embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean + + dist = -cdist(flatten, embed) + + embed_ind, embed_onehot = self.gumbel_sample( + dist, dim=-1, temperature=sample_codebook_temp, training=self.training + ) + + embed_ind = unpack_one(embed_ind, ps, "h *") + + if self.training: + unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") + quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) + else: + quantize = batched_embedding(embed_ind, embed) + + if self.training and self.ema_update and not freeze_codebook: + if self.affine_param: + flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean + + if mask is not None: + embed_onehot[~mask] = 0.0 + + cluster_size = embed_onehot.sum(dim=1) + + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size.data, cluster_size, self.decay) + + embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) + self.all_reduce_fn(embed_sum.contiguous()) + ema_inplace(self.embed_avg.data, embed_sum, self.decay) + + cluster_size = laplace_smoothing( + self.cluster_size, self.codebook_size, self.eps + ) * self.cluster_size.sum(dim=-1, keepdim=True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") + self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) + + if needs_codebook_dim: + quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) + + dist = unpack_one(dist, ps, "h * d") + + return quantize, embed_ind, dist diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..013419a9e770daa5db577c506249870f9a30b41c --- /dev/null +++ b/lerobot/common/robot_devices/cameras/configs.py @@ -0,0 +1,114 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass + +import draccus + + +@dataclass +class CameraConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@CameraConfig.register_subclass("opencv") +@dataclass +class OpenCVCameraConfig(CameraConfig): + """ + Example of tested options for Intel Real Sense D405: + + ```python + OpenCVCameraConfig(0, 30, 640, 480) + OpenCVCameraConfig(0, 60, 640, 480) + OpenCVCameraConfig(0, 90, 640, 480) + OpenCVCameraConfig(0, 30, 1280, 720) + ``` + """ + + camera_index: int + fps: int | None = None + width: int | None = None + height: int | None = None + color_mode: str = "rgb" + channels: int | None = None + rotation: int | None = None + mock: bool = False + + def __post_init__(self): + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + + if self.rotation not in [-90, None, 90, 180]: + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + + +@CameraConfig.register_subclass("intelrealsense") +@dataclass +class IntelRealSenseCameraConfig(CameraConfig): + """ + Example of tested options for Intel Real Sense D405: + + ```python + IntelRealSenseCameraConfig(128422271347, 30, 640, 480) + IntelRealSenseCameraConfig(128422271347, 60, 640, 480) + IntelRealSenseCameraConfig(128422271347, 90, 640, 480) + IntelRealSenseCameraConfig(128422271347, 30, 1280, 720) + IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True) + IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90) + ``` + """ + + name: str | None = None + serial_number: int | None = None + fps: int | None = None + width: int | None = None + height: int | None = None + color_mode: str = "rgb" + channels: int | None = None + use_depth: bool = False + force_hardware_reset: bool = True + rotation: int | None = None + mock: bool = False + + def __post_init__(self): + # bool is stronger than is None, since it works with empty strings + if bool(self.name) and bool(self.serial_number): + raise ValueError( + f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided." + ) + + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + + at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None + at_least_one_is_none = self.fps is None or self.width is None or self.height is None + if at_least_one_is_not_none and at_least_one_is_none: + raise ValueError( + "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " + f"but {self.fps=}, {self.width=}, {self.height=} were provided." + ) + + if self.rotation not in [-90, None, 90, 180]: + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py new file mode 100644 index 0000000000000000000000000000000000000000..607cef7852f01762f7bdd1d54ff1ef57159e1fe9 --- /dev/null +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -0,0 +1,538 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for recording frames from Intel Realsense cameras. +""" + +import argparse +import concurrent.futures +import logging +import math +import shutil +import threading +import time +import traceback +from collections import Counter +from pathlib import Path +from threading import Thread + +import numpy as np +from PIL import Image + +from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, +) +from lerobot.common.utils.utils import capture_timestamp_utc + +SERIAL_NUMBER_INDEX = 1 + + +def find_cameras(raise_when_empty=True, mock=False) -> list[dict]: + """ + Find the names and the serial numbers of the Intel RealSense cameras + connected to the computer. + """ + if mock: + import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs + else: + import pyrealsense2 as rs + + cameras = [] + for device in rs.context().query_devices(): + serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX))) + name = device.get_info(rs.camera_info.name) + cameras.append( + { + "serial_number": serial_number, + "name": name, + } + ) + + if raise_when_empty and len(cameras) == 0: + raise OSError( + "Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware." + ) + + return cameras + + +def save_image(img_array, serial_number, frame_index, images_dir): + try: + img = Image.fromarray(img_array) + path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) + logging.info(f"Saved image: {path}") + except Exception as e: + logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") + + +def save_images_from_cameras( + images_dir: Path, + serial_numbers: list[int] | None = None, + fps=None, + width=None, + height=None, + record_time_s=2, + mock=False, +): + """ + Initializes all the cameras and saves images to the directory. Useful to visually identify the camera + associated to a given serial number. + """ + if serial_numbers is None or len(serial_numbers) == 0: + camera_infos = find_cameras(mock=mock) + serial_numbers = [cam["serial_number"] for cam in camera_infos] + + if mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + print("Connecting cameras") + cameras = [] + for cam_sn in serial_numbers: + print(f"{cam_sn=}") + config = IntelRealSenseCameraConfig( + serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock + ) + camera = IntelRealSenseCamera(config) + camera.connect() + print( + f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})" + ) + cameras.append(camera) + + images_dir = Path(images_dir) + if images_dir.exists(): + shutil.rmtree( + images_dir, + ) + images_dir.mkdir(parents=True, exist_ok=True) + + print(f"Saving images to {images_dir}") + frame_index = 0 + start_time = time.perf_counter() + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + while True: + now = time.perf_counter() + + for camera in cameras: + # If we use async_read when fps is None, the loop will go full speed, and we will end up + # saving the same images from the cameras multiple times until the RAM/disk is full. + image = camera.read() if fps is None else camera.async_read() + if image is None: + print("No Frame") + + bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + executor.submit( + save_image, + bgr_converted_image, + camera.serial_number, + frame_index, + images_dir, + ) + + if fps is not None: + dt_s = time.perf_counter() - now + busy_wait(1 / fps - dt_s) + + if time.perf_counter() - start_time > record_time_s: + break + + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + + frame_index += 1 + finally: + print(f"Images have been saved to {images_dir}") + for camera in cameras: + camera.disconnect() + + +class IntelRealSenseCamera: + """ + The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras: + - is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux, + - can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(), + - depth map can be returned. + + To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera: + ```bash + python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras + ``` + + When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode + of the given camera will be used. + + Example of instantiating with a serial number: + ```python + from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig + + config = IntelRealSenseCameraConfig(serial_number=128422271347) + camera = IntelRealSenseCamera(config) + camera.connect() + color_image = camera.read() + # when done using the camera, consider disconnecting + camera.disconnect() + ``` + + Example of instantiating with a name if it's unique: + ``` + config = IntelRealSenseCameraConfig(name="Intel RealSense D405") + ``` + + Example of changing default fps, width, height and color_mode: + ```python + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720) + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480) + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr") + # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera + ``` + + Example of returning depth: + ```python + config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True) + camera = IntelRealSenseCamera(config) + camera.connect() + color_image, depth_map = camera.read() + ``` + """ + + def __init__( + self, + config: IntelRealSenseCameraConfig, + ): + self.config = config + if config.name is not None: + self.serial_number = self.find_serial_number_from_name(config.name) + else: + self.serial_number = config.serial_number + + # Store the raw (capture) resolution from the config. + self.capture_width = config.width + self.capture_height = config.height + + # If rotated by ±90, swap width and height. + if config.rotation in [-90, 90]: + self.width = config.height + self.height = config.width + else: + self.width = config.width + self.height = config.height + + self.fps = config.fps + self.channels = config.channels + self.color_mode = config.color_mode + self.use_depth = config.use_depth + self.force_hardware_reset = config.force_hardware_reset + self.mock = config.mock + + self.camera = None + self.is_connected = False + self.thread = None + self.stop_event = None + self.color_image = None + self.depth_map = None + self.logs = {} + + if self.mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + self.rotation = None + if config.rotation == -90: + self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE + elif config.rotation == 90: + self.rotation = cv2.ROTATE_90_CLOCKWISE + elif config.rotation == 180: + self.rotation = cv2.ROTATE_180 + + def find_serial_number_from_name(self, name): + camera_infos = find_cameras() + camera_names = [cam["name"] for cam in camera_infos] + this_name_count = Counter(camera_names)[name] + if this_name_count > 1: + # TODO(aliberts): Test this with multiple identical cameras (Aloha) + raise ValueError( + f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." + ) + + name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} + cam_sn = name_to_serial_dict[name] + + return cam_sn + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is already connected." + ) + + if self.mock: + import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs + else: + import pyrealsense2 as rs + + config = rs.config() + config.enable_device(str(self.serial_number)) + + if self.fps and self.capture_width and self.capture_height: + # TODO(rcadene): can we set rgb8 directly? + config.enable_stream( + rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps + ) + else: + config.enable_stream(rs.stream.color) + + if self.use_depth: + if self.fps and self.capture_width and self.capture_height: + config.enable_stream( + rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps + ) + else: + config.enable_stream(rs.stream.depth) + + self.camera = rs.pipeline() + try: + profile = self.camera.start(config) + is_camera_open = True + except RuntimeError: + is_camera_open = False + traceback.print_exc() + + # If the camera doesn't work, display the camera indices corresponding to + # valid cameras. + if not is_camera_open: + # Verify that the provided `serial_number` is valid before printing the traceback + camera_infos = find_cameras() + serial_numbers = [cam["serial_number"] for cam in camera_infos] + if self.serial_number not in serial_numbers: + raise ValueError( + f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. " + "To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`." + ) + + raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).") + + color_stream = profile.get_stream(rs.stream.color) + color_profile = color_stream.as_video_stream_profile() + actual_fps = color_profile.fps() + actual_width = color_profile.width() + actual_height = color_profile.height() + + # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + # Using `OSError` since it's a broad that encompasses issues related to device communication + raise OSError( + f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." + ) + if self.capture_width is not None and self.capture_width != actual_width: + raise OSError( + f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}." + ) + if self.capture_height is not None and self.capture_height != actual_height: + raise OSError( + f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}." + ) + + self.fps = round(actual_fps) + self.capture_width = round(actual_width) + self.capture_height = round(actual_height) + + self.is_connected = True + + def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) + of type `np.uint8`, contrarily to the pytorch format which is float channel first. + + When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format + height x width (e.g. 480 x 640) of type np.uint16. + + Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. + If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + start_time = time.perf_counter() + + frame = self.camera.wait_for_frames(timeout_ms=5000) + + color_frame = frame.get_color_frame() + + if not color_frame: + raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") + + color_image = np.asanyarray(color_frame.get_data()) + + requested_color_mode = self.color_mode if temporary_color is None else temporary_color + if requested_color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + ) + + # IntelRealSense uses RGB format as default (red, green, blue). + if requested_color_mode == "bgr": + color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + + h, w, _ = color_image.shape + if h != self.capture_height or w != self.capture_width: + raise OSError( + f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + color_image = cv2.rotate(color_image, self.rotation) + + # log the number of seconds it took to read the image + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the image was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + if self.use_depth: + depth_frame = frame.get_depth_frame() + if not depth_frame: + raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") + + depth_map = np.asanyarray(depth_frame.get_data()) + + h, w = depth_map.shape + if h != self.capture_height or w != self.capture_width: + raise OSError( + f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + depth_map = cv2.rotate(depth_map, self.rotation) + + return color_image, depth_map + else: + return color_image + + def read_loop(self): + while not self.stop_event.is_set(): + if self.use_depth: + self.color_image, self.depth_map = self.read() + else: + self.color_image = self.read() + + def async_read(self): + """Access the latest color image""" + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.read_loop, args=()) + self.thread.daemon = True + self.thread.start() + + num_tries = 0 + while self.color_image is None: + # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here + num_tries += 1 + time.sleep(1 / self.fps) + if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): + raise Exception( + "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." + ) + + if self.use_depth: + return self.color_image, self.depth_map + else: + return self.color_image + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is not None and self.thread.is_alive(): + # wait for the thread to finish + self.stop_event.set() + self.thread.join() + self.thread = None + self.stop_event = None + + self.camera.stop() + self.camera = None + + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset." + ) + parser.add_argument( + "--serial-numbers", + type=int, + nargs="*", + default=None, + help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", + ) + parser.add_argument( + "--width", + type=str, + default=640, + help="Set the width for all cameras. If not provided, use the default width of each camera.", + ) + parser.add_argument( + "--height", + type=str, + default=480, + help="Set the height for all cameras. If not provided, use the default height of each camera.", + ) + parser.add_argument( + "--images-dir", + type=Path, + default="outputs/images_from_intelrealsense_cameras", + help="Set directory to save a few frames for each camera.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=2.0, + help="Set the number of seconds used to record the frames. By default, 2 seconds.", + ) + args = parser.parse_args() + save_images_from_cameras(**vars(args)) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py new file mode 100644 index 0000000000000000000000000000000000000000..014841c71bda36e39335be53df4fbaf9bdbf8795 --- /dev/null +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -0,0 +1,518 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. +""" + +import argparse +import concurrent.futures +import math +import platform +import shutil +import threading +import time +from pathlib import Path +from threading import Thread + +import numpy as np +from PIL import Image + +from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, +) +from lerobot.common.utils.utils import capture_timestamp_utc + +# The maximum opencv device index depends on your operating system. For instance, +# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case +# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23. +# When you change the USB port or reboot the computer, the operating system might +# treat the same cameras as new devices. Thus we select a higher bound to search indices. +MAX_OPENCV_INDEX = 60 + + +def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: + cameras = [] + if platform.system() == "Linux": + print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") + possible_ports = [str(port) for port in Path("/dev").glob("video*")] + ports = _find_cameras(possible_ports, mock=mock) + for port in ports: + cameras.append( + { + "port": port, + "index": int(port.removeprefix("/dev/video")), + } + ) + else: + print( + "Mac or Windows detected. Finding available camera indices through " + f"scanning all indices from 0 to {MAX_OPENCV_INDEX}" + ) + possible_indices = range(max_index_search_range) + indices = _find_cameras(possible_indices, mock=mock) + for index in indices: + cameras.append( + { + "port": None, + "index": index, + } + ) + + return cameras + + +def _find_cameras( + possible_camera_ids: list[int | str], raise_when_empty=False, mock=False +) -> list[int | str]: + if mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + camera_ids = [] + for camera_idx in possible_camera_ids: + camera = cv2.VideoCapture(camera_idx) + is_open = camera.isOpened() + camera.release() + + if is_open: + print(f"Camera found at index {camera_idx}") + camera_ids.append(camera_idx) + + if raise_when_empty and len(camera_ids) == 0: + raise OSError( + "Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, " + "or your camera driver, or make sure your camera is compatible with opencv2." + ) + + return camera_ids + + +def is_valid_unix_path(path: str) -> bool: + """Note: if 'path' points to a symlink, this will return True only if the target exists""" + p = Path(path) + return p.is_absolute() and p.exists() + + +def get_camera_index_from_unix_port(port: Path) -> int: + return int(str(port.resolve()).removeprefix("/dev/video")) + + +def save_image(img_array, camera_index, frame_index, images_dir): + img = Image.fromarray(img_array) + path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) + + +def save_images_from_cameras( + images_dir: Path, + camera_ids: list | None = None, + fps=None, + width=None, + height=None, + record_time_s=2, + mock=False, +): + """ + Initializes all the cameras and saves images to the directory. Useful to visually identify the camera + associated to a given camera index. + """ + if camera_ids is None or len(camera_ids) == 0: + camera_infos = find_cameras(mock=mock) + camera_ids = [cam["index"] for cam in camera_infos] + + print("Connecting cameras") + cameras = [] + for cam_idx in camera_ids: + config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) + camera = OpenCVCamera(config) + camera.connect() + print( + f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, " + f"height={camera.capture_height}, color_mode={camera.color_mode})" + ) + cameras.append(camera) + + images_dir = Path(images_dir) + if images_dir.exists(): + shutil.rmtree( + images_dir, + ) + images_dir.mkdir(parents=True, exist_ok=True) + + print(f"Saving images to {images_dir}") + frame_index = 0 + start_time = time.perf_counter() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + while True: + now = time.perf_counter() + + for camera in cameras: + # If we use async_read when fps is None, the loop will go full speed, and we will endup + # saving the same images from the cameras multiple times until the RAM/disk is full. + image = camera.read() if fps is None else camera.async_read() + + executor.submit( + save_image, + image, + camera.camera_index, + frame_index, + images_dir, + ) + + if fps is not None: + dt_s = time.perf_counter() - now + busy_wait(1 / fps - dt_s) + + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + + if time.perf_counter() - start_time > record_time_s: + break + + frame_index += 1 + + print(f"Images have been saved to {images_dir}") + + +class OpenCVCamera: + """ + The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate + with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). + + An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera + like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index + might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system. + + To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera: + ```bash + python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras + ``` + + When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode + of the given camera will be used. + + Example of usage: + ```python + from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig + + config = OpenCVCameraConfig(camera_index=0) + camera = OpenCVCamera(config) + camera.connect() + color_image = camera.read() + # when done using the camera, consider disconnecting + camera.disconnect() + ``` + + Example of changing default fps, width, height and color_mode: + ```python + config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720) + config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480) + config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr") + # Note: might error out open `camera.connect()` if these settings are not compatible with the camera + ``` + """ + + def __init__(self, config: OpenCVCameraConfig): + self.config = config + self.camera_index = config.camera_index + self.port = None + + # Linux uses ports for connecting to cameras + if platform.system() == "Linux": + if isinstance(self.camera_index, int): + self.port = Path(f"/dev/video{self.camera_index}") + elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): + self.port = Path(self.camera_index) + # Retrieve the camera index from a potentially symlinked path + self.camera_index = get_camera_index_from_unix_port(self.port) + else: + raise ValueError(f"Please check the provided camera_index: {self.camera_index}") + + # Store the raw (capture) resolution from the config. + self.capture_width = config.width + self.capture_height = config.height + + # If rotated by ±90, swap width and height. + if config.rotation in [-90, 90]: + self.width = config.height + self.height = config.width + else: + self.width = config.width + self.height = config.height + + self.fps = config.fps + self.channels = config.channels + self.color_mode = config.color_mode + self.mock = config.mock + + self.camera = None + self.is_connected = False + self.thread = None + self.stop_event = None + self.color_image = None + self.logs = {} + + if self.mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + self.rotation = None + if config.rotation == -90: + self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE + elif config.rotation == 90: + self.rotation = cv2.ROTATE_90_CLOCKWISE + elif config.rotation == 180: + self.rotation = cv2.ROTATE_180 + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") + + if self.mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + # Use 1 thread to avoid blocking the main thread. Especially useful during data collection + # when other threads are used to save the images. + cv2.setNumThreads(1) + + backend = ( + cv2.CAP_V4L2 + if platform.system() == "Linux" + else cv2.CAP_DSHOW + if platform.system() == "Windows" + else cv2.CAP_AVFOUNDATION + if platform.system() == "Darwin" + else cv2.CAP_ANY + ) + + camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index + # First create a temporary camera trying to access `camera_index`, + # and verify it is a valid camera by calling `isOpened`. + tmp_camera = cv2.VideoCapture(camera_idx, backend) + is_camera_open = tmp_camera.isOpened() + # Release camera to make it accessible for `find_camera_indices` + tmp_camera.release() + del tmp_camera + + # If the camera doesn't work, display the camera indices corresponding to + # valid cameras. + if not is_camera_open: + # Verify that the provided `camera_index` is valid before printing the traceback + cameras_info = find_cameras() + available_cam_ids = [cam["index"] for cam in cameras_info] + if self.camera_index not in available_cam_ids: + raise ValueError( + f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. " + "To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`." + ) + + raise OSError(f"Can't access OpenCVCamera({camera_idx}).") + + # Secondly, create the camera that will be used downstream. + # Note: For some unknown reason, calling `isOpened` blocks the camera which then + # needs to be re-created. + self.camera = cv2.VideoCapture(camera_idx, backend) + + if self.fps is not None: + self.camera.set(cv2.CAP_PROP_FPS, self.fps) + if self.capture_width is not None: + self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width) + if self.capture_height is not None: + self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height) + + actual_fps = self.camera.get(cv2.CAP_PROP_FPS) + actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH) + actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) + + # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + # Using `OSError` since it's a broad that encompasses issues related to device communication + raise OSError( + f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." + ) + if self.capture_width is not None and not math.isclose( + self.capture_width, actual_width, rel_tol=1e-3 + ): + raise OSError( + f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}." + ) + if self.capture_height is not None and not math.isclose( + self.capture_height, actual_height, rel_tol=1e-3 + ): + raise OSError( + f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}." + ) + + self.fps = round(actual_fps) + self.capture_width = round(actual_width) + self.capture_height = round(actual_height) + self.is_connected = True + + def read(self, temporary_color_mode: str | None = None) -> np.ndarray: + """Read a frame from the camera returned in the format (height, width, channels) + (e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first. + + Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. + If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + start_time = time.perf_counter() + + ret, color_image = self.camera.read() + + if not ret: + raise OSError(f"Can't capture color image from camera {self.camera_index}.") + + requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode + + if requested_color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + ) + + # OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images. + # However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks, + # so we convert the image color from BGR to RGB. + if requested_color_mode == "rgb": + if self.mock: + import lerobot.common.mocks.cameras.mock_cv2 as cv2 + else: + import cv2 + + color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + + h, w, _ = color_image.shape + if h != self.capture_height or w != self.capture_width: + raise OSError( + f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + color_image = cv2.rotate(color_image, self.rotation) + + # log the number of seconds it took to read the image + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the image was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + self.color_image = color_image + + return color_image + + def read_loop(self): + while not self.stop_event.is_set(): + try: + self.color_image = self.read() + except Exception as e: + print(f"Error reading in thread: {e}") + + def async_read(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.read_loop, args=()) + self.thread.daemon = True + self.thread.start() + + num_tries = 0 + while True: + if self.color_image is not None: + return self.color_image + + time.sleep(1 / self.fps) + num_tries += 1 + if num_tries > self.fps * 2: + raise TimeoutError("Timed out waiting for async_read() to start.") + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is not None: + self.stop_event.set() + self.thread.join() # wait for the thread to finish + self.thread = None + self.stop_event = None + + self.camera.release() + self.camera = None + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset." + ) + parser.add_argument( + "--camera-ids", + type=int, + nargs="*", + default=None, + help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.", + ) + parser.add_argument( + "--fps", + type=int, + default=None, + help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", + ) + parser.add_argument( + "--width", + type=str, + default=None, + help="Set the width for all cameras. If not provided, use the default width of each camera.", + ) + parser.add_argument( + "--height", + type=str, + default=None, + help="Set the height for all cameras. If not provided, use the default height of each camera.", + ) + parser.add_argument( + "--images-dir", + type=Path, + default="outputs/images_from_opencv_cameras", + help="Set directory to save a few frames for each camera.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=4.0, + help="Set the number of seconds used to record the frames. By default, 2 seconds.", + ) + args = parser.parse_args() + save_images_from_cameras(**vars(args)) diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c64316467ca71379718550a419b2574a0a0b3168 --- /dev/null +++ b/lerobot/common/robot_devices/cameras/utils.py @@ -0,0 +1,67 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +import numpy as np + +from lerobot.common.robot_devices.cameras.configs import ( + CameraConfig, + IntelRealSenseCameraConfig, + OpenCVCameraConfig, +) + + +# Defines a camera type +class Camera(Protocol): + def connect(self): ... + def read(self, temporary_color: str | None = None) -> np.ndarray: ... + def async_read(self) -> np.ndarray: ... + def disconnect(self): ... + + +def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]: + cameras = {} + + for key, cfg in camera_configs.items(): + if cfg.type == "opencv": + from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera + + cameras[key] = OpenCVCamera(cfg) + + elif cfg.type == "intelrealsense": + from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + + cameras[key] = IntelRealSenseCamera(cfg) + else: + raise ValueError(f"The camera type '{cfg.type}' is not valid.") + + return cameras + + +def make_camera(camera_type, **kwargs) -> Camera: + if camera_type == "opencv": + from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera + + config = OpenCVCameraConfig(**kwargs) + return OpenCVCamera(config) + + elif camera_type == "intelrealsense": + from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + + config = IntelRealSenseCameraConfig(**kwargs) + return IntelRealSenseCamera(config) + + else: + raise ValueError(f"The camera type '{camera_type}' is not valid.") diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..cb558c7167ebabfa78a6849f0a71b9f5d0bcbaa2 --- /dev/null +++ b/lerobot/common/robot_devices/control_configs.py @@ -0,0 +1,134 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path + +import draccus + +from lerobot.common.robot_devices.robots.configs import RobotConfig +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig + + +@dataclass +class ControlConfig(draccus.ChoiceRegistry): + pass + + +@ControlConfig.register_subclass("calibrate") +@dataclass +class CalibrateControlConfig(ControlConfig): + # List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`) + arms: list[str] | None = None + + +@ControlConfig.register_subclass("teleoperate") +@dataclass +class TeleoperateControlConfig(ControlConfig): + # Limit the maximum frames per second. By default, no limit. + fps: int | None = None + teleop_time_s: float | None = None + # Display all cameras on screen + display_data: bool = False + + +@ControlConfig.register_subclass("record") +@dataclass +class RecordControlConfig(ControlConfig): + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") + single_task: str + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + policy: PreTrainedConfig | None = None + # Limit the frames per second. By default, uses the policy fps. + fps: int | None = None + # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. + warmup_time_s: int | float = 10 + # Number of seconds for data recording for each episode. + episode_time_s: int | float = 60 + # Number of seconds for resetting the environment after each episode. + reset_time_s: int | float = 60 + # Number of episodes to record. + num_episodes: int = 50 + # Encode frames in the dataset into video + video: bool = True + # Upload dataset to Hugging Face hub. + push_to_hub: bool = True + # Upload on private repository on the Hugging Face hub. + private: bool = False + # Add tags to your dataset on the hub. + tags: list[str] | None = None + # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; + # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes + # and threads depends on your system. We recommend 4 threads per camera with 0 processes. + # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. + num_image_writer_processes: int = 0 + # Number of threads writing the frames as png images on disk, per camera. + # Too many threads might cause unstable teleoperation fps due to main thread being blocked. + # Not enough threads might cause low camera fps. + num_image_writer_threads_per_camera: int = 4 + # Display all cameras on screen + display_data: bool = False + # Use vocal synthesis to read events. + play_sounds: bool = True + # Resume recording on an existing dataset. + resume: bool = False + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("control.policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("control.policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + +@ControlConfig.register_subclass("replay") +@dataclass +class ReplayControlConfig(ControlConfig): + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # Index of the episode to replay. + episode: int + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # Limit the frames per second. By default, uses the dataset fps. + fps: int | None = None + # Use vocal synthesis to read events. + play_sounds: bool = True + + +@ControlConfig.register_subclass("remote_robot") +@dataclass +class RemoteRobotConfig(ControlConfig): + log_interval: int = 100 + # Display all cameras on screen + display_data: bool = False + # Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp) + viewer_ip: str | None = None + viewer_port: str | None = None + + +@dataclass +class ControlPipelineConfig: + robot: RobotConfig + control: ControlConfig + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["control.policy"] diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4e42a9896c708ef8cc3a03baa686279b977d6b23 --- /dev/null +++ b/lerobot/common/robot_devices/control_utils.py @@ -0,0 +1,347 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################################## +# Utilities +######################################################################################## + + +import logging +import time +import traceback +from contextlib import nullcontext +from copy import copy +from functools import cache + +import rerun as rr +import torch +from deepdiff import DeepDiff +from termcolor import colored + +from lerobot.common.datasets.image_writer import safe_stop_image_writer +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import get_features_from_robot +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.utils import get_safe_torch_device, has_method + + +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + log_items = [] + if episode_index is not None: + log_items.append(f"ep:{episode_index}") + if frame_index is not None: + log_items.append(f"frame:{frame_index}") + + def log_dt(shortname, dt_val_s): + nonlocal log_items, fps + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" + if fps is not None: + actual_fps = 1 / dt_val_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + log_items.append(info_str) + + # total step time displayed in milliseconds and its frequency + log_dt("dt", dt_s) + + # TODO(aliberts): move robot-specific logs logic in robot.print_logs() + if not robot.robot_type.startswith("stretch"): + for name in robot.leader_arms: + key = f"read_leader_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRlead", robot.logs[key]) + + for name in robot.follower_arms: + key = f"write_follower_{name}_goal_pos_dt_s" + if key in robot.logs: + log_dt("dtWfoll", robot.logs[key]) + + key = f"read_follower_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRfoll", robot.logs[key]) + + for name in robot.cameras: + key = f"read_camera_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + + info_str = " ".join(log_items) + logging.info(info_str) + + +@cache +def is_headless(): + """Detects if python is running without a monitor.""" + try: + import pynput # noqa + + return False + except Exception: + print( + "Error trying to import pynput. Switching to headless mode. " + "As a result, the video stream from the cameras won't be shown, " + "and you won't be able to change the control flow with keyboards. " + "For more info, see traceback below.\n" + ) + traceback.print_exc() + print() + return True + + +def predict_action(observation, policy, device, use_amp): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def warmup_record( + robot, + events, + enable_teleoperation, + warmup_time_s, + display_data, + fps, +): + control_loop( + robot=robot, + control_time_s=warmup_time_s, + display_data=display_data, + events=events, + fps=fps, + teleoperate=enable_teleoperation, + ) + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_data, + policy, + fps, + single_task, +): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_data=display_data, + dataset=dataset, + events=events, + policy=policy, + fps=fps, + teleoperate=policy is None, + single_task=single_task, + ) + + +@safe_stop_image_writer +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_data=False, + dataset: LeRobotDataset | None = None, + events=None, + policy: PreTrainedPolicy = None, + fps: int | None = None, + single_task: str | None = None, +): + # TODO(rcadene): Add option to record logs + if not robot.is_connected: + robot.connect() + + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if teleoperate and policy is not None: + raise ValueError("When `teleoperate` is True, `policy` should be None.") + + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < control_time_s: + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step(record_data=True) + else: + observation = robot.capture_observation() + + if policy is not None: + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + frame = {**observation, **action, "task": single_task} + dataset.add_frame(frame) + + # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) + if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")): + for k, v in action.items(): + for i, vv in enumerate(v): + rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) + + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + rr.log(key, rr.Image(observation[key].numpy()), static=True) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def reset_environment(robot, events, reset_time_s, fps): + # TODO(rcadene): refactor warmup_record and reset_environment + if has_method(robot, "teleop_safety_stop"): + robot.teleop_safety_stop() + + control_loop( + robot=robot, + control_time_s=reset_time_s, + events=events, + fps=fps, + teleoperate=True, + ) + + +def stop_recording(robot, listener, display_data): + robot.disconnect() + + if not is_headless() and listener is not None: + listener.stop() + + +def sanity_check_dataset_name(repo_id, policy_cfg): + _, dataset_name = repo_id.split("/") + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + + # Check if dataset_name starts with "eval_" but policy is missing + if dataset_name.startswith("eval_") and policy_cfg is None: + raise ValueError( + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + ) + + # Check if dataset_name does not start with "eval_" but policy is provided + if not dataset_name.startswith("eval_") and policy_cfg is not None: + raise ValueError( + f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." + ) + + +def sanity_check_dataset_robot_compatibility( + dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool +) -> None: + fields = [ + ("robot_type", dataset.meta.robot_type, robot.robot_type), + ("fps", dataset.fps, fps), + ("features", dataset.features, get_features_from_robot(robot, use_videos)), + ] + + mismatches = [] + for field, dataset_value, present_value in fields: + diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + if diff: + mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") + + if mismatches: + raise ValueError( + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + ) diff --git a/lerobot/common/robot_devices/motors/configs.py b/lerobot/common/robot_devices/motors/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfbaf837588099e2c82f6d5d23ffeee81a5eff6 --- /dev/null +++ b/lerobot/common/robot_devices/motors/configs.py @@ -0,0 +1,41 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass + +import draccus + + +@dataclass +class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@MotorsBusConfig.register_subclass("dynamixel") +@dataclass +class DynamixelMotorsBusConfig(MotorsBusConfig): + port: str + motors: dict[str, tuple[int, str]] + mock: bool = False + + +@MotorsBusConfig.register_subclass("feetech") +@dataclass +class FeetechMotorsBusConfig(MotorsBusConfig): + port: str + motors: dict[str, tuple[int, str]] + mock: bool = False diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py new file mode 100644 index 0000000000000000000000000000000000000000..9321172cb8d456b2b7018b5da9572bc5ff932ee1 --- /dev/null +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -0,0 +1,873 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import logging +import math +import time +import traceback +from copy import deepcopy + +import numpy as np +import tqdm + +from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.utils.utils import capture_timestamp_utc + +PROTOCOL_VERSION = 2.0 +BAUDRATE = 1_000_000 +TIMEOUT_MS = 1000 + +MAX_ID_RANGE = 252 + +# The following bounds define the lower and upper joints range (after calibration). +# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees +# which corresponds to a half rotation on the left and half rotation on the right. +# Some joints might require higher range, so we allow up to [-270, 270] degrees until +# an error is raised. +LOWER_BOUND_DEGREE = -270 +UPPER_BOUND_DEGREE = 270 +# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper), +# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully +# closed, and 100% is fully open. To account for slight calibration issue, we allow up to +# [-10, 110] until an error is raised. +LOWER_BOUND_LINEAR = -10 +UPPER_BOUND_LINEAR = 110 + +HALF_TURN_DEGREE = 180 + +# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077 +# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288 +# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250 +# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350 +# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270 +# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150 + +# data_name: (address, size_byte) +X_SERIES_CONTROL_TABLE = { + "Model_Number": (0, 2), + "Model_Information": (2, 4), + "Firmware_Version": (6, 1), + "ID": (7, 1), + "Baud_Rate": (8, 1), + "Return_Delay_Time": (9, 1), + "Drive_Mode": (10, 1), + "Operating_Mode": (11, 1), + "Secondary_ID": (12, 1), + "Protocol_Type": (13, 1), + "Homing_Offset": (20, 4), + "Moving_Threshold": (24, 4), + "Temperature_Limit": (31, 1), + "Max_Voltage_Limit": (32, 2), + "Min_Voltage_Limit": (34, 2), + "PWM_Limit": (36, 2), + "Current_Limit": (38, 2), + "Acceleration_Limit": (40, 4), + "Velocity_Limit": (44, 4), + "Max_Position_Limit": (48, 4), + "Min_Position_Limit": (52, 4), + "Shutdown": (63, 1), + "Torque_Enable": (64, 1), + "LED": (65, 1), + "Status_Return_Level": (68, 1), + "Registered_Instruction": (69, 1), + "Hardware_Error_Status": (70, 1), + "Velocity_I_Gain": (76, 2), + "Velocity_P_Gain": (78, 2), + "Position_D_Gain": (80, 2), + "Position_I_Gain": (82, 2), + "Position_P_Gain": (84, 2), + "Feedforward_2nd_Gain": (88, 2), + "Feedforward_1st_Gain": (90, 2), + "Bus_Watchdog": (98, 1), + "Goal_PWM": (100, 2), + "Goal_Current": (102, 2), + "Goal_Velocity": (104, 4), + "Profile_Acceleration": (108, 4), + "Profile_Velocity": (112, 4), + "Goal_Position": (116, 4), + "Realtime_Tick": (120, 2), + "Moving": (122, 1), + "Moving_Status": (123, 1), + "Present_PWM": (124, 2), + "Present_Current": (126, 2), + "Present_Velocity": (128, 4), + "Present_Position": (132, 4), + "Velocity_Trajectory": (136, 4), + "Position_Trajectory": (140, 4), + "Present_Input_Voltage": (144, 2), + "Present_Temperature": (146, 1), +} + +X_SERIES_BAUDRATE_TABLE = { + 0: 9_600, + 1: 57_600, + 2: 115_200, + 3: 1_000_000, + 4: 2_000_000, + 5: 3_000_000, + 6: 4_000_000, +} + +CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"] +CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"] + +MODEL_CONTROL_TABLE = { + "x_series": X_SERIES_CONTROL_TABLE, + "xl330-m077": X_SERIES_CONTROL_TABLE, + "xl330-m288": X_SERIES_CONTROL_TABLE, + "xl430-w250": X_SERIES_CONTROL_TABLE, + "xm430-w350": X_SERIES_CONTROL_TABLE, + "xm540-w270": X_SERIES_CONTROL_TABLE, + "xc430-w150": X_SERIES_CONTROL_TABLE, +} + +MODEL_RESOLUTION = { + "x_series": 4096, + "xl330-m077": 4096, + "xl330-m288": 4096, + "xl430-w250": 4096, + "xm430-w350": 4096, + "xm540-w270": 4096, + "xc430-w150": 4096, +} + +MODEL_BAUDRATE_TABLE = { + "x_series": X_SERIES_BAUDRATE_TABLE, + "xl330-m077": X_SERIES_BAUDRATE_TABLE, + "xl330-m288": X_SERIES_BAUDRATE_TABLE, + "xl430-w250": X_SERIES_BAUDRATE_TABLE, + "xm430-w350": X_SERIES_BAUDRATE_TABLE, + "xm540-w270": X_SERIES_BAUDRATE_TABLE, + "xc430-w150": X_SERIES_BAUDRATE_TABLE, +} + +NUM_READ_RETRY = 10 +NUM_WRITE_RETRY = 10 + + +def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: + """This function converts the degree range to the step range for indicating motors rotation. + It assumes a motor achieves a full rotation by going from -180 degree position to +180. + The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. + """ + resolutions = [MODEL_RESOLUTION[model] for model in models] + steps = degrees / 180 * np.array(resolutions) / 2 + steps = steps.astype(int) + return steps + + +def convert_to_bytes(value, bytes, mock=False): + if mock: + return value + + import dynamixel_sdk as dxl + + # Note: No need to convert back into unsigned int, since this byte preprocessing + # already handles it for us. + if bytes == 1: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + ] + elif bytes == 2: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + ] + elif bytes == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + else: + raise NotImplementedError( + f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " + f"{bytes} is provided instead." + ) + return data + + +def get_group_sync_key(data_name, motor_names): + group_key = f"{data_name}_" + "_".join(motor_names) + return group_key + + +def get_result_name(fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + rslt_name = f"{fn_name}_{group_key}" + return rslt_name + + +def get_queue_name(fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + queue_name = f"{fn_name}_{group_key}" + return queue_name + + +def get_log_name(var_name, fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + log_name = f"{var_name}_{fn_name}_{group_key}" + return log_name + + +def assert_same_address(model_ctrl_table, motor_models, data_name): + all_addr = [] + all_bytes = [] + for model in motor_models: + addr, bytes = model_ctrl_table[model][data_name] + all_addr.append(addr) + all_bytes.append(bytes) + + if len(set(all_addr)) != 1: + raise NotImplementedError( + f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." + ) + + if len(set(all_bytes)) != 1: + raise NotImplementedError( + f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." + ) + + +class TorqueMode(enum.Enum): + ENABLED = 1 + DISABLED = 0 + + +class DriveMode(enum.Enum): + NON_INVERTED = 0 + INVERTED = 1 + + +class CalibrationMode(enum.Enum): + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + DEGREE = 0 + # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] + LINEAR = 1 + + +class JointOutOfRangeError(Exception): + def __init__(self, message="Joint is out of range"): + self.message = message + super().__init__(self.message) + + +class DynamixelMotorsBus: + """ + The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on + the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20). + + A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). + To find the port, you can run our utility script: + ```bash + python lerobot/scripts/find_motors_bus_port.py + >>> Finding all available ports for the MotorBus. + >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] + >>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done. + >>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751. + >>> Reconnect the usb cable. + ``` + + Example of usage for 1 motor connected to the bus: + ```python + motor_name = "gripper" + motor_index = 6 + motor_model = "xl330-m288" + + config = DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem575E0031751", + motors={motor_name: (motor_index, motor_model)}, + ) + motors_bus = DynamixelMotorsBus(config) + motors_bus.connect() + + position = motors_bus.read("Present_Position") + + # move from a few motor steps as an example + few_steps = 30 + motors_bus.write("Goal_Position", position + few_steps) + + # when done, consider disconnecting + motors_bus.disconnect() + ``` + """ + + def __init__( + self, + config: DynamixelMotorsBusConfig, + ): + self.port = config.port + self.motors = config.motors + self.mock = config.mock + + self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) + self.model_resolution = deepcopy(MODEL_RESOLUTION) + + self.port_handler = None + self.packet_handler = None + self.calibration = None + self.is_connected = False + self.group_readers = {} + self.group_writers = {} + self.logs = {} + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice." + ) + + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + self.port_handler = dxl.PortHandler(self.port) + self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) + + try: + if not self.port_handler.openPort(): + raise OSError(f"Failed to open port '{self.port}'.") + except Exception: + traceback.print_exc() + print( + "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" + ) + raise + + # Allow to read and write + self.is_connected = True + + self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) + + def reconnect(self): + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + self.port_handler = dxl.PortHandler(self.port) + self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) + + if not self.port_handler.openPort(): + raise OSError(f"Failed to open port '{self.port}'.") + + self.is_connected = True + + def are_motors_configured(self): + # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, + # a ConnectionError will be raised anyway. + try: + return (self.motor_indices == self.read("ID")).all() + except ConnectionError as e: + print(e) + return False + + def find_motor_indices(self, possible_ids=None, num_retry=2): + if possible_ids is None: + possible_ids = range(MAX_ID_RANGE) + + indices = [] + for idx in tqdm.tqdm(possible_ids): + try: + present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + except ConnectionError: + continue + + if idx != present_idx: + # sanity check + raise OSError( + "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." + ) + indices.append(idx) + + return indices + + def set_bus_baudrate(self, baudrate): + present_bus_baudrate = self.port_handler.getBaudRate() + if present_bus_baudrate != baudrate: + print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + self.port_handler.setBaudRate(baudrate) + + if self.port_handler.getBaudRate() != baudrate: + raise OSError("Failed to write bus baud rate.") + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def set_calibration(self, calibration: dict[str, list]): + self.calibration = calibration + + def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. + + For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. + """ + try: + values = self.apply_calibration(values, motor_names) + except JointOutOfRangeError as e: + print(e) + self.autocorrect_calibration(values, motor_names) + values = self.apply_calibration(values, motor_names) + return values + + def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with + a "zero position" at 0 degree. + + Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor + rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range. + + Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation + when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and + at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830, + or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor. + To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work + in the centered nominal degree range ]-180, 180[. + """ + if motor_names is None: + motor_names = self.motor_names + + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range + values = values.astype(np.float32) + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Update direction of rotation of the motor to match between leader and follower. + # In fact, the motor of the leader for a given joint can be assembled in an + # opposite direction in term of rotation than the motor of the follower on the same joint. + if drive_mode: + values[i] *= -1 + + # Convert from range [-2**31, 2**31] to + # nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048]) + values[i] += homing_offset + + # Convert from range [-resolution//2, resolution//2] to + # universal float32 centered degree range [-180, 180] + # (e.g. 2048 / (4096 // 2) * 180 = 180) + values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE + + if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), " + f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, " + f"but present value is {values[i]} degree. " + "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Rescale the present position to a nominal range [0, 100] %, + # useful for joints with linear motions like Aloha gripper + values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100 + + if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [0, 100] % (a full linear translation), " + f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, " + f"but present value is {values[i]} %. " + "This might be due to a cable connection issue creating an artificial jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + return values + + def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function automatically detects issues with values of motors after calibration, and correct for these issues. + + Some motors might have values outside of expected maximum bounds after calibration. + For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given + a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position. + + Known issues: + #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors. + #2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha). + #3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration + or by human error during manual calibration. + + Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn. + Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`, + that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue. + + Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + """ + if motor_names is None: + motor_names = self.motor_names + + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range + values = values.astype(np.float32) + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Update direction of rotation of the motor to match between leader and follower. + # In fact, the motor of the leader for a given joint can be assembled in an + # opposite direction in term of rotation than the motor of the follower on the same joint. + if drive_mode: + values[i] *= -1 + + # Convert from initial range to range [-180, 180] degrees + calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + + # Solve this inequality to find the factor to shift the range into [-180, 180] degrees + # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE + # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE + # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution + low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution + upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Convert from initial range to range [0, 100] in % + calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 + in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + + # Solve this inequality to find the factor to shift the range into [0, 100] % + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 + # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100 + # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution + low_factor = (start_pos - values[i]) / resolution + upp_factor = (end_pos - values[i]) / resolution + + if not in_range: + # Get first integer between the two bounds + if low_factor < upp_factor: + factor = math.ceil(low_factor) + + if factor > upp_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + else: + factor = math.ceil(upp_factor) + + if factor > low_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + + logging.warning( + f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " + f"from '{out_of_range_str}' to '{in_range_str}'." + ) + + # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + self.calibration["homing_offset"][calib_idx] += resolution * factor + + def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """Inverse of `apply_calibration`.""" + if motor_names is None: + motor_names = self.motor_names + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Convert from nominal 0-centered degree range [-180, 180] to + # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) + values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) + + # Subtract the homing offsets to come back to actual motor range of values + # which can be arbitrary. + values[i] -= homing_offset + + # Remove drive mode, which is the rotation direction of the motor, to come back to + # actual motor rotation direction which can be arbitrary. + if drive_mode: + values[i] *= -1 + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Convert from nominal lnear range of [0, 100] % to + # actual motor range of values which can be arbitrary. + values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos + + values = np.round(values).astype(np.int32) + return values + + def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + return_list = True + if not isinstance(motor_ids, list): + return_list = False + motor_ids = [motor_ids] + + assert_same_address(self.model_ctrl_table, self.motor_models, data_name) + addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] + group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) + for idx in motor_ids: + group.addParam(idx) + + for _ in range(num_retry): + comm = group.txRxPacket() + if comm == dxl.COMM_SUCCESS: + break + + if comm != dxl.COMM_SUCCESS: + raise ConnectionError( + f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + values = [] + for idx in motor_ids: + value = group.getData(idx, addr, bytes) + values.append(value) + + if return_list: + return values + else: + return values[0] + + def read(self, data_name, motor_names: str | list[str] | None = None): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." + ) + + start_time = time.perf_counter() + + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + if motor_names is None: + motor_names = self.motor_names + + if isinstance(motor_names, str): + motor_names = [motor_names] + + motor_ids = [] + models = [] + for name in motor_names: + motor_idx, model = self.motors[name] + motor_ids.append(motor_idx) + models.append(model) + + assert_same_address(self.model_ctrl_table, models, data_name) + addr, bytes = self.model_ctrl_table[model][data_name] + group_key = get_group_sync_key(data_name, motor_names) + + if data_name not in self.group_readers: + # create new group reader + self.group_readers[group_key] = dxl.GroupSyncRead( + self.port_handler, self.packet_handler, addr, bytes + ) + for idx in motor_ids: + self.group_readers[group_key].addParam(idx) + + for _ in range(NUM_READ_RETRY): + comm = self.group_readers[group_key].txRxPacket() + if comm == dxl.COMM_SUCCESS: + break + + if comm != dxl.COMM_SUCCESS: + raise ConnectionError( + f"Read failed due to communication error on port {self.port} for group_key {group_key}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + values = [] + for idx in motor_ids: + value = self.group_readers[group_key].getData(idx, addr, bytes) + values.append(value) + + values = np.array(values) + + # Convert to signed int to use range [-2048, 2048] for our motor positions. + if data_name in CONVERT_UINT32_TO_INT32_REQUIRED: + values = values.astype(np.int32) + + if data_name in CALIBRATION_REQUIRED and self.calibration is not None: + values = self.apply_calibration_autocorrect(values, motor_names) + + # log the number of seconds it took to read the data from the motors + delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + self.logs[delta_ts_name] = time.perf_counter() - start_time + + # log the utc time at which the data was received + ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) + self.logs[ts_utc_name] = capture_timestamp_utc() + + return values + + def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + if not isinstance(motor_ids, list): + motor_ids = [motor_ids] + if not isinstance(values, list): + values = [values] + + assert_same_address(self.model_ctrl_table, motor_models, data_name) + addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] + group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) + for idx, value in zip(motor_ids, values, strict=True): + data = convert_to_bytes(value, bytes, self.mock) + group.addParam(idx, data) + + for _ in range(num_retry): + comm = group.txPacket() + if comm == dxl.COMM_SUCCESS: + break + + if comm != dxl.COMM_SUCCESS: + raise ConnectionError( + f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." + ) + + start_time = time.perf_counter() + + if self.mock: + import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl + else: + import dynamixel_sdk as dxl + + if motor_names is None: + motor_names = self.motor_names + + if isinstance(motor_names, str): + motor_names = [motor_names] + + if isinstance(values, (int, float, np.integer)): + values = [int(values)] * len(motor_names) + + values = np.array(values) + + motor_ids = [] + models = [] + for name in motor_names: + motor_idx, model = self.motors[name] + motor_ids.append(motor_idx) + models.append(model) + + if data_name in CALIBRATION_REQUIRED and self.calibration is not None: + values = self.revert_calibration(values, motor_names) + + values = values.tolist() + + assert_same_address(self.model_ctrl_table, models, data_name) + addr, bytes = self.model_ctrl_table[model][data_name] + group_key = get_group_sync_key(data_name, motor_names) + + init_group = data_name not in self.group_readers + if init_group: + self.group_writers[group_key] = dxl.GroupSyncWrite( + self.port_handler, self.packet_handler, addr, bytes + ) + + for idx, value in zip(motor_ids, values, strict=True): + data = convert_to_bytes(value, bytes, self.mock) + if init_group: + self.group_writers[group_key].addParam(idx, data) + else: + self.group_writers[group_key].changeParam(idx, data) + + comm = self.group_writers[group_key].txPacket() + if comm != dxl.COMM_SUCCESS: + raise ConnectionError( + f"Write failed due to communication error on port {self.port} for group_key {group_key}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + # log the number of seconds it took to write the data to the motors + delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + self.logs[delta_ts_name] = time.perf_counter() - start_time + + # TODO(rcadene): should we log the time before sending the write command? + # log the utc time when the write has been completed + ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) + self.logs[ts_utc_name] = capture_timestamp_utc() + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first." + ) + + if self.port_handler is not None: + self.port_handler.closePort() + self.port_handler = None + + self.packet_handler = None + self.group_readers = {} + self.group_writers = {} + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py new file mode 100644 index 0000000000000000000000000000000000000000..3268f3439b490d3e138a2e58b62ac86602da1978 --- /dev/null +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -0,0 +1,898 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import logging +import math +import time +import traceback +from copy import deepcopy + +import numpy as np +import tqdm + +from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.utils.utils import capture_timestamp_utc + +PROTOCOL_VERSION = 0 +BAUDRATE = 1_000_000 +TIMEOUT_MS = 1000 + +MAX_ID_RANGE = 252 + +# The following bounds define the lower and upper joints range (after calibration). +# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees +# which corresponds to a half rotation on the left and half rotation on the right. +# Some joints might require higher range, so we allow up to [-270, 270] degrees until +# an error is raised. +LOWER_BOUND_DEGREE = -270 +UPPER_BOUND_DEGREE = 270 +# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper), +# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully +# closed, and 100% is fully open. To account for slight calibration issue, we allow up to +# [-10, 110] until an error is raised. +LOWER_BOUND_LINEAR = -10 +UPPER_BOUND_LINEAR = 110 + +HALF_TURN_DEGREE = 180 + + +# See this link for STS3215 Memory Table: +# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true +# data_name: (address, size_byte) +SCS_SERIES_CONTROL_TABLE = { + "Model": (3, 2), + "ID": (5, 1), + "Baud_Rate": (6, 1), + "Return_Delay": (7, 1), + "Response_Status_Level": (8, 1), + "Min_Angle_Limit": (9, 2), + "Max_Angle_Limit": (11, 2), + "Max_Temperature_Limit": (13, 1), + "Max_Voltage_Limit": (14, 1), + "Min_Voltage_Limit": (15, 1), + "Max_Torque_Limit": (16, 2), + "Phase": (18, 1), + "Unloading_Condition": (19, 1), + "LED_Alarm_Condition": (20, 1), + "P_Coefficient": (21, 1), + "D_Coefficient": (22, 1), + "I_Coefficient": (23, 1), + "Minimum_Startup_Force": (24, 2), + "CW_Dead_Zone": (26, 1), + "CCW_Dead_Zone": (27, 1), + "Protection_Current": (28, 2), + "Angular_Resolution": (30, 1), + "Offset": (31, 2), + "Mode": (33, 1), + "Protective_Torque": (34, 1), + "Protection_Time": (35, 1), + "Overload_Torque": (36, 1), + "Speed_closed_loop_P_proportional_coefficient": (37, 1), + "Over_Current_Protection_Time": (38, 1), + "Velocity_closed_loop_I_integral_coefficient": (39, 1), + "Torque_Enable": (40, 1), + "Acceleration": (41, 1), + "Goal_Position": (42, 2), + "Goal_Time": (44, 2), + "Goal_Speed": (46, 2), + "Torque_Limit": (48, 2), + "Lock": (55, 1), + "Present_Position": (56, 2), + "Present_Speed": (58, 2), + "Present_Load": (60, 2), + "Present_Voltage": (62, 1), + "Present_Temperature": (63, 1), + "Status": (65, 1), + "Moving": (66, 1), + "Present_Current": (69, 2), + # Not in the Memory Table + "Maximum_Acceleration": (85, 2), +} + +SCS_SERIES_BAUDRATE_TABLE = { + 0: 1_000_000, + 1: 500_000, + 2: 250_000, + 3: 128_000, + 4: 115_200, + 5: 57_600, + 6: 38_400, + 7: 19_200, +} + +CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"] +CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"] + + +MODEL_CONTROL_TABLE = { + "scs_series": SCS_SERIES_CONTROL_TABLE, + "sts3215": SCS_SERIES_CONTROL_TABLE, +} + +MODEL_RESOLUTION = { + "scs_series": 4096, + "sts3215": 4096, +} + +MODEL_BAUDRATE_TABLE = { + "scs_series": SCS_SERIES_BAUDRATE_TABLE, + "sts3215": SCS_SERIES_BAUDRATE_TABLE, +} + +# High number of retries is needed for feetech compared to dynamixel motors. +NUM_READ_RETRY = 20 +NUM_WRITE_RETRY = 20 + + +def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: + """This function converts the degree range to the step range for indicating motors rotation. + It assumes a motor achieves a full rotation by going from -180 degree position to +180. + The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. + """ + resolutions = [MODEL_RESOLUTION[model] for model in models] + steps = degrees / 180 * np.array(resolutions) / 2 + steps = steps.astype(int) + return steps + + +def convert_to_bytes(value, bytes, mock=False): + if mock: + return value + + import scservo_sdk as scs + + # Note: No need to convert back into unsigned int, since this byte preprocessing + # already handles it for us. + if bytes == 1: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + ] + elif bytes == 2: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + ] + elif bytes == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + else: + raise NotImplementedError( + f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " + f"{bytes} is provided instead." + ) + return data + + +def get_group_sync_key(data_name, motor_names): + group_key = f"{data_name}_" + "_".join(motor_names) + return group_key + + +def get_result_name(fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + rslt_name = f"{fn_name}_{group_key}" + return rslt_name + + +def get_queue_name(fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + queue_name = f"{fn_name}_{group_key}" + return queue_name + + +def get_log_name(var_name, fn_name, data_name, motor_names): + group_key = get_group_sync_key(data_name, motor_names) + log_name = f"{var_name}_{fn_name}_{group_key}" + return log_name + + +def assert_same_address(model_ctrl_table, motor_models, data_name): + all_addr = [] + all_bytes = [] + for model in motor_models: + addr, bytes = model_ctrl_table[model][data_name] + all_addr.append(addr) + all_bytes.append(bytes) + + if len(set(all_addr)) != 1: + raise NotImplementedError( + f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." + ) + + if len(set(all_bytes)) != 1: + raise NotImplementedError( + f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." + ) + + +class TorqueMode(enum.Enum): + ENABLED = 1 + DISABLED = 0 + + +class DriveMode(enum.Enum): + NON_INVERTED = 0 + INVERTED = 1 + + +class CalibrationMode(enum.Enum): + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + DEGREE = 0 + # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] + LINEAR = 1 + + +class JointOutOfRangeError(Exception): + def __init__(self, message="Joint is out of range"): + self.message = message + super().__init__(self.message) + + +class FeetechMotorsBus: + """ + The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on + the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20). + + A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). + To find the port, you can run our utility script: + ```bash + python lerobot/scripts/find_motors_bus_port.py + >>> Finding all available ports for the MotorsBus. + >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] + >>> Remove the usb cable from your FeetechMotorsBus and press Enter when done. + >>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751. + >>> Reconnect the usb cable. + ``` + + Example of usage for 1 motor connected to the bus: + ```python + motor_name = "gripper" + motor_index = 6 + motor_model = "sts3215" + + config = FeetechMotorsBusConfig( + port="/dev/tty.usbmodem575E0031751", + motors={motor_name: (motor_index, motor_model)}, + ) + motors_bus = FeetechMotorsBus(config) + motors_bus.connect() + + position = motors_bus.read("Present_Position") + + # move from a few motor steps as an example + few_steps = 30 + motors_bus.write("Goal_Position", position + few_steps) + + # when done, consider disconnecting + motors_bus.disconnect() + ``` + """ + + def __init__( + self, + config: FeetechMotorsBusConfig, + ): + self.port = config.port + self.motors = config.motors + self.mock = config.mock + + self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) + self.model_resolution = deepcopy(MODEL_RESOLUTION) + + self.port_handler = None + self.packet_handler = None + self.calibration = None + self.is_connected = False + self.group_readers = {} + self.group_writers = {} + self.logs = {} + + self.track_positions = {} + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice." + ) + + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + self.port_handler = scs.PortHandler(self.port) + self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) + + try: + if not self.port_handler.openPort(): + raise OSError(f"Failed to open port '{self.port}'.") + except Exception: + traceback.print_exc() + print( + "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" + ) + raise + + # Allow to read and write + self.is_connected = True + + self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) + + def reconnect(self): + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + self.port_handler = scs.PortHandler(self.port) + self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) + + if not self.port_handler.openPort(): + raise OSError(f"Failed to open port '{self.port}'.") + + self.is_connected = True + + def are_motors_configured(self): + # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, + # a ConnectionError will be raised anyway. + try: + return (self.motor_indices == self.read("ID")).all() + except ConnectionError as e: + print(e) + return False + + def find_motor_indices(self, possible_ids=None, num_retry=2): + if possible_ids is None: + possible_ids = range(MAX_ID_RANGE) + + indices = [] + for idx in tqdm.tqdm(possible_ids): + try: + present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + except ConnectionError: + continue + + if idx != present_idx: + # sanity check + raise OSError( + "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." + ) + indices.append(idx) + + return indices + + def set_bus_baudrate(self, baudrate): + present_bus_baudrate = self.port_handler.getBaudRate() + if present_bus_baudrate != baudrate: + print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + self.port_handler.setBaudRate(baudrate) + + if self.port_handler.getBaudRate() != baudrate: + raise OSError("Failed to write bus baud rate.") + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def set_calibration(self, calibration: dict[str, list]): + self.calibration = calibration + + def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. + + For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. + """ + try: + values = self.apply_calibration(values, motor_names) + except JointOutOfRangeError as e: + print(e) + self.autocorrect_calibration(values, motor_names) + values = self.apply_calibration(values, motor_names) + return values + + def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with + a "zero position" at 0 degree. + + Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor + rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range. + + Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation + when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and + at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830, + or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor. + To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work + in the centered nominal degree range ]-180, 180[. + """ + if motor_names is None: + motor_names = self.motor_names + + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range + values = values.astype(np.float32) + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Update direction of rotation of the motor to match between leader and follower. + # In fact, the motor of the leader for a given joint can be assembled in an + # opposite direction in term of rotation than the motor of the follower on the same joint. + if drive_mode: + values[i] *= -1 + + # Convert from range [-2**31, 2**31[ to + # nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[) + values[i] += homing_offset + + # Convert from range ]-resolution, resolution[ to + # universal float32 centered degree range ]-180, 180[ + values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE + + if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), " + f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, " + f"but present value is {values[i]} degree. " + "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Rescale the present position to a nominal range [0, 100] %, + # useful for joints with linear motions like Aloha gripper + values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100 + + if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR): + raise JointOutOfRangeError( + f"Wrong motor position range detected for {name}. " + f"Expected to be in nominal range of [0, 100] % (a full linear translation), " + f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, " + f"but present value is {values[i]} %. " + "This might be due to a cable connection issue creating an artificial jump in motor values. " + "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" + ) + + return values + + def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """This function automatically detects issues with values of motors after calibration, and correct for these issues. + + Some motors might have values outside of expected maximum bounds after calibration. + For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given + a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position. + + Known issues: + #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors. + #2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha). + #3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration + or by human error during manual calibration. + + Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn. + Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`, + that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue. + + Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + """ + if motor_names is None: + motor_names = self.motor_names + + # Convert from unsigned int32 original range [0, 2**32] to signed float32 range + values = values.astype(np.float32) + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + if drive_mode: + values[i] *= -1 + + # Convert from initial range to range [-180, 180] degrees + calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + + # Solve this inequality to find the factor to shift the range into [-180, 180] degrees + # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE + # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE + # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution + low_factor = ( + -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + ) / resolution + upp_factor = ( + HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + ) / resolution + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Convert from initial range to range [0, 100] in % + calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 + in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + + # Solve this inequality to find the factor to shift the range into [0, 100] % + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 + # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 + # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100 + # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution + low_factor = (start_pos - values[i]) / resolution + upp_factor = (end_pos - values[i]) / resolution + + if not in_range: + # Get first integer between the two bounds + if low_factor < upp_factor: + factor = math.ceil(low_factor) + + if factor > upp_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + else: + factor = math.ceil(upp_factor) + + if factor > low_factor: + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + + logging.warning( + f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " + f"from '{out_of_range_str}' to '{in_range_str}'." + ) + + # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. + self.calibration["homing_offset"][calib_idx] += resolution * factor + + def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + """Inverse of `apply_calibration`.""" + if motor_names is None: + motor_names = self.motor_names + + for i, name in enumerate(motor_names): + calib_idx = self.calibration["motor_names"].index(name) + calib_mode = self.calibration["calib_mode"][calib_idx] + + if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: + drive_mode = self.calibration["drive_mode"][calib_idx] + homing_offset = self.calibration["homing_offset"][calib_idx] + _, model = self.motors[name] + resolution = self.model_resolution[model] + + # Convert from nominal 0-centered degree range [-180, 180] to + # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) + values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) + + # Subtract the homing offsets to come back to actual motor range of values + # which can be arbitrary. + values[i] -= homing_offset + + # Remove drive mode, which is the rotation direction of the motor, to come back to + # actual motor rotation direction which can be arbitrary. + if drive_mode: + values[i] *= -1 + + elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: + start_pos = self.calibration["start_pos"][calib_idx] + end_pos = self.calibration["end_pos"][calib_idx] + + # Convert from nominal lnear range of [0, 100] % to + # actual motor range of values which can be arbitrary. + values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos + + values = np.round(values).astype(np.int32) + return values + + def avoid_rotation_reset(self, values, motor_names, data_name): + if data_name not in self.track_positions: + self.track_positions[data_name] = { + "prev": [None] * len(self.motor_names), + # Assume False at initialization + "below_zero": [False] * len(self.motor_names), + "above_max": [False] * len(self.motor_names), + } + + track = self.track_positions[data_name] + + if motor_names is None: + motor_names = self.motor_names + + for i, name in enumerate(motor_names): + idx = self.motor_names.index(name) + + if track["prev"][idx] is None: + track["prev"][idx] = values[i] + continue + + # Detect a full rotation occurred + if abs(track["prev"][idx] - values[i]) > 2048: + # Position went below 0 and got reset to 4095 + if track["prev"][idx] < values[i]: + # So we set negative value by adding a full rotation + values[i] -= 4096 + + # Position went above 4095 and got reset to 0 + elif track["prev"][idx] > values[i]: + # So we add a full rotation + values[i] += 4096 + + track["prev"][idx] = values[i] + + return values + + def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + return_list = True + if not isinstance(motor_ids, list): + return_list = False + motor_ids = [motor_ids] + + assert_same_address(self.model_ctrl_table, self.motor_models, data_name) + addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] + group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) + for idx in motor_ids: + group.addParam(idx) + + for _ in range(num_retry): + comm = group.txRxPacket() + if comm == scs.COMM_SUCCESS: + break + + if comm != scs.COMM_SUCCESS: + raise ConnectionError( + f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + values = [] + for idx in motor_ids: + value = group.getData(idx, addr, bytes) + values.append(value) + + if return_list: + return values + else: + return values[0] + + def read(self, data_name, motor_names: str | list[str] | None = None): + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." + ) + + start_time = time.perf_counter() + + if motor_names is None: + motor_names = self.motor_names + + if isinstance(motor_names, str): + motor_names = [motor_names] + + motor_ids = [] + models = [] + for name in motor_names: + motor_idx, model = self.motors[name] + motor_ids.append(motor_idx) + models.append(model) + + assert_same_address(self.model_ctrl_table, models, data_name) + addr, bytes = self.model_ctrl_table[model][data_name] + group_key = get_group_sync_key(data_name, motor_names) + + if data_name not in self.group_readers: + # Very Important to flush the buffer! + self.port_handler.ser.reset_output_buffer() + self.port_handler.ser.reset_input_buffer() + + # create new group reader + self.group_readers[group_key] = scs.GroupSyncRead( + self.port_handler, self.packet_handler, addr, bytes + ) + for idx in motor_ids: + self.group_readers[group_key].addParam(idx) + + for _ in range(NUM_READ_RETRY): + comm = self.group_readers[group_key].txRxPacket() + if comm == scs.COMM_SUCCESS: + break + + if comm != scs.COMM_SUCCESS: + raise ConnectionError( + f"Read failed due to communication error on port {self.port} for group_key {group_key}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + values = [] + for idx in motor_ids: + value = self.group_readers[group_key].getData(idx, addr, bytes) + values.append(value) + + values = np.array(values) + + # Convert to signed int to use range [-2048, 2048] for our motor positions. + if data_name in CONVERT_UINT32_TO_INT32_REQUIRED: + values = values.astype(np.int32) + + if data_name in CALIBRATION_REQUIRED: + values = self.avoid_rotation_reset(values, motor_names, data_name) + + if data_name in CALIBRATION_REQUIRED and self.calibration is not None: + values = self.apply_calibration_autocorrect(values, motor_names) + + # log the number of seconds it took to read the data from the motors + delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + self.logs[delta_ts_name] = time.perf_counter() - start_time + + # log the utc time at which the data was received + ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) + self.logs[ts_utc_name] = capture_timestamp_utc() + + return values + + def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + if not isinstance(motor_ids, list): + motor_ids = [motor_ids] + if not isinstance(values, list): + values = [values] + + assert_same_address(self.model_ctrl_table, motor_models, data_name) + addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] + group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) + for idx, value in zip(motor_ids, values, strict=True): + data = convert_to_bytes(value, bytes, self.mock) + group.addParam(idx, data) + + for _ in range(num_retry): + comm = group.txPacket() + if comm == scs.COMM_SUCCESS: + break + + if comm != scs.COMM_SUCCESS: + raise ConnectionError( + f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." + ) + + start_time = time.perf_counter() + + if self.mock: + import lerobot.common.mocks.motors.mock_scservo_sdk as scs + else: + import scservo_sdk as scs + + if motor_names is None: + motor_names = self.motor_names + + if isinstance(motor_names, str): + motor_names = [motor_names] + + if isinstance(values, (int, float, np.integer)): + values = [int(values)] * len(motor_names) + + values = np.array(values) + + motor_ids = [] + models = [] + for name in motor_names: + motor_idx, model = self.motors[name] + motor_ids.append(motor_idx) + models.append(model) + + if data_name in CALIBRATION_REQUIRED and self.calibration is not None: + values = self.revert_calibration(values, motor_names) + + values = values.tolist() + + assert_same_address(self.model_ctrl_table, models, data_name) + addr, bytes = self.model_ctrl_table[model][data_name] + group_key = get_group_sync_key(data_name, motor_names) + + init_group = data_name not in self.group_readers + if init_group: + self.group_writers[group_key] = scs.GroupSyncWrite( + self.port_handler, self.packet_handler, addr, bytes + ) + + for idx, value in zip(motor_ids, values, strict=True): + data = convert_to_bytes(value, bytes, self.mock) + if init_group: + self.group_writers[group_key].addParam(idx, data) + else: + self.group_writers[group_key].changeParam(idx, data) + + comm = self.group_writers[group_key].txPacket() + if comm != scs.COMM_SUCCESS: + raise ConnectionError( + f"Write failed due to communication error on port {self.port} for group_key {group_key}: " + f"{self.packet_handler.getTxRxResult(comm)}" + ) + + # log the number of seconds it took to write the data to the motors + delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + self.logs[delta_ts_name] = time.perf_counter() - start_time + + # TODO(rcadene): should we log the time before sending the write command? + # log the utc time when the write has been completed + ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) + self.logs[ts_utc_name] = capture_timestamp_utc() + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first." + ) + + if self.port_handler is not None: + self.port_handler.closePort() + self.port_handler = None + + self.packet_handler = None + self.group_readers = {} + self.group_writers = {} + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robot_devices/motors/utils.py b/lerobot/common/robot_devices/motors/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd86f4c64ef621776300dba7041bc0b3f98ec66e --- /dev/null +++ b/lerobot/common/robot_devices/motors/utils.py @@ -0,0 +1,67 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from lerobot.common.robot_devices.motors.configs import ( + DynamixelMotorsBusConfig, + FeetechMotorsBusConfig, + MotorsBusConfig, +) + + +class MotorsBus(Protocol): + def motor_names(self): ... + def set_calibration(self): ... + def apply_calibration(self): ... + def revert_calibration(self): ... + def read(self): ... + def write(self): ... + + +def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]: + motors_buses = {} + + for key, cfg in motors_bus_configs.items(): + if cfg.type == "dynamixel": + from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus + + motors_buses[key] = DynamixelMotorsBus(cfg) + + elif cfg.type == "feetech": + from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus + + motors_buses[key] = FeetechMotorsBus(cfg) + + else: + raise ValueError(f"The motor type '{cfg.type}' is not valid.") + + return motors_buses + + +def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: + if motor_type == "dynamixel": + from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus + + config = DynamixelMotorsBusConfig(**kwargs) + return DynamixelMotorsBus(config) + + elif motor_type == "feetech": + from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus + + config = FeetechMotorsBusConfig(**kwargs) + return FeetechMotorsBus(config) + + else: + raise ValueError(f"The motor type '{motor_type}' is not valid.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..591801cc582c76d8bd0ea4018a0fd8154ec9c882 --- /dev/null +++ b/lerobot/common/robot_devices/robots/configs.py @@ -0,0 +1,613 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass, field +from typing import Sequence + +import draccus + +from lerobot.common.robot_devices.cameras.configs import ( + CameraConfig, + IntelRealSenseCameraConfig, + OpenCVCameraConfig, +) +from lerobot.common.robot_devices.motors.configs import ( + DynamixelMotorsBusConfig, + FeetechMotorsBusConfig, + MotorsBusConfig, +) + + +@dataclass +class RobotConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction +@dataclass +class ManipulatorRobotConfig(RobotConfig): + leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) + follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) + cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) + + # Optionally limit the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length + # as the number of motors in your follower arms (assumes all follower arms have the same number of + # motors). + max_relative_target: list[float] | float | None = None + + # Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it + # possible to squeeze the gripper and have it spring back to an open position on its own. If None, the + # gripper is not put in torque mode. + gripper_open_degree: float | None = None + + mock: bool = False + + def __post_init__(self): + if self.mock: + for arm in self.leader_arms.values(): + if not arm.mock: + arm.mock = True + for arm in self.follower_arms.values(): + if not arm.mock: + arm.mock = True + for cam in self.cameras.values(): + if not cam.mock: + cam.mock = True + + if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): + for name in self.follower_arms: + if len(self.follower_arms[name].motors) != len(self.max_relative_target): + raise ValueError( + f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " + f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " + f"`max_relative_target` list has as many parameters as there are motors per arm. " + "Note: This feature does not yet work with robots where different follower arms have " + "different numbers of motors." + ) + + +@RobotConfig.register_subclass("aloha") +@dataclass +class AlohaRobotConfig(ManipulatorRobotConfig): + # Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been + # properly assembled, no manual calibration step is expected. If you need to run manual calibration, + # simply update this path to ".cache/calibration/aloha" + calibration_dir: str = ".cache/calibration/aloha_default" + + # /!\ FOR SAFETY, READ THIS /!\ + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + # For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default. + # When you feel more confident with teleoperation or running the policy, you can extend + # this safety limit and even removing it by setting it to `null`. + # Also, everything is expected to work safely out-of-the-box, but we highly advise to + # first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml), + # then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully + max_relative_target: int | None = 5 + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "left": DynamixelMotorsBusConfig( + # window_x + port="/dev/ttyDXL_leader_left", + motors={ + # name: (index, model) + "waist": [1, "xm430-w350"], + "shoulder": [2, "xm430-w350"], + "shoulder_shadow": [3, "xm430-w350"], + "elbow": [4, "xm430-w350"], + "elbow_shadow": [5, "xm430-w350"], + "forearm_roll": [6, "xm430-w350"], + "wrist_angle": [7, "xm430-w350"], + "wrist_rotate": [8, "xl430-w250"], + "gripper": [9, "xc430-w150"], + }, + ), + "right": DynamixelMotorsBusConfig( + # window_x + port="/dev/ttyDXL_leader_right", + motors={ + # name: (index, model) + "waist": [1, "xm430-w350"], + "shoulder": [2, "xm430-w350"], + "shoulder_shadow": [3, "xm430-w350"], + "elbow": [4, "xm430-w350"], + "elbow_shadow": [5, "xm430-w350"], + "forearm_roll": [6, "xm430-w350"], + "wrist_angle": [7, "xm430-w350"], + "wrist_rotate": [8, "xl430-w250"], + "gripper": [9, "xc430-w150"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "left": DynamixelMotorsBusConfig( + port="/dev/ttyDXL_follower_left", + motors={ + # name: (index, model) + "waist": [1, "xm540-w270"], + "shoulder": [2, "xm540-w270"], + "shoulder_shadow": [3, "xm540-w270"], + "elbow": [4, "xm540-w270"], + "elbow_shadow": [5, "xm540-w270"], + "forearm_roll": [6, "xm540-w270"], + "wrist_angle": [7, "xm540-w270"], + "wrist_rotate": [8, "xm430-w350"], + "gripper": [9, "xm430-w350"], + }, + ), + "right": DynamixelMotorsBusConfig( + port="/dev/ttyDXL_follower_right", + motors={ + # name: (index, model) + "waist": [1, "xm540-w270"], + "shoulder": [2, "xm540-w270"], + "shoulder_shadow": [3, "xm540-w270"], + "elbow": [4, "xm540-w270"], + "elbow_shadow": [5, "xm540-w270"], + "forearm_roll": [6, "xm540-w270"], + "wrist_angle": [7, "xm540-w270"], + "wrist_rotate": [8, "xm430-w350"], + "gripper": [9, "xm430-w350"], + }, + ), + } + ) + + # Troubleshooting: If one of your IntelRealSense cameras freeze during + # data recording due to bandwidth limit, you might need to plug the camera + # on another USB hub or PCIe card. + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "cam_high": IntelRealSenseCameraConfig( + serial_number=128422271347, + fps=30, + width=640, + height=480, + ), + "cam_low": IntelRealSenseCameraConfig( + serial_number=130322270656, + fps=30, + width=640, + height=480, + ), + "cam_left_wrist": IntelRealSenseCameraConfig( + serial_number=218622272670, + fps=30, + width=640, + height=480, + ), + "cam_right_wrist": IntelRealSenseCameraConfig( + serial_number=130322272300, + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False + + +@RobotConfig.register_subclass("koch") +@dataclass +class KochRobotConfig(ManipulatorRobotConfig): + calibration_dir: str = ".cache/calibration/koch" + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem585A0085511", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl330-m077"], + "shoulder_lift": [2, "xl330-m077"], + "elbow_flex": [3, "xl330-m077"], + "wrist_flex": [4, "xl330-m077"], + "wrist_roll": [5, "xl330-m077"], + "gripper": [6, "xl330-m077"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem585A0076891", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl430-w250"], + "shoulder_lift": [2, "xl430-w250"], + "elbow_flex": [3, "xl330-m288"], + "wrist_flex": [4, "xl330-m288"], + "wrist_roll": [5, "xl330-m288"], + "gripper": [6, "xl330-m288"], + }, + ), + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=1, + fps=30, + width=640, + height=480, + ), + } + ) + + # ~ Koch specific settings ~ + # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible + # to squeeze the gripper and have it spring back to an open position on its own. + gripper_open_degree: float = 35.156 + + mock: bool = False + + +@RobotConfig.register_subclass("koch_bimanual") +@dataclass +class KochBimanualRobotConfig(ManipulatorRobotConfig): + calibration_dir: str = ".cache/calibration/koch_bimanual" + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "left": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem585A0085511", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl330-m077"], + "shoulder_lift": [2, "xl330-m077"], + "elbow_flex": [3, "xl330-m077"], + "wrist_flex": [4, "xl330-m077"], + "wrist_roll": [5, "xl330-m077"], + "gripper": [6, "xl330-m077"], + }, + ), + "right": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem575E0031751", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl330-m077"], + "shoulder_lift": [2, "xl330-m077"], + "elbow_flex": [3, "xl330-m077"], + "wrist_flex": [4, "xl330-m077"], + "wrist_roll": [5, "xl330-m077"], + "gripper": [6, "xl330-m077"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "left": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem585A0076891", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl430-w250"], + "shoulder_lift": [2, "xl430-w250"], + "elbow_flex": [3, "xl330-m288"], + "wrist_flex": [4, "xl330-m288"], + "wrist_roll": [5, "xl330-m288"], + "gripper": [6, "xl330-m288"], + }, + ), + "right": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem575E0032081", + motors={ + # name: (index, model) + "shoulder_pan": [1, "xl430-w250"], + "shoulder_lift": [2, "xl430-w250"], + "elbow_flex": [3, "xl330-m288"], + "wrist_flex": [4, "xl330-m288"], + "wrist_roll": [5, "xl330-m288"], + "gripper": [6, "xl330-m288"], + }, + ), + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=1, + fps=30, + width=640, + height=480, + ), + } + ) + + # ~ Koch specific settings ~ + # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible + # to squeeze the gripper and have it spring back to an open position on its own. + gripper_open_degree: float = 35.156 + + mock: bool = False + + +@RobotConfig.register_subclass("moss") +@dataclass +class MossRobotConfig(ManipulatorRobotConfig): + calibration_dir: str = ".cache/calibration/moss" + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/tty.usbmodem58760431091", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/tty.usbmodem585A0076891", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + }, + ), + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=1, + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False + + +@RobotConfig.register_subclass("so100") +@dataclass +class So100RobotConfig(ManipulatorRobotConfig): + calibration_dir: str = ".cache/calibration/so100" + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/tty.usbmodem58760429101", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/tty.usbmodem58760435821", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + }, + ), + } + ) + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "laptop": OpenCVCameraConfig( + camera_index=0, + fps=30, + width=640, + height=480, + ), + "phone": OpenCVCameraConfig( + camera_index=1, + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False + + +@RobotConfig.register_subclass("stretch") +@dataclass +class StretchRobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "navigation": OpenCVCameraConfig( + camera_index="/dev/hello-nav-head-camera", + fps=10, + width=1280, + height=720, + rotation=-90, + ), + "head": IntelRealSenseCameraConfig( + name="Intel RealSense D435I", + fps=30, + width=640, + height=480, + rotation=90, + ), + "wrist": IntelRealSenseCameraConfig( + name="Intel RealSense D405", + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False + + +@RobotConfig.register_subclass("lekiwi") +@dataclass +class LeKiwiRobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # Network Configuration + ip: str = "192.168.0.193" + port: int = 5555 + video_port: int = 5556 + + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "front": OpenCVCameraConfig( + camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90 + ), + "wrist": OpenCVCameraConfig( + camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180 + ), + } + ) + + calibration_dir: str = ".cache/calibration/lekiwi" + + leader_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/tty.usbmodem585A0077581", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + }, + ), + } + ) + + follower_arms: dict[str, MotorsBusConfig] = field( + default_factory=lambda: { + "main": FeetechMotorsBusConfig( + port="/dev/ttyACM0", + motors={ + # name: (index, model) + "shoulder_pan": [1, "sts3215"], + "shoulder_lift": [2, "sts3215"], + "elbow_flex": [3, "sts3215"], + "wrist_flex": [4, "sts3215"], + "wrist_roll": [5, "sts3215"], + "gripper": [6, "sts3215"], + "left_wheel": (7, "sts3215"), + "back_wheel": (8, "sts3215"), + "right_wheel": (9, "sts3215"), + }, + ), + } + ) + + teleop_keys: dict[str, str] = field( + default_factory=lambda: { + # Movement + "forward": "w", + "backward": "s", + "left": "a", + "right": "d", + "rotate_left": "z", + "rotate_right": "x", + # Speed control + "speed_up": "r", + "speed_down": "f", + # quit teleop + "quit": "q", + } + ) + + mock: bool = False diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..98fe8754f8723b33300693d4602e32571e8787a3 --- /dev/null +++ b/lerobot/common/robot_devices/robots/dynamixel_calibration.py @@ -0,0 +1,144 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logic to calibrate a robot arm built with dynamixel motors""" +# TODO(rcadene, aliberts): move this logic into the robot code when refactoring + +import numpy as np + +from lerobot.common.robot_devices.motors.dynamixel import ( + CalibrationMode, + TorqueMode, + convert_degrees_to_steps, +) +from lerobot.common.robot_devices.motors.utils import MotorsBus + +URL_TEMPLATE = ( + "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +) + +# The following positions are provided in nominal degree range ]-180, +180[ +# For more info on these constants, see comments in the code where they get used. +ZERO_POSITION_DEGREE = 0 +ROTATED_POSITION_DEGREE = 90 + + +def assert_drive_mode(drive_mode): + # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. + if not np.all(np.isin(drive_mode, [0, 1])): + raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + + +def apply_drive_mode(position, drive_mode): + assert_drive_mode(drive_mode) + # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted, + # to [-1, 1] with 1 indicates original rotation direction and -1 inverted. + signed_drive_mode = -(drive_mode * 2 - 1) + position *= signed_drive_mode + return position + + +def compute_nearest_rounded_position(position, models): + delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models) + nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn + return nearest_pos.astype(position.dtype) + + +def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + """This function ensures that a neural network trained on data collected on a given robot + can work on another robot. For instance before calibration, setting a same goal position + for each motor of two different robots will get two very different positions. But after calibration, + the two robots will move to the same position.To this end, this function computes the homing offset + and the drive mode for each motor of a given robot. + + Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps + to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions + being 0. During the calibration process, you will need to manually move the robot to this "zero position". + + Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled + in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot + to the "rotated position". + + After calibration, the homing offsets and drive modes are stored in a cache. + + Example of usage: + ```python + run_arm_calibration(arm, "koch", "left", "follower") + ``` + """ + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run calibration, the torque must be disabled on all motors.") + + print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") + + print("\nMove arm to zero position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + input("Press Enter to continue...") + + # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. + # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will + # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. + zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) + + # Compute homing offset so that `present_position + homing_offset ~= target_position`. + zero_pos = arm.read("Present_Position") + zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models) + homing_offset = zero_target_pos - zero_nearest_pos + + # The rotated target position corresponds to a rotation of a quarter turn from the zero position. + # This allows to identify the rotation direction of each motor. + # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction + # is inverted. However, for the calibration being successful, we need everyone to follow the same target position. + # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which + # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view + # of the previous motor in the kinetic chain. + print("\nMove arm to rotated target position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + input("Press Enter to continue...") + + rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + + # Find drive mode by rotating each motor by a quarter of a turn. + # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). + rotated_pos = arm.read("Present_Position") + drive_mode = (rotated_pos < zero_pos).astype(np.int32) + + # Re-compute homing offset to take into account drive mode + rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) + rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) + homing_offset = rotated_target_pos - rotated_nearest_pos + + print("\nMove arm to rest position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + input("Press Enter to continue...") + print() + + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names) + + # TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml? + if robot_type in ["aloha"] and "gripper" in arm.motor_names: + # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] + calib_idx = arm.motor_names.index("gripper") + calib_mode[calib_idx] = CalibrationMode.LINEAR.name + + calib_data = { + "homing_offset": homing_offset.tolist(), + "drive_mode": drive_mode.tolist(), + "start_pos": zero_pos.tolist(), + "end_pos": rotated_pos.tolist(), + "calib_mode": calib_mode, + "motor_names": arm.motor_names, + } + return calib_data diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1e7180e8d28d6ac3b3c95d9e907a89f24c6c71 --- /dev/null +++ b/lerobot/common/robot_devices/robots/feetech_calibration.py @@ -0,0 +1,498 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logic to calibrate a robot arm built with feetech motors""" +# TODO(rcadene, aliberts): move this logic into the robot code when refactoring + +import time + +import numpy as np + +from lerobot.common.robot_devices.motors.feetech import ( + CalibrationMode, + TorqueMode, + convert_degrees_to_steps, +) +from lerobot.common.robot_devices.motors.utils import MotorsBus + +URL_TEMPLATE = ( + "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +) + +# The following positions are provided in nominal degree range ]-180, +180[ +# For more info on these constants, see comments in the code where they get used. +ZERO_POSITION_DEGREE = 0 +ROTATED_POSITION_DEGREE = 90 + + +def assert_drive_mode(drive_mode): + # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. + if not np.all(np.isin(drive_mode, [0, 1])): + raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + + +def apply_drive_mode(position, drive_mode): + assert_drive_mode(drive_mode) + # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted, + # to [-1, 1] with 1 indicates original rotation direction and -1 inverted. + signed_drive_mode = -(drive_mode * 2 - 1) + position *= signed_drive_mode + return position + + +def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None): + count = 0 + while True: + present_pos = arm.read("Present_Position", motor_name) + if positive_direction: + # Move +100 steps every time. Lower the steps to lower the speed at which the arm moves. + arm.write("Goal_Position", present_pos + 100, motor_name) + else: + arm.write("Goal_Position", present_pos - 100, motor_name) + + if while_move_hook is not None: + while_move_hook() + + present_pos = arm.read("Present_Position", motor_name).item() + present_speed = arm.read("Present_Speed", motor_name).item() + present_current = arm.read("Present_Current", motor_name).item() + # present_load = arm.read("Present_Load", motor_name).item() + # present_voltage = arm.read("Present_Voltage", motor_name).item() + # present_temperature = arm.read("Present_Temperature", motor_name).item() + + # print(f"{present_pos=}") + # print(f"{present_speed=}") + # print(f"{present_current=}") + # print(f"{present_load=}") + # print(f"{present_voltage=}") + # print(f"{present_temperature=}") + + if present_speed == 0 and present_current > 40: + count += 1 + if count > 100 or present_current > 300: + return present_pos + else: + count = 0 + + +def move_to_calibrate( + arm, + motor_name, + invert_drive_mode=False, + positive_first=True, + in_between_move_hook=None, + while_move_hook=None, +): + initial_pos = arm.read("Present_Position", motor_name) + + if positive_first: + p_present_pos = move_until_block( + arm, motor_name, positive_direction=True, while_move_hook=while_move_hook + ) + else: + n_present_pos = move_until_block( + arm, motor_name, positive_direction=False, while_move_hook=while_move_hook + ) + + if in_between_move_hook is not None: + in_between_move_hook() + + if positive_first: + n_present_pos = move_until_block( + arm, motor_name, positive_direction=False, while_move_hook=while_move_hook + ) + else: + p_present_pos = move_until_block( + arm, motor_name, positive_direction=True, while_move_hook=while_move_hook + ) + + zero_pos = (n_present_pos + p_present_pos) / 2 + + calib_data = { + "initial_pos": initial_pos, + "homing_offset": zero_pos if invert_drive_mode else -zero_pos, + "invert_drive_mode": invert_drive_mode, + "drive_mode": -1 if invert_drive_mode else 0, + "zero_pos": zero_pos, + "start_pos": n_present_pos if invert_drive_mode else p_present_pos, + "end_pos": p_present_pos if invert_drive_mode else n_present_pos, + } + return calib_data + + +def apply_offset(calib, offset): + calib["zero_pos"] += offset + if calib["drive_mode"]: + calib["homing_offset"] += offset + else: + calib["homing_offset"] -= offset + return calib + + +def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + if robot_type == "so100": + return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) + elif robot_type == "moss": + return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type) + else: + raise ValueError(robot_type) + + +def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run calibration, the torque must be disabled on all motors.") + + if not (robot_type == "so100" and arm_type == "follower"): + raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") + + print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") + + print("\nMove arm to initial position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + input("Press Enter to continue...") + + # Lower the acceleration of the motors (in [0,254]) + initial_acceleration = arm.read("Acceleration") + arm.write("Lock", 0) + arm.write("Acceleration", 10) + time.sleep(1) + + arm.write("Torque_Enable", TorqueMode.ENABLED.value) + + print(f'{arm.read("Present_Position", "elbow_flex")=}') + + calib = {} + + init_wf_pos = arm.read("Present_Position", "wrist_flex") + init_sl_pos = arm.read("Present_Position", "shoulder_lift") + init_ef_pos = arm.read("Present_Position", "elbow_flex") + arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex") + arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift") + arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex") + time.sleep(2) + + print("Calibrate shoulder_pan") + calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan") + arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") + time.sleep(1) + + print("Calibrate gripper") + calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True) + time.sleep(1) + + print("Calibrate wrist_flex") + calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex") + calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80) + + def in_between_move_hook(): + nonlocal arm, calib + time.sleep(2) + ef_pos = arm.read("Present_Position", "elbow_flex") + sl_pos = arm.read("Present_Position", "shoulder_lift") + arm.write("Goal_Position", ef_pos + 1024, "elbow_flex") + arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift") + time.sleep(2) + + print("Calibrate elbow_flex") + calib["elbow_flex"] = move_to_calibrate( + arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook + ) + calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) + + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") + time.sleep(1) + + def in_between_move_hook(): + nonlocal arm, calib + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex") + + print("Calibrate shoulder_lift") + calib["shoulder_lift"] = move_to_calibrate( + arm, + "shoulder_lift", + invert_drive_mode=True, + positive_first=False, + in_between_move_hook=in_between_move_hook, + ) + # add an 30 steps as offset to align with body + calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50) + + def while_move_hook(): + nonlocal arm, calib + positions = { + "shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600), + "elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700), + "wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800), + "gripper": round(calib["gripper"]["end_pos"]), + } + arm.write("Goal_Position", list(positions.values()), list(positions.keys())) + + arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift") + time.sleep(2) + arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") + time.sleep(2) + arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") + time.sleep(2) + arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") + time.sleep(2) + + print("Calibrate wrist_roll") + calib["wrist_roll"] = move_to_calibrate( + arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook + ) + + arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") + time.sleep(1) + arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper") + time.sleep(1) + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") + time.sleep(1) + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") + arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") + time.sleep(1) + arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") + time.sleep(1) + + calib_modes = [] + for name in arm.motor_names: + if name == "gripper": + calib_modes.append(CalibrationMode.LINEAR.name) + else: + calib_modes.append(CalibrationMode.DEGREE.name) + + calib_dict = { + "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names], + "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names], + "start_pos": [calib[name]["start_pos"] for name in arm.motor_names], + "end_pos": [calib[name]["end_pos"] for name in arm.motor_names], + "calib_mode": calib_modes, + "motor_names": arm.motor_names, + } + + # Re-enable original accerlation + arm.write("Lock", 0) + arm.write("Acceleration", initial_acceleration) + time.sleep(1) + + return calib_dict + + +def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run calibration, the torque must be disabled on all motors.") + + if not (robot_type == "moss" and arm_type == "follower"): + raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") + + print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") + + print("\nMove arm to initial position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + input("Press Enter to continue...") + + # Lower the acceleration of the motors (in [0,254]) + initial_acceleration = arm.read("Acceleration") + arm.write("Lock", 0) + arm.write("Acceleration", 10) + time.sleep(1) + + arm.write("Torque_Enable", TorqueMode.ENABLED.value) + + sl_pos = arm.read("Present_Position", "shoulder_lift") + arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift") + ef_pos = arm.read("Present_Position", "elbow_flex") + arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex") + time.sleep(2) + + calib = {} + + print("Calibrate shoulder_pan") + calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan") + arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") + time.sleep(1) + + print("Calibrate gripper") + calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True) + time.sleep(1) + + print("Calibrate wrist_flex") + calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True) + calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024) + + wr_pos = arm.read("Present_Position", "wrist_roll") + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") + time.sleep(1) + arm.write("Goal_Position", wr_pos - 1024, "wrist_roll") + time.sleep(1) + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex") + time.sleep(1) + arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper") + time.sleep(1) + + print("Calibrate wrist_roll") + calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True) + calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790) + + arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll") + arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper") + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") + time.sleep(1) + arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex") + + def in_between_move_elbow_flex_hook(): + nonlocal arm, calib + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") + + print("Calibrate elbow_flex") + calib["elbow_flex"] = move_to_calibrate( + arm, + "elbow_flex", + invert_drive_mode=True, + in_between_move_hook=in_between_move_elbow_flex_hook, + ) + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") + + def in_between_move_shoulder_lift_hook(): + nonlocal arm, calib + sl = arm.read("Present_Position", "shoulder_lift") + arm.write("Goal_Position", sl - 1500, "shoulder_lift") + time.sleep(1) + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex") + time.sleep(1) + arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex") + time.sleep(1) + + print("Calibrate shoulder_lift") + calib["shoulder_lift"] = move_to_calibrate( + arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook + ) + calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024) + + arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") + time.sleep(1) + arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") + time.sleep(2) + + calib_modes = [] + for name in arm.motor_names: + if name == "gripper": + calib_modes.append(CalibrationMode.LINEAR.name) + else: + calib_modes.append(CalibrationMode.DEGREE.name) + + calib_dict = { + "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names], + "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names], + "start_pos": [calib[name]["start_pos"] for name in arm.motor_names], + "end_pos": [calib[name]["end_pos"] for name in arm.motor_names], + "calib_mode": calib_modes, + "motor_names": arm.motor_names, + } + + # Re-enable original accerlation + arm.write("Lock", 0) + arm.write("Acceleration", initial_acceleration) + time.sleep(1) + + return calib_dict + + +def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): + """This function ensures that a neural network trained on data collected on a given robot + can work on another robot. For instance before calibration, setting a same goal position + for each motor of two different robots will get two very different positions. But after calibration, + the two robots will move to the same position.To this end, this function computes the homing offset + and the drive mode for each motor of a given robot. + + Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps + to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions + being 0. During the calibration process, you will need to manually move the robot to this "zero position". + + Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled + in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot + to the "rotated position". + + After calibration, the homing offsets and drive modes are stored in a cache. + + Example of usage: + ```python + run_arm_calibration(arm, "so100", "left", "follower") + ``` + """ + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run calibration, the torque must be disabled on all motors.") + + print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") + + print("\nMove arm to zero position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + input("Press Enter to continue...") + + # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. + # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will + # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. + zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) + + # Compute homing offset so that `present_position + homing_offset ~= target_position`. + zero_pos = arm.read("Present_Position") + homing_offset = zero_target_pos - zero_pos + + # The rotated target position corresponds to a rotation of a quarter turn from the zero position. + # This allows to identify the rotation direction of each motor. + # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction + # is inverted. However, for the calibration being successful, we need everyone to follow the same target position. + # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which + # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view + # of the previous motor in the kinetic chain. + print("\nMove arm to rotated target position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + input("Press Enter to continue...") + + rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + + # Find drive mode by rotating each motor by a quarter of a turn. + # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). + rotated_pos = arm.read("Present_Position") + drive_mode = (rotated_pos < zero_pos).astype(np.int32) + + # Re-compute homing offset to take into account drive mode + rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) + homing_offset = rotated_target_pos - rotated_drived_pos + + print("\nMove arm to rest position") + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + input("Press Enter to continue...") + print() + + # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] + calib_modes = [] + for name in arm.motor_names: + if name == "gripper": + calib_modes.append(CalibrationMode.LINEAR.name) + else: + calib_modes.append(CalibrationMode.DEGREE.name) + + calib_dict = { + "homing_offset": homing_offset.tolist(), + "drive_mode": drive_mode.tolist(), + "start_pos": zero_pos.tolist(), + "end_pos": rotated_pos.tolist(), + "calib_mode": calib_modes, + "motor_names": arm.motor_names, + } + return calib_dict diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf52d21d236db888c3797fbe9ea5d57fdaaee3a --- /dev/null +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -0,0 +1,224 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import threading +import time +from pathlib import Path + +import cv2 +import zmq + +from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi + + +def setup_zmq_sockets(config): + context = zmq.Context() + cmd_socket = context.socket(zmq.PULL) + cmd_socket.setsockopt(zmq.CONFLATE, 1) + cmd_socket.bind(f"tcp://*:{config.port}") + + video_socket = context.socket(zmq.PUSH) + video_socket.setsockopt(zmq.CONFLATE, 1) + video_socket.bind(f"tcp://*:{config.video_port}") + + return context, cmd_socket, video_socket + + +def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): + while not stop_event.is_set(): + local_dict = {} + for name, cam in cameras.items(): + frame = cam.async_read() + ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) + if ret: + local_dict[name] = base64.b64encode(buffer).decode("utf-8") + else: + local_dict[name] = "" + with images_lock: + latest_images_dict.update(local_dict) + time.sleep(0.01) + + +def calibrate_follower_arm(motors_bus, calib_dir_str): + """ + Calibrates the follower arm. Attempts to load an existing calibration file; + if not found, runs manual calibration and saves the result. + """ + calib_dir = Path(calib_dir_str) + calib_dir.mkdir(parents=True, exist_ok=True) + calib_file = calib_dir / "main_follower.json" + try: + from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration + except ImportError: + print("[WARNING] Calibration function not available. Skipping calibration.") + return + + if calib_file.exists(): + with open(calib_file) as f: + calibration = json.load(f) + print(f"[INFO] Loaded calibration from {calib_file}") + else: + print("[INFO] Calibration file not found. Running manual calibration...") + calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower") + print(f"[INFO] Calibration complete. Saving to {calib_file}") + with open(calib_file, "w") as f: + json.dump(calibration, f) + try: + motors_bus.set_calibration(calibration) + print("[INFO] Applied calibration for follower arm.") + except Exception as e: + print(f"[WARNING] Could not apply calibration: {e}") + + +def run_lekiwi(robot_config): + """ + Runs the LeKiwi robot: + - Sets up cameras and connects them. + - Initializes the follower arm motors. + - Calibrates the follower arm if necessary. + - Creates ZeroMQ sockets for receiving commands and streaming observations. + - Processes incoming commands (arm and wheel commands) and sends back sensor and camera data. + """ + # Import helper functions and classes + from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs + from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode + + # Initialize cameras from the robot configuration. + cameras = make_cameras_from_configs(robot_config.cameras) + for cam in cameras.values(): + cam.connect() + + # Initialize the motors bus using the follower arm configuration. + motor_config = robot_config.follower_arms.get("main") + if motor_config is None: + print("[ERROR] Follower arm 'main' configuration not found.") + return + motors_bus = FeetechMotorsBus(motor_config) + motors_bus.connect() + + # Calibrate the follower arm. + calibrate_follower_arm(motors_bus, robot_config.calibration_dir) + + # Create the LeKiwi robot instance. + robot = LeKiwi(motors_bus) + + # Define the expected arm motor IDs. + arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] + + # Disable torque for each arm motor. + for motor in arm_motor_ids: + motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor) + + # Set up ZeroMQ sockets. + context, cmd_socket, video_socket = setup_zmq_sockets(robot_config) + + # Start the camera capture thread. + latest_images_dict = {} + images_lock = threading.Lock() + stop_event = threading.Event() + cam_thread = threading.Thread( + target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True + ) + cam_thread.start() + + last_cmd_time = time.time() + print("LeKiwi robot server started. Waiting for commands...") + + try: + while True: + loop_start_time = time.time() + + # Process incoming commands (non-blocking). + while True: + try: + msg = cmd_socket.recv_string(zmq.NOBLOCK) + except zmq.Again: + break + try: + data = json.loads(msg) + # Process arm position commands. + if "arm_positions" in data: + arm_positions = data["arm_positions"] + if not isinstance(arm_positions, list): + print(f"[ERROR] Invalid arm_positions: {arm_positions}") + elif len(arm_positions) < len(arm_motor_ids): + print( + f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}" + ) + else: + for motor, pos in zip(arm_motor_ids, arm_positions, strict=False): + motors_bus.write("Goal_Position", pos, motor) + # Process wheel (base) commands. + if "raw_velocity" in data: + raw_command = data["raw_velocity"] + # Expect keys: "left_wheel", "back_wheel", "right_wheel". + command_speeds = [ + int(raw_command.get("left_wheel", 0)), + int(raw_command.get("back_wheel", 0)), + int(raw_command.get("right_wheel", 0)), + ] + robot.set_velocity(command_speeds) + last_cmd_time = time.time() + except Exception as e: + print(f"[ERROR] Parsing message failed: {e}") + + # Watchdog: stop the robot if no command is received for over 0.5 seconds. + now = time.time() + if now - last_cmd_time > 0.5: + robot.stop() + last_cmd_time = now + + # Read current wheel speeds from the robot. + current_velocity = robot.read_velocity() + + # Read the follower arm state from the motors bus. + follower_arm_state = [] + for motor in arm_motor_ids: + try: + pos = motors_bus.read("Present_Position", motor) + # Convert the position to a float (or use as is if already numeric). + follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos) + except Exception as e: + print(f"[ERROR] Reading motor {motor} failed: {e}") + + # Get the latest camera images. + with images_lock: + images_dict_copy = dict(latest_images_dict) + + # Build the observation dictionary. + observation = { + "images": images_dict_copy, + "present_speed": current_velocity, + "follower_arm_state": follower_arm_state, + } + # Send the observation over the video socket. + video_socket.send_string(json.dumps(observation)) + + # Ensure a short sleep to avoid overloading the CPU. + elapsed = time.time() - loop_start_time + time.sleep( + max(0.033 - elapsed, 0) + ) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd + except KeyboardInterrupt: + print("Shutting down LeKiwi server.") + finally: + stop_event.set() + cam_thread.join() + robot.stop() + motors_bus.disconnect() + cmd_socket.close() + video_socket.close() + context.term() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py new file mode 100644 index 0000000000000000000000000000000000000000..9173abc628810d35a7eda1eb9568da018f8d316c --- /dev/null +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -0,0 +1,627 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains logic to instantiate a robot, read information from its motors and cameras, +and send orders to its motors. +""" +# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated +# calibration procedure, to make it easy for people to add their own robot. + +import json +import logging +import time +import warnings +from pathlib import Path + +import numpy as np +import torch + +from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs +from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig +from lerobot.common.robot_devices.robots.utils import get_arm_id +from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError + + +def ensure_safe_goal_position( + goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] +): + # Cap relative action target magnitude for safety. + diff = goal_pos - present_pos + max_relative_target = torch.tensor(max_relative_target) + safe_diff = torch.minimum(diff, max_relative_target) + safe_diff = torch.maximum(safe_diff, -max_relative_target) + safe_goal_pos = present_pos + safe_diff + + if not torch.allclose(goal_pos, safe_goal_pos): + logging.warning( + "Relative goal position magnitude had to be clamped to be safe.\n" + f" requested relative goal position target: {diff}\n" + f" clamped relative goal position target: {safe_diff}" + ) + + return safe_goal_pos + + +class ManipulatorRobot: + # TODO(rcadene): Implement force feedback + """This class allows to control any manipulator robot of various number of motors. + + Non exhaustive list of robots: + - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed + by Alexander Koch from [Tau Robotics](https://tau-robotics.com) + - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss + - [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics + + Example of instantiation, a pre-defined robot config is required: + ```python + robot = ManipulatorRobot(KochRobotConfig()) + ``` + + Example of overwriting motors during instantiation: + ```python + # Defines how to communicate with the motors of the leader and follower arms + leader_arms = { + "main": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem575E0031751", + motors={ + # name: (index, model) + "shoulder_pan": (1, "xl330-m077"), + "shoulder_lift": (2, "xl330-m077"), + "elbow_flex": (3, "xl330-m077"), + "wrist_flex": (4, "xl330-m077"), + "wrist_roll": (5, "xl330-m077"), + "gripper": (6, "xl330-m077"), + }, + ), + } + follower_arms = { + "main": DynamixelMotorsBusConfig( + port="/dev/tty.usbmodem575E0032081", + motors={ + # name: (index, model) + "shoulder_pan": (1, "xl430-w250"), + "shoulder_lift": (2, "xl430-w250"), + "elbow_flex": (3, "xl330-m288"), + "wrist_flex": (4, "xl330-m288"), + "wrist_roll": (5, "xl330-m288"), + "gripper": (6, "xl330-m288"), + }, + ), + } + robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms) + robot = ManipulatorRobot(robot_config) + ``` + + Example of overwriting cameras during instantiation: + ```python + # Defines how to communicate with 2 cameras connected to the computer. + # Here, the webcam of the laptop and the phone (connected in USB to the laptop) + # can be reached respectively using the camera indices 0 and 1. These indices can be + # arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices. + cameras = { + "laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480), + "phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480), + } + robot = ManipulatorRobot(KochRobotConfig(cameras=cameras)) + ``` + + Once the robot is instantiated, connect motors buses and cameras if any (Required): + ```python + robot.connect() + ``` + + Example of highest frequency teleoperation, which doesn't require cameras: + ```python + while True: + robot.teleop_step() + ``` + + Example of highest frequency data collection from motors and cameras (if any): + ```python + while True: + observation, action = robot.teleop_step(record_data=True) + ``` + + Example of controlling the robot with a policy: + ```python + while True: + # Uses the follower arms and cameras to capture an observation + observation = robot.capture_observation() + + # Assumes a policy has been instantiated + with torch.inference_mode(): + action = policy.select_action(observation) + + # Orders the robot to move + robot.send_action(action) + ``` + + Example of disconnecting which is not mandatory since we disconnect when the object is deleted: + ```python + robot.disconnect() + ``` + """ + + def __init__( + self, + config: ManipulatorRobotConfig, + ): + self.config = config + self.robot_type = self.config.type + self.calibration_dir = Path(self.config.calibration_dir) + self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) + self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) + self.cameras = make_cameras_from_configs(self.config.cameras) + self.is_connected = False + self.logs = {} + + def get_motor_names(self, arm: dict[str, MotorsBus]) -> list: + return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors] + + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + action_names = self.get_motor_names(self.leader_arms) + state_names = self.get_motor_names(self.leader_arms) + return { + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": action_names, + }, + "observation.state": { + "dtype": "float32", + "shape": (len(state_names),), + "names": state_names, + }, + } + + @property + def features(self): + return {**self.motor_features, **self.camera_features} + + @property + def has_camera(self): + return len(self.cameras) > 0 + + @property + def num_cameras(self): + return len(self.cameras) + + @property + def available_arms(self): + available_arms = [] + for name in self.follower_arms: + arm_id = get_arm_id(name, "follower") + available_arms.append(arm_id) + for name in self.leader_arms: + arm_id = get_arm_id(name, "leader") + available_arms.append(arm_id) + return available_arms + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." + ) + + if not self.leader_arms and not self.follower_arms and not self.cameras: + raise ValueError( + "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." + ) + + # Connect the arms + for name in self.follower_arms: + print(f"Connecting {name} follower arm.") + self.follower_arms[name].connect() + for name in self.leader_arms: + print(f"Connecting {name} leader arm.") + self.leader_arms[name].connect() + + if self.robot_type in ["koch", "koch_bimanual", "aloha"]: + from lerobot.common.robot_devices.motors.dynamixel import TorqueMode + elif self.robot_type in ["so100", "moss", "lekiwi"]: + from lerobot.common.robot_devices.motors.feetech import TorqueMode + + # We assume that at connection time, arms are in a rest position, and torque can + # be safely disabled to run calibration and/or set robot preset configurations. + for name in self.follower_arms: + self.follower_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) + for name in self.leader_arms: + self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) + + self.activate_calibration() + + # Set robot preset (e.g. torque in leader gripper for Koch v1.1) + if self.robot_type in ["koch", "koch_bimanual"]: + self.set_koch_robot_preset() + elif self.robot_type == "aloha": + self.set_aloha_robot_preset() + elif self.robot_type in ["so100", "moss", "lekiwi"]: + self.set_so100_robot_preset() + + # Enable torque on all motors of the follower arms + for name in self.follower_arms: + print(f"Activating torque on {name} follower arm.") + self.follower_arms[name].write("Torque_Enable", 1) + + if self.config.gripper_open_degree is not None: + if self.robot_type not in ["koch", "koch_bimanual"]: + raise NotImplementedError( + f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open." + ) + # Set the leader arm in torque mode with the gripper motor set to an angle. This makes it possible + # to squeeze the gripper and have it spring back to an open position on its own. + for name in self.leader_arms: + self.leader_arms[name].write("Torque_Enable", 1, "gripper") + self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + + # Check both arms can be read + for name in self.follower_arms: + self.follower_arms[name].read("Present_Position") + for name in self.leader_arms: + self.leader_arms[name].read("Present_Position") + + # Connect the cameras + for name in self.cameras: + self.cameras[name].connect() + + self.is_connected = True + + def activate_calibration(self): + """After calibration all motors function in human interpretable ranges. + Rotations are expressed in degrees in nominal range of [-180, 180], + and linear motions (like gripper of Aloha) in nominal range of [0, 100]. + """ + + def load_or_run_calibration_(name, arm, arm_type): + arm_id = get_arm_id(name, arm_type) + arm_calib_path = self.calibration_dir / f"{arm_id}.json" + + if arm_calib_path.exists(): + with open(arm_calib_path) as f: + calibration = json.load(f) + else: + # TODO(rcadene): display a warning in __init__ if calibration file not available + print(f"Missing calibration file '{arm_calib_path}'") + + if self.robot_type in ["koch", "koch_bimanual", "aloha"]: + from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration + + calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) + + elif self.robot_type in ["so100", "moss", "lekiwi"]: + from lerobot.common.robot_devices.robots.feetech_calibration import ( + run_arm_manual_calibration, + ) + + calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) + + print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") + arm_calib_path.parent.mkdir(parents=True, exist_ok=True) + with open(arm_calib_path, "w") as f: + json.dump(calibration, f) + + return calibration + + for name, arm in self.follower_arms.items(): + calibration = load_or_run_calibration_(name, arm, "follower") + arm.set_calibration(calibration) + for name, arm in self.leader_arms.items(): + calibration = load_or_run_calibration_(name, arm, "leader") + arm.set_calibration(calibration) + + def set_koch_robot_preset(self): + def set_operating_mode_(arm): + from lerobot.common.robot_devices.motors.dynamixel import TorqueMode + + if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): + raise ValueError("To run set robot preset, the torque must be disabled on all motors.") + + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't + # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, + # you could end up with a servo with a position 0 or 4095 at a crucial point See [ + # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] + all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] + if len(all_motors_except_gripper) > 0: + # 4 corresponds to Extended Position on Koch motors + arm.write("Operating_Mode", 4, all_motors_except_gripper) + + # Use 'position control current based' for gripper to be limited by the limit of the current. + # For the follower gripper, it means it can grasp an object without forcing too much even tho, + # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger + # to make it move, and it will move back to its original target position when we release the force. + # 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288" + arm.write("Operating_Mode", 5, "gripper") + + for name in self.follower_arms: + set_operating_mode_(self.follower_arms[name]) + + # Set better PID values to close the gap between recorded states and actions + # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor + self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex") + self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex") + self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex") + + if self.config.gripper_open_degree is not None: + for name in self.leader_arms: + set_operating_mode_(self.leader_arms[name]) + + # Enable torque on the gripper of the leader arms, and move it to 45 degrees, + # so that we can use it as a trigger to close the gripper of the follower arms. + self.leader_arms[name].write("Torque_Enable", 1, "gripper") + self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + + def set_aloha_robot_preset(self): + def set_shadow_(arm): + # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. + # As a result, if only one of them is required to move to a certain position, + # the other will follow. This is to avoid breaking the motors. + if "shoulder_shadow" in arm.motor_names: + shoulder_idx = arm.read("ID", "shoulder") + arm.write("Secondary_ID", shoulder_idx, "shoulder_shadow") + + if "elbow_shadow" in arm.motor_names: + elbow_idx = arm.read("ID", "elbow") + arm.write("Secondary_ID", elbow_idx, "elbow_shadow") + + for name in self.follower_arms: + set_shadow_(self.follower_arms[name]) + + for name in self.leader_arms: + set_shadow_(self.leader_arms[name]) + + for name in self.follower_arms: + # Set a velocity limit of 131 as advised by Trossen Robotics + self.follower_arms[name].write("Velocity_Limit", 131) + + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't + # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, + # you could end up with a servo with a position 0 or 4095 at a crucial point See [ + # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] + all_motors_except_gripper = [ + name for name in self.follower_arms[name].motor_names if name != "gripper" + ] + if len(all_motors_except_gripper) > 0: + # 4 corresponds to Extended Position on Aloha motors + self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) + + # Use 'position control current based' for follower gripper to be limited by the limit of the current. + # It can grasp an object without forcing too much even tho, + # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # 5 corresponds to Current Controlled Position on Aloha gripper follower "xm430-w350" + self.follower_arms[name].write("Operating_Mode", 5, "gripper") + + # Note: We can't enable torque on the leader gripper since "xc430-w150" doesn't have + # a Current Controlled Position mode. + + if self.config.gripper_open_degree is not None: + warnings.warn( + f"`gripper_open_degree` is set to {self.config.gripper_open_degree}, but None is expected for Aloha instead", + stacklevel=1, + ) + + def set_so100_robot_preset(self): + for name in self.follower_arms: + # Mode=0 for Position Control + self.follower_arms[name].write("Mode", 0) + # Set P_Coefficient to lower value to avoid shakiness (Default is 32) + self.follower_arms[name].write("P_Coefficient", 16) + # Set I_Coefficient and D_Coefficient to default value 0 and 32 + self.follower_arms[name].write("I_Coefficient", 0) + self.follower_arms[name].write("D_Coefficient", 32) + # Close the write lock so that Maximum_Acceleration gets written to EPROM address, + # which is mandatory for Maximum_Acceleration to take effect after rebooting. + self.follower_arms[name].write("Lock", 0) + # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of + # the motors. Note: this configuration is not in the official STS3215 Memory Table + self.follower_arms[name].write("Maximum_Acceleration", 254) + self.follower_arms[name].write("Acceleration", 254) + + def teleop_step( + self, record_data=False + ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "ManipulatorRobot is not connected. You need to run `robot.connect()`." + ) + + # Prepare to assign the position of the leader to the follower + leader_pos = {} + for name in self.leader_arms: + before_lread_t = time.perf_counter() + leader_pos[name] = self.leader_arms[name].read("Present_Position") + leader_pos[name] = torch.from_numpy(leader_pos[name]) + self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + + # Send goal position to the follower + follower_goal_pos = {} + for name in self.follower_arms: + before_fwrite_t = time.perf_counter() + goal_pos = leader_pos[name] + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.follower_arms[name].read("Present_Position") + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Used when record_data=True + follower_goal_pos[name] = goal_pos + + goal_pos = goal_pos.numpy().astype(np.float32) + self.follower_arms[name].write("Goal_Position", goal_pos) + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t + + # Early exit when recording data is not requested + if not record_data: + return + + # TODO(rcadene): Add velocity and other info + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = self.follower_arms[name].read("Present_Position") + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Create action by concatenating follower goal position + action = [] + for name in self.follower_arms: + if name in follower_goal_pos: + action.append(follower_goal_pos[name]) + action = torch.cat(action) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries + obs_dict, action_dict = {}, {} + obs_dict["observation.state"] = state + action_dict["action"] = action + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + return obs_dict, action_dict + + def capture_observation(self): + """The returned observations do not have a batch dimension.""" + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "ManipulatorRobot is not connected. You need to run `robot.connect()`." + ) + + # Read follower position + follower_pos = {} + for name in self.follower_arms: + before_fread_t = time.perf_counter() + follower_pos[name] = self.follower_arms[name].read("Present_Position") + follower_pos[name] = torch.from_numpy(follower_pos[name]) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + + # Create state by concatenating follower current position + state = [] + for name in self.follower_arms: + if name in follower_pos: + state.append(follower_pos[name]) + state = torch.cat(state) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries and format to pytorch + obs_dict = {} + obs_dict["observation.state"] = state + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + """Command the follower arms to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action: tensor containing the concatenated goal positions for the follower arms. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "ManipulatorRobot is not connected. You need to run `robot.connect()`." + ) + + from_idx = 0 + to_idx = 0 + action_sent = [] + for name in self.follower_arms: + # Get goal position of each follower arm by splitting the action vector + to_idx += len(self.follower_arms[name].motor_names) + goal_pos = action[from_idx:to_idx] + from_idx = to_idx + + # Cap goal position when too far away from present position. + # Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.follower_arms[name].read("Present_Position") + present_pos = torch.from_numpy(present_pos) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + + # Save tensor to concat and return + action_sent.append(goal_pos) + + # Send goal position to each follower + goal_pos = goal_pos.numpy().astype(np.float32) + self.follower_arms[name].write("Goal_Position", goal_pos) + + return torch.cat(action_sent) + + def print_logs(self): + pass + # TODO(aliberts): move robot-specific logs logic here + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting." + ) + + for name in self.follower_arms: + self.follower_arms[name].disconnect() + + for name in self.leader_arms: + self.leader_arms[name].disconnect() + + for name in self.cameras: + self.cameras[name].disconnect() + + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py new file mode 100644 index 0000000000000000000000000000000000000000..385e218bed59942dd09744ac8a39f15af519ca61 --- /dev/null +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -0,0 +1,703 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +import sys +from pathlib import Path + +import cv2 +import numpy as np +import torch +import zmq + +from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs +from lerobot.common.robot_devices.motors.feetech import TorqueMode +from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs +from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig +from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration +from lerobot.common.robot_devices.robots.utils import get_arm_id +from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError + +PYNPUT_AVAILABLE = True +try: + # Only import if there's a valid X server or if we're not on a Pi + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + print("No DISPLAY set. Skipping pynput import.") + raise ImportError("pynput blocked intentionally due to no display.") + + from pynput import keyboard +except ImportError: + keyboard = None + PYNPUT_AVAILABLE = False +except Exception as e: + keyboard = None + PYNPUT_AVAILABLE = False + print(f"Could not import pynput: {e}") + + +class MobileManipulator: + """ + MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot. + The robot includes a three omniwheel mobile base and a remote follower arm. + The leader arm is connected locally (on the laptop) and its joint positions are recorded and then + forwarded to the remote follower arm (after applying a safety clamp). + In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels. + """ + + def __init__(self, config: LeKiwiRobotConfig): + """ + Expected keys in config: + - ip, port, video_port for the remote connection. + - calibration_dir, leader_arms, follower_arms, max_relative_target, etc. + """ + self.robot_type = config.type + self.config = config + self.remote_ip = config.ip + self.remote_port = config.port + self.remote_port_video = config.video_port + self.calibration_dir = Path(self.config.calibration_dir) + self.logs = {} + + self.teleop_keys = self.config.teleop_keys + + # For teleoperation, the leader arm (local) is used to record the desired arm pose. + self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) + + self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) + + self.cameras = make_cameras_from_configs(self.config.cameras) + + self.is_connected = False + + self.last_frames = {} + self.last_present_speed = {} + self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32) + + # Define three speed levels and a current index + self.speed_levels = [ + {"xy": 0.1, "theta": 30}, # slow + {"xy": 0.2, "theta": 60}, # medium + {"xy": 0.3, "theta": 90}, # fast + ] + self.speed_index = 0 # Start at slow + + # ZeroMQ context and sockets. + self.context = None + self.cmd_socket = None + self.video_socket = None + + # Keyboard state for base teleoperation. + self.running = True + self.pressed_keys = { + "forward": False, + "backward": False, + "left": False, + "right": False, + "rotate_left": False, + "rotate_right": False, + } + + if PYNPUT_AVAILABLE: + print("pynput is available - enabling local keyboard listener.") + self.listener = keyboard.Listener( + on_press=self.on_press, + on_release=self.on_release, + ) + self.listener.start() + else: + print("pynput not available - skipping local keyboard listener.") + self.listener = None + + def get_motor_names(self, arms: dict[str, MotorsBus]) -> list: + return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors] + + @property + def camera_features(self) -> dict: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + key = f"observation.images.{cam_key}" + cam_ft[key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + @property + def motor_features(self) -> dict: + follower_arm_names = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ] + observations = ["x_mm", "y_mm", "theta"] + combined_names = follower_arm_names + observations + return { + "action": { + "dtype": "float32", + "shape": (len(combined_names),), + "names": combined_names, + }, + "observation.state": { + "dtype": "float32", + "shape": (len(combined_names),), + "names": combined_names, + }, + } + + @property + def features(self): + return {**self.motor_features, **self.camera_features} + + @property + def has_camera(self): + return len(self.cameras) > 0 + + @property + def num_cameras(self): + return len(self.cameras) + + @property + def available_arms(self): + available = [] + for name in self.leader_arms: + available.append(get_arm_id(name, "leader")) + for name in self.follower_arms: + available.append(get_arm_id(name, "follower")) + return available + + def on_press(self, key): + try: + # Movement + if key.char == self.teleop_keys["forward"]: + self.pressed_keys["forward"] = True + elif key.char == self.teleop_keys["backward"]: + self.pressed_keys["backward"] = True + elif key.char == self.teleop_keys["left"]: + self.pressed_keys["left"] = True + elif key.char == self.teleop_keys["right"]: + self.pressed_keys["right"] = True + elif key.char == self.teleop_keys["rotate_left"]: + self.pressed_keys["rotate_left"] = True + elif key.char == self.teleop_keys["rotate_right"]: + self.pressed_keys["rotate_right"] = True + + # Quit teleoperation + elif key.char == self.teleop_keys["quit"]: + self.running = False + return False + + # Speed control + elif key.char == self.teleop_keys["speed_up"]: + self.speed_index = min(self.speed_index + 1, 2) + print(f"Speed index increased to {self.speed_index}") + elif key.char == self.teleop_keys["speed_down"]: + self.speed_index = max(self.speed_index - 1, 0) + print(f"Speed index decreased to {self.speed_index}") + + except AttributeError: + # e.g., if key is special like Key.esc + if key == keyboard.Key.esc: + self.running = False + return False + + def on_release(self, key): + try: + if hasattr(key, "char"): + if key.char == self.teleop_keys["forward"]: + self.pressed_keys["forward"] = False + elif key.char == self.teleop_keys["backward"]: + self.pressed_keys["backward"] = False + elif key.char == self.teleop_keys["left"]: + self.pressed_keys["left"] = False + elif key.char == self.teleop_keys["right"]: + self.pressed_keys["right"] = False + elif key.char == self.teleop_keys["rotate_left"]: + self.pressed_keys["rotate_left"] = False + elif key.char == self.teleop_keys["rotate_right"]: + self.pressed_keys["rotate_right"] = False + except AttributeError: + pass + + def connect(self): + if not self.leader_arms: + raise ValueError("MobileManipulator has no leader arm to connect.") + for name in self.leader_arms: + print(f"Connecting {name} leader arm.") + self.calibrate_leader() + + # Set up ZeroMQ sockets to communicate with the remote mobile robot. + self.context = zmq.Context() + self.cmd_socket = self.context.socket(zmq.PUSH) + connection_string = f"tcp://{self.remote_ip}:{self.remote_port}" + self.cmd_socket.connect(connection_string) + self.cmd_socket.setsockopt(zmq.CONFLATE, 1) + self.video_socket = self.context.socket(zmq.PULL) + video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}" + self.video_socket.connect(video_connection) + self.video_socket.setsockopt(zmq.CONFLATE, 1) + print( + f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}." + ) + self.is_connected = True + + def load_or_run_calibration_(self, name, arm, arm_type): + arm_id = get_arm_id(name, arm_type) + arm_calib_path = self.calibration_dir / f"{arm_id}.json" + + if arm_calib_path.exists(): + with open(arm_calib_path) as f: + calibration = json.load(f) + else: + print(f"Missing calibration file '{arm_calib_path}'") + calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) + print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") + arm_calib_path.parent.mkdir(parents=True, exist_ok=True) + with open(arm_calib_path, "w") as f: + json.dump(calibration, f) + + return calibration + + def calibrate_leader(self): + for name, arm in self.leader_arms.items(): + # Connect the bus + arm.connect() + + # Disable torque on all motors + for motor_id in arm.motors: + arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id) + + # Now run calibration + calibration = self.load_or_run_calibration_(name, arm, "leader") + arm.set_calibration(calibration) + + def calibrate_follower(self): + for name, bus in self.follower_arms.items(): + bus.connect() + + # Disable torque on all motors + for motor_id in bus.motors: + bus.write("Torque_Enable", 0, motor_id) + + # Then filter out wheels + arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")} + if not arm_only_dict: + continue + + original_motors = bus.motors + bus.motors = arm_only_dict + + calibration = self.load_or_run_calibration_(name, bus, "follower") + bus.set_calibration(calibration) + + bus.motors = original_motors + + def _get_data(self): + """ + Polls the video socket for up to 15 ms. If data arrives, decode only + the *latest* message, returning frames, speed, and arm state. If + nothing arrives for any field, use the last known values. + """ + frames = {} + present_speed = {} + remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32) + + # Poll up to 15 ms + poller = zmq.Poller() + poller.register(self.video_socket, zmq.POLLIN) + socks = dict(poller.poll(15)) + if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN: + # No new data arrived → reuse ALL old data + return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + + # Drain all messages, keep only the last + last_msg = None + while True: + try: + obs_string = self.video_socket.recv_string(zmq.NOBLOCK) + last_msg = obs_string + except zmq.Again: + break + + if not last_msg: + # No new message → also reuse old + return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + + # Decode only the final message + try: + observation = json.loads(last_msg) + + images_dict = observation.get("images", {}) + new_speed = observation.get("present_speed", {}) + new_arm_state = observation.get("follower_arm_state", None) + + # Convert images + for cam_name, image_b64 in images_dict.items(): + if image_b64: + jpg_data = base64.b64decode(image_b64) + np_arr = np.frombuffer(jpg_data, dtype=np.uint8) + frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if frame_candidate is not None: + frames[cam_name] = frame_candidate + + # If remote_arm_state is None and frames is None there is no message then use the previous message + if new_arm_state is not None and frames is not None: + self.last_frames = frames + + remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32) + self.last_remote_arm_state = remote_arm_state_tensor + + present_speed = new_speed + self.last_present_speed = new_speed + else: + frames = self.last_frames + + remote_arm_state_tensor = self.last_remote_arm_state + + present_speed = self.last_present_speed + + except Exception as e: + print(f"[DEBUG] Error decoding video message: {e}") + # If decode fails, fall back to old data + return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + + return frames, present_speed, remote_arm_state_tensor + + def _process_present_speed(self, present_speed: dict) -> torch.Tensor: + state_tensor = torch.zeros(3, dtype=torch.int32) + if present_speed: + decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()} + if "1" in decoded: + state_tensor[0] = decoded["1"] + if "2" in decoded: + state_tensor[1] = decoded["2"] + if "3" in decoded: + state_tensor[2] = decoded["3"] + return state_tensor + + def teleop_step( + self, record_data: bool = False + ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + if not self.is_connected: + raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.") + + speed_setting = self.speed_levels[self.speed_index] + xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 + theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90 + + # Prepare to assign the position of the leader to the follower + arm_positions = [] + for name in self.leader_arms: + pos = self.leader_arms[name].read("Present_Position") + pos_tensor = torch.from_numpy(pos).float() + arm_positions.extend(pos_tensor.tolist()) + + y_cmd = 0.0 # m/s forward/backward + x_cmd = 0.0 # m/s lateral + theta_cmd = 0.0 # deg/s rotation + if self.pressed_keys["forward"]: + y_cmd += xy_speed + if self.pressed_keys["backward"]: + y_cmd -= xy_speed + if self.pressed_keys["left"]: + x_cmd += xy_speed + if self.pressed_keys["right"]: + x_cmd -= xy_speed + if self.pressed_keys["rotate_left"]: + theta_cmd += theta_speed + if self.pressed_keys["rotate_right"]: + theta_cmd -= theta_speed + + wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd) + + message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions} + self.cmd_socket.send_string(json.dumps(message)) + + if not record_data: + return + + obs_dict = self.capture_observation() + + arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32) + + wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands) + wheel_velocity_mm = ( + wheel_velocity_tuple[0] * 1000.0, + wheel_velocity_tuple[1] * 1000.0, + wheel_velocity_tuple[2], + ) + wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32) + action_tensor = torch.cat([arm_state_tensor, wheel_tensor]) + action_dict = {"action": action_tensor} + + return obs_dict, action_dict + + def capture_observation(self) -> dict: + """ + Capture observations from the remote robot: current follower arm positions, + present wheel speeds (converted to body-frame velocities: x, y, theta), + and a camera frame. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.") + + frames, present_speed, remote_arm_state_tensor = self._get_data() + + body_state = self.wheel_raw_to_body(present_speed) + + body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s + wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) + combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) + + obs_dict = {"observation.state": combined_state_tensor} + + # Loop over each configured camera + for cam_name, cam in self.cameras.items(): + frame = frames.get(cam_name, None) + if frame is None: + # Create a black image using the camera's configured width, height, and channels + frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) + obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame) + + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + if not self.is_connected: + raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.") + + # Ensure the action tensor has at least 9 elements: + # - First 6: arm positions. + # - Last 3: base commands. + if action.numel() < 9: + # Pad with zeros if there are not enough elements. + padded = torch.zeros(9, dtype=action.dtype) + padded[: action.numel()] = action + action = padded + + # Extract arm and base actions. + arm_actions = action[:6].flatten() + base_actions = action[6:].flatten() + + x_cmd_mm = base_actions[0].item() # mm/s + y_cmd_mm = base_actions[1].item() # mm/s + theta_cmd = base_actions[2].item() # deg/s + + # Convert mm/s to m/s for the kinematics calculations. + x_cmd = x_cmd_mm / 1000.0 # m/s + y_cmd = y_cmd_mm / 1000.0 # m/s + + # Compute wheel commands from body commands. + wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd) + + arm_positions_list = arm_actions.tolist() + + message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list} + self.cmd_socket.send_string(json.dumps(message)) + + return action + + def print_logs(self): + pass + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError("Not connected.") + if self.cmd_socket: + stop_cmd = { + "raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0}, + "arm_positions": {}, + } + self.cmd_socket.send_string(json.dumps(stop_cmd)) + self.cmd_socket.close() + if self.video_socket: + self.video_socket.close() + if self.context: + self.context.term() + if PYNPUT_AVAILABLE: + self.listener.stop() + self.is_connected = False + print("[INFO] Disconnected from remote robot.") + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + if PYNPUT_AVAILABLE: + self.listener.stop() + + @staticmethod + def degps_to_raw(degps: float) -> int: + steps_per_deg = 4096.0 / 360.0 + speed_in_steps = abs(degps) * steps_per_deg + speed_int = int(round(speed_in_steps)) + if speed_int > 0x7FFF: + speed_int = 0x7FFF + if degps < 0: + return speed_int | 0x8000 + else: + return speed_int & 0x7FFF + + @staticmethod + def raw_to_degps(raw_speed: int) -> float: + steps_per_deg = 4096.0 / 360.0 + magnitude = raw_speed & 0x7FFF + degps = magnitude / steps_per_deg + if raw_speed & 0x8000: + degps = -degps + return degps + + def body_to_wheel_raw( + self, + x_cmd: float, + y_cmd: float, + theta_cmd: float, + wheel_radius: float = 0.05, + base_radius: float = 0.125, + max_raw: int = 3000, + ) -> dict: + """ + Convert desired body-frame velocities into wheel raw commands. + + Parameters: + x_cmd : Linear velocity in x (m/s). + y_cmd : Linear velocity in y (m/s). + theta_cmd : Rotational velocity (deg/s). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the center of rotation to each wheel (meters). + max_raw : Maximum allowed raw command (ticks) per wheel. + + Returns: + A dictionary with wheel raw commands: + {"left_wheel": value, "back_wheel": value, "right_wheel": value}. + + Notes: + - Internally, the method converts theta_cmd to rad/s for the kinematics. + - The raw command is computed from the wheels angular speed in deg/s + using degps_to_raw(). If any command exceeds max_raw, all commands + are scaled down proportionally. + """ + # Convert rotational velocity from deg/s to rad/s. + theta_rad = theta_cmd * (np.pi / 180.0) + # Create the body velocity vector [x, y, theta_rad]. + velocity_vector = np.array([x_cmd, y_cmd, theta_rad]) + + # Define the wheel mounting angles (defined from y axis cw) + angles = np.radians(np.array([300, 180, 60])) + # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. + # The third column (base_radius) accounts for the effect of rotation. + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s). + wheel_linear_speeds = m.dot(velocity_vector) + wheel_angular_speeds = wheel_linear_speeds / wheel_radius + + # Convert wheel angular speeds from rad/s to deg/s. + wheel_degps = wheel_angular_speeds * (180.0 / np.pi) + + # Scaling + steps_per_deg = 4096.0 / 360.0 + raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps] + max_raw_computed = max(raw_floats) + if max_raw_computed > max_raw: + scale = max_raw / max_raw_computed + wheel_degps = wheel_degps * scale + + # Convert each wheel’s angular speed (deg/s) to a raw integer. + wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps] + + return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]} + + def wheel_raw_to_body( + self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125 + ) -> tuple: + """ + Convert wheel raw command feedback back into body-frame velocities. + + Parameters: + wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel"). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the robot center to each wheel (meters). + + Returns: + A tuple (x_cmd, y_cmd, theta_cmd) where: + x_cmd : Linear velocity in x (m/s). + y_cmd : Linear velocity in y (m/s). + theta_cmd : Rotational velocity in deg/s. + """ + # Extract the raw values in order. + raw_list = [ + int(wheel_raw.get("left_wheel", 0)), + int(wheel_raw.get("back_wheel", 0)), + int(wheel_raw.get("right_wheel", 0)), + ] + + # Convert each raw command back to an angular speed in deg/s. + wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list]) + # Convert from deg/s to rad/s. + wheel_radps = wheel_degps * (np.pi / 180.0) + # Compute each wheel’s linear speed (m/s) from its angular speed. + wheel_linear_speeds = wheel_radps * wheel_radius + + # Define the wheel mounting angles (defined from y axis cw) + angles = np.radians(np.array([300, 180, 60])) + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. + m_inv = np.linalg.inv(m) + velocity_vector = m_inv.dot(wheel_linear_speeds) + x_cmd, y_cmd, theta_rad = velocity_vector + theta_cmd = theta_rad * (180.0 / np.pi) + return (x_cmd, y_cmd, theta_cmd) + + +class LeKiwi: + def __init__(self, motor_bus): + """ + Initializes the LeKiwi with Feetech motors bus. + """ + self.motor_bus = motor_bus + self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"] + + # Initialize motors in velocity mode. + self.motor_bus.write("Lock", 0) + self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids) + self.motor_bus.write("Lock", 1) + print("Motors set to velocity mode.") + + def read_velocity(self): + """ + Reads the raw speeds for all wheels. Returns a dictionary with motor names: + """ + raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids) + return { + "left_wheel": int(raw_speeds[0]), + "back_wheel": int(raw_speeds[1]), + "right_wheel": int(raw_speeds[2]), + } + + def set_velocity(self, command_speeds): + """ + Sends raw velocity commands (16-bit encoded values) directly to the motor bus. + The order of speeds must correspond to self.motor_ids. + """ + self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids) + + def stop(self): + """Stops the robot by setting all motor speeds to zero.""" + self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids) + print("Motors stopped.") diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfe6e49053fc87f4dfa2a44e0f221046e8ea5c4 --- /dev/null +++ b/lerobot/common/robot_devices/robots/stretch.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from dataclasses import replace + +import torch +from stretch_body.gamepad_teleop import GamePadTeleop +from stretch_body.robot import Robot as StretchAPI +from stretch_body.robot_params import RobotParams + +from lerobot.common.robot_devices.robots.configs import StretchRobotConfig + + +class StretchRobot(StretchAPI): + """Wrapper of stretch_body.robot.Robot""" + + def __init__(self, config: StretchRobotConfig | None = None, **kwargs): + super().__init__() + if config is None: + self.config = StretchRobotConfig(**kwargs) + else: + # Overwrite config arguments using kwargs + self.config = replace(config, **kwargs) + + self.robot_type = self.config.type + self.cameras = self.config.cameras + self.is_connected = False + self.teleop = None + self.logs = {} + + # TODO(aliberts): test this + RobotParams.set_logging_level("WARNING") + RobotParams.set_logging_formatter("brief_console_formatter") + + self.state_keys = None + self.action_keys = None + + def connect(self) -> None: + self.is_connected = self.startup() + if not self.is_connected: + print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") + raise ConnectionError() + + for name in self.cameras: + self.cameras[name].connect() + self.is_connected = self.is_connected and self.cameras[name].is_connected + + if not self.is_connected: + print("Could not connect to the cameras, check that all cameras are plugged-in.") + raise ConnectionError() + + self.run_calibration() + + def run_calibration(self) -> None: + if not self.is_homed(): + self.home() + + def teleop_step( + self, record_data=False + ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + # TODO(aliberts): return ndarrays instead of torch.Tensors + if not self.is_connected: + raise ConnectionError() + + if self.teleop is None: + self.teleop = GamePadTeleop(robot_instance=False) + self.teleop.startup(robot=self) + + before_read_t = time.perf_counter() + state = self.get_state() + action = self.teleop.gamepad_controller.get_state() + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + before_write_t = time.perf_counter() + self.teleop.do_motion(robot=self) + self.push_command() + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + + if self.state_keys is None: + self.state_keys = list(state) + + if not record_data: + return + + state = torch.as_tensor(list(state.values())) + action = torch.as_tensor(list(action.values())) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries + obs_dict, action_dict = {}, {} + obs_dict["observation.state"] = state + action_dict["action"] = action + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + return obs_dict, action_dict + + def get_state(self) -> dict: + status = self.get_status() + return { + "head_pan.pos": status["head"]["head_pan"]["pos"], + "head_tilt.pos": status["head"]["head_tilt"]["pos"], + "lift.pos": status["lift"]["pos"], + "arm.pos": status["arm"]["pos"], + "wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"], + "wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"], + "wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"], + "gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"], + "base_x.vel": status["base"]["x_vel"], + "base_y.vel": status["base"]["y_vel"], + "base_theta.vel": status["base"]["theta_vel"], + } + + def capture_observation(self) -> dict: + # TODO(aliberts): return ndarrays instead of torch.Tensors + before_read_t = time.perf_counter() + state = self.get_state() + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + if self.state_keys is None: + self.state_keys = list(state) + + state = torch.as_tensor(list(state.values())) + + # Capture images from cameras + images = {} + for name in self.cameras: + before_camread_t = time.perf_counter() + images[name] = self.cameras[name].async_read() + images[name] = torch.from_numpy(images[name]) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + + # Populate output dictionaries + obs_dict = {} + obs_dict["observation.state"] = state + for name in self.cameras: + obs_dict[f"observation.images.{name}"] = images[name] + + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + # TODO(aliberts): return ndarrays instead of torch.Tensors + if not self.is_connected: + raise ConnectionError() + + if self.teleop is None: + self.teleop = GamePadTeleop(robot_instance=False) + self.teleop.startup(robot=self) + + if self.action_keys is None: + dummy_action = self.teleop.gamepad_controller.get_state() + self.action_keys = list(dummy_action.keys()) + + action_dict = dict(zip(self.action_keys, action.tolist(), strict=True)) + + before_write_t = time.perf_counter() + self.teleop.do_motion(state=action_dict, robot=self) + self.push_command() + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + + # TODO(aliberts): return action_sent when motion is limited + return action + + def print_logs(self) -> None: + pass + # TODO(aliberts): move robot-specific logs logic here + + def teleop_safety_stop(self) -> None: + if self.teleop is not None: + self.teleop._safety_stop(robot=self) + + def disconnect(self) -> None: + self.stop() + if self.teleop is not None: + self.teleop.gamepad_controller.stop() + self.teleop.stop() + + if len(self.cameras) > 0: + for cam in self.cameras.values(): + cam.disconnect() + + self.is_connected = False + + def __del__(self): + self.disconnect() diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dab514d5ec824ce4d99fa2250135fb1605c63232 --- /dev/null +++ b/lerobot/common/robot_devices/robots/utils.py @@ -0,0 +1,86 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from lerobot.common.robot_devices.robots.configs import ( + AlohaRobotConfig, + KochBimanualRobotConfig, + KochRobotConfig, + LeKiwiRobotConfig, + ManipulatorRobotConfig, + MossRobotConfig, + RobotConfig, + So100RobotConfig, + StretchRobotConfig, +) + + +def get_arm_id(name, arm_type): + """Returns the string identifier of a robot arm. For instance, for a bimanual manipulator + like Aloha, it could be left_follower, right_follower, left_leader, or right_leader. + """ + return f"{name}_{arm_type}" + + +class Robot(Protocol): + # TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes + robot_type: str + features: dict + + def connect(self): ... + def run_calibration(self): ... + def teleop_step(self, record_data=False): ... + def capture_observation(self): ... + def send_action(self, action): ... + def disconnect(self): ... + + +def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: + if robot_type == "aloha": + return AlohaRobotConfig(**kwargs) + elif robot_type == "koch": + return KochRobotConfig(**kwargs) + elif robot_type == "koch_bimanual": + return KochBimanualRobotConfig(**kwargs) + elif robot_type == "moss": + return MossRobotConfig(**kwargs) + elif robot_type == "so100": + return So100RobotConfig(**kwargs) + elif robot_type == "stretch": + return StretchRobotConfig(**kwargs) + elif robot_type == "lekiwi": + return LeKiwiRobotConfig(**kwargs) + else: + raise ValueError(f"Robot type '{robot_type}' is not available.") + + +def make_robot_from_config(config: RobotConfig): + if isinstance(config, ManipulatorRobotConfig): + from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot + + return ManipulatorRobot(config) + elif isinstance(config, LeKiwiRobotConfig): + from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator + + return MobileManipulator(config) + else: + from lerobot.common.robot_devices.robots.stretch import StretchRobot + + return StretchRobot(config) + + +def make_robot(robot_type: str, **kwargs) -> Robot: + config = make_robot_config(robot_type, **kwargs) + return make_robot_from_config(config) diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..837c9d2eb2a908ed4501655b28aa55ff538dcb00 --- /dev/null +++ b/lerobot/common/robot_devices/utils.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import time + + +def busy_wait(seconds): + if platform.system() == "Darwin": + # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, + # but it consumes CPU cycles. + # TODO(rcadene): find an alternative: from python 11, time.sleep is precise + end_time = time.perf_counter() + seconds + while time.perf_counter() < end_time: + pass + else: + # On Linux time.sleep is accurate + if seconds > 0: + time.sleep(seconds) + + +def safe_disconnect(func): + # TODO(aliberts): Allow to pass custom exceptions + # (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError) + def wrapper(robot, *args, **kwargs): + try: + return func(robot, *args, **kwargs) + except Exception as e: + if robot.is_connected: + robot.disconnect() + raise e + + return wrapper + + +class RobotDeviceNotConnectedError(Exception): + """Exception raised when the robot device is not connected.""" + + def __init__( + self, message="This robot device is not connected. Try calling `robot_device.connect()` first." + ): + self.message = message + super().__init__(self.message) + + +class RobotDeviceAlreadyConnectedError(Exception): + """Exception raised when the robot device is already connected.""" + + def __init__( + self, + message="This robot device is already connected. Try not calling `robot_device.connect()` twice.", + ): + self.message = message + super().__init__(self.message) diff --git a/lerobot/common/utils/benchmark.py b/lerobot/common/utils/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..4b08e6f6d8e987cb77c88f6ec35781e0a2a9a707 --- /dev/null +++ b/lerobot/common/utils/benchmark.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from contextlib import ContextDecorator + + +class TimeBenchmark(ContextDecorator): + """ + Measures execution time using a context manager or decorator. + + This class supports both context manager and decorator usage, and is thread-safe for multithreaded + environments. + + Args: + print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults + to False. + + Examples: + + Using as a context manager: + + >>> benchmark = TimeBenchmark() + >>> with benchmark: + ... time.sleep(1) + >>> print(f"Block took {benchmark.result:.4f} seconds") + Block took approximately 1.0000 seconds + + Using with multithreading: + + ```python + import threading + + benchmark = TimeBenchmark() + + def context_manager_example(): + with benchmark: + time.sleep(0.01) + print(f"Block took {benchmark.result_ms:.2f} milliseconds") + + threads = [] + for _ in range(3): + t1 = threading.Thread(target=context_manager_example) + threads.append(t1) + + for t in threads: + t.start() + + for t in threads: + t.join() + ``` + Expected output: + Block took approximately 10.00 milliseconds + Block took approximately 10.00 milliseconds + Block took approximately 10.00 milliseconds + """ + + def __init__(self, print=False): + self.local = threading.local() + self.print_time = print + + def __enter__(self): + self.local.start_time = time.perf_counter() + return self + + def __exit__(self, *exc): + self.local.end_time = time.perf_counter() + self.local.elapsed_time = self.local.end_time - self.local.start_time + if self.print_time: + print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds") + return False + + @property + def result(self): + return getattr(self.local, "elapsed_time", None) + + @property + def result_ms(self): + return self.result * 1e3 diff --git a/lerobot/common/utils/hub.py b/lerobot/common/utils/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..df7435c0fcfcc691e66190093699d07b0558c5c1 --- /dev/null +++ b/lerobot/common/utils/hub.py @@ -0,0 +1,202 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Type, TypeVar + +from huggingface_hub import HfApi +from huggingface_hub.utils import validate_hf_hub_args + +T = TypeVar("T", bound="HubMixin") + + +class HubMixin: + """ + A Mixin containing the functionality to push an object to the hub. + + This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its + subclasses (in particular, the fact that it's not necessarily a model). + + The inheriting classes must implement '_save_pretrained' and 'from_pretrained'. + """ + + def save_pretrained( + self, + save_directory: str | Path, + *, + repo_id: str | None = None, + push_to_hub: bool = False, + card_kwargs: dict[str, Any] | None = None, + **push_to_hub_kwargs, + ) -> str | None: + """ + Save object in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the object will be saved. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your object to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the card template to customize the card. + push_to_hub_kwargs: + Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method. + Returns: + `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # save object (weights, files, etc.) + self._save_pretrained(save_directory) + + # push to the Hub if required + if push_to_hub: + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs) + return None + + def _save_pretrained(self, save_directory: Path) -> None: + """ + Overwrite this method in subclass to define how to save your object. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the object files will be saved. + """ + raise NotImplementedError + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls: Type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **kwargs, + ) -> T: + """ + Download the object from the Huggingface Hub and instantiate it. + + Args: + pretrained_name_or_path (`str`, `Path`): + - Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`. + - Or a path to a `directory` containing the object files saved using `.save_pretrained`, + e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the files from the Hub, overriding the existing cache. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + kwargs (`Dict`, *optional*): + Additional kwargs to pass to the object during initialization. + """ + raise NotImplementedError + + @validate_hf_hub_args + def push_to_hub( + self, + repo_id: str, + *, + commit_message: str | None = None, + private: bool | None = None, + token: str | None = None, + branch: str | None = None, + create_pr: bool | None = None, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + delete_patterns: list[str] | str | None = None, + card_kwargs: dict[str, Any] | None = None, + ) -> str: + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*): + Whether the repository created should be private. + If `None` (default), the repo will be public unless the organization's default is private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the card template to customize the card. + + Returns: + The url of the commit of your object in the given repository. + """ + api = HfApi(token=token) + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + if commit_message is None: + if "Policy" in self.__class__.__name__: + commit_message = "Upload policy" + elif "Config" in self.__class__.__name__: + commit_message = "Upload config" + else: + commit_message = f"Upload {self.__class__.__name__}" + + # Push the files to the repo in a single commit + with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: + saved_path = Path(tmp) / repo_id + self.save_pretrained(saved_path, card_kwargs=card_kwargs) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5f82450221ba7b3f707721b65531433b50021c --- /dev/null +++ b/lerobot/common/utils/import_utils.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging + + +def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: + """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py + Check if the package spec exists and grab its version to avoid importing a local directory. + **Note:** this doesn't work for all packages. + """ + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logging.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +_torch_available, _torch_version = is_package_available("torch", return_version=True) +_gym_xarm_available = is_package_available("gym_xarm") +_gym_aloha_available = is_package_available("gym_aloha") +_gym_pusht_available = is_package_available("gym_pusht") diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da0be1c771997ec9395dc570d99f5cfd3a3d7ca5 --- /dev/null +++ b/lerobot/common/utils/io_utils.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import warnings +from pathlib import Path +from typing import TypeVar + +import imageio + +JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] +T = TypeVar("T", bound=JsonLike) + + +def write_video(video_path, stacked_frames, fps): + # Filter out DeprecationWarnings raised from pkg_resources + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning + ) + imageio.mimsave(video_path, stacked_frames, fps=fps) + + +def deserialize_json_into_object(fpath: Path, obj: T) -> T: + """ + Loads the JSON data from `fpath` and recursively fills `obj` with the + corresponding values (strictly matching structure and types). + Tuples in `obj` are expected to be lists in the JSON data, which will be + converted back into tuples. + """ + with open(fpath, encoding="utf-8") as f: + data = json.load(f) + + def _deserialize(target, source): + """ + Recursively overwrite the structure in `target` with data from `source`, + performing strict checks on structure and type. + Returns the updated version of `target` (especially important for tuples). + """ + + # If the target is a dictionary, source must be a dictionary as well. + if isinstance(target, dict): + if not isinstance(source, dict): + raise TypeError(f"Type mismatch: expected dict, got {type(source)}") + + # Check that they have exactly the same set of keys. + if target.keys() != source.keys(): + raise ValueError( + f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}" + ) + + # Recursively update each key. + for k in target: + target[k] = _deserialize(target[k], source[k]) + + return target + + # If the target is a list, source must be a list as well. + elif isinstance(target, list): + if not isinstance(source, list): + raise TypeError(f"Type mismatch: expected list, got {type(source)}") + + # Check length + if len(target) != len(source): + raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}") + + # Recursively update each element. + for i in range(len(target)): + target[i] = _deserialize(target[i], source[i]) + + return target + + # If the target is a tuple, the source must be a list in JSON, + # which we'll convert back to a tuple. + elif isinstance(target, tuple): + if not isinstance(source, list): + raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}") + + if len(target) != len(source): + raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}") + + # Convert each element, forming a new tuple. + converted_items = [] + for t_item, s_item in zip(target, source, strict=False): + converted_items.append(_deserialize(t_item, s_item)) + + # Return a brand new tuple (tuples are immutable in Python). + return tuple(converted_items) + + # Otherwise, we're dealing with a "primitive" (int, float, str, bool, None). + else: + # Check the exact type. If these must match 1:1, do: + if type(target) is not type(source): + raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}") + return source + + # Perform the in-place/recursive deserialization + updated_obj = _deserialize(obj, data) + return updated_obj diff --git a/lerobot/common/utils/logging_utils.py b/lerobot/common/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56c9abb237625ab8a8db72b509b227ae73f9d90b --- /dev/null +++ b/lerobot/common/utils/logging_utils.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from lerobot.common.utils.utils import format_big_number + + +class AverageMeter: + """ + Computes and stores the average and current value + Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py + """ + + def __init__(self, name: str, fmt: str = ":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self) -> None: + self.val = 0.0 + self.avg = 0.0 + self.sum = 0.0 + self.count = 0.0 + + def update(self, val: float, n: int = 1) -> None: + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name}:{avg" + self.fmt + "}" + return fmtstr.format(**self.__dict__) + + +class MetricsTracker: + """ + A helper class to track and log metrics over time. + + Usage pattern: + + ```python + # initialize, potentially with non-zero initial step (e.g. if resuming run) + metrics = {"loss": AverageMeter("loss", ":.3f")} + train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step) + + # update metrics derived from step (samples, episodes, epochs) at each training step + train_metrics.step() + + # update various metrics + loss = policy.forward(batch) + train_metrics.loss = loss + + # display current metrics + logging.info(train_metrics) + + # export for wandb + wandb.log(train_metrics.to_dict()) + + # reset averages after logging + train_metrics.reset_averages() + ``` + """ + + __keys__ = [ + "_batch_size", + "_num_frames", + "_avg_samples_per_ep", + "metrics", + "steps", + "samples", + "episodes", + "epochs", + ] + + def __init__( + self, + batch_size: int, + num_frames: int, + num_episodes: int, + metrics: dict[str, AverageMeter], + initial_step: int = 0, + ): + self.__dict__.update(dict.fromkeys(self.__keys__)) + self._batch_size = batch_size + self._num_frames = num_frames + self._avg_samples_per_ep = num_frames / num_episodes + self.metrics = metrics + + self.steps = initial_step + # A sample is an (observation,action) pair, where observation and action + # can be on multiple timestamps. In a batch, we have `batch_size` number of samples. + self.samples = self.steps * self._batch_size + self.episodes = self.samples / self._avg_samples_per_ep + self.epochs = self.samples / self._num_frames + + def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: + if name in self.__dict__: + return self.__dict__[name] + elif name in self.metrics: + return self.metrics[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + if name in self.__dict__: + super().__setattr__(name, value) + elif name in self.metrics: + self.metrics[name].update(value) + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def step(self) -> None: + """ + Updates metrics that depend on 'step' for one step. + """ + self.steps += 1 + self.samples += self._batch_size + self.episodes = self.samples / self._avg_samples_per_ep + self.epochs = self.samples / self._num_frames + + def __str__(self) -> str: + display_list = [ + f"step:{format_big_number(self.steps)}", + # number of samples seen during training + f"smpl:{format_big_number(self.samples)}", + # number of episodes seen during training + f"ep:{format_big_number(self.episodes)}", + # number of time all unique samples are seen + f"epch:{self.epochs:.2f}", + *[str(m) for m in self.metrics.values()], + ] + return " ".join(display_list) + + def to_dict(self, use_avg: bool = True) -> dict[str, int | float]: + """ + Returns the current metric values (or averages if `use_avg=True`) as a dict. + """ + return { + "steps": self.steps, + "samples": self.samples, + "episodes": self.episodes, + "epochs": self.epochs, + **{k: m.avg if use_avg else m.val for k, m in self.metrics.items()}, + } + + def reset_averages(self) -> None: + """Resets average meters.""" + for m in self.metrics.values(): + m.reset() diff --git a/lerobot/common/utils/random_utils.py b/lerobot/common/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9bf4dd80a6d1dcdc47070e5b83fded7cec3904 --- /dev/null +++ b/lerobot/common/utils/random_utils.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator + +import numpy as np +import torch +from safetensors.torch import load_file, save_file + +from lerobot.common.constants import RNG_STATE +from lerobot.common.datasets.utils import flatten_dict, unflatten_dict + + +def serialize_python_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + py_state = random.getstate() + return { + "py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64), + "py_rng_state": torch.tensor(py_state[1], dtype=torch.int64), + } + + +def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`. + """ + py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None) + random.setstate(py_state) + + +def serialize_numpy_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + np_state = np.random.get_state() + # Ensure no breaking changes from numpy + assert np_state[0] == "MT19937" + return { + "np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64), + "np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64), + "np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64), + "np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32), + } + + +def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`. + """ + np_state = ( + "MT19937", + rng_state_dict["np_rng_state_values"].numpy(), + rng_state_dict["np_rng_state_index"].item(), + rng_state_dict["np_rng_has_gauss"].item(), + rng_state_dict["np_rng_cached_gaussian"].item(), + ) + np.random.set_state(np_state) + + +def serialize_torch_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state() + return torch_rng_state_dict + + +def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`. + """ + torch.set_rng_state(rng_state_dict["torch_rng_state"]) + if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict: + torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"]) + + +def serialize_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat + dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`. + """ + py_rng_state_dict = serialize_python_rng_state() + np_rng_state_dict = serialize_numpy_rng_state() + torch_rng_state_dict = serialize_torch_rng_state() + + return { + **py_rng_state_dict, + **np_rng_state_dict, + **torch_rng_state_dict, + } + + +def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by + `serialize_rng_state()`. + """ + py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} + np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} + torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")} + + deserialize_python_rng_state(py_rng_state_dict) + deserialize_numpy_rng_state(np_rng_state_dict) + deserialize_torch_rng_state(torch_rng_state_dict) + + +def save_rng_state(save_dir: Path) -> None: + rng_state_dict = serialize_rng_state() + flat_rng_state_dict = flatten_dict(rng_state_dict) + save_file(flat_rng_state_dict, save_dir / RNG_STATE) + + +def load_rng_state(save_dir: Path) -> None: + flat_rng_state_dict = load_file(save_dir / RNG_STATE) + rng_state_dict = unflatten_dict(flat_rng_state_dict) + deserialize_rng_state(rng_state_dict) + + +def get_rng_state() -> dict[str, Any]: + """Get the random state for `random`, `numpy`, and `torch`.""" + random_state_dict = { + "random_state": random.getstate(), + "numpy_random_state": np.random.get_state(), + "torch_random_state": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() + return random_state_dict + + +def set_rng_state(random_state_dict: dict[str, Any]): + """Set the random state for `random`, `numpy`, and `torch`. + + Args: + random_state_dict: A dictionary of the form returned by `get_rng_state`. + """ + random.setstate(random_state_dict["random_state"]) + np.random.set_state(random_state_dict["numpy_random_state"]) + torch.random.set_rng_state(random_state_dict["torch_random_state"]) + if torch.cuda.is_available(): + torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) + + +def set_seed(seed) -> None: + """Set seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +@contextmanager +def seeded_context(seed: int) -> Generator[None, None, None]: + """Set the seed when entering a context, and restore the prior random state at exit. + + Example usage: + + ``` + a = random.random() # produces some random number + with seeded_context(1337): + b = random.random() # produces some other random number + c = random.random() # produces yet another random number, but the same it would have if we never made `b` + ``` + """ + random_state_dict = get_rng_state() + set_seed(seed) + yield None + set_rng_state(random_state_dict) diff --git a/lerobot/common/utils/train_utils.py b/lerobot/common/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a79983128a11081bf70cb0a3845759d9d47ceeb5 --- /dev/null +++ b/lerobot/common/utils/train_utils.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path + +from termcolor import colored +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, + TRAINING_STEP, +) +from lerobot.common.datasets.utils import load_json, write_json +from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state +from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.random_utils import load_rng_state, save_rng_state +from lerobot.configs.train import TrainPipelineConfig + + +def log_output_dir(out_dir): + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + + +def get_step_identifier(step: int, total_steps: int) -> str: + num_digits = max(6, len(str(total_steps))) + return f"{step:0{num_digits}d}" + + +def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path: + """Returns the checkpoint sub-directory corresponding to the step number.""" + step_identifier = get_step_identifier(step, total_steps) + return output_dir / CHECKPOINTS_DIR / step_identifier + + +def save_training_step(step: int, save_dir: Path) -> None: + write_json({"step": step}, save_dir / TRAINING_STEP) + + +def load_training_step(save_dir: Path) -> int: + training_step = load_json(save_dir / TRAINING_STEP) + return training_step["step"] + + +def update_last_checkpoint(checkpoint_dir: Path) -> Path: + last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK + if last_checkpoint_dir.is_symlink(): + last_checkpoint_dir.unlink() + relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent) + last_checkpoint_dir.symlink_to(relative_target) + + +def save_checkpoint( + checkpoint_dir: Path, + step: int, + cfg: TrainPipelineConfig, + policy: PreTrainedPolicy, + optimizer: Optimizer, + scheduler: LRScheduler | None = None, +) -> None: + """This function creates the following directory structure: + + 005000/ # training step at checkpoint + ├── pretrained_model/ + │ ├── config.json # policy config + │ ├── model.safetensors # policy weights + │ └── train_config.json # train config + └── training_state/ + ├── optimizer_param_groups.json # optimizer param groups + ├── optimizer_state.safetensors # optimizer state + ├── rng_state.safetensors # rng states + ├── scheduler_state.json # scheduler state + └── training_step.json # training step + + Args: + cfg (TrainPipelineConfig): The training config used for this run. + step (int): The training step at that checkpoint. + policy (PreTrainedPolicy): The policy to save. + optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. + scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. + """ + pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR + policy.save_pretrained(pretrained_dir) + cfg.save_pretrained(pretrained_dir) + save_training_state(checkpoint_dir, step, optimizer, scheduler) + + +def save_training_state( + checkpoint_dir: Path, + train_step: int, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, +) -> None: + """ + Saves the training step, optimizer state, scheduler state, and rng state. + + Args: + save_dir (Path): The directory to save artifacts to. + train_step (int): Current training step. + optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict. + Defaults to None. + scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. + Defaults to None. + """ + save_dir = checkpoint_dir / TRAINING_STATE_DIR + save_dir.mkdir(parents=True, exist_ok=True) + save_training_step(train_step, save_dir) + save_rng_state(save_dir) + if optimizer is not None: + save_optimizer_state(optimizer, save_dir) + if scheduler is not None: + save_scheduler_state(scheduler, save_dir) + + +def load_training_state( + checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None +) -> tuple[int, Optimizer, LRScheduler | None]: + """ + Loads the training step, optimizer state, scheduler state, and rng state. + This is used to resume a training run. + + Args: + checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. + optimizer (Optimizer): The optimizer to load the state_dict to. + scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). + + Raises: + NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir + + Returns: + tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their + state_dict loaded. + """ + training_state_dir = checkpoint_dir / TRAINING_STATE_DIR + if not training_state_dir.is_dir(): + raise NotADirectoryError(training_state_dir) + + load_rng_state(training_state_dir) + step = load_training_step(training_state_dir) + optimizer = load_optimizer_state(optimizer, training_state_dir) + if scheduler is not None: + scheduler = load_scheduler_state(scheduler, training_state_dir) + + return step, optimizer, scheduler diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..563a7b813424b37d33d29bc2e77adcd4a6d98b02 --- /dev/null +++ b/lerobot/common/utils/utils.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import os.path as osp +import platform +import subprocess +from copy import copy +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import torch + + +def none_or_int(value): + if value == "None": + return None + return int(value) + + +def inside_slurm(): + """Check whether the python process was launched through slurm""" + # TODO(rcadene): return False for interactive mode `--pty bash` + return "SLURM_JOB_ID" in os.environ + + +def auto_select_torch_device() -> torch.device: + """Tries to select automatically a torch device.""" + if torch.cuda.is_available(): + logging.info("Cuda backend detected, using cuda.") + return torch.device("cuda") + elif torch.backends.mps.is_available(): + logging.info("Metal backend detected, using cuda.") + return torch.device("mps") + else: + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + return torch.device("cpu") + + +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) + match try_device: + case "cuda": + assert torch.cuda.is_available() + device = torch.device("cuda") + case "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + case "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + case _: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") + + return device + + +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + else: + return dtype + + +def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string + if try_device == "cuda": + return torch.cuda.is_available() + elif try_device == "mps": + return torch.backends.mps.is_available() + elif try_device == "cpu": + return True + else: + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") + + +def is_amp_available(device: str): + if device in ["cuda", "cpu"]: + return True + elif device == "mps": + return False + else: + raise ValueError(f"Unknown device '{device}.") + + +def init_logging(): + def custom_format(record): + dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + fnameline = f"{record.pathname}:{record.lineno}" + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + return message + + logging.basicConfig(level=logging.INFO) + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + formatter = logging.Formatter() + formatter.format = custom_format + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logging.getLogger().addHandler(console_handler) + + +def format_big_number(num, precision=0): + suffixes = ["", "K", "M", "B", "T", "Q"] + divisor = 1000.0 + + for suffix in suffixes: + if abs(num) < divisor: + return f"{num:.{precision}f}{suffix}" + num /= divisor + + return num + + +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + +def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" + import gc + + gc.collect() + # Also clear the cache if you want to fully release the memory + torch.cuda.empty_cache() + print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) + print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) + print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) + print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) + + +def capture_timestamp_utc(): + return datetime.now(timezone.utc) + + +def say(text, blocking=False): + system = platform.system() + + if system == "Darwin": + cmd = ["say", text] + + elif system == "Linux": + cmd = ["spd-say", text] + if blocking: + cmd.append("--wait") + + elif system == "Windows": + cmd = [ + "PowerShell", + "-Command", + "Add-Type -AssemblyName System.Speech; " + f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')", + ] + + else: + raise RuntimeError("Unsupported operating system for text-to-speech.") + + if blocking: + subprocess.run(cmd, check=True) + else: + subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) + + +def log_say(text, play_sounds, blocking=False): + logging.info(text) + + if play_sounds: + say(text, blocking) + + +def get_channel_first_image_shape(image_shape: tuple) -> tuple: + shape = copy(image_shape) + if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif not (shape[0] < shape[1] and shape[0] < shape[2]): + raise ValueError(image_shape) + + return shape + + +def has_method(cls: object, method_name: str) -> bool: + return hasattr(cls, method_name) and callable(getattr(cls, method_name)) + + +def is_valid_numpy_dtype_string(dtype_str: str) -> bool: + """ + Return True if a given string can be converted to a numpy dtype. + """ + try: + # Attempt to convert the string to a numpy dtype + np.dtype(dtype_str) + return True + except TypeError: + # If a TypeError is raised, the string is not a valid dtype + return False diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe241d412e418fa1400fd2d12ae73e14233c8be --- /dev/null +++ b/lerobot/common/utils/wandb_utils.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import re +from glob import glob +from pathlib import Path + +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from termcolor import colored + +from lerobot.common.constants import PRETRAINED_MODEL_DIR +from lerobot.configs.train import TrainPipelineConfig + + +def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: + """Return a group name for logging. Optionally returns group name as list.""" + lst = [ + f"policy:{cfg.policy.type}", + f"dataset:{cfg.dataset.repo_id}", + f"seed:{cfg.seed}", + ] + if cfg.env is not None: + lst.append(f"env:{cfg.env.type}") + return lst if return_list else "-".join(lst) + + +def get_wandb_run_id_from_filesystem(log_dir: Path) -> str: + # Get the WandB run ID. + paths = glob(str(log_dir / "wandb/latest-run/run-*")) + if len(paths) != 1: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) + if match is None: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + wandb_run_id = match.groups(0)[0] + return wandb_run_id + + +def get_safe_wandb_artifact_name(name: str): + """WandB artifacts don't accept ":" or "/" in their name.""" + return name.replace(":", "_").replace("/", "_") + + +class WandBLogger: + """A helper class to log object using wandb.""" + + def __init__(self, cfg: TrainPipelineConfig): + self.cfg = cfg.wandb + self.log_dir = cfg.output_dir + self.job_name = cfg.job_name + self.env_fps = cfg.env.fps if cfg.env else None + self._group = cfg_to_group(cfg) + + # Set up WandB. + os.environ["WANDB_SILENT"] = "True" + import wandb + + wandb_run_id = ( + cfg.wandb.run_id + if cfg.wandb.run_id + else get_wandb_run_id_from_filesystem(self.log_dir) + if cfg.resume + else None + ) + wandb.init( + id=wandb_run_id, + project=self.cfg.project, + entity=self.cfg.entity, + name=self.job_name, + notes=self.cfg.notes, + tags=cfg_to_group(cfg, return_list=True), + dir=self.log_dir, + config=cfg.to_dict(), + # TODO(rcadene): try set to True + save_code=False, + # TODO(rcadene): split train and eval, and run async eval with job_type="eval" + job_type="train_eval", + resume="must" if cfg.resume else None, + mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", + ) + print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + self._wandb = wandb + + def log_policy(self, checkpoint_dir: Path): + """Checkpoints the policy to wandb.""" + if self.cfg.disable_artifact: + return + + step_id = checkpoint_dir.name + artifact_name = f"{self._group}-{step_id}" + artifact_name = get_safe_wandb_artifact_name(artifact_name) + artifact = self._wandb.Artifact(artifact_name, type="model") + artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) + self._wandb.log_artifact(artifact) + + def log_dict(self, d: dict, step: int, mode: str = "train"): + if mode not in {"train", "eval"}: + raise ValueError(mode) + + for k, v in d.items(): + if not isinstance(v, (int, float, str)): + logging.warning( + f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' + ) + continue + self._wandb.log({f"{mode}/{k}": v}, step=step) + + def log_video(self, video_path: str, step: int, mode: str = "train"): + if mode not in {"train", "eval"}: + raise ValueError(mode) + + wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") + self._wandb.log({f"{mode}/video": wandb_video}, step=step) diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py new file mode 100644 index 0000000000000000000000000000000000000000..ce72466a832f24bfa9ddc907d5b887633e8029f8 --- /dev/null +++ b/lerobot/configs/default.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common import ( + policies, # noqa: F401 +) +from lerobot.common.datasets.transforms import ImageTransformsConfig +from lerobot.common.datasets.video_utils import get_safe_default_codec + + +@dataclass +class DatasetConfig: + # You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data + # keys common between the datasets are kept. Each dataset gets and additional transform that inserts the + # "dataset_index" into the returned item. The index mapping is made according to the order in which the + # datasets are provided. + repo_id: str + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | None = None + episodes: list[int] | None = None + image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) + revision: str | None = None + use_imagenet_stats: bool = True + video_backend: str = field(default_factory=get_safe_default_codec) + + +@dataclass +class WandBConfig: + enable: bool = False + # Set to true to disable saving an artifact despite training.save_checkpoint=True + disable_artifact: bool = False + project: str = "lerobot" + entity: str | None = None + notes: str | None = None + run_id: str | None = None + mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online' + + +@dataclass +class EvalConfig: + n_episodes: int = 50 + # `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv. + batch_size: int = 50 + # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). + use_async_envs: bool = False + + def __post_init__(self): + if self.batch_size > self.n_episodes: + raise ValueError( + "The eval batch size is greater than the number of eval episodes " + f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} " + f"eval environments will be instantiated, but only {self.n_episodes} will be used. " + "This might significantly slow down evaluation. To fix this, you should update your command " + f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " + f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." + ) diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..16b35291360c8088ec956e6c0155fbdeeea498fe --- /dev/null +++ b/lerobot/configs/eval.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import logging +from dataclasses import dataclass, field +from pathlib import Path + +from lerobot.common import envs, policies # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.default import EvalConfig +from lerobot.configs.policies import PreTrainedConfig + + +@dataclass +class EvalPipelineConfig: + # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights + # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch + # (useful for debugging). This argument is mutually exclusive with `--config`. + env: envs.EnvConfig + eval: EvalConfig = field(default_factory=EvalConfig) + policy: PreTrainedConfig | None = None + output_dir: Path | None = None + job_name: str | None = None + seed: int | None = 1000 + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + else: + logging.warning( + "No pretrained path was provided, evaluated policy will be built from scratch (random weights)." + ) + + if not self.job_name: + if self.env is None: + self.job_name = f"{self.policy.type}" + else: + self.job_name = f"{self.env.type}_{self.policy.type}" + + if not self.output_dir: + now = dt.datetime.now() + eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" + self.output_dir = Path("outputs/eval") / eval_dir + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..39e315152b18ff71ef8982e7264287a06e91b246 --- /dev/null +++ b/lerobot/configs/parser.py @@ -0,0 +1,232 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import pkgutil +import sys +from argparse import ArgumentError +from functools import wraps +from pathlib import Path +from typing import Sequence + +import draccus + +from lerobot.common.utils.utils import has_method + +PATH_KEY = "path" +PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" +draccus.set_config_type("json") + + +def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: + """Parses arguments from cli at a given nested attribute level. + + For example, supposing the main script was called with: + python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path + + If called during execution of myscript.py, get_cli_overrides("arg2") will return: + ["--subarg1=abc" "--subarg2=some/path"] + """ + if args is None: + args = sys.argv[1:] + attr_level_args = [] + detect_string = f"--{field_name}." + exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=") + for arg in args: + if arg.startswith(detect_string) and not arg.startswith(exclude_strings): + denested_arg = f"--{arg.removeprefix(detect_string)}" + attr_level_args.append(denested_arg) + + return attr_level_args + + +def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None: + if args is None: + args = sys.argv[1:] + prefix = f"--{arg_name}=" + for arg in args: + if arg.startswith(prefix): + return arg[len(prefix) :] + return None + + +def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict: + """Parse plugin-related arguments from command-line arguments. + + This function extracts arguments from command-line arguments that match a specified suffix pattern. + It processes arguments in the format '--key=value' and returns them as a dictionary. + + Args: + plugin_arg_suffix (str): The suffix to identify plugin-related arguments. + cli_args (Sequence[str]): A sequence of command-line arguments to parse. + + Returns: + dict: A dictionary containing the parsed plugin arguments where: + - Keys are the argument names (with '--' prefix removed if present) + - Values are the corresponding argument values + + Example: + >>> args = ['--env.discover_packages_path=my_package', + ... '--other_arg=value'] + >>> parse_plugin_args('discover_packages_path', args) + {'env.discover_packages_path': 'my_package'} + """ + plugin_args = {} + for arg in args: + if "=" in arg and plugin_arg_suffix in arg: + key, value = arg.split("=", 1) + # Remove leading '--' if present + if key.startswith("--"): + key = key[2:] + plugin_args[key] = value + return plugin_args + + +class PluginLoadError(Exception): + """Raised when a plugin fails to load.""" + + +def load_plugin(plugin_path: str) -> None: + """Load and initialize a plugin from a given Python package path. + + This function attempts to load a plugin by importing its package and any submodules. + Plugin registration is expected to happen during package initialization, i.e. when + the package is imported the gym environment should be registered and the config classes + registered with their parents using the `register_subclass` decorator. + + Args: + plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin") + + Raises: + PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid. + + Examples: + >>> load_plugin("external_plugin.core") # Loads plugin from external package + + Notes: + - The plugin package should handle its own registration during import + - All submodules in the plugin package will be imported + - Implementation follows the plugin discovery pattern from Python packaging guidelines + + See Also: + https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ + """ + try: + package_module = importlib.import_module(plugin_path, __package__) + except (ImportError, ModuleNotFoundError) as e: + raise PluginLoadError( + f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" + ) from e + + def iter_namespace(ns_pkg): + return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".") + + try: + for _finder, pkg_name, _ispkg in iter_namespace(package_module): + importlib.import_module(pkg_name) + except ImportError as e: + raise PluginLoadError( + f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" + ) from e + + +def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: + return parse_arg(f"{field_name}.{PATH_KEY}", args) + + +def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: + return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args) + + +def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]: + return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] + + +def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]: + """ + Filters command-line arguments related to fields with specific path arguments. + + Args: + fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered. + args (Sequence[str] | None): The sequence of command-line arguments to be filtered. + Defaults to None. + + Returns: + list[str]: A filtered list of arguments, with arguments related to the specified + fields removed. + + Raises: + ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type + argument (e.g., `--field_name.type`) are specified for the same field. + """ + if isinstance(fields_to_filter, str): + fields_to_filter = [fields_to_filter] + + filtered_args = args + for field in fields_to_filter: + if get_path_arg(field, args): + if get_type_arg(field, args): + raise ArgumentError( + argument=None, + message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}", + ) + filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")] + + return filtered_args + + +def wrap(config_path: Path | None = None): + """ + HACK: Similar to draccus.wrap but does three additional things: + - Will remove '.path' arguments from CLI in order to process them later on. + - If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will + initialize it from there to allow to fetch configs from the hub directly + - Will load plugins specified in the CLI arguments. These plugins will typically register + their own subclasses of config classes, so that draccus can find the right class to instantiate + from the CLI '.type' arguments + """ + + def wrapper_outer(fn): + @wraps(fn) + def wrapper_inner(*args, **kwargs): + argspec = inspect.getfullargspec(fn) + argtype = argspec.annotations[argspec.args[0]] + if len(args) > 0 and type(args[0]) is argtype: + cfg = args[0] + args = args[1:] + else: + cli_args = sys.argv[1:] + plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args) + for plugin_cli_arg, plugin_path in plugin_args.items(): + try: + load_plugin(plugin_path) + except PluginLoadError as e: + # add the relevant CLI arg to the error message + raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e + cli_args = filter_arg(plugin_cli_arg, cli_args) + config_path_cli = parse_arg("config_path", cli_args) + if has_method(argtype, "__get_path_fields__"): + path_fields = argtype.__get_path_fields__() + cli_args = filter_path_args(path_fields, cli_args) + if has_method(argtype, "from_pretrained") and config_path_cli: + cli_args = filter_arg("config_path", cli_args) + cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) + else: + cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) + response = fn(cfg, *args, **kwargs) + return response + + return wrapper_inner + + return wrapper_outer diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..022d1fb5293860cadba3ca011da495fbe4e408ca --- /dev/null +++ b/lerobot/configs/policies.py @@ -0,0 +1,176 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Type, TypeVar + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.common.optim.optimizers import OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.common.utils.hub import HubMixin +from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + +# Generic variable that is either PreTrainedConfig or a subclass thereof +T = TypeVar("T", bound="PreTrainedConfig") + + +@dataclass +class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): + """ + Base configuration class for policy models. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + input_shapes: A dictionary defining the shapes of the input data for the policy. + output_shapes: A dictionary defining the shapes of the output data for the policy. + input_normalization_modes: A dictionary with key representing the modality and the value specifies the + normalization mode to apply. + output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to + the original scale. + """ + + n_obs_steps: int = 1 + normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) + + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + + device: str | None = None # cuda | cpu | mp + # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, + # automatic gradient scaling is used. + use_amp: bool = False + + def __post_init__(self): + self.pretrained_path = None + if not self.device or not is_torch_device_available(self.device): + auto_device = auto_select_torch_device() + logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + self.device = auto_device.type + + # Automatically deactivate AMP if necessary + if self.use_amp and not is_amp_available(self.device): + logging.warning( + f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." + ) + self.use_amp = False + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @abc.abstractproperty + def observation_delta_indices(self) -> list | None: + raise NotImplementedError + + @abc.abstractproperty + def action_delta_indices(self) -> list | None: + raise NotImplementedError + + @abc.abstractproperty + def reward_delta_indices(self) -> list | None: + raise NotImplementedError + + @abc.abstractmethod + def get_optimizer_preset(self) -> OptimizerConfig: + raise NotImplementedError + + @abc.abstractmethod + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + raise NotImplementedError + + @abc.abstractmethod + def validate_features(self) -> None: + raise NotImplementedError + + @property + def robot_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.STATE: + return ft + return None + + @property + def env_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.ENV: + return ft + return None + + @property + def image_features(self) -> dict[str, PolicyFeature]: + return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} + + @property + def action_feature(self) -> PolicyFeature | None: + for _, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION: + return ft + return None + + def _save_pretrained(self, save_directory: Path) -> None: + with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: Type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **policy_kwargs, + ) -> T: + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus + # something like --policy.path (in addition to --policy.type) + cli_overrides = policy_kwargs.pop("cli_overrides", []) + return draccus.parse(cls, config_file, args=cli_overrides) diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7a787b83e15f77da3b794adcbad3d1cb4e5256c7 --- /dev/null +++ b/lerobot/configs/train.py @@ -0,0 +1,175 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime as dt +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Type + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.common import envs +from lerobot.common.optim import OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.common.utils.hub import HubMixin +from lerobot.configs import parser +from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig +from lerobot.configs.policies import PreTrainedConfig + +TRAIN_CONFIG_NAME = "train_config.json" + + +@dataclass +class TrainPipelineConfig(HubMixin): + dataset: DatasetConfig + env: envs.EnvConfig | None = None + policy: PreTrainedConfig | None = None + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + output_dir: Path | None = None + job_name: str | None = None + # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure + # `dir` is the directory of an existing run with at least one checkpoint in it. + # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, + # regardless of what's provided with the training command at the time of resumption. + resume: bool = False + # `seed` is used for training (eg: model initialization, dataset shuffling) + # AND for the evaluation environments. + seed: int | None = 1000 + # Number of workers for the dataloader. + num_workers: int = 4 + batch_size: int = 8 + steps: int = 100_000 + eval_freq: int = 20_000 + log_freq: int = 200 + save_checkpoint: bool = True + # Checkpoint is saved every `save_freq` training iterations and after the last training step. + save_freq: int = 20_000 + use_policy_training_preset: bool = True + optimizer: OptimizerConfig | None = None + scheduler: LRSchedulerConfig | None = None + eval: EvalConfig = field(default_factory=EvalConfig) + wandb: WandBConfig = field(default_factory=WandBConfig) + + def __post_init__(self): + self.checkpoint_path = None + + def validate(self): + # HACK: We parse again the cli args here to get the pretrained paths if there was some. + policy_path = parser.get_path_arg("policy") + if policy_path: + # Only load the policy config + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + elif self.resume: + # The entire train config is already loaded, we just need to get the checkpoint dir + config_path = parser.parse_arg("config_path") + if not config_path: + raise ValueError( + f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" + ) + if not Path(config_path).resolve().exists(): + raise NotADirectoryError( + f"{config_path=} is expected to be a local path. " + "Resuming from the hub is not supported for now." + ) + policy_path = Path(config_path).parent + self.policy.pretrained_path = policy_path + self.checkpoint_path = policy_path.parent + + if not self.job_name: + if self.env is None: + self.job_name = f"{self.policy.type}" + else: + self.job_name = f"{self.env.type}_{self.policy.type}" + + if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): + raise FileExistsError( + f"Output directory {self.output_dir} already exists and resume is {self.resume}. " + f"Please change your output directory so that {self.output_dir} is not overwritten." + ) + elif not self.output_dir: + now = dt.datetime.now() + train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" + self.output_dir = Path("outputs/train") / train_dir + + if isinstance(self.dataset.repo_id, list): + raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") + + if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): + raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") + elif self.use_policy_training_preset and not self.resume: + self.optimizer = self.policy.get_optimizer_preset() + self.scheduler = self.policy.get_scheduler_preset() + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + def to_dict(self) -> dict: + return draccus.encode(self) + + def _save_pretrained(self, save_directory: Path) -> None: + with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: Type["TrainPipelineConfig"], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **kwargs, + ) -> "TrainPipelineConfig": + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if TRAIN_CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, TRAIN_CONFIG_NAME) + else: + print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}") + elif Path(model_id).is_file(): + config_file = model_id + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=TRAIN_CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + cli_args = kwargs.pop("cli_args", []) + cfg = draccus.parse(cls, config_file, args=cli_args) + + return cfg diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3d92e80d52fc52f8103b44d7d04533f481a408 --- /dev/null +++ b/lerobot/configs/types.py @@ -0,0 +1,41 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Note: We subclass str so that serialization is straightforward +# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol + + +class FeatureType(str, Enum): + STATE = "STATE" + VISUAL = "VISUAL" + ENV = "ENV" + ACTION = "ACTION" + + +class NormalizationMode(str, Enum): + MIN_MAX = "MIN_MAX" + MEAN_STD = "MEAN_STD" + IDENTITY = "IDENTITY" + + +class DictLike(Protocol): + def __getitem__(self, key: Any) -> Any: ... + + +@dataclass +class PolicyFeature: + type: FeatureType + shape: tuple diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py new file mode 100644 index 0000000000000000000000000000000000000000..b0dc8a97d1234f3cb3a01ab77df1facf2162f41d --- /dev/null +++ b/lerobot/scripts/configure_motor.py @@ -0,0 +1,176 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script configure a single motor at a time to a given ID and baudrate. + +Example of usage: +```bash +python lerobot/scripts/configure_motor.py \ + --port /dev/tty.usbmodem585A0080521 \ + --brand feetech \ + --model sts3215 \ + --baudrate 1000000 \ + --ID 1 +``` +""" + +import argparse +import time + + +def get_motor_bus_cls(brand: str) -> tuple: + if brand == "feetech": + from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig + from lerobot.common.robot_devices.motors.feetech import ( + MODEL_BAUDRATE_TABLE, + SCS_SERIES_BAUDRATE_TABLE, + FeetechMotorsBus, + ) + + return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE + + elif brand == "dynamixel": + from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig + from lerobot.common.robot_devices.motors.dynamixel import ( + MODEL_BAUDRATE_TABLE, + X_SERIES_BAUDRATE_TABLE, + DynamixelMotorsBus, + ) + + return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE + + else: + raise ValueError( + f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors." + ) + + +def configure_motor(port, brand, model, motor_idx_des, baudrate_des): + motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls( + brand + ) + + # Check if the provided model exists in the model_baud_rate_table + if model not in model_baudrate_table: + raise ValueError( + f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(model_baudrate_table.keys())}" + ) + + # Setup motor names, indices, and models + motor_name = "motor" + motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument + motor_model = model # Use the motor model passed via argument + + config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}) + + # Initialize the MotorBus with the correct port and motor configurations + motor_bus = motor_bus_cls(config=config) + + # Try to connect to the motor bus and handle any connection-specific errors + try: + motor_bus.connect() + print(f"Connected on port {motor_bus.port}") + except OSError as e: + print(f"Error occurred when connecting to the motor bus: {e}") + return + + # Motor bus is connected, proceed with the rest of the operations + try: + print("Scanning all baudrates and motor indices") + all_baudrates = set(series_baudrate_table.values()) + motor_index = -1 # Set the motor index to an out-of-range value. + + for baudrate in all_baudrates: + motor_bus.set_bus_baudrate(baudrate) + present_ids = motor_bus.find_motor_indices(list(range(1, 10))) + if len(present_ids) > 1: + raise ValueError( + "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor." + ) + + if len(present_ids) == 1: + if motor_index != -1: + raise ValueError( + "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor." + ) + motor_index = present_ids[0] + break + + if motor_index == -1: + raise ValueError("No motors detected. Please ensure you have one motor connected.") + + print(f"Motor index found at: {motor_index}") + + if brand == "feetech": + # Allows ID and BAUDRATE to be written in memory + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) + + if baudrate != baudrate_des: + print(f"Setting its baudrate to {baudrate_des}") + baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) + + # The write can fail, so we allow retries + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) + time.sleep(0.5) + motor_bus.set_bus_baudrate(baudrate_des) + present_baudrate_idx = motor_bus.read_with_motor_ids( + motor_bus.motor_models, motor_index, "Baud_Rate", num_retry=2 + ) + + if present_baudrate_idx != baudrate_idx: + raise OSError("Failed to write baudrate.") + + print(f"Setting its index to desired index {motor_idx_des}") + if brand == "feetech": + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) + + present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) + if present_idx != motor_idx_des: + raise OSError("Failed to write index.") + + if brand == "feetech": + # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of + # the motors. Note: this configuration is not in the official STS3215 Memory Table + motor_bus.write("Lock", 0) + motor_bus.write("Maximum_Acceleration", 254) + + motor_bus.write("Goal_Position", 2048) + time.sleep(4) + print("Present Position", motor_bus.read("Present_Position")) + + motor_bus.write("Offset", 0) + time.sleep(4) + print("Offset", motor_bus.read("Offset")) + + except Exception as e: + print(f"Error occurred during motor configuration: {e}") + + finally: + motor_bus.disconnect() + print("Disconnected from motor bus.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)") + parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") + parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") + parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)") + parser.add_argument( + "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)" + ) + args = parser.parse_args() + + configure_motor(args.port, args.brand, args.model, args.ID, args.baudrate) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py new file mode 100644 index 0000000000000000000000000000000000000000..3daea98d3b557a7dfd202c5d7f73540ec0a37144 --- /dev/null +++ b/lerobot/scripts/control_robot.py @@ -0,0 +1,437 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities to control a robot. + +Useful to record a dataset, replay a recorded episode, run the policy on your robot +and record an evaluation dataset, and to recalibrate your robot if needed. + +Examples of usage: + +- Recalibrate your robot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --control.type=calibrate +``` + +- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --robot.cameras='{}' \ + --control.type=teleoperate + +# Add the cameras from the robot definition to visualize them: +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --control.type=teleoperate +``` + +- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --control.type=teleoperate \ + --control.fps=30 +``` + +- Record one episode in order to test replay: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=$USER/koch_test \ + --control.num_episodes=1 \ + --control.push_to_hub=True +``` + +- Visualize dataset: +```bash +python lerobot/scripts/visualize_dataset.py \ + --repo-id $USER/koch_test \ + --episode-index 0 +``` + +- Replay this test episode: +```bash +python lerobot/scripts/control_robot.py replay \ + --robot.type=so100 \ + --control.type=replay \ + --control.fps=30 \ + --control.repo_id=$USER/koch_test \ + --control.episode=0 +``` + +- Record a full dataset in order to train a policy, with 2 seconds of warmup, +30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes: +```bash +python lerobot/scripts/control_robot.py record \ + --robot.type=so100 \ + --control.type=record \ + --control.fps 30 \ + --control.repo_id=$USER/koch_pick_place_lego \ + --control.num_episodes=50 \ + --control.warmup_time_s=2 \ + --control.episode_time_s=30 \ + --control.reset_time_s=10 +``` + +- For remote controlled robots like LeKiwi, run this script on the robot edge device (e.g. RaspBerryPi): +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=lekiwi \ + --control.type=remote_robot +``` + +**NOTE**: You can use your keyboard to control data recording flow. +- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment. +- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode. +- Tap left arrow key '<-' to early exit and re-record the current episode. +- Tap escape key 'esc' to stop the data recording. +This might require a sudo permission to allow your terminal to monitor keyboard events. + +**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`. + +- Train on this dataset with the ACT policy: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/koch_pick_place_lego \ + --policy.type=act \ + --output_dir=outputs/train/act_koch_pick_place_lego \ + --job_name=act_koch_pick_place_lego \ + --device=cuda \ + --wandb.enable=true +``` + +- Run the pretrained policy on the robot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=so100 \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=$USER/eval_act_koch_pick_place_lego \ + --control.num_episodes=10 \ + --control.warmup_time_s=2 \ + --control.episode_time_s=30 \ + --control.reset_time_s=10 \ + --control.push_to_hub=true \ + --control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model +``` +""" + +import logging +import os +import time +from dataclasses import asdict +from pprint import pformat + +import rerun as rr + +# from safetensors.torch import load_file, save_file +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.control_configs import ( + CalibrateControlConfig, + ControlConfig, + ControlPipelineConfig, + RecordControlConfig, + RemoteRobotConfig, + ReplayControlConfig, + TeleoperateControlConfig, +) +from lerobot.common.robot_devices.control_utils import ( + control_loop, + init_keyboard_listener, + is_headless, + log_control_info, + record_episode, + reset_environment, + sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, + stop_recording, + warmup_record, +) +from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect +from lerobot.common.utils.utils import has_method, init_logging, log_say +from lerobot.configs import parser + +######################################################################################## +# Control modes +######################################################################################## + + +@safe_disconnect +def calibrate(robot: Robot, cfg: CalibrateControlConfig): + # TODO(aliberts): move this code in robots' classes + if robot.robot_type.startswith("stretch"): + if not robot.is_connected: + robot.connect() + if not robot.is_homed(): + robot.home() + return + + arms = robot.available_arms if cfg.arms is None else cfg.arms + unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms] + available_arms_str = " ".join(robot.available_arms) + unknown_arms_str = " ".join(unknown_arms) + + if arms is None or len(arms) == 0: + raise ValueError( + "No arm provided. Use `--arms` as argument with one or more available arms.\n" + f"For instance, to recalibrate all arms add: `--arms {available_arms_str}`" + ) + + if len(unknown_arms) > 0: + raise ValueError( + f"Unknown arms provided ('{unknown_arms_str}'). Available arms are `{available_arms_str}`." + ) + + for arm_id in arms: + arm_calib_path = robot.calibration_dir / f"{arm_id}.json" + if arm_calib_path.exists(): + print(f"Removing '{arm_calib_path}'") + arm_calib_path.unlink() + else: + print(f"Calibration file not found '{arm_calib_path}'") + + if robot.is_connected: + robot.disconnect() + + if robot.robot_type.startswith("lekiwi") and "main_follower" in arms: + print("Calibrating only the lekiwi follower arm 'main_follower'...") + robot.calibrate_follower() + return + + if robot.robot_type.startswith("lekiwi") and "main_leader" in arms: + print("Calibrating only the lekiwi leader arm 'main_leader'...") + robot.calibrate_leader() + return + + # Calling `connect` automatically runs calibration + # when the calibration file is missing + robot.connect() + robot.disconnect() + print("Calibration is done! You can now teleoperate and record datasets!") + + +@safe_disconnect +def teleoperate(robot: Robot, cfg: TeleoperateControlConfig): + control_loop( + robot, + control_time_s=cfg.teleop_time_s, + fps=cfg.fps, + teleoperate=True, + display_data=cfg.display_data, + ) + + +@safe_disconnect +def record( + robot: Robot, + cfg: RecordControlConfig, +) -> LeRobotDataset: + # TODO(rcadene): Add option to record logs + if cfg.resume: + dataset = LeRobotDataset( + cfg.repo_id, + root=cfg.root, + ) + if len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.num_image_writer_processes, + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) + else: + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(cfg.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.root, + robot=robot, + use_videos=cfg.video, + image_writer_processes=cfg.num_image_writer_processes, + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + + if not robot.is_connected: + robot.connect() + + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + enable_teleoperation = policy is None + log_say("Warmup record", cfg.play_sounds) + warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps) + + if has_method(robot, "teleop_safety_stop"): + robot.teleop_safety_stop() + + recorded_episodes = 0 + while True: + if recorded_episodes >= cfg.num_episodes: + break + + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + record_episode( + robot=robot, + dataset=dataset, + events=events, + episode_time_s=cfg.episode_time_s, + display_data=cfg.display_data, + policy=policy, + fps=cfg.fps, + single_task=cfg.single_task, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + reset_environment(robot, events, cfg.reset_time_s, cfg.fps) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + stop_recording(robot, listener, cfg.display_data) + + if cfg.push_to_hub: + dataset.push_to_hub(tags=cfg.tags, private=cfg.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +@safe_disconnect +def replay( + robot: Robot, + cfg: ReplayControlConfig, +): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset + # TODO(rcadene): Add option to record logs + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) + actions = dataset.hf_dataset.select_columns("action") + + if not robot.is_connected: + robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) + + dt_s = time.perf_counter() - start_episode_t + log_control_info(robot, dt_s, fps=cfg.fps) + + +def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None: + """Initializes the Rerun SDK for visualizing the control loop. + + Args: + control_config: Configuration determining data display and robot type. + session_name: Rerun session name. Defaults to "lerobot_control_loop". + + Raises: + ValueError: If viewer IP is missing for non-remote configurations with display enabled. + """ + if (control_config.display_data and not is_headless()) or ( + control_config.display_data and isinstance(control_config, RemoteRobotConfig) + ): + # Configure Rerun flush batch size default to 8KB if not set + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") + os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size + + # Initialize Rerun based on configuration + rr.init(session_name) + if isinstance(control_config, RemoteRobotConfig): + viewer_ip = control_config.viewer_ip + viewer_port = control_config.viewer_port + if not viewer_ip or not viewer_port: + raise ValueError( + "Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data." + ) + logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}") + rr.connect_tcp(f"{viewer_ip}:{viewer_port}") + else: + # Get memory limit for rerun viewer parameters + memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") + rr.spawn(memory_limit=memory_limit) + + +@parser.wrap() +def control_robot(cfg: ControlPipelineConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + robot = make_robot_from_config(cfg.robot) + + # TODO(Steven): Blueprint for fixed window size + + if isinstance(cfg.control, CalibrateControlConfig): + calibrate(robot, cfg.control) + elif isinstance(cfg.control, TeleoperateControlConfig): + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop") + teleoperate(robot, cfg.control) + elif isinstance(cfg.control, RecordControlConfig): + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record") + record(robot, cfg.control) + elif isinstance(cfg.control, ReplayControlConfig): + replay(robot, cfg.control) + elif isinstance(cfg.control, RemoteRobotConfig): + from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi + + _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote") + run_lekiwi(cfg.robot) + + if robot.is_connected: + # Disconnect manually to avoid a "Core dump" during process + # termination due to camera threads not properly exiting. + robot.disconnect() + + +if __name__ == "__main__": + control_robot() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py new file mode 100644 index 0000000000000000000000000000000000000000..5347822c8b664579f023205bfeda8ac69b79fe1b --- /dev/null +++ b/lerobot/scripts/control_sim_robot.py @@ -0,0 +1,561 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities to control a robot in simulation. + +Useful to record a dataset, replay a recorded episode and record an evaluation dataset. + +Examples of usage: + + +- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency. + You can modify this value depending on how fast your simulation can run: +```bash +python lerobot/scripts/control_robot.py teleoperate \ + --fps 30 \ + --robot-path lerobot/configs/robot/your_robot_config.yaml \ + --sim-config lerobot/configs/env/your_sim_config.yaml +``` + +- Record one episode in order to test replay: +```bash +python lerobot/scripts/control_sim_robot.py record \ + --robot-path lerobot/configs/robot/your_robot_config.yaml \ + --sim-config lerobot/configs/env/your_sim_config.yaml \ + --fps 30 \ + --repo-id $USER/robot_sim_test \ + --num-episodes 1 \ + --run-compute-stats 0 +``` + +Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub. + +- Visualize dataset: +```bash +python lerobot/scripts/visualize_dataset.py \ + --repo-id $USER/robot_sim_test \ + --episode-index 0 +``` + +- Replay a sequence of test episodes: +```bash +python lerobot/scripts/control_sim_robot.py replay \ + --robot-path lerobot/configs/robot/your_robot_config.yaml \ + --sim-config lerobot/configs/env/your_sim_config.yaml \ + --fps 30 \ + --repo-id $USER/robot_sim_test \ + --episode 0 +``` +Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection. + +- Record a full dataset in order to train a policy, +30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes: +```bash +python lerobot/scripts/control_sim_robot.py record \ + --robot-path lerobot/configs/robot/your_robot_config.yaml \ + --sim-config lerobot/configs/env/your_sim_config.yaml \ + --fps 30 \ + --repo-id $USER/robot_sim_test \ + --num-episodes 50 \ + --episode-time-s 30 \ +``` + +**NOTE**: You can use your keyboard to control data recording flow. +- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment. +- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode. +- Tap left arrow key '<-' to early exit and re-record the current episode. +- Tap escape key 'esc' to stop the data recording. +This might require a sudo permission to allow your terminal to monitor keyboard events. + +**NOTE**: You can resume/continue data recording by running the same data recording command twice. +""" + +import argparse +import importlib +import logging +import time +from pathlib import Path + +import cv2 +import gymnasium as gym +import numpy as np +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robot_devices.control_utils import ( + init_keyboard_listener, + init_policy, + is_headless, + log_control_info, + predict_action, + sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, + stop_recording, +) +from lerobot.common.robot_devices.robots.utils import Robot, make_robot +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say + +raise NotImplementedError("This script is currently deactivated") + +DEFAULT_FEATURES = { + "next.reward": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + "next.success": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + "seed": { + "dtype": "int64", + "shape": (1,), + "names": None, + }, + "timestamp": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, +} + + +######################################################################################## +# Utilities +######################################################################################## +def none_or_int(value): + if value == "None": + return None + return int(value) + + +def init_sim_calibration(robot, cfg): + # Constants necessary for transforming the joint pos of the real robot to the sim + # depending on the robot description used in that sim. + start_pos = np.array(robot.leader_arms.main.calibration["start_pos"]) + axis_directions = np.array(cfg.get("axis_directions", [1])) + offsets = np.array(cfg.get("offsets", [0])) * np.pi + + return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets} + + +def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): + """Counts - starting position -> radians -> align axes -> offset""" + return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets + + +######################################################################################## +# Control modes +######################################################################################## + + +def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None): + env = env() + env.reset() + start_teleop_t = time.perf_counter() + while True: + leader_pos = robot.leader_arms.main.read("Present_Position") + action = process_action_fn(leader_pos) + env.step(np.expand_dims(action, 0)) + if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: + print("Teleoperation processes finished.") + break + + +def record( + env, + robot: Robot, + process_action_from_leader, + root: Path, + repo_id: str, + task: str, + fps: int | None = None, + tags: list[str] | None = None, + pretrained_policy_name_or_path: str = None, + policy_overrides: bool | None = None, + episode_time_s: int = 30, + num_episodes: int = 50, + video: bool = True, + push_to_hub: bool = True, + num_image_writer_processes: int = 0, + num_image_writer_threads_per_camera: int = 4, + display_cameras: bool = False, + play_sounds: bool = True, + resume: bool = False, + local_files_only: bool = False, + run_compute_stats: bool = True, +) -> LeRobotDataset: + # Load pretrained policy + policy = None + if pretrained_policy_name_or_path is not None: + policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + + if fps is None: + fps = policy_fps + logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + + if policy is None and process_action_from_leader is None: + raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") + + # initialize listener before sim env + listener, events = init_keyboard_listener() + + # create sim env + env = env() + + # Create empty dataset or load existing saved episodes + num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space]) + + # get image keys + image_keys = [key for key in env.observation_space if "image" in key] + state_keys_dict = env_cfg.state_keys + + if resume: + dataset = LeRobotDataset( + repo_id, + root=root, + local_files_only=local_files_only, + ) + dataset.start_image_writer( + num_processes=num_image_writer_processes, + num_threads=num_image_writer_threads_per_camera * num_cameras, + ) + sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) + else: + features = DEFAULT_FEATURES + # add image keys to features + for key in image_keys: + shape = env.observation_space[key].shape + if not key.startswith("observation.image."): + key = "observation.image." + key + features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape} + + for key, obs_key in state_keys_dict.items(): + features[key] = { + "dtype": "float32", + "names": None, + "shape": env.observation_space[obs_key].shape, + } + + features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} + + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(repo_id, policy) + dataset = LeRobotDataset.create( + repo_id, + fps, + root=root, + features=features, + use_videos=video, + image_writer_processes=num_image_writer_processes, + image_writer_threads=num_image_writer_threads_per_camera * num_cameras, + ) + + recorded_episodes = 0 + while True: + log_say(f"Recording episode {dataset.num_episodes}", play_sounds) + + if events is None: + events = {"exit_early": False} + + if episode_time_s is None: + episode_time_s = float("inf") + + timestamp = 0 + start_episode_t = time.perf_counter() + + seed = np.random.randint(0, 1e5) + observation, info = env.reset(seed=seed) + + while timestamp < episode_time_s: + start_loop_t = time.perf_counter() + + if policy is not None: + action = predict_action(observation, policy, device, use_amp) + else: + leader_pos = robot.leader_arms.main.read("Present_Position") + action = process_action_from_leader(leader_pos) + + observation, reward, terminated, _, info = env.step(action) + + success = info.get("is_success", False) + env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps) + + frame = { + "action": torch.from_numpy(action), + "next.reward": reward, + "next.success": success, + "seed": seed, + "timestamp": env_timestamp, + } + + for key in image_keys: + if not key.startswith("observation.image"): + frame["observation.image." + key] = observation[key] + else: + frame[key] = observation[key] + + for key, obs_key in state_keys_dict.items(): + frame[key] = torch.from_numpy(observation[obs_key]) + + dataset.add_frame(frame) + + if display_cameras and not is_headless(): + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"] or terminated: + events["exit_early"] = False + break + + if events["rerecord_episode"]: + log_say("Re-record episode", play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode(task=task) + recorded_episodes += 1 + + if events["stop_recording"] or recorded_episodes >= num_episodes: + break + else: + logging.info("Waiting for a few seconds before starting next episode recording...") + busy_wait(3) + + log_say("Stop recording", play_sounds, blocking=True) + stop_recording(robot, listener, display_cameras) + + if run_compute_stats: + logging.info("Computing dataset statistics") + dataset.consolidate(run_compute_stats) + + if push_to_hub: + dataset.push_to_hub(tags=tags) + + log_say("Exiting", play_sounds) + return dataset + + +def replay( + env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True +): + env = env() + + local_dir = Path(root) / repo_id + if not local_dir.exists(): + raise ValueError(local_dir) + + dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + items = dataset.hf_dataset.select_columns("action") + seeds = dataset.hf_dataset.select_columns("seed")["seed"] + + from_idx = dataset.episode_data_index["from"][episode].item() + to_idx = dataset.episode_data_index["to"][episode].item() + env.reset(seed=seeds[from_idx].item()) + logging.info("Replaying episode") + log_say("Replaying episode", play_sounds=True) + for idx in range(from_idx, to_idx): + start_episode_t = time.perf_counter() + action = items[idx]["action"] + env.step(action.unsqueeze(0).numpy()) + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / fps - dt_s) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="mode", required=True) + + # Set common options for all the subparsers + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + + base_parser.add_argument( + "--sim-config", + help="Path to a yaml config you want to use for initializing a sim environment based on gym ", + ) + + parser_record = subparsers.add_parser("teleoperate", parents=[base_parser]) + + parser_record = subparsers.add_parser("record", parents=[base_parser]) + parser_record.add_argument( + "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + ) + parser_record.add_argument( + "--root", + type=Path, + default=None, + help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", + ) + parser_record.add_argument( + "--repo-id", + type=str, + default="lerobot/test", + help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + parser_record.add_argument( + "--episode-time-s", + type=int, + default=60, + help="Number of seconds for data recording for each episode.", + ) + parser_record.add_argument( + "--task", + type=str, + required=True, + help="A description of the task preformed during recording that can be used as a language instruction.", + ) + parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") + parser_record.add_argument( + "--run-compute-stats", + type=int, + default=1, + help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", + ) + parser_record.add_argument( + "--push-to-hub", + type=int, + default=1, + help="Upload dataset to Hugging Face hub.", + ) + parser_record.add_argument( + "--tags", + type=str, + nargs="*", + help="Add tags to your dataset on the hub.", + ) + parser_record.add_argument( + "--num-image-writer-processes", + type=int, + default=0, + help=( + "Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; " + "set to ≥1 to use subprocesses, each using threads to write images. The best number of processes " + "and threads depends on your system. We recommend 4 threads per camera with 0 processes. " + "If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses." + ), + ) + parser_record.add_argument( + "--num-image-writer-threads-per-camera", + type=int, + default=4, + help=( + "Number of threads writing the frames as png images on disk, per camera. " + "Too much threads might cause unstable teleoperation fps due to main thread being blocked. " + "Not enough threads might cause low camera fps." + ), + ) + parser_record.add_argument( + "--display-cameras", + type=int, + default=0, + help="Visualize image observations with opencv.", + ) + parser_record.add_argument( + "--resume", + type=int, + default=0, + help="Resume recording on an existing dataset.", + ) + parser_replay = subparsers.add_parser("replay", parents=[base_parser]) + parser_replay.add_argument( + "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + ) + parser_replay.add_argument( + "--root", + type=Path, + default=None, + help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.", + ) + parser_replay.add_argument( + "--repo-id", + type=str, + default="lerobot/test", + help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") + + args = parser.parse_args() + + init_logging() + + control_mode = args.mode + robot_path = args.robot_path + env_config_path = args.sim_config + kwargs = vars(args) + del kwargs["mode"] + del kwargs["robot_path"] + del kwargs["sim_config"] + + # make gym env + env_cfg = init_hydra_config(env_config_path) + importlib.import_module(f"gym_{env_cfg.env.type}") + + def env_constructor(): + return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym) + + robot = None + process_leader_actions_fn = None + + if control_mode in ["teleoperate", "record"]: + # make robot + robot_overrides = ["~cameras", "~follower_arms"] + # TODO(rcadene): remove + robot_cfg = init_hydra_config(robot_path, robot_overrides) + robot = make_robot(robot_cfg) + robot.connect() + + calib_kwgs = init_sim_calibration(robot, env_cfg.calibration) + + def process_leader_actions_fn(action): + return real_positions_to_sim(action, **calib_kwgs) + + robot.leader_arms.main.calibration = None + + if control_mode == "teleoperate": + teleoperate(env_constructor, robot, process_leader_actions_fn) + + elif control_mode == "record": + record(env_constructor, robot, process_leader_actions_fn, **kwargs) + + elif control_mode == "replay": + replay(env_constructor, **kwargs) + + else: + raise ValueError( + f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." + ) + + if robot and robot.is_connected: + # Disconnect manually to avoid a "Core dump" during process + # termination due to camera threads not properly exiting. + robot.disconnect() diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3cc291f30c07287b2f0b836d34506b991baf10 --- /dev/null +++ b/lerobot/scripts/display_sys_info.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Use this script to get a quick summary of your system config. +It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. +""" + +import platform + +HAS_HF_HUB = True +HAS_HF_DATASETS = True +HAS_NP = True +HAS_TORCH = True +HAS_LEROBOT = True + +try: + import huggingface_hub +except ImportError: + HAS_HF_HUB = False + +try: + import datasets +except ImportError: + HAS_HF_DATASETS = False + +try: + import numpy as np +except ImportError: + HAS_NP = False + +try: + import torch +except ImportError: + HAS_TORCH = False + +try: + import lerobot +except ImportError: + HAS_LEROBOT = False + + +lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A" +hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A" +hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A" +np_version = np.__version__ if HAS_NP else "N/A" + +torch_version = torch.__version__ if HAS_TORCH else "N/A" +torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" +cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" + + +# TODO(aliberts): refactor into an actual command `lerobot env` +def display_sys_info() -> dict: + """Run this to get basic system info to help for tracking issues & bugs.""" + info = { + "`lerobot` version": lerobot_version, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "Huggingface_hub version": hf_hub_version, + "Dataset version": hf_datasets_version, + "Numpy version": np_version, + "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})", + "Cuda version": cuda_version, + "Using GPU in script?": "", + # "Using distributed or parallel set-up in script?": "", + } + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") + print(format_dict(info)) + return info + + +def format_dict(d: dict) -> str: + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" + + +if __name__ == "__main__": + display_sys_info() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9790f8b317157969a177519f30b221bed63674ae --- /dev/null +++ b/lerobot/scripts/eval.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluate a policy on an environment by running rollouts and computing metrics. + +Usage examples: + +You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht) +for 10 episodes. + +``` +python lerobot/scripts/eval.py \ + --policy.path=lerobot/diffusion_pusht \ + --env.type=pusht \ + --eval.batch_size=10 \ + --eval.n_episodes=10 \ + --use_amp=false \ + --device=cuda +``` + +OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. +``` +python lerobot/scripts/eval.py \ + --policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \ + --env.type=pusht \ + --eval.batch_size=10 \ + --eval.n_episodes=10 \ + --use_amp=false \ + --device=cuda +``` + +Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files. + +You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py +""" + +import json +import logging +import threading +import time +from contextlib import nullcontext +from copy import deepcopy +from dataclasses import asdict +from pathlib import Path +from pprint import pformat +from typing import Callable + +import einops +import gymnasium as gym +import numpy as np +import torch +from termcolor import colored +from torch import Tensor, nn +from tqdm import trange + +from lerobot.common.envs.factory import make_env +from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.io_utils import write_video +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.utils import ( + get_safe_torch_device, + init_logging, + inside_slurm, +) +from lerobot.configs import parser +from lerobot.configs.eval import EvalPipelineConfig + + +def rollout( + env: gym.vector.VectorEnv, + policy: PreTrainedPolicy, + seeds: list[int] | None = None, + return_observations: bool = False, + render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, +) -> dict: + """Run a batched policy rollout once through a batch of environments. + + Note that all environments in the batch are run until the last environment is done. This means some + data will probably need to be discarded (for environments that aren't the first one to be done). + + The return dictionary contains: + (optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation + keys. NOTE the that this has an extra sequence element relative to the other keys in the + dictionary. This is because an extra observation is included for after the environment is + terminated or truncated. + "action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not + including the last observations). + "reward": A (batch, sequence) tensor of rewards received for applying the actions. + "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon + environment termination/truncation). + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + the first True is followed by True's all the way till the end. This can be used for masking + extraneous elements from the sequences above. + + Args: + env: The batch of environments. + policy: The policy. Must be a PyTorch nn module. + seeds: The environments are seeded once at the start of the rollout. If provided, this argument + specifies the seeds for each of the environments. + return_observations: Whether to include all observations in the returned rollout data. Observations + are returned optionally because they typically take more memory to cache. Defaults to False. + render_callback: Optional rendering callback to be used after the environments are reset, and after + every step. + Returns: + The dictionary described above. + """ + assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." + device = get_device_from_parameters(policy) + + # Reset the policy and environments. + policy.reset() + observation, info = env.reset(seed=seeds) + if render_callback is not None: + render_callback(env) + + all_observations = [] + all_actions = [] + all_rewards = [] + all_successes = [] + all_dones = [] + + step = 0 + # Keep track of which environments are done. + done = np.array([False] * env.num_envs) + max_steps = env.call("_max_episode_steps")[0] + progbar = trange( + max_steps, + desc=f"Running rollout with at most {max_steps} steps", + disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs + leave=False, + ) + check_env_attributes_and_types(env) + while not np.all(done): + # Numpy array to tensor and changing dictionary keys to LeRobot policy format. + observation = preprocess_observation(observation) + if return_observations: + all_observations.append(deepcopy(observation)) + + observation = { + key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation + } + + # Infer "task" from attributes of environments. + # TODO: works with SyncVectorEnv but not AsyncVectorEnv + observation = add_envs_task(env, observation) + + with torch.inference_mode(): + action = policy.select_action(observation) + + # Convert to CPU / numpy. + action = action.to("cpu").numpy() + assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" + + # Apply the next action. + observation, reward, terminated, truncated, info = env.step(action) + if render_callback is not None: + render_callback(env) + + # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't + # available of none of the envs finished. + if "final_info" in info: + successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + else: + successes = [False] * env.num_envs + + # Keep track of which environments are done so far. + done = terminated | truncated | done + + all_actions.append(torch.from_numpy(action)) + all_rewards.append(torch.from_numpy(reward)) + all_dones.append(torch.from_numpy(done)) + all_successes.append(torch.tensor(successes)) + + step += 1 + running_success_rate = ( + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() + ) + progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) + progbar.update() + + # Track the final observation. + if return_observations: + observation = preprocess_observation(observation) + all_observations.append(deepcopy(observation)) + + # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. + ret = { + "action": torch.stack(all_actions, dim=1), + "reward": torch.stack(all_rewards, dim=1), + "success": torch.stack(all_successes, dim=1), + "done": torch.stack(all_dones, dim=1), + } + if return_observations: + stacked_observations = {} + for key in all_observations[0]: + stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) + ret["observation"] = stacked_observations + + if hasattr(policy, "use_original_modules"): + policy.use_original_modules() + + return ret + + +def eval_policy( + env: gym.vector.VectorEnv, + policy: PreTrainedPolicy, + n_episodes: int, + max_episodes_rendered: int = 0, + videos_dir: Path | None = None, + return_episode_data: bool = False, + start_seed: int | None = None, +) -> dict: + """ + Args: + env: The batch of environments. + policy: The policy. + n_episodes: The number of episodes to evaluate. + max_episodes_rendered: Maximum number of episodes to render into videos. + videos_dir: Where to save rendered videos. + return_episode_data: Whether to return episode data for online training. Incorporates the data into + the "episodes" key of the returned dictionary. + start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the + seed is incremented by 1. If not provided, the environments are not manually seeded. + Returns: + Dictionary with metrics and data regarding the rollouts. + """ + if max_episodes_rendered > 0 and not videos_dir: + raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + + if not isinstance(policy, PreTrainedPolicy): + raise ValueError( + f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." + ) + + start = time.time() + policy.eval() + + # Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly + # divisible by env.num_envs we end up discarding some data in the last batch. + n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0) + + # Keep track of some metrics. + sum_rewards = [] + max_rewards = [] + all_successes = [] + all_seeds = [] + threads = [] # for video saving threads + n_episodes_rendered = 0 # for saving the correct number of videos + + # Callback for visualization. + def render_frame(env: gym.vector.VectorEnv): + # noqa: B023 + if n_episodes_rendered >= max_episodes_rendered: + return + n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) + if isinstance(env, gym.vector.SyncVectorEnv): + ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 + elif isinstance(env, gym.vector.AsyncVectorEnv): + # Here we must render all frames and discard any we don't need. + ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) + + if max_episodes_rendered > 0: + video_paths: list[str] = [] + + if return_episode_data: + episode_data: dict | None = None + + # we dont want progress bar when we use slurm, since it clutters the logs + progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) + for batch_ix in progbar: + # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout + # step. + if max_episodes_rendered > 0: + ep_frames: list[np.ndarray] = [] + + if start_seed is None: + seeds = None + else: + seeds = range( + start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + ) + rollout_data = rollout( + env, + policy, + seeds=list(seeds) if seeds else None, + return_observations=return_episode_data, + render_callback=render_frame if max_episodes_rendered > 0 else None, + ) + + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). + n_steps = rollout_data["done"].shape[1] + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + done_indices = torch.argmax(rollout_data["done"].to(int), dim=1) + + # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done + # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. + mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() + # Extend metrics. + batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") + sum_rewards.extend(batch_sum_rewards.tolist()) + batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") + max_rewards.extend(batch_max_rewards.tolist()) + batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") + all_successes.extend(batch_successes.tolist()) + if seeds: + all_seeds.extend(seeds) + else: + all_seeds.append(None) + + # FIXME: episode_data is either None or it doesn't exist + if return_episode_data: + this_episode_data = _compile_episode_data( + rollout_data, + done_indices, + start_episode_index=batch_ix * env.num_envs, + start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), + fps=env.unwrapped.metadata["render_fps"], + ) + if episode_data is None: + episode_data = this_episode_data + else: + # Some sanity checks to make sure we are correctly compiling the data. + assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] + assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] + # Concatenate the episode data. + episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} + + # Maybe render video for visualization. + if max_episodes_rendered > 0 and len(ep_frames) > 0: + batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *) + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if n_episodes_rendered >= max_episodes_rendered: + break + + videos_dir.mkdir(parents=True, exist_ok=True) + video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" + video_paths.append(str(video_path)) + thread = threading.Thread( + target=write_video, + args=( + str(video_path), + stacked_frames[: done_index + 1], # + 1 to capture the last observation + env.unwrapped.metadata["render_fps"], + ), + ) + thread.start() + threads.append(thread) + n_episodes_rendered += 1 + + progbar.set_postfix( + {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} + ) + + # Wait till all video rendering threads are done. + for thread in threads: + thread.join() + + # Compile eval info. + info = { + "per_episode": [ + { + "episode_ix": i, + "sum_reward": sum_reward, + "max_reward": max_reward, + "success": success, + "seed": seed, + } + for i, (sum_reward, max_reward, success, seed) in enumerate( + zip( + sum_rewards[:n_episodes], + max_rewards[:n_episodes], + all_successes[:n_episodes], + all_seeds[:n_episodes], + strict=True, + ) + ) + ], + "aggregated": { + "avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])), + "avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])), + "pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100), + "eval_s": time.time() - start, + "eval_ep_s": (time.time() - start) / n_episodes, + }, + } + + if return_episode_data: + info["episodes"] = episode_data + + if max_episodes_rendered > 0: + info["video_paths"] = video_paths + + return info + + +def _compile_episode_data( + rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float +) -> dict: + """Convenience function for `eval_policy(return_episode_data=True)` + + Compiles all the rollout data into a Hugging Face dataset. + + Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`). + """ + ep_dicts = [] + total_frames = 0 + for ep_ix in range(rollout_data["action"].shape[0]): + # + 2 to include the first done frame and the last observation frame. + num_frames = done_indices[ep_ix].item() + 2 + total_frames += num_frames + + # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. + ep_dict = { + "action": rollout_data["action"][ep_ix, : num_frames - 1], + "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), + "frame_index": torch.arange(0, num_frames - 1, 1), + "timestamp": torch.arange(0, num_frames - 1, 1) / fps, + "next.done": rollout_data["done"][ep_ix, : num_frames - 1], + "next.success": rollout_data["success"][ep_ix, : num_frames - 1], + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + } + + # For the last observation frame, all other keys will just be copy padded. + for k in ep_dict: + ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) + + for key in rollout_data["observation"]: + ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + + ep_dicts.append(ep_dict) + + data_dict = {} + for key in ep_dicts[0]: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) + + return data_dict + + +@parser.wrap() +def eval_main(cfg: EvalPipelineConfig): + logging.info(pformat(asdict(cfg))) + + # Check device is available + device = get_safe_torch_device(cfg.policy.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_seed(cfg.seed) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + + logging.info("Making environment.") + env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + + logging.info("Making policy.") + + policy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + policy.eval() + + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + cfg.eval.n_episodes, + max_episodes_rendered=10, + videos_dir=Path(cfg.output_dir) / "videos", + start_seed=cfg.seed, + ) + print(info["aggregated"]) + + # Save info + with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: + json.dump(info, f, indent=2) + + env.close() + + logging.info("End of eval") + + +if __name__ == "__main__": + init_logging() + eval_main() diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py new file mode 100644 index 0000000000000000000000000000000000000000..68f2315d7c31477a52edb5557942013e4132d03a --- /dev/null +++ b/lerobot/scripts/find_motors_bus_port.py @@ -0,0 +1,55 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from pathlib import Path + +from serial.tools import list_ports # Part of pyserial library + + +def find_available_ports(): + if os.name == "nt": # Windows + # List COM ports using pyserial + ports = [port.device for port in list_ports.comports()] + else: # Linux/macOS + # List /dev/tty* ports for Unix-based systems + ports = [str(path) for path in Path("/dev").glob("tty*")] + return ports + + +def find_port(): + print("Finding all available ports for the MotorsBus.") + ports_before = find_available_ports() + print("Ports before disconnecting:", ports_before) + + print("Remove the USB cable from your MotorsBus and press Enter when done.") + input() # Wait for user to disconnect the device + + time.sleep(0.5) # Allow some time for port to be released + ports_after = find_available_ports() + ports_diff = list(set(ports_before) - set(ports_after)) + + if len(ports_diff) == 1: + port = ports_diff[0] + print(f"The port of this MotorsBus is '{port}'") + print("Reconnect the USB cable.") + elif len(ports_diff) == 0: + raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") + else: + raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") + + +if __name__ == "__main__": + # Helper to find the USB port associated with your MotorsBus. + find_port() diff --git a/lerobot/scripts/push_pretrained.py b/lerobot/scripts/push_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c683f96f17794e97c5a64387140975630ed7b2 --- /dev/null +++ b/lerobot/scripts/push_pretrained.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it +to the hub. + +Example: + +```bash +python lerobot/scripts/push_pretrained.py \ + --pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \ + --repo_id=lerobot/act_aloha_sim_transfer_cube_human +``` +""" + +from dataclasses import dataclass +from pathlib import Path + +import draccus +from huggingface_hub import HfApi + + +@dataclass +class PushPreTrainedConfig: + pretrained_path: Path + repo_id: str + branch: str | None = None + private: bool = False + exist_ok: bool = False + + +@draccus.wrap() +def main(cfg: PushPreTrainedConfig): + hub_api = HfApi() + hub_api.create_repo( + repo_id=cfg.repo_id, + private=cfg.private, + repo_type="model", + exist_ok=cfg.exist_ok, + ) + if cfg.branch: + hub_api.create_branch( + repo_id=cfg.repo_id, + branch=cfg.branch, + repo_type="model", + exist_ok=cfg.exist_ok, + ) + + hub_api.upload_folder( + repo_id=cfg.repo_id, + folder_path=cfg.pretrained_path, + repo_type="model", + revision=cfg.branch, + ) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/server/async_inference.proto b/lerobot/scripts/server/async_inference.proto new file mode 100644 index 0000000000000000000000000000000000000000..8eac7ef90d1dd64d7efa8a2b7df8a19f273c65b9 --- /dev/null +++ b/lerobot/scripts/server/async_inference.proto @@ -0,0 +1,60 @@ +// fmt: off +// flake8: noqa +// !/usr/bin/env python + +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package async_inference; + +// AsyncInference: from Robot perspective +// Robot send observations to & executes action received from a remote Policy server +service AsyncInference { + // Robot -> Policy to share observations with a remote inference server + // Policy -> Robot to share actions predicted for given observations + rpc SendObservations(stream Observation) returns (Empty); + rpc StreamActions(Empty) returns (stream Action); + rpc SendPolicyInstructions(PolicySetup) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Observation { + // sent by Robot, to remote Policy + TransferState transfer_state = 1; + bytes data = 2; +} + +message Action { + // sent by remote Policy, to Robot + TransferState transfer_state = 1; + bytes data = 2; +} + +message PolicySetup { + // sent by Robot to remote server, to init Policy + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/scripts/server/async_inference_pb2.py b/lerobot/scripts/server/async_inference_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d18d6f7a0147cd62009cff605a9f0e636a730c --- /dev/null +++ b/lerobot/scripts/server/async_inference_pb2.py @@ -0,0 +1,48 @@ +# fmt: off +# flake8: noqa +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: async_inference.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'async_inference.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=301 + _globals['_TRANSFERSTATE']._serialized_end=397 + _globals['_OBSERVATION']._serialized_start=42 + _globals['_OBSERVATION']._serialized_end=125 + _globals['_ACTION']._serialized_start=127 + _globals['_ACTION']._serialized_end=205 + _globals['_POLICYSETUP']._serialized_start=207 + _globals['_POLICYSETUP']._serialized_end=290 + _globals['_EMPTY']._serialized_start=292 + _globals['_EMPTY']._serialized_end=299 + _globals['_ASYNCINFERENCE']._serialized_start=400 + _globals['_ASYNCINFERENCE']._serialized_end=697 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/async_inference_pb2_grpc.py b/lerobot/scripts/server/async_inference_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ab0f50a4a936bc06a13624817269fb4e0de75b --- /dev/null +++ b/lerobot/scripts/server/async_inference_pb2_grpc.py @@ -0,0 +1,236 @@ +# fmt: off +# flake8: noqa +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import async_inference_pb2 as async__inference__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in async_inference_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class AsyncInferenceStub: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendObservations = channel.stream_unary( + '/async_inference.AsyncInference/SendObservations', + request_serializer=async__inference__pb2.Observation.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.StreamActions = channel.unary_stream( + '/async_inference.AsyncInference/StreamActions', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Action.FromString, + _registered_method=True) + self.SendPolicyInstructions = channel.unary_unary( + '/async_inference.AsyncInference/SendPolicyInstructions', + request_serializer=async__inference__pb2.PolicySetup.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/async_inference.AsyncInference/Ready', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + + +class AsyncInferenceServicer: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def SendObservations(self, request_iterator, context): + """Robot -> Policy to share observations with a remote inference server + Policy -> Robot to share actions predicted for given observations + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamActions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendPolicyInstructions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsyncInferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendObservations': grpc.stream_unary_rpc_method_handler( + servicer.SendObservations, + request_deserializer=async__inference__pb2.Observation.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'StreamActions': grpc.unary_stream_rpc_method_handler( + servicer.StreamActions, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Action.SerializeToString, + ), + 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( + servicer.SendPolicyInstructions, + request_deserializer=async__inference__pb2.PolicySetup.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'async_inference.AsyncInference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AsyncInference: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + @staticmethod + def SendObservations(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/async_inference.AsyncInference/SendObservations', + async__inference__pb2.Observation.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def StreamActions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/async_inference.AsyncInference/StreamActions', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Action.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendPolicyInstructions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/SendPolicyInstructions', + async__inference__pb2.PolicySetup.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/Ready', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py new file mode 100644 index 0000000000000000000000000000000000000000..8a04639e1ccf30f35f1ea7105677126a024c9bdc --- /dev/null +++ b/lerobot/scripts/server/policy_server.py @@ -0,0 +1,341 @@ +import itertools +import logging +import logging.handlers +import os +import pickle # nosec +import time +from concurrent import futures +from queue import Queue +from typing import Generator, List, Optional + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore +import grpc +import torch +from datasets import load_dataset + +from lerobot.common.policies.factory import get_policy_class +from lerobot.scripts.server.robot_client import ( + TimedAction, + TimedObservation, + TinyPolicyConfig, + environment_dt, +) + +# Create logs directory if it doesn't exist +os.makedirs("logs", exist_ok=True) + +# Set up logging with both console and file output +logger = logging.getLogger("policy_server") +logger.setLevel(logging.INFO) + +# Console handler +console_handler = logging.StreamHandler() +console_handler.setFormatter( + logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +) +logger.addHandler(console_handler) + +# File handler - creates a new log file for each run +file_handler = logging.handlers.RotatingFileHandler( + f"logs/policy_server_{int(time.time())}.log", + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, +) +file_handler.setFormatter( + logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +) +logger.addHandler(file_handler) + +inference_latency = 1 / 3 +idle_wait = 0.1 + +supported_policies = ["act"] + + +class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): + def __init__(self): + # Initialize dataset action generator + self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) + + self._setup_server() + + self.actions_per_chunk = 20 + self.actions_overlap = 10 + + def _setup_server(self) -> None: + """Flushes server state when new client connects.""" + # only running inference on the latest observation received by the server + self.observation_queue = Queue(maxsize=1) + + def Ready(self, request, context): # noqa: N802 + client_id = context.peer() + logger.info(f"Client {client_id} connected and ready") + self._setup_server() + + return async_inference_pb2.Empty() + + def SendPolicyInstructions(self, request, context): # noqa: N802 + """Receive policy instructions from the robot client""" + client_id = context.peer() + logger.debug(f"Receiving policy instructions from {client_id}") + + policy_specs = pickle.loads(request.data) # nosec + assert isinstance(policy_specs, TinyPolicyConfig), ( + f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}" + ) + + logger.info( + f"Policy type: {policy_specs.policy_type} | " + f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | " + f"Device: {policy_specs.device}" + ) + + assert policy_specs.policy_type in supported_policies, ( + f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}" + ) + + self.device = policy_specs.device + policy_class = get_policy_class(policy_specs.policy_type) + + start = time.time() + self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) + self.policy.to(self.device) + end = time.time() + + logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") + + return async_inference_pb2.Empty() + + def SendObservations(self, request_iterator, context): # noqa: N802 + """Receive observations from the robot client""" + client_id = context.peer() + logger.debug(f"Receiving observations from {client_id}") + + for observation in request_iterator: + receive_time = time.time() + timed_observation = pickle.loads(observation.data) # nosec + deserialize_time = time.time() + + # If queue is full, get the old observation to make room + if self.observation_queue.full(): + # pops from queue + _ = self.observation_queue.get_nowait() + logger.debug("Observation queue was full, removed oldest observation") + + # Now put the new observation (never blocks as queue is non-full here) + self.observation_queue.put(timed_observation) + queue_time = time.time() + + obs_timestep = timed_observation.get_timestep() + obs_timestamp = timed_observation.get_timestamp() + + logger.info( + f"Received observation #{obs_timestep} | " + f"Client timestamp: {obs_timestamp:.6f} | " + f"Server timestamp: {receive_time:.6f} | " + f"Network latency: {receive_time - obs_timestamp:.6f}s | " + f"Deserialization time: {deserialize_time - receive_time:.6f}s | " + f"Queue time: {queue_time - deserialize_time:.6f}s" + ) + + return async_inference_pb2.Empty() + + def StreamActions(self, request, context): # noqa: N802 + """Stream actions to the robot client""" + client_id = context.peer() + logger.debug(f"Client {client_id} connected for action streaming") + + # Generate action based on the most recent observation and its timestep + start_time = time.time() + try: + obs = self.observation_queue.get() + get_time = time.time() + logger.info( + f"Running inference for observation #{obs.get_timestep()} | Queue get time: {get_time - start_time:.6f}s" + ) + + if obs: + action = self._predict_action_chunk(obs) + inference_end_time = time.time() + logger.info( + f"Action chunk #{obs.get_timestep()} generated | " + f"Total inference time: {inference_end_time - get_time:.6f}s" + ) + yield action + yield_time = time.time() + logger.info( + f"Action chunk #{obs.get_timestep()} sent | Send time: {yield_time - inference_end_time:.6f}s" + ) + else: + logger.warning("No observation in queue yet!") + time.sleep(idle_wait) + except Exception as e: + logger.error(f"Error in StreamActions: {e}") + + return async_inference_pb2.Empty() + + def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: + """Turn a chunk of actions into a list of TimedAction instances, + with the first action corresponding to t_0 and the rest corresponding to + t_0 + i*environment_dt for i in range(len(action_chunk)) + """ + return [ + TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk) + ] + + @torch.no_grad() + def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: + # NOTE: This temporary function only works for ACT policies (Pi0-like models are *not* supported just yet) + """Get an action chunk from the policy""" + start_time = time.time() + + # prepare observation for policy forward pass + batch = self.policy.normalize_inputs(observation) + normalize_time = time.time() + logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s") + + if self.policy.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.policy.config.image_features] + prep_time = time.time() + logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s") + + # forward pass outputs up to policy.config.n_action_steps != actions_per_chunk + forward_start = time.time() + actions = self.policy.model(batch)[0][:, : self.actions_per_chunk] + forward_end = time.time() + logger.debug(f"Policy forward pass time: {forward_end - forward_start:.6f}s") + + actions = self.policy.unnormalize_outputs({"action": actions})["action"] + unnormalize_end = time.time() + logger.debug(f"Action unnormalization time: {unnormalize_end - forward_end:.6f}s") + + end_time = time.time() + logger.info(f"Action chunk generation total time: {end_time - start_time:.6f}s") + + return actions + + def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: + """Predict an action based on the observation""" + start_time = time.time() + observation = {} + for k, v in observation_t.get_observation().items(): + if "image" in k: + observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device) + else: + observation[k] = v.unsqueeze(0).to(self.device) + + prep_time = time.time() + logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s") + + # normalize observation + observation = self.policy.normalize_inputs(observation) + + # Remove batch dimension + action_tensor = self._get_action_chunk(observation) + action_tensor = action_tensor.squeeze(0) + + post_inference_time = time.time() + logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s") + + if action_tensor.dim() == 1: + # No chunk dimension, so repeat action to create a (dummy) chunk of actions + action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1) + + action_chunk = self._time_action_chunk( + observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() + ) + + chunk_time = time.time() + logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s") + + action_bytes = pickle.dumps(action_chunk) # nosec + serialize_time = time.time() + logger.debug(f"Action serialization time: {serialize_time - chunk_time:.6f}s") + + # Create and return the Action message + action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes) + + end_time = time.time() + logger.info( + f"Total action prediction time: {end_time - start_time:.6f}s | " + f"Observation #{observation_t.get_timestep()} | " + f"Action chunk size: {len(action_chunk)}" + ) + + return action + + def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: + """Stream chunks of actions from a prerecorded dataset. + + Returns: + Generator that yields chunks of actions from the dataset + """ + dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch") + + # 1. Select the action column only, where you will find tensors with 6 elements + actions = dataset["action"] + action_indices = torch.arange(len(actions)) + + # 2. Chunk the iterable of tensors into chunks with 10 elements each + # sending only first element for debugging + indices_chunks = action_indices.unfold( + 0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap + ) + + for idx_chunk in indices_chunks: + yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] + + def _read_action_chunk(self, observation: Optional[TimedObservation] = None): + """Dummy function for predicting action chunk given observation. + + Instead of computing actions on-the-fly, this method streams + actions from a prerecorded dataset. + """ + import warnings + + warnings.warn( + "This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 + ) + + if not observation: + observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0) + transfer_state = 0 + else: + transfer_state = observation.transfer_state + + # Get chunk of actions from the generator + actions_chunk = next(self.action_generator) + + # Return a list of TimedActions, with timestamps starting from the observation timestamp + action_data = self._time_action_chunk( + observation.get_timestamp(), actions_chunk, observation.get_timestep() + ) + action_bytes = pickle.dumps(action_data) # nosec + + # Create and return the Action message + action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes) + + time.sleep(inference_latency) # slow action generation, emulates inference time + + return action + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server) + server.add_insecure_port("[::]:50051") + server.start() + logger.info("PolicyServer started on port 50051") + + try: + while True: + time.sleep(86400) # Sleep for a day, or until interrupted + except KeyboardInterrupt: + server.stop(0) + logger.info("Server stopped") + + +if __name__ == "__main__": + serve() diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6a088e8742ede6d5453ac4754c09157517f009 --- /dev/null +++ b/lerobot/scripts/server/robot_client.py @@ -0,0 +1,566 @@ +import logging +import logging.handlers +import os +import pickle # nosec +import threading +import time +from queue import Empty, Queue +from typing import Any, Optional + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore +import grpc +import torch + +from lerobot.common.robot_devices.robots.utils import make_robot + +# Create logs directory if it doesn't exist +os.makedirs("logs", exist_ok=True) + +# Set up logging with both console and file output +logger = logging.getLogger("robot_client") +logger.setLevel(logging.INFO) + +# Console handler +console_handler = logging.StreamHandler() +console_handler.setFormatter( + logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +) +logger.addHandler(console_handler) + +# File handler - creates a new log file for each run +file_handler = logging.handlers.RotatingFileHandler( + f"logs/robot_client_{int(time.time())}.log", + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, +) +file_handler.setFormatter( + logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +) +logger.addHandler(file_handler) + +environment_dt = 1 / 30 +idle_wait = 0.1 + + +class TimedData: + def __init__(self, timestamp: float, data: Any, timestep: int): + """Initialize a TimedData object. + + Args: + timestamp: Unix timestamp relative to data's creation. + data: The actual data to wrap a timestamp around. + """ + self.timestamp = timestamp + self.data = data + self.timestep = timestep + + def get_data(self): + return self.data + + def get_timestamp(self): + return self.timestamp + + def get_timestep(self): + return self.timestep + + +class TimedAction(TimedData): + def __init__(self, timestamp: float, action: torch.Tensor, timestep: int): + super().__init__(timestamp=timestamp, data=action, timestep=timestep) + + def get_action(self): + return self.get_data() + + +class TimedObservation(TimedData): + def __init__( + self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0 + ): + super().__init__(timestamp=timestamp, data=observation, timestep=timestep) + self.transfer_state = transfer_state + + def get_observation(self): + return self.get_data() + + +class TinyPolicyConfig: + def __init__( + self, + policy_type: str = "act", + pretrained_name_or_path: str = "fracapuano/act_so100_test", + device: str = "cpu", + ): + self.policy_type = policy_type + self.pretrained_name_or_path = pretrained_name_or_path + self.device = device + + +class RobotClient: + def __init__( + self, + server_address="localhost:50051", + policy_type: str = "act", # "pi0" + pretrained_name_or_path: str = "fracapuano/act_so100_test", # "lerobot/pi0" + policy_device: str = "mps", + ): + self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device) + self.channel = grpc.insecure_channel(server_address) + self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + logger.info(f"Initializing client to connect to server at {server_address}") + + self.running = False + self.first_observation_sent = False + self.latest_action = 0 + self.action_chunk_size = 20 + + self.action_queue = Queue() + self.start_barrier = threading.Barrier( + 3 + ) # 3 threads: observation sender, action receiver, action executor + + # Create a lock for robot access + self.robot_lock = threading.Lock() + + # Stats for logging + self.obs_sent_count = 0 + self.actions_received_count = 0 + self.actions_executed_count = 0 + self.last_obs_sent_time = 0 + self.last_action_received_time = 0 + + start_time = time.time() + self.robot = make_robot("so100", mock=True) + self.robot.connect() + connect_time = time.time() + logger.info(f"Robot connection time: {connect_time - start_time:.4f}s") + + time.sleep(idle_wait) # sleep waiting for cameras to activate + logger.info("Robot connected and ready") + + def timestamps(self): + """Get the timestamps of the actions in the queue""" + return sorted([action.get_timestep() for action in self.action_queue.queue]) + + def start(self): + """Start the robot client and connect to the policy server""" + try: + # client-server handshake + start_time = time.time() + self.stub.Ready(async_inference_pb2.Empty()) + end_time = time.time() + logger.info(f"Connected to policy server in {end_time - start_time:.4f}s") + + # send policy instructions + policy_config_bytes = pickle.dumps(self.policy_config) + policy_setup = async_inference_pb2.PolicySetup( + transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes + ) + + logger.info("Sending policy instructions to policy server") + logger.info( + f"Policy type: {self.policy_config.policy_type} | " + f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | " + f"Device: {self.policy_config.device}" + ) + + self.stub.SendPolicyInstructions(policy_setup) + + self.running = True + + return True + + except grpc.RpcError as e: + logger.error(f"Failed to connect to policy server: {e}") + return False + + def stop(self): + """Stop the robot client""" + self.running = False + + self.robot.disconnect() + logger.info("Robot disconnected") + + self.channel.close() + logger.info("Client stopped, channel closed") + + # Log final stats + logger.info( + f"Session stats - Observations sent: {self.obs_sent_count}, " + f"Action chunks received: {self.actions_received_count}, " + f"Actions executed: {self.actions_executed_count}" + ) + + def send_observation( + self, + obs: TimedObservation, + transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE, + ) -> bool: + """Send observation to the policy server. + Returns True if the observation was sent successfully, False otherwise.""" + if not self.running: + logger.warning("Client not running") + return False + + assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!" + + start_time = time.time() + observation_bytes = pickle.dumps(obs) + serialize_time = time.time() + logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s") + + observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes) + + try: + send_start = time.time() + _ = self.stub.SendObservations(iter([observation])) + send_end = time.time() + + self.obs_sent_count += 1 + obs_timestep = obs.get_timestep() + + logger.info( + f"Sent observation #{obs_timestep} | " + f"Serialize time: {serialize_time - start_time:.6f}s | " + f"Network time: {send_end - send_start:.6f}s | " + f"Total time: {send_end - start_time:.6f}s" + ) + + if transfer_state == async_inference_pb2.TRANSFER_BEGIN: + self.first_observation_sent = True + + self.last_obs_sent_time = send_end + return True + + except grpc.RpcError as e: + logger.error(f"Error sending observation #{obs.get_timestep()}: {e}") + return False + + def _validate_action(self, action: TimedAction): + """Received actions are keps only when they have been produced for now or later, never before""" + return not action.get_timestamp() < self.latest_action + + def _validate_action_chunk(self, actions: list[TimedAction]): + assert len(actions) == self.action_chunk_size, ( + f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}" + ) + assert all(self._validate_action(action) for action in actions), "Invalid action in chunk" + + return True + + def _inspect_action_queue(self): + queue_size = self.action_queue.qsize() + timestamps = sorted([action.get_timestep() for action in self.action_queue.queue]) + logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}") + return queue_size, timestamps + + def _clear_queue(self): + """Clear the existing queue""" + start_time = time.time() + old_size = self.action_queue.qsize() + + while not self.action_queue.empty(): + try: + self.action_queue.get_nowait() + except Empty: + break + + end_time = time.time() + logger.debug(f"Queue cleared: {old_size} items removed in {end_time - start_time:.6f}s") + + def _fill_action_queue(self, actions: list[TimedAction]): + """Fill the action queue with incoming valid actions""" + start_time = time.time() + valid_count = 0 + + for action in actions: + if self._validate_action(action): + self.action_queue.put(action) + valid_count += 1 + + end_time = time.time() + logger.debug( + f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s" + ) + + def _clear_and_fill_action_queue(self, actions: list[TimedAction]): + """Clear the existing queue and fill it with new actions. + This is a higher-level function that combines clearing and filling operations. + + Args: + actions: List of TimedAction instances to queue + """ + start_time = time.time() + logger.info(f"Current latest action: {self.latest_action}") + + # Get queue state before changes + old_size, old_timesteps = self._inspect_action_queue() + + # Log incoming actions + incoming_timesteps = [a.get_timestep() for a in actions] + logger.info(f"Incoming actions: {len(actions)} items with timesteps {incoming_timesteps}") + + # Clear and fill + clear_start = time.time() + self._clear_queue() + clear_end = time.time() + + fill_start = time.time() + self._fill_action_queue(actions) + fill_end = time.time() + + # Get queue state after changes + new_size, new_timesteps = self._inspect_action_queue() + + end_time = time.time() + logger.info( + f"Queue update complete | " + f"Before: {old_size} items | " + f"After: {new_size} items | " + f"Previous content: {old_timesteps} | " + f"Incoming content: {incoming_timesteps} | " + f"Current contents: {new_timesteps}" + ) + + logger.info( + f"Clear time: {clear_end - clear_start:.6f}s | " + f"Fill time: {fill_end - fill_start:.6f}s | " + f"Total time: {end_time - start_time:.6f}s" + ) + + def receive_actions(self): + """Receive actions from the policy server""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + logger.info("Action receiving thread starting") + + while self.running: + try: + # Use StreamActions to get a stream of actions from the server + for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()): + receive_time = time.time() + + # Deserialize bytes back into list[TimedAction] + deserialize_start = time.time() + timed_actions = pickle.loads(actions_chunk.data) # nosec + deserialize_end = time.time() + + # Calculate network latency if we have matching observations + if len(timed_actions) > 0: + first_action_timestep = timed_actions[0].get_timestep() + server_to_client_latency = receive_time - self.last_obs_sent_time + + logger.info( + f"Received action chunk for step #{first_action_timestep} | " + f"Network latency (server->client): {server_to_client_latency:.6f}s | " + f"Deserialization time: {deserialize_end - deserialize_start:.6f}s" + ) + + # Update action queue + _ = time.time() + self._clear_and_fill_action_queue(timed_actions) + update_end = time.time() + + self.actions_received_count += 1 + self.last_action_received_time = receive_time + + logger.info( + f"Action chunk processed | " + f"Total processing time: {update_end - receive_time:.6f}s | " + f"Round-trip time since observation sent: {receive_time - self.last_obs_sent_time:.6f}s" + ) + + except grpc.RpcError as e: + logger.error(f"Error receiving actions: {e}") + time.sleep(idle_wait) # Avoid tight loop on error + + def _get_next_action(self) -> Optional[TimedAction]: + """Get the next action from the queue""" + try: + action = self.action_queue.get_nowait() + return action + + except Empty: + return None + + def execute_actions(self): + """Continuously execute actions from the queue""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + logger.info("Action execution thread starting") + + while self.running: + # Get the next action from the queue + cycle_start = time.time() + time.sleep(environment_dt) + + get_start = time.time() + timed_action = self._get_next_action() + get_end = time.time() + + if timed_action is not None: + # self.latest_action = timed_action.get_timestep() + _ = self.latest_action + self.latest_action = timed_action.get_timestamp() + + action_timestep = timed_action.get_timestep() + + # Convert action to tensor and send to robot - Acquire lock before accessing the robot + lock_start = time.time() + if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock + lock_acquired = time.time() + try: + send_start = time.time() + self.robot.send_action(timed_action.get_action()) + send_end = time.time() + + self.actions_executed_count += 1 + logger.info( + f"Executed action #{action_timestep} | " + f"Queue get time: {get_end - get_start:.6f}s | " + f"Lock wait time: {lock_acquired - lock_start:.6f}s | " + f"Action send time: {send_end - send_start:.6f}s | " + f"Total execution time: {send_end - cycle_start:.6f}s | " + f"Action latency: {send_end - timed_action.get_timestamp():.6f}s | " + f"Queue size: {self.action_queue.qsize()}" + ) + finally: + # Always release the lock in a finally block to ensure it's released + self.robot_lock.release() + else: + logger.warning("Could not acquire robot lock for action execution, retrying next cycle") + else: + if get_end - get_start > 0.001: # Only log if there was a measurable delay + logger.debug(f"No action available, get time: {get_end - get_start:.6f}s") + time.sleep(idle_wait) + + def stream_observations(self, get_observation_fn): + """Continuously stream observations to the server""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + logger.info("Observation streaming thread starting") + + first_observation = True + while self.running: + try: + # Get serialized observation bytes from the function + cycle_start = time.time() + time.sleep(environment_dt) + + get_start = time.time() + observation = get_observation_fn() + get_end = time.time() + + # Skip if observation is None (couldn't acquire lock) + if observation is None: + logger.warning("Failed to get observation, skipping cycle") + continue + + # Set appropriate transfer state + if first_observation: + state = async_inference_pb2.TRANSFER_BEGIN + first_observation = False + else: + state = async_inference_pb2.TRANSFER_MIDDLE + + obs_timestep = observation.get_timestep() + logger.debug(f"Got observation #{obs_timestep} in {get_end - get_start:.6f}s, sending...") + + send_start = time.time() + self.send_observation(observation, state) + send_end = time.time() + + logger.info( + f"Observation #{obs_timestep} cycle complete | " + f"Get time: {get_end - get_start:.6f}s | " + f"Send time: {send_end - send_start:.6f}s | " + f"Total cycle time: {send_end - cycle_start:.6f}s" + ) + + except Exception as e: + logger.error(f"Error in observation sender: {e}") + time.sleep(idle_wait) + + +def async_client(): + # Example of how to use the RobotClient + client = RobotClient() + + if client.start(): + # Function to generate mock observations + def get_observation(): + # Create a counter attribute if it doesn't exist + if not hasattr(get_observation, "counter"): + get_observation.counter = 0 + + # Acquire lock before accessing the robot + start_time = time.time() + observation_content = None + if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock + lock_time = time.time() + try: + capture_start = time.time() + observation_content = client.robot.capture_observation() + capture_end = time.time() + logger.debug( + f"Observation capture | " + f"Lock acquisition: {lock_time - start_time:.6f}s | " + f"Capture time: {capture_end - capture_start:.6f}s" + ) + finally: + # Always release the lock in a finally block to ensure it's released + client.robot_lock.release() + else: + logger.warning("Could not acquire robot lock for observation capture, skipping this cycle") + return None # Return None to indicate no observation was captured + + current_time = time.time() + observation = TimedObservation( + timestamp=current_time, observation=observation_content, timestep=get_observation.counter + ) + + # Increment counter for next call + get_observation.counter += 1 + + end_time = time.time() + logger.debug( + f"Observation #{observation.get_timestep()} prepared | " + f"Total time: {end_time - start_time:.6f}s" + ) + + return observation + + logger.info("Starting all threads...") + + # Create and start observation sender thread + obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,)) + obs_thread.daemon = True + + # Create and start action receiver thread + action_receiver_thread = threading.Thread(target=client.receive_actions) + action_receiver_thread.daemon = True + + # Create action execution thread + action_execution_thread = threading.Thread(target=client.execute_actions) + action_execution_thread.daemon = True + + # Start all threads + obs_thread.start() + action_receiver_thread.start() + action_execution_thread.start() + + try: + # Main thread just keeps everything alive + while client.running: + time.sleep(idle_wait) + + except KeyboardInterrupt: + pass + + finally: + client.stop() + logger.info("Client stopped") + + +if __name__ == "__main__": + async_client() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0de247be9aa775318beec59e656e85f5d0721d1e --- /dev/null +++ b/lerobot/scripts/train.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from contextlib import nullcontext +from pprint import pformat +from typing import Any + +import torch +from termcolor import colored +from torch.amp import GradScaler +from torch.optim import Optimizer + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.utils import cycle +from lerobot.common.envs.factory import make_env +from lerobot.common.optim.factory import make_optimizer_and_scheduler +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + has_method, + init_logging, +) +from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.scripts.eval import eval_policy + + +def update_policy( + train_metrics: MetricsTracker, + policy: PreTrainedPolicy, + batch: Any, + optimizer: Optimizer, + grad_clip_norm: float, + grad_scaler: GradScaler, + lr_scheduler=None, + use_amp: bool = False, + lock=None, +) -> tuple[MetricsTracker, dict]: + start_time = time.perf_counter() + device = get_device_from_parameters(policy) + policy.train() + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + loss, output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + grad_scaler.scale(loss).backward() + + # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) + + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) + + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() + + optimizer.zero_grad() + + # Step through pytorch scheduler at every batch instead of epoch + if lr_scheduler is not None: + lr_scheduler.step() + + if has_method(policy, "update"): + # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). + policy.update() + + train_metrics.loss = loss.item() + train_metrics.grad_norm = grad_norm.item() + train_metrics.lr = optimizer.param_groups[0]["lr"] + train_metrics.update_s = time.perf_counter() - start_time + return train_metrics, output_dict + + +@parser.wrap() +def train(cfg: TrainPipelineConfig): + cfg.validate() + logging.info(pformat(cfg.to_dict())) + + if cfg.wandb.enable and cfg.wandb.project: + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + if cfg.seed is not None: + set_seed(cfg.seed) + + # Check device is available + device = get_safe_torch_device(cfg.policy.device, log=True) + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("Creating dataset") + dataset = make_dataset(cfg) + + # Create environment used for evaluating checkpoints during training on simulation data. + # On real-world data, no need to create an environment as evaluations are done outside train.py, + # using the eval.py instead, with gym_dora environment and dora-rs. + eval_env = None + if cfg.eval_freq > 0 and cfg.env is not None: + logging.info("Creating env") + eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + + logging.info("Creating policy") + policy = make_policy( + cfg=cfg.policy, + ds_meta=dataset.meta, + ) + + logging.info("Creating optimizer and scheduler") + optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) + + step = 0 # number of policy updates (forward + backward + optim) + + if cfg.resume: + step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) + + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + # create dataloader for offline training + if hasattr(cfg.policy, "drop_n_last_frames"): + shuffle = False + sampler = EpisodeAwareSampler( + dataset.episode_data_index, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + shuffle=True, + ) + else: + shuffle = True + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=cfg.num_workers, + batch_size=cfg.batch_size, + shuffle=shuffle, + sampler=sampler, + pin_memory=device.type != "cpu", + drop_last=False, + ) + dl_iter = cycle(dataloader) + + policy.train() + + train_metrics = { + "loss": AverageMeter("loss", ":.3f"), + "grad_norm": AverageMeter("grdn", ":.3f"), + "lr": AverageMeter("lr", ":0.1e"), + "update_s": AverageMeter("updt_s", ":.3f"), + "dataloading_s": AverageMeter("data_s", ":.3f"), + } + + train_tracker = MetricsTracker( + cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + ) + + logging.info("Start offline training on a fixed dataset") + for _ in range(step, cfg.steps): + start_time = time.perf_counter() + batch = next(dl_iter) + train_tracker.dataloading_s = time.perf_counter() - start_time + + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device, non_blocking=True) + + train_tracker, output_dict = update_policy( + train_tracker, + policy, + batch, + optimizer, + cfg.optimizer.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.policy.use_amp, + ) + + # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we + # increment `step` here. + step += 1 + train_tracker.step() + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 + is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps + is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 + + if is_log_step: + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + wandb_logger.log_dict(wandb_log_dict, step) + train_tracker.reset_averages() + + if cfg.save_checkpoint and is_saving_step: + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) + + if cfg.env and is_eval_step: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + ): + eval_info = eval_policy( + eval_env, + policy, + cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) + + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step + ) + eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") + eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") + eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") + logging.info(eval_tracker) + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") + + if eval_env: + eval_env.close() + logging.info("End of training") + + +if __name__ == "__main__": + init_logging() + train() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfea6b8b9070e428a2a4eebe47e2e7080dd5589 --- /dev/null +++ b/lerobot/scripts/visualize_dataset.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesn't always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossy compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Examples: + +- Visualize data stored on a local machine: +``` +local$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +- Visualize data stored on a distant machine with a local viewer: +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd +``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + +""" + +import argparse +import gc +import logging +import time +from pathlib import Path +from typing import Iterator + +import numpy as np +import rerun as rr +import torch +import torch.utils.data +import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset: LeRobotDataset, episode_index: int): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self) -> Iterator: + return iter(self.frame_ids) + + def __len__(self) -> int: + return len(self.frame_ids) + + +def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy + + +def visualize_dataset( + dataset: LeRobotDataset, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + output_dir: Path | None = None, +) -> Path | None: + if save: + assert output_dir is not None, ( + "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + ) + + repo_id = dataset.repo_id + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + logging.info("Starting Rerun") + + if mode not in ["local", "distant"]: + raise ValueError(mode) + + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + + # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush + # when iterating on a dataloader with `num_workers` > 0 + # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix + gc.collect() + + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + + logging.info("Logging to Rerun") + + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + + # display each camera image + for key in dataset.meta.camera_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path + + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--episode-index", + type=int, + required=True, + help="Episode to visualize.", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of processes of Dataloader for loading the data.", + ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. " + "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. " + "Visualize the data by running `rerun path/to/file.rrd` on your local machine." + ), + ) + + parser.add_argument( + "--tolerance-s", + type=float, + default=1e-4, + help=( + "Tolerance in seconds used to ensure data timestamps respect the dataset fps value" + "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument" + "If not given, defaults to 1e-4." + ), + ) + + args = parser.parse_args() + kwargs = vars(args) + repo_id = kwargs.pop("repo_id") + root = kwargs.pop("root") + tolerance_s = kwargs.pop("tolerance_s") + + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) + + visualize_dataset(dataset, **vars(args)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc21a8f122985c9c259d726853466c247ccc15d --- /dev/null +++ b/lerobot/scripts/visualize_dataset_html.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Example of usage: + +- Visualize data stored on a local machine: +```bash +local$ python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht + +local$ open http://localhost:9090 +``` + +- Visualize data stored on a distant machine with a local viewer: +```bash +distant$ python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht + +local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel +local$ open http://localhost:9090 +``` + +- Select episodes to visualize: +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht \ + --episodes 7 3 5 1 4 +``` +""" + +import argparse +import csv +import json +import logging +import re +import shutil +import tempfile +from io import StringIO +from pathlib import Path + +import numpy as np +import pandas as pd +import requests +from flask import Flask, redirect, render_template, request, url_for + +from lerobot import available_datasets +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import IterableNamespace +from lerobot.common.utils.utils import init_logging + + +def run_server( + dataset: LeRobotDataset | IterableNamespace | None, + episodes: list[int] | None, + host: str, + port: str, + static_folder: Path, + template_folder: Path, +): + app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) + app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache + + @app.route("/") + def hommepage(dataset=dataset): + if dataset: + dataset_namespace, dataset_name = dataset.repo_id.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=0, + ) + ) + + dataset_param, episode_param = None, None + all_params = request.args + if "dataset" in all_params: + dataset_param = all_params["dataset"] + if "episode" in all_params: + episode_param = int(all_params["episode"]) + + if dataset_param: + dataset_namespace, dataset_name = dataset_param.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=episode_param if episode_param is not None else 0, + ) + ) + + featured_datasets = [ + "lerobot/aloha_static_cups_open", + "lerobot/columbia_cairlab_pusht_real", + "lerobot/taco_play", + ] + return render_template( + "visualize_dataset_homepage.html", + featured_datasets=featured_datasets, + lerobot_datasets=available_datasets, + ) + + @app.route("//") + def show_first_episode(dataset_namespace, dataset_name): + first_episode_id = 0 + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=first_episode_id, + ) + ) + + @app.route("///episode_") + def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): + repo_id = f"{dataset_namespace}/{dataset_name}" + try: + if dataset is None: + dataset = get_dataset_info(repo_id) + except FileNotFoundError: + return ( + "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461", + 400, + ) + dataset_version = ( + str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version + ) + match = re.search(r"v(\d+)\.", dataset_version) + if match: + major_version = int(match.group(1)) + if major_version < 2: + return "Make sure to convert your LeRobotDataset to v2 & above." + + episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id) + dataset_info = { + "repo_id": f"{dataset_namespace}/{dataset_name}", + "num_samples": dataset.num_frames + if isinstance(dataset, LeRobotDataset) + else dataset.total_frames, + "num_episodes": dataset.num_episodes + if isinstance(dataset, LeRobotDataset) + else dataset.total_episodes, + "fps": dataset.fps, + } + if isinstance(dataset, LeRobotDataset): + video_paths = [ + dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys + ] + videos_info = [ + {"url": url_for("static", filename=video_path), "filename": video_path.parent.name} + for video_path in video_paths + ] + tasks = dataset.meta.episodes[episode_id]["tasks"] + else: + video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] + videos_info = [ + { + "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + dataset.video_path.format( + episode_chunk=int(episode_id) // dataset.chunks_size, + video_key=video_key, + episode_index=episode_id, + ), + "filename": video_key, + } + for video_key in video_keys + ] + + response = requests.get( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5 + ) + response.raise_for_status() + # Split into lines and parse each line as JSON + tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] + + filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] + tasks = filtered_tasks_jsonl[0]["tasks"] + + videos_info[0]["language_instruction"] = tasks + + if episodes is None: + episodes = list( + range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) + ) + + return render_template( + "visualize_dataset_template.html", + episode_id=episode_id, + episodes=episodes, + dataset_info=dataset_info, + videos_info=videos_info, + episode_data_csv_str=episode_data_csv_str, + columns=columns, + ignored_columns=ignored_columns, + ) + + app.run(host=host, port=port) + + +def get_ep_csv_fname(episode_id: int): + ep_csv_fname = f"episode_{episode_id}.csv" + return ep_csv_fname + + +def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): + """Get a csv str containing timeseries data of an episode (e.g. state and action). + This file will be loaded by Dygraph javascript to plot data in real time.""" + columns = [] + + selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]] + selected_columns.remove("timestamp") + + ignored_columns = [] + for column_name in selected_columns: + shape = dataset.features[column_name]["shape"] + shape_dim = len(shape) + if shape_dim > 1: + selected_columns.remove(column_name) + ignored_columns.append(column_name) + + # init header of csv with state and action names + header = ["timestamp"] + + for column_name in selected_columns: + dim_state = ( + dataset.meta.shapes[column_name][0] + if isinstance(dataset, LeRobotDataset) + else dataset.features[column_name].shape[0] + ) + + if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: + column_names = dataset.features[column_name]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + else: + column_names = [f"{column_name}_{i}" for i in range(dim_state)] + columns.append({"key": column_name, "value": column_names}) + + header += column_names + + selected_columns.insert(0, "timestamp") + + if isinstance(dataset, LeRobotDataset): + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + data = ( + dataset.hf_dataset.select(range(from_idx, to_idx)) + .select_columns(selected_columns) + .with_format("pandas") + ) + else: + repo_id = dataset.repo_id + + url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( + episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index + ) + df = pd.read_parquet(url) + data = df[selected_columns] # Select specific columns + + rows = np.hstack( + ( + np.expand_dims(data["timestamp"], axis=1), + *[np.vstack(data[col]) for col in selected_columns[1:]], + ) + ).tolist() + + # Convert data to CSV string + csv_buffer = StringIO() + csv_writer = csv.writer(csv_buffer) + # Write header + csv_writer.writerow(header) + # Write data rows + csv_writer.writerows(rows) + csv_string = csv_buffer.getvalue() + + return csv_string, columns, ignored_columns + + +def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: + # get first frame of episode (hack to get video_path of the episode) + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + return [ + dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] + for key in dataset.meta.video_keys + ] + + +def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: + # check if the dataset has language instructions + if "language_instruction" not in dataset.features: + return None + + # get first frame index + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + + language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] + # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored + # with the tf.tensor appearing in the string + return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") + + +def get_dataset_info(repo_id: str) -> IterableNamespace: + response = requests.get( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5 + ) + response.raise_for_status() # Raises an HTTPError for bad responses + dataset_info = response.json() + dataset_info["repo_id"] = repo_id + return IterableNamespace(dataset_info) + + +def visualize_dataset_html( + dataset: LeRobotDataset | None, + episodes: list[int] | None = None, + output_dir: Path | None = None, + serve: bool = True, + host: str = "127.0.0.1", + port: int = 9090, + force_override: bool = False, +) -> Path | None: + init_logging() + + template_dir = Path(__file__).resolve().parent.parent / "templates" + + if output_dir is None: + # Create a temporary directory that will be automatically cleaned up + output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") + + output_dir = Path(output_dir) + if output_dir.exists(): + if force_override: + shutil.rmtree(output_dir) + else: + logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") + + output_dir.mkdir(parents=True, exist_ok=True) + + static_dir = output_dir / "static" + static_dir.mkdir(parents=True, exist_ok=True) + + if dataset is None: + if serve: + run_server( + dataset=None, + episodes=None, + host=host, + port=port, + static_folder=static_dir, + template_folder=template_dir, + ) + else: + # Create a simlink from the dataset video folder containing mp4 files to the output directory + # so that the http server can get access to the mp4 files. + if isinstance(dataset, LeRobotDataset): + ln_videos_dir = static_dir / "videos" + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) + + if serve: + run_server(dataset, episodes, host, port, static_dir, template_dir) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + default=None, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--load-from-hf-hub", + type=int, + default=0, + help="Load videos and parquet files from HF Hub rather than local system.", + ) + parser.add_argument( + "--episodes", + type=int, + nargs="*", + default=None, + help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", + ) + parser.add_argument( + "--serve", + type=int, + default=1, + help="Launch web server.", + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Web host used by the http server.", + ) + parser.add_argument( + "--port", + type=int, + default=9090, + help="Web port used by the http server.", + ) + parser.add_argument( + "--force-override", + type=int, + default=0, + help="Delete the output directory if it exists already.", + ) + + parser.add_argument( + "--tolerance-s", + type=float, + default=1e-4, + help=( + "Tolerance in seconds used to ensure data timestamps respect the dataset fps value" + "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument" + "If not given, defaults to 1e-4." + ), + ) + + args = parser.parse_args() + kwargs = vars(args) + repo_id = kwargs.pop("repo_id") + load_from_hf_hub = kwargs.pop("load_from_hf_hub") + root = kwargs.pop("root") + tolerance_s = kwargs.pop("tolerance_s") + + dataset = None + if repo_id: + dataset = ( + LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) + if not load_from_hf_hub + else get_dataset_info(repo_id) + ) + + visualize_dataset_html(dataset, **vars(args)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..80935d327df5b13ab2f0b1ef35fe0a080f8f1279 --- /dev/null +++ b/lerobot/scripts/visualize_image_transforms.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize effects of image transforms for a given configuration. + +This script will generate examples of transformed images as they are output by LeRobot dataset. +Additionally, each individual transform can be visualized separately as well as examples of combined transforms + +Example: +```bash +python lerobot/scripts/visualize_image_transforms.py \ + --repo_id=lerobot/pusht \ + --episodes='[0]' \ + --image_transforms.enable=True +``` +""" + +import logging +from copy import deepcopy +from dataclasses import replace +from pathlib import Path + +import draccus +from torchvision.transforms import ToPILImage + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import ( + ImageTransforms, + ImageTransformsConfig, + make_transform_from_config, +) +from lerobot.configs.default import DatasetConfig + +OUTPUT_DIR = Path("outputs/image_transforms") +to_pil = ToPILImage() + + +def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): + output_dir_all = output_dir / "all" + output_dir_all.mkdir(parents=True, exist_ok=True) + + tfs = ImageTransforms(cfg) + for i in range(1, n_examples + 1): + transformed_frame = tfs(original_frame) + to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100) + + print("Combined transforms examples saved to:") + print(f" {output_dir_all}") + + +def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): + if not cfg.enable: + logging.warning( + "No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`." + ) + return + + print("Individual transforms examples saved to:") + for tf_name, tf_cfg in cfg.tfs.items(): + # Apply a few transformation with random value in min_max range + output_dir_single = output_dir / tf_name + output_dir_single.mkdir(parents=True, exist_ok=True) + + tf = make_transform_from_config(tf_cfg) + for i in range(1, n_examples + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100) + + # Apply min, max, average transformations + tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs) + tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs) + tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs) + + for key, (min_, max_) in tf_cfg.kwargs.items(): + avg = (min_ + max_) / 2 + tf_cfg_kwgs_min[key] = [min_, min_] + tf_cfg_kwgs_max[key] = [max_, max_] + tf_cfg_kwgs_avg[key] = [avg, avg] + + tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})) + tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})) + tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})) + + tf_frame_min = tf_min(original_frame) + tf_frame_max = tf_max(original_frame) + tf_frame_avg = tf_avg(original_frame) + + to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100) + to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100) + to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100) + + print(f" {output_dir_single}") + + +@draccus.wrap() +def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5): + dataset = LeRobotDataset( + repo_id=cfg.repo_id, + episodes=cfg.episodes, + revision=cfg.revision, + video_backend=cfg.video_backend, + ) + + output_dir = output_dir / cfg.repo_id.split("/")[-1] + output_dir.mkdir(parents=True, exist_ok=True) + + # Get 1st frame from 1st camera of 1st episode + original_frame = dataset[0][dataset.meta.camera_keys[0]] + to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) + print("\nOriginal frame saved to:") + print(f" {output_dir / 'original_frame.png'}.") + + save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples) + save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples) + + +if __name__ == "__main__": + visualize_image_transforms() diff --git a/lerobot/templates/visualize_dataset_homepage.html b/lerobot/templates/visualize_dataset_homepage.html new file mode 100644 index 0000000000000000000000000000000000000000..19613afb5d9cc28996321adc51adf27617aad504 --- /dev/null +++ b/lerobot/templates/visualize_dataset_homepage.html @@ -0,0 +1,68 @@ + + + + + + Interactive Video Background Page + + + + +
+ +
+
+
+
+

LeRobot Dataset Visualizer

+ + create & train your own robots + +

+
+

Example Datasets:

+
    + {% for dataset in featured_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+
+ + +
+ +
+ More example datasets +
    + {% for dataset in lerobot_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+ + diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html new file mode 100644 index 0000000000000000000000000000000000000000..cf9d40f1d0076bab302d35d69ccc1055f975d980 --- /dev/null +++ b/lerobot/templates/visualize_dataset_template.html @@ -0,0 +1,546 @@ + + + + + + + + + + + {{ dataset_info.repo_id }} episode {{ episode_id }} + + + + + + + +
+ + +

{{ dataset_info.repo_id }}

+
+ +
    +
  • + Number of samples/frames: {{ dataset_info.num_samples }} +
  • +
  • + Number of episodes: {{ dataset_info.num_episodes }} +
  • +
  • + Frames per second: {{ dataset_info.fps }} +
  • +
+ +

Episodes:

+ + + + +
+ +
+ +
+ +
+ +
+ + + + + +
+

+ Episode {{ episode_id }} +

+ + + + + +
+
+ filter videos +
🔽
+
+ +
+
+ +
+
+
+ +
+ {% for video_info in videos_info %} +
+

{{ video_info.filename }}

+ +
+ {% endfor %} +
+ + + {% if videos_info[0].language_instruction %} +

+ Language Instruction: {{ videos_info[0].language_instruction }} +

+ {% endif %} + + + + + +
+ + + + + + +
0:00 / + 0:00 +
+
+ + +
+
+
+
+

+ Time: 0.00s +

+
+ +
+ + + + + + + + + + +
+ + + + {% if ignored_columns|length > 0 %} +
+ Columns {{ ignored_columns }} are NOT shown since the visualizer currently does not support 2D or 3D data. +
+ {% endif %} +
+ +
+
+ + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..2dd766c4e7d2cf33b19dceac81cbb4c04384a22b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,138 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[project.urls] +homepage = "https://github.com/huggingface/lerobot" +issues = "https://github.com/huggingface/lerobot/issues" +discord = "https://discord.gg/s3KuuzsPFb" + +[project] +name = "lerobot" +version = "0.1.0" +description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" +authors = [ + { name = "Rémi Cadène", email = "re.cadene@gmail.com" }, + { name = "Simon Alibert", email = "alibert.sim@gmail.com" }, + { name = "Alexander Soare", email = "alexander.soare159@gmail.com" }, + { name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" }, + { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" }, + { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" }, + { name = "Steven Palma", email = "imstevenpmwork@ieee.org" }, +] +readme = "README.md" +license = { text = "Apache-2.0" } +requires-python = ">=3.10" +keywords = ["robotics", "deep learning", "pytorch"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Topic :: Software Development :: Build Tools", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", +] +dependencies = [ + "cmake>=3.29.0.1", + "datasets>=2.19.0", + "deepdiff>=7.0.1", + "diffusers>=0.27.2", + "draccus>=0.10.0", + "einops>=0.8.0", + "flask>=3.0.3", + "gdown>=5.1.0", + "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work + "h5py>=3.10.0", + "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", + "imageio[ffmpeg]>=2.34.0", + "jsonlines>=4.0.0", + "numba>=0.59.0", + "omegaconf>=2.3.0", + "opencv-python-headless>=4.9.0", + "packaging>=24.2", + "av>=12.0.5", + "pymunk>=6.6.0", + "pynput>=1.7.7", + "pyzmq>=26.2.1", + "rerun-sdk>=0.21.0", + "termcolor>=2.4.0", + "torch>=2.2.1", + "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", + "torchvision>=0.21.0", + "wandb>=0.16.3", + "zarr>=2.17.0", + "grpcio>=1.71.0", +] + +[project.optional-dependencies] +aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"] +dora = [ + "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'", +] +dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] +feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] +intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] +pi0 = ["transformers>=4.48.0"] +pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] +stretch = [ + "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", + "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", + "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", + "pynput>=1.7.7", +] +test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] +umi = ["imagecodecs>=2024.1.1"] +video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] +xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] + +[tool.poetry] +requires-poetry = ">=2.1" + +[tool.ruff] +line-length = 110 +target-version = "py310" +exclude = ["tests/artifacts/**/*.safetensors"] + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] + +[tool.bandit] +exclude_dirs = [ + "tests", + "benchmarks", + "lerobot/common/datasets/push_dataset_to_hub", + "lerobot/common/datasets/v2/convert_dataset_v1_to_v2", + "lerobot/common/policies/pi0/conversion_scripts", + "lerobot/scripts/push_dataset_to_hub.py", +] +skips = ["B101", "B311", "B404", "B603"] + +[tool.typos] +default.extend-ignore-re = [ + "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line + "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker: +] +default.extend-ignore-identifiers-re = [ + # Add individual words here to ignore them + "2nd", + "pn", + "ser", + "ein", +] + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"