LeRobot 与 RLDS 转换

1. 核心概念

在机器人学习(尤其是模仿学习、强化学习与真实世界机器人数据集)的语境下,下述概念是构建、使用和评估数据集的核心。

1.1. Episode(回合 / 轨迹)

定义

Episode 指机器人从某个初始状态开始,执行一系列动作,直到达到终止条件的完整过程。终止条件通常是任务成功、失败、超时或人为干预。

特点

  • 在数据集中,一个 Episode 通常是一个时间序列,包含一系列观测、动作、奖励(如果是强化学习)等。
  • 比如开门任务——从机械臂靠近门把手开始,到门被完全打开结束,整个过程就是一个 Episode。

常见结构

每个 Episode 通常是一个字典或列表,包含:

  • observations: 图像、本体感觉(关节角度、末端位姿)等序列
  • actions: 关节位置、速度、力控等动作序列
  • rewards(可选): 每个时间步的奖励
  • terminated / truncated: 表示是否成功或超时

1.2. Step(时间步)

定义

Step 是 Episode 内的最小时间单位,对应一次“观测 → 动作”的交互。

关系

  • 一个 Episode 由多个 Step 组成。
  • 比如一个 200 帧的开门轨迹,就有 200 个 Step(通常按控制频率采样,比如 10Hz、30Hz)。

数据粒度

在训练时,模型通常以 Step 为单位学习映射关系(比如从图像到动作),但有时也以 Episode 为单位进行序列建模(比如 Diffusion Policy、Transformer-based Policy 使用轨迹片段)。

1.3. Dataset Mix(数据集混合)

定义

指在训练时将多个来源、类型或任务的数据集组合在一起使用。

常见形式

  • 多任务混合:将抓取、插拔、开关等不同任务的数据混合,训练通用模型。
  • 多来源混合:结合仿真数据与真实机器人数据(Sim + Real),或不同实验室采集的数据(比如 Open X-Embodiment 汇集数十个机器人数据集)。
  • 多模态混合:同一任务中包含不同视角的相机、力觉、语言指令等。

为什么要做 Mix

  • 提升泛化能力
  • 减少过拟合
  • 实现跨任务、跨实体(不同机械臂)的通用模型

1.4. Split(数据集划分)

定义

将数据集划分为训练集(Train)验证集(Val)测试集(Test) 的方式。

机器人学习中的特殊考量

  • 按 Episode 划分:不能将同一个 Episode 的部分 Step 分到训练集、部分分到测试集,否则将造成数据泄露(模型见过该轨迹的部分内容)。
  • 按场景划分:更严格的划分方式——将某些场景(比如某个桌子布局、某类物体)的所有 Episode 划入测试集,用于测试模型的泛化能力。
  • 按任务或物体划分:测试时使用训练中未见过的新物体或新任务,评估模型的新颖场景适应能力。

常用比例

  • 常见划分:70% / 15% / 15% 或 80% / 10% / 10%
  • 在小规模真实机器人数据集中,有时使用交叉验证leave-one-out 方式

使用像 Open X-EmbodimentRLDS(Reinforcement Learning Dataset Standard)或 Hugging Face LeRobot 等框架时,这些概念将直接体现在数据加载器(DataLoader)和数据集类的设计中。



2. LeRobot To RLDS

下面的代码来自 https://github.com/Tavish9/any4lerobot/blob/main/lerobot2rlds/lerobot2rlds.py

import argparse
import logging
import os
from functools import partial
from pathlib import Path

import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from tensorflow_datasets.core.file_adapters import FileFormat
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.rlds import rlds_base

os.environ["NO_GCE_CHECK"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
tfds.core.utils.gcs_utils._is_gcs_disabled = True

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def generate_config_from_features(features, encoding_format, **kwargs):
    action_info = {
        **{
            "_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
                shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
            )
            for k, v in features.items()
            if "action" in k  # for compatibility with actions.action_key and action
        },
    }
    action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
    return dict(
        observation_info={
            **{
                k.split(".")[-1]: tfds.features.Image(
                    shape=v["shape"], dtype=np.uint8, encoding_format=encoding_format, doc=v["names"]
                )
                for k, v in features.items()
                if "observation.image" in k and "depth" not in k
            },
            **{
                k.split(".")[-1]: tfds.features.Tensor(shape=v["shape"][:-1], dtype=np.float32, doc=v["names"])
                for k, v in features.items()
                if "observation.image" in k and "depth" in k
            },
            **{
                "_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
                    shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
                )
                for k, v in features.items()
                if "observation.state" in k  # for compatibility with observation.states.state_key and observation.state
            },
        },
        action_info=action_info,
        step_metadata_info={
            "language_instruction": tfds.features.Text(),
        },
        citation=kwargs.get("citation", ""),
        homepage=kwargs.get("homepage", ""),
        overall_description=kwargs.get("overall_description", ""),
        description=kwargs.get("description", ""),
    )


