第 03 篇:自动微分不神秘——梯度是怎么流动的
目录先从一个问题开始第一步requires_grad——谁需要被追踪第二步前向传播——计算值的同时悄悄建图整张计算图长这样第三步.backward()——沿着图把梯度传回来把反向传播的每一步拆开来看第四步叶节点与非叶节点的区别第五步梯度累积——一个你必须理解的坑第六步torch.no_grad()——关掉追踪的开关第七步计算图只能用一次第八步用 detach() 切断梯度流第九步参数更新之后计算图就消失了第十步当你用 nn.Module 和 optimizer 之后第十一步梯度消失和梯度爆炸是怎么回事第十二步自定义反向传播选读小结整条链路在第一篇里我们说过PyTorch 的动态图是边执行边构建的.backward()不是在计算梯度而是在沿着已经构建好的图做一次反向遍历。这一篇就来兑现这个承诺——把这条链路真正讲清楚。不用复杂的模型就用最简单的手写线性回归输入一个数输出一个数一个权重一个偏置。但我们会把这个过程拆解到足够细让你看清楚每一步背后发生了什么。搞懂这一篇之后你再去看任何关于梯度消失、梯度裁剪、自定义反向传播的内容都会理解其中的原理。先从一个问题开始假设你有一个极其简单的函数其中是输入是标签和是参数是损失预测值和真实值的误差的平方。你想用梯度下降更新和就需要手动推导不难但问题是你不想每次换个模型就重新推一遍导数。你希望框架能自动帮你算。PyTorch 的自动微分Autograd就是干这个的。但它不是在符号层面推导公式而是在计算图上做数值反向传播。这两件事的结果一样但方式完全不同——理解这个区别是理解 Autograd 的起点。第一步requires_grad——谁需要被追踪import torch # 数据不需要求导 x torch.tensor(2.0) y torch.tensor(5.0) # 真实标签预期输出 5 # 参数需要求导这是 PyTorch 需要追踪的 w torch.tensor(1.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) print(x.requires_grad) # False print(w.requires_grad) # True print(b.requires_grad) # Truerequires_gradTrue告诉 PyTorch这个 Tensor 是我们关心的参数请在所有涉及它的计算过程中构建计算图以便之后能对它求导。没有设置requires_gradTrue的 Tensor比如x和yPyTorch 不会追踪与它们相关的操作这样可以节省内存和计算。有一个非常重要的传播规则只要一个操作的任意输入有requires_gradTrue那么这个操作的输出也会自动有requires_gradTrue。a torch.tensor(3.0, requires_gradTrue) b torch.tensor(2.0) # 没有 requires_grad c a * b # 输入里 a 有 requires_grad print(c.requires_grad) # True ← 自动传播这个传播规则是整个自动微分系统能工作的基础——从参数出发所有经过它的计算结果都会自动被追踪。第二步前向传播——计算值的同时悄悄建图# 前向传播计算预测值和损失 pred w * x b # 预测值1.0 * 2.0 0.0 2.0 loss (pred - y) ** 2 # 损失(2.0 - 5.0)^2 9.0 print(f预测值: {pred.item()}) # 2.0 print(f损失: {loss.item()}) # 9.0这两行代码在你眼里是在做数学运算但在 PyTorch 背后还同时发生了另一件事计算图被悄悄建立起来了。你可以通过grad_fn属性窥探这张图print(w.grad_fn) # None ← 叶节点没有 grad_fn print(pred.grad_fn) # AddBackward0 object at ... print(loss.grad_fn) # PowBackward0 object at ...grad_fn是什么它是每个 Tensor 的出生证明记录了这个 Tensor 是由哪个操作产生的以及这个操作需要什么信息才能做反向传播。w和b是叶节点直接由用户创建没有grad_fnpred w * x b是加法操作的结果所以grad_fn是AddBackward0loss (pred - y) ** 2是幂操作的结果所以grad_fn是PowBackward0这些grad_fn对象不只是标签它们每个都存储了反向传播所需的中间数据并且知道收到上游梯度之后该把什么梯度传给自己的输入。每个节点知道两件事自己的输出值是什么前向传播已经算好了收到来自输出方向的梯度时应该怎么把梯度传给自己的输入反向传播就是从loss出发沿着这张图的反向边把梯度一路传回到w和b。第三步.backward()——沿着图把梯度传回来loss.backward()就这一行。PyTorch 会从loss开始初始梯度为 1.0损失对自身的导数是 1沿着计算图反向遍历对每个节点调用它的grad_fn计算并传播梯度直到到达叶节点w和b把最终梯度累积到它们的.grad属性里让我们看结果print(fw 的梯度: {w.grad}) # tensor(-12.) print(fb 的梯度: {b.grad}) # tensor(-6.)验证一下手动推导的结果和手动计算完全一致。PyTorch 做的事情和你用链式法则手推是等价的只是它是在计算图上自动完成的。为了真正理解梯度是怎么在图上流动的我们一步一步手动追踪第一步loss → pred 的梯度令则第二步pred → w 的梯度通过链式法则所以第三步pred → b 的梯度每个grad_fn的职责就是完成第二步和第三步这样的工作接收上游传来的梯度乘以自己这层的局部导数传给下游。这就是链式法则在计算图上的具体体现。第四步叶节点与非叶节点的区别print(fw 是叶节点吗: {w.is_leaf}) # True print(fb 是叶节点吗: {b.is_leaf}) # True print(fpred 是叶节点吗: {pred.is_leaf}) # False print(floss 是叶节点吗: {loss.is_leaf}) # False叶节点由用户直接创建的 Tensor不是任何运算的结果。模型的参数权重、偏置都是叶节点。它们的grad_fn是None。非叶节点由运算产生的 Tensor中间计算结果。它们有grad_fn记录了自己从哪里来。.backward()完成后只有叶节点的.grad会被保留非叶节点的梯度默认会被释放print(fw.grad: {w.grad}) # tensor(-12.) ← 保留了 print(fb.grad: {b.grad}) # tensor(-6.) ← 保留了 print(fpred.grad: {pred.grad}) # None ← 释放了 print(floss.grad: {loss.grad}) # None ← 释放了这是 PyTorch 的内存优化策略对于模型训练来说你只需要叶节点参数的梯度来做参数更新中间激活值的梯度是临时的算完就可以扔掉。如果你确实需要某个中间节点的梯度比如做梯度可视化或者某些特殊的优化可以在它上面调用.retain_grad()w torch.tensor(1.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) x torch.tensor(2.0) y torch.tensor(5.0) pred w * x b pred.retain_grad() # 告诉 PyTorch请保留 pred 的梯度 loss (pred - y) ** 2 loss.backward() print(fpred.grad: {pred.grad}) # tensor(-6.) ← 现在有了第五步梯度累积——一个你必须理解的坑.backward()做的事情是把梯度累加到.grad上而不是替换。w torch.tensor(1.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) x torch.tensor(2.0) y torch.tensor(5.0) # 第一次前向 反向 pred w * x b loss (pred - y) ** 2 loss.backward() print(f第一次 w.grad: {w.grad}) # tensor(-12.) # 第二次前向 反向没有清零 pred w * x b loss (pred - y) ** 2 loss.backward() print(f第二次 w.grad: {w.grad}) # tensor(-24.) ← -12 (-12) -24梯度被累加了。这就是为什么训练循环里每次迭代必须先清零梯度# 正确的训练循环写法 for epoch in range(100): # ① 清零梯度必须在前向传播之前或反向传播之后 w.grad None # 手动清零方式 b.grad None # ② 前向传播 pred w * x b loss (pred - y) ** 2 # ③ 反向传播 loss.backward() # ④ 参数更新 with torch.no_grad(): w - 0.01 * w.grad b - 0.01 * b.grad用 optimizer 的时候optimizer.zero_grad()就是在做清零这件事等价于对所有参数的.grad置零或者置 None。梯度累积并不总是坏事。当你因为显存限制无法使用大 batch size 时可以故意不清零跑几个小 batch、累积梯度再做一次参数更新——效果等价于跑了一个大 batch。这是一种常用的训练技巧叫做 Gradient Accumulation。第六步torch.no_grad()——关掉追踪的开关做参数更新的时候用w - 0.01 * w.grad这一步计算本身也会触发 PyTorch 的追踪因为w有requires_gradTrue。但这次更新不应该被追踪——它只是在修改参数值不是模型的前向计算。这就是为什么参数更新要包在torch.no_grad()里with torch.no_grad(): w - 0.01 * w.grad b - 0.01 * b.gradtorch.no_grad()是一个上下文管理器在它的代码块里所有操作都不会被追踪创建的 Tensor 的requires_grad强制为False不会构建计算图节省内存和计算同样的验证模型时也应该放在torch.no_grad()里原因在第一篇已经讲过验证阶段不需要反向传播不追踪计算图可以大幅节省显存和加快速度。还有一个类似的工具是torch.inference_mode()比no_grad()更激进连grad_fn都不记录推理性能更好但在它里面修改 Tensor 可能有副作用不能回到训练模式一般只在最终部署推理时使用。第七步计算图只能用一次w torch.tensor(1.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) x torch.tensor(2.0) y torch.tensor(5.0) pred w * x b loss (pred - y) ** 2 loss.backward() # 第一次反向传播正常 # loss.backward() ← 如果你再调用一次会报错 # RuntimeError: Trying to backward through the graph a second time...为什么因为.backward()在遍历计算图的过程中会释放中间节点存储的那些用于反向传播的中间数据比如前向传播的激活值以节省内存。图走完之后这些数据就消失了没法再走第二遍。如果你确实需要多次对同一个计算图做反向传播某些研究场景比如计算 Hessian 矩阵或者 MAML 这类元学习算法需要加retain_graphTrueloss.backward(retain_graphTrue) # 不释放中间数据 loss.backward() # 可以再走一遍注意retain_graphTrue会让内存占用翻倍甚至更多只在确实需要时使用。第八步用 detach() 切断梯度流有时候你希望某段计算不参与反向传播——比如你在做一个 GAN更新生成器的时候不希望梯度流回判别器或者你在 RL 里有一个 target network它的输出作为监督信号但本身不更新。这时候需要.detach()w torch.tensor(1.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) x torch.tensor(2.0) y torch.tensor(5.0) pred w * x b # pred 有 grad_fn是图的一部分 # detach 返回一个新的 Tensor和 pred 数值相同但从计算图里切断了 pred_detached pred.detach() print(pred_detached.requires_grad) # False print(pred_detached.grad_fn) # None ← 梯度流在这里断了 loss (pred_detached - y) ** 2 loss.backward() # w 和 b 的梯度是 None因为梯度流在 pred_detached 这里断了 print(w.grad) # None print(b.grad) # None.detach()常见的实际用途# RL 里的 target networktarget 不参与梯度计算 target_q target_net(next_obs).max(dim1).values.detach() loss F.mse_loss(current_q, target_q) # target_q 断开梯度 loss.backward() # 只更新 current network 的参数 # GAN 里更新判别器时不想让梯度流回生成器 fake_img generator(noise).detach() # 切断梯度 d_loss discriminator(fake_img) d_loss.backward() # 只更新判别器detach()和torch.no_grad()的区别torch.no_grad()是上下文管理器在其作用域内所有操作都不追踪出了这个块一切恢复正常.detach()是对某个特定 Tensor的永久切断这个 Tensor 在任何地方都不再携带梯度信息第九步参数更新之后计算图就消失了把前面所有知识组装成一个完整的线性回归训练循环同时展示计算图的生命周期import torch # 数据 x torch.tensor([1.0, 2.0, 3.0, 4.0]) y torch.tensor([3.0, 5.0, 7.0, 9.0]) # y 2x 1 # 参数随机初始化 w torch.tensor(0.0, requires_gradTrue) b torch.tensor(0.0, requires_gradTrue) lr 0.01 for epoch in range(200): # ① 前向传播同时建图 pred w * x b # 向量操作x 是 4 个样本 loss ((pred - y) ** 2).mean() # MSE loss # ② 清零梯度关键在反向传播之前清零上一次残留的梯度 if w.grad is not None: w.grad.zero_() # in-place 清零 if b.grad is not None: b.grad.zero_() # ③ 反向传播沿图传播梯度图在这之后被释放 loss.backward() # ④ 参数更新在 no_grad 里避免这步操作被追踪进图 with torch.no_grad(): w - lr * w.grad b - lr * b.grad if (epoch 1) % 50 0: print(fEpoch {epoch1:3d} | loss: {loss.item():.6f} | w: {w.item():.4f} | b: {b.item():.4f}) print(f\n最终结果w {w.item():.4f}b {b.item():.4f}) print(f期望结果w 2.0b 1.0)输出大概是Epoch 50 | loss: 0.193418 | w: 1.6484 | b: 0.8096 Epoch 100 | loss: 0.018273 | w: 1.9003 | b: 0.9546 Epoch 150 | loss: 0.001726 | w: 1.9727 | b: 0.9854 Epoch 200 | loss: 0.000163 | w: 1.9924 | b: 0.9950 最终结果w 1.9924b 0.9950 期望结果w 2.0b 1.0模型正在向w2, b1收敛一切正常。注意zero_()后面的下划线——这是 PyTorch 的in-place 操作命名约定。所有以_结尾的方法都是原地修改不创建新 Tensor。zero_()把 Tensor 里的所有数值置为 0。与之对应zeros_like(t)是创建一个新的全零 Tensor不是 in-place。第十步当你用 nn.Module 和 optimizer 之后上面手写的训练循环在实际使用中会用nn.Module和optimizer来替代但背后的逻辑是一模一样的import torch import torch.nn as nn import torch.optim as optim # 数据 x torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y torch.tensor([[3.0], [5.0], [7.0], [9.0]]) # 模型一个线性层等价于 w*x b model nn.Linear(1, 1) # 优化器SGDlr0.01 # optimizer 内部持有模型参数的引用知道该更新哪些东西 optimizer optim.SGD(model.parameters(), lr0.01) loss_fn nn.MSELoss() for epoch in range(200): # ① 前向传播 pred model(x) loss loss_fn(pred, y) # ② 清零梯度optimizer.zero_grad() 等价于前面手写的清零 optimizer.zero_grad() # ③ 反向传播 loss.backward() # ④ 参数更新optimizer.step() 等价于前面手写的 w - lr * w.grad # optimizer 内部自动用 torch.no_grad() 包裹参数更新 optimizer.step() if (epoch 1) % 50 0: print(fEpoch {epoch1:3d} | loss: {loss.item():.6f})optimizer.zero_grad()、loss.backward()、optimizer.step()这三步是 PyTorch 训练循环的铁三角顺序不能错每步背后干的事情你现在都清楚了。一个值得注意的细节optimizer.zero_grad()可以写在optimizer.step()之后下一次迭代开始之前也可以写在前向传播之前。两者效果相同但有人偏好写在 step 之后理由是这样的代码更接近更新完参数立刻清理准备下一轮的语义。在带 Gradient Accumulation 的场景里清零时机需要更精细的控制但那是后面的话题。第十一步梯度消失和梯度爆炸是怎么回事理解了梯度是如何在图上流动的现在可以自然地解释这两个深度学习里最烦恼的问题。梯度在每一层都是通过链式法则相乘传播的这是项相乘。如果每一项都小于 1比如 sigmoid 函数的导数最大只有 0.25乘了次之后梯度就趋近于 0——梯度消失靠近输入的层几乎没有梯度根本学不动。反过来如果每一项都大于 1乘了次之后梯度会爆炸式增长——梯度爆炸参数更新步幅极大训练直接发散。# 用代码演示梯度消失 import torch import torch.nn as nn # 10 层网络用 sigmoid 激活 layers [] for i in range(10): layers.append(nn.Linear(10, 10)) layers.append(nn.Sigmoid()) # sigmoid 导数最大 0.25 model nn.Sequential(*layers) x torch.randn(1, 10) y torch.randn(1, 10) loss ((model(x) - y) ** 2).mean() loss.backward() # 看第一层和最后一层权重的梯度大小 first_layer_grad model[0].weight.grad.abs().mean().item() last_layer_grad model[-2].weight.grad.abs().mean().item() print(f最后一层梯度均值: {last_layer_grad:.6f}) print(f第一层梯度均值: {first_layer_grad:.6f}) # 第一层的梯度会比最后一层小几个数量级这就是为什么ReLU 取代 sigmoid 成为默认激活函数——ReLU 的导数是 1正区间不会持续缩小梯度ResNet 引入残差连接——给梯度提供了一条高速公路不用穿过所有层就能传回来梯度裁剪Gradient Clipping用于对抗梯度爆炸——在更新参数之前把梯度的范数限制在一个最大值内# 梯度裁剪的标准写法 loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0) optimizer.step() # clip_grad_norm_ 会计算所有参数梯度的全局范数 # 如果超过 max_norm就等比例缩小所有梯度使总范数等于 max_norm第十二步自定义反向传播选读大多数情况下PyTorch 内置的grad_fn能自动处理一切。但偶尔你需要自己定义一个操作的反向传播——比如这个操作在数学上不可微但你知道一个合理的近似梯度或者你在实现一个论文里的自定义 loss。这时候用torch.autograd.Functionclass MySquare(torch.autograd.Function): staticmethod def forward(ctx, x): # ctx 是一个上下文对象用来在 forward 和 backward 之间传递信息 ctx.save_for_backward(x) # 把前向需要传给反向的值存起来 return x ** 2 staticmethod def backward(ctx, grad_output): # grad_output从输出方向传来的梯度即 dL/d_output x, ctx.saved_tensors # 取出前向存的值 # 本地梯度d(x^2)/dx 2x # 链式法则dL/dx dL/d_output * d_output/dx grad_output * 2x return grad_output * 2 * x # 使用自定义操作 x torch.tensor(3.0, requires_gradTrue) y MySquare.apply(x) # 注意用 .apply() 调用不是直接 () y.backward() print(x.grad) # tensor(6.) ← 2 * 3 6 ✓这个机制在以下场景里会用到实现 Straight-Through Estimator量化感知训练里绕过取整操作的不可微性实现某些特殊的激活函数或者 loss 函数性能优化用 CUDA kernel 实现自定义算子同时告诉 PyTorch 它的反向传播怎么算小结整条链路从头到尾自动微分的完整工作流是这样的1. 给参数设置 requires_gradTrue ↓ 2. 前向传播计算值 实时构建计算图 每个操作产生一个 grad_fn 节点记录操作类型和所需的中间数据 ↓ 3. loss.backward() 从 loss 出发沿图反向遍历 每个节点的 grad_fn 接收上游梯度用局部导数 * 上游梯度 本节点的梯度传给下游 中间节点的梯度随用随弃 叶节点参数的梯度累加到 .grad 里 ↓ 4. 用 .grad 更新参数在 no_grad 里做 ↓ 5. 清零 .gradzero_grad准备下一次迭代关键认知点汇总requires_grad 是追踪的开关只有设置了它的 Tensor 及其下游才会被追踪grad_fn 是每个非叶节点的出生证明存储了反向传播需要的中间信息计算图在 backward() 之后默认释放需要重复使用时加 retain_graphTrue梯度是累加的不是覆盖的训练循环每次迭代必须主动清零detach() 切断某个 Tensor 的梯度流no_grad() 关掉整个代码块的追踪叶节点的 .grad 会保留非叶节点的默认释放需要时用 retain_grad()梯度消失和梯度爆炸是链式法则连乘在深层网络里的必然结果残差连接和梯度裁剪是工程上的对策。下一篇我们进入数据管道——Dataset和DataLoader。你会看到即便是加载数据这件看似平凡的事在 PyTorch 里也有一套清晰的设计哲学值得理解透彻。