Transformer 中的核心机制 - Attention

1. 解析注意力机制

下文将解析 Transformer 模型中的核心 - 注意力机制,特别是其最核心的 “缩放点积注意力” 和 “多头注意力”

1.1. 核心思想与动机

在传统序列模型(比如 RNN)中,模型需要逐步处理序列,并且难以建立远距离依赖关系。注意力机制的核心理念是:让模型在处理序列中的任何一个位置(比如一个词)时,能够“直接看到”并且“加权考虑”序列中所有其他位置的信息。

简单比喻:人在阅读一句话时,理解某个词的含义将依赖句子中的其他关键词。比如,理解“它”这个词,需要注意力指向其指代的前文名词。Transformer 将此过程数学化、并行化

1.2. 核心组件:Query、Key、Value

这是理解注意力的关键类比,灵感来源于信息检索系统:

  • Query:当前正在关注的“问题”或“目标”。比如,当前要处理的词向量。
  • Key:序列中所有位置的“标识”或“索引”,用于与 Query 进行匹配,计算相关度。
  • Value:序列中所有位置所携带的“实际信息”或“内容”。

注意力机制的过程是:对于给定的 Query,计算它与所有 Key 的匹配度(相似度),然后将这个匹配度作为权重,对所有的 Value 进行加权求和。输出就是该融合全局信息的加权和。

1.3. 缩放点积注意力

这是 Transformer 中注意力计算的基本单元。

假设输入序列包含 n 个词,每个词的向量维度是 d_model。将所有词向量堆叠成矩阵 X(形状为 [n, d_model])。

第 1 步:线性投影生成 Q、 K、 V

通过三个不同的可学习权重矩阵 W^QW^KW^VW^QW^K 的形状为 [d_model, d_k]W^V 的形状为 [d_model, d_v],通常 d_k = d_v),将输入 X 投影到三个空间:

Q = X * W^Q  // Query 矩阵
K = X * W^K  // Key 矩阵
V = X * W^V  // Value 矩阵

Q、 KV 的形状都是 [n, d_k]这一步的目的是将输入转换到适合进行相似度计算和提取信息的空间。

第 2 步:计算注意力分数(匹配度)

计算 Query 和所有 Key 的点积,衡量其相关性。对于第 i 个 Query 向量 q_i 和第 j 个 Key 向量 k_j,它们的分数为:

score_{ij} = q_i · k_j^T

矩阵形式为计算 Q 和 K 的转置的乘积:

Scores = Q * K^T  // 形状 [n, n]

这个 [n, n] 的矩阵 Scores 的第 i 行第 j 列表示第 i 个位置(作为 Query)对第 j 个位置(作为 Key)的注意力原始分数。

第 3 步:缩放与归一化(得到注意力权重)

点积的结果在维度 d_k 较大时,其方差将变大,导致经过 Softmax 后梯度非常小(极端值)。因此引入缩放因子 sqrt(d_k) 稳定梯度。

ScaledScores = Scores / sqrt(d_k)

然后,对每一行(即每个 Query)应用 Softmax 函数,使得该行所有数值变为和为 1 的概率分布(注意力权重):

AttentionWeights = softmax(ScaledScores, dim=-1)  // 形状 [n, n],每行和为 1

这里 AttentionWeights[i, j] 表示第 i 个位置给予第 j 个位置的注意力权重。

第 4 步:加权求和得到输出

用得到的注意力权重矩阵对 V 矩阵进行加权求和:

Z = AttentionWeights * V  // 形状 [n, d_v]

输出矩阵 Z 的每一行是对应位置的 Query 在“关注”序列所有位置的信息后得到的新表示。

完整公式(一步到位):

Attention(Q, K, V) = softmax( (Q * K^T) / sqrt(d_k) ) * V

1.4. 多头注意力

单一的注意力机制可能只关注到一种模式的信息(比如语法依赖)。为让模型能够同时关注来自不同表示子空间的信息,Transformer 引入“多头”机制。

核心思想:将 Q、 K、 V 在最后一个维度(d_model)上“分割”成 h 个头,每个头有自己的投影矩阵,独立进行注意力计算,最后将结果拼接并且映射。

