FlowerVLA 源码解析 2 - FlowerVLA Agent 初始设置及编码/解码方法

源码:https://github.com/intuitive-robots/flower_vla_pret/blob/main/flower_vla/agents/flower.py


1. 初始设置

  • 说明

    很多组件只看定义,很难形成深入认知;因此先不用纠结它们的具体作用。等读到 dit_forwardencoding_* 等实际使用这些组件的方法时,再回过头复看一遍,理解将更清晰。

1.1. 设置 VLM

    # 初始化及配置 VLM 相关组件。
    # 参数:
    # - vlm_path:预训练 Florence VLM 权重路径(本地或 HuggingFace 路径);
    # - freeze_vision_tower:是否冻结视觉塔(True=保持冻结,False=允许视觉塔参与训练);
    # - freeze_florence:是否冻结整套 VLM 参数(优先级最高);
    # - freeze_embeddings_only:仅冻结词嵌入层(当 freeze_florence=False 时生效)。
    def _setup_vlm(self, vlm_path: str, freeze_vision_tower: bool, freeze_florence: bool, freeze_embeddings_only: bool) -> None:
        """
        Loads the pretrained VLM, sets up the processor/tokenizer, adds a prompt token,
        and optionally freezes parameters.
        """
        # 加载预训练 VLM:
        # - AutoModelForCausalLM 根据 vlm_path 中的配置自动选择及实例化对应模型类。
        logger.info(f"Loading VLM from {vlm_path}")
        self.vlm = AutoModelForCausalLM.from_pretrained(vlm_path, trust_remote_code=True)
        self.train_vlm = not freeze_florence
        
        # 策略 A:冻结整套 VLM 参数。
        if freeze_florence:
            for param in self.vlm.parameters():
                param.requires_grad = False
        # 策略 B:仅冻结输入 Embedding(常用于保护词表语义稳定性)。
        elif freeze_embeddings_only:
            embedding_layer = self.vlm.get_input_embeddings()
            for param in embedding_layer.parameters():
                param.requires_grad = False
            # 在某些实现中 language_model.shared 与输入 Embedding 共享权重,需要同步冻结。
            if hasattr(self.vlm.language_model, 'shared'):
                for param in self.vlm.language_model.shared.parameters():
                    param.requires_grad = False
        
        # `vision_tower` 是 VLM 中负责图像编码的视觉分支(将图像转成 Visual Token Sequence/特征)。
        # 若显式要求训练视觉塔,则将该分支参数设为可训练;
        # 这一步可覆盖上面的全量冻结策略,实现“仅解冻视觉塔、其余分支继续冻结”的配置。
        if not freeze_vision_tower:
            for param in self.vlm.vision_tower.parameters():
                param.requires_grad = True
        
        # 初始化 Processor/Tokenizer:
        # - Processor:多模态预处理入口,统一封装图像/文本输入的预处理规则,比如 Resize、Normalize、Padding 等;
        # - Tokenizer:文本侧子组件,负责 Text -> Token ID 的切分与映射,也用于后续添加 Special Token,比如 "<Flow>"。
        self.processor = AutoProcessor.from_pretrained(vlm_path, trust_remote_code=True)
        self.tokenizer = self.processor.tokenizer
        # `prompt_embeds` 即新增加的 "<Flow>" Token 对应的 Embedding(冻结参数)。
        self.prompt_embeds = self._create_prompt_embed("<Flow>").to(self.device)
        # 删除语言生成分支,仅保留编码主干:
        # - `decoder`:自回归生成时用于逐步解码文本;
        # - `lm_head`:将隐藏状态投影到词表维度,输出每个 Token 的 Logit,用于下一词预测。
        # 本项目主要使用中间特征做条件编码,不走文本生成,因此可移除,以减少参数与显存占用。
        del self.vlm.language_model.model.decoder, self.vlm.language_model.lm_head
        # 编码阶段可选的 Token Dropout(正则化)。
        self.vlm_token_dropout = nn.Dropout(self.token_dropout)


    # 创建 Prompt Embedding:
    # 1)将 prompt_text 注册为 Tokenizer 的额外 Special Token;
    # 2)按扩容后的词表大小重置 VLM 的输入 Embedding 矩阵;
    # 3)取出该 Token 对应的 Embedding,并且以冻结参数返回(requires_grad=False)。
    # 参数:
    #   prompt_text:Prompt 字符串,当前文件中的调用为 _create_prompt_embed("<Flow>")。
    # 返回:
    #   初始 Token 向量为 [D],两次 unsqueeze 后为 [1, 1, D],后续可 Expand 到 [B, 1, D] 与其他特征拼接。
    def _create_prompt_embed(self, prompt_text: str) -> nn.Parameter:
        """
        Creates a prompt embedding. Adds the prompt token to the tokenizer
        and returns its embedding (frozen).
        """
        self.tokenizer.add_special_tokens({'additional_special_tokens': [prompt_text]})
        self.vlm.resize_token_embeddings(len(self.tokenizer))
        prompt_token_id = self.tokenizer.convert_tokens_to_ids(prompt_text)
        prompt_embed = nn.Parameter(
            self.vlm.get_input_embeddings()(torch.tensor(prompt_token_id)),
            requires_grad=False
        )
        return prompt_embed.unsqueeze(0).unsqueeze(0)

