PyTorch实战5步搞定监督对比学习SupCon损失函数实现监督对比学习Supervised Contrastive Learning作为对比学习在监督场景下的扩展正在计算机视觉、自然语言处理等领域展现出强大的特征提取能力。与传统的交叉熵损失相比SupCon通过显式拉近同类样本、推远异类样本的特征表示能够学习到更具判别性的嵌入空间。本文将聚焦PyTorch实现用5个关键步骤带你从零实现SupCon损失函数。1. 理解监督对比学习的核心思想SupCon的核心创新在于巧妙利用了监督信息来定义正负样本。假设我们有一个batch中包含N个样本每个样本经过两次不同的数据增强得到两个视图views那么正样本与锚样本anchor类别相同的所有样本包括不同视图负样本与锚样本类别不同的所有样本这种定义方式比自监督对比学习更直接地利用了标签信息。从数学上看SupCon损失函数可以表示为$$ \mathcal{L}{sup} \sum{i\in I}\frac{-1}{|P(i)|}\sum_{p\in P(i)}\log \frac{\exp(z_i \cdot z_p/\tau)}{\sum_{a\in A(i)}\exp(z_i \cdot z_a/\tau)} $$其中$P(i)$是与样本$i$同类的正样本集合$A(i)$是除$i$本身外的所有样本集合$\tau$是温度系数控制分布的尖锐程度温度系数$\tau$的选择很关键值太大会导致所有样本相似度趋同太小则会使模型难以收敛。实践中通常设置在0.05到0.2之间。2. 准备数据与特征编码器在实现损失函数前我们需要准备数据加载器和特征编码器。这里以CIFAR-10为例import torch import torchvision from torch import nn # 数据增强策略 train_transform torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(32), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean[0.4914, 0.4822, 0.4465], std[0.2023, 0.1994, 0.2010]) ]) # 加载CIFAR-10数据集 train_dataset torchvision.datasets.CIFAR10( root./data, trainTrue, transformtrain_transform, downloadTrue) # 简单的CNN编码器 class Encoder(nn.Module): def __init__(self): super().__init__() self.net nn.Sequential( nn.Conv2d(3, 32, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64*8*8, 128) ) def forward(self, x): return self.net(x)关键点在于数据增强策略的选择。SupCon的性能很大程度上依赖于使用的数据增强组合常见的包括随机裁剪和大小调整颜色抖动亮度、对比度、饱和度、色调随机灰度化高斯模糊3. 实现SupCon损失函数现在我们来逐步实现SupCon损失函数。完整的实现需要考虑以下几个技术细节相似度矩阵的高效计算正负样本掩码mask的构建数值稳定性处理多视图支持class SupConLoss(nn.Module): def __init__(self, temperature0.07, contrast_modeall): super().__init__() self.temperature temperature self.contrast_mode contrast_mode def forward(self, features, labelsNone): device features.device # 特征维度处理 if len(features.shape) 3: raise ValueError(特征需要是[bsz, n_views, ...]格式) batch_size features.shape[0] # 构建标签掩码 labels labels.view(-1, 1) mask torch.eq(labels, labels.T).float().to(device) # 处理多视图特征 contrast_feature torch.cat(torch.unbind(features, dim1), dim0) if self.contrast_mode all: anchor_feature contrast_feature anchor_count features.shape[1] else: raise ValueError(不支持的对比模式) # 计算相似度矩阵 anchor_dot_contrast torch.matmul( anchor_feature, contrast_feature.T) / self.temperature # 数值稳定性处理 logits_max, _ torch.max(anchor_dot_contrast, dim1, keepdimTrue) logits anchor_dot_contrast - logits_max.detach() # 构建排除自身的掩码 logits_mask torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) mask mask.repeat(anchor_count, anchor_count) * logits_mask # 计算log概率 exp_logits torch.exp(logits) * logits_mask log_prob logits - torch.log(exp_logits.sum(1, keepdimTrue)) # 计算正样本的平均log概率 mean_log_prob_pos (mask * log_prob).sum(1) / mask.sum(1) # 最终损失 loss -mean_log_prob_pos.mean() return loss实现细节使用torch.scatter构建的对角线掩码可以高效地排除样本与自身的对比这是对比学习中常见的技巧。4. 训练流程与技巧有了损失函数后我们需要设计完整的训练流程。以下是关键训练步骤前向传播对每个样本生成两个增强视图特征提取通过编码器获取特征表示损失计算使用SupCon损失函数反向传播更新模型参数def train_one_epoch(model, loss_fn, optimizer, loader, device): model.train() total_loss 0 for images, labels in loader: images torch.cat([images[0], images[1]], dim0).to(device) labels labels.to(device) # 获取特征 features model(images) features features.view(2, -1, features.size(-1)).permute(1, 0, 2) # 计算损失 loss loss_fn(features, labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(loader)训练SupCon模型时有几个实用技巧值得注意学习率预热前几个epoch使用较小的学习率然后逐步增大大batch size对比学习通常需要较大的batch size256以上以获得足够的负样本投影头在编码器后添加一个小型MLP如两层感知机作为投影头可以提升性能5. 评估与应用训练完成后我们可以评估学习到的特征表示质量。常见评估方式包括线性评估协议冻结特征提取器只训练线性分类器最近邻分类在特征空间中使用k-NN分类可视化分析使用t-SNE或UMAP降维可视化特征分布def evaluate(model, test_loader, device): model.eval() total_correct 0 with torch.no_grad(): for images, labels in test_loader: images images.to(device) labels labels.to(device) features model(images) preds features.argmax(dim1) total_correct (preds labels).sum().item() return total_correct / len(test_loader.dataset)在实际项目中SupCon学习到的特征可以用于少样本学习Few-shot Learning迁移学习任务数据增强效果有限的场景需要鲁棒特征表示的应用监督对比学习的优势在于它结合了监督信号的明确性和对比学习的表征能力。在实践中我发现合理调整温度系数和选择合适的投影头结构对最终性能影响很大。对于计算资源有限的场景可以考虑使用memory bank等技术来增加有效的负样本数量。