大模型数学速成(09):多头注意力——多个专家各看各的
第 05 篇我们用的是「一整块」768 维 Q/K/V 做一次 Attention。实际模型里常见 12、16、32 个头(head)——同一层里多组更小的 Attention 并行运行,再拼回去。为什么要拆头?维度怎么变?
这一篇讲清 分头、逐头 Attention、concat 与输出投影 $W_o$ 的全流程。
这是「大模型数学速成」系列的第 9 篇。建议先读 第 05 篇:Attention 与 Softmax、第 07 篇:RoPE。下一篇 GQA。
一、单头 vs 多头:陪审团类比
单头 Attention:一个「专家」用 768 维 Q/K/V 同时看所有语义侧面。
多头 Attention(MHA):$h$ 个「陪审员」,每人只拿 $d_\text{head}$ 维(如 64 维),独立做一遍 Attention,最后把 $h$ 份意见 拼接 成完整向量。
1 | 单头:1 人 768 维全包 |
直觉:不同 head 可学到不同模式——有的关注语法距离,有的关注指代,有的关注局部 n-gram……并行多视角,比单一 768 维点积更灵活。
二、符号与典型数字
| 符号 | 含义 | 典型值 |
|---|---|---|
| $d$ | 模型 hidden 维($n_\text{embd}$) | 768 |
| $h$ | head 个数 | 12 |
| $d_\text{head}$ | 每头维度 | $d / h = 64$ |
| $S$ | token 数 | 序列长度 |
约束:$d$ 必须能被 $h$ 整除。
本系列约定:矩阵 [行 = 特征,列 = token](第 01 篇)。
三、分头:从 Q/K/V 到「每头一份」
投影后得到 $Q, K, V$,形状均为 [d, S] = [768, S]。
按 head 切分特征行(实现上常 reshape / view,数学上等价):
1 | Q [768, S] → 12 个头,每个 Q_h [64, S] |
示意($h=3$ 缩小):
1 | 列: token0 token1 token2 |
RoPE(第 07 篇)在分头之后、Attention 之前,对每个 head 的 Q/K 独立施加(每头 $d_\text{head}$ 维上的旋转)。
四、每个 head 内:与单头 Attention 相同
对第 $k$ 个头($k = 0..h-1$):
$$
\text{head}k = \text{Attention}(Q_k, K_k, V_k)
= V_k \cdot \text{Softmax}!\left(\frac{Q_k^\top K_k}{\sqrt{d\text{head}}}\right)^\top
$$
| 量 | 形状 |
|---|---|
| $Q_k, K_k, V_k$ | [d_head, S] |
| $Q_k^\top K_k$ | [S, S] |
| $\text{head}_k$ 输出 | [d_head, S] |
缩放因子用 $\sqrt{d_\text{head}}$(64),不是 $\sqrt{d}$(768)——每头在自己的子空间里做点积。
$h$ 个头 各自 算一遍,互不共享 Attention 矩阵(权重 $W_q,W_k,W_v$ 在投影阶段已统一,分头是 reshape)。
五、concat:拼回 [d, S]
把 $h$ 个 [d_\text{head}, S] 在**行(特征维)**上拼接:
$$
\text{Concat} = \begin{bmatrix} \text{head}_0 \ \text{head}1 \ \vdots \ \text{head}{h-1} \end{bmatrix}
\quad \in \mathbb{R}^{d \times S}
$$
1 | head_0 [64×S] ─┐ |
列数 $S$ 始终不变;行数 $h \times d_\text{head} = d$。
六、输出投影 $W_o$
拼接结果再过一层线性变换(输出投影):
$$
O = W_o \cdot \text{Concat}, \quad W_o \in \mathbb{R}^{d \times d}
$$
- $O$ 形状
[d, S],与 Attention 输入同形,可 残差相加(第 04 篇) - $W_o$ 让模型在 head 之间混合信息——concat 只是简单并排,$W_o$ 学「如何综合各陪审员意见」
完整 MHA 公式:
$$
\text{MHA}(Q,K,V) = W_o \cdot \text{Concat}\big(\text{head}0, \ldots, \text{head}{h-1}\big)
$$
七、维度变化全流程(一层 Self-Attention)
以 $d=768, h=12, d_\text{head}=64, S=1024$ 为例:
| 步骤 | 张量 | 形状 |
|---|---|---|
| 输入(Norm 后) | $X$ | [768, 1024] |
| Q/K/V 投影 | $Q,K,V$ | [768, 1024] |
| 分头 | $Q_k$ 等 | $h$ × [64, 1024] |
| RoPE | $\tilde{Q}_k, \tilde{K}_k$ | [64, 1024] |
| 单头 Attention | $\text{head}_k$ | [64, 1024] |
| Concat | [768, 1024] |
|
| $W_o$ | $O$ | [768, 1024] |
| 残差 | $X + O$ | [768, 1024] |
参数量:$W_q,W_k,W_v,W_o$ 各约 $d \times d$(是否含 bias 视实现而定);分头 不增加 投影矩阵个数,主要是 reshape 与 $h$ 次较小 Attention 计算。
八、与 ViT / LLM 的关系
第 08 篇 对照表:ViT 与 LLM 都用 MHA(或变体),差别在掩码(因果 vs 双向)、RoPE、Norm 等,分头 concat 结构相同。
ViT 有时 head 数随层变化;LLM 通常每层 $h$ 固定(如 32)。读 config 时看 num_attention_heads 与 hidden_size。
九、计算与显存直觉
- FLOPs:约 $h$ 次单头 Attention,每头 $O(S^2 \cdot d_\text{head}) = O(S^2 \cdot d / h)$;总计仍 $O(S^2 \cdot d)$ 量级
- KV Cache(第 11 篇):推理时需缓存 每个 head 的 K/V(或 GQA 压缩后的 K/V)
- 下一篇 GQA:在 head 数不变的前提下,减少 K/V 的 head 数,降低 Cache 体积
十、小结
| 概念 | 要点 |
|---|---|
| 为何多头 | 多视角并行,每头 $d_\text{head}$ 维 |
| 分头 | [d,S] → $h$ × [d_head,S] |
| 单头 Attn | 与第 05 篇相同,缩放 $\sqrt{d_\text{head}}$ |
| Concat | 行拼接 → [d,S] |
| $W_o$ | 混合各 head 输出 |
大模型数学速成系列第 9 篇完。下一篇 GQA——K/V 头数少于 Q,为 KV Cache 省显存。
系列导航
| 篇号 | 标题 | 状态 |
|---|---|---|
| 08 | ViT vs LLM 对照 | ✅ |
| 09 | 多头注意力(本篇) | ✅ |
| 10 | GQA:分组查询注意力 | 下一篇 |
完整大纲见工作区 docs/MATH_SERIES_OUTLINE.md。











