FlowerVLA 源码解析 1 - 网络部分

源码:https://github.com/intuitive-robots/flower_vla_pret/blob/main/flower_vla/agents/networks/transformers.py


1. 工具函数

# 描述:返回大于或等于 n,并且是 k 的倍数的最小数字。
# 示例:n = 10,k = 3,那么返回 12。
def find_multiple(n: int, k: int) -> int:
    """
    Returns the smallest number greater than or equal to n that is a multiple of k.
    """
    return n if n % k == 0 else n + k - (n % k)

# 描述:归一化 x,并且无需维护运行时统计。
def stateless_norm(x: torch.Tensor) -> torch.Tensor:
    """
    Normalizes x without maintaining running statistics.
    """
    # 在最后一个维度上计算均值,并且保持维度,方便广播。
    mean = x.mean(dim=-1, keepdim=True)
    # 在最后一个维度上计算方差,并且保持维度,方便广播。
    # σ² = Σ(xi - μ)² / N
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    # 归一化:减去均值,除以标准差。(1e-6 用于在方差为 0 或极小时,稳定分母)
    return (x - mean) / torch.sqrt(var + 1e-6)

# 根据 shift、scale 对 x 应用调制(由调制信号控制的仿射变换)。
# 调制公式:out = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)。
# 典型形状:x 为 [B,L,D],shift、scale 为 [B,D];unsqueeze(1) 得 [B,1,D],沿序列维 L 广播。scale=0 时缩放因子为 1(即恒等)。
# 示例:设 B=1,L=2,D=2,x = [[ [1,2], [3,4] ]];scale、shift 形状 [1,2],取值 [0,1] 与 [10,0]。
# 沿 L 广播后两行共用同一组 scale/shift:第一行得 [11,4],第二行得 [13,8],即 [[ [11,4], [13,8] ]],out 的形状仍为 [1,2,2]。
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """
    Applies a modulation to x given shift and scale signals.
    The modulation formula: x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
    """
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

2. SwiGlu

# Transformer 前馈层中使用的 SwiGlu MLP 激活实现。
# 参数:
#   dim:输入维度。
#   hidden_dim:隐藏层维度;为 None 时默认为 4 * dim。
#   dropout:Dropout 概率。
#   output_dim:输出维度;为 None 时默认为 dim。
class SwiGlu(nn.Module):
    """
    An implementation of the SwiGlu MLP activation as used in transformer feedforward layers.
    
    Args:
        dim: Input dimension.
        hidden_dim: Dimension of the hidden layer. If None, defaults to 4 * dim.
        dropout: Dropout probability.
        output_dim: Output dimension. Defaults to dim.
    """
    def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0, output_dim: Optional[int] = None) -> None:
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
        # 隐藏层维度等于 hidden_dim 的 2/3,再四舍五入到 256 的倍数。
        n_hidden = int(2 * hidden_dim / 3)
        n_hidden = find_multiple(n_hidden, 256)
        if output_dim is None:
            output_dim = dim
        self.fc1 = nn.Linear(dim, n_hidden, bias=False)
        self.fc2 = nn.Linear(dim, n_hidden, bias=False)
        self.proj = nn.Linear(n_hidden, output_dim, bias=False)
        # Dropout 的主要作用是作为正则化方法:迫使网络不依赖少数固定神经元组合,削弱特征与共适应,提升对缺失扰动的鲁棒性;
        # 亦可视为对大量随机子网络做近似集成,从而降低过拟合风险。直观上,每轮在随机子网络上更新参数以达成上述效果。
        #
        # 训练时以概率 (p) 随机将一部分神经元输出置零,等价于暂时断开对应连接;
        # 被置零位置在计算图分支上梯度为 0(链式法则乘 0),该轮误差不沿此路径回传,与之相连的权重亦收不到来自被丢弃单元的梯度,其余路径照常反传。
        # 评估时关闭 dropout,使用完整网络推理。
        # 常见的 inverted dropout 在训练时对保留位置除以 (1-p),使保留单元输出相对 mask 前被放大,
        # 并且与推理阶段不做缩放时的期望尺度一致;
        # 若采用经典实现(训练时保留位置不缩放),则保留单元不放大,推理时常对权重或输出整体乘以 (1-p) 以对齐期望。
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through SwiGlu MLP.
        """
        # 对张量逐元素应用 SiLU(Sigmoid Linear Unit)激活。
        x1 = F.silu(self.fc1(x))
        x2 = self.fc2(x)
        # 计算 x1 和 x2 的哈达玛积
        x = x1 * x2
        x = self.dropout(x)
        x = self.proj(x)
        return x

3. 旋转位置编码(Rotary Positional Embedding)

# 预计算用于 1 维旋转嵌入的余弦和正弦频率矩阵。
# 参数:
#   dim:Embedding 的维度。
#   max_seq_len:最大序列长度。
#   theta:频率缩放因子;默认 10000.0。
# 返回:
#   (cosine, sine) - 余弦和正弦频率矩阵,形状均为 [max_seq_len, dim/2]。
def precompute_freqs_1d(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Precomputes cosine and sine frequency matrices for 1D rotary embeddings.
    Returns:
        (cosine, sine): Tensors of shape [max_seq_len, dim/2].
    """
    # Shape: [dim/2],比如 dim = 8,freqs = [0, 2, 4, 6]
    freqs = torch.arange(0, dim, 2).float()
    # Shape: [dim/2],比如 [10000^(-0/dim), 10000^(-2/dim), 10000^(-4/dim), 10000^(-6/dim)]
    freqs = theta ** (-freqs / dim)
    # Shape: [max_seq_len],比如 max_seq_len = 4,positions = [0, 1, 2, 3]
    positions = torch.arange(max_seq_len).float()
    # positions.unsqueeze(1):
    # - Shape: [max_seq_len, 1]
    # - 比如 [[0], [1], [2], [3]]
    #
    # freqs.unsqueeze(0):
    # - Shape: [1, dim/2]
    # - 比如 [[10000^(-0/dim), 10000^(-2/dim), 10000^(-4/dim), 10000^(-6/dim)]]
    #
    # angles:
    # - Shape: [max_seq_len, dim/2]
    # - 比如 [
    #   [0*10000^(-0/dim), 0*10000^(-2/dim), 0*10000^(-4/dim), 0*10000^(-6/dim)],
    #   [1*10000^(-0/dim), 1*10000^(-2/dim), 1*10000^(-4/dim), 1*10000^(-6/dim)],
    #   [2*10000^(-0/dim), 2*10000^(-2/dim), 2*10000^(-4/dim), 2*10000^(-6/dim)],
    #   [3*10000^(-0/dim), 3*10000^(-2/dim), 3*10000^(-4/dim), 3*10000^(-6/dim)]
    # ]
    angles = positions.unsqueeze(1) * freqs.unsqueeze(0)
    return angles.cos(), angles.sin()

