LeRobot 与 RLDS 转换
1. 核心概念
在机器人学习(尤其是模仿学习、强化学习与真实世界机器人数据集)的语境下,下述概念是构建、使用和评估数据集的核心。
1.1. Episode(回合 / 轨迹)
定义
Episode 指机器人从某个初始状态开始,执行一系列动作,直到达到终止条件的完整过程。终止条件通常是任务成功、失败、超时或人为干预。
特点
- 在数据集中,一个 Episode 通常是一个时间序列,包含一系列观测、动作、奖励(如果是强化学习)等。
- 比如开门任务——从机械臂靠近门把手开始,到门被完全打开结束,整个过程就是一个 Episode。
常见结构
每个 Episode 通常是一个字典或列表,包含:
observations: 图像、本体感觉(关节角度、末端位姿)等序列
actions: 关节位置、速度、力控等动作序列
rewards(可选): 每个时间步的奖励
terminated/truncated: 表示是否成功或超时
1.2. Step(时间步)
定义
Step 是 Episode 内的最小时间单位,对应一次“观测 → 动作”的交互。
关系
- 一个 Episode 由多个 Step 组成。
- 比如一个 200 帧的开门轨迹,就有 200 个 Step(通常按控制频率采样,比如 10Hz、30Hz)。
数据粒度
在训练时,模型通常以 Step 为单位学习映射关系(比如从图像到动作),但有时也以 Episode 为单位进行序列建模(比如 Diffusion Policy、Transformer-based Policy 使用轨迹片段)。
1.3. Dataset Mix(数据集混合)
定义
指在训练时将多个来源、类型或任务的数据集组合在一起使用。
常见形式
- 多任务混合:将抓取、插拔、开关等不同任务的数据混合,训练通用模型。
- 多来源混合:结合仿真数据与真实机器人数据(Sim + Real),或不同实验室采集的数据(比如 Open X-Embodiment 汇集数十个机器人数据集)。
- 多模态混合:同一任务中包含不同视角的相机、力觉、语言指令等。
为什么要做 Mix
- 提升泛化能力
- 减少过拟合
- 实现跨任务、跨实体(不同机械臂)的通用模型
1.4. Split(数据集划分)
定义
将数据集划分为训练集(Train)、验证集(Val) 和 测试集(Test) 的方式。
机器人学习中的特殊考量
- 按 Episode 划分:不能将同一个 Episode 的部分 Step 分到训练集、部分分到测试集,否则将造成数据泄露(模型见过该轨迹的部分内容)。
- 按场景划分:更严格的划分方式——将某些场景(比如某个桌子布局、某类物体)的所有 Episode 划入测试集,用于测试模型的泛化能力。
- 按任务或物体划分:测试时使用训练中未见过的新物体或新任务,评估模型的新颖场景适应能力。
常用比例
- 常见划分:70% / 15% / 15% 或 80% / 10% / 10%
- 在小规模真实机器人数据集中,有时使用交叉验证 或 leave-one-out 方式
使用像 Open X-Embodiment、RLDS(Reinforcement Learning Dataset Standard)或 Hugging Face LeRobot 等框架时,这些概念将直接体现在数据加载器(DataLoader)和数据集类的设计中。
2. LeRobot To RLDS
import argparse
import logging
import os
from functools import partial
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from tensorflow_datasets.core.file_adapters import FileFormat
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.rlds import rlds_base
os.environ["NO_GCE_CHECK"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
tfds.core.utils.gcs_utils._is_gcs_disabled = True
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def generate_config_from_features(features, encoding_format, **kwargs):
action_info = {
**{
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
)
for k, v in features.items()
if "action" in k # for compatibility with actions.action_key and action
},
}
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
return dict(
observation_info={
**{
k.split(".")[-1]: tfds.features.Image(
shape=v["shape"], dtype=np.uint8, encoding_format=encoding_format, doc=v["names"]
)
for k, v in features.items()
if "observation.image" in k and "depth" not in k
},
**{
k.split(".")[-1]: tfds.features.Tensor(shape=v["shape"][:-1], dtype=np.float32, doc=v["names"])
for k, v in features.items()
if "observation.image" in k and "depth" in k
},
**{
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
)
for k, v in features.items()
if "observation.state" in k # for compatibility with observation.states.state_key and observation.state
},
},
action_info=action_info,
step_metadata_info={
"language_instruction": tfds.features.Text(),
},
citation=kwargs.get("citation", ""),
homepage=kwargs.get("homepage", ""),
overall_description=kwargs.get("overall_description", ""),
description=kwargs.get("description", ""),
)
def parse_step(data_item):
observation_info = {
**{
# lerobot image is (C, H, W) and in range [0, 1]
k.split(".")[-1]: np.array(v * 255, dtype=np.uint8).transpose(1, 2, 0)
for k, v in data_item.items()
if "observation.image" in k and "depth" not in k
},
**{
# lerobot depth is (1, H, W) and in range [0, 1]
k.split(".")[-1]: v.float().squeeze()
for k, v in data_item.items()
if "observation.image" in k and "depth" in k
},
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "observation.state" in k},
}
action_info = {
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "action" in k},
}
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
return observation_info, action_info, data_item["task"]
class DatasetBuilder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
def __init__(self, raw_dir, name, dataset_config, enable_beam, *, file_format=None, **kwargs):
self.name = name
self.VERSION = kwargs["version"]
self.raw_dir = raw_dir
self.dataset_config = dataset_config
self.enable_beam = enable_beam
self.__module__ = "lerobot2rlds"
super().__init__(file_format=file_format, **kwargs)
def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
return rlds_base.build_info(
rlds_base.DatasetConfig(
name=self.name,
**self.dataset_config,
),
self,
)
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
dl_manager._download_dir.rmtree(missing_ok=True)
return {
"train": self._generate_examples(),
}
def _generate_examples(self):
"""Yields examples."""
def _generate_examples_beam(episode_index, raw_dir):
episode = []
dataset = LeRobotDataset("", raw_dir, episodes=[episode_index])
logging.info(f"processing episode {episode_index}")
for data_item in dataset:
observation_info, action_info, language_instruction = parse_step(data_item)
episode.append(
{
"observation": observation_info,
"action": action_info,
"language_instruction": language_instruction,
"is_first": data_item["frame_index"].item() == 0,
"is_last": data_item["frame_index"].item()
== dataset.meta.episodes[episode_index]["length"] - 1,
"is_terminal": data_item["frame_index"].item()
== dataset.meta.episodes[episode_index]["length"] - 1,
}
)
return episode_index, {"steps": episode}
def _generate_examples_regular():
dataset = LeRobotDataset("", self.raw_dir)
episode = []
current_episode_index = 0
for data_item in dataset:
if data_item["episode_index"] != current_episode_index:
episode[-1]["is_last"] = True
episode[-1]["is_terminal"] = True
yield f"{current_episode_index}", {"steps": episode}
current_episode_index = data_item["episode_index"]
episode.clear()
observation_info, action_info, language_instruction = parse_step(data_item)
episode.append(
{
"observation": observation_info,
"action": action_info,
"language_instruction": language_instruction,
"is_first": data_item["frame_index"].item() == 0,
"is_last": False,
"is_terminal": False,
}
)
episode[-1]["is_last"] = True
episode[-1]["is_terminal"] = True
yield f"{current_episode_index}", {"steps": episode}
if self.enable_beam:
metadata = LeRobotDatasetMetadata("", self.raw_dir)
return beam.Create(list(metadata.episodes.keys())) | beam.Map(
partial(_generate_examples_beam, raw_dir=self.raw_dir)
)
else:
# NOTE: we should return a generator, not yield
return _generate_examples_regular()
def main(src_dir, output_dir, task_name, version, encoding_format, enable_beam, **kwargs):
raw_dataset_meta = LeRobotDatasetMetadata("", root=src_dir)
dataset_config = generate_config_from_features(raw_dataset_meta.features, encoding_format, **kwargs)
dataset_builder = DatasetBuilder(
raw_dir=src_dir,
name=task_name,
data_dir=output_dir,
version=version,
dataset_config=dataset_config,
enable_beam=enable_beam,
file_format=FileFormat.TFRECORD,
)
if enable_beam:
logging.warning("beam processing is enabled. Some episodes might be lost, a bug with apache beam.")
logging.warning("disable beam processing if your dataset is small or you want to save all episodes.")
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners import create_runner
if "threading" in kwargs["beam_run_mode"]:
logging.warning("multi_threading might have issues when sharding and saving.")
logging.warning("recommend using multi_processing instead.")
beam_options = PipelineOptions(
direct_running_mode=kwargs["beam_run_mode"],
direct_num_workers=kwargs["beam_num_workers"],
)
beam_runner = create_runner("DirectRunner")
else:
beam_options = None
beam_runner = None
dataset_builder.download_and_prepare(
download_config=tfds.download.DownloadConfig(
try_download_gcs=False,
verify_ssl=False,
beam_options=beam_options,
beam_runner=beam_runner,
),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src-dir", type=Path, help="Path to the local lerobot dataset.")
parser.add_argument("--output-dir", type=Path, help="Path to the output directory.")
parser.add_argument("--task-name", type=str, help="Task name.")
parser.add_argument("--enable-beam", action="store_true", help="Enable beam processing.")
parser.add_argument("--beam-run-mode", choices=["multi_threading", "multi_processing"], default="multi_processing")
parser.add_argument("--beam-num-workers", type=int, default=5)
parser.add_argument("--encoding-format", type=str, choices=["jpeg", "png"], default="jpeg")
parser.add_argument("--version", type=str, help="x.y.z", default="0.1.0")
parser.add_argument("--citation", type=str, help="Citation.", default="")
parser.add_argument("--homepage", type=str, help="Homepage.", default="")
parser.add_argument("--overall-description", type=str, help="Overall description.", default="")
parser.add_argument("--description", type=str, help="Description.", default="")
args = parser.parse_args()
main(**vars(args))运行示例:
python lerobot2rlds.py --src-dir /Users/timchow/Desktop/libero --output-dir libero_rlds_test --task-name libero_rlds_test3. 读取 RLDS 数据集
import tensorflow_datasets as tfds
# 指向包含数据集信息的目录
data_dir = "libero_rlds_test/libero_rlds_test/0.1.0"
builder = tfds.builder_from_directory(data_dir)
# 加载数据
ds = builder.as_dataset(split="all")
num_episodes = builder.info.splits["all"].num_examples
print("total episodes:", num_episodes)
# Step 计数器
total_steps = 0
# 遍历所有 Episode
for episode_idx, episode in enumerate(ds):
# 某些数据集没有 episode_metadata,使用索引作为回退标识
episode_metadata = episode.get('episode_metadata')
if isinstance(episode_metadata, dict) and 'episode_id' in episode_metadata:
episode_id = episode_metadata['episode_id'].numpy()
else:
episode_id = f"episode_{episode_idx + 1}"
print(f"\n{'='*60}")
print(f"Episode {episode_idx + 1}: ID = {episode_id}")
print(f"{'='*60}")
# 初始化当前 Episode 的步数计数
step_count = 0
# 遍历该 Episode 内的所有 Step
for step in episode["steps"]:
step_count += 1
total_steps += 1
def safe_numpy(step_dict, key, default=None):
value = step_dict.get(key, default)
return value.numpy() if hasattr(value, "numpy") else value
# 获取 Step 中的关键数据。
# 注意:具体的数据结构取决于数据集,这里展示常见的字段
observation = step["observation"] # 观测数据(通常是字典)
action = safe_numpy(step, "action") # 动作数据
reward = safe_numpy(step, "reward") # 奖励值(有些数据集可能没有)
is_first = safe_numpy(step, "is_first") # 是否第一步
is_last = safe_numpy(step, "is_last") # 是否最后一步
# 打印 Step 信息(只打印前 1 个 Step 避免输出过多)
if step_count <= 1:
if isinstance(observation, dict):
observation_desc = {
k: v.numpy().shape if hasattr(v, "numpy") and hasattr(
v.numpy(), "shape") else type(v).__name__
for k, v in observation.items()
}
else:
observation_desc = observation.shape if hasattr(
observation, "shape") else type(observation)
print(f" Step {step_count}:")
print(
f" Observation: {observation_desc}")
print(
f" Action shape: {action.shape if hasattr(action, 'shape') else type(action)}")
print(f" Reward: {reward}")
print(f" is_first: {is_first}, is_last: {is_last}")
print(f"Total steps in this episode: {step_count}")
# 打印汇总信息
print(f"\n{'='*60}")
print(f"SUMMARY")
print(f"{'='*60}")
print(f"Total episodes: {num_episodes}")
print(f"Total steps: {total_steps}")