首先了解一些基本概念以Llama13B为例首先是输入输出这里的2是因为每个值都是float16占两个字节然后转换为MB输入输出相加为20MB所占显存大小和其他部分相比可以忽略不计这里的2是因为每个值都是float16占两个字节1B和1GB大致相当都是float32存储的为什么优化器要存模型参数从归属上看模型参数属于 Model优化器属于 Optimizer。从物理内存上看优化器不复制模型参数而是通过引用直接修改它们但优化器会为每个参数分配额外的状态缓存如动量缓冲池。在大模型显存规划中评估优化器带来的显存压力时必须将这部分“辅助状态”计算在内例如 Adam 需要额外增加约 8~16 GB/十亿参数的显存消耗具体取决于精度格式。为什么平滑值不能用float16因为会丢失精度梯度很小学习率更小在反向传播中会用到前向传播中的激活值https://zhuanlan.zhihu.com/p/673916177关于激活值显存占用更详细可以参考上面这个链接具体的 34 是一个经验估算值或特定实现下的精确计数涵盖了 LayerNorm 的统计量、MLP 层的多个线性变换输入输出缓存等。这里的系数 5 可能对应Q, K, V, Score, Output 这 5 个主要张量的保存需求。激活值计算好像漏乘了2FP16占两个字节计算 QKT。其中 Q 和 K 的 shape 都是[b, a, s, h/a]。矩阵乘法后得到的分数矩阵 shape 为[b, a, s, s]。显存占用需要保存 Q 和 K 用于反向传播大小为 bsh。分数矩阵本身大小为 bs^2a。在计算总显存时Attention模块与序列长度相关的主要二次方项来自于 bs^2a 将sbh提取出括号后得到 as/h参考视频RethinkFun投稿视频-RethinkFun视频分享-哔哩哔哩视频