# 将旋转位置编码(RoPE)应用到 q 和 k。
# 参数:
#   q:查询张量,形状为 [批量大小,注意力头的数量,序列长度,每个头的维度]。
#   k:键张量,形状与 q 相同。
#   cos, sin:Cosine 和 Sine 频率张量,形状为 [最大序列长度,每个头的维度/2]。
#   position_ids:可选的位置索引张量;若为 None,则使用按顺序的位置(0、1、2、…、序列长度-1)。
# 返回:
#   施加旋转位置编码的 q 和 k。
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
                         cos: torch.Tensor, sin: torch.Tensor,
                         position_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Applies rotary positional embeddings to queries and keys.
    
    Args:
        q: Query tensor of shape [B, heads, seq_len, head_dim].
        k: Key tensor with the same shape as q.
        cos, sin: Cosine and sine frequency tensors of shape [max_seq_len, head_dim/2].
        position_ids: Optional tensor with position indices; if None, uses sequential positions.
    
    Returns:
        A tuple (q_rot, k_rot) with rotary embeddings applied.
    """
    seq_len = q.size(-2)
    if position_ids is None:
        position_ids = torch.arange(seq_len, device=q.device)
    cos = cos[position_ids].unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim/2]
    sin = sin[position_ids].unsqueeze(0).unsqueeze(0)
    # 将 q 和 k 的最后一维一分为二,形状均为 [B, heads, seq_len, head_dim/2]
    q1, q2 = q.chunk(2, dim=-1)
    k1, k2 = k.chunk(2, dim=-1)
    q_rot = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1)
    k_rot = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1)
    return q_rot, k_rot

4. 注意力模块

4.1. 多头自注意力

# 多头自注意力模块,可选使用旋转位置编码。
# 参数:
#   dim:输入维度,也叫 Transformer 的维度。
#   n_heads:注意力头数。
#   attn_pdrop:注意力概率的 Dropout 概率。
#   resid_pdrop:输出投影的 Dropout 概率。
#   use_rope:是否使用旋转位置编码。
#   max_seq_len:用于预计算旋转频率的最大序列长度。
#   rope_theta:旋转嵌入的 theta 值。
class FlowerAttention(nn.Module):
    """
    Multi-head self-attention module with optional rotary positional embeddings.
    
    Args:
        dim: Input dimension.
        n_heads: Number of attention heads.
        attn_pdrop: Dropout rate on the attention probabilities.
        resid_pdrop: Dropout rate on the output projection.
        use_rope: Whether to apply rotary embeddings.
        max_seq_len: Maximum sequence length for precomputed rotary frequencies.
        rope_theta: Theta value for rotary embeddings.
    """
    def __init__(self,
                 dim: int,
                 n_heads: int,
                 attn_pdrop: float = 0.1,
                 resid_pdrop: float = 0.1,
                 use_rope: bool = False,
                 max_seq_len: int = 120,
                 rope_theta: float = 32) -> None:
        super().__init__()
        # 输入维度必须是注意力头数的倍数。
        assert dim % n_heads == 0, "Dimension must be divisible by number of heads."
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # 缩放点积注意力(SDPA)中的缩放因子,即 1/sqrt(head_dim)。
        self.scale = self.head_dim ** -0.5

        # 融和的 W_Q、W_K、W_V 矩阵。
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        # W_O 投影矩阵。
        self.proj = nn.Linear(dim, dim, bias=False)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # RmsNorm(Root Mean Square Layer Normalization)是 Layer Normalization 的一种变体,用于加速训练。
        self.q_norm = RmsNorm(self.head_dim, eps=1e-6)
        self.k_norm = RmsNorm(self.head_dim, eps=1e-6)
        self.use_rope = use_rope
        if use_rope:
            self.rope_theta = rope_theta
            cos, sin = precompute_freqs_1d(self.head_dim, max_seq_len, theta=rope_theta)
            # register_buffer 是 nn.Module 中用于注册不参与梯度更新的持久化张量的方法。
            # 这些张量随模型一起保存(state_dict)和移动设备(.to(device)),但不被优化器更新(requires_grad=False)。
            self.register_buffer("cos", cos)
            self.register_buffer("sin", sin)
            self.max_seq_len = max_seq_len

    # 自注意力的前向传播。
    # 参数:
    #   x:输入张量,形状为 [B, seq_len, dim]。
    #   custom_attn_mask:可选的自定义注意力掩码。
    #   is_causal:如果为 True,则应用因果掩码。
    #     因果掩码是一种注意力掩码:在自回归模型里,它规定每个位置在算注意力时,只能使用当前位置及之前的 Token 信息,
    #     不能用到更后面的 Token,从而避免模型在训练或生成时“偷看未来”。
    # 返回:
    #   经过注意力计算和投影后的张量,形状为 [B, seq_len, dim]。
    def forward(self, x: torch.Tensor,
                custom_attn_mask: Optional[torch.Tensor] = None,
                is_causal: bool = False) -> torch.Tensor:
        """
        Forward pass for self-attention.
        
        Args:
            x: Input tensor of shape [B, seq_len, dim].
            custom_attn_mask: Optional attention mask.
            is_causal: If True, applies causal masking.
        
        Returns:
            Tensor of shape [B, seq_len, dim] after attention and projection.
        """
        B, T, C = x.size()
        # Fused QKV:self.qkv 为 Linear(dim, 3*dim),一次前向得到拼接的 Q/K/V,形状 [B, T, 3*dim]。
        # reshape:末维 3*dim → 3 组 × (n_heads, head_dim),形状 [B,T,3,n_heads,head_dim]。
        # permute(2,0,3,1,4):得到 [3,B,n_heads,T,head_dim],第 0 维即 Q/K/V,与下方 unbind(0) 对应。
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # unbind(0):沿第 0 维拆成三个张量,形状均为 [B,n_heads,T,head_dim]。
        q, k, v = qkv.unbind(0)
        q = self.q_norm(q)
        k = self.k_norm(k)
        if self.use_rope:
            q, k = apply_rotary_pos_emb(q, k, self.cos, self.sin)
        # 按需构建因果掩码。
        if is_causal and custom_attn_mask is None:
            # triu(..., diagonal=1):上三角(主对角线)之上为 True,即 (i,j) 在 j>i 时为 True,表示未来位需被挡住。
            mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
            # 扩成 [1,1,T,T],与 q/k 的 batch、head 维广播对齐,供后续 ~mask 传入 SDPA。
            mask = mask.unsqueeze(0).unsqueeze(0)
        elif custom_attn_mask is not None:
            # custom_attn_mask 形状为 [B, T, T],缺少 head 维。
            # unsqueeze(1) 在第 1 维插入一个大小为 1 的 head 维,变为 [B, 1, T, T]。
            # expand(-1, n_heads, -1, -1) 在逻辑上将该维扩到 n_heads,形状变为 [B, n_heads, T, T]。
            # 这里不会复制出 n_heads 份数据,而是所有 head 共享同一份掩码视图。
            mask = custom_attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
        else:
            mask = None
        # SDPA 输入的 q/k/v 形状均为 [B, n_heads, T, head_dim],输出形状与之相同。
        # 本代码里的 mask 约定是 True=屏蔽;但 SDPA 的 bool mask 约定是 True=可见,所以传入前要取反(~mask)。
        # dropout_p 仅在训练时生效;
        # scale 使用 head_dim**-0.5(在 __init__ 中预先计算)。
        # is_causal 只在未提供 custom_attn_mask 时启用,避免同时叠加两套因果约束。
        # 最终作用:在给定掩码规则下计算多头注意力,为每个 Token 生成融合上下文后的表示。
        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None if mask is None else ~mask,
            dropout_p=self.attn_dropout.p if self.training else 0.0,
            scale=self.scale,
            is_causal=is_causal if custom_attn_mask is None else False
        )
        # [B,n_heads,T,head_dim] -> transpose -> [B,T,n_heads,head_dim] -> reshape 合并 head,得 [B,T,dim]。
        out = attn_output.transpose(1, 2).reshape(B, T, C)
        # proj:输出投影 W_O;resid_dropout:投影后的 dropout。
        out = self.resid_dropout(self.proj(out))
        return out

4.2. 多头交叉注意力

# 多头交叉注意力模块(Cross-Attention),可选使用旋转位置编码。
# 与 FlowerAttention(自注意力)不同:Q 来自当前序列 x,K/V 来自外部 context。
# 参数:
#   dim:输入/输出维度。
#   n_heads:注意力头数。
#   attn_pdrop:注意力权重 Dropout 概率(仅训练时生效)。
#   resid_pdrop:输出投影后的 Dropout 概率。
#   use_rope:是否对 q、k 应用旋转位置编码。
#   query_seq_len:query 侧(x)的最大长度,用于预计算 q 的 RoPE 频率。
#   context_seq_len:context 侧的最大长度,用于预计算 k 的 RoPE 频率。
#   rope_theta:query 侧 RoPE 的 theta。
#   context_rope_theta:context 侧 RoPE 的 theta(可与 query 侧不同)。
class FlowerCrossAttention(nn.Module):
    """
    Cross-attention module with optional rotary embeddings.
    
    Args:
        dim: Input and output dimension.
        n_heads: Number of attention heads.
        attn_pdrop: Dropout rate on the attention weights.
        resid_pdrop: Dropout rate on the output.
        use_rope: Whether to apply rotary embeddings.
        query_seq_len: Maximum length for queries.
        context_seq_len: Maximum length for context.
        rope_theta: Theta for query rotary embeddings.
        context_rope_theta: Theta for context rotary embeddings.
    """
    def __init__(self,
                 dim: int,
                 n_heads: int,
                 attn_pdrop: float = 0.1,
                 resid_pdrop: float = 0.1,
                 use_rope: bool = False,
                 query_seq_len: int = 64,
                 context_seq_len: int = 384,
                 rope_theta: float = 32,
                 context_rope_theta: float = 1000.0) -> None:
        super().__init__()
        # 输入/输出维度必须是注意力头数的倍数。
        assert dim % n_heads == 0, "Dimension must be divisible by number of heads."
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        # SDPA 缩放系数:1/sqrt(head_dim)。
        self.scale = self.head_dim ** -0.5

        # 交叉注意力中 Q/K/V 分别由独立线性层产生:
        # q_proj 作用于 x(当前序列),k_proj/v_proj 作用于 context(外部条件)。
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        # 输出投影 W_O。
        self.proj = nn.Linear(dim, dim, bias=False)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)
        # 按 head_dim 做 RMS 归一化,稳定 q/k 的尺度。
        self.q_norm = RmsNorm(self.head_dim, eps=1e-6)
        self.k_norm = RmsNorm(self.head_dim, eps=1e-6)
        self.use_rope = use_rope
        if use_rope:
            # 分别为 query 与 context 预计算 RoPE 频率表;二者长度和 theta 可独立设置。
            q_cos, q_sin = precompute_freqs_1d(self.head_dim, query_seq_len, theta=rope_theta)
            k_cos, k_sin = precompute_freqs_1d(self.head_dim, context_seq_len, theta=context_rope_theta)
            # register_buffer:随模型保存和迁移设备,但不参与梯度更新。
            self.register_buffer("q_cos", q_cos)
            self.register_buffer("q_sin", q_sin)
            self.register_buffer("k_cos", k_cos)
            self.register_buffer("k_sin", k_sin)
            self.query_seq_len = query_seq_len
            self.context_seq_len = context_seq_len
            self.rope_theta = rope_theta
            self.context_rope_theta = context_rope_theta

    # 在 x(queries)与 context(keys 和 values)之间应用交叉注意力。
    # 参数:
    #   x:查询张量(query),形状为 [B, seq_len, dim];线性投影后得到注意力中的 Q。
    #   context:上下文张量(context),形状为 [B, context_len, dim];线性投影后得到注意力中的 K/V。
    #   custom_attn_mask:可选注意力掩码;常见形状为 [B, context_len],用于控制 query 可见的 context 位置。
    #     例子:B=2、context_len=5 时,mask 可为 [[1,1,1,0,0],[1,1,0,0,0]],
    #     表示第 1 个样本仅可看前 3 个 context token,第 2 个样本仅可看前 2 个 context token。
    # 返回:
    #   形状为 [B, seq_len, dim] 的张量。
    def forward(self, x: torch.Tensor, context: torch.Tensor,
                custom_attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Applies cross-attention between x (queries) and context (keys and values).
        
        Args:
            x: Query tensor of shape [B, seq_len, dim].
            context: Context tensor of shape [B, context_len, dim].
            custom_attn_mask: Optional attention mask.
        
        Returns:
            Tensor of shape [B, seq_len, dim].
        """
        # x: [B,T,C],context: [B,S,C];T 为 query 长度,S 为 context 长度。
        B, T, C = x.size()
        _, S, _ = context.size()
        # 线性映射及分头(分两步):
        # q:先 reshape [B,T,C] -> [B,T,n_heads,head_dim],再 permute -> [B,n_heads,T,head_dim]。
        # k/v:先 reshape [B,S,C] -> [B,S,n_heads,head_dim],再 permute -> [B,n_heads,S,head_dim]。
        q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(context).reshape(B, S, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_proj(context).reshape(B, S, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # 对 q/k 做按 head 的 RMS 归一化。
        q = self.q_norm(q)
        k = self.k_norm(k)
        if self.use_rope:
            # 给 q/k 添加 RoPE。
            # 这里将同一张量同时作为 q/k 传入,仅复用 apply_rotary_pos_emb 的实现,实际只取第一个返回值。
            q, _ = apply_rotary_pos_emb(q, q, self.q_cos, self.q_sin)
            k, _ = apply_rotary_pos_emb(k, k, self.k_cos, self.k_sin)
        if custom_attn_mask is not None:
            # 先调整 mask 形状,使其与 q 的查询长度维度对齐。
            #
            # 典型输入 mask 为 [B,S](按 context 位置指示可见性)。
            # 先扩成 [B,1,1,S],再沿 head 与 query 长度广播到 [B,n_heads,T,S],
            # 以匹配 SDPA 所需的 [B,n_heads,Lq,Lk] 掩码形状。
            mask = custom_attn_mask.unsqueeze(1).unsqueeze(2)
            mask = mask.expand(-1, self.n_heads, q.size(2), -1)
            # 交叉注意力通常不使用因果约束(is_causal=False),由 custom_attn_mask 控制可见性。
            attn_output = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=mask,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                scale=self.scale,
                is_causal=False
            )
        else:
            # 无自定义掩码时,query 可关注全部 context 位置。
            attn_output = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                scale=self.scale,
                is_causal=False
            )
        # transpose(1, 2):将 [B,n_heads,T,head_dim] 调整为 [B,T,n_heads,head_dim],把序列维放回前面。
        # reshape(B, T, C):再把 n_heads 与 head_dim 合并为 C,得到 [B,T,C]。
        out = attn_output.transpose(1, 2).reshape(B, T, C)
        # 输出投影 + dropout。
        out = self.resid_dropout(self.proj(out))
        return out

