1. rf_loss()
# 计算 Rectified Flow 损失:
# 目标是学习速度场 v_theta(zt, t, cond),使其逼近真实速度 (z1 - actions)。
# 处理流程包括:
# 1)按采样策略为每个样本采样时间 t(形状 [B]);
# 2)按 Action Type 在各自有效动作维生成噪声 z1(形状 [B, T, A_max]);
# 3)构造线性插值 zt = (1 - t) * actions + t * z1;
# 4)前向传播得到 vtheta = dit_forward(zt, t, cond);
# 5)仅在有效动作维上计算 MSE,忽略无效/填充维。
# 维度约定:
# - actions 形状:[B, T, A_max](若输入为 [B, 1, T, A_max],将先 squeeze 到三维);
# - action_type 形状:[B];
# - t 形状:[B],texp 形状:[B, 1, 1],用于广播到动作张量;
# - 输出 loss 为标量,losses_dict 为日志统计字典。
def rf_loss(self, cond: dict, actions: torch.Tensor, dataset_idx: Any = None) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Computes the rectified flow loss.
Interpolates between actions and noise, then computes MSE only over valid dimensions.
"""
# 对齐模型默认精度,确保后续新建张量与模型参数 dtype 一致。
default_dtype = next(self.parameters()).dtype
# 从条件字典读取每个样本对应的动作空间类型索引。
# 形状为 [B],每个样本 1 个整型 Action Type。
action_type = cond['action_type']
# 兼容可能的四维输入:[B, 1, T, A_max] -> [B, T, A_max]。
if len(actions.shape) == 4:
actions = actions.squeeze(1)
b = actions.size(0)
device = actions.device
actions = actions.to(default_dtype)
# 按配置的 Sampling Type 采样时间 t,每个样本一个标量,范围在 [0, 1)。
# sampling_type 来源于 FlowerVLA 构造参数 sampling_type(默认值为 "ln"),
# 可通过 conf/trainer/agent/flower_vla.yaml 的 agent.sampling_type 进行覆盖。
if self.sampling_type == "pi_zero":
alpha, beta = 1.5, 1.0
t = torch.distributions.Beta(alpha, beta).sample((b,)).to(device).clamp(max=0.999)
elif self.sampling_type == "ln":
t = torch.sigmoid(torch.randn((b,), device=device)).clamp(max=0.999).to(default_dtype)
elif self.sampling_type == "uniform":
# eps 用于避免采样到 1.0,后续插值或数值计算时可减少边界问题。
eps = 1e-5
# 先生成 [0, 1) 的随机偏移量 rand,再加上等间隔网格 arange(b)/b,
# 得到长度为 b 的“分层均匀采样”时间点(Stratified Uniform Sampling)。
# 取模 (1 - eps) 可将越界值回卷到区间内,最终 t 的形状为 [B],范围约在 [0, 1-eps)。
t = (torch.rand(1, device=device) + torch.arange(b, device=device) / b) % (1 - eps)
# 与模型其余张量精度对齐。
t = t.to(default_dtype)
else:
raise NotImplementedError(f"Sampling type {self.sampling_type} not implemented")
# 将 t 扩展为可广播形状 [B, 1, 1](对于三维 actions)。
texp = t.view([b] + [1] * (actions.dim() - 1))
# z1 为噪声端点,初始化为全 0;后续仅在有效动作维写入高斯噪声。
z1 = torch.zeros_like(actions)
for action_name, action_idx in self.action_space_index.action_spaces.items():
mask = (action_type == action_idx)
if mask.any():
adim = self.action_space_index.get_action_dim(action_idx)
# 仅为当前 Action Space 的有效维 [:adim] 采样噪声。
# 维度对齐关系:
# - mask.sum() = B_group(当前 Action Type 选中的样本数);
# - noise_slice 形状为 [B_group, T, adim];
# - z1[mask, :, :adim] 先按 mask 选 Batch 子集,再取全部时间步与前 adim 维,
# 左值切片形状同样是 [B_group, T, adim],因此可直接赋值。
noise_slice = torch.randn((mask.sum(), actions.size(1), adim), dtype=actions.dtype, device=actions.device)
z1[mask, :, :adim] = noise_slice
# 线性插值到中间点 zt:
# - t=0 对应数据端(actions);
# - t=1 对应噪声端(z1)。
# 因此 zt = (1 - t) * actions + t * z1。
zt = (1 - texp) * actions + texp * z1
# 预测速度场 vtheta,形状与 actions/z1 一致:[B, T, A_max]。
vtheta = self.dit_forward(zt, t, cond)
# 构造有效维掩码:当前样本对应 Action Space 的前 adim 维为 True,其他维为 False。
valid_mask = torch.zeros_like(actions, dtype=torch.bool)
for action_name, action_idx in self.action_space_index.action_spaces.items():
mask = (action_type == action_idx)
if mask.any():
# 当前 Action Space 的有效动作维度。
adim = self.action_space_index.get_action_dim(action_idx)
# mask 原始形状为 [B],表示哪些样本属于当前 Action Type。
# 先 View 成 [B, 1, 1],再 Expand 到 [B, T, adim],使其与
# valid_mask 在“时间维 + 有效动作维”上的目标切片形状对齐。
mask_expanded = mask.view(-1, 1, 1).expand(-1, actions.size(1), adim).to(device)
# 左值 valid_mask[mask, :, :adim] 形状为 [B_group, T, adim];
# 右值 mask_expanded[mask] 形状同样为 [B_group, T, adim]。
# 赋值后,当前 Action Type 样本在前 adim 维被置为 True,其他位置保持 False。
valid_mask[mask, :, :adim] = mask_expanded[mask]
# 监督 Label(目标速度场)定义为 v* = z1 - actions,形状为 [B, T, A_max]。
# vtheta 形状同为 [B, T, A_max],因此 diff = v* - vtheta 形状也是 [B, T, A_max]。
diff = (z1 - actions) - vtheta
# 只在有效动作维取值:valid_diff 为一维展平向量 [N_valid],N_valid 为有效元素总数。
valid_diff = diff[valid_mask]
loss = (valid_diff ** 2).mean()
# 记录训练可观测统计量,便于排查数值范围与稳定性。
losses_dict = {
"diff_min": valid_diff.min().item(),
"diff_max": valid_diff.max().item(),
"diff_mean": valid_diff.mean().item(),
"loss": loss.item(),
}
if hasattr(self, 'accelerator') and self.accelerator is not None and wandb.run is not None:
if self.accelerator.is_main_process:
wandb.log(losses_dict)
return loss, losses_dict
2. dit_forward()
说明
- 通过
self.vlm.config.text_config.d_model,从 VLM 文本配置读取隐藏维度 d_model(文本/多模态 Token 的通道维)。
# 通过 DiT 块前向传播,编码动作,添加位置信息,以及应用条件。
#
# 输入与形状约定:
# - z:[B, T_act, A_max],噪声/中间动作序列;
# - t:[B],时间步;
# - cond_dict["features"]:[B, N_cond, D_vlm];
# - cond_dict["frequency_embeds"]:[B, D_dit];
# - cond_dict["action_type"]:[B];
# - cond_dict["proprio"](可选):[B, D_prop];
# - cond_dict["attention_mask"]:[B, N_cond]。
def dit_forward(self, z: torch.Tensor, t: torch.Tensor, cond_dict: dict) -> torch.Tensor:
"""
Forward pass through the DiT blocks.
Encodes actions, adds positional information, and applies conditioning.
"""
# z 形状拆解:Batch、动作序列长度、动作维。
B, t_seq, d = z.shape
# 读取模型参数默认 dtype。
dtype = next(self.parameters()).dtype
# 条件分支:features 从 [B, N_cond, D_vlm] 归一化并且线性映射到 [B, N_cond, D_dit]。
cond = self.cond_linear(self.cond_norm(cond_dict['features'].to(dtype)))
# Frequency Embeds:FreqEmbedder 输出 [B, D_dit],与 Time / Proprio 全局条件同形。
# 后续主要有三种用途:
# 1)作为条件输入,与 Time Embedding、Proprio Embeds 相加形成 t_emb;
# 2)在训练阶段参与 CFG Dropout(可按样本置零);
# 3)作为形状参考,比如 torch.zeros_like(freq_embeds),创建对齐张量。
freq_embeds = cond_dict['frequency_embeds'].squeeze(1).to(dtype)
# action_type 的形状是 [B],用于后续按 Batch 维进行布尔分组索引。
action_type = cond_dict['action_type'].to(self.device)
# proprio 用于生成本体条件向量:
# - 启用 use_proprio 时:读取 cond_dict["proprio"];
# - 否则:退化为与 freq_embeds 同形状的全 0 张量 [B, D_dit]。
proprio = cond_dict.get('proprio', torch.zeros_like(freq_embeds)).to(dtype) if self.use_proprio else torch.zeros_like(freq_embeds)
# encode_proprio 输出 [B, D_dit],与 freq_embeds、t_embed 可逐元素相加。
proprio_embeds = self.encode_proprio(proprio, action_type, freq_embeds.shape).to(dtype)
# 按 Action Type 编码动作:
# z:[B, T_act, A_max] -> [B, T_act, D_dit];
# valid_dims:[B, T_act, A_max],记录各样本有效动作维。
z, valid_dims = self.encode_actions(z, action_type)
# 若使用绝对位置编码(非 RoPE/NoPE),则与 z 逐元素相加。
# 要求 self.positional_encoding 可广播或同形到 [B, T_act, D_dit]。
# 该相加不会改变 z 的形状;若位置编码长度大于当前 T_act,PyTorch 不会自动截断,
# 通常会因维度不匹配报错,需要显式切片到 [:, :T_act, :] 才能对齐。
if not (self.use_rope or self.use_nope):
z += self.positional_encoding
# 仅对频率与本体条件执行 CFG Dropout,不改变 z 主分支:
# drop_mask:[B, 1],沿特征维广播到 [B, D_dit]。
if self.training and self.cfg_dropout > 0:
# 按样本生成二值掩码(1=Drop,0=Keep)。
drop_mask = (torch.rand(freq_embeds.size(0), device=freq_embeds.device) < self.cfg_dropout).float().unsqueeze(1)
freq_embeds = freq_embeds * (1 - drop_mask)
proprio_embeds = proprio_embeds * (1 - drop_mask)
# 时间与全局条件融合,t_embedder(t) 输出 [B, D_dit],三项相加后 t_emb 仍为 [B, D_dit]。
t_emb = sum(map(stateless_norm, [self.t_embedder(t), freq_embeds, proprio_embeds]))
# 计算全局条件 global_cond,供 AdaLN 使用:
# - cond:[B, N_cond, D_dit],readout 取 cond[:,0,:],否则对 N_cond 维计算 mean → 均为 [B, D_dit];
# - t_emb:[B, D_dit];二者同形,+= 为逐样本逐维相加。
# - use_adaln_cond=False 时仅用 t_emb。
if self.use_adaln_cond:
global_cond = cond[:, 0, :] if self.use_readout_token else cond.mean(dim=1)
global_cond += t_emb
else:
global_cond = t_emb
# Cross-Attention 上下文:
# - 开启时 context=cond([B, N_cond, D_dit]);
# - 关闭时 context=None。
context = cond if self.use_cross_attn else None
# AdaLN 调制信号:
# - action_type_adaln=False:共享调制;
# - action_type_adaln=True:按 Action Type 生成分组调制。
global_adaln = self.adaln(global_cond) if not self.action_type_adaln else self.action_specific_adaln(global_cond, action_type)
# 逐层 DiT Block 前向:
# 输入 z 始终保持 [B, T_act, D_dit];
# custom_cross_attn_mask 使用 cond_dict["attention_mask"]([B, N_cond])。
for layer in self.dit:
z = layer(z, global_cond, context=context, custom_attn_mask=None,
custom_cross_attn_mask=cond_dict['attention_mask'], is_causal=True, global_adaln=global_adaln)
# 解码回动作空间,同时应用有效维约束:
# [B, T_act, D_dit] -> [B, T_act, A_max]。
return self.decode_actions(z, action_type, valid_dims)
# 计算动作专属的 AdaLN 调制信号。
# 输入:
# - global_cond:全局条件张量,形状为 [B, D_dit];
# - action_type:每个样本对应的动作类型索引,形状为 [B];
# 输出:
# - List[Tensor],其中每个 Tensor 表示一条调制分量通路的全 Batch 结果。
def action_specific_adaln(self, global_cond: torch.Tensor, action_type: torch.Tensor) -> List[torch.Tensor]:
"""
Computes action-specific AdaLN modulation signals.
Returns a list of modulation tensors.
"""
dtype = next(self.parameters()).dtype
batch_size = global_cond.shape[0]
num_chunks = 9 if self.use_cross_attn else 6
# 先为每路 AdaLN 调制信号分量预分配 [B, D_dit] 的零张量;
# 后续将按动作类型用 mask 只覆盖对应样本位置。
mod_signals = [torch.zeros(batch_size, self.dit_dim, device=self.device, dtype=dtype) for _ in range(num_chunks)]
# 每次迭代针对一种动作类型,为该类型样本生成及回填对应的调制信号。
for action_idx in range(len(self.action_space_index.action_spaces)):
# 按当前 action_idx 取子批,mask 的形状为 [B],True 表示该样本属于当前动作类型。
mask = (action_type == action_idx)
# 仅在当前类型确有样本时才执行对应 AdaLN 控制器,避免空张量前向传播。
if mask.any():
action_name = self.action_space_index.get_action_name(action_idx)
# global_cond[mask] 的形状为 [B_g, D_dit](B_g 为该类型样本数);
# action_mod 是长度为 num_chunks 的列表,每一路的形状均为 [B_g, D_dit]。
action_mod = self.adaln[action_name](global_cond[mask])
# i 对应“第 i 路调制信号”,比如 shift/scale/gate 等不同分量,
# signal 的形状为 [B_g, D_dit],与当前动作类型子批一一对应。
for i, signal in enumerate(action_mod):
# 将当前动作类型子批的结果(signal,形状为 [B_g, D_dit]),
# 按 mask 回填到全 Batch 容器 mod_signals[i]([B, D_dit])的对应行。
# 非当前类型(mask=False)的行不会被改动,继续保留已有值。
mod_signals[i][mask] = signal
return mod_signals
3. 推理
# 推理阶段前向传播:
# 1) 将输入的观测 obs 与任务目标 goal 组装为 batch;
# 2) 调用 encode_observations 提取条件特征;
# 3) 以高斯噪声作为初始动作序列,通过 sample_actions 在条件约束下逐步采样,
# 最终得到模型预测的动作序列,形状通常为 [B, T, D]。
def forward(self, obs: Dict, goal: Dict) -> torch.Tensor:
"""
Inference forward pass.
Given observation and goal dictionaries, it encodes them and samples an action sequence.
"""
batch = {'observation': obs, 'task': goal}
features = self.encode_observations(batch)
# 构造推理初始噪声 z0 ~ N(0, I):
# 第一维 B=len(features['features']),表示当前 Batch 中每个样本各自对应一条初始噪声动作序列。
# 第二维为 act_window_size(动作序列长度)。
# 第三维为 action_dim,来自于 conf/trainer/agent/flower_vla.yaml 中的 action_dim: ${act_dim},
# act_dim 的定义在 conf/training.yaml。
# device 与特征保持一致,避免后续采样阶段发生跨设备张量运算错误。
noise = torch.randn(len(features['features']), self.act_window_size, self.action_dim,
device=features['features'].device)
return self.sample_actions(noise, features, inference=True)
# 从 DiT 模型中采样动作。
def sample_actions(self, z: torch.Tensor, cond: Dict[str, torch.Tensor], inference: bool = False) -> torch.Tensor:
"""
Samples actions from the DiT model.
Chooses between an adaptive ODE solver and fixed-step Euler integration.
"""
steps = self.num_sampling_steps if inference else 5
b = z.size(0)
# 按各自动作类型,将每个样本的无效动作维度清零。
action_type = cond['action_type']
for action_name, action_idx in self.action_space_index.action_spaces.items():
mask = (action_type == action_idx)
if mask.any():
adim = self.action_space_index.get_action_dim(action_idx)
z[mask, :, adim:] = 0.0
# 当前文件内 use_dopri5 仅在初始化时设为 False,且未被改写,
# 因此默认走 fixed-step 分支,除非外部代码在运行时手动修改该属性。
if hasattr(self, 'use_dopri5') and self.use_dopri5:
return self._sample_with_adaptive_solver(z, cond)
else:
return self._sample_with_fixed_steps(z, cond, inference)
# 使用固定步长的 Euler 积分采样动作。
def _sample_with_fixed_steps(self, z: torch.Tensor, cond: Dict[str, torch.Tensor], inference: bool = False) -> torch.Tensor:
"""
Samples actions using fixed-step Euler integration.
"""
steps = self.num_sampling_steps if inference else 5
b = z.size(0)
device = z.device
action_type = cond['action_type']
# 固定步长 Euler 的时间步长:将区间 [0, 1] 均分为 steps 份。
dt = 1.0 / steps
# 将标量 dt 扩展为可广播张量 [B, 1, ..., 1],
# 以便与速度项 vc(通常为 [B, T, D])逐元素相乘得到步进增量,再用于更新 z。
dt_tensor = torch.tensor([dt] * b, device=device).view([b] + [1] * (z.dim() - 1))
# 仅在推理阶段且 `cfg_lambda` != 1.0 时启用 CFG。
# `cfg_lambda` 在本类 `_init_flags` 中默认初始化为 1.0(默认不启用 CFG),
# 若需要启用 CFG,需在运行时将 `self.cfg_lambda` 设为其他值。
apply_cfg = inference and self.cfg_lambda != 1.0
# 在循环外仅构造一次 CFG 的“无文本条件分支”条件,避免每步重复创建。
if apply_cfg:
null_cond = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in cond.items()}
# 先取 `features` 的副本,再修改副本,最后写回。
features = null_cond['features'].clone()
# 计算文本特征起始位置:
# `features` 按 [Prompt Token][Image Token][Text Token] 的顺序拼接,
# 因此 `text_start = prompt_length + image_length`。
prompt_length = self.prompt_embeds.shape[1]
image_length = 50 if self.use_second_view is False else 100
text_start = prompt_length + image_length
# 仅将文本区段置零,保留 Prompt 和图像条件,构造“弱条件/近无条件”输入供 CFG 使用。
features[:, text_start:, :] = 0.0
# 将修改后的 `features` 写回 `null_cond`,供后续计算 `vu`(Unconditional Velocity)。
null_cond['features'] = features
for i in range(steps, 0, -1):
# 将当前步号 i 转成时间值 t=i/steps,作为本步的“时间标签”。
# 比如当 steps=4 时,t 依次为 1.0、0.75、0.5、0.25(从 1 向 0 方向推进)。
t_val = i / steps
# 将标量 t 扩展为 [B] 形状,使 Batch 内每个样本在该步使用同一时间输入。
t_tensor = torch.full((b,), t_val, device=device)
# 先用条件输入 cond 预测速度 vc(Conditional Velocity)。
vc = self.dit_forward(z, t_tensor, cond)
# 若启用 CFG,则再计算无条件速度 vu,然后做线性融合:
# vc <- vu + cfg_lambda * (vc - vu)。
# 其中 (vc - vu) 表示条件信息带来的引导方向,cfg_lambda 越大引导越强。
if apply_cfg:
vu = self.dit_forward(z, t_tensor, null_cond)
vc = vu + self.cfg_lambda * (vc - vu)
# Euler 反向积分更新:z <- z - dt * vc。
# 逐步将 z 从噪声端拉回到动作端。
z = z - dt_tensor * vc
# 每个 Euler 步更新后,按动作类型再次将无效动作维度清零,
# 防止这些维度在迭代过程中被数值更新“污染”。
for action_name, action_idx in self.action_space_index.action_spaces.items():
mask = (action_type == action_idx)
if mask.any():
adim = self.action_space_index.get_action_dim(action_idx)
z[mask, :, adim:] = 0.0
# 将输出限制在 [-1, 1],与训练时的动作归一化范围保持一致,
# 同时抑制采样积分产生的少量越界值,提升推理稳定性。
return z.clamp(-1, 1)