步骤与公式:

第 1 步:线性投影并且分头

使用 h 组不同的投影矩阵 W_i^Q、W_i^K、W_i^Vi 从 1 到 h),将输入的 Q、K、V 投影到低维。通常 d_k = d_v = d_model / h

head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)

每个 head_i 的形状为 [n, d_model/h]

第 2 步:拼接各头输出

将所有 h 个头的输出在最后一个维度上拼接起来。

MultiHeadOutput = Concat(head_1, head_2, ..., head_h)  // 形状 [n, d_model]

第 3 步:最终线性投影

将拼接后的结果通过可学习的输出权重矩阵 W^O(形状 [d_model, d_model])进行投影,整合信息。

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O

完整公式:

MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) * W^O
where head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)

1.5. Transformer 在三个地方使用多头注意力机制,但 Q、K、V 的来源不同

  1. 编码器自注意力Q、K、V 均来自前一编码器层的输出。每个位置可以关注同一序列中的所有位置,用于学习丰富的上下文表示。
  1. 解码器自注意力(又称掩码自注意力):Q、K、V 均来自前一解码器层的输出。为防止在训练时“看到未来”的信息,它在计算注意力权重时,将 i 位置之后的所有位置(j > i)的分数屏蔽(mask),设置为负无穷,这样 Softmax 后权重为 0。
  1. 编码器-解码器注意力(又称交叉注意力):Q 来自前一解码器层的输出,而 K, V 来自最后一个编码器层的输出。这使得解码器的每个位置可以关注输入序列的相关部分,类似于传统 Seq2Seq 模型中的注意力。

2. 为什么要缩放?

2.1. Softmax 的梯度

2.1.1. 定义与符号

设向量 z = [z₁, z₂, ..., zₙ] ∈ ℝⁿ,其 Softmax 函数定义为:

其中,令

(所有指数的和)。

目标是计算偏导数:

这表示第 i 个输出对第 j 个输入的偏导数。

2.1.2. 分情况推导

情况 1:当 i = j 时(对角元素)

计算:

使用商的求导法则:

