用PyTorch复现PGD攻击:手把手教你生成能“骗过”LeNet的对抗样本(附完整代码)
用PyTorch实战PGD对抗攻击从零构建可欺骗LeNet的对抗样本在计算机视觉领域对抗样本正成为模型安全性的重要测试手段。想象一下当你在手机银行应用中手写数字进行验证时攻击者可能通过精心设计的微小扰动让系统将3识别为8。这种看似魔法的攻击背后正是PGDProjected Gradient Descent这类算法的威力。本文将带你用PyTorch从零实现PGD攻击目标是在MNIST数据集上成功欺骗经典的LeNet模型。1. 环境准备与基础配置1.1 安装必要依赖确保你的Python环境已安装以下核心库pip install torch torchvision matplotlib tqdm numpy对于GPU加速用户建议安装CUDA版本的PyTorchpip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu1131.2 数据加载与预处理MNIST数据集作为入门级计算机视觉基准其28x28的手写数字图像非常适合快速验证对抗攻击效果。我们使用PyTorch内置的数据加载器from torchvision import datasets, transforms def load_mnist(batch_size64): transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准归一化参数 ]) testset datasets.MNIST(root./data, trainFalse, downloadTrue, transformtransform) return torch.utils.data.DataLoader(testset, batch_sizebatch_size, shuffleFalse)注意保持测试集shuffleFalse以确保结果可复现这对对抗攻击实验尤为重要。2. LeNet模型架构解析2.1 经典LeNet实现LeNet作为卷积神经网络的先驱其结构简洁却效果显著。以下是PyTorch实现import torch.nn as nn class LeNet(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(1, 6, 5, padding2) self.conv2 nn.Conv2d(6, 16, 5) self.fc1 nn.Linear(16*5*5, 120) self.fc2 nn.Linear(120, 84) self.fc3 nn.Linear(84, 10) def forward(self, x): x nn.functional.max_pool2d(nn.functional.relu(self.conv1(x)), 2) x nn.functional.max_pool2d(nn.functional.relu(self.conv2(x)), 2) x torch.flatten(x, 1) x nn.functional.relu(self.fc1(x)) x nn.functional.relu(self.fc2(x)) return self.fc3(x)2.2 模型性能基准测试在开始攻击前我们需要确认原始模型的准确率def test_accuracy(model, loader, device): correct 0 total 0 with torch.no_grad(): for images, labels in loader: images, labels images.to(device), labels.to(device) outputs model(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() return 100 * correct / total典型情况下训练良好的LeNet在MNIST测试集上应达到98%以上的准确率。3. PGD攻击核心实现3.1 攻击算法数学原理PGD攻击的核心迭代公式$$ x_{t1} \Pi_{x \pm \epsilon} \left( x_t \alpha \cdot \text{sign}(\nabla_{x_t} J(x_t, y)) \right) $$其中关键参数参数说明典型值ε扰动上限8/255α单步扰动强度2/255T迭代次数7-103.2 PyTorch实现细节完整PGD攻击函数实现def pgd_attack(model, images, labels, epsilon8/255, alpha2/255, iterations10): # 随机初始化扰动 delta torch.empty_like(images).uniform_(-epsilon, epsilon) adv_images torch.clamp(images delta, 0, 1).detach() for _ in range(iterations): adv_images.requires_grad True # 计算损失 outputs model(adv_images) loss nn.functional.cross_entropy(outputs, labels) # 梯度计算 model.zero_grad() loss.backward() grad adv_images.grad.data # 更新对抗样本 adv_images adv_images.detach() alpha * grad.sign() # 投影到ε邻域并保持有效像素范围 delta torch.clamp(adv_images - images, -epsilon, epsilon) adv_images torch.clamp(images delta, 0, 1) return adv_images提示使用.detach()切断计算图可以节省内存这对大规模攻击很重要。4. 攻击效果可视化与分析4.1 对抗样本生成执行攻击并收集结果def evaluate_attack(model, loader, device, epsilon): correct 0 adv_examples [] for images, labels in loader: images, labels images.to(device), labels.to(device) # 生成对抗样本 adv_images pgd_attack(model, images, labels, epsilonepsilon) # 测试模型表现 outputs model(adv_images) _, pred torch.max(outputs, 1) correct (pred labels).sum().item() # 保存示例 if len(adv_examples) 5: adv_ex adv_images[0].squeeze().cpu().numpy() orig_ex images[0].squeeze().cpu().numpy() adv_examples.append((orig_ex, adv_ex, labels[0].item(), pred[0].item())) accuracy 100 * correct / len(loader.dataset) print(fEpsilon: {epsilon:.4f}, Test Accuracy: {accuracy:.2f}%) return adv_examples4.2 结果可视化使用matplotlib对比原始图像与对抗样本import matplotlib.pyplot as plt def plot_examples(examples): plt.figure(figsize(10, 5)) for i, (orig, adv, true, pred) in enumerate(examples): plt.subplot(2, 5, i1) plt.imshow(orig, cmapgray) plt.title(fTrue: {true}) plt.axis(off) plt.subplot(2, 5, i6) plt.imshow(adv, cmapgray) plt.title(fPred: {pred}) plt.axis(off) plt.tight_layout() plt.show()典型输出会显示人眼几乎无法区分的微小扰动却导致模型完全错误的预测结果。5. 高级技巧与实战建议5.1 参数调优指南PGD攻击效果受多个参数影响ε (epsilon): 控制最大扰动幅度太小攻击可能失败太大扰动变得明显建议范围4/255到16/255α (alpha): 单步扰动强度经验法则α ≈ ε/4太大可能导致振荡太小收敛缓慢5.2 多步攻击策略增加迭代次数通常能提高攻击成功率但边际效益递减迭代次数攻击成功率计算成本165%低589%中1093%高2095%很高5.3 防御措施初探了解攻击后我们可以尝试简单防御def defensive_denoise(images, threshold0.1): # 简单去噪防御 return torch.clamp(images threshold * torch.randn_like(images), 0, 1)在实际项目中更强大的防御需要对抗训练或专用检测机制。