GR00T 系列 4 - 理解 GR00T 策略 API
理解 GR00T 策略 API
本指南说明如何使用 Gr00tPolicy 类加载已训练模型以及运行推理。训练完成后,可使用该 API 将模型集成到评估环境中。
加载策略
通过提供 Embodiment 标签、模型 Checkpoint 路径以及设备来初始化策略:
from gr00t.policy import Gr00tPolicy
from gr00t.data.embodiment_tags import EmbodimentTag
# Load your trained model
policy = Gr00tPolicy(
model_path="/path/to/your/checkpoint",
embodiment_tag=EmbodimentTag.NEW_EMBODIMENT, # or other embodiment tags
device="cuda:0", # or "cpu", or device index like 0
strict=True # Enable input/output validation (recommended during development)
)参数:
| 参数 | 类型 | 默认值 | 描述 |
|---|---|---|---|
embodiment_tag | EmbodimentTag | str | (必填) | 机器人类型;接受 Enum 或大小写不敏感的字符串(比如 "NEW_EMBODIMENT") |
model_path | str | (必填) | 模型 Checkpoint 目录路径(本地路径或 HuggingFace 模型 ID) |
device | str | int | (必填) | 推理设备:"cuda:0"、0 或 "cpu" |
strict | bool | True | 在运行时验证观测形状和数据类型。建议开发期间启用;生产环境可关闭以提升速度 |
推理参数指南
运行推理脚本时(比如 standalone_inference_script.py、open_loop_eval.py),关键参数如下:
--embodiment-tag
决定模型使用哪个模态配置(状态/动作键、归一化)。必须与数据集的机器人类型匹配。
该标签大小写不敏感,接受枚举名称或字符串值。
比如,--embodiment-tag OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT 和 --embodiment-tag LIBERO_PANDA 都可以正确解析。未知标签将产生错误,并且列出所有已知选项。
- 预训练标签(比如
OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT、XDOF、REAL_G1):用于在与预训练 Embodiment 匹配的数据集上进行 Zero-Shot 推理。模态配置从基础模型 Checkpoint 加载。
- 后训练标签(
OXE_DROID_RELATIVE_EEF_RELATIVE_JOINT、LIBERO_PANDA、SIMPLER_ENV_GOOGLE、SIMPLER_ENV_WIDOWX):需要微调后的 Checkpoint。将这些标签传给基础模型将产生错误。
NEW_EMBODIMENT:用于自定义机器人。微调期间需要--modality-config-path。微调后,配置将保存在 Checkpoint 中,并且在推理期间自动加载。- 每个 Python 进程只能注册一个
NEW_EMBODIMENT模态配置。比如examples/SO100/so100_config.py和examples/mask-guided-background-suppression/so101_config.py都注册到该标签;在同一进程中同时导入二者将失败。在正常 CLI 使用中,将只导入所选的--modality-config-path,因此这不是问题;只需避免在同一脚本中同时接入两个配置。
- 每个 Python 进程只能注册一个
--action-horizon
每次推理调用预测的未来动作步数。模型最大值为 16(来自模型配置)。常用值:
16:完整时域,用于开环评估
8:较短时域,常用于需要频繁重新规划的实时部署
该参数与机器人无关;同一值可用于不同数据集及 Embodiment。
--inference-mode
pytorch:标准 PyTorch 推理(默认值,无需设置)
tensorrt:使用 TensorRT 引擎的加速推理
理解观测格式
策略期望观测是包含三种模态的嵌套字典:
observation = {
"video": {
"camera_name": np.ndarray, # Shape: (B, T, H, W, 3), dtype: uint8
# ... one entry per camera
},
"state": {
"state_name": np.ndarray, # Shape: (B, T, D), dtype: float32
# ... one entry per state stream
},
"language": {
"task": [[str]], # Shape: (B, 1), list of lists of strings
}
}维度
B:批大小(并行环境数量)
T:时间时域(历史观测数量)
H, W:图像高度及宽度
D:状态维度
C:通道数(RGB 必须为 3)
数据类型要求
- Video 必须是
np.uint8数组,RGB 像素值范围为[0, 255]
- State 必须是
np.float32数组
- Language 指令是字符串列表的列表
重要说明
- 时间时域
T由模型训练配置决定
- 不同模态可能具有不同时间时域(通过
get_modality_config()查询)
- Language 指令通常是单个时间步(
T=1)
- 批次中的所有数组必须具有相同的批大小
B
理解动作格式
策略以类似的嵌套结构返回动作:
action = {
"action_name": np.ndarray, # Shape: (B, T, D), dtype: float32
# ... one entry per action stream
}维度
B:批大小(与输入批大小匹配)
T:动作时域(要预测的未来动作步数)
D:动作维度(比如机械臂关节为 7,夹爪为 1)
重要说明
- Action 以物理单位返回(比如以弧度表示的关节位置、以 rad/s 表示的速度)
- Action 未归一化,可直接发送到机器人控制器
- Action 时域
T允许预测多个未来步骤(对动作分块有用)
运行推理
使用 get_action() 方法从观测计算动作:
# Get action from current observation
action, info = policy.get_action(observation)
# Access the action array
arm_action = action["action_name"] # Shape: (B, T, D)
# Extract the first action to execute
next_action = arm_action[:, 0, :] # Shape: (B, D)该方法返回由以下内容组成的元组:
action:动作数组字典
info:附加信息字典(当前为空,保留供未来使用)
查询模态配置
要了解策略期望哪些观测以及将产生哪些动作,请查询模态配置:
# Get modality configs for your embodiment
modality_configs = policy.get_modality_config()
# Check what camera keys are expected
video_keys = modality_configs["video"].modality_keys
print(f"Expected cameras: {video_keys}")
# Check video temporal horizon
video_horizon = len(modality_configs["video"].delta_indices)
print(f"Video frames needed: {video_horizon}")
# Check state keys and horizon
state_keys = modality_configs["state"].modality_keys
state_horizon = len(modality_configs["state"].delta_indices)
print(f"Expected states: {state_keys}, horizon: {state_horizon}")
# Check action keys and horizon
action_keys = modality_configs["action"].modality_keys
action_horizon = len(modality_configs["action"].delta_indices)
print(f"Action outputs: {action_keys}, horizon: {action_horizon}")这在以下情况下特别有用:
- 不确定已训练模型期望哪些观测
- 需要验证每个模态的时间时域
- 正在调试观测/动作格式不匹配问题
重置策略
在 Episode 之间重置策略:
# Reset policy state (if any) between episodes
info = policy.reset()目前策略是无状态的,但调用 reset() 是面向未来兼容性的良好实践。
常见模式
批量推理
策略支持批量推理以提高效率:
# Run 4 environments in parallel
batch_size = 4
observation = {
"video": {"wrist_cam": np.zeros((batch_size, T_video, H, W, 3), dtype=np.uint8)},
"state": {"joints": np.zeros((batch_size, T_state, D_state), dtype=np.float32)},
"language": {"task": [["pick up the cube"]] * batch_size},
}
action, _ = policy.get_action(observation)
# action["action_name"] has shape (batch_size, action_horizon, action_dim)单环境推理
对于单环境,批大小使用 1:
# Add batch dimension (B=1)
observation = {
"video": {"wrist_cam": video[np.newaxis, ...]}, # (1, T, H, W, 3)
"state": {"joints": state[np.newaxis, ...]}, # (1, T, D)
"language": {"task": [["pick up the cube"]]}, # List of length 1
}
action, _ = policy.get_action(observation)
# Remove batch dimension
single_action = action["action_name"][0] # (action_horizon, action_dim)动作分块
当动作时域 T > 1 时,可以使用动作分块:
action, _ = policy.get_action(observation)
action_chunk = action["action_name"][:, :, :] # (B, T, D)
# Execute actions over multiple timesteps
for t in range(action_chunk.shape[1]):
env.step(action_chunk[:, t, :])训练数据加载优化
训练模型时,可以通过多个命令行参数在数据加载速度与内存使用量之间进行权衡优化。
示例:
uv run python gr00t/experiment/launch_finetune.py \
.... \
--num-shards-per-epoch 100 \
--dataloader-num-workers 2 \
--shard-size 512 \如果 VRAM 受限,可以降低上述所有数值以减少内存使用量。
在 Shard 采样期间,为确保更加 IID,可以将 episode_sampling_rate 降低到 0.05 或更低。
故障排查
- 开发期间启用严格模式:
strict=True
- 打印模态配置以了解期望格式
- 调用
get_action()前检查观测形状
- 使用参考包装器(
Gr00tSimPolicyWrapper)作为模板
- 增量验证:连接真实环境前,先使用 Dummy 观测进行测试