这里:

  • \( u = e^{z_i} \),所以 \( u' = e^{z_i} \)
  • \( v = S \),所以 \( v' = \frac{\partial S}{\partial z_i} = e^{z_i} \)(因为 S 中只有一项 \( e^{z_i} \) 依赖于 \( z_i \))

代入公式:

简化:

将其拆分为:

现在用 Softmax 表示:

  • \( \frac{e^{z_i}}{S} = \text{softmax}(z_i) \)
  • \( \frac{S - e^{z_i}}{S} = 1 - \frac{e^{z_i}}{S} = 1 - \text{softmax}(z_i) \)

因此:

情况 2:当 i ≠ j 时(非对角元素)

计算:

注意:这里 \( e^{z_i} \) 不依赖于 \( z_j \)(因为 i ≠ j),所以 \( e^{z_i} \) 对 \( z_j \) 来说是常数。

因此:

计算 \( \frac{\partial}{\partial z_j} \left( \frac{1}{S} \right) \):

而 \( \frac{\partial S}{\partial z_j} = e^{z_j} \)(因为 S 中只有 \( e^{z_j} \) 依赖于 \( z_j \))

所以:

代回原式:

现在用 Softmax 表示:

  • \( \frac{e^{z_i}}{S} = \text{softmax}(z_i) \)
  • \( \frac{e^{z_j}}{S} = \text{softmax}(z_j) \)

因此:

2.1.3. 统一公式

使用克罗内克 δ 函数统一这两种情况:

克罗内克 δ 函数定义为:

统一公式为:

2.2. 为什么点积的方差随 d_k 增大而增大?

2.2.1. 假设条件

假设 Q 和 K 矩阵中的每个元素都是独立同分布的随机变量,且:

  • 均值为 0
  • 方差为 σ²(通常 σ² = 1/d_model 或 1)

2.2.2. 点积计算

对于第 i 个 Query 向量 q_i 和第 j 个 Key 向量 k_j

点积 s = q_i · k_j = Σ_{m=1}^{d_k} q_{im} × k_{jm}

其中每个 q_{im} 和 k_{jm} 都是随机变量。

2.2.3. 方差计算

首先需要两个重要的统计学性质:

性质 1:独立随机变量之和的方差

如果随机变量 X1、X2、...、Xn 相互独立,那么:

这是方差的可加性,但仅当变量相互独立时成立

性质 2:协方差为 0 时的方差可加性

更一般地,对于任意随机变量:

当所有随机变量相互独立时,协方差项为 0,即为性质 1。

协方差衡量的是两个随机变量“协同变化”的趋势和方向。

可以将其想象成两个变量之间的关系描述:

  • 同向变化:当一个变量高于其平均值时,另一个也倾向于高于其平均值 → 正协方差
  • 反向变化:当一个变量高于其平均值时,另一个却倾向于低于其平均值 → 负协方差
  • 无规律变化:一个变量的高低与另一个变量的高低没有系统性关系 → 协方差接近零

根据上述性质:

Var(s) = Var(Σ_{m=1}^{d_k} q_{im} × k_{jm}) = Σ_{m=1}^{d_k} Var(q_{im} × k_{jm})

由于 q_{im} 和 k_{jm} 相互独立:

Var(q_{im} × k_{jm}) = E[(q_{im} × k_{jm})^2] - (E[q_{im}×k_{jm}])^2
                     = E[q_{im}^2] × E[k_{jm}^2] - (E[q_{im}]×E[k_{jm}])^2

假设 E[q_{im}] = E[k_{jm}] = 0,E[q_{im}^2] = E[k_{jm}^2] = σ²:

Var(q_{im} × k_{jm}) = σ² × σ² - 0 = σ⁴

因此:

Var(s) = d_k × σ⁴

点积 s 的方差与 d_k 成正比。当 d_k 很大时(如 512 或 1024),方差非常大。

2.3. 大方差对 Softmax 的影响

Softmax 函数特性

softmax(z_i) = exp(z_i) / Σ_j exp(z_j)

当输入向量 z 的某些元素非常大时:

  1. 数值不稳定:exp(很大的数) 可能导致数值溢出(infinity)
  1. 梯度消失
    • Softmax 的梯度为:∂softmax(z_i)/∂z_j = softmax(z_i) × (δ_ij - softmax(z_j))
    • 当某个 z_i 远大于其他时,softmax(z_i) ≈ 1,其他 softmax(z_j) ≈ 0
    • 梯度矩阵将非常稀疏,大部分梯度接近于 0

具体例子

假设 d_k = 256,σ² = 1:

  • 理论方差:Var(s) = 256 × 1 = 256
  • 标准差:std(s) = √256 = 16

这意味着点积值通常分布在 ±16 的范围内。但 Softmax 对输入的尺度非常敏感:

# 不同尺度输入对 Softmax 的影响
输入1: [1.0, 2.0, 3.0]        -> softmax: [0.0900, 0.2447, 0.6652]  # 分布合理
输入2: [10.0, 20.0, 30.0]     -> softmax: [0.0000, 0.0000, 1.0000]  # 几乎 one-hot
输入3: [100.0, 200.0, 300.0]  -> softmax: [0.0000, 0.0000, 1.0000]  # 数值溢出风险

2.4. 为什么缩放因子是 √d_k?

缩放的目标

希望缩放后的点积方差为 1(或一个固定常数),避免随 d_k 增长。

计算缩放因子

设缩放因子为 α,缩放后:

s_scaled = s / α
Var(s_scaled) = Var(s) / α² = (d_k × σ⁴) / α²

希望 Var(s_scaled) = 1:

(d_k × σ⁴) / α² = 1
α² = d_k × σ⁴
α = √(d_k × σ⁴) = √d_k × σ²

标准假设

在标准初始化中(比如 Xavier 初始化),通常 σ² = 1/d_k 或 1。为简化,假设 σ² = 1:

α = √d_k

这就是为什么使用 1/√d_k 作为缩放因子。