Pi0.5 系列 2 - 在自己的数据上微调基础模型
Pi0.5 系列 2 - 在自己的数据上微调基础模型
将自己的数据转换为 LeRobot 数据集(LeRobot v2.1)
下面以将 LeRobot v3.0 数据集转换为 LeRobot v2.1 为例进行说明。
安装 FFMPEG:
sudo apt install ffmpeg使用 any4lerobot 进行转换:
# 克隆 any4lerobot
git clone <https://github.com/Tavish9/any4lerobot>
cd any4lerobot/
# 创建虚拟环境 - 只需执行一次
uv venv --python 3.12
# 激活虚拟环境 - 每次运行时都需要执行
source .venv/bin/activate
python -m ensurepip --upgrade
python -m pip install --upgrade pip
# 参考:<https://github.com/Tavish9/any4lerobot/blob/main/ds_version_convert/v30_to_v21/README.md。>
# 1. 降级 datasets,因为自 4.0.0 起引入 List 和 Column。
python -m pip install jsonlines numpy tqdm pyarrow "datasets==2.19.2" "av==12.3.0"
# 2. 安装 v3.0 LeRobot
git clone <https://github.com/huggingface/lerobot.git>
cd lerobot/
# 切换分支
git checkout v0.5.1
python -m pip install -i <https://mirrors.aliyun.com/pypi/simple/> -e .
cd ..
# 3. 转换
cd ds_version_convert/v30_to_v21/
# 黄色部分按需修改
python convert_dataset_v30_to_v21.py --repo-id local/20260531 --root /opt/0615/merged_grasp_Coke_20260531定义训练配置
在自己的数据上微调基础模型时,需要为数据处理和训练定义配置。openpi 为 LIBERO 提供带详尽注释的示例配置,可供参考。
定义数据映射
定义从 LIBERO 环境到模型的数据映射,反之亦然。该映射将同时用于训练和推理。
import dataclasses
import einops
import numpy as np
from openpi import transforms
from openpi.models import model as _model
def make_libero_example() -> dict:
"""为 Libero 策略创建随机输入示例。"""
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
def _parse_image(image) -> np.ndarray:
image = np.asarray(image)
# 反归一化
if np.issubdtype(image.dtype, np.floating):
image = (255 * image).astype(np.uint8)
# 将图像形状转换为 HWC
if image.shape[0] == 3:
image = einops.rearrange(image, "c h w -> h w c")
return image
@dataclasses.dataclass(frozen=True)
class LiberoInputs(transforms.DataTransformFn):
"""
该类用于将输入转换为模型期望的格式。同时用于训练和推理。
对于自己的数据集,可以复制该类,根据下面的注释修改键,以便将数据集中正确的元素传入模型。
"""
# 决定使用哪个模型。不要修改。
model_type: _model.ModelType
def __call__(self, data: dict) -> dict:
# 可能需要将图像解析为 uint8 格式的 (H,W,C),因为 LeRobot 自动
# 存储为 float32 格式的 (C,H,W),*如果不转换,策略推理时将跳过这一路图像。*
# 请保留这部分给自己的数据集使用,但如果数据集将图像
# 存在不同于 "observation/image" 或 "observation/wrist_image" 的键下,
# 那么应该在下面修改对应键名。
# Pi0 模型目前支持三个图像输入:一个第三人称视角,
# 以及两个腕部视角(左腕和右腕)。如果数据集没有某一类
# 图像,比如腕部图像,可以在这里将其注释掉,同时像下面处理
# 右腕图像那样用零值替代。
base_image = _parse_image(data["observation/image"])
wrist_image = _parse_image(data["observation/wrist_image"])
# 创建输入字典。不要更改下面字典中的键。
inputs = {
"state": data["observation/state"],
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# 对于不存在的图像,用形状匹配的零数组进行填充。
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
*# 只有 pi0 模型需要对填充图像做 mask,pi0-FAST 不需要。针对自己的数据集也不要修改这里。*
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
# 将动作填充到模型动作维度。自己的数据集也需要保留这一步。
# 动作只在训练阶段可用。
if "actions" in data:
inputs["actions"] = data["actions"]
# 将 Prompt(即语言指令)传递给模型。
# 自己的数据集也需要保留这一步;但如果指令不是存储在 "prompt" 这个键下,
# 需要修改键名。输出字典中需要包含 "prompt" 这个键。
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class LiberoOutputs(transforms.DataTransformFn):
"""
该类用于将模型输出转换回数据集特定的格式。仅用于推理阶段。
对于自己的数据集,可以复制该类,同时根据下面的注释修改动作维度。
"""
def __call__(self, data: dict) -> dict:
# 仅返回前 N 个动作维度。由于上文为匹配模型动作维度对动作进行填充,
# 此处需要从返回字典中解析出实际有效的动作维度。
# 对于 Libero 数据集,仅返回前 7 个动作维度(其余部分为填充)。
# 若使用自定义数据集,请将 `7` 替换为该数据集的动作维度。
return {"actions": np.asarray(data["actions"][:, :7])}处理原始数据
定义如何处理 LeRobot 数据集中的原始 LIBERO 数据,以用于训练。
@dataclasses.dataclass(frozen=True)
class LeRobotLiberoDataConfig(DataConfigFactory):
"""
该配置用于设置数据流水线中不同阶段所应用的变换。
若使用自定义数据集,可以复制该类,根据下方注释修改变换,使其与数据集匹配。
"""
extra_delta_transform: bool = False
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# repack 变换仅应用于来自数据集的数据,不用于推理阶段。
# 可通过该变换使数据集输入尽可能接近推理环境中的输入格式
#(比如匹配键名)。
# 下方将数据集中的键名(在数据转换脚本中定义)映射到
# 推理流水线中使用的键名(在 Libero 推理脚本中定义)。
# 若使用自定义数据集,应先确认环境传递给 Policy Server 的键名,
# 然后修改下方映射,使数据集中的键名与这些目标键名匹配。
# 此处的 repack 变换仅用于重映射键名。
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"observation/image": "image",
"observation/wrist_image": "wrist_image",
"observation/state": "state",
"actions": "actions",
"prompt": "prompt",
}
)
]
)
# 数据变换同时应用于来自数据集的数据以及推理阶段的数据。
# 下方分别定义输入模型的数据变换(``inputs``)和模型输出数据的变换
#(``outputs``,仅在推理阶段使用)。
# 这些变换定义于 `libero_policy.py`。可查看其中的详细注释,
# 了解如何修改变换以适配自定义数据集。创建自定义变换后,
# 可将下方变换替换为自定义版本。
data_transforms = _transforms.Group(
inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],
outputs=[libero_policy.LiberoOutputs()],
)
# 额外的数据变换:pi0 模型使用 delta 动作进行训练
# (相对于每个动作片段中的第一个状态)。如果数据中的动作是
# ``absolute`` 动作(比如目标关节角),可以取消注释下面这一行,
# 将动作转换为 delta 动作。唯一的例外是夹爪动作,其始终为 absolute。
# 在下面的示例中,delta 转换将应用于前 6 个动作维度(关节),
# 第 7 个动作维度(夹爪)保持不变,即仍为 absolute。
# 在 Libero 中,数据集内的原始动作已经是 delta 动作,因此不需要
# 应用单独的 delta 转换(这也是该行被注释掉的原因)。是否应用该变换,
# 应根据数据集原生使用的是 ``absolute`` 动作还是 ``delta`` 动作来决定。
# LIBERO 已经使用 delta 形式表示动作,但存在一些较早的 Pi0 Checkpoint,
# 其训练时使用这一额外的 delta 变换。
if self.extra_delta_transform:
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# 模型变换包括对 Prompt 和动作目标进行分词等操作。
# 若使用自定义数据集,此处无需修改。
model_transforms = ModelTransformFactory()(model_config)
# 返回训练和推理所需的所有数据变换。此处无需修改。
return dataclasses.replace(
self.create_base_config(assets_dirs, model_config),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)make_bool_mask(*dims: int) -> tuple[bool, ...]
为给定的维度创建布尔掩码。
示例:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)参数:
- dims:为哪些维度生成掩码。
返回:
- 布尔值元组。
训练配置
定义微调超参数,数据配置,以及权重加载器。
# 如果代码中需要通过名称获取配置,那么使用 `get_config`。
_CONFIGS = [
TrainConfig(
name="pi05_libero",
model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=False,
),
batch_size=256,
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=10_000,
peak_lr=5e-5,
decay_steps=1_000_000,
decay_lr=5e-5,
),
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
ema_decay=0.999,
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
pytorch_weight_path="/path/to/your/pytorch_weight_path",
num_train_steps=30_000,
),
]Pi0Config
pi05: bool
Pi05 与 Pi0 有两点不同:
- 状态输入是离散语言 Token 的一部分,而非作为
suffix一部分的连续输入。- 在 Pi05 里状态被离散化(256 个 bin)后拼进 Prompt 字符串里,与文本一起经过 PaliGemma Tokenizer,所以是"语言 Token 的一部分"。
suffix是 Pi0 架构里相对prefix(视觉 + 文本 Token)的对应术语,指动作专家处理的那段输入序列。
- 动作专家使用
adaRMSNorm注入 Flow Matching 的时间步。
discrete_state_input: bool
discrete_state_input 表示是否将 state 离散化、当作语言 Token 拼进 Prompt 里。它默认跟随 pi05 同步开关。
action_dim: int
动作空间维度。
action_horizon: int
动作序列长度。
pytorch_compile_mode: str | None = "max-autotune"
目前只影响 PyTorch 版推理采样。pytorch_compile_mode 允许 4 种 torch.compile 模式:
default
reduce-overhead
max-autotune
max-autotune-no-cudagraphs
None,表示不对sample_actions调用torch.compile
数据配置
repo_id: str:LeRobot Repo ID
assets: AssetsConfig:
决定 assets(比如 norm stats)的位置,这些 assets 用于设置数据流水线。
这些 assets 将被复制到 checkpoint 里的 assets/asset_id 目录下。
这个配置可以用于从另一个 checkpoint(比如基础模型 checkpoint)或其他集中存放的位置加载 assets。
比如在微调时,如果想从基础模型 checkpoint 加载 Trossen 机器人的归一化统计,可以这样写:
AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
)base_config: tyro.conf.Suppress[DataConfig | None:prompt_from_task: bool:True 表示使用 LeRobot 数据集任务定义 Prompt。
action_sequence_keys: Sequence[str]:数据加载器使用这些 Key 名称生成 Action 序列。
这个序列的长度由模型配置里的
action_horizon字段决定。如果 LeRobot 数据集使用不同的 Key 来表示 Action,那么需要调整这里。
use_quantile_norm:如果为 True,则使用分位数归一化;否则使用普通的 z-score 归一化。非 Pi0 模型为 True。
weight_loader
权重加载器可以在模型初始化后,选择性地从磁盘加载权重;加载的权重也可以是不完整的、只包含一部分参数。默认值是 NoOpWeightLoader()。
pytorch_weight_path
可选的 PyTorch checkpoint 路径,用于加载模型权重。
计算归一化统计
首先,更新依赖包的版本:
uv pip install "datasets==3.6.0" "huggingface-hub>=0.30.0,<1.0"在运行训练前需要为训练数据计算归一化统计,使用训练配置名称运行下面的脚本:
uv run scripts/compute_norm_stats.py --config-name pi05_liberoPyTorch 支持
openpi 现在除原始的 JAX 版本外,也提供 π₀ 和 π₀.₅ 模型的 PyTorch 实现。PyTorch 实现已经在 LIBERO benchmark 上验证过,包括推理和微调。目前还有一些功能暂不支持,但未来可能变化:
- π₀-FAST 模型
- 混合精度训练
- FSDP 训练
- LoRA 训练
- EMA 权重训练
设置
- 确保已安装最新版本的所有依赖:
uv sync
- 确认已经安装
transformers 4.53.2:uv pip show transformers
- 应用
transformers库补丁:cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
这将用必要的模型修改覆盖 transformers 库中的若干文件,包括:
- 支持 AdaRMS
- 正确控制激活值精度
- 允许使用 KV Cache,但不更新它
警告:使用 uv 默认的链接模式(hardlink)时,这将永久影响 uv 缓存中的 transformers 库。也就是说,即使重新安装 transformers,这些修改仍被保留,甚至可能传播到其他使用 transformers 的项目。若要完全撤销这个操作,必须运行 uv cache clean transformers。
将 JAX 模型转换为 PyTorch
要把 JAX 模型 checkpoint 转换为 PyTorch 格式:
uv run examples/convert_jax_model_to_pytorch.py \
--checkpoint_dir /path/to/jax/checkpoint \
--config_name <config name> \
--output_path /path/to/converted/pytorch/checkpoint使用 PyTorch 进行推理
PyTorch 实现使用和 JAX 版本相同的 API,只需将 checkpoint 路径改成转换后的 PyTorch 模型路径:
from openpi.training import config as _config
from openpi.policies import policy_config
from openpi.shared import download
config = _config.get_config("pi05_droid")
checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
# Create a trained policy (automatically detects PyTorch format)
policy = policy_config.create_trained_policy(config, checkpoint_dir)
# Run inference (same API as JAX)
action_chunk = policy.infer(example)["actions"]使用 PyTorch 微调
安装依赖库:
sudo apt-get update
sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev将 JAX 基础模型转换为 PyTorch 格式
uv run examples/convert_jax_model_to_pytorch.py --checkpoint_dir /root/.cache/openpi/openpi-assets/checkpoints/pi05_base --config_name pi05_XXX --output_path pi05_base_pytorch指定 PyTorch 模型路径
在配置中使用 pytorch_weight_path 指定转换后的 PyTorch 模型路径。
使用以下任一模式启动训练
# Single GPU training:
uv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
# Example:
uv run scripts/train_pytorch.py debug --exp_name pytorch_test
uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint
# Multi-GPU training (single node):
uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
# Example:
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
# Multi-Node Training:
uv run torchrun \
--nnodes=<num_nodes> \
--nproc_per_node=<gpus_per_node> \
--node_rank=<rank_of_node> \
--master_addr=<master_ip> \
--master_port=<port> \
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>