FlowerVLA 源码解析 3 - FlowerVLA Agent 训练及推理

源码:https://github.com/intuitive-robots/flower_vla_pret/blob/main/flower_vla/agents/flower.py


1. rf_loss()

    # 计算 Rectified Flow 损失:
    # 目标是学习速度场 v_theta(zt, t, cond),使其逼近真实速度 (z1 - actions)。
    # 处理流程包括:
    # 1)按采样策略为每个样本采样时间 t(形状 [B]);
    # 2)按 Action Type 在各自有效动作维生成噪声 z1(形状 [B, T, A_max]);
    # 3)构造线性插值 zt = (1 - t) * actions + t * z1;
    # 4)前向传播得到 vtheta = dit_forward(zt, t, cond);
    # 5)仅在有效动作维上计算 MSE,忽略无效/填充维。
    # 维度约定:
    # - actions 形状:[B, T, A_max](若输入为 [B, 1, T, A_max],将先 squeeze 到三维);
    # - action_type 形状:[B];
    # - t 形状:[B],texp 形状:[B, 1, 1],用于广播到动作张量;
    # - 输出 loss 为标量,losses_dict 为日志统计字典。
    def rf_loss(self, cond: dict, actions: torch.Tensor, dataset_idx: Any = None) -> Tuple[torch.Tensor, Dict[str, Any]]:
        """
        Computes the rectified flow loss.
        Interpolates between actions and noise, then computes MSE only over valid dimensions.
        """
        # 对齐模型默认精度,确保后续新建张量与模型参数 dtype 一致。
        default_dtype = next(self.parameters()).dtype
        # 从条件字典读取每个样本对应的动作空间类型索引。
        # 形状为 [B],每个样本 1 个整型 Action Type。
        action_type = cond['action_type']
        # 兼容可能的四维输入:[B, 1, T, A_max] -> [B, T, A_max]。
        if len(actions.shape) == 4:
            actions = actions.squeeze(1)
        b = actions.size(0)
        device = actions.device
        actions = actions.to(default_dtype)

        # 按配置的 Sampling Type 采样时间 t,每个样本一个标量,范围在 [0, 1)。
        # sampling_type 来源于 FlowerVLA 构造参数 sampling_type(默认值为 "ln"),
        # 可通过 conf/trainer/agent/flower_vla.yaml 的 agent.sampling_type 进行覆盖。
        if self.sampling_type == "pi_zero":
            alpha, beta = 1.5, 1.0
            t = torch.distributions.Beta(alpha, beta).sample((b,)).to(device).clamp(max=0.999)
        elif self.sampling_type == "ln":
            t = torch.sigmoid(torch.randn((b,), device=device)).clamp(max=0.999).to(default_dtype)
        elif self.sampling_type == "uniform":
            # eps 用于避免采样到 1.0,后续插值或数值计算时可减少边界问题。
            eps = 1e-5
            # 先生成 [0, 1) 的随机偏移量 rand,再加上等间隔网格 arange(b)/b,
            # 得到长度为 b 的“分层均匀采样”时间点(Stratified Uniform Sampling)。
            # 取模 (1 - eps) 可将越界值回卷到区间内,最终 t 的形状为 [B],范围约在 [0, 1-eps)。
            t = (torch.rand(1, device=device) + torch.arange(b, device=device) / b) % (1 - eps)
            # 与模型其余张量精度对齐。
            t = t.to(default_dtype)
        else:
            raise NotImplementedError(f"Sampling type {self.sampling_type} not implemented")
        # 将 t 扩展为可广播形状 [B, 1, 1](对于三维 actions)。
        texp = t.view([b] + [1] * (actions.dim() - 1))
        # z1 为噪声端点,初始化为全 0;后续仅在有效动作维写入高斯噪声。
        z1 = torch.zeros_like(actions)
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            mask = (action_type == action_idx)
            if mask.any():
                adim = self.action_space_index.get_action_dim(action_idx)
                # 仅为当前 Action Space 的有效维 [:adim] 采样噪声。
                # 维度对齐关系:
                # - mask.sum() = B_group(当前 Action Type 选中的样本数);
                # - noise_slice 形状为 [B_group, T, adim];
                # - z1[mask, :, :adim] 先按 mask 选 Batch 子集,再取全部时间步与前 adim 维,
                #   左值切片形状同样是 [B_group, T, adim],因此可直接赋值。
                noise_slice = torch.randn((mask.sum(), actions.size(1), adim), dtype=actions.dtype, device=actions.device)
                z1[mask, :, :adim] = noise_slice
        # 线性插值到中间点 zt:
        # - t=0 对应数据端(actions);
        # - t=1 对应噪声端(z1)。
        # 因此 zt = (1 - t) * actions + t * z1。
        zt = (1 - texp) * actions + texp * z1
        # 预测速度场 vtheta,形状与 actions/z1 一致:[B, T, A_max]。
        vtheta = self.dit_forward(zt, t, cond)
        # 构造有效维掩码:当前样本对应 Action Space 的前 adim 维为 True,其他维为 False。
        valid_mask = torch.zeros_like(actions, dtype=torch.bool)
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            mask = (action_type == action_idx)
            if mask.any():
                # 当前 Action Space 的有效动作维度。
                adim = self.action_space_index.get_action_dim(action_idx)
                # mask 原始形状为 [B],表示哪些样本属于当前 Action Type。
                # 先 View 成 [B, 1, 1],再 Expand 到 [B, T, adim],使其与
                # valid_mask 在“时间维 + 有效动作维”上的目标切片形状对齐。
                mask_expanded = mask.view(-1, 1, 1).expand(-1, actions.size(1), adim).to(device)
                # 左值 valid_mask[mask, :, :adim] 形状为 [B_group, T, adim];
                # 右值 mask_expanded[mask] 形状同样为 [B_group, T, adim]。
                # 赋值后,当前 Action Type 样本在前 adim 维被置为 True,其他位置保持 False。
                valid_mask[mask, :, :adim] = mask_expanded[mask]
        # 监督 Label(目标速度场)定义为 v* = z1 - actions,形状为 [B, T, A_max]。
        # vtheta 形状同为 [B, T, A_max],因此 diff = v* - vtheta 形状也是 [B, T, A_max]。
        diff = (z1 - actions) - vtheta
        # 只在有效动作维取值:valid_diff 为一维展平向量 [N_valid],N_valid 为有效元素总数。
        valid_diff = diff[valid_mask]
        loss = (valid_diff ** 2).mean()
        # 记录训练可观测统计量,便于排查数值范围与稳定性。
        losses_dict = {
            "diff_min": valid_diff.min().item(),
            "diff_max": valid_diff.max().item(),
            "diff_mean": valid_diff.mean().item(),
            "loss": loss.item(),
        }
        if hasattr(self, 'accelerator') and self.accelerator is not None and wandb.run is not None:
            if self.accelerator.is_main_process:
                wandb.log(losses_dict)
        return loss, losses_dict

