• 问答
  • 技术
  • 实践
  • 资源
YOLOX 在 MMDetection 中复现全流程解析
技术讨论

200条CV人的技术留言,学习路线/避坑指南(下拉 加载更多查看):https://bbs.cvmart.net/articles/5369


来源:OpenMMLab 欢迎关注

0 摘要

最近 YOLOX 火爆全网,速度和精度相比 YOLOv3、v4 都有了大幅提升,并且提出了很多通用性的 trick,同时提供了部署相关脚本,实用性极强。 MMDetection 开源团队成员也组织进行了相关复现。

在本次复现过程中,有5位社区成员参与贡献:

首先非常感谢几位社区成员的贡献!

通过协同开发,不仅让复现过程更加高效,而且社区成员在参与过程中可以不断熟悉算法,熟悉 MMDetection 开发模式。后续我们也会再次组织相关复现活动,让社区成员积极参与,共同成长,共同打造更加优异的目标检测框架

本文先简要介绍 YOLOX 算法,然后重点描述复现流程。

1 YOLOX 算法简介

官方开源地址:https://github.com/Megvii-BaseDetection/YOLOX

MMDetection 开源地址:GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark,欢迎 star

复现相关 projects:Support of YOLOX · open-mmlab/mmdetection

YOLOX 网上的解读非常多,详情可见官方解读:https://www.zhihu.com/question/473350307/answer/2021031747

YOLOX 的主要特性可以归纳为:

  1. Anchor-free,无需设计 anchor,更少先验,减少复杂超参,推理较高效
  2. 提出了 Decoupled Head,参考 FCOS 算法设计解耦了分类和回归分支,同时新增 objectness 回归分支
  3. 为了加快收敛以及提高性能,引入了 SimOTA 标签分配策略,其包括两部分 Multi positives 和 SimOTA,Multi positives 可以简单的增加每个 gt 所需的正样本数,SimOTA 基于 CVPR2021 最新的 OTA 算法,采用最优传输理论全局分配每个 gt 的正样本,考虑到 OTA 带来的额外训练代价,作者提出了简化版本的 OTA ,在长 epoch 训练任务中 SimOTA 和 OTA 性能相当,且可以极大的缩小训练代价
  4. 参考 YOLOV4 的数据增强策略,引入了 Mosaic 和 MixUp,并且在最后 15 个 epoch 时候关闭这两个数据增强操作,实验表明可以极大地提升性能
  5. 基于上述实践,参考 YOLOV5 网络设计思路,提出了 YOLOX-Nano、YOLOX-Tiny、YOLOX-S、YOLOX-M、YOLOX-L 和 YOLOX-X 不同参数量的模型

算法细节将在后续小结中会进行详细说明。

2 YOLOX 复现流程全解析

我们简单将 YOLOX 复现过程拆分为 3 个步骤,分别是:

  1. 推理精度对齐
  2. 训练精度对齐
  3. 重构

2.1 推理精度对齐

为了方便将官方开源权重迁移到 MMDetection 中,在推理精度对齐过程中,我们没有修改任何模型代码,而且简单的复制开源代码,分别插入 MMDetection 的 backbone 和 head 文件夹下,这样就只需要简单的替换模型 key 即可。

一个特别需要注意的点:BN 层参数不是默认值 eps=1e-5, momentum=0.1,而是 eps=1e-3, momentum=0.03,这个应该是直接参考 YOLOV5。

排除了模型方面的问题(后处理策略我们暂时也没有改),对齐推理精度核心就是分析图片前处理代码。其处理流程非常简单

# 前处理核心操作 
def preproc(image, input_size, mean, std, swap=(2, 0, 1)): 
    # 预定义输出图片大小 
    padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 

    # 保持宽高比的 resize 
    img = np.array(image) 
    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 
    resized_img = cv2.resize( 
        img, 
        (int(img.shape[1] * r), int(img.shape[0] * r)), 
        interpolation=cv2.INTER_LINEAR, 
    ).astype(np.float32) 

    # 右下 padding 
    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 

    # bgr -> rgb 
    padded_img = padded_img[:, :, ::-1] 

    # 减均值,除方差 
    padded_img /= 255.0 
    if mean is not None: 
        padded_img -= mean 
    if std is not None: 
        padded_img /= std 
    padded_img = padded_img.transpose(swap) 
    return padded_img, r     

可以发现,其前处理流程比较简单:先采用保持宽高比的 resize,然后右下 padding 成指定大小输出,最后是归一化。

在 MMDetection 中可以直接通过修改配置文件来支持上述功能:

test_pipeline = [ 
    dict(type='LoadImageFromFile'), 
    dict( 
        type='MultiScaleFlipAug', 
        img_scale=img_scale, 
        flip=False, 
        transforms=[ 
            dict(type='Resize', keep_ratio=True), 
            dict(type='RandomFlip'), 
            dict(type='Pad', size=(640, 640), pad_val=114.0), 
            dict(type='Normalize', **img_norm_cfg), 
            dict(type='DefaultFormatBundle'), 
            dict(type='Collect', keys=['img']) 
        ]) 
] 

相关的后处理阈值为:

test_cfg=dict(   
 score_thr=0.001,  
 nms=dict(type='nms', iou_threshold=0.65))) 

分别对 YOLOX-S 和 YOLOX-Tiny 模型权重在官方源码和 MMDetection 中进行评估,验证是否对齐,结果如下:

注意:由于官方开源代码一直处于更新中,现在下载的最新权重,可能 mAP 不是上表中的值。

2.2 训练精度对齐

训练精度对齐相对来说复杂很多,初步观察源码,发现训练 trick 还是蛮多的,而 MMDetection 中暂时没有直接能复用的模块。训练精度对齐由 MMDetection 的开发团队和社区用户共同完成,我们将整体的对齐分解成如下若干模块,每个模块都有社区用户参与:

  1. 优化器和学习率调度器
  2. EMA 策略
  3. Dataset
  4. Loss
  5. 其他训练 trick

考虑到 YOLOX 训练需要 300 epoch,训练时长比较长,我们采用 YOLOX-Tiny 模型进行训练精度对齐,通过对比源码和复现版本的 log 进行判断。相比标准的 YOLOX-S 模型,其差别仅仅是没有使用 MixUp 以及额外的两个超参不一样而已。

2.2.1 优化器和学习率调度器

(1) 优化器

优化器部分相对来说容易实现,其流程是:将优化参数设置为3组,卷积 bias、BN 和卷积权重,其中只对卷积的权重进行 decay,优化器是 SGD+ nesterov+ momentum。在 MMDetection 中可通过修改配置直接实现上述功能:

optimizer = dict( 
    type='SGD', 
    lr=0.01,   
    momentum=0.9, 
    weight_decay=5e-4, 
    nesterov=True, 
    paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))  

核心是依靠 paramwise_cfg 参数,关闭 norm 和 bias 的weight decay。

注意: MMDetection 中的 lr 设置是针对总 batch 的,上面写的 lr =0.01 是指的 8 卡 x 8 bs 的情况,而不算单卡的,如果你的总 bs 不是 64,那么你需要手动进行线性缩放

(2) 学习率调度器

YOLOX 的学习率调度器是带有 warmup 策略的余弦调度策略,并且为了配合数据增强,在最后 15 个 epoch 会采用固定的最小学习率

MMDetection 中导入的 MMCV 已经实现了带有 warmup 策略的余弦调度策略,但是比较麻烦的是指数 warmup 策略公式不太一样,并且不具有在最后 15 个 epoch 采用固定最小学习率的功能。为了解决上述问题,且不更改依赖 MMCV 版本,我们是在 MMDetection 中继承了原先的 CosineAnnealingLrUpdaterHook,并重写了相关方法。

class YOLOXLrUpdaterHook(CosineAnnealingLrUpdaterHook): 
    def get_warmup_lr(self, cur_iters): 

        def _get_warmup_lr(cur_iters, regular_lr): 
 # 重写 warmup 策略 
            k = self.warmup_ratio * pow( 
                (cur_iters + 1) / float(self.warmup_iters), 2) 
            warmup_lr = [_lr * k for _lr in regular_lr] 
            return warmup_lr 

        ... 

    def get_lr(self, runner, base_lr): 
        last_iter = len(runner.data_loader) * self.num_last_epochs 

        progress = runner.iter 
        max_progress = runner.max_iters 

        progress += 1 

        target_lr = base_lr * self.min_lr_ratio 

        if progress >= max_progress - last_iter: 
  # 固定学习率策略 
            return target_lr 
        else 
            return annealing_cos( 
                base_lr, target_lr, (progress - self.warmup_iters) / 
                (max_progress - self.warmup_iters - last_iter))  

为了保证代码的正确性,我们单独写了脚本,运行官方源码和 MMDetection 复现代码,比较两边的 lr 曲线是否完全一致。

2.2.2 EMA 策略

模型的指数移动平均可以提升模型鲁棒性和性能,属于一个常用的 trick。其原理是:对模型额外维护一份指数移动平均模型 ema_model,然后在每次迭代模型参数更新后,利用 model 中的参数和 ema_model 计算更新后的 ema_model,评估阶段使用的是 ema_model

def update(self, model): 
    # Update EMA parameters 
    with torch.no_grad(): 
        self.updates += 1 
        # self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) 
        d = self.decay(self.updates) 

        msd = ( 
            model.module.state_dict() if is_parallel(model) else model.state_dict() 
        ) 
        for k, v in self.ema.state_dict().items(): 
            if v.dtype.is_floating_point: 
                v *= d 
                v += (1.0 - d) * msd[k].detach()  

MMCV 中已经通过 Hook 实现了 EMA 功能,但是没有考虑 BN buffer 的 EMA 操作,而且 decay 公式也不一样,MMCV 中是线性 decay 策略,而 YOLOX 中是指数 decay。为此,我们在兼容 MMCV EMA 功能的前提下,重新设计了 EMA Hook,目前暂时放在 MMDetection 中,会在后续版本迁移到 MMCV 中。

先分析下 MMCV 中 EMA 策略实现方式,然后再说明如何考虑兼容。

EMA 的实现是采用 Hook 实现的,下面贴核心代码:

@HOOKS.register_module() 
class EMAHook(Hook): 
   # 在模型运行前,将模型参数重新 copy 一份,然后重新作为 buffer 插入到模型中 
   # 也就是说此时 Model 里面有两份相同参数的模型了 
   def before_run(self, runner): 
  model = runner.model 
        if is_module_wrapper(model): 
            model = model.module 
        self.param_ema_buffer = {} 
        self.model_parameters = dict(model.named_parameters(recurse=True)) 
        for name, value in self.model_parameters.items(): 
            # "." is not allowed in module's buffer name 
            buffer_name = f"ema_{name.replace('.', '_')}" 
            self.param_ema_buffer[name] = buffer_name 
            model.register_buffer(buffer_name, value.data.clone()) 
        self.model_buffers = dict(model.named_buffers(recurse=True)) 
        # 这个很关键,后续会说明 
        if self.checkpoint is not None: 
            runner.resume(self.checkpoint) 

     # 实时计算更新 ema 模型参数  
     def after_train_iter(self, runner): 
  curr_step = runner.iter 
        # We warm up the momentum considering the instability at beginning 
        momentum = min(self.momentum, 
                       (1 + curr_step) / (self.warm_up + curr_step)) 
        if curr_step % self.interval != 0: 
            return 
        for name, parameter in self.model_parameters.items(): 
            buffer_name = self.param_ema_buffer[name] 
            buffer_parameter = self.model_buffers[buffer_name] 
            buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)        
     # 非常关键,后续说明 
     def after_train_epoch(self, runner): 
 """We load parameter values from ema backup to model before the 
     EvalHook.""" 
  self._swap_ema_parameters() 

    def before_train_epoch(self, runner): 
 """We recover model's parameter from ema backup after last epoch's 
     EvalHook.""" 
  self._swap_ema_parameters() 

    def _swap_ema_parameters(self): 
 """Swap the parameter of model with parameter in ema_buffer.""" 
  for name, value in self.model_parameters.items(): 
            temp = value.data.clone() 
            ema_buffer = self.model_buffers[self.param_ema_buffer[name]] 
            value.data.copy_(ema_buffer.data) 
            ema_buffer.data.copy_(temp)         

为了后续方便保存和恢复模型,我们没有单独维护一个新的 ema_model,而是将参数重新 copy 一份,然后作为 buffer 插入到原先模型中,变成两份。然后在每次迭代训练后,利用 momentum 和更新后的模型参数来更新 ema 模型。