1.2. 设置 DiT 组件

    # 设置 DiT 组件,包括动作专属的编码器、解码器,以及共享的条件组件。
    # 参数:
    # - dit_dim:DiT 主干与动作表示的统一隐藏维度;
    # - n_heads:每个 Transformer Block 的多头注意力头数;
    # - n_layers:DiT 堆叠的 FlowBlock 层数;
    # - action_dim:未使用。
    # - act_window_size:动作序列窗口长度;
    # - hidden_dim:来自上游多模态编码器的条件特征维度;
    # - attn_pdrop:注意力层 Dropout 概率;
    # - resid_pdrop:投影后的 Dropout 概率;
    # - mlp_pdrop:MLP 子层 Dropout 概率;
    # - use_cross_attn:是否在 FlowBlock/AdaLN 中启用 Cross Attention;
    # - use_rope:是否启用 RoPE;
    # - use_nope:是否禁用位置编码(No Positional Encoding);
    # - query_seq_len:RoPE 在查询侧预计算/支持的最大序列长度;
    # - rope_theta:RoPE theta 参数。
    def _setup_dit_components(
        self,
        dit_dim: int,
        n_heads: int,
        n_layers: int,
        action_dim: int,
        act_window_size: int,
        hidden_dim: int,
        attn_pdrop: float,
        resid_pdrop: float,
        mlp_pdrop: float,
        use_cross_attn: bool,
        use_rope: bool,
        use_nope: bool,
        query_seq_len: int,
        rope_theta: float
    ) -> None:
        """
        Sets up the Diffusion Transformer (DiT) components including action-specific 
        encoders/decoders and shared conditioning components.
        """
        # action name -> action encoder
        self.action_encoders = nn.ModuleDict()
        # action name -> action decoder
        self.action_decoders = nn.ModuleDict()
        if self.use_proprio:
            # action name -> proprio encoder
            self.proprio_encoders = nn.ModuleDict()
        # action_type_adaln=True 时,为每个动作类型维护独立的 AdaLN 控制器,
        # 否则后面将创建共享的 AdaLN 控制器。
        self.adaln = nn.ModuleDict() if self.action_type_adaln else None

        # 逐个动作空间创建“编码-解码-条件调制”组件。
        # action_space_index.action_spaces 的 key 是动作名称,value 是动作索引。
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            # 每个动作空间有不同输入维度,需动态查询。
            input_dim = self.action_space_index.get_action_dim(action_idx)
            
            # 动作编码器:将原始动作向量提升到 dit_dim,以便与 DiT 主干对齐。
            self.action_encoders[action_name] = Mlp(
                in_features=input_dim,
                hidden_features=dit_dim,
                out_features=dit_dim,
                bias=True
            )
            # 动作解码器:将 DiT 隐状态投影回该动作空间的原始维度。
            self.action_decoders[action_name] = nn.Linear(dit_dim, input_dim).to(self.device)

            # 动作专属 AdaLN:为不同动作类型提供独立条件调制参数,避免多任务/多动作空间间的统计干扰。
            if self.action_type_adaln:
                self.adaln[action_name] = SharedAdaLNController(
                    dit_dim,
                    global_conddim=dit_dim,
                    use_cross_attn=use_cross_attn
                )

            # 本体感觉编码器(可选):
            # - bimanual_nav:使用 MLP 编码实际 proprio 输入;
            # - 其他动作:使用 ZeroEncoder 占位,保持接口一致。
            if self.use_proprio:
                if action_name == 'bimanual_nav':
                    self.proprio_encoders[action_name] = Mlp(
                        input_dim,
                        dit_dim,
                        out_features=dit_dim,
                        drop=0.2
                    ).to(self.device)
                else:
                    self.proprio_encoders[action_name] = ZeroEncoder(
                        self.dit_dim,
                        device=self.device
                    )

        # 若未启用动作专属 AdaLN,则构建全局共享的 AdaLN 控制器。
        if not self.action_type_adaln:
            self.adaln = SharedAdaLNController(
                dit_dim,
                global_conddim=dit_dim,
                use_cross_attn=use_cross_attn
            )

        # 条件编码相关模块:
        # - cond_norm = RmsNorm(hidden_dim):对 VLM 条件特征按最后一维做 RMS 归一化,稳定尺度;
        # - cond_linear = Linear(hidden_dim -> dit_dim, bias=False):将条件通道映射到 DiT 隐空间维度;
        self.cond_linear = nn.Linear(hidden_dim, dit_dim, bias=False)
        self.t_embedder = TimestepEmbedder(dit_dim)
        self.cond_norm = RmsNorm(hidden_dim)
        self.frequency_embedder = FreqEmbedder(dit_dim)
        self.action_space_embedder = ActionSpaceEmbedderParameter(
            dit_dim,
            max_actions=len(self.action_space_index.action_spaces)
        )

        # 绝对位置编码,仅在未启用 RoPE / NoPE 时生效:
        # `torch.randn(1, act_window_size, dit_dim)` 的形状含义为:
        # - 第 1 维 `1`:批维占位,便于在前向中与 `z`(形状 [B, T, D])按批次自动广播;
        #   即同一组位置编码将共享给当前 Batch 中所有样本。
        # - 第 2 维 `act_window_size`:动作序列长度 T,每个时间步有独立的位置向量。
        # - 第 3 维 `dit_dim`:每个位置向量的通道维 D,与 DiT 隐状态维度对齐,才能做逐元素相加。
        # 乘以 0.1 是缩小初始化幅度,避免训练初期位置项过大干扰主特征。
        if not use_rope and not use_nope:
            self.positional_encoding = nn.Parameter(
                torch.randn(1, act_window_size, dit_dim) * 0.1
            )

        # 构建 DiT 主干:由 n_layers 个 FlowBlock 组成,所有层共享相同超参数。
        self.dit = nn.ModuleList([
            FlowBlock(
                dim=dit_dim,
                heads=n_heads,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop,
                mlp_pdrop=mlp_pdrop,
                use_cross_attn=use_cross_attn,
                use_rope=use_rope,
                query_seq_len=query_seq_len,
                rope_theta=rope_theta
            ) for _ in range(n_layers)
        ])