5. FlowBlock

# Flow-based diffusion 使用的 Transformer 基本块。
# 结构由三部分组成:自注意力(Self-Attention)+ 可选交叉注意力(Cross-Attention)+ SwiGlu MLP。
# 同时通过 AdaLN 调制(shift/scale/gate)让条件向量 c 动态控制每个子层的行为。
# 参数:
#   dim:输入维度。
#   heads:注意力头数。
#   attn_pdrop:注意力 Dropout 概率。
#   resid_pdrop:投影后的 Dropout 概率。
#   mlp_pdrop:MLP 分支的 Dropout 概率。
#   use_cross_attn:是否启用交叉注意力分支。
#   use_rope:是否在自注意力中使用 RoPE。
#   query_seq_len:自注意力序列最大长度(RoPE 预计算使用)。
#   rope_theta:自注意力 RoPE 的 theta。
#   lora_dim:AdaLN 调制器中的中间维度(先降维,再升维)。
#   use_global_adaln:是否与外部全局 AdaLN 信号逐项相加融合。
class FlowBlock(nn.Module):
    """
    A transformer block for flow-based diffusion. Combines self-attention,
    (optional) cross-attention, and a SwiGlu MLP with adaptive layer normalization modulation.
    
    Args:
        dim: Input dimension.
        heads: Number of attention heads.
        attn_pdrop: Attention dropout rate.
        resid_pdrop: Residual dropout rate.
        mlp_pdrop: MLP dropout rate.
        use_cross_attn: Whether to include a cross-attention layer.
        use_rope: Whether to use rotary positional embeddings in self-attention.
        query_seq_len: Maximum query sequence length.
        rope_theta: Theta parameter for rotary embeddings.
        lora_dim: Intermediate dimension for adaptive normalization modulation.
        use_global_adaln: If True, combines global AdaLN modulation signals.
    """
    def __init__(self,
                 dim: int,
                 heads: int = 8,
                 attn_pdrop: float = 0.1,
                 resid_pdrop: float = 0.1,
                 mlp_pdrop: float = 0.1,
                 use_cross_attn: bool = False,
                 use_rope: bool = False,
                 query_seq_len: int = 128,
                 rope_theta: float = 32,
                 lora_dim: int = 256,
                 use_global_adaln: bool = True) -> None:
        super().__init__()
        self.dim = dim
        self.use_cross_attn = use_cross_attn
        self.use_global_adaln = use_global_adaln

        # 三个归一化层:
        # - norm1:自注意力前;
        # - norm2:cross-attention 前(若不开 cross-attn,则给后续 MLP 前使用);
        # - norm3:仅在 use_cross_attn=True 时存在,专用于 MLP 前。
        self.norm1 = RmsNorm(dim, eps=1e-6)
        self.norm2 = RmsNorm(dim, eps=1e-6)
        self.norm3 = RmsNorm(dim, eps=1e-6) if use_cross_attn else None

        # 自注意力分支(可选 RoPE)。
        self.self_attn = FlowerAttention(dim=dim, n_heads=heads,
                                           attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                                           use_rope=use_rope, max_seq_len=query_seq_len, rope_theta=rope_theta)
        if use_cross_attn:
            # 可选交叉注意力分支:Q 来自主序列,K/V 来自 context。
            self.cross_attn = FlowerCrossAttention(dim=dim, n_heads=heads,
                                                     attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                                                     use_rope=False)
        # 前馈网络分支。
        self.mlp = SwiGlu(dim, dropout=mlp_pdrop)
        # AdaLN 调制器:
        # c: [B,dim] -> [B,lora_dim] -> [B,6*dim]。
        # 最终切成 6 路信号:shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp。
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, lora_dim),  # Down-project
            nn.Linear(lora_dim, 6 * dim)  # Up-project to produce 6 modulation signals
        )

    # 通过 FlowBlock 前向传播。
    # 参数:
    #   cx:输入张量,比如动作 Latent 表示,形状为 [B, L, D]。
    #   c:条件张量(来自外部编码器)。
    #   context:可选的上下文张量,形状常见为 [B,S,D];仅在 use_cross_attn=True 时使用。
    #   custom_attn_mask:自注意力掩码(传给 self_attn)。
    #   custom_cross_attn_mask:交叉注意力掩码(传给 cross_attn)。
    #   is_causal:如果为 True,则使用因果自注意力。
    #   global_adaln:可选的全局 AdaLN 调制信号列表。
    # 返回:
    #   输出张量,形状为 [B,L,D]。
    def forward(self, cx: torch.Tensor, c: torch.Tensor,
                context: Optional[torch.Tensor] = None,
                custom_attn_mask: Optional[torch.Tensor] = None,
                custom_cross_attn_mask: Optional[torch.Tensor] = None,
                is_causal: bool = False,
                global_adaln: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        """
        Forward pass through the FlowBlock.
        
        Args:
            cx: Input tensor for the block (e.g. action latent representations) of shape [B, L, D].
            c: Conditioning tensor (from external encoder).
            context: Optional context tensor for cross-attention.
            custom_attn_mask: Optional attention mask.
            is_causal: If True, uses causal self-attention.
            global_adaln: Optional list of global AdaLN modulation signals.
        
        Returns:
            Output tensor of shape [B, L, D].
        """
        # B: batch size,L: sequence length,D: dimension。
        B, L, D = cx.shape
        # 记录输入残差分支。
        residual = cx

        # 计算 AdaLN 调制信号,并且拆分为 6 路。
        # modulation: [B,6D];chunk 后每一路为 [B,D]。
        modulation = self.adaLN_modulation(c)
        signals = modulation.chunk(6, dim=1)
        # 若启用全局 AdaLN,则逐路相加融合(局部 + 全局)。
        if self.use_global_adaln and global_adaln is not None:
            mod_signals = [signals[i] + global_adaln[i] for i in range(6)]
        else:
            mod_signals = signals
        # 6 路调制信号分为两组:
        # - *_msa 用于 Self-Attention 分支;*_mlp 用于 MLP 分支。
        # - shift_* / scale_* 作用于 modulate(仿射调制);
        # gate_* 控制该分支输出加回主干时的强度。
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod_signals

        # Self-Attention 子层:
        # 1) norm1;
        # 2) 用 (shift_msa, scale_msa) 做条件调制;
        # 3) 进入 self_attn;
        # 4) 用 gate_msa 做门控后加回残差。
        x_norm = self.norm1(cx)
        x_mod = modulate(x_norm, shift_msa, scale_msa)
        x_self = self.self_attn(x_mod, custom_attn_mask=custom_attn_mask, is_causal=is_causal)
        # gate_msa 原始形状为 [B,D],unsqueeze(1) 后变为 [B,1,D],可在序列维 L 上广播到 [B,L,D] 与 x_self 对齐。
        x_out = residual + gate_msa.unsqueeze(1) * x_self

        # 可选的 Cross-Attention 子层:
        # 仅当 use_cross_attn=True 时执行;若未提供 context 则报错。
        # 该分支输出与 x_out 直接做残差相加。
        if self.use_cross_attn:
            if context is None:
                raise ValueError("Context is required for cross-attention.")
            x_norm = self.norm2(x_out)
            x_cross = self.cross_attn(x_norm, context, custom_attn_mask=custom_cross_attn_mask)
            x_out = x_out + x_cross

        # MLP 子层:
        # 若存在 cross-attn,则使用 norm3;否则复用 norm2。
        # 与 self-attn 类似,先调制再过 MLP,最后用 gate_mlp 门控后残差相加。
        norm_layer = self.norm3 if self.use_cross_attn else self.norm2
        x_norm = norm_layer(x_out)
        x_mod = modulate(x_norm, shift_mlp, scale_mlp)
        mlp_out = self.mlp(x_mod)
        x_final = x_out + gate_mlp.unsqueeze(1) * mlp_out

        # 最终作用:将“自注意力 + 可选交叉注意力 + MLP”在 AdaLN 条件调制下串联,输出更新后的序列表示。
        return x_final

6. 编码器类

6.1. TimestepEmbedder

参考 6.3。

class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = 1000 * torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half) / half
        ).to(t.device)
        args = t[:, None] * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    # @torch.compile()
    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
            dtype=next(self.parameters()).dtype
        )
        t_emb = self.mlp(t_freq)
        return t_emb