因为在评估时候采用的是 ema 模型,为了不影响后面的 evalhook 和 save checkpoint 相关逻辑,我们在每次开启 epoch 训练前和训练后都会交换一次模型参数,这样在评估过程就会自动使用 ema 模型,这是一个比较巧妙的 trick,需要特意强调的是 EMAHook 优先级一定要确保比 evalhook 和 save checkpoint 相关逻辑高,否则会出问题。因为整个训练过程流程是:

  1. 首先在外面构建模型参数和初始化模型
  2. 在 before_run 中新建一份 ema 模型参数的 buffer,并且插入到原先模型中
  3. 在开启 epoch 训练前 before_train_epoch 交换一次权重,即 model 的参数值是 ema 参数值,而 ema 参数值是 model,并且由于此时两者完全相等,所以可以认为没有交换
  4. 迭代的更新 model 参数,并且在 after_train_iter 中实时更新 ema 模型参数
  5. 由于 EMAHook 优先级较高,故会优先于 evalhook 和 save checkpoint 相关逻辑,其会先运行 after_train_epoch 交换一次参数,此时 model 的参数值是 ema 参数值,而 ema 参数值是 model,完成了一次真正的参数交换
  6. 运行 evalhook 和 save checkpoint 相关逻辑,其面对的 model 是 ema model,符合预期
  7. 在下一次循环迭代时候,重复 3-6 过程,不断的交换参数

大家可以发现,此时保存的模型虽然同时保存了 ema model 和本身 model 参数,但是实际上是反的,也就是说保存的 model 参数实际上是 ema model 参数值,而 ema model 参数是 model 参数值。如此设计的原因是为了能够完全正确的 resume。在是 resume 阶段,以下步骤会依次执行:

  1. 首先在外面构建模型参数和初始化模型
  2. resume 模型时候,由于此时 ema 模型还没有构建,所以只能加载权重字典中的 model 字段,但实际上该 model 字段是 ema
  3. 在 before_run 中新建一份 ema 模型参数的 buffer,并且插入到原先模型中,此时就构建出了 ema 模型
  4. 如果传入了待 resume 的 checkpoint,此时会重新加载一遍,由于 ema 模型已经构建,所以 ema 模型和 model 都会被 resume,同时 ema 模型参数是 model,而 model 参数是 ema
  5. 在开启 epoch 训练前 before_train_epoch 交换一次权重,此时就正确了,也就是说到这一步就完全恢复了
  6. 继续正常的训练

需要强调:由于这种特殊的 resume 技巧,当你需要对模型进行 resume 时候,暂时不可以通过外部指定 resume 参数实现,必须要修改配置中的 resume_from 字段,否则 resume 过程是不正确的

上述逻辑可能比较绕,大家需要仔细思考。这么设计的原因是:1. 方便后续评估 2. 能够正确 resume。

考虑到 YOLOX 中的 ema 是需要同时平滑 buffer 的,为此我们重新进行了设计,在兼容的同时有更好的扩展性。

(1) 考虑要能够平滑 BN 的 buffer 参数,我们增加了 skip_buffers 参数

if self.skip_buffers: 
    # 如果跳过,那就直接用参数即可 
    self.model_parameters = dict(model.named_parameters()) 
else: 
    # 如果不跳过,则直接使用状态字典 
    self.model_parameters = model.state_dict()  

(2) 考虑会可能存在多种平滑曲线,我们设计了 BaseEMAHook,然后继承这个类进行扩展不同的平滑曲线

@HOOKS.register_module() 
class ExpMomentumEMAHook(BaseEMAHook): 

 def __init__(self, total_iter=2000, **kwargs): 
        super(ExpMomentumEMAHook, self).__init__(**kwargs) 
        self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-( 
            1 + x) / total_iter) + self.momentum 

@HOOKS.register_module() 
class LinearMomentumEMAHook(BaseEMAHook): 

 def __init__(self, warm_up=100, **kwargs): 
        super(LinearMomentumEMAHook, self).__init__(**kwargs) 
        self.momentum_fun = lambda x: min(self.momentum**self.interval, 
                                          (1 + x) / (warm_up + x))        

后续我们会将该 EMA hook 从 MMDetection 移动到 MMCV 中变成基础模块。

2.2.3 Dataset

dataset 部分最复杂,其包括 Mosaic 、MixUp、ColorJit 和动态 resize 等等操作。由于 MMDetection 中暂时都没有实现上述组件,并且超参很多,为此在我们第一版中实际上是直接复制了源码的 dataset,只对 dataset 输出进行包装,使其能够接入 MMDetection 训练过程,在训练精度对齐后进行重新设计。本小结先简要分析源码,然后再描述 MMDetection 实现过程。

(1) 源码解释

dataset = COCODataset(...) 

dataset = MosaicDetection( 
    dataset, 
    mosaic=not no_aug, 
    img_size=self.input_size, 
    preproc=TrainTransform( 
        rgb_means=(0.485, 0.456, 0.406), 
        std=(0.229, 0.224, 0.225), 
        max_labels=120, 
    ), 
    degrees=self.degrees, 
    translate=self.translate, 
    scale=self.scale, 
    shear=self.shear, 
    perspective=self.perspective, 
    enable_mixup=self.enable_mixup, 
)  

为了实现 Mosaic 和 Mixup,作者引入了 MosaicDetection 来包裹 COCODataset。其 dataset 流程为:

def __getitem__(self, idx): 
    # 是否进入 mosaic,默认前 285 个 epoch 都会进入 
    if self.enable_mosaic: 
         1 马赛克增强 
         2 几何变换增强 
         # 是否进入 mixup,nano 和 tiny 版本默认是关闭的 
         if self.enable_mixup and not len(mosaic_labels) == 0: 
             3 mixup 增强 
          4 图片后处理   
    else: 
       4 图片后处理 

总共分成上述 4 个步骤,整体流程如下图所示(前两步是马赛克增强,第三步是几何变换增强,第 4 步是 MixUp 增强)

上图由社区的小伙伴 HAOCHENYE 提供,感恩!

1) 马赛克增强

  1. 随机出 4 张图片在待输出图片中交接的中心点坐标
  2. 随机出另外 3 张图片的索引以及读取对应的标注
  3. 对每张图片采用保持宽高比的 resize 操作缩放到指定大小
  4. 按照上下左右规则,计算每张图片在待输出图片中应该放置的位置,因为图片可能出界故还需要计算裁剪坐标
  5. 利用裁剪坐标将缩放后的图片裁剪,然后贴到前面计算出的位置,其余位置全部补 114 像素值
  6. 对每张图片的标注也进行相应处理
  7. 由于拼接了 4 张图,所以输出图片大小会扩大 4 倍

2) 几何变换增强

random_perspective 包括平移、旋转、缩放、错切等增强,并且会将输入图片还原为 (640, 640),同时对增强后的标注进行处理,过滤规则是

  1. 增强后的 gt bbox 宽高要大于 wh_thr
  2. 增强后的 gt bbox 面积和增强前的 gt bbox 面积要大于 ar_thr,防止增强太严重
  3. 最大宽高比要小于 area_thr,防止宽高比改变太多

3) MixUp

Mixup 实现方法有多种,常见的做法是:要么 label 直接拼接起来,要么 label 也采用 alpha 混合,作者的做法非常简单,对 label 直接拼接即可,而图片也是采用固定的 0.5:0.5 混合方法。

