Swin Transformer与拉普拉斯近似:宇宙学参数推断中的不确定性量化实践
1. 项目概述当Swin Transformer遇见拉普拉斯近似在宇宙学参数推断这个领域里我们一直在和数据中的“不确定性”作斗争。无论是来自宇宙微波背景辐射CMB的观测还是来自21厘米信号的模拟数据本身都充满了噪声和非高斯结构。传统的统计方法比如马尔可夫链蒙特卡洛MCMC虽然能给出漂亮的后验分布但计算成本高得吓人面对高维数据和复杂模型时常常力不从心。最近几年深度学习尤其是视觉Transformer架构在特征提取方面展现了惊人的能力但一个核心问题始终悬而未决我们如何相信一个“黑箱”模型给出的点估计它的预测到底有多可靠这正是不确定性量化Uncertainty Quantification, UQ要解决的问题。它不只是给个误差棒那么简单而是要为模型的每一个预测提供一个完整的概率分布描述告诉我们“这个预测在多大程度上是可信的”。在科学计算中没有可靠的不确定性估计再精确的点预测也像是空中楼阁。拉普拉斯近似Laplace Approximation, LA作为一种经典的贝叶斯近似方法因其在深度神经网络上的可扩展性和相对较低的计算开销重新回到了研究者的视野。它本质上是在模型参数的最大后验估计点附近用一个高斯分布来近似真实的后验分布从而为我们提供预测均值和方差。我最近的工作就是尝试将这两者结合起来解决一个具体的宇宙学问题从模拟的动力学Sunyaev-Zel‘dovich效应kSZ地图中推断宇宙再电离时期的光学深度τ。这个参数至关重要它描述了光子从最后散射面传播到今天被自由电子散射的概率直接关联到宇宙第一代恒星和星系何时形成。我们采用了Swin Transformer作为特征提取的骨干网络因为它能高效捕捉图像中的长程依赖关系和层次化特征非常适合kSZ地图中复杂的、非高斯的电离气泡结构。核心的对比实验在于应用拉普拉斯近似的两种策略后验拉普拉斯近似和在线拉普拉斯近似。前者是在一个已经训练好的、性能最优的点估计模型最大后验估计MAP上“套用”拉普拉斯近似后者则是将网络权重和拉普拉斯近似的超参数如先验精度进行联合优化。这篇文章我就来详细拆解这个项目的完整流程从为什么选择这个技术栈到数据如何准备、模型如何训练、两种拉普拉斯近似具体怎么实现再到最后的性能对比和物理可解释性分析。我会分享在实现过程中踩过的坑、调参的经验以及为什么最终后验方法会胜出。无论你是对宇宙学中的机器学习应用感兴趣还是想在自己的领域里实践不确定性量化希望这篇深度复盘能给你带来一些直接的参考。2. 核心思路与技术选型解析2.1 问题定义从kSZ地图到光学深度τ首先我们得明确要解决的是一个什么样的回归问题。输入是一张模拟的kSZ效应温度涨落地图∆T/µK输出是一个标量值光学深度τ。kSZ效应是CMB光子穿过宇宙大尺度结构中的电离气体时由于电子团的整体运动体速度而产生的二次多普勒频移。这个效应留下的印记直接编码了再电离时期电离气体的分布和速度场信息。因此kSZ地图本质上是再电离时期宇宙结构的“指纹”。然而这个指纹非常微弱且淹没在原始CMB起伏、其他次级效应如热SZ效应以及前景噪声之中。即使在我们使用的理想化模拟中暂不考虑仪器噪声和天体物理前景从这种高度非高斯、结构复杂的二维图像中直接回归出一个全局物理参数也是一个极具挑战性的任务。它要求模型不仅能识别局部特征如电离气泡的边缘还要能理解这些特征的空间分布和统计规律所蕴含的全局信息。2.2 骨干网络为什么是Swin Transformer在骨干网络的选择上我们放弃了传统的卷积神经网络CNN而选择了Swin Transformer。这个决定基于几个关键的考量长程依赖建模kSZ地图中的电离气泡尺度不一且气泡间的关联可能蕴含重要信息。CNN的感受野受限于卷积核大小即使通过堆叠层数来扩大其捕捉长程依赖的效率也不及基于自注意力机制的Transformer。Swin Transformer通过引入移位窗口Shifted Windows机制在计算效率和建模能力之间取得了绝佳的平衡。它允许跨窗口的信息交互从而能够有效建模图像中任意两个像素点之间的关系这对于理解气泡的整体拓扑结构至关重要。层次化特征提取Swin Transformer的架构天然是层次化的包含多个“阶段”Stage每个阶段会进行patch merging来降低分辨率、增加通道数。这类似于CNN的下采样过程但底层是基于自注意力的特征变换。这种设计使得网络能够同时学习到低层次的边缘、纹理特征对应电离前沿和高层次的语义、结构特征对应气泡的分布模式非常适合我们的多尺度物理问题。对非高斯特征的敏感性kSZ信号是非高斯的其统计特性与传统天体物理图像不同。Transformer的自注意力机制本质上是在计算所有特征之间的相关性权重它不依赖于像CNN那样预设的平移不变性先验尽管kSZ地图在一定程度上具有统计均匀性因此可能更灵活地适应数据本身的复杂统计特性。实操心得在项目初期我们也尝试过ResNet、DenseNet等经典CNN架构。它们确实也能工作但验证集上的损失平台期来得更早且最终性能有约5%-10%的差距。切换到Swin Transformer后最直观的感受是模型收敛更快并且对学习率等超参数不那么敏感鲁棒性更好。当然Transformer类模型对计算资源的需求也更高这是需要权衡的。2.3 不确定性量化拉普拉斯近似的两种玩法确定了特征提取器接下来就是核心的不确定性量化模块——拉普拉斯近似。其核心思想非常直观对于一个训练好的神经网络我们找到了其参数θ的最大后验估计MAP。在这个最优参数点θ_MAP附近我们对对数后验概率log p(θ|D)进行二阶泰勒展开。由于在极值点一阶导数为零展开式主要取决于海森矩阵Hessian MatrixH。这样复杂的真实后验分布p(θ|D)就被近似为一个高斯分布N(θ_MAP, H^{-1})。然而对于大型神经网络计算和存储完整的海森矩阵是天文数字级别的开销。因此实践中我们使用高斯-牛顿矩阵Gauss-Newton Matrix或费雪信息矩阵Fisher Information Matrix来近似海森矩阵。更重要的是我们只对最后一层或最后几层的参数应用拉普拉斯近似这被称为“拉普拉斯最后一层”Laplace Last Layer, LLL方法在精度和计算成本之间是一个很好的折中。我们的对比聚焦于应用这个近似的时机和策略后验拉普拉斯近似这是“两步走”策略。第一步我们像训练一个普通的回归网络一样使用大量的超参数搜索学习率、权重衰减、优化器等找到一个在验证集上损失最低的、性能最优的MAP模型。这个模型本身已经是一个强大的点估计预测器。第二步在这个训练好的、固定的MAP模型上“嫁接”拉普拉斯近似。我们冻结所有网络权重只优化拉普拉斯近似相关的超参数比如先验精度prior precision来拟合后验的高斯近似。在线拉普拉斯近似这是“一步到位”策略。我们不先找一个最优的MAP模型而是将网络权重θ和拉普拉斯超参数同样比如先验精度视为需要联合优化的变量。训练目标不再是简单的负对数似然而是包含了先验项的边际似然或证据下界ELBO的某种近似。这更像是一个完整的虽然是近似的贝叶斯推理过程。关键选择解析为什么设计这个对比后验方法代表了当前实践中更常见的思路先利用成熟的深度学习技巧得到一个好模型再为其“添加”不确定性估计。它逻辑清晰易于实现和调试。在线方法则更具贝叶斯纯粹性理论上可能找到权重和不确定性之间的最优平衡点。我们的实验旨在回答一个实际问题在这种复杂的科学回归任务中为了获得“校准良好”的不确定性我们是否需要付出联合优化的额外代价和复杂度还是说一个精心调优的点估计模型加上后验近似就足够了3. 数据准备与模型实现细节3.1 模拟数据生成与预处理我们的数据来源于宇宙学数值模拟。使用诸如21cmFAST或CROC等模拟代码生成不同再电离历史下的宇宙体积并从中提取出对应的kSZ温度涨落二维投影图。每个模拟样本对应一个真实的全局光学深度τ值τ的范围大约在0.05到0.07之间这与当前观测约束如Planck卫星结果相符。数据生成流程输入一套宇宙学参数如物质密度参数Ω_m 哈勃常数h等和再电离参数化模型。模拟过程运行模拟得到三维的电子密度和速度场分布。视线积分沿某条视线方向对电子密度和速度场的乘积进行积分生成二维的kSZ温度涨落图∆T map。标签生成同一个模拟体积通过积分电子密度场计算出该视线方向上的总光学深度τ作为回归目标。关键预处理步骤归一化对每张kSZ地图我们进行全局的标准化即减去所有像素的均值除以所有像素的标准差。这有助于稳定训练。需要注意的是我们是对整个训练集计算统一的均值和标准差而不是每张图单独归一化以保持样本间的一致性。数据增强为了提升模型的泛化能力我们对训练数据施加了简单的增强包括随机水平/垂直翻转和随机90度旋转。由于kSZ信号在统计上是各向同性的这些增强是物理合理的。我们没有使用裁剪或颜色抖动因为这会破坏信号的全局统计特性。数据集划分我们生成了约10,000个独立样本按70%/15%/15%的比例划分为训练集、验证集和测试集。确保划分是随机的且来自同一套模拟参数的样本不会同时出现在训练和测试集中以避免数据泄露。避坑指南数据归一化的方式至关重要。早期我们尝试过每张图单独做“减均值除标准差”结果模型完全无法收敛。后来意识到kSZ信号的绝对值大小本身就携带了τ的信息总体而言τ越大信号幅度可能越强每张图单独归一化会抹掉这个关键特征。改用全局归一化后训练立刻稳定下来。3.2 Swin Transformer模型架构与训练我们基于PyTorch和timm库实现了Swin Transformer Tiny版本。这个版本在计算效率和表达能力之间取得了很好的平衡。模型配置输入单通道的kSZ灰度图分辨率调整为224x224像素。Patch Partition将图像划分为4x4的非重叠块每个块通过线性嵌入投影到96维。Stage 1-4包含4个阶段逐步下采样通道数从96增加到768。每个阶段由多个Swin Transformer Block堆叠而成核心是基于窗口的多头自注意力和移位窗口多头自注意力的交替。头部在最终的特征图7x7x768上执行全局平均池化得到一个768维的特征向量然后通过一个简单的线性层映射到1维输出即预测的τ值。训练超参数与技巧优化器使用AdamW优化器其内置的权重衰减对于Transformer类模型防止过拟合非常有效。学习率采用余弦退火学习率调度初始学习率设为5e-5。对于Swin Transformer较低的学习率通常更稳定。损失函数均方误差损失。这是一个标准回归任务。早停我们设置了严格的早停策略监控验证集损失如果连续15个epoch没有下降则停止训练并回滚到验证损失最低的模型 checkpoint。这是防止过拟合、确保获得高质量MAP模型的关键。从结果图中的训练曲线可以看到验证损失在平稳后早停机制及时触发。超参数搜索对于后验方法中的MAP模型训练我们使用了贝叶斯优化工具如Optuna对学习率、权重衰减系数、Dropout率等进行了超过100次的搜索以找到验证损失最低的配置。3.3 拉普拉斯近似的具体实现我们使用了laplace这个PyTorch库来实现拉普拉斯近似它提供了非常清晰的接口。后验拉普拉斯近似实现步骤# 伪代码示例 # 1. 加载训练好的最优MAP模型 map_model load_best_swin_model(‘checkpoints/best_map_model.pth’) map_model.eval() # 2. 定义拉普拉斯近似器这里我们选择只对最后一层分类头应用 from laplace import Laplace la Laplace(modelmap_model, likelihood‘regression’, subset_of_weights‘last_layer’, # 关键只近似最后一层 hessian_structure‘kron’) # 使用Kronecker因子化的费雪信息矩阵节省内存 # 3. 用训练数据拟合拉普拉斯后验 # 这里需要传入训练数据加载器 la.fit(train_loader) # 4. 设置先验精度正则化强度。这是一个超参数可以通过验证集优化。 # 例如在验证集上优化边际似然 la.optimize_prior_precision(method‘marglik’, val_loaderval_loader) # 5. 预测时可以得到预测均值和方差 predictive_mean, predictive_var la(x_test) predictive_std predictive_var.sqrt() # 这就是我们需要的误差条在线拉普拉斯近似实现步骤 在线方法的实现更为复杂因为需要修改训练循环。核心思想是在训练损失中加入一个基于当前拉普拉斯后验的正则化项或直接优化边际似然的近似。# 伪代码概念 # 在线训练循环概览 for epoch in epochs: for x_batch, y_batch in train_loader: # 前向传播 output model(x_batch) # 计算负对数似然损失 nll_loss F.mse_loss(output, y_batch) # 关键计算基于当前参数θ的拉普拉斯对数后验 # 这需要计算 log p(θ|D) ∝ log p(D|θ) log p(θ) # log p(θ) 是高斯先验的对数概率 log_prior -0.5 * prior_precision * (model.last_layer.weight ** 2).sum() # 在线目标最大化边际似然 p(D) 的近似等价于最小化负的 log p(θ|D) online_loss nll_loss - log_prior / num_data # 注意符号和归一化 optimizer.zero_grad() online_loss.backward() optimizer.step() # 在线更新拉普拉斯近似的超参数如先验精度 # 这通常需要在每个epoch或每个batch后用验证集进行一步优化 if epoch % 10 0: current_la.optimize_prior_precision(val_loader)实际上我们采用了laplace库中提供的OnlineLaplace类它封装了这部分逻辑但需要仔细调整学习率和先验精度的优化频率。实现难点在线方法最大的挑战是训练的稳定性。网络权重和先验精度在同时更新很容易导致优化过程震荡甚至发散。我们花了大量时间调整学习率调度通常需要更小的学习率和更慢的退火以及先验精度优化的频率不能太频繁。相比之下后验方法因为将两步解耦每一步都可以独立地精细调优稳定得多。4. 实验结果深度分析与对比4.1 预测性能与不确定性校准的量化评估在完成了超参数优化并选出各自的最佳模型后我们在独立的测试集上进行了全面的评估。评估指标分为两大类预测精度指标和不确定性校准指标。预测精度指标平均绝对误差衡量预测误差的平均幅度。后验LA为0.0012在线LA为0.0017。后验方法误差更小。均方根误差对大的预测误差更敏感。后验LA为0.0015在线LA为0.0021。同样后验方法更优。R²分数解释方差的比例越接近1越好。后验LA达到了0.93意味着模型能够解释数据中93%的方差而在线LA为0.86。这是一个非常显著的差距。皮尔逊相关系数衡量预测值与真实值的线性相关程度。后验LA为0.96在线LA为0.93都表现出很强的相关性但后验方法更优。不确定性校准指标卡方统计量这是评估不确定性的核心指标。对于一个完美校准的模型每个预测的标准化残差(y_true - y_pred) / σ_pred应服从标准正态分布。因此所有测试样本的标准化残差平方和即χ²值应近似等于测试集的数据点数量N_test。我们的测试集有约100个样本。后验LA的χ² 59.27在线LA的χ² 42.45两者都偏离了期望值100。后验LA的χ²值更接近N_test这表明它的不确定性估计虽然仍欠校准倾向于低估不确定性因为χ² N但比在线LA更接近理想情况。在线LA的χ²值更小说明它可能过度估计了不确定性给出的误差棒过大。结果解读 从量化指标看后验拉普拉斯近似在几乎所有预测精度指标上都显著优于在线方法。更重要的是在不确定性校准方面后验方法也表现更好。这个结果有点反直觉因为在线方法理论上更“贝叶斯”。我们的分析是后验方法受益于一个高度优化的点估计模型。我们通过大量的超参数搜索为Swin Transformer找到了一个在MAP估计上表现极佳的配置。在这个坚实的基础上应用拉普拉斯近似相当于在一个“正确”的点附近进行高斯近似。而在线方法的联合优化过程可能陷入了一个局部最优解这个解在边际似然上可能不错但网络权重本身并没有达到最好的预测性能。换句话说在线方法为了“平衡”权重和不确定性可能牺牲了点估计的精度。4.2 可视化分析散点图与误差条数字是冰冷的图表更能说明问题。我们绘制了预测值与真实值的散点图。左图无误差条可以清晰看到后验LA图a的点几乎都紧密分布在yx的对角线两侧形成一条很窄的带状区域。而在线LA图c的点则分散得多尤其是在τ值两端出现了更多偏离对角线的点。这直观地印证了R²分数和相关系数的差异。右图带误差条我们为每个预测点画出了±1σ的误差棒。理想情况下大约68%的数据点其误差棒应穿过对角线。在后验LA的图b中误差棒普遍较短且大多数能覆盖真实值。在线LA的图d中误差棒明显更长这与其较小的χ²值过度估计不确定性相符。虽然长的误差棒更“安全”但它们在科学上提供的信息量更少不够精确。图表解读心得看这种回归任务的散点图我习惯先看点的“紧致度”再看误差棒的“覆盖率”。后验LA的图呈现出我们最希望看到的样子点准、误差棒窄且覆盖好。在线LA的图则给人一种“模型自己也不太确定”的感觉。在科学应用中我们当然需要保守的不确定性估计但前提是点估计本身要足够准确。一个不准但很“自信”的模型是危险的而一个不准且“犹豫”的模型实用价值也会打折扣。4.3 物理可解释性模型到底学到了什么对于一个用于科学发现的模型我们不能只相信它的输出还必须理解它做出决策的依据。我们使用了积分梯度结合噪声隧道的方法来生成显著性图以可视化输入kSZ地图中哪些区域对模型的预测贡献最大。分析方法积分梯度计算输入像素相对于输出τ的梯度。为了得到更稳健的归因我们从一张基线图像如全黑图像到真实输入图像之间构造一条路径对路径上的梯度进行积分。噪声隧道为了平滑掉梯度中的高频噪声我们生成多张如50张在输入图像上添加了随机高斯噪声的样本分别计算它们的积分梯度然后取平均。这能得到更干净、更稳定的显著性图。关键发现 从生成的显著性图如原文图7中我们观察到两个一致且令人振奋的模式电离前沿显著性最高的区域图中最亮的区域清晰地与kSZ地图中电离气泡的边界重合。这些边界是中性氢和电离氢的过渡区电子密度梯度最大因此对积分光学深度τ的贡献也最大。模型准确地抓住了这个最关键的物理特征。速度相干结构模型并非只关注孤立的像素点而是会高亮一些延伸的、具有相干性的结构。这对应于电离气体的整体速度场。kSZ效应依赖于视线方向上的速度分量因此具有相干速度的大片区域也会对信号产生显著贡献。结论Swin Transformer并没有去学习数据中的一些虚假统计关联或噪声模式而是真正学会了识别与光学深度τ物理上相关的形态学特征——电离气泡的大小、分布以及边界。这极大地增强了我们对该模型预测结果的物理信任度。5. 讨论、局限与未来展望5.1 为什么后验方法赢了—— 优化景观的视角我们的核心结论是对于此类复杂回归任务先集中精力找到一个最优的点估计模型再为其添加不确定性量化是更有效且更稳健的策略。这背后的原因可以从优化景观来理解。深度神经网络的损失函数通常非常复杂存在大量局部极小值。超参数搜索如贝叶斯优化的目标就是找到那个能通向“更优”局部极小值的配置。后验方法将这个搜索过程专注于最小化预测损失这是一个相对明确且易于评估的目标。而在线方法的目标函数是边际似然它同时权衡了数据拟合优度和模型复杂度。在高度非凸的景观中联合优化网络权重和先验精度可能引导优化器走向一个在边际似然意义上不错、但网络权重本身并非最优预测器的区域。换句话说它可能找到了一个“简单”但“平庸”的模型其预测精度一般但因其简单而获得了较高的边际似然奥卡姆剃刀原理。在我们的任务中预测精度是首要目标因此后验方法的策略更胜一筹。个人体会这有点像装修房子。后验方法是先请最好的设计师和施工队把房子的主体结构和功能做到极致MAP模型然后再请专业的监理来评估每个部分的质量和潜在风险拉普拉斯近似。在线方法则是设计师和监理从一开始就一起工作互相妥协可能最后监理报告很漂亮不确定性校准好但房子住起来未必最舒服预测精度稍差。在科研中我们往往更需要那个“住起来舒服”的房子。5.2 当前工作的局限性与挑战理想化模拟我们使用的是洁净的、无噪声的模拟kSZ地图。真实的观测数据将包含复杂的仪器噪声、来自银河系和河外源的前景辐射、以及其他宇宙学信号的污染如原始CMB、热SZ效应、引力透镜效应等。模型能否在如此嘈杂的环境中保持性能是下一个巨大的挑战。拉普拉斯近似的局限性拉普拉斯近似本质上是局部的高斯近似。它假设真实后验分布在MAP点附近是单峰的、近似高斯的。如果后验分布是多峰的或高度非高斯这个近似就会失效。对于极度非凸的神经网络损失函数这个假设不一定总是成立。计算成本虽然拉普拉斯近似比MCMC快得多但计算高斯-牛顿矩阵的逆即使是Kronecker因子化形式对于超大型模型来说仍然有压力。我们只对最后一层应用这是一个权衡。如果需要对所有参数进行不确定性量化计算和存储开销会急剧上升。先验的选择拉普拉斯近似的结果对先验分布特别是先验精度很敏感。我们通过验证集优化了这个超参数但这仍然是一个点估计。一个更贝叶斯的方式是对先验也设置超先验但这会进一步增加复杂性。5.3 未来可行的改进方向面向真实数据的预处理未来的工作必须引入真实的噪声和前景。一种思路是使用维纳滤波等技术从CMB总图中分离出再电离时期的kSZ信号。但这也带来了新问题比如滤波过程会改变信号的统计特性以及如何分离晚期kSZ的污染。我们需要在模拟中就从端到端地模拟这些过程让模型学会在噪声中提取信号。探索其他UQ方法虽然拉普拉斯近似在这里表现不错但值得对比其他方法如蒙特卡洛Dropout、深度集成以及更近似的变分推理。深度集成训练多个独立模型通常能给出很好的不确定性估计但成本是N倍的计算量。我们需要在精度、校准度和计算成本之间做更系统的基准测试。应用到实际观测数据最终目标是应用于如西蒙斯天文台和CMB-S4等下一代CMB巡天项目的数据。这需要将我们的管道与数据处理流程整合并仔细评估系统误差。扩展到其他宇宙学参数本框架不应局限于τ。理论上它可以扩展到同时推断多个再电离参数如再电离中点红移、持续时间等甚至与其他探针如21厘米观测进行联合分析。5.4 给实践者的建议如果你正在自己的科学或工程项目中考虑引入不确定性量化基于这次项目的经验我建议从后验拉普拉斯近似开始它实现简单易于集成到现有训练流程中并且能在一个好的点估计模型上提供快速且通常不错的不确定性估计。这是一个非常实用的基线方法。不要忽视点估计模型的质量不确定性估计的质量高度依赖于底层点估计模型的质量。花时间做好超参数调优、使用早停、确保模型不过拟合这比纠结用哪种UQ方法更重要。务必进行校准评估像计算χ²统计量、绘制可靠性曲线这样的校准检查必不可少。一个给出不确定性但未校准的模型其输出可能具有误导性。结合可解释性工具特别是当你的模型用于科学发现时使用如积分梯度、遮挡测试等方法去理解模型的决策依据。这不仅能增加可信度还可能帮助你发现数据或模型中的新问题。这个项目让我深刻体会到将前沿的深度学习架构与严谨的贝叶斯思想结合确实能为宇宙学这类数据复杂、理论驱动的研究领域开辟新的道路。它不仅仅是得到一个预测值和一个误差棒更是搭建了一座连接数据驱动模型与物理理解的桥梁。虽然前路还有诸多挑战但每一步扎实的对比和验证都让我们离更可靠地解读宇宙的奥秘更近了一点。