def parse_step(data_item):
    observation_info = {
        **{
            # lerobot image is (C, H, W) and in range [0, 1]
            k.split(".")[-1]: np.array(v * 255, dtype=np.uint8).transpose(1, 2, 0)
            for k, v in data_item.items()
            if "observation.image" in k and "depth" not in k
        },
        **{
            # lerobot depth is (1, H, W) and in range [0, 1]
            k.split(".")[-1]: v.float().squeeze()
            for k, v in data_item.items()
            if "observation.image" in k and "depth" in k
        },
        **{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "observation.state" in k},
    }
    action_info = {
        **{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "action" in k},
    }
    action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]

    return observation_info, action_info, data_item["task"]


class DatasetBuilder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
    def __init__(self, raw_dir, name, dataset_config, enable_beam, *, file_format=None, **kwargs):
        self.name = name
        self.VERSION = kwargs["version"]
        self.raw_dir = raw_dir
        self.dataset_config = dataset_config
        self.enable_beam = enable_beam
        self.__module__ = "lerobot2rlds"
        super().__init__(file_format=file_format, **kwargs)

    def _info(self) -> tfds.core.DatasetInfo:
        """Returns the dataset metadata."""
        return rlds_base.build_info(
            rlds_base.DatasetConfig(
                name=self.name,
                **self.dataset_config,
            ),
            self,
        )

    def _split_generators(self, dl_manager: tfds.download.DownloadManager):
        """Returns SplitGenerators."""
        dl_manager._download_dir.rmtree(missing_ok=True)
        return {
            "train": self._generate_examples(),
        }

    def _generate_examples(self):
        """Yields examples."""

        def _generate_examples_beam(episode_index, raw_dir):
            episode = []
            dataset = LeRobotDataset("", raw_dir, episodes=[episode_index])
            logging.info(f"processing episode {episode_index}")
            for data_item in dataset:
                observation_info, action_info, language_instruction = parse_step(data_item)
                episode.append(
                    {
                        "observation": observation_info,
                        "action": action_info,
                        "language_instruction": language_instruction,
                        "is_first": data_item["frame_index"].item() == 0,
                        "is_last": data_item["frame_index"].item()
                        == dataset.meta.episodes[episode_index]["length"] - 1,
                        "is_terminal": data_item["frame_index"].item()
                        == dataset.meta.episodes[episode_index]["length"] - 1,
                    }
                )
            return episode_index, {"steps": episode}

        def _generate_examples_regular():
            dataset = LeRobotDataset("", self.raw_dir)
            episode = []
            current_episode_index = 0
            for data_item in dataset:
                if data_item["episode_index"] != current_episode_index:
                    episode[-1]["is_last"] = True
                    episode[-1]["is_terminal"] = True
                    yield f"{current_episode_index}", {"steps": episode}
                    current_episode_index = data_item["episode_index"]
                    episode.clear()

                observation_info, action_info, language_instruction = parse_step(data_item)
                episode.append(
                    {
                        "observation": observation_info,
                        "action": action_info,
                        "language_instruction": language_instruction,
                        "is_first": data_item["frame_index"].item() == 0,
                        "is_last": False,
                        "is_terminal": False,
                    }
                )
            episode[-1]["is_last"] = True
            episode[-1]["is_terminal"] = True
            yield f"{current_episode_index}", {"steps": episode}

        if self.enable_beam:
            metadata = LeRobotDatasetMetadata("", self.raw_dir)
            return beam.Create(list(metadata.episodes.keys())) | beam.Map(
                partial(_generate_examples_beam, raw_dir=self.raw_dir)
            )
        else:
            # NOTE: we should return a generator, not yield
            return _generate_examples_regular()