6.2. SharedAdaLNController

# 根据提供的条件向量生成调制信号列表。调制信号可以控制 FlowBlock 的子层。
# 参数:
#   dim:调制信号的维度。
#   global_conddim:条件向量的维度。
#   use_cross_attn:是否使用交叉注意力。
class SharedAdaLNController(nn.Module):
    """Shared Adaptive Layer Normalization controller for all DiT blocks"""
    def __init__(self, dim, global_conddim, use_cross_attn=False):
        super().__init__()
        # Number of modulation signals needed
        # 条件信号的数量与 FlowBlock 的结构相关。
        num_mod_signals = 9 if use_cross_attn else 6
        self.modCX = nn.Sequential(
            nn.SiLU(),
            nn.Linear(global_conddim, num_mod_signals * dim, bias=False),
        )
        self.use_cross_attn = use_cross_attn

        # Zero initialize the final linear layer
        nn.init.zeros_(self.modCX[-1].weight)
        self.use_cross_attn = use_cross_attn

    def forward(self, global_cond):
        # global_cond: [B, global_conddim]
        # mod_signals: [B, num_mod_signals * dim]
        mod_signals = self.modCX(global_cond)
        if self.use_cross_attn:
            # Split into 9 parts for cross-attention path
            # 例:若 dim=64,则返回 9 个 [B,64] 张量。
            return mod_signals.chunk(9, dim=-1)
        else:
            # Split into 6 parts for self-attention only path
            # 例:若 dim=64,则返回 6 个 [B,64] 张量。
            return mod_signals.chunk(6, dim=-1)

