013、HAN分层注意力:跨层交互与全局上下文融合的代码剖析
013、HAN分层注意力跨层交互与全局上下文融合的代码剖析从一次诡异的PSNR下降说起去年做视频超分项目时我在HANHierarchical Attention Network上栽了个跟头。模型训练到第80个epochPSNR突然从32.1掉到31.6然后死活上不去。排查了三天最后发现是跨层注意力融合时某个张量的维度索引写反了——这种bug在论文里根本不会提但实际跑起来就是会悄无声息地吃掉你的性能。HAN的核心思想其实很朴素低层特征包含细节纹理高层特征包含语义结构但大多数超分网络只是简单地把它们拼在一起或者加在一起。HAN想干的是让不同层之间“对话”——低层特征在生成时应该知道高层想要什么高层特征在细化时也应该参考低层的细节。这个想法在2019年提出时确实惊艳但实现起来坑不少。分层注意力模块的代码解剖先看HAN最核心的模块——分层注意力融合Hierarchical Attention Fusion。我直接贴出关键代码注释里写满了踩坑记录classHierarchicalAttention(nn.Module):def__init__(self,channels64,reduction16):super().__init__()# 注意这里reduction不能设太小否则参数量爆炸# 我试过reduction4显存直接飙到12Gself.conv_lownn.Conv2d(channels,channels//reduction,1)self.conv_highnn.Conv2d(channels,channels//reduction,1)self.conv_fusionnn.Conv2d(channels//reduction,channels,1)self.sigmoidnn.Sigmoid()defforward(self,low_feat,high_feat):# low_feat: 浅层特征high_feat: 深层特征# 这里踩过坑两个特征的spatial size必须一致# 如果来自不同分辨率的层记得先插值对齐# 生成低层对高层的注意力权重low_attnself.conv_low(low_feat)high_attnself.conv_high(high_feat)# 别这样写直接相加然后sigmoid# 应该先做element-wise乘法让低层和高层特征交互attnlow_attn*high_attn# 跨层交互的关键attnself.conv_fusion(attn)attnself.sigmoid(attn)# 用注意力加权融合fusedlow_feat*attnhigh_feat*(1-attn)returnfused这段代码看起来简单但实际调试时我犯过一个低级错误把low_attn * high_attn写成了low_attn high_attn。结果模型训练时loss下降很快但PSNR始终上不去。后来可视化注意力图才发现加法让两个特征互相抵消了注意力权重几乎全是0.5等于没做任何选择。全局上下文融合的陷阱HAN的另一个亮点是全局上下文融合Global Context Fusion。它想解决的是局部注意力只能看到patch内的信息但超分任务需要全局的纹理一致性。比如重建人脸时左眼和右眼应该对称这就需要全局上下文。看这个实现classGlobalContextFusion(nn.Module):def__init__(self,channels64):super().__init__()# 这里用1x1卷积做全局特征提取别用全连接层# 全连接层会破坏空间结构导致棋盘伪影self.conv_globalnn.Conv2d(channels,channels,1)self.conv_localnn.Conv2d(channels,channels,3,padding1)self.gatenn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels,channels//4,1),nn.ReLU(),nn.Conv2d(channels//4,channels,1),nn.Sigmoid())defforward(self,x):# 全局上下文global_featself.conv_global(x)# 这里踩过坑全局平均池化后直接上采样会丢失细节# 应该用gate机制动态调整全局和局部的比例global_weightself.gate(x)# 局部细节local_featself.conv_local(x)# 融合全局特征作为残差补充outlocal_featglobal_feat*global_weightreturnout这个模块有个隐藏的坑AdaptiveAvgPool2d(1)会把空间维度压缩到1x1如果输入特征图尺寸是64x64池化后变成1x1然后通过卷积恢复到64通道。这个过程中空间信息完全丢失了。所以后面必须用global_weight来控制全局信息的注入比例否则模型会变得“近视”——只看到全局平均颜色看不到纹理。训练时的那些坑HAN的训练比普通SRResNet要敏感得多。我总结几个血泪教训学习率设置别用默认的1e-4。HAN的分层注意力模块对学习率非常敏感我试过1e-3直接梯度爆炸1e-5又收敛太慢。最终经验值是1e-4配合warmup前5个epoch从1e-5线性增加到1e-4。损失函数选择论文里用L1 loss但我发现加上感知损失Perceptual Loss后纹理细节明显更自然。不过要注意感知损失的权重不能太大我设0.1就够太大会导致颜色偏移。批大小HAN的参数量不大约2M但注意力计算需要大量显存。我用的RTX 3090批大小只能设16输入96x96。如果你用2080Ti建议降到8否则会OOM。实战效果与调优建议在我的测试集Set5, Set14, Urban100上HAN相比EDSR在PSNR上提升了约0.15dB但视觉上纹理更锐利。不过有个问题HAN对噪声敏感。如果输入图像有轻微噪声分层注意力会放大噪声因为低层特征和高层特征的交互会传播噪声。个人经验性建议如果做视频超分别直接用HAN。视频帧间的时序信息会干扰分层注意力建议先做光流对齐再接入HAN。训练时监控注意力图的可视化。如果注意力图全是0.5或者全是1说明模型没学到有效交互需要调整学习率或网络深度。实际部署时可以把分层注意力模块替换成更轻量的版本比如用depthwise卷积代替普通卷积参数量能降40%性能几乎不变。别迷信论文里的超参数。我试过把reduction从16改成8PSNR反而掉了0.05dB因为通道压缩太狠丢失了信息。HAN这个架构在2023年之后被很多新模型超越了但它的跨层交互思想至今仍有参考价值。如果你在做超分或者图像恢复任务不妨把分层注意力作为一个baseline模块看看它在你自己的数据集上表现如何——很多时候简单的交互比复杂的transformer更有效。