2. 编码/解码方法

2.1. encode_observations()

    # 编码主视角图片(以及可选的第二视角图片)和文本目标。
    # 返回字典字段:
    # - features:VLM Encoder 输出的条件特征;
    # - frequency_embeds:Frequency Embeddings;
    # - action_space_embeds:Action Space Embeddings;
    # - action_type:Action Type 索引;
    # - proprio:本体感觉数据(若存在);
    # - attention_mask:与 features 对齐的 Attention Mask。
    def encode_observations(self, batch: Dict) -> Dict[str, torch.Tensor]:
        """
        Encodes primary (and optional second view) image observations and text goals.
        Returns a dictionary with:
            - 'features': Encoder outputs.
            - 'frequency_embeds': Frequency embeddings.
            - 'action_space_embeds': Action space embeddings.
            - 'action_type': Action type indices.
            - 'proprio': Proprioception data (if available).
            - 'attention_mask': Attention mask.
        """
        device = self.device
        # 从模型参数中读取默认精度类型,比如 float32/bfloat16/float16。
        default_dtype = next(self.parameters()).dtype
        # 默认配置项为 conf/training.yaml 的 obs_modalities(Hydra 键:obs_modalities)="observation"。
        image_tensor = batch[self.obs_modalities]['image_primary']
        # T 表示时间维度,即每个样本的连续观测帧数 / 序列长度。
        B, T, C, H, W = image_tensor.shape
        # image_tensor: [B, T, C, H, W] -> view 后: [B*T, C, H, W]。
        # vlm._encode_image 的作用是将每帧图像编码为视觉 Token 序列。
        # vlm._encode_image 的输出的维度是 [B*T, N_img, D](N_img 是每帧的视觉 Token 数, D 是特征维度)。
        image_features = self.vlm._encode_image(
            image_tensor.view(-1, C, H, W).to(device).to(default_dtype)
        )
        # 还原 Batch 结构,合并序列维:将 [B*T, N_img, D] 视作 [B, T, N_img, D] 后,
        # 将时间维 T 与每帧视觉 Token 维 N_img 展平为单一序列维,得到 [B, T*N_img, D]。
        image_features = image_features.view(B, T * image_features.shape[1], -1)


        # second_view_key 在代码中的默认值为 "image_wrist";
        # 在训练配置 conf/trainer/agent/flower_vla.yaml 中,
        # 通过 agent.second_view_key(Hydra 键:trainer.agent.second_view_key)默认覆盖为 "image_secondary"。
        if self.use_second_view and self.second_view_key in batch[self.obs_modalities]:
            image2_tensor = batch[self.obs_modalities][self.second_view_key]
            image2_features = self.vlm._encode_image(
                image2_tensor.view(-1, C, H, W).to(device).to(default_dtype)
            )
            image2_features = image2_features.view(B, T * image2_features.shape[1], -1)
            # dim=1 表示沿序列长度维(Token 维)拼接两路视觉特征:
            # [B, T*N_img, D] + [B, T*N_img2, D] -> [B, T*(N_img+N_img2), D]。
            # 不改变 Batch 维(dim=0)和特征维 D(dim=2)。
            image_features = torch.cat([image_features, image2_features], dim=1)


        # goal_modalities 来自 conf/training.yaml(Hydra 键:goal_modalities)且取值为 "task";
        # lang_modalities 来自 conf/training.yaml(Hydra 键:lang_modalities)且取值为 ["language_instruction"]。
        # 这里调用 VLM 的输入 Embedding 层,将语言 input_ids 映射为 Embedding。
        # input_ids 的常见形状为 [B, 1, L](或 [B, L]),映射后为 [B, 1, L, D_txt](或 [B, L, D_txt]);
        # squeeze(1) 用于移除长度为 1 的中间维,将形状统一为 [B, L, D_txt]。
        text_embeds = self.vlm.get_input_embeddings()(
            batch[self.goal_modalities][self.lang_modalities[0]]['input_ids'].to(device)
        ).to(device).squeeze(1)

        # 将 [1, 1, D] 的 Prompt Embedding 在 Batch 维复制为 [B, 1, D];
        # 其中 Expand(B, -1, -1) 仅扩展视图,不复制 Token 维与特征维。
        task_prompt = self.prompt_embeds.expand(B, -1, -1)
        # dim=1 表示按序列维度拼接 Prompt、视觉与文本特征:
        # [B, 1, D] + [B, N_vis, D] + [B, N_txt, D] -> [B, 1+N_vis+N_txt, D]。
        merged_embeds = torch.cat([task_prompt.to(image_features.device), image_features, text_embeds.to(image_features.device)], dim=1)

        # merged_embeds 的形状为 [B, N_total, D],其中 N_total = 1(Prompt)+ N_vis + N_txt。
        # 注意:下面这行当前未被实际使用。
        attention_mask = torch.ones(merged_embeds.shape[:2], device=merged_embeds.device)
        # 读取语言 Attention Mask,通常形状为 [B, 1, N_txt];
        # 经 squeeze(1) 后变为 [B, N_txt],与序列维对齐。
        lang_attention_mask = batch[self.goal_modalities][self.lang_modalities[0]]['attention_mask'].to(device).squeeze(1)
        # 视觉特征 image_features 的形状为 [B, N_vis, D],因此视觉 Mask 形状为 [B, N_vis],
        # 这里用全 1,表示视觉 Token 默认全部有效。
        vis_attention_mask = torch.ones(image_features.shape[:2], device=image_features.device)
        # Prompt 只有 1 个 Token,对应 Mask 形状为 [B, 1]。
        # 这里用全 0,表示 Prompt Token 在 Attention 中被显式屏蔽。
        prompt_mask = torch.zeros(B, 1, dtype=torch.bool, device=image_features.device)
        # 按序列维(dim=1)拼接三段 Mask:
        # [B, 1](Prompt) + [B, N_vis](Vision) + [B, N_txt](Text) -> [B, 1+N_vis+N_txt]。
        attention_mask = torch.cat([prompt_mask, vis_attention_mask, lang_attention_mask], dim=1)

        # 1)通过 VLM Encoder 对拼接后的多模态序列做上下文编码。
        # inputs_embeds 形状:[B, N_total, D],其中 N_total = 1(Prompt)+ N_vis + N_txt。
        # attention_mask 形状:[B, N_total],与 inputs_embeds 的序列维一一对应。
        # last_hidden_state 输出形状:[B, N_total, D],每个 Token 都变为上下文相关特征。
        features = self.vlm.get_encoder()(
            inputs_embeds=merged_embeds, 
            attention_mask=attention_mask,
        ).last_hidden_state

        # 2)对 Encoder 输出做 Token 级 Dropout(仅训练时生效,Eval 时自动关闭)。
        # 输入/输出形状保持不变:[B, N_total, D]。
        features = self.vlm_token_dropout(features)

        # 3)可选的 CFG 文本条件 Dropout:
        # 仅在训练阶段且 cfg_dropout > 0 时,按“样本级”随机屏蔽文本段特征,
        # 用于模拟无文本条件分支,提升 Classifier-Free Guidance 的稳定性。
        # 对应配置项为 conf/trainer/agent/flower_vla.yaml 的 agent.cfg_dropout
        # (Hydra 键为 trainer.agent.cfg_dropout,当前默认配置值为 0.0)。
        if self.cfg_dropout > 0 and self.training:
            # 三段长度:Prompt 段、Vision 段、Text 段。
            prompt_length = task_prompt.shape[1]
            image_length = image_features.shape[1]
            # Text 段长度。
            text_length = text_embeds.shape[1]
            text_start = prompt_length + image_length
            # 文本片段在 features 中的切片区间是 [text_start:text_end)。
            text_end = text_start + text_length
            # 为每个样本生成 1 个 Bernoulli 掩码,形状 [B, 1, 1],用于广播到文本段:
            # 取值 1 表示“丢弃该样本文本条件”,取值 0 表示“保留文本条件”。
            drop_mask = (torch.rand(B, device=device) < self.cfg_dropout).float().view(B, 1, 1)
            # 仅对文本段应用掩码,维持 Prompt/Vision 段不变:
            # features[:, text_start:text_end, :] *= (1 - drop_mask)。
            features[:, text_start:text_end, :] = features[:, text_start:text_end, :] * (1 - drop_mask)

        return {
            'features': features,
            'frequency_embeds': self.frequency_embedder(batch[self.goal_modalities]['frequency'].to(device).to(default_dtype)),
            'action_space_embeds': self.action_space_embedder(batch[self.goal_modalities]['action_space_index'].to(device)),
            'action_type': batch[self.goal_modalities]['action_space_index'],
            'proprio': batch[self.obs_modalities]['proprio'].to(device).to(default_dtype) if self.use_proprio and 'proprio' in batch[self.obs_modalities] else None,
            'attention_mask': attention_mask,
        }