2. dit_forward()

    # 通过 DiT 块前向传播,编码动作,添加位置信息,以及应用条件。
    #
    # 输入与形状约定:
    # - z:[B, T_act, A_max],噪声/中间动作序列;
    # - t:[B],时间步;
    # - cond_dict["features"]:[B, N_cond, D_vlm];
    # - cond_dict["frequency_embeds"]:[B, D_dit];
    # - cond_dict["action_type"]:[B];
    # - cond_dict["proprio"](可选):[B, D_prop];
    # - cond_dict["attention_mask"]:[B, N_cond]。
    def dit_forward(self, z: torch.Tensor, t: torch.Tensor, cond_dict: dict) -> torch.Tensor:
        """
        Forward pass through the DiT blocks.
        Encodes actions, adds positional information, and applies conditioning.
        """
        # z 形状拆解:Batch、动作序列长度、动作维。
        B, t_seq, d = z.shape
        # 读取模型参数默认 dtype。
        dtype = next(self.parameters()).dtype
        # 条件分支:features 从 [B, N_cond, D_vlm] 归一化并且线性映射到 [B, N_cond, D_dit]。
        cond = self.cond_linear(self.cond_norm(cond_dict['features'].to(dtype)))
        # Frequency Embeds:FreqEmbedder 输出 [B, D_dit],与 Time / Proprio 全局条件同形。
        # 后续主要有三种用途:
        # 1)作为条件输入,与 Time Embedding、Proprio Embeds 相加形成 t_emb;
        # 2)在训练阶段参与 CFG Dropout(可按样本置零);
        # 3)作为形状参考,比如 torch.zeros_like(freq_embeds),创建对齐张量。
        freq_embeds = cond_dict['frequency_embeds'].squeeze(1).to(dtype)
        # action_type 的形状是 [B],用于后续按 Batch 维进行布尔分组索引。
        action_type = cond_dict['action_type'].to(self.device)
        # proprio 用于生成本体条件向量:
        # - 启用 use_proprio 时:读取 cond_dict["proprio"];
        # - 否则:退化为与 freq_embeds 同形状的全 0 张量 [B, D_dit]。
        proprio = cond_dict.get('proprio', torch.zeros_like(freq_embeds)).to(dtype) if self.use_proprio else torch.zeros_like(freq_embeds)
        # encode_proprio 输出 [B, D_dit],与 freq_embeds、t_embed 可逐元素相加。
        proprio_embeds = self.encode_proprio(proprio, action_type, freq_embeds.shape).to(dtype)
        
        # 按 Action Type 编码动作:
        # z:[B, T_act, A_max] -> [B, T_act, D_dit];
        # valid_dims:[B, T_act, A_max],记录各样本有效动作维。
        z, valid_dims = self.encode_actions(z, action_type)
        # 若使用绝对位置编码(非 RoPE/NoPE),则与 z 逐元素相加。
        # 要求 self.positional_encoding 可广播或同形到 [B, T_act, D_dit]。
        # 该相加不会改变 z 的形状;若位置编码长度大于当前 T_act,PyTorch 不会自动截断,
        # 通常会因维度不匹配报错,需要显式切片到 [:, :T_act, :] 才能对齐。
        if not (self.use_rope or self.use_nope):
            z += self.positional_encoding
        
        # 仅对频率与本体条件执行 CFG Dropout,不改变 z 主分支:
        # drop_mask:[B, 1],沿特征维广播到 [B, D_dit]。
        if self.training and self.cfg_dropout > 0:
            # 按样本生成二值掩码(1=Drop,0=Keep)。
            drop_mask = (torch.rand(freq_embeds.size(0), device=freq_embeds.device) < self.cfg_dropout).float().unsqueeze(1)
            freq_embeds = freq_embeds * (1 - drop_mask)
            proprio_embeds = proprio_embeds * (1 - drop_mask)
        # 时间与全局条件融合,t_embedder(t) 输出 [B, D_dit],三项相加后 t_emb 仍为 [B, D_dit]。
        t_emb = sum(map(stateless_norm, [self.t_embedder(t), freq_embeds, proprio_embeds]))
        
        # 计算全局条件 global_cond,供 AdaLN 使用:
        # - cond:[B, N_cond, D_dit],readout 取 cond[:,0,:],否则对 N_cond 维计算 mean → 均为 [B, D_dit];
        # - t_emb:[B, D_dit];二者同形,+= 为逐样本逐维相加。
        # - use_adaln_cond=False 时仅用 t_emb。
        if self.use_adaln_cond:
            global_cond = cond[:, 0, :] if self.use_readout_token else cond.mean(dim=1)
            global_cond += t_emb
        else:
            global_cond = t_emb
        
        # Cross-Attention 上下文:
        # - 开启时 context=cond([B, N_cond, D_dit]);
        # - 关闭时 context=None。
        context = cond if self.use_cross_attn else None
        
        # AdaLN 调制信号:
        # - action_type_adaln=False:共享调制;
        # - action_type_adaln=True:按 Action Type 生成分组调制。
        global_adaln = self.adaln(global_cond) if not self.action_type_adaln else self.action_specific_adaln(global_cond, action_type)

        # 逐层 DiT Block 前向:
        # 输入 z 始终保持 [B, T_act, D_dit];
        # custom_cross_attn_mask 使用 cond_dict["attention_mask"]([B, N_cond])。
        for layer in self.dit:
            z = layer(z, global_cond, context=context, custom_attn_mask=None,
                    custom_cross_attn_mask=cond_dict['attention_mask'], is_causal=True, global_adaln=global_adaln)
        
        # 解码回动作空间,同时应用有效维约束:
        # [B, T_act, D_dit] -> [B, T_act, A_max]。
        return self.decode_actions(z, action_type, valid_dims)
    
    # 计算动作专属的 AdaLN 调制信号。
    # 输入:
    # - global_cond:全局条件张量,形状为 [B, D_dit];
    # - action_type:每个样本对应的动作类型索引,形状为 [B];
    # 输出:
    # - List[Tensor],其中每个 Tensor 表示一条调制分量通路的全 Batch 结果。
    def action_specific_adaln(self, global_cond: torch.Tensor, action_type: torch.Tensor) -> List[torch.Tensor]:
        """
        Computes action-specific AdaLN modulation signals.
        Returns a list of modulation tensors.
        """
        dtype = next(self.parameters()).dtype
        batch_size = global_cond.shape[0]
        num_chunks = 9 if self.use_cross_attn else 6
        
        # 先为每路 AdaLN 调制信号分量预分配 [B, D_dit] 的零张量;
        # 后续将按动作类型用 mask 只覆盖对应样本位置。
        mod_signals = [torch.zeros(batch_size, self.dit_dim, device=self.device, dtype=dtype) for _ in range(num_chunks)]
        
        # 每次迭代针对一种动作类型,为该类型样本生成及回填对应的调制信号。
        for action_idx in range(len(self.action_space_index.action_spaces)):
            # 按当前 action_idx 取子批,mask 的形状为 [B],True 表示该样本属于当前动作类型。
            mask = (action_type == action_idx)
            # 仅在当前类型确有样本时才执行对应 AdaLN 控制器,避免空张量前向传播。
            if mask.any():
                action_name = self.action_space_index.get_action_name(action_idx)
                # global_cond[mask] 的形状为 [B_g, D_dit](B_g 为该类型样本数);
                # action_mod 是长度为 num_chunks 的列表,每一路的形状均为 [B_g, D_dit]。
                action_mod = self.adaln[action_name](global_cond[mask])
                # i 对应“第 i 路调制信号”,比如 shift/scale/gate 等不同分量,
                # signal 的形状为 [B_g, D_dit],与当前动作类型子批一一对应。
                for i, signal in enumerate(action_mod):
                    # 将当前动作类型子批的结果(signal,形状为 [B_g, D_dit]),
                    # 按 mask 回填到全 Batch 容器 mod_signals[i]([B, D_dit])的对应行。
                    # 非当前类型(mask=False)的行不会被改动,继续保留已有值。
                    mod_signals[i][mask] = signal
        
        return mod_signals