其处理流程是:

  1. 随机出一张图片,必须要保证该图片不是空标注
  2. 对随机出的图片采用保持宽高比的 resize 操作缩放到指定大小
  3. 然后左上 padding 成指定大小,padding 值也是 114
  4. 对 padding 后的图片进行随机抖动增强
  5. 随机采用 flip 增强
  6. 如果处理后的图片比原图大,则还需要进行随机裁剪增强
  7. 对标签进行对应处理,并且采用和马赛克增强一样的过滤规则
  8. 如果过滤后还存在 gt bbox,则采用 0.5:0.5 的比例混合原图和处理后的图片,标签则直接拼接即可

4) 图片后处理

图片后处理操作也包括众多数据增强操作,如下所示:

  1. 随机 ColorJit,包括众多颜色相关增强
  2. 随机翻转增强
  3. 对随机后的图片采用保持宽高比的 resize 操作缩放到指定大小
  4. 对于宽高小于 8 像素的 gt bbox 直接删掉,因为网络输出的最小 stride 是 8
  5. Padding 成正方形图片输出

(2) MMDetection 实现

Dataset 部分涉及到的代码比较多,主要包括:

  1. 之前框架中还没有实现过类似 Mosaic 等需要再次利用 dataset 相关信息的代码
  2. 之前框架中还没有实现过类似在某一阶段关闭某些数据增强的操作

针对第一个问题,以 Mosaic 为例实现方式有多种,下面列一下暂时想到的方案:

  • 类似 RandomFlip 等,作为 pipeline 实现

参考:https://github.com/open-mmlab/mmdetection/blob/e41dc0cb26ea43302c6444f504c99a688fc93ff4/mmdet/datasets/pipelines/transforms.py#L1815

作为 pipeline 实现的时候,需要额外插入 dataset 对象或者相关信息,否则内部获取不到 dataset。这种做法对应的配置写法是:

mosaic_pipeline = [ 
    dict(type='LoadImageFromFile', to_float32=True), 
    dict(type='LoadAnnotations', with_bbox=True), 
    dict(type='PhotoMetricDistortion'), 
    dict( 
        type='Expand', 
        mean=img_norm_cfg['mean'], 
        to_rgb=img_norm_cfg['to_rgb'], 
        ratio_range=(1, 2)), 
] 

mosaic_data = dict( 
    type=dataset_type, 
    ann_file=data_root + 'annotations/instances_train2017.json', 
    img_prefix=data_root + 'train2017/', 
    pipeline=mosaic_pipeline) 

train_pipeline = [ 
    dict(type='LoadImageFromFile', to_float32=True), 
    dict(type='LoadAnnotations', with_bbox=True), 
    dict(type='PhotoMetricDistortion'), 
    dict( 
        type='Expand', 
        mean=img_norm_cfg['mean'], 
        to_rgb=img_norm_cfg['to_rgb'], 
        ratio_range=(1, 2)), 
 dict(type='Mosaic', size=(416, 416), dataset=mosaic_data, min_offset=0.2), 
    dict(type='Resize', img_scale=[(320, 320), (416, 416)], keep_ratio=True), 
    dict(type='RandomFlip', flip_ratio=0.5), 
    dict(type='Normalize', **img_norm_cfg), 
    dict(type='Pad', size_divisor=32), 
    dict(type='DefaultFormatBundle'), 
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 
]  

这种写法的好处是可以无缝插入任何 MMDetection 已经实现的算法中,配置改动最小。但是其有一个非常致命的缺点:dataset 需要在内部重新 build,这个开销比较大,特别是当你的数据集非常大的时候。而且如果有多个类似的混合数据增强,那么就需要 build 多次,为此这种写法实际上很难接受。

当然有其他折中办法,例如在 dataset 基类中传入 self 对象,例如 results['dataset']=self,因为 dataset 对象可以贯穿整个 pipeline 生命周期,但是其风险比较大,例如可能出现循环调用,同时由于 self 对象贯穿整个生命周期,一旦你在某个时候不小心修改了 dataset,那么产生的 bug 将会难以估计和排查,为此我们也不打算采用这种做法。

  • 类似 RandomFlip 等,作为 pipeline 实现,但是内部缓存 index

其核心就是:在内部维持一个固定大小的缓冲池,实时更新和缓存已经读取过的 index,只要缓存池足够大,那么理论上 Mosaic 效果应该非常接近。

这种做法可以避免第一种做法缺陷,优点也非常明显:即插即用,无其他要求,但是其效果是否和标准的 mosaic 效果是否一致,需要调研下。由于时间比较赶,我们没有去调研最终效果是否一致。

  • 类似 RepeatDataset 等,作为 dataset wrapper 实现

在 mmdet/datasets/dataset_wrappers.py 中实现 MultiImageMixDataset,其会对 dataset 进行包装,避免了多次 build 的性能开销问题。

其核心代码是:

def __getitem__(self, idx): 
    # 获取当前 idx 的图片信息 
    results = self.dataset[idx] 
    # 遍历 transorm,其中可以包括 mosaic 、mixup、flip 等各种  transform 
    for (transform, transform_type) in zip(self.pipeline, 
                                           self.pipeline_types): 
        # 考虑到某些训练阶段需要动态关闭掉部分数据增强,故引入   _skip_type_keys                         
        if self._skip_type_keys is not None and \ 
                transform_type in self._skip_type_keys: 
            continue 
        if hasattr(transform, 'get_indexes'): 
            # transform 如果额外提供了 get_indexes 方法,则表示需要进行混合数据增强 
            # 返回索引 
            indexes = transform.get_indexes(self.dataset) 
            if not isinstance(indexes, collections.abc.Sequence): 
                indexes = [indexes] 
            # 得到混合图片信息     
            mix_results = [ 
                copy.deepcopy(self.dataset[index]) for index in indexes 
            ] 
            results['mix_results'] = mix_results 
        # 动态尺度 resize 
        if self._dynamic_scale is not None: 
            # Used for subsequent pipeline to automatically change 
            # the output image size. E.g MixUp, Resize. 
            results['scale'] = self._dynamic_scale 
        # 数据增强 
        results = transform(results) 

        if 'mix_results' in results: 
            results.pop('mix_results') 
        if 'img_scale' in results: 
            results.pop('img_scale') 
    return results  

其核心是对于 Mosaic 或者 MixUp 等需要混合数据的增强操作,对应的 transform 需要额外提供 get_indexes 方式,内部返回 indexes 信息;然后 MultiImageMixDataset 会自动完成相关获取数据操作。

以 Mosaic 和 MixUp 为例,其作为 pipeline 的写法如下所示:

@PIPELINES.register_module() 
class Mosaic: 
 def __init__(self, 
                 img_scale=(640, 640), 
                 center_ratio_range=(0.5, 1.5), 
                 pad_val=114): 
        assert isinstance(img_scale, tuple) 
        self.img_scale = img_scale 
        self.center_ratio_range = center_ratio_range 
        self.pad_val = pad_val 

    def __call__(self, results): 
 results = self._mosaic_transform(results) 
        return results 

    def get_indexes(self, dataset): 
 indexs = [random.randint(0, len(dataset)) for _ in range(3)] 
        return indexs 

@PIPELINES.register_module() 
class MixUp: 
 def __init__(self, 
                 img_scale=(640, 640), 
                 ratio_range=(0.5, 1.5), 
                 flip_ratio=0.5, 
                 pad_value=114, 
                 max_iters=15, 
                 min_bbox_size=5, 
                 min_area_ratio=0.2, 
                 max_aspect_ratio=20): 
        assert isinstance(img_scale, tuple) 
        self.dynamic_scale = img_scale 
        self.ratio_range = ratio_range 
        self.flip_ratio = flip_ratio 
        self.pad_value = pad_value 
        self.max_iters = max_iters 
        self.min_bbox_size = min_bbox_size 
        self.min_area_ratio = min_area_ratio 
        self.max_aspect_ratio = max_aspect_ratio 

    def __call__(self, results): 
 results = self._mixup_transform(results) 
        return results 

    # 必须要返回非空 gt bbox 数据索引 
    def get_indexes(self, dataset): 
 for i in range(self.max_iters): 
            index = random.randint(0, len(dataset)) 
            gt_bboxes_i = dataset.get_ann_info(index)['bboxes'] 
            if len(gt_bboxes_i) != 0: 
                break 

        return index  

对应的完整配置写法如下:

train_pipeline = [ 
    dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), 
    dict( 
        type='RandomAffine', 
        scaling_ratio_range=(0.1, 2), 
        border=(-img_scale[0] // 2, -img_scale[1] // 2)), 
    dict(type='MixUp', img_scale=img_scale, ratio_range=(0.8, 1.6)), 
   # PhotoMetricDistortion 和 YOLOX 中实现的颜色增强不一样,PhotoMetricDistortion 增强更强一些,但是考虑随机操作可能对最终性能没有很大影响,故我们并没有对其进行修改 
    dict( 
        type='PhotoMetricDistortion', 
        brightness_delta=32, 
        contrast_range=(0.5, 1.5), 
        saturation_range=(0.5, 1.5), 
        hue_delta=18), 
    dict(type='RandomFlip', flip_ratio=0.5), 
    dict(type='Resize', keep_ratio=True), 
    dict(type='Pad', pad_to_square=True, pad_val=114.0), 
    dict(type='Normalize', **img_norm_cfg), 
    dict(type='DefaultFormatBundle'), 
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 
] 

train_dataset = dict( 
    type='MultiImageMixDataset', 
    dataset=dict( 
        type=dataset_type, 
        ann_file=data_root + 'annotations/instances_train2017.json', 
        img_prefix=data_root + 'train2017/', 
        pipeline=[ 
            dict(type='LoadImageFromFile', to_float32=True), 
            dict(type='LoadAnnotations', with_bbox=True) 
        ], 
        filter_empty_gt=False, 
    ), 
    pipeline=train_pipeline, 
    dynamic_scale=img_scale)  

这种写法虽然会对目前已经实现的算法配置写法有较大改动,但是可扩展强,不存在性能开销问题。

对于 MultiImageMixDataset 中实现的动态 scale 和关闭数据增强操作在 2.2.5 小结分析。需要说明的是: 由于 YOLOX 算法开发进度比较快,时间比较赶,上述方案可能不是最好的,我们后续也会慢慢改进的,使其在满足可扩展性前提下,易用性提高,出错率能够降低。

2.2.4 Loss

关于 Loss 部分,在第一版中我们是直接复制了源码。考虑到 loss 是 YOLOX 的核心,故先简要分析源码,然后再描述 MMDetection 重构版本。

(1) 源码解释

其网络输出包括 3 个 尺度, stride 分别是 8、16 和 32,每个输出尺度上又包括 3 个输出,分别是 bbox 输出分支、objectness 输出分支和 cls 类别输出分支。其 loss 计算流程为:

1) 计算 3 个输出层所需要的特征图尺度的坐标,用于 bbox 解码

yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) 
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)  

2) 对输出 bbox 进行解码还原到原图尺度

output = output.view(batch_size, 1, n_ch, hsize, wsize) 
output = output.permute(0, 1, 3, 4, 2).reshape( 
    batch_size, 1 * hsize * wsize, -1 
) 
grid = grid.view(1, -1, 2) 
# 还原 
output[..., :2] = (output[..., :2] + grid) * stride 
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride  

通过上述解码公式,可以知道 bbox 输出预测值 cxcywh,分别代表 gt bbox 中心和当前网格左上标偏移以及 wh 的指数变换值,并且都基于当前 stride 进行了缩放。

3) 对每张图片单独计算匹配的正样本和对应的 target

因为后续匹配规则需要考虑中心区域,故提前计算每个 gt bbox 在指定范围内的中心区域。作者引入了超参 center_radius =2.5,其主要计算过程为:

  1. 基于 grid 和 stride 计算 anchor 点的中心坐标,其中 anchor 是个数为 1 的正方形 anchor
  2. 计算所有 gt bbox 的中心坐标
  3. 计算所有在 gt bbox 内部的 anchor 点的掩码 is_in_boxes_all
  4. 利用 center_radius 阈值重新计算在 gt bbox 中心 center_radius 范围内的 anchor 点的掩码 is_in_centers_all
  5. 两个掩码取并集得到在 gt bbox 内部或处于 center_radius 范围内的 anchor 点的掩码 is_in_boxes_anchor,同时可以取交集得到每个 gt bbox 和哪些 anchor 点符合 gt bbox 内部和处于 center_radius 范围内的 anchor is_in_boxes_and_center

4) 计算每张图片中 gt bbox 和候选预测框的匹配代价

# fg_mask : (n,) 如果某个位置是 True 代表该 anchor 点是前景即 
# 落在 gt bbox 内部或者在距离 gt bbox 中心 center_radius 半径范围内 
# is_in_boxes_and_center:(num_gt,n), 如果某个位置是 True 代表 
# 该 anchor 点落在 gt bbox 内部并且在距离 gt bbox 中心 center_radius 半径范围内 
# 提取对应值 
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] 
cls_preds_ = cls_preds[batch_idx][fg_mask] 
obj_preds_ = obj_preds[batch_idx][fg_mask] 
num_in_boxes_anchor = bboxes_preds_per_image.shape[0] 