6.3. FreqEmbedder

# 将标量时间步嵌入为向量表示。
class FreqEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    # 创建正弦-余弦时间步嵌入。
    # 参数:
    #   t:形状为 [N] 的时间步张量(N 为 batch_size),可为浮点数。
    #   dim:输出维度。
    #   max_period:控制嵌入的最小频率。
    # 返回:
    #   形状为 [N, dim] 的位置嵌入张量。
    def timestep_embedding(self, t, dim, max_period=1000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        #
        # 公式与常见的正余弦位置编码一致:
        #   freqs_i = exp(-log(max_period) * i / half)
        #   args = t * freqs
        #   embedding = [cos(args), sin(args)]
        # 例:N=2, dim=6 -> half=3,输出形状 [2,6]。
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(
                start=0, end=half, dtype=torch.float32, device=t.device) / half
        )
        # shape 对齐说明:
        # - t 通常为 [N],t[:, None] 后为 [N,1]
        # - freqs 为 [half],freqs[None] 后为 [1,half]
        # - 二者按广播相乘得到 args: [N,half]
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            # 当 dim 为奇数时,cos/sin 拼接后只有 2*half (= dim-1) 维,
            # 这里补一列 0 使最后一维精确等于 dim,便于与后续线性层输入维度对齐。
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        # t: [B] -> [B, frequency_embedding_size] -> [B, hidden_size]
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

