Flash Attention 技术原理让 Transformer 更快更高效前言Flash Attention 是 Transformer 架构的一项革命性优化技术。它通过重新设计注意力计算的内存访问模式显著减少了内存带宽消耗同时保持了计算精度。我在多个项目中使用过 Flash Attention性能提升非常明显。今天分享这项技术的原理和实现。传统 Attention 的问题内存瓶颈# 传统 Attention 计算 def naive_attention(Q, K, V): 朴素的注意力计算 # Q, K, V: (batch, heads, seq_len, head_dim) # 计算注意力分数 scores Q K.transpose(-2, -1) / math.sqrt(Q.size(-1)) # (B, H, N, N) # Softmax attn F.softmax(scores, dim-1) # (B, H, N, N) # 输出 output attn V # (B, H, N, D) return output # 问题scores 和 attn 都是 O(N²) 的内存占用 # 对于 N2048这意味着约 160MB 每头FP16Flash Attention 的解决方案class FlashAttention: Flash Attention 核心思想 def __init__(self, block_size128): self.block_size block_size def flash_mha(self, Q, K, V): 分块计算注意力 B, H, N, D Q.shape output torch.zeros_like(Q) # 按块处理 for i in range(0, N, self.block_size): # 处理输出块 Q_block Q[:, :, i:iself.block_size, :] accumulator torch.zeros_like(Q_block) scale 1.0 for j in range(0, N, self.block_size): # 加载 K, V 块 K_block K[:, :, j:jself.block_size, :] V_block V[:, :, j:jself.block_size, :] # 计算局部注意力 scores Q_block K_block.transpose(-2, -1) / math.sqrt(D) # 数值稳定的 Softmax scores_max scores.max(dim-1, keepdimTrue).values exp_scores torch.exp(scores - scores_max) exp_sum exp_scores.sum(dim-1, keepdimTrue) # 累加 accumulator accumulator * scale (exp_scores V_block) scale scale * exp_sum # 最终归一化 output[:, :, i:iself.block_size, :] accumulator / scale return output技术细节分块策略def flash_attention_blocked(Q, K, V, block_m128, block_n128): block_m: 输出维度的块大小 block_n: K/V 维度的块大小 B, H, N, D Q.shape # 初始化输出和累加器 O torch.zeros(B, H, N, D, deviceQ.device) l torch.zeros(B, H, N, 1, deviceQ.device) # 累加和 m torch.full((B, H, N, 1), float(-inf), deviceQ.device) # 最大值 # 遍历输出块 for i in range(0, N, block_m): # 当前输出块的 Q Q_i Q[:, :, i:iblock_m, :] # 遍历 K/V 块 for j in range(0, N, block_n): # 当前 K/V 块 K_j K[:, :, j:jblock_n, :] V_j V[:, :, j:jblock_n, :] # 计算注意力分数 scores_ij Q_i K_j.transpose(-2, -1) / math.sqrt(D) # (B, H, block_m, block_n) # 更新最大值 m_new torch.max(m[:, :, i:iblock_m, :], scores_ij.max(dim-1, keepdimTrue).values) # 计算缩放因子 P torch.exp(scores_ij - m_new) l_new torch.exp(m[:, :, i:iblock_m, :] - m_new) * l[:, :, i:iblock_m, :] P.sum(dim-1, keepdimTrue) # 更新输出 O[:, :, i:iblock_m, :] torch.exp(m[:, :, i:iblock_m, :] - m_new) * O[:, :, i:iblock_m, :] / l_new (P V_j) / l_new # 更新状态 m[:, :, i:iblock_m, :] m_new l[:, :, i:iblock_m, :] l_new return O数值稳定性def stable_softmax(scores): 数值稳定的 Softmax # 减去最大值防止溢出 scores_max scores.max(dim-1, keepdimTrue).values exp_scores torch.exp(scores - scores_max) # 归一化 return exp_scores / exp_scores.sum(dim-1, keepdimTrue)性能对比import time def benchmark_attention(Q, K, V, iterations10): 基准测试 # 预热 _ F.scaled_dot_product_attention(Q, K, V) # 计时 start time.time() for _ in range(iterations): _ F.scaled_dot_product_attention(Q, K, V) end time.time() avg_time (end - start) / iterations print(fFlash Attention: {avg_time*1000:.2f} ms) # 朴素实现仅用于对比 start time.time() for _ in range(iterations): _ naive_attention(Q, K, V) end time.time() avg_time (end - start) / iterations print(fNaive Attention: {avg_time*1000:.2f} ms) # 测试 Q torch.randn(1, 12, 2048, 64, devicecuda) K torch.randn(1, 12, 2048, 64, devicecuda) V torch.randn(1, 12, 2048, 64, devicecuda) benchmark_attention(Q, K, V)在 Transformers 中使用from transformers import AutoModelForCausalLM # 使用 Flash Attention 2 model AutoModelForCausalLM.from_pretrained( Qwen/Qwen2-7B, attn_implementationflash_attention_2, torch_dtypetorch.bfloat16, device_mapauto ) # 验证 print(fAttention implementation: {model.config.attn_implementation})总结Flash Attention 的优势内存效率高O(N) 内存复杂度而非 O(N²)速度快减少内存访问提高计算效率精度保持数值稳定的算法设计关键要点分块处理减少内存带宽消耗利用 GPU 的高带宽 SRAMPyTorch 2.0 原生支持性能提升可达 2-4x