2.2. encode_actions()

    # 按样本的 Action Type 将原始动作序列编码到统一的 DiT 隐空间。
    # 处理流程包括:
    # 1)根据 action_type 为 Batch 内样本分组;
    # 2)对每组样本仅截取其有效动作维 adim,送入对应的 Action Encoder;
    # 3)将编码结果写回统一形状的 encoded(最后一维为 dit_dim);
    # 4)同步构造 valid_dims 掩码,标记哪些维度是该样本的真实有效维度。
    # 维度约定:
    # - 输入 z 形状:[B, T, A_max](A_max 为最大动作维);
    # - 输入 action_type 形状:[B],每个样本 1 个动作空间索引;
    # - 输出 encoded 形状:[B, T, dit_dim];
    # - 输出 valid_dims 形状:[B, T, A_max],有效维位置为 1,无效维位置为 0。
    def encode_actions(self, z: torch.Tensor, action_type: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encodes actions for each sample based on its action type.
        Returns:
            - Encoded actions (latent representations).
            - A valid dimensions mask.
        """
        # 对齐模型默认精度,避免后续创建的新张量与模型参数 dtype 不一致。
        default_dtype = next(self.parameters()).dtype
        # 将动作类型索引放到当前设备,便于后续布尔掩码索引。
        action_type = action_type.to(self.device)
        B = z.shape[0]
        # 统一的编码输出容器:[B, T, dit_dim]。
        encoded = torch.zeros(B, z.shape[1], self.dit_dim, device=self.device, dtype=default_dtype)
        # 有效维掩码容器:[B, T, A_max];初始时全部为 0,后续仅将真实有效维标为 1。
        valid_dims = torch.zeros_like(z, dtype=default_dtype)
        # 遍历所有已注册动作空间;每个动作空间使用各自的 Encoder。
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            # mask 的形状为 [B],用于选出当前动作空间对应的样本。
            mask = (action_type == action_idx)
            if mask.any():
                # 当前动作空间的真实动作维度 adim(adim <= A_max)。
                adim = self.action_space_index.get_action_dim(action_idx)
                # valid_dims[mask, :, :adim] 的含义:
                # - 第 1 维(Batch 维)用 mask 选出当前 Action Type 的样本;
                # - 第 2 维(时间维)取全部 T;
                # - 第 3 维(动作维)取前 adim 个有效维置 1。
                valid_dims[mask, :, :adim] = 1
                # 仅编码有效维输入 z[mask, :, :adim],输出写回统一容器 encoded[mask]。
                encoded[mask] = self.action_encoders[action_name](z[mask, :, :adim])
        return encoded, valid_dims

2.3. decode_actions()

    # 将隐空间表示解码回真实动作。
    # 处理流程包括:
    # 1)根据 action_type 为 Batch 内样本分组;
    # 2)对每组样本使用对应的 Action Decoder,将 [B_group, T, dit_dim] 解码为动作预测;
    # 3)仅将该动作空间有效维 [:adim] 写回 decoded,其余维保持 0;
    # 4)用 valid_dims 进一步约束有效维,保证无效维不被误写入。
    # 维度约定:
    # - 输入 z 形状:[B, T, dit_dim];
    # - 输入 action_type 形状:[B];
    # - 输入 valid_dims 形状:[B, T, A_max](A_max 与 self.action_dim 对齐);
    # - 输出 decoded 形状:[B, T, A_max]。
    def decode_actions(self, z: torch.Tensor, action_type: torch.Tensor, valid_dims: torch.Tensor) -> torch.Tensor:
        """
        Decodes latent representations into actual actions.
        Only the dimensions corresponding to valid action spaces are active.
        """
        # 对齐模型默认精度,避免新建张量 dtype 与模型参数不一致。
        default_dtype = next(self.parameters()).dtype
        B = z.shape[0]
        # 动作维上限(A_max)。
        max_action_dim = self.action_dim
        # 解码输出容器:[B, T, A_max],初始为 0。
        decoded = torch.zeros(B, z.shape[1], max_action_dim, device=self.device, dtype=default_dtype)
        # 遍历所有已注册动作空间;每个动作空间使用各自的 Decoder。
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            # mask 的形状为 [B],用于选出当前动作空间对应的样本子集。
            mask = (action_type == action_idx)
            if mask.any():
                # 当前动作空间的真实动作维度 adim(adim <= A_max)。
                adim = self.action_space_index.get_action_dim(action_idx)
                # 对该样本子集执行解码,pred 形状通常为 [B_group, T, A_pred]。
                pred = self.action_decoders[action_name](z[mask])
                # 仅写回有效动作维 [:adim],并且乘以 valid_dims 做显式掩码约束:
                # decoded[mask, :, :adim] <- pred[..., :adim] * valid_dims[mask, :, :adim]。
                # 在当前实现下(decoded 初始全 0,且仅写回 :adim)这一乘法通常是冗余的安全保护。
                decoded[mask, :, :adim] = pred[..., :adim] * valid_dims[mask, :, :adim]
        return decoded

2.4. encode_proprio()

    # 按 Action Type 编码本体感觉数据。
    # 输入约定:
    # - proprio:形状为 [B, D_prop];
    # - action_type:形状 [B],用于按 Batch 维分组选择对应 Encoder;
    # - output_shape:调用方传入的对齐参考形状;
    # 输出:
    # - encoded:形状 [B, D_dit],可与 Timestep Embedding、Frequency Embedding 逐元素相加。
    # 流程:
    # 1)use_proprio=False:直接返回全 0 的 [B, D_dit];
    # 2)use_proprio=True:按 Action Type 分组编码,并且写回统一容器。
    def encode_proprio(self, proprio: torch.Tensor, action_type: torch.Tensor, output_shape) -> torch.Tensor:
        """
        Encodes proprioceptive data based on action type.
        Returns a tensor with shape [batch, dit_dim].
        """
        # 从 output_shape 获取 Batch 大小 B。
        batch_size, _ = output_shape
        dtype = next(self.parameters()).dtype
        
        # 不使用 Proprio 条件时,返回 [B, D_dit] 全 0 张量,保持与其他条件向量对齐。
        if not self.use_proprio:
            return torch.zeros(batch_size, self.dit_dim, device=self.device)
        
        # 统一输出容器:[B, D_dit]。
        encoded = torch.zeros(batch_size, self.dit_dim, device=self.device, dtype=dtype)
        # 遍历动作空间,每种 Action Type 使用对应 Proprio Encoder。
        for action_name, action_idx in self.action_space_index.action_spaces.items():
            # mask 的形状为 [B],选出当前动作空间对应的样本。
            mask = (action_type == action_idx)
            if mask.any():
                # proprio[mask] 的形状为 [B_group, D_prop];
                # 当前 Encoder 实现返回 [B_group, D_dit],squeeze(1) 属于兼容性写法,
                # 仅在存在长度为 1 的中间维时才会生效,最终写回 encoded[mask]。
                encoded[mask] = self.proprio_encoders[action_name](proprio[mask]).squeeze(1).to(dtype)
        
        return encoded