6.4. ActionSpaceEmbedderParameter

# 使用可直接学习的参数,将离散的动作索引转换成 Embedding。
class ActionSpaceEmbedderParameter(nn.Module):
    """
    Embeds discrete action indices using direct learnable parameters.
    """
    def __init__(
        self,
        hidden_size,
        max_actions=11,  # 0-10 inclusive
        embedding_size=256,
    ):
        super().__init__()
        # 用于每个动作的可直接学习参数
        self.action_embeddings = nn.Parameter(
            torch.randn(max_actions, embedding_size) * 0.02  # Small initialization
        )
        self.mlp = nn.Sequential(
            nn.Linear(embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        # 最大动作数量
        self.max_actions = max_actions
        
    # 通过参数查表,将动作索引转换为 Embedding。
    # 参数:
    #   action_indices:形状为 (batch_size,) 的整数张量,取值范围 [0, max_actions-1]。
    def forward(self, action_indices):
        """
        Convert action indices to embeddings using parameter lookup.
        
        Args:
            action_indices: tensor of shape (batch_size,) containing integers in [0, max_actions-1]
        """
        # 索引进参数矩阵。
        # 例子:action_indices=[2,0,2] 时,那么 embeddings 形状为 [3, embedding_size],
        # 第 1、3 行都来自 action_embeddings[2]。
        embeddings = self.action_embeddings[action_indices]
        
        # Process through MLP
        embeddings = embeddings
        # output 形状为 [batch_size, hidden_size]。
        output = self.mlp(embeddings)
        
        return output

		# 返回所有可能动作的 Embedding。
    def get_all_embeddings(self):
        """Returns embeddings for all possible actions."""
        return self.mlp(self.action_embeddings)

6.5. ZeroEncoder

# 主要用于占位,保持接口一致。
class ZeroEncoder(nn.Module):
    def __init__(self, dit_dim, device):
        super(ZeroEncoder, self).__init__()
        self.dit_dim = dit_dim
        self.device = device

    # 不管 x 的形状如何,永远返回全零矩阵,行数等于 x 的第 0 维的长度,列数等于 dit_dim。
    def forward(self, x):
        return torch.zeros((x.shape[0], self.dit_dim), device=self.device)