先看效果:多头注意力的不同视角

切换不同的注意力头,观察同一句话中不同头关注的模式——有的关注相邻词,有的捕捉长距离依赖。

悬停 token 查看注意力权重,线条粗细表示权重大小

01 大白话讲 Transformer

RNN 的问题:只能一步步走

RNN 处理句子就像在读一本书,必须从第一个字读到最后一个字,前一步没读完就不能读下一步。这导致两个问题:

!
无法并行

GPU 有成千上万个核心,但 RNN 每次只用一个,浪费严重

!
长距离遗忘

句子开头的信息要经过几十步才能影响结尾,沿途不断被稀释

Transformer 的解法:所有位置同时互相看

Transformer 直接用 Self-Attention 让句子中每个词同时关注所有其他词,没有时序依赖,整句话一次并行处理完。

代价是:必须人工告诉模型词的顺序——这就是位置编码的作用。

位置编码:给词打上"座位号"

Transformer 用正弦/余弦函数生成每个位置的唯一编码,叠加到词向量上:

1
低维振动快,高维振动慢

就像时钟的秒针、分针、时针——组合起来可以精确表示任意时刻(任意位置)

2
相对位置可计算

任意两个位置的编码之差是固定的,让模型能感知相对距离,不只是绝对位置

多头注意力:多角度理解

单个 Attention 一次只能关注一种模式。多头注意力并行运行多个 Attention,每个"头"学会关注不同类型的关系:

Head 1

关注相邻词(局部语法)

Head 2

关注同类词(语义相似)

Head 3

关注句法依存(主谓宾)

Head 4

关注长距离指代(代词消歧)

最后把所有头的输出拼接起来,得到更丰富的表示。

为什么 Transformer 这么强?

训练速度快

全并行,充分利用 GPU——GPT-3 这种千亿参数模型只有 Transformer 才能训出来

长程依赖一跳直达

任意两词之间路径长度为 1,不再有梯度消失问题

通用架构

GPT(纯解码器)、BERT(纯编码器)、T5(编码器+解码器)都是 Transformer 变体

一步步构建位置编码

Transformer 没有循环结构,靠位置编码注入序列顺序信息。

第一步 位置编码公式

偶数维用 sin、奇数维用 cos,频率随维度增大而降低。

第二步 构建完整 PE 矩阵

生成 [MAX_LEN × D_MODEL] 的位置编码矩阵,加到词嵌入上。

第三步 热力图可视化

用热力图观察每个位置、每个维度的编码值——越靠右的维度变化越慢。

02 代码

03 学术性讲解

Transformer(Vaswani et al., 2017 "Attention Is All You Need")完全基于自注意力机制构建,彻底抛弃循环和卷积,通过多头注意力、残差连接和层归一化实现高效的序列建模。

整体架构:编码器–解码器

编码器 Encoder 多头自注意力 Add & Norm 前馈网络 FFN Add & Norm ×N 输入 + 位置编码 K, V → 解码器 解码器 Decoder 掩码自注意力 Add & Norm 交叉注意力(K,V←编码器) FFN + Add & Norm ×N 输出 + 位置编码 Linear + Softmax → 词 K, V

核心组件详解

1
位置编码(Positional Encoding)

PE(pos, 2i) = sin(pos / 10000^(2i/d)),PE(pos, 2i+1) = cos(...)。不同频率的正弦波叠加,使模型能通过线性变换计算相对位置差

2
多头注意力(Multi-Head Attention)

将 Q/K/V 分别投影到 h 个低维子空间,各自做 Attention 后拼接:MultiHead(Q,K,V) = Concat(head₁,...,headₕ)·W_O,其中 headᵢ = Attention(QWᵢ_Q, KWᵢ_K, VWᵢ_V)

3
残差连接 + 层归一化(Add & Norm)

每个子层输出为 LayerNorm(x + Sublayer(x))。残差防止深层网络梯度消失,层归一化稳定训练

4
掩码注意力(Masked Attention)

解码器在生成第 t 个词时,只能看到位置 1~t-1 的词(未来词被掩盖),保证自回归生成的因果性

5
交叉注意力(Cross Attention)

解码器的 Q 来自自身,K/V 来自编码器输出——这是解码器"查阅"源序列信息的通道

与 RNN 对比

最长路径

RNN O(n);Transformer O(1)——任意两词直接相连

并行度

RNN 顺序依赖无法并行;Transformer 全并行

计算量

RNN O(nd²);Transformer O(n²d)——短序列 Transformer 更快,超长序列反之

应用

GPT/BERT/T5 均基于 Transformer,已成为 NLP 默认架构

总结

并行

抛弃循环,全局并行

位置编码

正弦波注入顺序信息

多头

多角度捕捉依赖关系

残差+归一化

稳定深层网络训练