DETR 技术详解(Detection Transformer)
DETR 技术详解(Detection Transformer)一、前言DETR(DEtection TRansformer)是 Facebook AI 在 2020 年提出的一种端到端目标检测模型,它首次将 Transformer 结构引入目标检测任务中,取代了传统的 anchor-based 和 NMS 后处理机制。它的核心思想是:“使用 Transformer 的 Set Prediction 框架,直接预测一组 bounding boxes,并通过匈牙利匹配机制进行 loss 计算。”本文将围绕 DETR 的模型结构、训练流程、损失函数、推理过程等进行详细讲解。二、DETR 的完整模型结构流程图(输入图像:800×800×3)Input Image (800x800x3) │ ├— Backbone: ResNet-50 / Swin Transformer → 提取多尺度特征 P3/P4/P5 │ ├— 输出 feature map `[B, C, H, W]`(如 `[1, 2048, 25, 25]`) │ ├— Neck: Feature Pyramid Network(FPN)→ 可选模块(部分变体启用) │ ├— 上采样 + Concatenate(增强小目标识别能力) │ ├— Positional Encoding → 添加位置信息给 feature map │ ├— Transformer Encoder → 自注意力建模全局特征 │ ├— Transformer Decoder + Learnable Queries → 解码器生成 object queries │ └— Detection Head: ├— Bounding Box Reg Branch(回归 `(x_center, y_center, width, height)`) └— Class Confidence Branch(分类置信度)三、DETR 的完整模型结构详解1. 主干网络(Backbone)使用标准的 CNN 或 Vision Transformer 提取图像特征;常见 backbone:ResNet-50(默认)ResNet-101Swin-Tiny/Swin-Base(DETR-DC5 / Deformable DETR 中使用)输出为 feature map:[B, C, H, W]例如:[1, 2048, 25, 25](ResNet-50 输出)2. 特征编码(Positional Encoding)将 feature map 展平为[B, C, HW];添加可学习的位置编码(positional encoding);输入给 Transformer encoder;pos_encoding=PositionEmbeddingSine()flatten_feature_map=feature_map.flatten(2).permute(2,0,1)# [HW, B, C]input_with_pos=flatten_feature_map+pos_encoding(flatten_feature_map)3. Transformer Encoder标准 Transformer 编码器;对 feature map 进行全局自注意力建模;输出为[HW, B, C]形式的编码后特征;classTransformerEncoder(nn.Module):def__init__(self,encoder_layer,num_layers):super().__init__()self.layers=_get_clones(encoder_layer,num_layers)defforward(self,src):forlayerinself.layers:src=layer(src)returnsrc4. Learnable Object Queries(解码器输入)初始化为 learnable embeddings;数量通常设为 100(支持最多 100 个 objects);作为 decoder 的初始输入;query_embed=nn.Embedding(num_queries=100,embedding_dim=d_model)queries=query_embed.weight.unsqueeze(1).repeat(1,batch_size,1)5. Transformer Decoder使用 cross-attention 查询 encoder 输出;输出为[Q, B, D]形式的 object embeddings;Q 为 object queries 数量(如 100);classTransformerDecoder(nn.Module):def__init__(self,decoder_layer,num_layers):super().__init__()self.layers=_get_clones(decoder_layer,num_layers)defforward(self,tgt,memory,...):forlayerinself.layers:tgt=layer(tgt,memory,...)returntgt6. Detection Head(边界框 + 分类分支)每个 object query 由两个 head 预测:reg_head: 回归 bounding box;cls_head: 分类 confidence;classBBoxHead(nn.Module):def__init__(self,d_model,num_classes=91):super().__init__()self.bbox_embed=MLP(d_model,d_model,4,3)# 回归头self.class_embed=nn.Linear(d_model,num_classes)