Robotics
Transformers
Safetensors
English
VLA
Libero-Goal-2 / array_typing.py
Hume-vla's picture
Upload folder using huggingface_hub
72bf50e verified
from typing import Annotated, TypeAlias, TypedDict
import numpy as np
from jaxtyping import Float
from numpy.typing import NDArray
from torch import Tensor
class InferConfig(TypedDict):
"""Configuration for inference."""
replan_steps: int
s2_replan_steps: int
s2_candidates_num: int
noise_temp_lower_bound: float
noise_temp_upper_bound: float
time_temp_lower_bound: float
time_temp_upper_bound: float
post_process_action: bool
device: str
ImageArray: TypeAlias = Annotated[NDArray[np.uint8], "Shape[B, H, W, C]"]
StateArray: TypeAlias = Annotated[
NDArray[np.float32], "Shape[B, state_horizon, state_dim]"
]
ActionArray: TypeAlias = Annotated[NDArray[np.float32], "Shape[B, action_dim]"]
InferBatchObs = TypedDict(
"BatchObs",
{
"observation.images.image": ImageArray,
"observation.images.wrist_image": ImageArray,
"observation.state": StateArray,
"task": list[str],
},
)
class InferOutput(TypedDict):
noise_action: Float[Tensor, "batch s2_chunksize padded_action_dim"]
s1_action: Float[Tensor, "batch s1_chunksize unpadded_action_dim"]
s2_action: Float[Tensor, "batch s2_chunksize unpadded_action_dim"]
class CalQlBatch(TypedDict):
encoded_observations: Float[Tensor, "batch encoded_dim"]
encoded_next_observations: Float[Tensor, "batch encoded_dim"]
actions: Float[Tensor, "batch action_dim"]
rewards: Float[Tensor, " batch"]
mc_returns: Float[Tensor, " batch"]
masks: Float[Tensor, " batch"]
class EnvArgs(TypedDict):
"""Environment arguments."""
# necessary args
num_trials_per_task: int
num_steps_wait: int
task_suite_name: str
seed: int
ckpt_path: str | None
eval_name: str | None
class Request(TypedDict):
"""Environment receive message."""
frame_type: str # "init" | "action"
env_args: EnvArgs | None
action: ActionArray | None
class Response(TypedDict):
"""Environment send message."""
status: str # "new_episode" | "eval_finished" | "in_episode"
success_rate: float | None
observation: InferBatchObs | None