# 计算预测框和 gt bbox 的配对 iou 
pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) 
gt_cls_per_image = ( 
    F.one_hot(gt_classes.to(torch.int64), self.num_classes) 
    .float() 
    .unsqueeze(1) 
    .repeat(1, num_in_boxes_anchor, 1) 
) 
# iou 越大,匹配度越高,所以需要取负号 
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) 

cls_preds_ = ( 
    cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 
    * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() 
) 
# 配对的分类 Loss,包括了 iou 分支预测值 
pair_wise_cls_loss = F.binary_cross_entropy( 
    cls_preds_.sqrt_(), gt_cls_per_image, reduction="none" 
).sum(-1) 
del cls_preds_ 

# 计算每个 gt bbox 和选择出来的候选预测框的分类 loss + 坐标 loss + 中心点和半径约束 
# 值越小,表示匹配度越高 
# (num_gt,n) 
cost = ( 
    pair_wise_cls_loss 
    + 3.0 * pair_wise_ious_loss 
    + 100000.0 * (~is_in_boxes_and_center) 
)  
  1. fg_mask 就是前面计算出的 is_in_boxes_anchor,如果某个位置是 True 代表该 anchor 点是前景即落在 gt bbox 内部或者在距离 gt bbox 中心 center_radius 半径范围内,这些 True 位置就是正样本候选点
  2. 利用 fg_mask 提取对应的预测信息,假设 num_gt 是 3,一共提取了 800 个候选预测位置,则每个 gt bbox 都会提取出 800 个候选位置
  3. 计算候选预测框和 gt bbox 的配对 iou,然后加 log 和负数,变成 iou 的代价函数
  4. 计算候选预测框和 gt bbox 的配对分类代价值,同时考虑了 objectness 预测分支,并且其分类 cost 在 binary_cross_entropy 前有开根号的训练 trick
  5. is_in_boxes_and_center shape 是 (3, 800), 如果某个位置是 True 表示该 anchor 点落在 gt bbox 内部并且在距离 gt bbox 中心 center_radius 半径范围内。在计算代价函数时候,如果该预测点是 False,表示不再交集内部,那么应该不太可能是候选点,所以给予一个非常大的代价权重 100000.0,该操作可以保证每个 gt bbox 最终选择的候选点不会在交集外部

上述计算出的代价值充分考虑了各个分支预测值,也考虑了中心先验,有利于训练稳定和收敛,同时也为后续的动态匹配提供了全局信息。

5) 为每个 gt bbox 动态选择 k 个 候选预测值,作为匹配正样本

def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): 
    # 匹配矩阵初始化为 0 
    matching_matrix = torch.zeros_like(cost) 
    ious_in_boxes_matrix = pair_wise_ious 
    # 每个 gt bbox 选择的候选预测点不超过 10 个 
    n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) 
    topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) 
    # 每个 gt bbox 的动态 k 
    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) 
    for gt_idx in range(num_gt): 
        _, pos_idx = torch.topk( 
            cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False 
        ) 
        # 匹配上位置设置为 1 
        matching_matrix[gt_idx][pos_idx] = 1.0 
    del topk_ious, dynamic_ks, pos_idx 
    # n, 表示该候选点有没有匹配到 gt bbox  
    anchor_matching_gt = matching_matrix.sum(0) 
    if (anchor_matching_gt > 1).sum() > 0: 
        # 如果某个候选点匹配了多个 gt bbox,则选择代价最小的,保证每个候选点只能匹配一个 gt bbox  
        _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) 
        matching_matrix[:, anchor_matching_gt > 1] *= 0.0 
        matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 

    # 每个候选点的匹配情况 
    fg_mask_inboxes = matching_matrix.sum(0) > 0.0 
    # 总共有多少候选点 
    num_fg = fg_mask_inboxes.sum().item() 
    # 更新前景掩码,在前面中心先验的前提下进一步筛选正样本 
    fg_mask[fg_mask.clone()] = fg_mask_inboxes 

    # 该候选框匹配到哪个 gt bbox  
    matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) 
    gt_matched_classes = gt_classes[matched_gt_inds] 

    # 提取对应的预测点和gt bbox 的 iou 
    pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[ 
        fg_mask_inboxes 
    ] 
    return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds  
  1. 初始化 gt bbox 和候选点的匹配矩阵为全 0,表示全部不匹配
  2. 每个 gt bbox 选择的候选预测点不超过 10 个
  3. 利用前面的匹配代价,给每个 gt bbox 计算动态 k
  4. 遍历每个 gt bbox,提取代价为前动态 k 个位置,表示匹配上
  5. 如果某个候选点匹配了多个 gt bbox,则选择代价最小的,保证每个候选点只能匹配一个 gt bbox
  6. 返回总共有多少候选点 num_fg、每个候选点匹配上的 gt bbox 信息 gt_matched_classes、每个候选点匹配上的 和 gt bbox 计算的 IoU 值、匹配上的 gt bbox 索引、更新后的前景掩码 fg_mask,其长度和预测点个数相同,其中 1 表示正样本点,0 表示负样本点。

6) 计算 loss

分类分支和 objectness 分支采用 bce loss,bbox 预测分支采用 IoU Loss。

  1. 分类分支仅仅考虑正样本即 fg_mask 为 1 的位置,其 label 是同时考虑了预测值和 gt bbox 的 IoU 值,用于加强各分支间的一致性
  2. objectness 分支需要同时考虑正负样本,其起到抑制背景的作用,其 label 就是上述 fg_mask,非 0 即 1,即所有候选点
  3. bbox 分支也仅仅考虑正样本,其 label 就是正样本候选点所对应的解码后的预测值

7) 附加 L1 Loss

在最后 15 个 epoch 后,作者加入了额外的 L1 Loss。其作用的对象是原始没有解码的正样本 bbox 预测值,和对应的 gt bbox。

从以上分析可知: 分类分支不考虑背景,背景预测功能由 objectness 分支提供,而 bbox 分支联合采用了 IoU Loss 和 L1 Loss,其最大改进在于动态匹配。

(2) MMDetection 重构版本

Loss 写法的重构,主要改动是采用 MMDetection 中默认的构建方式即 prior_generator + bbox assign + bbox encode decode + loss。

YOLOX 是 anchor-free 算法,但是依然需要特征图上每个预测点的 point 信息,为此我们直接采用 RepPoints 中的 MlvlPointGenerator 类用于生成 anchor-point 坐标;bbox assign 部分则重写并新建了 SimOTAAssigner 类,用于对预测点进行分配正负样本;其他地方则完善了命名规范,和简化了部分写法,使其更加容易理解,整体逻辑没有改动。

