多头注意力机制Multi-Head Attention一、整体结构该实现是一个带掩码Mask的多头自注意力Masked Multi-Head Self-Attention常用于 GPT 类自回归模型。输入x∈RB×T×C x \in \mathbb{R}^{B \times T \times C}x∈RB×T×C其中BBBbatch sizeTTT序列长度 block_sizeCCCembedding 维度n_embd二、核心流程1. Q / K / V 线性映射qself.q_proj(x)kself.k_proj(x)vself.v_proj(x)对应公式QXWQ,KXWK,VXWV Q XW^Q,\quad K XW^K,\quad V XW^VQXWQ,KXWK,VXWV2. 多头拆分qq.view(B,T,n_heads,C//n_heads).transpose(1,2)得到(B,nheads,T,dk) (B, n_heads, T, d_k)(B,nh​eads,T,dk​)其中dkCnheads d_k \frac{C}{n_heads}dk​nh​eadsC​3. 注意力分数计算attn(q k.transpose(-2,-1))*(1/sqrt(d_k))公式scoreQKTdk \text{score} \frac{QK^T}{\sqrt{d_k}}scoredk​​QKT​shape(B,nheads,T,T) (B, n_heads, T, T)(B,nh​eads,T,T)4. Mask关键点self.masktorch.tril(torch.ones(T,T))attnattn.masked_fill(mask0,-inf)作用保证当前位置只能看到自己及之前的信息防止“信息泄露”自回归矩阵形式[100110111] \begin{bmatrix} 1 0 0 \\ 1 1 0 \\ 1 1 1 \end{bmatrix}​111​011​001​​5. Softmax 得到注意力权重attnsoftmax(attn,dim-1)αsoftmax(score) \alpha \text{softmax}(\text{score})αsoftmax(score)6. Dropout防过拟合attnself.attn_drop(attn)7. 加权求和xattn vheadαV \text{head} \alpha VheadαV8. 多头拼接xx.transpose(1,2).contiguous().view(B,T,C)9. 输出映射xself.out_proj(x)xself.out_drop(x)outputConcat(headi)WO \text{output} \text{Concat}(head_i) W^OoutputConcat(headi​)WO三、Shape变化总结面试重点步骤Shape输入(B, T, C)Q/K/V(B, T, C)分头(B, n_heads, T, d_k)注意力矩阵(B, n_heads, T, T)加权输出(B, n_heads, T, d_k)拼接(B, T, C)四、关键设计点1. 为什么要除以dk\sqrt{d_k}dk​​防止点积过大导致 softmax 梯度消失。2. 为什么需要 Mask自回归任务如 GPT必须保证当前位置不能看到未来信息3. register_buffer 的作用self.register_buffer(mask,...)不参与训练不是参数会随模型一起保存 / 加载自动放到 GPU4. contiguous().view().transpose(...).contiguous().view(...)transpose 后内存不连续必须 contiguous 才能 view五、复杂度分析时间复杂度O(T2⋅C) O(T^2 \cdot C)O(T2⋅C)空间复杂度O(T2) O(T^2)O(T2)六、完整执行流程总结版# NOTE multi-head attentionimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassMultiHeadAttention(nn.Module):def__init__(self,n_embd,n_heads,block_size,biasTrue,drop_rate0.1):super().__init__()self.n_embdn_embd# 768self.n_headsn_heads# 8self.block_sizeblock_size# 10self.q_projnn.Linear(n_embd,n_embd,biasbias)# 768, 768, Trueself.k_projnn.Linear(n_embd,n_embd,biasbias)# 768, 768, Trueself.v_projnn.Linear(n_embd,n_embd,biasbias)# 768, 768, Trueself.out_projnn.Linear(n_embd,n_embd,biasbias)# 768, 768, Trueself.attn_dropnn.Dropout(drop_rate)# 0.1self.out_dropnn.Dropout(drop_rate)# 0.1self.register_buffer(mask,torch.tril(torch.ones(block_size,block_size)).view(1,1,block_size,block_size))# 1, 1, 10, 10defforward(self,x):B,T,Cx.shape qself.q_proj(x).view(B,T,self.n_heads,C//self.n_heads).transpose(1,2)# B, 8, T, 96kself.k_proj(x).view(B,T,self.n_heads,C//self.n_heads).transpose(1,2)# B, 8, T, 96vself.v_proj(x).view(B,T,self.n_heads,C//self.n_heads).transpose(1,2)# B, 8, T, 96attn(q k.transpose(-2,-1))*(1/math.sqrt(k.size(-1)))# B, 8, T, Tattnattn.masked_fill(self.mask[:,:,:T,:T]0,float(-inf))attnF.softmax(attn,dim-1)# B, 8, T, Tattnself.attn_drop(attn)# B, 8, T, Tx(attn v).transpose(1,2).contiguous().view(B,T,C)# B, T, 768xself.out_proj(x)# B, T, 768xself.out_drop(x)# B, T, 768returnxinputtorch.rand(10,10,768)attentionMultiHeadAttention(n_embd768,n_heads8,block_size10)print(attention(input).shape)