Pi0.5 系列 3 - 推理
Pi0.5 系列 3 - 推理
Checkpoint 目录结构
checkpoints/
└── pi05_XXX
└── pi05_XXX_test_1
└── 19999
├── assets
│ └── YYY
│ └── norm_stats.json
├── metadata.pt
├── model.safetensors
└── optimizer.pt推理
from openpi.training import config as _config
from openpi.policies import policy_config, XXX_policy
config = _config.get_config("pi05_XXX")
checkpoint_dir = "/root/openpi/checkpoints/pi05_XXX/pi05_XXX_test_1/19999"
# Create a trained policy (automatically detects PyTorch format)
policy = policy_config.create_trained_policy(config, checkpoint_dir)
# Run inference (same API as JAX)
action_chunk = policy.infer(XXX_policy.make_XXX_example())["actions"]
print(action_chunk)示例输入:
def make_XXX_example() -> dict:
"""为 XXX 策略创建随机输入示例。"""
state = np.random.rand(16)
# 7 和 15 分别是右手和左手
state[[7, 15]] = np.random.randint(2, size=2)
return {
"observation.state": state,
# HWC
"observation.images.front": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
"observation.images.left": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
"observation.images.right": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
"prompt": "do something",
}避免在线下载
在 openpi 中的大多数地方 paligemma_tokenizer.model 的地址都是硬编码的,比如:
class PaligemmaTokenizer:
def __init__(self, max_len: int = 48):
self._max_len = max_len
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())默认的缓存路径在 ~/.cache/openpi,但可以通过环境变量 OPENPI_DATA_HOME 指定缓存路径。
开启 Torch Compile
import importlib
import os
import statistics
import time
def setup_torch_compile_cache():
# 这些环境变量需要在导入 openpi.policy_config 前设置,因为它间接导入 torch。
_COMPILE_CACHE_DIR = "/root/openpi/.torch_compile_cache"
# 指定 PyTorch Inductor 的编译产物缓存目录,用于复用 torch.compile 生成的代码。
os.environ.setdefault(
"TORCHINDUCTOR_CACHE_DIR",
f"{_COMPILE_CACHE_DIR}/inductor"
)
# 指定 Triton kernel 缓存目录;Inductor 生成的 GPU kernel 将经过 Triton 编译。
os.environ.setdefault("TRITON_CACHE_DIR", f"{_COMPILE_CACHE_DIR}/triton")
# 开启 FX Graph 级别缓存,帮助后续进程复用 torch.compile 捕获和编译后的图。
os.environ.setdefault("TORCHINDUCTOR_FX_GRAPH_CACHE", "1")
# 打印重新编译和 graph break 日志,用于观察 torch.compile 是否稳定命中同一张图。
os.environ.setdefault("TORCH_LOGS", "recompiles,graph_breaks")
# 输出更详细的 Dynamo 日志,和 TORCH_LOGS 配合排查编译过程。
os.environ.setdefault("TORCHDYNAMO_VERBOSE", "1")
# 生成 torch_compile_debug 调试目录,里面包含编译图和 Inductor 生成代码等信息。
os.environ.setdefault("TORCH_COMPILE_DEBUG", "1")
setup_torch_compile_cache()
# 在缓存环境变量设置完成后再动态导入 openpi,避免 autopep8 将 import 移到文件顶部。
_config = importlib.import_module("openpi.training.config")
policy_config = importlib.import_module("openpi.policies.policy_config")
XXX_policy = importlib.import_module("openpi.policies.XXX_policy")
config = _config.get_config("pi05_XXX")
checkpoint_dir = "/root/openpi/checkpoints/pi05_XXX/pi05_XXX_test_1/19999"
# 创建训练好的 Policy,自动检测是否为 PyTorch 格式。
policy = policy_config.create_trained_policy(config, checkpoint_dir)
example = XXX_policy.make_XXX_example()
# 用真实推理相同的输入形状和 dtype 预热 torch.compile。
for i in range(2):
print(f"开始 warmup {i + 1}...")
start = time.perf_counter()
result = policy.infer(example)
print(
f"warmup {i + 1}: {time.perf_counter() - start:.3f}s,"
f"model={result['policy_timing']['infer_ms']:.1f}ms"
)
print("开始正式推理...")
# 正式推理,API 与 JAX 版本一致。循环统计总耗时和模型段耗时。
infer_times = []
model_times = []
for _ in range(100):
start = time.perf_counter()
result = policy.infer(example)
infer_times.append((time.perf_counter() - start) * 1000)
model_times.append(result["policy_timing"]["infer_ms"])
print(
"infer total ms: "
f"min={min(infer_times):.1f}, max={max(infer_times):.1f}, mean={statistics.mean(infer_times):.1f}"
)
print(
"infer model ms: "
f"min={min(model_times):.1f}, max={max(model_times):.1f}, mean={statistics.mean(model_times):.1f}"
)