""" Example usage of the Pi-0 Bolt Nut Sort model """ from openpi.policies import policy_config from openpi.training import config import numpy as np def load_model(checkpoint_path: str): """Load the Pi-0 bolt nut sort model.""" train_config = config.get_config("pi0_bns") policy = policy_config.create_trained_policy( train_config, checkpoint_path, default_prompt="sort the bolts and the nuts into separate baskets" ) return policy def create_observation(images, joint_positions): """Create observation dict for the model.""" return { "images": { "cam_high": images["high"], # [224, 224, 3] uint8 "cam_left_wrist": images["left_wrist"], # [224, 224, 3] uint8 "cam_right_wrist": images["right_wrist"], # [224, 224, 3] uint8 }, "state": joint_positions, # [14] float32 "prompt": "sort the bolts and the nuts into separate baskets" } # Example usage if __name__ == "__main__": # Load model policy = load_model("./checkpoint") # Create dummy observation images = { "high": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), "left_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), "right_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), } joint_positions = np.random.randn(14).astype(np.float32) obs = create_observation(images, joint_positions) # Get actions result = policy.infer(obs) actions = result["actions"] # [50, 14] - 50 timesteps of 14-DoF actions print(f"Generated actions shape: {actions.shape}")