FlowerVLA 源码解析 1 - 网络部分
1. 工具函数
# 描述:返回大于或等于 n,并且是 k 的倍数的最小数字。
# 示例:n = 10,k = 3,那么返回 12。
def find_multiple(n: int, k: int) -> int:
"""
Returns the smallest number greater than or equal to n that is a multiple of k.
"""
return n if n % k == 0 else n + k - (n % k)
# 描述:归一化 x,并且无需维护运行时统计。
def stateless_norm(x: torch.Tensor) -> torch.Tensor:
"""
Normalizes x without maintaining running statistics.
"""
# 在最后一个维度上计算均值,并且保持维度,方便广播。
mean = x.mean(dim=-1, keepdim=True)
# 在最后一个维度上计算方差,并且保持维度,方便广播。
# σ² = Σ(xi - μ)² / N
var = x.var(dim=-1, keepdim=True, unbiased=False)
# 归一化:减去均值,除以标准差。(1e-6 用于在方差为 0 或极小时,稳定分母)
return (x - mean) / torch.sqrt(var + 1e-6)
# 根据 shift、scale 对 x 应用调制(由调制信号控制的仿射变换)。
# 调制公式:out = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)。
# 典型形状:x 为 [B,L,D],shift、scale 为 [B,D];unsqueeze(1) 得 [B,1,D],沿序列维 L 广播。scale=0 时缩放因子为 1(即恒等)。
# 示例:设 B=1,L=2,D=2,x = [[ [1,2], [3,4] ]];scale、shift 形状 [1,2],取值 [0,1] 与 [10,0]。
# 沿 L 广播后两行共用同一组 scale/shift:第一行得 [11,4],第二行得 [13,8],即 [[ [11,4], [13,8] ]],out 的形状仍为 [1,2,2]。
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Applies a modulation to x given shift and scale signals.
The modulation formula: x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)2. SwiGlu
# Transformer 前馈层中使用的 SwiGlu MLP 激活实现。
# 参数:
# dim:输入维度。
# hidden_dim:隐藏层维度;为 None 时默认为 4 * dim。
# dropout:Dropout 概率。
# output_dim:输出维度;为 None 时默认为 dim。
class SwiGlu(nn.Module):
"""
An implementation of the SwiGlu MLP activation as used in transformer feedforward layers.
Args:
dim: Input dimension.
hidden_dim: Dimension of the hidden layer. If None, defaults to 4 * dim.
dropout: Dropout probability.
output_dim: Output dimension. Defaults to dim.
"""
def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0, output_dim: Optional[int] = None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
# 隐藏层维度等于 hidden_dim 的 2/3,再四舍五入到 256 的倍数。
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
if output_dim is None:
output_dim = dim
self.fc1 = nn.Linear(dim, n_hidden, bias=False)
self.fc2 = nn.Linear(dim, n_hidden, bias=False)
self.proj = nn.Linear(n_hidden, output_dim, bias=False)
# Dropout 的主要作用是作为正则化方法:迫使网络不依赖少数固定神经元组合,削弱特征与共适应,提升对缺失扰动的鲁棒性;
# 亦可视为对大量随机子网络做近似集成,从而降低过拟合风险。直观上,每轮在随机子网络上更新参数以达成上述效果。
#
# 训练时以概率 (p) 随机将一部分神经元输出置零,等价于暂时断开对应连接;
# 被置零位置在计算图分支上梯度为 0(链式法则乘 0),该轮误差不沿此路径回传,与之相连的权重亦收不到来自被丢弃单元的梯度,其余路径照常反传。
# 评估时关闭 dropout,使用完整网络推理。
# 常见的 inverted dropout 在训练时对保留位置除以 (1-p),使保留单元输出相对 mask 前被放大,
# 并且与推理阶段不做缩放时的期望尺度一致;
# 若采用经典实现(训练时保留位置不缩放),则保留单元不放大,推理时常对权重或输出整体乘以 (1-p) 以对齐期望。
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through SwiGlu MLP.
"""
# 对张量逐元素应用 SiLU(Sigmoid Linear Unit)激活。
x1 = F.silu(self.fc1(x))
x2 = self.fc2(x)
# 计算 x1 和 x2 的哈达玛积
x = x1 * x2
x = self.dropout(x)
x = self.proj(x)
return x3. 旋转位置编码(Rotary Positional Embedding)
# 预计算用于 1 维旋转嵌入的余弦和正弦频率矩阵。
# 参数:
# dim:Embedding 的维度。
# max_seq_len:最大序列长度。
# theta:频率缩放因子;默认 10000.0。
# 返回:
# (cosine, sine) - 余弦和正弦频率矩阵,形状均为 [max_seq_len, dim/2]。
def precompute_freqs_1d(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precomputes cosine and sine frequency matrices for 1D rotary embeddings.
Returns:
(cosine, sine): Tensors of shape [max_seq_len, dim/2].
"""
# Shape: [dim/2],比如 dim = 8,freqs = [0, 2, 4, 6]
freqs = torch.arange(0, dim, 2).float()
# Shape: [dim/2],比如 [10000^(-0/dim), 10000^(-2/dim), 10000^(-4/dim), 10000^(-6/dim)]
freqs = theta ** (-freqs / dim)
# Shape: [max_seq_len],比如 max_seq_len = 4,positions = [0, 1, 2, 3]
positions = torch.arange(max_seq_len).float()
# positions.unsqueeze(1):
# - Shape: [max_seq_len, 1]
# - 比如 [[0], [1], [2], [3]]
#
# freqs.unsqueeze(0):
# - Shape: [1, dim/2]
# - 比如 [[10000^(-0/dim), 10000^(-2/dim), 10000^(-4/dim), 10000^(-6/dim)]]
#
# angles:
# - Shape: [max_seq_len, dim/2]
# - 比如 [
# [0*10000^(-0/dim), 0*10000^(-2/dim), 0*10000^(-4/dim), 0*10000^(-6/dim)],
# [1*10000^(-0/dim), 1*10000^(-2/dim), 1*10000^(-4/dim), 1*10000^(-6/dim)],
# [2*10000^(-0/dim), 2*10000^(-2/dim), 2*10000^(-4/dim), 2*10000^(-6/dim)],
# [3*10000^(-0/dim), 3*10000^(-2/dim), 3*10000^(-4/dim), 3*10000^(-6/dim)]
# ]
angles = positions.unsqueeze(1) * freqs.unsqueeze(0)
return angles.cos(), angles.sin()
# 将旋转位置编码(RoPE)应用到 q 和 k。
# 参数:
# q:查询张量,形状为 [批量大小,注意力头的数量,序列长度,每个头的维度]。
# k:键张量,形状与 q 相同。
# cos, sin:Cosine 和 Sine 频率张量,形状为 [最大序列长度,每个头的维度/2]。
# position_ids:可选的位置索引张量;若为 None,则使用按顺序的位置(0、1、2、…、序列长度-1)。
# 返回:
# 施加旋转位置编码的 q 和 k。
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor,
position_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q: Query tensor of shape [B, heads, seq_len, head_dim].
k: Key tensor with the same shape as q.
cos, sin: Cosine and sine frequency tensors of shape [max_seq_len, head_dim/2].
position_ids: Optional tensor with position indices; if None, uses sequential positions.
Returns:
A tuple (q_rot, k_rot) with rotary embeddings applied.
"""
seq_len = q.size(-2)
if position_ids is None:
position_ids = torch.arange(seq_len, device=q.device)
cos = cos[position_ids].unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim/2]
sin = sin[position_ids].unsqueeze(0).unsqueeze(0)
# 将 q 和 k 的最后一维一分为二,形状均为 [B, heads, seq_len, head_dim/2]
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
q_rot = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1)
k_rot = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1)
return q_rot, k_rot4. 注意力模块
4.1. 多头自注意力
# 多头自注意力模块,可选使用旋转位置编码。
# 参数:
# dim:输入维度,也叫 Transformer 的维度。
# n_heads:注意力头数。
# attn_pdrop:注意力概率的 Dropout 概率。
# resid_pdrop:输出投影的 Dropout 概率。
# use_rope:是否使用旋转位置编码。
# max_seq_len:用于预计算旋转频率的最大序列长度。
# rope_theta:旋转嵌入的 theta 值。
class FlowerAttention(nn.Module):
"""
Multi-head self-attention module with optional rotary positional embeddings.
Args:
dim: Input dimension.
n_heads: Number of attention heads.
attn_pdrop: Dropout rate on the attention probabilities.
resid_pdrop: Dropout rate on the output projection.
use_rope: Whether to apply rotary embeddings.
max_seq_len: Maximum sequence length for precomputed rotary frequencies.
rope_theta: Theta value for rotary embeddings.
"""
def __init__(self,
dim: int,
n_heads: int,
attn_pdrop: float = 0.1,
resid_pdrop: float = 0.1,
use_rope: bool = False,
max_seq_len: int = 120,
rope_theta: float = 32) -> None:
super().__init__()
# 输入维度必须是注意力头数的倍数。
assert dim % n_heads == 0, "Dimension must be divisible by number of heads."
self.n_heads = n_heads
self.head_dim = dim // n_heads
# 缩放点积注意力(SDPA)中的缩放因子,即 1/sqrt(head_dim)。
self.scale = self.head_dim ** -0.5
# 融和的 W_Q、W_K、W_V 矩阵。
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
# W_O 投影矩阵。
self.proj = nn.Linear(dim, dim, bias=False)
self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
# RmsNorm(Root Mean Square Layer Normalization)是 Layer Normalization 的一种变体,用于加速训练。
self.q_norm = RmsNorm(self.head_dim, eps=1e-6)
self.k_norm = RmsNorm(self.head_dim, eps=1e-6)
self.use_rope = use_rope
if use_rope:
self.rope_theta = rope_theta
cos, sin = precompute_freqs_1d(self.head_dim, max_seq_len, theta=rope_theta)
# register_buffer 是 nn.Module 中用于注册不参与梯度更新的持久化张量的方法。
# 这些张量随模型一起保存(state_dict)和移动设备(.to(device)),但不被优化器更新(requires_grad=False)。
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
self.max_seq_len = max_seq_len
# 自注意力的前向传播。
# 参数:
# x:输入张量,形状为 [B, seq_len, dim]。
# custom_attn_mask:可选的自定义注意力掩码。
# is_causal:如果为 True,则应用因果掩码。
# 因果掩码是一种注意力掩码:在自回归模型里,它规定每个位置在算注意力时,只能使用当前位置及之前的 Token 信息,
# 不能用到更后面的 Token,从而避免模型在训练或生成时“偷看未来”。
# 返回:
# 经过注意力计算和投影后的张量,形状为 [B, seq_len, dim]。
def forward(self, x: torch.Tensor,
custom_attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False) -> torch.Tensor:
"""
Forward pass for self-attention.
Args:
x: Input tensor of shape [B, seq_len, dim].
custom_attn_mask: Optional attention mask.
is_causal: If True, applies causal masking.
Returns:
Tensor of shape [B, seq_len, dim] after attention and projection.
"""
B, T, C = x.size()
# Fused QKV:self.qkv 为 Linear(dim, 3*dim),一次前向得到拼接的 Q/K/V,形状 [B, T, 3*dim]。
# reshape:末维 3*dim → 3 组 × (n_heads, head_dim),形状 [B,T,3,n_heads,head_dim]。
# permute(2,0,3,1,4):得到 [3,B,n_heads,T,head_dim],第 0 维即 Q/K/V,与下方 unbind(0) 对应。
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# unbind(0):沿第 0 维拆成三个张量,形状均为 [B,n_heads,T,head_dim]。
q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
if self.use_rope:
q, k = apply_rotary_pos_emb(q, k, self.cos, self.sin)
# 按需构建因果掩码。
if is_causal and custom_attn_mask is None:
# triu(..., diagonal=1):上三角(主对角线)之上为 True,即 (i,j) 在 j>i 时为 True,表示未来位需被挡住。
mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
# 扩成 [1,1,T,T],与 q/k 的 batch、head 维广播对齐,供后续 ~mask 传入 SDPA。
mask = mask.unsqueeze(0).unsqueeze(0)
elif custom_attn_mask is not None:
# custom_attn_mask 形状为 [B, T, T],缺少 head 维。
# unsqueeze(1) 在第 1 维插入一个大小为 1 的 head 维,变为 [B, 1, T, T]。
# expand(-1, n_heads, -1, -1) 在逻辑上将该维扩到 n_heads,形状变为 [B, n_heads, T, T]。
# 这里不会复制出 n_heads 份数据,而是所有 head 共享同一份掩码视图。
mask = custom_attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
else:
mask = None
# SDPA 输入的 q/k/v 形状均为 [B, n_heads, T, head_dim],输出形状与之相同。
# 本代码里的 mask 约定是 True=屏蔽;但 SDPA 的 bool mask 约定是 True=可见,所以传入前要取反(~mask)。
# dropout_p 仅在训练时生效;
# scale 使用 head_dim**-0.5(在 __init__ 中预先计算)。
# is_causal 只在未提供 custom_attn_mask 时启用,避免同时叠加两套因果约束。
# 最终作用:在给定掩码规则下计算多头注意力,为每个 Token 生成融合上下文后的表示。
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None if mask is None else ~mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
scale=self.scale,
is_causal=is_causal if custom_attn_mask is None else False
)
# [B,n_heads,T,head_dim] -> transpose -> [B,T,n_heads,head_dim] -> reshape 合并 head,得 [B,T,dim]。
out = attn_output.transpose(1, 2).reshape(B, T, C)
# proj:输出投影 W_O;resid_dropout:投影后的 dropout。
out = self.resid_dropout(self.proj(out))
return out4.2. 多头交叉注意力
# 多头交叉注意力模块(Cross-Attention),可选使用旋转位置编码。
# 与 FlowerAttention(自注意力)不同:Q 来自当前序列 x,K/V 来自外部 context。
# 参数:
# dim:输入/输出维度。
# n_heads:注意力头数。
# attn_pdrop:注意力权重 Dropout 概率(仅训练时生效)。
# resid_pdrop:输出投影后的 Dropout 概率。
# use_rope:是否对 q、k 应用旋转位置编码。
# query_seq_len:query 侧(x)的最大长度,用于预计算 q 的 RoPE 频率。
# context_seq_len:context 侧的最大长度,用于预计算 k 的 RoPE 频率。
# rope_theta:query 侧 RoPE 的 theta。
# context_rope_theta:context 侧 RoPE 的 theta(可与 query 侧不同)。
class FlowerCrossAttention(nn.Module):
"""
Cross-attention module with optional rotary embeddings.
Args:
dim: Input and output dimension.
n_heads: Number of attention heads.
attn_pdrop: Dropout rate on the attention weights.
resid_pdrop: Dropout rate on the output.
use_rope: Whether to apply rotary embeddings.
query_seq_len: Maximum length for queries.
context_seq_len: Maximum length for context.
rope_theta: Theta for query rotary embeddings.
context_rope_theta: Theta for context rotary embeddings.
"""
def __init__(self,
dim: int,
n_heads: int,
attn_pdrop: float = 0.1,
resid_pdrop: float = 0.1,
use_rope: bool = False,
query_seq_len: int = 64,
context_seq_len: int = 384,
rope_theta: float = 32,
context_rope_theta: float = 1000.0) -> None:
super().__init__()
# 输入/输出维度必须是注意力头数的倍数。
assert dim % n_heads == 0, "Dimension must be divisible by number of heads."
self.n_heads = n_heads
self.head_dim = dim // n_heads
# SDPA 缩放系数:1/sqrt(head_dim)。
self.scale = self.head_dim ** -0.5
# 交叉注意力中 Q/K/V 分别由独立线性层产生:
# q_proj 作用于 x(当前序列),k_proj/v_proj 作用于 context(外部条件)。
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
# 输出投影 W_O。
self.proj = nn.Linear(dim, dim, bias=False)
self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
# 按 head_dim 做 RMS 归一化,稳定 q/k 的尺度。
self.q_norm = RmsNorm(self.head_dim, eps=1e-6)
self.k_norm = RmsNorm(self.head_dim, eps=1e-6)
self.use_rope = use_rope
if use_rope:
# 分别为 query 与 context 预计算 RoPE 频率表;二者长度和 theta 可独立设置。
q_cos, q_sin = precompute_freqs_1d(self.head_dim, query_seq_len, theta=rope_theta)
k_cos, k_sin = precompute_freqs_1d(self.head_dim, context_seq_len, theta=context_rope_theta)
# register_buffer:随模型保存和迁移设备,但不参与梯度更新。
self.register_buffer("q_cos", q_cos)
self.register_buffer("q_sin", q_sin)
self.register_buffer("k_cos", k_cos)
self.register_buffer("k_sin", k_sin)
self.query_seq_len = query_seq_len
self.context_seq_len = context_seq_len
self.rope_theta = rope_theta
self.context_rope_theta = context_rope_theta
# 在 x(queries)与 context(keys 和 values)之间应用交叉注意力。
# 参数:
# x:查询张量(query),形状为 [B, seq_len, dim];线性投影后得到注意力中的 Q。
# context:上下文张量(context),形状为 [B, context_len, dim];线性投影后得到注意力中的 K/V。
# custom_attn_mask:可选注意力掩码;常见形状为 [B, context_len],用于控制 query 可见的 context 位置。
# 例子:B=2、context_len=5 时,mask 可为 [[1,1,1,0,0],[1,1,0,0,0]],
# 表示第 1 个样本仅可看前 3 个 context token,第 2 个样本仅可看前 2 个 context token。
# 返回:
# 形状为 [B, seq_len, dim] 的张量。
def forward(self, x: torch.Tensor, context: torch.Tensor,
custom_attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Applies cross-attention between x (queries) and context (keys and values).
Args:
x: Query tensor of shape [B, seq_len, dim].
context: Context tensor of shape [B, context_len, dim].
custom_attn_mask: Optional attention mask.
Returns:
Tensor of shape [B, seq_len, dim].
"""
# x: [B,T,C],context: [B,S,C];T 为 query 长度,S 为 context 长度。
B, T, C = x.size()
_, S, _ = context.size()
# 线性映射及分头(分两步):
# q:先 reshape [B,T,C] -> [B,T,n_heads,head_dim],再 permute -> [B,n_heads,T,head_dim]。
# k/v:先 reshape [B,S,C] -> [B,S,n_heads,head_dim],再 permute -> [B,n_heads,S,head_dim]。
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.k_proj(context).reshape(B, S, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.v_proj(context).reshape(B, S, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# 对 q/k 做按 head 的 RMS 归一化。
q = self.q_norm(q)
k = self.k_norm(k)
if self.use_rope:
# 给 q/k 添加 RoPE。
# 这里将同一张量同时作为 q/k 传入,仅复用 apply_rotary_pos_emb 的实现,实际只取第一个返回值。
q, _ = apply_rotary_pos_emb(q, q, self.q_cos, self.q_sin)
k, _ = apply_rotary_pos_emb(k, k, self.k_cos, self.k_sin)
if custom_attn_mask is not None:
# 先调整 mask 形状,使其与 q 的查询长度维度对齐。
#
# 典型输入 mask 为 [B,S](按 context 位置指示可见性)。
# 先扩成 [B,1,1,S],再沿 head 与 query 长度广播到 [B,n_heads,T,S],
# 以匹配 SDPA 所需的 [B,n_heads,Lq,Lk] 掩码形状。
mask = custom_attn_mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(-1, self.n_heads, q.size(2), -1)
# 交叉注意力通常不使用因果约束(is_causal=False),由 custom_attn_mask 控制可见性。
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
scale=self.scale,
is_causal=False
)
else:
# 无自定义掩码时,query 可关注全部 context 位置。
attn_output = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_dropout.p if self.training else 0.0,
scale=self.scale,
is_causal=False
)
# transpose(1, 2):将 [B,n_heads,T,head_dim] 调整为 [B,T,n_heads,head_dim],把序列维放回前面。
# reshape(B, T, C):再把 n_heads 与 head_dim 合并为 C,得到 [B,T,C]。
out = attn_output.transpose(1, 2).reshape(B, T, C)
# 输出投影 + dropout。
out = self.resid_dropout(self.proj(out))
return out5. FlowBlock
# Flow-based diffusion 使用的 Transformer 基本块。
# 结构由三部分组成:自注意力(Self-Attention)+ 可选交叉注意力(Cross-Attention)+ SwiGlu MLP。
# 同时通过 AdaLN 调制(shift/scale/gate)让条件向量 c 动态控制每个子层的行为。
# 参数:
# dim:输入维度。
# heads:注意力头数。
# attn_pdrop:注意力 Dropout 概率。
# resid_pdrop:投影后的 Dropout 概率。
# mlp_pdrop:MLP 分支的 Dropout 概率。
# use_cross_attn:是否启用交叉注意力分支。
# use_rope:是否在自注意力中使用 RoPE。
# query_seq_len:自注意力序列最大长度(RoPE 预计算使用)。
# rope_theta:自注意力 RoPE 的 theta。
# lora_dim:AdaLN 调制器中的中间维度(先降维,再升维)。
# use_global_adaln:是否与外部全局 AdaLN 信号逐项相加融合。
class FlowBlock(nn.Module):
"""
A transformer block for flow-based diffusion. Combines self-attention,
(optional) cross-attention, and a SwiGlu MLP with adaptive layer normalization modulation.
Args:
dim: Input dimension.
heads: Number of attention heads.
attn_pdrop: Attention dropout rate.
resid_pdrop: Residual dropout rate.
mlp_pdrop: MLP dropout rate.
use_cross_attn: Whether to include a cross-attention layer.
use_rope: Whether to use rotary positional embeddings in self-attention.
query_seq_len: Maximum query sequence length.
rope_theta: Theta parameter for rotary embeddings.
lora_dim: Intermediate dimension for adaptive normalization modulation.
use_global_adaln: If True, combines global AdaLN modulation signals.
"""
def __init__(self,
dim: int,
heads: int = 8,
attn_pdrop: float = 0.1,
resid_pdrop: float = 0.1,
mlp_pdrop: float = 0.1,
use_cross_attn: bool = False,
use_rope: bool = False,
query_seq_len: int = 128,
rope_theta: float = 32,
lora_dim: int = 256,
use_global_adaln: bool = True) -> None:
super().__init__()
self.dim = dim
self.use_cross_attn = use_cross_attn
self.use_global_adaln = use_global_adaln
# 三个归一化层:
# - norm1:自注意力前;
# - norm2:cross-attention 前(若不开 cross-attn,则给后续 MLP 前使用);
# - norm3:仅在 use_cross_attn=True 时存在,专用于 MLP 前。
self.norm1 = RmsNorm(dim, eps=1e-6)
self.norm2 = RmsNorm(dim, eps=1e-6)
self.norm3 = RmsNorm(dim, eps=1e-6) if use_cross_attn else None
# 自注意力分支(可选 RoPE)。
self.self_attn = FlowerAttention(dim=dim, n_heads=heads,
attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
use_rope=use_rope, max_seq_len=query_seq_len, rope_theta=rope_theta)
if use_cross_attn:
# 可选交叉注意力分支:Q 来自主序列,K/V 来自 context。
self.cross_attn = FlowerCrossAttention(dim=dim, n_heads=heads,
attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
use_rope=False)
# 前馈网络分支。
self.mlp = SwiGlu(dim, dropout=mlp_pdrop)
# AdaLN 调制器:
# c: [B,dim] -> [B,lora_dim] -> [B,6*dim]。
# 最终切成 6 路信号:shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp。
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, lora_dim), # Down-project
nn.Linear(lora_dim, 6 * dim) # Up-project to produce 6 modulation signals
)
# 通过 FlowBlock 前向传播。
# 参数:
# cx:输入张量,比如动作 Latent 表示,形状为 [B, L, D]。
# c:条件张量(来自外部编码器)。
# context:可选的上下文张量,形状常见为 [B,S,D];仅在 use_cross_attn=True 时使用。
# custom_attn_mask:自注意力掩码(传给 self_attn)。
# custom_cross_attn_mask:交叉注意力掩码(传给 cross_attn)。
# is_causal:如果为 True,则使用因果自注意力。
# global_adaln:可选的全局 AdaLN 调制信号列表。
# 返回:
# 输出张量,形状为 [B,L,D]。
def forward(self, cx: torch.Tensor, c: torch.Tensor,
context: Optional[torch.Tensor] = None,
custom_attn_mask: Optional[torch.Tensor] = None,
custom_cross_attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
global_adaln: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
"""
Forward pass through the FlowBlock.
Args:
cx: Input tensor for the block (e.g. action latent representations) of shape [B, L, D].
c: Conditioning tensor (from external encoder).
context: Optional context tensor for cross-attention.
custom_attn_mask: Optional attention mask.
is_causal: If True, uses causal self-attention.
global_adaln: Optional list of global AdaLN modulation signals.
Returns:
Output tensor of shape [B, L, D].
"""
# B: batch size,L: sequence length,D: dimension。
B, L, D = cx.shape
# 记录输入残差分支。
residual = cx
# 计算 AdaLN 调制信号,并且拆分为 6 路。
# modulation: [B,6D];chunk 后每一路为 [B,D]。
modulation = self.adaLN_modulation(c)
signals = modulation.chunk(6, dim=1)
# 若启用全局 AdaLN,则逐路相加融合(局部 + 全局)。
if self.use_global_adaln and global_adaln is not None:
mod_signals = [signals[i] + global_adaln[i] for i in range(6)]
else:
mod_signals = signals
# 6 路调制信号分为两组:
# - *_msa 用于 Self-Attention 分支;*_mlp 用于 MLP 分支。
# - shift_* / scale_* 作用于 modulate(仿射调制);
# gate_* 控制该分支输出加回主干时的强度。
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod_signals
# Self-Attention 子层:
# 1) norm1;
# 2) 用 (shift_msa, scale_msa) 做条件调制;
# 3) 进入 self_attn;
# 4) 用 gate_msa 做门控后加回残差。
x_norm = self.norm1(cx)
x_mod = modulate(x_norm, shift_msa, scale_msa)
x_self = self.self_attn(x_mod, custom_attn_mask=custom_attn_mask, is_causal=is_causal)
# gate_msa 原始形状为 [B,D],unsqueeze(1) 后变为 [B,1,D],可在序列维 L 上广播到 [B,L,D] 与 x_self 对齐。
x_out = residual + gate_msa.unsqueeze(1) * x_self
# 可选的 Cross-Attention 子层:
# 仅当 use_cross_attn=True 时执行;若未提供 context 则报错。
# 该分支输出与 x_out 直接做残差相加。
if self.use_cross_attn:
if context is None:
raise ValueError("Context is required for cross-attention.")
x_norm = self.norm2(x_out)
x_cross = self.cross_attn(x_norm, context, custom_attn_mask=custom_cross_attn_mask)
x_out = x_out + x_cross
# MLP 子层:
# 若存在 cross-attn,则使用 norm3;否则复用 norm2。
# 与 self-attn 类似,先调制再过 MLP,最后用 gate_mlp 门控后残差相加。
norm_layer = self.norm3 if self.use_cross_attn else self.norm2
x_norm = norm_layer(x_out)
x_mod = modulate(x_norm, shift_mlp, scale_mlp)
mlp_out = self.mlp(x_mod)
x_final = x_out + gate_mlp.unsqueeze(1) * mlp_out
# 最终作用:将“自注意力 + 可选交叉注意力 + MLP”在 AdaLN 条件调制下串联,输出更新后的序列表示。
return x_final6. 编码器类
6.1. TimestepEmbedder
参考 6.3。
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = 1000 * torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
# @torch.compile()
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb6.2. SharedAdaLNController
# 根据提供的条件向量生成调制信号列表。调制信号可以控制 FlowBlock 的子层。
# 参数:
# dim:调制信号的维度。
# global_conddim:条件向量的维度。
# use_cross_attn:是否使用交叉注意力。
class SharedAdaLNController(nn.Module):
"""Shared Adaptive Layer Normalization controller for all DiT blocks"""
def __init__(self, dim, global_conddim, use_cross_attn=False):
super().__init__()
# Number of modulation signals needed
# 条件信号的数量与 FlowBlock 的结构相关。
num_mod_signals = 9 if use_cross_attn else 6
self.modCX = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, num_mod_signals * dim, bias=False),
)
self.use_cross_attn = use_cross_attn
# Zero initialize the final linear layer
nn.init.zeros_(self.modCX[-1].weight)
self.use_cross_attn = use_cross_attn
def forward(self, global_cond):
# global_cond: [B, global_conddim]
# mod_signals: [B, num_mod_signals * dim]
mod_signals = self.modCX(global_cond)
if self.use_cross_attn:
# Split into 9 parts for cross-attention path
# 例:若 dim=64,则返回 9 个 [B,64] 张量。
return mod_signals.chunk(9, dim=-1)
else:
# Split into 6 parts for self-attention only path
# 例:若 dim=64,则返回 6 个 [B,64] 张量。
return mod_signals.chunk(6, dim=-1)6.3. FreqEmbedder
# 将标量时间步嵌入为向量表示。
class FreqEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
# 创建正弦-余弦时间步嵌入。
# 参数:
# t:形状为 [N] 的时间步张量(N 为 batch_size),可为浮点数。
# dim:输出维度。
# max_period:控制嵌入的最小频率。
# 返回:
# 形状为 [N, dim] 的位置嵌入张量。
def timestep_embedding(self, t, dim, max_period=1000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
#
# 公式与常见的正余弦位置编码一致:
# freqs_i = exp(-log(max_period) * i / half)
# args = t * freqs
# embedding = [cos(args), sin(args)]
# 例:N=2, dim=6 -> half=3,输出形状 [2,6]。
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(
start=0, end=half, dtype=torch.float32, device=t.device) / half
)
# shape 对齐说明:
# - t 通常为 [N],t[:, None] 后为 [N,1]
# - freqs 为 [half],freqs[None] 后为 [1,half]
# - 二者按广播相乘得到 args: [N,half]
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
# 当 dim 为奇数时,cos/sin 拼接后只有 2*half (= dim-1) 维,
# 这里补一列 0 使最后一维精确等于 dim,便于与后续线性层输入维度对齐。
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
# t: [B] -> [B, frequency_embedding_size] -> [B, hidden_size]
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
6.4. ActionSpaceEmbedderParameter
# 使用可直接学习的参数,将离散的动作索引转换成 Embedding。
class ActionSpaceEmbedderParameter(nn.Module):
"""
Embeds discrete action indices using direct learnable parameters.
"""
def __init__(
self,
hidden_size,
max_actions=11, # 0-10 inclusive
embedding_size=256,
):
super().__init__()
# 用于每个动作的可直接学习参数
self.action_embeddings = nn.Parameter(
torch.randn(max_actions, embedding_size) * 0.02 # Small initialization
)
self.mlp = nn.Sequential(
nn.Linear(embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
# 最大动作数量
self.max_actions = max_actions
# 通过参数查表,将动作索引转换为 Embedding。
# 参数:
# action_indices:形状为 (batch_size,) 的整数张量,取值范围 [0, max_actions-1]。
def forward(self, action_indices):
"""
Convert action indices to embeddings using parameter lookup.
Args:
action_indices: tensor of shape (batch_size,) containing integers in [0, max_actions-1]
"""
# 索引进参数矩阵。
# 例子:action_indices=[2,0,2] 时,那么 embeddings 形状为 [3, embedding_size],
# 第 1、3 行都来自 action_embeddings[2]。
embeddings = self.action_embeddings[action_indices]
# Process through MLP
embeddings = embeddings
# output 形状为 [batch_size, hidden_size]。
output = self.mlp(embeddings)
return output
# 返回所有可能动作的 Embedding。
def get_all_embeddings(self):
"""Returns embeddings for all possible actions."""
return self.mlp(self.action_embeddings)6.5. ZeroEncoder
# 主要用于占位,保持接口一致。
class ZeroEncoder(nn.Module):
def __init__(self, dit_dim, device):
super(ZeroEncoder, self).__init__()
self.dit_dim = dit_dim
self.device = device
# 不管 x 的形状如何,永远返回全零矩阵,行数等于 x 的第 0 维的长度,列数等于 dit_dim。
def forward(self, x):
return torch.zeros((x.shape[0], self.dit_dim), device=self.device)