3. 推理

    # 推理阶段前向传播:
    # 1) 将输入的观测 obs 与任务目标 goal 组装为 batch;
    # 2) 调用 encode_observations 提取条件特征;
    # 3) 以高斯噪声作为初始动作序列,通过 sample_actions 在条件约束下逐步采样,
    #    最终得到模型预测的动作序列,形状通常为 [B, T, D]。
    def forward(self, obs: Dict, goal: Dict) -> torch.Tensor:
        """
        Inference forward pass.
        Given observation and goal dictionaries, it encodes them and samples an action sequence.
        """
        batch = {'observation': obs, 'task': goal}
        features = self.encode_observations(batch)
        # 构造推理初始噪声 z0 ~ N(0, I):
        # 第一维 B=len(features['features']),表示当前 Batch 中每个样本各自对应一条初始噪声动作序列。
        # 第二维为 act_window_size(动作序列长度)。
        # 第三维为 action_dim,来自于 conf/trainer/agent/flower_vla.yaml 中的 action_dim: ${act_dim},
        # act_dim 的定义在 conf/training.yaml。
        # device 与特征保持一致,避免后续采样阶段发生跨设备张量运算错误。
        noise = torch.randn(len(features['features']), self.act_window_size, self.action_dim,
                              device=features['features'].device)
        return self.sample_actions(noise, features, inference=True)

    # 从 DiT 模型中采样动作。
    def sample_actions(self, z: torch.Tensor, cond: Dict[str, torch.Tensor], inference: bool = False) -> torch.Tensor:
        """
        Samples actions from the DiT model.
        Chooses between an adaptive ODE solver and fixed-step Euler integration.
        """
        steps = self.num_sampling_steps if inference else 5
        b = z.size(0)
        # 按各自动作类型,将每个样本的无效动作维度清零。
        action_type = cond['action_type']
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            mask = (action_type == action_idx)
            if mask.any():
                adim = self.action_space_index.get_action_dim(action_idx)
                z[mask, :, adim:] = 0.0
        # 当前文件内 use_dopri5 仅在初始化时设为 False,且未被改写,
        # 因此默认走 fixed-step 分支,除非外部代码在运行时手动修改该属性。
        if hasattr(self, 'use_dopri5') and self.use_dopri5:
            return self._sample_with_adaptive_solver(z, cond)
        else:
            return self._sample_with_fixed_steps(z, cond, inference)

    # 使用固定步长的 Euler 积分采样动作。
    def _sample_with_fixed_steps(self, z: torch.Tensor, cond: Dict[str, torch.Tensor], inference: bool = False) -> torch.Tensor:
        """
        Samples actions using fixed-step Euler integration.
        """
        steps = self.num_sampling_steps if inference else 5
        b = z.size(0)
        device = z.device
        action_type = cond['action_type']
        # 固定步长 Euler 的时间步长:将区间 [0, 1] 均分为 steps 份。
        dt = 1.0 / steps
        # 将标量 dt 扩展为可广播张量 [B, 1, ..., 1],
        # 以便与速度项 vc(通常为 [B, T, D])逐元素相乘得到步进增量,再用于更新 z。
        dt_tensor = torch.tensor([dt] * b, device=device).view([b] + [1] * (z.dim() - 1))
        
        # 仅在推理阶段且 `cfg_lambda` != 1.0 时启用 CFG。
        # `cfg_lambda` 在本类 `_init_flags` 中默认初始化为 1.0(默认不启用 CFG),
        # 若需要启用 CFG,需在运行时将 `self.cfg_lambda` 设为其他值。
        apply_cfg = inference and self.cfg_lambda != 1.0
        
        # 在循环外仅构造一次 CFG 的“无文本条件分支”条件,避免每步重复创建。
        if apply_cfg:
            null_cond = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in cond.items()}
            
            # 先取 `features` 的副本,再修改副本,最后写回。
            features = null_cond['features'].clone()
            
            # 计算文本特征起始位置:
            # `features` 按 [Prompt Token][Image Token][Text Token] 的顺序拼接,
            # 因此 `text_start = prompt_length + image_length`。
            prompt_length = self.prompt_embeds.shape[1]
            image_length = 50 if self.use_second_view is False else 100
            
            text_start = prompt_length + image_length
            # 仅将文本区段置零,保留 Prompt 和图像条件,构造“弱条件/近无条件”输入供 CFG 使用。
            features[:, text_start:, :] = 0.0
            
            # 将修改后的 `features` 写回 `null_cond`,供后续计算 `vu`(Unconditional Velocity)。
            null_cond['features'] = features
        
        for i in range(steps, 0, -1):
            # 将当前步号 i 转成时间值 t=i/steps,作为本步的“时间标签”。
            # 比如当 steps=4 时,t 依次为 1.0、0.75、0.5、0.25(从 1 向 0 方向推进)。
            t_val = i / steps
            # 将标量 t 扩展为 [B] 形状,使 Batch 内每个样本在该步使用同一时间输入。
            t_tensor = torch.full((b,), t_val, device=device)
            
            # 先用条件输入 cond 预测速度 vc(Conditional Velocity)。
            vc = self.dit_forward(z, t_tensor, cond)
            
            # 若启用 CFG,则再计算无条件速度 vu,然后做线性融合:
            # vc <- vu + cfg_lambda * (vc - vu)。
            # 其中 (vc - vu) 表示条件信息带来的引导方向,cfg_lambda 越大引导越强。
            if apply_cfg:
                vu = self.dit_forward(z, t_tensor, null_cond)
                vc = vu + self.cfg_lambda * (vc - vu)
            
            # Euler 反向积分更新:z <- z - dt * vc。
            # 逐步将 z 从噪声端拉回到动作端。
            z = z - dt_tensor * vc
            
            # 每个 Euler 步更新后,按动作类型再次将无效动作维度清零,
            # 防止这些维度在迭代过程中被数值更新“污染”。
            for action_name, action_idx in self.action_space_index.action_spaces.items():
                mask = (action_type == action_idx)
                if mask.any():
                    adim = self.action_space_index.get_action_dim(action_idx)
                    z[mask, :, adim:] = 0.0
                    
        # 将输出限制在 [-1, 1],与训练时的动作归一化范围保持一致,
        # 同时抑制采样积分产生的少量越界值,提升推理稳定性。
        return z.clamp(-1, 1)