FlowerVLA 源码解析 2 - FlowerVLA Agent 初始设置及编码/解码方法
1. 初始设置
说明
很多组件只看定义,很难形成深入认知;因此先不用纠结它们的具体作用。等读到
dit_forward、encoding_*等实际使用这些组件的方法时,再回过头复看一遍,理解将更清晰。
1.1. 设置 VLM
# 初始化及配置 VLM 相关组件。
# 参数:
# - vlm_path:预训练 Florence VLM 权重路径(本地或 HuggingFace 路径);
# - freeze_vision_tower:是否冻结视觉塔(True=保持冻结,False=允许视觉塔参与训练);
# - freeze_florence:是否冻结整套 VLM 参数(优先级最高);
# - freeze_embeddings_only:仅冻结词嵌入层(当 freeze_florence=False 时生效)。
def _setup_vlm(self, vlm_path: str, freeze_vision_tower: bool, freeze_florence: bool, freeze_embeddings_only: bool) -> None:
"""
Loads the pretrained VLM, sets up the processor/tokenizer, adds a prompt token,
and optionally freezes parameters.
"""
# 加载预训练 VLM:
# - AutoModelForCausalLM 根据 vlm_path 中的配置自动选择及实例化对应模型类。
logger.info(f"Loading VLM from {vlm_path}")
self.vlm = AutoModelForCausalLM.from_pretrained(vlm_path, trust_remote_code=True)
self.train_vlm = not freeze_florence
# 策略 A:冻结整套 VLM 参数。
if freeze_florence:
for param in self.vlm.parameters():
param.requires_grad = False
# 策略 B:仅冻结输入 Embedding(常用于保护词表语义稳定性)。
elif freeze_embeddings_only:
embedding_layer = self.vlm.get_input_embeddings()
for param in embedding_layer.parameters():
param.requires_grad = False
# 在某些实现中 language_model.shared 与输入 Embedding 共享权重,需要同步冻结。
if hasattr(self.vlm.language_model, 'shared'):
for param in self.vlm.language_model.shared.parameters():
param.requires_grad = False
# `vision_tower` 是 VLM 中负责图像编码的视觉分支(将图像转成 Visual Token Sequence/特征)。
# 若显式要求训练视觉塔,则将该分支参数设为可训练;
# 这一步可覆盖上面的全量冻结策略,实现“仅解冻视觉塔、其余分支继续冻结”的配置。
if not freeze_vision_tower:
for param in self.vlm.vision_tower.parameters():
param.requires_grad = True
# 初始化 Processor/Tokenizer:
# - Processor:多模态预处理入口,统一封装图像/文本输入的预处理规则,比如 Resize、Normalize、Padding 等;
# - Tokenizer:文本侧子组件,负责 Text -> Token ID 的切分与映射,也用于后续添加 Special Token,比如 "<Flow>"。
self.processor = AutoProcessor.from_pretrained(vlm_path, trust_remote_code=True)
self.tokenizer = self.processor.tokenizer
# `prompt_embeds` 即新增加的 "<Flow>" Token 对应的 Embedding(冻结参数)。
self.prompt_embeds = self._create_prompt_embed("<Flow>").to(self.device)
# 删除语言生成分支,仅保留编码主干:
# - `decoder`:自回归生成时用于逐步解码文本;
# - `lm_head`:将隐藏状态投影到词表维度,输出每个 Token 的 Logit,用于下一词预测。
# 本项目主要使用中间特征做条件编码,不走文本生成,因此可移除,以减少参数与显存占用。
del self.vlm.language_model.model.decoder, self.vlm.language_model.lm_head
# 编码阶段可选的 Token Dropout(正则化)。
self.vlm_token_dropout = nn.Dropout(self.token_dropout)
# 创建 Prompt Embedding:
# 1)将 prompt_text 注册为 Tokenizer 的额外 Special Token;
# 2)按扩容后的词表大小重置 VLM 的输入 Embedding 矩阵;
# 3)取出该 Token 对应的 Embedding,并且以冻结参数返回(requires_grad=False)。
# 参数:
# prompt_text:Prompt 字符串,当前文件中的调用为 _create_prompt_embed("<Flow>")。
# 返回:
# 初始 Token 向量为 [D],两次 unsqueeze 后为 [1, 1, D],后续可 Expand 到 [B, 1, D] 与其他特征拼接。
def _create_prompt_embed(self, prompt_text: str) -> nn.Parameter:
"""
Creates a prompt embedding. Adds the prompt token to the tokenizer
and returns its embedding (frozen).
"""
self.tokenizer.add_special_tokens({'additional_special_tokens': [prompt_text]})
self.vlm.resize_token_embeddings(len(self.tokenizer))
prompt_token_id = self.tokenizer.convert_tokens_to_ids(prompt_text)
prompt_embed = nn.Parameter(
self.vlm.get_input_embeddings()(torch.tensor(prompt_token_id)),
requires_grad=False
)
return prompt_embed.unsqueeze(0).unsqueeze(0)1.2. 设置 DiT 组件
# 设置 DiT 组件,包括动作专属的编码器、解码器,以及共享的条件组件。
# 参数:
# - dit_dim:DiT 主干与动作表示的统一隐藏维度;
# - n_heads:每个 Transformer Block 的多头注意力头数;
# - n_layers:DiT 堆叠的 FlowBlock 层数;
# - action_dim:未使用。
# - act_window_size:动作序列窗口长度;
# - hidden_dim:来自上游多模态编码器的条件特征维度;
# - attn_pdrop:注意力层 Dropout 概率;
# - resid_pdrop:投影后的 Dropout 概率;
# - mlp_pdrop:MLP 子层 Dropout 概率;
# - use_cross_attn:是否在 FlowBlock/AdaLN 中启用 Cross Attention;
# - use_rope:是否启用 RoPE;
# - use_nope:是否禁用位置编码(No Positional Encoding);
# - query_seq_len:RoPE 在查询侧预计算/支持的最大序列长度;
# - rope_theta:RoPE theta 参数。
def _setup_dit_components(
self,
dit_dim: int,
n_heads: int,
n_layers: int,
action_dim: int,
act_window_size: int,
hidden_dim: int,
attn_pdrop: float,
resid_pdrop: float,
mlp_pdrop: float,
use_cross_attn: bool,
use_rope: bool,
use_nope: bool,
query_seq_len: int,
rope_theta: float
) -> None:
"""
Sets up the Diffusion Transformer (DiT) components including action-specific
encoders/decoders and shared conditioning components.
"""
# action name -> action encoder
self.action_encoders = nn.ModuleDict()
# action name -> action decoder
self.action_decoders = nn.ModuleDict()
if self.use_proprio:
# action name -> proprio encoder
self.proprio_encoders = nn.ModuleDict()
# action_type_adaln=True 时,为每个动作类型维护独立的 AdaLN 控制器,
# 否则后面将创建共享的 AdaLN 控制器。
self.adaln = nn.ModuleDict() if self.action_type_adaln else None
# 逐个动作空间创建“编码-解码-条件调制”组件。
# action_space_index.action_spaces 的 key 是动作名称,value 是动作索引。
for action_name, action_idx in self.action_space_index.action_spaces.items():
# 每个动作空间有不同输入维度,需动态查询。
input_dim = self.action_space_index.get_action_dim(action_idx)
# 动作编码器:将原始动作向量提升到 dit_dim,以便与 DiT 主干对齐。
self.action_encoders[action_name] = Mlp(
in_features=input_dim,
hidden_features=dit_dim,
out_features=dit_dim,
bias=True
)
# 动作解码器:将 DiT 隐状态投影回该动作空间的原始维度。
self.action_decoders[action_name] = nn.Linear(dit_dim, input_dim).to(self.device)
# 动作专属 AdaLN:为不同动作类型提供独立条件调制参数,避免多任务/多动作空间间的统计干扰。
if self.action_type_adaln:
self.adaln[action_name] = SharedAdaLNController(
dit_dim,
global_conddim=dit_dim,
use_cross_attn=use_cross_attn
)
# 本体感觉编码器(可选):
# - bimanual_nav:使用 MLP 编码实际 proprio 输入;
# - 其他动作:使用 ZeroEncoder 占位,保持接口一致。
if self.use_proprio:
if action_name == 'bimanual_nav':
self.proprio_encoders[action_name] = Mlp(
input_dim,
dit_dim,
out_features=dit_dim,
drop=0.2
).to(self.device)
else:
self.proprio_encoders[action_name] = ZeroEncoder(
self.dit_dim,
device=self.device
)
# 若未启用动作专属 AdaLN,则构建全局共享的 AdaLN 控制器。
if not self.action_type_adaln:
self.adaln = SharedAdaLNController(
dit_dim,
global_conddim=dit_dim,
use_cross_attn=use_cross_attn
)
# 条件编码相关模块:
# - cond_norm = RmsNorm(hidden_dim):对 VLM 条件特征按最后一维做 RMS 归一化,稳定尺度;
# - cond_linear = Linear(hidden_dim -> dit_dim, bias=False):将条件通道映射到 DiT 隐空间维度;
self.cond_linear = nn.Linear(hidden_dim, dit_dim, bias=False)
self.t_embedder = TimestepEmbedder(dit_dim)
self.cond_norm = RmsNorm(hidden_dim)
self.frequency_embedder = FreqEmbedder(dit_dim)
self.action_space_embedder = ActionSpaceEmbedderParameter(
dit_dim,
max_actions=len(self.action_space_index.action_spaces)
)
# 绝对位置编码,仅在未启用 RoPE / NoPE 时生效:
# `torch.randn(1, act_window_size, dit_dim)` 的形状含义为:
# - 第 1 维 `1`:批维占位,便于在前向中与 `z`(形状 [B, T, D])按批次自动广播;
# 即同一组位置编码将共享给当前 Batch 中所有样本。
# - 第 2 维 `act_window_size`:动作序列长度 T,每个时间步有独立的位置向量。
# - 第 3 维 `dit_dim`:每个位置向量的通道维 D,与 DiT 隐状态维度对齐,才能做逐元素相加。
# 乘以 0.1 是缩小初始化幅度,避免训练初期位置项过大干扰主特征。
if not use_rope and not use_nope:
self.positional_encoding = nn.Parameter(
torch.randn(1, act_window_size, dit_dim) * 0.1
)
# 构建 DiT 主干:由 n_layers 个 FlowBlock 组成,所有层共享相同超参数。
self.dit = nn.ModuleList([
FlowBlock(
dim=dit_dim,
heads=n_heads,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
mlp_pdrop=mlp_pdrop,
use_cross_attn=use_cross_attn,
use_rope=use_rope,
query_seq_len=query_seq_len,
rope_theta=rope_theta
) for _ in range(n_layers)
])2. 编码/解码方法
2.1. encode_observations()
# 编码主视角图片(以及可选的第二视角图片)和文本目标。
# 返回字典字段:
# - features:VLM Encoder 输出的条件特征;
# - frequency_embeds:Frequency Embeddings;
# - action_space_embeds:Action Space Embeddings;
# - action_type:Action Type 索引;
# - proprio:本体感觉数据(若存在);
# - attention_mask:与 features 对齐的 Attention Mask。
def encode_observations(self, batch: Dict) -> Dict[str, torch.Tensor]:
"""
Encodes primary (and optional second view) image observations and text goals.
Returns a dictionary with:
- 'features': Encoder outputs.
- 'frequency_embeds': Frequency embeddings.
- 'action_space_embeds': Action space embeddings.
- 'action_type': Action type indices.
- 'proprio': Proprioception data (if available).
- 'attention_mask': Attention mask.
"""
device = self.device
# 从模型参数中读取默认精度类型,比如 float32/bfloat16/float16。
default_dtype = next(self.parameters()).dtype
# 默认配置项为 conf/training.yaml 的 obs_modalities(Hydra 键:obs_modalities)="observation"。
image_tensor = batch[self.obs_modalities]['image_primary']
# T 表示时间维度,即每个样本的连续观测帧数 / 序列长度。
B, T, C, H, W = image_tensor.shape
# image_tensor: [B, T, C, H, W] -> view 后: [B*T, C, H, W]。
# vlm._encode_image 的作用是将每帧图像编码为视觉 Token 序列。
# vlm._encode_image 的输出的维度是 [B*T, N_img, D](N_img 是每帧的视觉 Token 数, D 是特征维度)。
image_features = self.vlm._encode_image(
image_tensor.view(-1, C, H, W).to(device).to(default_dtype)
)
# 还原 Batch 结构,合并序列维:将 [B*T, N_img, D] 视作 [B, T, N_img, D] 后,
# 将时间维 T 与每帧视觉 Token 维 N_img 展平为单一序列维,得到 [B, T*N_img, D]。
image_features = image_features.view(B, T * image_features.shape[1], -1)
# second_view_key 在代码中的默认值为 "image_wrist";
# 在训练配置 conf/trainer/agent/flower_vla.yaml 中,
# 通过 agent.second_view_key(Hydra 键:trainer.agent.second_view_key)默认覆盖为 "image_secondary"。
if self.use_second_view and self.second_view_key in batch[self.obs_modalities]:
image2_tensor = batch[self.obs_modalities][self.second_view_key]
image2_features = self.vlm._encode_image(
image2_tensor.view(-1, C, H, W).to(device).to(default_dtype)
)
image2_features = image2_features.view(B, T * image2_features.shape[1], -1)
# dim=1 表示沿序列长度维(Token 维)拼接两路视觉特征:
# [B, T*N_img, D] + [B, T*N_img2, D] -> [B, T*(N_img+N_img2), D]。
# 不改变 Batch 维(dim=0)和特征维 D(dim=2)。
image_features = torch.cat([image_features, image2_features], dim=1)
# goal_modalities 来自 conf/training.yaml(Hydra 键:goal_modalities)且取值为 "task";
# lang_modalities 来自 conf/training.yaml(Hydra 键:lang_modalities)且取值为 ["language_instruction"]。
# 这里调用 VLM 的输入 Embedding 层,将语言 input_ids 映射为 Embedding。
# input_ids 的常见形状为 [B, 1, L](或 [B, L]),映射后为 [B, 1, L, D_txt](或 [B, L, D_txt]);
# squeeze(1) 用于移除长度为 1 的中间维,将形状统一为 [B, L, D_txt]。
text_embeds = self.vlm.get_input_embeddings()(
batch[self.goal_modalities][self.lang_modalities[0]]['input_ids'].to(device)
).to(device).squeeze(1)
# 将 [1, 1, D] 的 Prompt Embedding 在 Batch 维复制为 [B, 1, D];
# 其中 Expand(B, -1, -1) 仅扩展视图,不复制 Token 维与特征维。
task_prompt = self.prompt_embeds.expand(B, -1, -1)
# dim=1 表示按序列维度拼接 Prompt、视觉与文本特征:
# [B, 1, D] + [B, N_vis, D] + [B, N_txt, D] -> [B, 1+N_vis+N_txt, D]。
merged_embeds = torch.cat([task_prompt.to(image_features.device), image_features, text_embeds.to(image_features.device)], dim=1)
# merged_embeds 的形状为 [B, N_total, D],其中 N_total = 1(Prompt)+ N_vis + N_txt。
# 注意:下面这行当前未被实际使用。
attention_mask = torch.ones(merged_embeds.shape[:2], device=merged_embeds.device)
# 读取语言 Attention Mask,通常形状为 [B, 1, N_txt];
# 经 squeeze(1) 后变为 [B, N_txt],与序列维对齐。
lang_attention_mask = batch[self.goal_modalities][self.lang_modalities[0]]['attention_mask'].to(device).squeeze(1)
# 视觉特征 image_features 的形状为 [B, N_vis, D],因此视觉 Mask 形状为 [B, N_vis],
# 这里用全 1,表示视觉 Token 默认全部有效。
vis_attention_mask = torch.ones(image_features.shape[:2], device=image_features.device)
# Prompt 只有 1 个 Token,对应 Mask 形状为 [B, 1]。
# 这里用全 0,表示 Prompt Token 在 Attention 中被显式屏蔽。
prompt_mask = torch.zeros(B, 1, dtype=torch.bool, device=image_features.device)
# 按序列维(dim=1)拼接三段 Mask:
# [B, 1](Prompt) + [B, N_vis](Vision) + [B, N_txt](Text) -> [B, 1+N_vis+N_txt]。
attention_mask = torch.cat([prompt_mask, vis_attention_mask, lang_attention_mask], dim=1)
# 1)通过 VLM Encoder 对拼接后的多模态序列做上下文编码。
# inputs_embeds 形状:[B, N_total, D],其中 N_total = 1(Prompt)+ N_vis + N_txt。
# attention_mask 形状:[B, N_total],与 inputs_embeds 的序列维一一对应。
# last_hidden_state 输出形状:[B, N_total, D],每个 Token 都变为上下文相关特征。
features = self.vlm.get_encoder()(
inputs_embeds=merged_embeds,
attention_mask=attention_mask,
).last_hidden_state
# 2)对 Encoder 输出做 Token 级 Dropout(仅训练时生效,Eval 时自动关闭)。
# 输入/输出形状保持不变:[B, N_total, D]。
features = self.vlm_token_dropout(features)
# 3)可选的 CFG 文本条件 Dropout:
# 仅在训练阶段且 cfg_dropout > 0 时,按“样本级”随机屏蔽文本段特征,
# 用于模拟无文本条件分支,提升 Classifier-Free Guidance 的稳定性。
# 对应配置项为 conf/trainer/agent/flower_vla.yaml 的 agent.cfg_dropout
# (Hydra 键为 trainer.agent.cfg_dropout,当前默认配置值为 0.0)。
if self.cfg_dropout > 0 and self.training:
# 三段长度:Prompt 段、Vision 段、Text 段。
prompt_length = task_prompt.shape[1]
image_length = image_features.shape[1]
# Text 段长度。
text_length = text_embeds.shape[1]
text_start = prompt_length + image_length
# 文本片段在 features 中的切片区间是 [text_start:text_end)。
text_end = text_start + text_length
# 为每个样本生成 1 个 Bernoulli 掩码,形状 [B, 1, 1],用于广播到文本段:
# 取值 1 表示“丢弃该样本文本条件”,取值 0 表示“保留文本条件”。
drop_mask = (torch.rand(B, device=device) < self.cfg_dropout).float().view(B, 1, 1)
# 仅对文本段应用掩码,维持 Prompt/Vision 段不变:
# features[:, text_start:text_end, :] *= (1 - drop_mask)。
features[:, text_start:text_end, :] = features[:, text_start:text_end, :] * (1 - drop_mask)
return {
'features': features,
'frequency_embeds': self.frequency_embedder(batch[self.goal_modalities]['frequency'].to(device).to(default_dtype)),
'action_space_embeds': self.action_space_embedder(batch[self.goal_modalities]['action_space_index'].to(device)),
'action_type': batch[self.goal_modalities]['action_space_index'],
'proprio': batch[self.obs_modalities]['proprio'].to(device).to(default_dtype) if self.use_proprio and 'proprio' in batch[self.obs_modalities] else None,
'attention_mask': attention_mask,
}2.2. encode_actions()
# 按样本的 Action Type 将原始动作序列编码到统一的 DiT 隐空间。
# 处理流程包括:
# 1)根据 action_type 为 Batch 内样本分组;
# 2)对每组样本仅截取其有效动作维 adim,送入对应的 Action Encoder;
# 3)将编码结果写回统一形状的 encoded(最后一维为 dit_dim);
# 4)同步构造 valid_dims 掩码,标记哪些维度是该样本的真实有效维度。
# 维度约定:
# - 输入 z 形状:[B, T, A_max](A_max 为最大动作维);
# - 输入 action_type 形状:[B],每个样本 1 个动作空间索引;
# - 输出 encoded 形状:[B, T, dit_dim];
# - 输出 valid_dims 形状:[B, T, A_max],有效维位置为 1,无效维位置为 0。
def encode_actions(self, z: torch.Tensor, action_type: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encodes actions for each sample based on its action type.
Returns:
- Encoded actions (latent representations).
- A valid dimensions mask.
"""
# 对齐模型默认精度,避免后续创建的新张量与模型参数 dtype 不一致。
default_dtype = next(self.parameters()).dtype
# 将动作类型索引放到当前设备,便于后续布尔掩码索引。
action_type = action_type.to(self.device)
B = z.shape[0]
# 统一的编码输出容器:[B, T, dit_dim]。
encoded = torch.zeros(B, z.shape[1], self.dit_dim, device=self.device, dtype=default_dtype)
# 有效维掩码容器:[B, T, A_max];初始时全部为 0,后续仅将真实有效维标为 1。
valid_dims = torch.zeros_like(z, dtype=default_dtype)
# 遍历所有已注册动作空间;每个动作空间使用各自的 Encoder。
for action_name, action_idx in self.action_space_index.action_spaces.items():
# mask 的形状为 [B],用于选出当前动作空间对应的样本。
mask = (action_type == action_idx)
if mask.any():
# 当前动作空间的真实动作维度 adim(adim <= A_max)。
adim = self.action_space_index.get_action_dim(action_idx)
# valid_dims[mask, :, :adim] 的含义:
# - 第 1 维(Batch 维)用 mask 选出当前 Action Type 的样本;
# - 第 2 维(时间维)取全部 T;
# - 第 3 维(动作维)取前 adim 个有效维置 1。
valid_dims[mask, :, :adim] = 1
# 仅编码有效维输入 z[mask, :, :adim],输出写回统一容器 encoded[mask]。
encoded[mask] = self.action_encoders[action_name](z[mask, :, :adim])
return encoded, valid_dims2.3. decode_actions()
# 将隐空间表示解码回真实动作。
# 处理流程包括:
# 1)根据 action_type 为 Batch 内样本分组;
# 2)对每组样本使用对应的 Action Decoder,将 [B_group, T, dit_dim] 解码为动作预测;
# 3)仅将该动作空间有效维 [:adim] 写回 decoded,其余维保持 0;
# 4)用 valid_dims 进一步约束有效维,保证无效维不被误写入。
# 维度约定:
# - 输入 z 形状:[B, T, dit_dim];
# - 输入 action_type 形状:[B];
# - 输入 valid_dims 形状:[B, T, A_max](A_max 与 self.action_dim 对齐);
# - 输出 decoded 形状:[B, T, A_max]。
def decode_actions(self, z: torch.Tensor, action_type: torch.Tensor, valid_dims: torch.Tensor) -> torch.Tensor:
"""
Decodes latent representations into actual actions.
Only the dimensions corresponding to valid action spaces are active.
"""
# 对齐模型默认精度,避免新建张量 dtype 与模型参数不一致。
default_dtype = next(self.parameters()).dtype
B = z.shape[0]
# 动作维上限(A_max)。
max_action_dim = self.action_dim
# 解码输出容器:[B, T, A_max],初始为 0。
decoded = torch.zeros(B, z.shape[1], max_action_dim, device=self.device, dtype=default_dtype)
# 遍历所有已注册动作空间;每个动作空间使用各自的 Decoder。
for action_name, action_idx in self.action_space_index.action_spaces.items():
# mask 的形状为 [B],用于选出当前动作空间对应的样本子集。
mask = (action_type == action_idx)
if mask.any():
# 当前动作空间的真实动作维度 adim(adim <= A_max)。
adim = self.action_space_index.get_action_dim(action_idx)
# 对该样本子集执行解码,pred 形状通常为 [B_group, T, A_pred]。
pred = self.action_decoders[action_name](z[mask])
# 仅写回有效动作维 [:adim],并且乘以 valid_dims 做显式掩码约束:
# decoded[mask, :, :adim] <- pred[..., :adim] * valid_dims[mask, :, :adim]。
# 在当前实现下(decoded 初始全 0,且仅写回 :adim)这一乘法通常是冗余的安全保护。
decoded[mask, :, :adim] = pred[..., :adim] * valid_dims[mask, :, :adim]
return decoded2.4. encode_proprio()
# 按 Action Type 编码本体感觉数据。
# 输入约定:
# - proprio:形状为 [B, D_prop];
# - action_type:形状 [B],用于按 Batch 维分组选择对应 Encoder;
# - output_shape:调用方传入的对齐参考形状;
# 输出:
# - encoded:形状 [B, D_dit],可与 Timestep Embedding、Frequency Embedding 逐元素相加。
# 流程:
# 1)use_proprio=False:直接返回全 0 的 [B, D_dit];
# 2)use_proprio=True:按 Action Type 分组编码,并且写回统一容器。
def encode_proprio(self, proprio: torch.Tensor, action_type: torch.Tensor, output_shape) -> torch.Tensor:
"""
Encodes proprioceptive data based on action type.
Returns a tensor with shape [batch, dit_dim].
"""
# 从 output_shape 获取 Batch 大小 B。
batch_size, _ = output_shape
dtype = next(self.parameters()).dtype
# 不使用 Proprio 条件时,返回 [B, D_dit] 全 0 张量,保持与其他条件向量对齐。
if not self.use_proprio:
return torch.zeros(batch_size, self.dit_dim, device=self.device)
# 统一输出容器:[B, D_dit]。
encoded = torch.zeros(batch_size, self.dit_dim, device=self.device, dtype=dtype)
# 遍历动作空间,每种 Action Type 使用对应 Proprio Encoder。
for action_name, action_idx in self.action_space_index.action_spaces.items():
# mask 的形状为 [B],选出当前动作空间对应的样本。
mask = (action_type == action_idx)
if mask.any():
# proprio[mask] 的形状为 [B_group, D_prop];
# 当前 Encoder 实现返回 [B_group, D_dit],squeeze(1) 属于兼容性写法,
# 仅在存在长度为 1 的中间维时才会生效,最终写回 encoded[mask]。
encoded[mask] = self.proprio_encoders[action_name](proprio[mask]).squeeze(1).to(dtype)
return encoded