Pi0.5 系列 3 - 推理

Pi0.5 系列 3 - 推理

Checkpoint 目录结构

checkpoints/
└── pi05_XXX
    └── pi05_XXX_test_1
        └── 19999
            ├── assets
            │   └── YYY
            │       └── norm_stats.json
            ├── metadata.pt
            ├── model.safetensors
            └── optimizer.pt

推理

from openpi.training import config as _config
from openpi.policies import policy_config, XXX_policy


config = _config.get_config("pi05_XXX")
checkpoint_dir = "/root/openpi/checkpoints/pi05_XXX/pi05_XXX_test_1/19999"

# Create a trained policy (automatically detects PyTorch format)
policy = policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference (same API as JAX)
action_chunk = policy.infer(XXX_policy.make_XXX_example())["actions"]

print(action_chunk)

示例输入:

def make_XXX_example() -> dict:
    """为 XXX 策略创建随机输入示例。"""
    state = np.random.rand(16)
    # 7 和 15 分别是右手和左手
    state[[7, 15]] = np.random.randint(2, size=2)

    return {
        "observation.state": state,
        # HWC
        "observation.images.front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
        "observation.images.left": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
        "observation.images.right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
        "prompt": "do something",
    }

避免在线下载

在 openpi 中的大多数地方 paligemma_tokenizer.model 的地址都是硬编码的,比如:

class PaligemmaTokenizer:
    def __init__(self, max_len: int = 48):
        self._max_len = max_len

        path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
        with path.open("rb") as f:
            self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())

默认的缓存路径在 ~/.cache/openpi,但可以通过环境变量 OPENPI_DATA_HOME 指定缓存路径。


开启 Torch Compile

import importlib
import os
import statistics
import time


def setup_torch_compile_cache():
    # 这些环境变量需要在导入 openpi.policy_config 前设置,因为它间接导入 torch。
    _COMPILE_CACHE_DIR = "/root/openpi/.torch_compile_cache"
    # 指定 PyTorch Inductor 的编译产物缓存目录,用于复用 torch.compile 生成的代码。
    os.environ.setdefault(
        "TORCHINDUCTOR_CACHE_DIR",
        f"{_COMPILE_CACHE_DIR}/inductor"
    )
    # 指定 Triton kernel 缓存目录;Inductor 生成的 GPU kernel 将经过 Triton 编译。
    os.environ.setdefault("TRITON_CACHE_DIR", f"{_COMPILE_CACHE_DIR}/triton")
    # 开启 FX Graph 级别缓存,帮助后续进程复用 torch.compile 捕获和编译后的图。
    os.environ.setdefault("TORCHINDUCTOR_FX_GRAPH_CACHE", "1")
    # 打印重新编译和 graph break 日志,用于观察 torch.compile 是否稳定命中同一张图。
    os.environ.setdefault("TORCH_LOGS", "recompiles,graph_breaks")
    # 输出更详细的 Dynamo 日志,和 TORCH_LOGS 配合排查编译过程。
    os.environ.setdefault("TORCHDYNAMO_VERBOSE", "1")
    # 生成 torch_compile_debug 调试目录,里面包含编译图和 Inductor 生成代码等信息。
    os.environ.setdefault("TORCH_COMPILE_DEBUG", "1")


setup_torch_compile_cache()

# 在缓存环境变量设置完成后再动态导入 openpi,避免 autopep8 将 import 移到文件顶部。
_config = importlib.import_module("openpi.training.config")
policy_config = importlib.import_module("openpi.policies.policy_config")
XXX_policy = importlib.import_module("openpi.policies.XXX_policy")


config = _config.get_config("pi05_XXX")
checkpoint_dir = "/root/openpi/checkpoints/pi05_XXX/pi05_XXX_test_1/19999"

# 创建训练好的 Policy,自动检测是否为 PyTorch 格式。
policy = policy_config.create_trained_policy(config, checkpoint_dir)

example = XXX_policy.make_XXX_example()

# 用真实推理相同的输入形状和 dtype 预热 torch.compile。
for i in range(2):
    print(f"开始 warmup {i + 1}...")
    start = time.perf_counter()
    result = policy.infer(example)
    print(
        f"warmup {i + 1}: {time.perf_counter() - start:.3f}s,"
        f"model={result['policy_timing']['infer_ms']:.1f}ms"
    )

print("开始正式推理...")
# 正式推理,API 与 JAX 版本一致。循环统计总耗时和模型段耗时。
infer_times = []
model_times = []
for _ in range(100):
    start = time.perf_counter()
    result = policy.infer(example)
    infer_times.append((time.perf_counter() - start) * 1000)
    model_times.append(result["policy_timing"]["infer_ms"])

print(
    "infer total ms: "
    f"min={min(infer_times):.1f}, max={max(infer_times):.1f}, mean={statistics.mean(infer_times):.1f}"
)
print(
    "infer model ms: "
    f"min={min(model_times):.1f}, max={max(model_times):.1f}, mean={statistics.mean(model_times):.1f}"
)