第 05 篇我们用的是「一整块」768 维 Q/K/V 做一次 Attention。实际模型里常见 12、16、32 个头(head)——同一层里多组更小的 Attention 并行运行,再拼回去。为什么要拆头?维度怎么变?

这一篇讲清 分头、逐头 Attention、concat 与输出投影 $W_o$ 的全流程。

这是「大模型数学速成」系列的第 9 篇。建议先读 第 05 篇:Attention 与 Softmax第 07 篇:RoPE。下一篇 GQA

一、单头 vs 多头:陪审团类比

单头 Attention:一个「专家」用 768 维 Q/K/V 同时看所有语义侧面。

多头 Attention(MHA):$h$ 个「陪审员」,每人只拿 $d_\text{head}$ 维(如 64 维),独立做一遍 Attention,最后把 $h$ 份意见 拼接 成完整向量。

1
2
单头:1 人 768 维全包
多头:12 人各 64 维,各看各的角度,再合并报告

直觉:不同 head 可学到不同模式——有的关注语法距离,有的关注指代,有的关注局部 n-gram……并行多视角,比单一 768 维点积更灵活。

二、符号与典型数字

符号 含义 典型值
$d$ 模型 hidden 维($n_\text{embd}$) 768
$h$ head 个数 12
$d_\text{head}$ 每头维度 $d / h = 64$
$S$ token 数 序列长度

约束:$d$ 必须能被 $h$ 整除

本系列约定:矩阵 [行 = 特征,列 = token]第 01 篇)。

三、分头:从 Q/K/V 到「每头一份」

投影后得到 $Q, K, V$,形状均为 [d, S] = [768, S]

按 head 切分特征行(实现上常 reshape / view,数学上等价):

1
2
3
Q [768, S]  →  12 个头,每个 Q_h [64, S]
K [768, S] → 每个 K_h [64, S]
V [768, S] → 每个 V_h [64, S]

示意($h=3$ 缩小):

1
2
3
4
5
        列: token0  token1  token2
头0行: [64 维 ] [64 维 ] [64 维 ] ← Q_0
头1行: [64 维 ] [64 维 ] [64 维 ] ← Q_1
...
头11: [64 维 ] [64 维 ] [64 维 ] ← Q_11

RoPE第 07 篇)在分头之后、Attention 之前,对每个 head 的 Q/K 独立施加(每头 $d_\text{head}$ 维上的旋转)。

四、每个 head 内:与单头 Attention 相同

对第 $k$ 个头($k = 0..h-1$):

$$
\text{head}k = \text{Attention}(Q_k, K_k, V_k)
= V_k \cdot \text{Softmax}!\left(\frac{Q_k^\top K_k}{\sqrt{d
\text{head}}}\right)^\top
$$

形状
$Q_k, K_k, V_k$ [d_head, S]
$Q_k^\top K_k$ [S, S]
$\text{head}_k$ 输出 [d_head, S]

缩放因子用 $\sqrt{d_\text{head}}$(64),不是 $\sqrt{d}$(768)——每头在自己的子空间里做点积。

$h$ 个头 各自 算一遍,互不共享 Attention 矩阵(权重 $W_q,W_k,W_v$ 在投影阶段已统一,分头是 reshape)。

五、concat:拼回 [d, S]

把 $h$ 个 [d_\text{head}, S] 在**行(特征维)**上拼接:

$$
\text{Concat} = \begin{bmatrix} \text{head}_0 \ \text{head}1 \ \vdots \ \text{head}{h-1} \end{bmatrix}
\quad \in \mathbb{R}^{d \times S}
$$

1
2
3
4
5
6
head_0 [64×S]  ─┐
head_1 [64×S] ─┤ concat 行
... ─┤
head_11[64×S] ─┘

Concat [768×S]

列数 $S$ 始终不变;行数 $h \times d_\text{head} = d$。

六、输出投影 $W_o$

拼接结果再过一层线性变换(输出投影):

$$
O = W_o \cdot \text{Concat}, \quad W_o \in \mathbb{R}^{d \times d}
$$

  • $O$ 形状 [d, S],与 Attention 输入同形,可 残差相加第 04 篇
  • $W_o$ 让模型在 head 之间混合信息——concat 只是简单并排,$W_o$ 学「如何综合各陪审员意见」

完整 MHA 公式:

$$
\text{MHA}(Q,K,V) = W_o \cdot \text{Concat}\big(\text{head}0, \ldots, \text{head}{h-1}\big)
$$

七、维度变化全流程(一层 Self-Attention)

以 $d=768, h=12, d_\text{head}=64, S=1024$ 为例:

步骤 张量 形状
输入(Norm 后) $X$ [768, 1024]
Q/K/V 投影 $Q,K,V$ [768, 1024]
分头 $Q_k$ 等 $h$ × [64, 1024]
RoPE $\tilde{Q}_k, \tilde{K}_k$ [64, 1024]
单头 Attention $\text{head}_k$ [64, 1024]
Concat [768, 1024]
$W_o$ $O$ [768, 1024]
残差 $X + O$ [768, 1024]

参数量:$W_q,W_k,W_v,W_o$ 各约 $d \times d$(是否含 bias 视实现而定);分头 不增加 投影矩阵个数,主要是 reshape 与 $h$ 次较小 Attention 计算。

八、与 ViT / LLM 的关系

第 08 篇 对照表:ViT 与 LLM 都用 MHA(或变体),差别在掩码(因果 vs 双向)、RoPE、Norm 等,分头 concat 结构相同

ViT 有时 head 数随层变化;LLM 通常每层 $h$ 固定(如 32)。读 config 时看 num_attention_headshidden_size

九、计算与显存直觉

  • FLOPs:约 $h$ 次单头 Attention,每头 $O(S^2 \cdot d_\text{head}) = O(S^2 \cdot d / h)$;总计仍 $O(S^2 \cdot d)$ 量级
  • KV Cache(第 11 篇):推理时需缓存 每个 head 的 K/V(或 GQA 压缩后的 K/V)
  • 下一篇 GQA:在 head 数不变的前提下,减少 K/V 的 head 数,降低 Cache 体积

十、小结

概念 要点
为何多头 多视角并行,每头 $d_\text{head}$ 维
分头 [d,S] → $h$ × [d_head,S]
单头 Attn 与第 05 篇相同,缩放 $\sqrt{d_\text{head}}$
Concat 行拼接 → [d,S]
$W_o$ 混合各 head 输出

大模型数学速成系列第 9 篇完。下一篇 GQA——K/V 头数少于 Q,为 KV Cache 省显存。

系列导航

篇号 标题 状态
08 ViT vs LLM 对照
09 多头注意力(本篇)
10 GQA:分组查询注意力 下一篇

完整大纲见工作区 docs/MATH_SERIES_OUTLINE.md