def main(src_dir, output_dir, task_name, version, encoding_format, enable_beam, **kwargs):
    raw_dataset_meta = LeRobotDatasetMetadata("", root=src_dir)

    dataset_config = generate_config_from_features(raw_dataset_meta.features, encoding_format, **kwargs)

    dataset_builder = DatasetBuilder(
        raw_dir=src_dir,
        name=task_name,
        data_dir=output_dir,
        version=version,
        dataset_config=dataset_config,
        enable_beam=enable_beam,
        file_format=FileFormat.TFRECORD,
    )

    if enable_beam:
        logging.warning("beam processing is enabled. Some episodes might be lost, a bug with apache beam.")
        logging.warning("disable beam processing if your dataset is small or you want to save all episodes.")
        from apache_beam.options.pipeline_options import PipelineOptions
        from apache_beam.runners import create_runner

        if "threading" in kwargs["beam_run_mode"]:
            logging.warning("multi_threading might have issues when sharding and saving.")
            logging.warning("recommend using multi_processing instead.")

        beam_options = PipelineOptions(
            direct_running_mode=kwargs["beam_run_mode"],
            direct_num_workers=kwargs["beam_num_workers"],
        )
        beam_runner = create_runner("DirectRunner")
    else:
        beam_options = None
        beam_runner = None

    dataset_builder.download_and_prepare(
        download_config=tfds.download.DownloadConfig(
            try_download_gcs=False,
            verify_ssl=False,
            beam_options=beam_options,
            beam_runner=beam_runner,
        ),
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--src-dir", type=Path, help="Path to the local lerobot dataset.")
    parser.add_argument("--output-dir", type=Path, help="Path to the output directory.")
    parser.add_argument("--task-name", type=str, help="Task name.")
    parser.add_argument("--enable-beam", action="store_true", help="Enable beam processing.")
    parser.add_argument("--beam-run-mode", choices=["multi_threading", "multi_processing"], default="multi_processing")
    parser.add_argument("--beam-num-workers", type=int, default=5)
    parser.add_argument("--encoding-format", type=str, choices=["jpeg", "png"], default="jpeg")
    parser.add_argument("--version", type=str, help="x.y.z", default="0.1.0")
    parser.add_argument("--citation", type=str, help="Citation.", default="")
    parser.add_argument("--homepage", type=str, help="Homepage.", default="")
    parser.add_argument("--overall-description", type=str, help="Overall description.", default="")
    parser.add_argument("--description", type=str, help="Description.", default="")
    args = parser.parse_args()

    main(**vars(args))

运行示例:

python lerobot2rlds.py --src-dir /Users/timchow/Desktop/libero --output-dir libero_rlds_test --task-name libero_rlds_test

3. 读取 RLDS 数据集

import tensorflow_datasets as tfds


# 指向包含数据集信息的目录
data_dir = "libero_rlds_test/libero_rlds_test/0.1.0"
builder = tfds.builder_from_directory(data_dir)
# 加载数据
ds = builder.as_dataset(split="all")
num_episodes = builder.info.splits["all"].num_examples
print("total episodes:", num_episodes)

# Step 计数器
total_steps = 0

# 遍历所有 Episode
for episode_idx, episode in enumerate(ds):
    # 某些数据集没有 episode_metadata,使用索引作为回退标识
    episode_metadata = episode.get('episode_metadata')
    if isinstance(episode_metadata, dict) and 'episode_id' in episode_metadata:
        episode_id = episode_metadata['episode_id'].numpy()
    else:
        episode_id = f"episode_{episode_idx + 1}"

    print(f"\n{'='*60}")
    print(f"Episode {episode_idx + 1}: ID = {episode_id}")
    print(f"{'='*60}")

    # 初始化当前 Episode 的步数计数
    step_count = 0

    # 遍历该 Episode 内的所有 Step
    for step in episode["steps"]:
        step_count += 1
        total_steps += 1

        def safe_numpy(step_dict, key, default=None):
            value = step_dict.get(key, default)
            return value.numpy() if hasattr(value, "numpy") else value

        # 获取 Step 中的关键数据。
        # 注意:具体的数据结构取决于数据集,这里展示常见的字段
        observation = step["observation"]          # 观测数据(通常是字典)
        action = safe_numpy(step, "action")        # 动作数据
        reward = safe_numpy(step, "reward")        # 奖励值(有些数据集可能没有)
        is_first = safe_numpy(step, "is_first")    # 是否第一步
        is_last = safe_numpy(step, "is_last")      # 是否最后一步

        # 打印 Step 信息(只打印前 1 个 Step 避免输出过多)
        if step_count <= 1:
            if isinstance(observation, dict):
                observation_desc = {
                    k: v.numpy().shape if hasattr(v, "numpy") and hasattr(
                        v.numpy(), "shape") else type(v).__name__
                    for k, v in observation.items()
                }
            else:
                observation_desc = observation.shape if hasattr(
                    observation, "shape") else type(observation)

            print(f"  Step {step_count}:")
            print(
                f"    Observation: {observation_desc}")
            print(
                f"    Action shape: {action.shape if hasattr(action, 'shape') else type(action)}")
            print(f"    Reward: {reward}")
            print(f"    is_first: {is_first}, is_last: {is_last}")

    print(f"Total steps in this episode: {step_count}")

# 打印汇总信息
print(f"\n{'='*60}")
print(f"SUMMARY")
print(f"{'='*60}")
print(f"Total episodes: {num_episodes}")
print(f"Total steps: {total_steps}")