PyTorch实现CIFAR-10图像分类的CNN模型详解
1. 项目概述CIFAR-10图像分类任务是深度学习领域的经典入门项目。这个32x32像素的彩色图像数据集包含10个类别共6万张图片5万训练1万测试。相比MNIST手写数字识别CIFAR-10的识别难度更高主要体现在彩色图像3通道比灰度图像1通道信息更复杂物体可能出现在图片的任何位置背景干扰因素更多同类物体的形态差异更大我使用的开发环境是Python 3.10.19和PyTorch 2.10.0在NVIDIA GPU上运行。下面将详细介绍从数据准备到模型训练的全过程。2. 环境配置与数据准备2.1 GPU环境设置在深度学习项目中GPU加速至关重要。PyTorch中可以通过以下代码检查并设置计算设备import torch device torch.device(cuda if torch.cuda.is_available() else cpu) print(fUsing device: {device})提示如果使用Colab等云平台需要确保已启用GPU加速。本地开发时建议安装对应CUDA版本的PyTorch以获得最佳性能。2.2 数据集加载与处理CIFAR-10数据集可以通过torchvision直接加载import torchvision from torchvision import transforms # 定义数据转换 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载训练集和测试集 train_ds torchvision.datasets.CIFAR10( data, trainTrue, transformtransform, downloadTrue ) test_ds torchvision.datasets.CIFAR10( data, trainFalse, transformtransform, downloadTrue )这里有几个关键点需要注意ToTensor()将PIL图像转换为PyTorch张量并自动将像素值缩放到[0,1]范围Normalize()对每个通道进行标准化参数分别是均值(0.5)和标准差(0.5)下载的数据会保存在data目录下2.3 数据加载器配置使用DataLoader可以方便地进行批量数据加载和打乱batch_size 32 train_dl torch.utils.data.DataLoader( train_ds, batch_sizebatch_size, shuffleTrue ) test_dl torch.utils.data.DataLoader( test_ds, batch_sizebatch_size )选择batch_size时需要考虑GPU内存大小训练速度模型收敛稳定性32是一个常用的起始值可以根据实际情况调整。3. 模型架构设计3.1 CNN基础结构我们的CNN模型包含以下层次import torch.nn as nn import torch.nn.functional as F class CIFAR10Model(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size3, padding1) self.pool1 nn.MaxPool2d(2, 2) self.conv2 nn.Conv2d(64, 64, kernel_size3, padding1) self.pool2 nn.MaxPool2d(2, 2) self.conv3 nn.Conv2d(64, 128, kernel_size3, padding1) self.pool3 nn.MaxPool2d(2, 2) self.fc1 nn.Linear(128 * 4 * 4, 256) self.fc2 nn.Linear(256, 10) def forward(self, x): x self.pool1(F.relu(self.conv1(x))) x self.pool2(F.relu(self.conv2(x))) x self.pool3(F.relu(self.conv3(x))) x x.view(-1, 128 * 4 * 4) x F.relu(self.fc1(x)) x self.fc2(x) return x3.2 关键设计选择卷积层配置使用3x3小卷积核平衡特征提取能力和参数数量逐步增加通道数(64→64→128)提取更复杂的特征添加padding1保持特征图尺寸池化策略采用2x2最大池化每次将特征图尺寸减半在三个卷积层后都进行池化全连接层第一个全连接层将特征展平并降维到256最终输出10维对应10个类别3.3 参数数量分析使用torchsummary查看模型参数from torchinfo import summary model CIFAR10Model().to(device) summary(model, input_size(batch_size, 3, 32, 32))输出显示总参数约24.6万这对于CIFAR-10任务是一个适中的规模。4. 模型训练与评估4.1 训练配置loss_fn nn.CrossEntropyLoss() optimizer torch.optim.SGD(model.parameters(), lr0.01, momentum0.9) epochs 10选择交叉熵损失函数因为它非常适合多分类问题。优化器使用带动量的SGD初始学习率设为0.01。4.2 训练循环实现def train_epoch(model, train_loader, loss_fn, optimizer): model.train() total_loss, total_correct 0, 0 for X, y in train_loader: X, y X.to(device), y.to(device) optimizer.zero_grad() outputs model(X) loss loss_fn(outputs, y) loss.backward() optimizer.step() total_loss loss.item() total_correct (outputs.argmax(1) y).sum().item() avg_loss total_loss / len(train_loader) accuracy total_correct / len(train_loader.dataset) return accuracy, avg_loss4.3 测试评估实现def evaluate(model, test_loader, loss_fn): model.eval() total_loss, total_correct 0, 0 with torch.no_grad(): for X, y in test_loader: X, y X.to(device), y.to(device) outputs model(X) loss loss_fn(outputs, y) total_loss loss.item() total_correct (outputs.argmax(1) y).sum().item() avg_loss total_loss / len(test_loader) accuracy total_correct / len(test_loader.dataset) return accuracy, avg_loss4.4 完整训练流程train_accs, train_losses [], [] test_accs, test_losses [], [] for epoch in range(epochs): train_acc, train_loss train_epoch(model, train_dl, loss_fn, optimizer) test_acc, test_loss evaluate(model, test_dl, loss_fn) train_accs.append(train_acc) train_losses.append(train_loss) test_accs.append(test_acc) test_losses.append(test_loss) print(fEpoch {epoch1}/{epochs}) print(fTrain Acc: {train_acc:.2%}, Loss: {train_loss:.4f}) print(fTest Acc: {test_acc:.2%}, Loss: {test_loss:.4f}\n)5. 结果分析与改进方向5.1 训练结果经过10个epoch的训练典型结果如下Epoch 1/10 Train Acc: 13.52%, Loss: 2.2834 Test Acc: 20.90%, Loss: 2.1952 Epoch 10/10 Train Acc: 58.20%, Loss: 1.1843 Test Acc: 54.00%, Loss: 1.33705.2 性能可视化import matplotlib.pyplot as plt plt.figure(figsize(12, 4)) plt.subplot(1, 2, 1) plt.plot(range(epochs), train_accs, labelTrain) plt.plot(range(epochs), test_accs, labelTest) plt.title(Accuracy) plt.legend() plt.subplot(1, 2, 2) plt.plot(range(epochs), train_losses, labelTrain) plt.plot(range(epochs), test_losses, labelTest) plt.title(Loss) plt.legend() plt.show()5.3 改进建议数据增强transform_train transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding4), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])学习率调度scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size5, gamma0.1)模型优化增加批归一化层尝试更深的网络结构使用ResNet等先进架构正则化技术Dropout权重衰减早停法6. 关键问题与解决方案6.1 过拟合问题现象训练准确率明显高于测试准确率解决方案增加数据增强添加Dropout层使用L2正则化减少模型复杂度6.2 训练不稳定现象损失值波动大解决方案适当减小学习率增加批量大小使用梯度裁剪尝试不同的优化器(如Adam)6.3 类别不平衡现象某些类别准确率明显低于其他解决方案在损失函数中添加类别权重过采样少数类使用Focal Loss在实际项目中我通常会保存多个检查点方便后续分析和模型选择torch.save({ epoch: epoch, model_state_dict: model.state_dict(), optimizer_state_dict: optimizer.state_dict(), loss: loss, }, fcheckpoint_epoch{epoch}.pth)这个基础CNN模型在CIFAR-10上能达到约54%的测试准确率虽然不算很高但完整展示了深度学习项目的工作流程。后续可以通过更复杂的模型架构和训练技巧进一步提升性能。