2.2.5 其他训练 trick

除了上述核心部件,作者还引入了其他训练 trick,如下所示:

  1. 在最后 15 个 epoch 关闭 Mosaic 和 MixUp 增强,并且增加额外的 L1 loss
  2. 每隔一定间隔,改变输出图片尺寸,并且保证多卡之间的图片尺寸相同
  3. 每隔一定间隔,对 BN 参数进行多卡同步,保证评估时候不同卡的权重性能一致

(1) 关闭 Mosaic 和 MixUp 增强,增加额外的 L1 loss

def before_epoch(self): 

    if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug: 
        logger.info("--->No mosaic aug now!") 
        self.train_loader.close_mosaic() 
        logger.info("--->Add additional L1 loss now!") 
         self.model.head.use_l1 = True  

MMDetection 是通过 hook 实现上述功能的

@HOOKS.register_module() 
class YOLOXModeSwitchHook(Hook): 

 def __init__(self, num_last_epochs=15): 
        self.num_last_epochs = num_last_epochs 

    def before_train_epoch(self, runner): 
 """Close mosaic and mixup augmentation and switches to use L1 loss.""" 
 epoch = runner.epoch 
        train_loader = runner.data_loader 
        model = runner.model 
        if is_parallel(model): 
            model = model.module 
        if (epoch + 1) == runner.max_epochs - self.num_last_epochs: 
            runner.logger.info('No mosaic and mixup aug now!') 
            train_loader.dataset.update_skip_type_keys( 
                ['Mosaic', 'RandomAffine', 'MixUp']) 
            runner.logger.info('Add additional L1 loss now!') 
            model.bbox_head.use_l1 = True  

核心就是 train_loader.dataset.update_skip_type_keys,将需要排除的 pipeline 对应的类名写入即可。 Loss 切换也是同理。

(2) 改变输出图片尺寸

# 在每次训练迭代后,判断 
if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0: 
    self.input_size = self.exp.random_resize( 
        self.train_loader, self.epoch, self.rank, self.is_distributed 
    ) 

def random_resize(self, data_loader, epoch, rank, is_distributed): 
    tensor = torch.LongTensor(2).cuda() 

    if rank == 0: 
        # 随机采样一个新的 size 
        size_factor = self.input_size[1] * 1.0 / self.input_size[0] 
        size = random.randint(*self.random_size) 
        size = (int(32 * size), 32 * int(size * size_factor)) 
        tensor[0] = size[0] 
        tensor[1] = size[1] 

    if is_distributed: 
        dist.barrier() 
        dist.broadcast(tensor, 0) # 广播到其余卡 

    # 改变图片 size 
    input_size = data_loader.change_input_dim( 
        multiple=(tensor[0].item(), tensor[1].item()), random_range=None 
    ) 
    return input_size  

并且利用 broadcast 将 tensor 广播给其余卡,从而实现不同卡间的输入图片尺寸相同功能。

MMDetection 是通过 hook 实现上述功能的

@HOOKS.register_module() 
class SyncRandomSizeHook(Hook): 
 """Change and synchronize the random image size across ranks, currently 
    used in YOLOX. 

    Args: 
        ratio_range (tuple[int]): Random ratio range. It will be multiplied 
            by 32, and then change the dataset output image size. 
            Default: (14, 26). 
        img_scale (tuple[int]): Size of input image. Default: (640, 640). 
        interval (int): The interval of change image size. Default: 10. 
    """ 

 def __init__(self, 
                 ratio_range=(14, 26), 
                 img_scale=(640, 640), 
                 interval=10): 
        self.rank, world_size = get_dist_info() 
        self.is_distributed = world_size > 1 
        self.ratio_range = ratio_range 
        self.img_scale = img_scale 
        self.interval = interval 

    def after_train_iter(self, runner): 
 """Change the dataset output image size.""" 
 if self.ratio_range is not None and (runner.iter + 
                                             1) % self.interval == 0: 
            # 同步动态 resize  

核心是 runner.data_loader.dataset.update_dynamic_scale,该操作会实时改变 MultiImageMixDataset 中的 _dynamic_scale 属性,进而改变对应 pipeline 的 scale 字段值例如 Resize 和 MixUp,实现每个几个 iter 就改变 size 的多尺度训练功能。

(3) BN 参数进行多卡同步

在每次评估前,会同步下不同卡的 BN 参数,保证不同卡间参数的一致性,其中字典对象的同步过程参考了 detectron2 中的代码实现。

def all_reduce_norm(module): 
 """ 
    All reduce norm statistics in different devices. 
    """ 
 states = get_async_norm_states(module) 
    states = all_reduce(states, op="mean") 
    module.load_state_dict(states, strict=False)  

MMDetection 是通过 hook 实现上述功能的,并且考虑多卡同步字典操作是通用的,为此将其封装到 get_norm_states(module) 中。

@HOOKS.register_module() 
class SyncNormHook(Hook): 

 def __init__(self, interval=1): 
        self.interval = interval 

    def after_train_epoch(self, runner): 
 """Synchronizing norm.""" 
 epoch = runner.epoch 
        module = runner.model 
        if (epoch + 1) % self.interval == 0: 
            _, world_size = get_dist_info() 
            if world_size == 1: 
                return 
            norm_states = get_norm_states(module) 
            norm_states = all_reduce_dict(norm_states, op='mean') 
            module.load_state_dict(norm_states, strict=False)  

(4) Hook 优先级注意事项

通过前面分析,可以发现 YOLOX 重构插入了 5 个 hook,此时就需要特别考虑下 hook 优先级,特别需要注意的是 EMAHook 优先级一定要高于 evalhook 和保存权重操作,其余几个 hook 优先级只要比 EMAHook 高就行,其配置如下所示:

custom_hooks = [ 
    dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48), 
    dict( 
        type='SyncRandomSizeHook', 
        ratio_range=(14, 26), 
        img_scale=img_scale, 
        interval=interval, 
        priority=48), 
    dict(type='SyncNormHook', interval=interval, priority=48), 
    dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49) 
]  

evalhook 和保存权重操作的优先级是 50,优先级值越小优先级越高,为此我们将 ExpMomentumEMAHook 优先级调整为 49,其余三个 hook,优先级设置为 48 就行,这三个 hook 执行顺序没有特别要求,优先级相同则会依据插入顺序执行。

2.3 重构

在训练精度对齐后需要将上述代码重构,使其在兼容 MMDetection 规范前提下代码可读性提升。重构部分也是并行进行,社区人员和维护者负责不同的部分,并且经过 review 后合并。

