1. 小样本学习与Prototypical Network基础当你第一次听说小样本学习时可能会觉得这是个遥不可及的高深概念。其实它的核心思想很简单就像人类能通过少量例子快速学习新事物一样让AI模型也具备这种能力。想象一下你给孩子看几张不同品种的鸟的图片他很快就能在野外认出这些鸟。Prototypical Network正是实现这种能力的经典方法之一。在实际应用中小样本学习特别适合那些数据稀缺的场景。比如医疗影像中罕见病症的诊断、工业质检中的缺陷识别或者保护生物学中的濒危物种监测。传统深度学习需要大量标注数据而Prototypical Network只需要每个类别几张图片就能达到不错的效果。Prototypical Network属于度量学习(Metric Learning)的范畴它的核心思路是学习一个特征空间在这个空间中同类样本彼此靠近异类样本相互远离。具体来说它会为每个类别计算一个原型(prototype)也就是该类样本在特征空间中的平均位置。对新样本分类时只需看它离哪个原型最近即可。2. Prototypical Network核心原理拆解2.1 原型计算类别的中心点Prototypical Network最核心的概念就是原型。举个例子假设我们要识别三种鸟类红雀、蓝鸦和黄鹂。对于每个类别我们有几张示例图片支持集。网络会用编码器比如ResNet提取每张图片的特征向量对同一类别的所有特征向量取平均值这个平均值就是该类别的原型用数学公式表示就是# 假设support_features是形状为[n_way, k_shot, feature_dim]的张量 prototypes torch.mean(support_features, dim1) # 得到形状为[n_way, feature_dim]的原型2.2 距离度量如何定义相似得到原型后我们需要度量查询样本与各个原型的相似度。最常用的方法是欧式距离的平方def euclidean_distance(query_features, prototypes): # query_features: [n_query, feature_dim] # prototypes: [n_way, feature_dim] return torch.cdist(query_features, prototypes, p2)**2这个距离越小说明查询样本与该原型越相似。在实际实现时我们通常会先对特征进行L2归一化这样欧式距离就和余弦相似度等价了。2.3 损失函数推动原型分离训练时使用的损失函数是负对数似然Negative Log-Likelihood。对于每个查询样本我们计算它与所有原型的距离用softmax将距离转换为概率分布最小化正确类别的负对数概率代码实现如下def compute_loss(distances, targets): # distances: [n_query, n_way] # targets: [n_query] log_p_y F.log_softmax(-distances, dim1) loss F.nll_loss(log_p_y, targets) return loss3. PyTorch完整实现指南3.1 数据准备与Episode采样小样本学习与传统监督学习最大的不同在于数据组织形式。我们需要实现一个Episode采样器class EpisodeSampler: def __init__(self, dataset, n_way, k_shot, q_query): self.dataset dataset self.n_way n_way self.k_shot k_shot self.q_query q_query self.classes list(set(dataset.targets)) def __iter__(self): while True: # 随机选择n_way个类别 selected_classes random.sample(self.classes, self.n_way) support [] query [] for class_idx in selected_classes: # 获取该类所有样本 samples [i for i, (_, y) in enumerate(self.dataset) if y class_idx] # 随机选择k_shot q_query个样本 selected random.sample(samples, self.k_shot self.q_query) support.extend(selected[:self.k_shot]) query.extend(selected[self.k_shot:]) yield support, query3.2 模型架构设计一个完整的Prototypical Network包含两个主要部分编码器(Encoder)负责将图像映射到特征空间。可以使用预训练的CNNclass Encoder(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 64, 3, padding1) self.bn1 nn.BatchNorm2d(64) self.conv2 nn.Conv2d(64, 64, 3, padding1) self.bn2 nn.BatchNorm2d(64) self.conv3 nn.Conv2d(64, 128, 3, padding1) self.bn3 nn.BatchNorm2d(128) self.conv4 nn.Conv2d(128, 128, 3, padding1) self.bn4 nn.BatchNorm2d(128) self.fc nn.Linear(128*8*8, 256) # 假设输入是84x84图像 def forward(self, x): x F.relu(self.bn1(self.conv1(x))) x F.max_pool2d(x, 2) x F.relu(self.bn2(self.conv2(x))) x F.max_pool2d(x, 2) x F.relu(self.bn3(self.conv3(x))) x F.max_pool2d(x, 2) x F.relu(self.bn4(self.conv4(x))) x x.view(x.size(0), -1) x self.fc(x) return x原型网络(PrototypicalNetwork)实现原型计算和分类逻辑class PrototypicalNetwork(nn.Module): def __init__(self, encoder): super().__init__() self.encoder encoder def forward(self, support_x, support_y, query_x): # 提取支持集和查询集特征 support_features self.encoder(support_x) query_features self.encoder(query_x) # 计算每个类别的原型 prototypes [] for class_idx in torch.unique(support_y): mask support_y class_idx class_prototype support_features[mask].mean(dim0) prototypes.append(class_prototype) prototypes torch.stack(prototypes) # 计算查询样本与各原型的距离 distances torch.cdist(query_features, prototypes, p2)**2 return distances3.3 训练循环实现训练过程需要特别注意episode的组织方式def train(model, optimizer, sampler, device, epochs100): model.train() for epoch in range(epochs): total_loss 0 correct 0 total 0 # 每个epoch使用固定数量的episode for _ in range(100): support_idx, query_idx next(sampler) support_x torch.stack([dataset[i][0] for i in support_idx]).to(device) support_y torch.tensor([dataset[i][1] for i in support_idx]).to(device) query_x torch.stack([dataset[i][0] for i in query_idx]).to(device) query_y torch.tensor([dataset[i][1] for i in query_idx]).to(device) optimizer.zero_grad() # 前向传播 distances model(support_x, support_y, query_x) # 计算损失 loss F.cross_entropy(-distances, query_y) # 反向传播 loss.backward() optimizer.step() # 统计准确率 _, predicted torch.min(distances, 1) correct (predicted query_y).sum().item() total query_y.size(0) total_loss loss.item() print(fEpoch {epoch1}, Loss: {total_loss/100:.4f}, Acc: {correct/total:.4f})4. 实战技巧与性能优化4.1 特征编码器的选择编码器的选择对性能影响巨大。在实践中我发现浅层网络对于简单数据集如Omniglot3-4层CNN就足够ResNet变体对于较复杂数据集如miniImageNetResNet-12或ResNet-18效果更好预训练模型如果领域相近使用在ImageNet预训练的模型能显著提升性能一个实用的技巧是在编码器最后加入可学习的缩放层class ScaledEncoder(nn.Module): def __init__(self, base_encoder): super().__init__() self.encoder base_encoder self.scale nn.Parameter(torch.tensor(1.0)) def forward(self, x): features self.encoder(x) return features * self.scale4.2 训练策略优化学习率调度使用余弦退火学习率能稳定训练optimizer torch.optim.Adam(model.parameters(), lr1e-3) scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max100)特征归一化对编码特征进行L2归一化能提升距离度量的效果def forward(self, x): features self.encoder(x) return F.normalize(features, p2, dim1)数据增强适当的数据增强能显著提升小样本学习性能train_transform transforms.Compose([ transforms.RandomResizedCrop(84), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.4), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])4.3 处理类别不平衡问题在实际应用中不同类别的样本数量可能差异很大。我们可以通过以下方式改进加权采样在构造episode时对样本少的类别增加采样概率class_weight 1.0 / np.bincount(dataset.targets) sample_prob class_weight[dataset.targets] sample_prob / sample_prob.sum()原型修正对样本少的类别使用更大的温度系数class PrototypicalNetworkWithTemperature(nn.Module): def __init__(self, encoder, class_counts): super().__init__() self.encoder encoder self.temperature nn.Parameter(torch.ones(len(class_counts))) def forward(self, support_x, support_y, query_x): # ... 计算原型和距离 ... scaled_distances distances / self.temperature[support_y] return scaled_distances5. 进阶应用与扩展思路5.1 跨域小样本学习当训练数据和测试数据来自不同领域时比如用自然图像训练用于医学图像分类常规Prototypical Network性能会下降。解决方法包括特征解耦将特征空间分为领域共享部分和领域特定部分class DomainDisentangle(nn.Module): def __init__(self, encoder): super().__init__() self.encoder encoder self.domain_proj nn.Linear(256, 128) self.class_proj nn.Linear(256, 128) def forward(self, x): features self.encoder(x) domain_feat self.domain_proj(features) class_feat self.class_proj(features) return torch.cat([domain_feat, class_feat], dim1)元迁移学习先在多个源域上进行元训练再在目标域上微调5.2 半监督Prototypical Network当支持集中部分样本没有标签时可以采用半监督方法标签传播基于特征相似度传播标签def label_propagation(features, labeled_idx, labels, k5): # 计算所有样本间的相似度 sim_matrix torch.mm(features, features.t()) # 对每个未标注样本用k近邻的标签加权平均 # ... 实现细节省略 ... return propagated_labels一致性正则对同一图像的不同增强版本强制预测一致5.3 与其他方法的结合Prototypical Network可以与其他小样本学习方法结合与MAML结合先用MAML进行参数初始化再用Prototypical Network进行分类与Transformer结合用Transformer编码器替代CNN利用自注意力机制class TransformerEncoder(nn.Module): def __init__(self): super().__init__() self.patch_embed nn.Conv2d(3, 128, kernel_size16, stride16) self.transformer nn.TransformerEncoderLayer(d_model128, nhead8) def forward(self, x): patches self.patch_embed(x).flatten(2).transpose(1, 2) features self.transformer(patches) return features.mean(dim1)在实际项目中我通常会先用标准的Prototypical Network建立baseline然后根据具体问题和数据特点逐步引入这些进阶技术。记住模型复杂度增加的同时过拟合风险也会上升特别是在小样本场景下。