重构部分的核心是模型、后处理、loss、dataset 四个大模块,其中 loss 和 dataset 以及其他模块重构已经在 2.2 小结已经分析过了,故本小结只分析模型和后处理本身的重构。

(1) 模型相关重构

YOLOX 系列模型典型结构是 CSPDarknet+SPPBottleneck+PAFPN+Head,网络设计参考了 YOLOV5,Head 部分进行了改进:

采用了解耦 Head,其完整结构如下所示:

因为 YOLOX 模型是参考 YOLOV5 的,我们在重构前期讨论了两种模型构建方式,分别是:

1) 参考 ResNet 写法,arch_settings 按 stage 写,逐个 stage build

arch_settings = { 
    'P5': [[64, 128, 3, True, False], 
           [128, 256, 9, True, False], 
           [256, 512, 9, True, False], 
           [512, 1024, 3, False, True]], 

    'P6': [[64, 128, 3, True, False], 
           [128, 256, 9, True, False], 
           [256, 512, 9, True, False], 
           [512, 768, 3, True, False], 
           [768, 1024, 3, False, True]] 
}  

2) 参考 YOLOV5 写法,arch_settings 按 YOLOv5 config 的风格写,逐个 layer build

backbone: 
  # [from, number, module, args] 
  [ [ -1, 1, Focus, [ 64, 3 ] ],  # 0-P1/2 
    [ -1, 1, Conv, [ 128, 3, 2 ] ],  # 1-P2/4 
    [ -1, 3, C3, [ 128 ] ], 
    [ -1, 1, Conv, [ 256, 3, 2 ] ],  # 3-P3/8 
    [ -1, 9, C3, [ 256 ] ], 
    [ -1, 1, Conv, [ 512, 3, 2 ] ],  # 5-P4/16 
    [ -1, 9, C3, [ 512 ] ], 
    [ -1, 1, Conv, [ 768, 3, 2 ] ],  # 7-P5/32 
    [ -1, 3, C3, [ 768 ] ], 
    [ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 9-P6/64 
    [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], 
    [ -1, 3, C3, [ 1024, False ] ],  # 11 
  ]  

第一种写法优点是清晰易懂,缺点是灵活度较低,第二种写法正好相反,灵活度较高,但是代码难以理解。权衡之下,并结合 MMDetection 风格,我们最终采用了第一种方案。

如上图所示,SPPBottleneck 包含在了 CSPDarknet 中了,backbone 输出 3 个分支,stride 分别是 [8, 16, 32],然后接一个标准的 PAFPN 模块,也是输出 3 个分支,通道数都是 256,最后接 3 个不共享权重的 Head 模块,分别输出类别、bbox 和 objectness 预测信息。

有个细节:在重构模型结构后重新训练后发现,初始时候 Loss 相比原版代码有较大差距,经过排查发现是由于卷积参数初始化导致的。MMDetection 在构建模型时候都会采用 ConvModule 来搭建 Conv+BN+Act 结构,但是该模块内部会修改默认的卷积参数初始化过程。为了能够和源码完全对齐,我们需要在初始化后重置所有卷积层的参数初始化过程。

幸好有 init_cfg,这个参数可以通过配置文件来修改任意位置的参数初始化过程,而不需要改动任何代码。例如要实现上述功能,则只需要在构建模型时候,传入特定的 init_cfg 就行,例如 backbone 中要重置所有卷积层的参数初始化

@BACKBONES.register_module() 
class CSPDarknet(BaseModule): 

    def __init__(self, 
                 arch='P5', 
                 ... 
 init_cfg=dict( 
                     type='Kaiming', 
                     layer='Conv2d', 
                     a=math.sqrt(5), 
                     distribution='uniform', 
                     mode='fan_in', 
                     nonlinearity='leaky_relu')):  

由于 init_cfg 功能非常强大,后续我们会推出关于 init_cfg 的专题解读和用法。

(2) 后处理相关重构

YOLOX 的后处理策略非常简单,只需要 conf_thr 和 nms 阈值就可以。

  • 基于输出特征图宽高和 stride,生成 grid,然后基于 grid 对预测框进行解码还原到原图尺度
def decode_outputs(self, outputs, dtype): 
    grids = [] 
    strides = [] 
    for (hsize, wsize), stride in zip(self.hw, self.strides): 
        yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) 
        grid = torch.stack((xv, yv), 2).view(1, -1, 2) 
        grids.append(grid) 
        shape = grid.shape[:2] 
        strides.append(torch.full((*shape, 1), stride)) 
    grids = torch.cat(grids, dim=1).type(dtype) 
    strides = torch.cat(strides, dim=1).type(dtype) 
    # 解码还原 
    outputs[..., :2] = (outputs[..., :2] + grids) * strides 
    outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 

    return outputs 
  • 利用 conf_thre 过滤 bbox
# 提取最大分值对应的类别和预测分值 
class_conf, class_pred = torch.max( 
    image_pred[:, 5 : 5 + num_classes], 1, keepdim=True 
) 
# obj 分值和 class_conf 相乘,然后利用 conf_thre 进行过滤 
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() 

# 拼接 
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 
detections = detections[conf_mask]  
  • NMS 后处理
nms_out_index = torchvision.ops.batched_nms( 
    detections[:, :4], 
    detections[:, 4] * detections[:, 5], 
    detections[:, 6], 
    nms_thre, 
) 
detections = detections[nms_out_index]  

可以发现后处理策略非常简单,没有 nms_pre 参数、没有两次过滤准则、没有 max_num_pre_img 参数

MMDetection 后处理重构整体没有改变(因为已经很简单了),只不过我们统一用 prior_generator 来生成 grid,其他地方统一了参数命名方式和代码风格。

3 总结

本文详细分析了如何重头复现一个新算法,高效的复现过程离不开 MMDetection 开发团队和社区小伙伴们的不懈努力,再次表示感谢!我们也希望在这样的合作开发模式中,大家在快速理解算法本身的同时,社区小伙伴们也可以进一步的理解了 MMDetection 设计理念、代码规范和开发要求等等,共同成长和进步。

我们希望后续能够进一步推广这种开发模式,让更多的社区用户参与进来,共同打造最好用的目标检测框架,打造更好用的 OpenMMLab 开源库。


【极市开学礼,社区踩楼直接送】小米 11、海信电视(43英寸)、服务器、手环 5、200 京东卡(数张)、键鼠套装、漫步者头戴蓝牙耳机等 已准备就绪(下拉 加载更多查看):https://bbs.cvmart.net/articles/5369

  • 2
  • 0
  • 3355
收藏
暂无评论
shijie

华南理工

  • 18

    关注
  • 51

    获赞
  • 2

    精选文章
近期动态
  • 目标检测
文章专栏
  • shijie的专栏