Rethinking “Batch” in BatchNorm

共 5371字,需浏览 11分钟

 ·

2021-05-22 18:10

这篇很有趣且很有用,激动的赶紧把文章看了一遍,不愧是FAIR,实验看的太爽了。

之前对于Norm的研究主要在于改变Norm的维度,然后衍生出了BatchNorm、GroupNorm、InstanceNorm和LayerNorm等方法,但是除了BN外的其他Norm含义是确定,而BN的batch却可以有多种采样方式,本文就是为了探讨BN的batch使用不同的采样方式会有什么影响堪称BatchNorm圣经,建议全文背诵(ps:GN也是吴育昕的作品)。


本文总共4大核心实验,每个核心实验有多个子结论。


01

Motivation


BatchNorm现在已经广泛的应用于CNN中。但是BN针对不同的场景使用时有许多细微的差异,如果选择不当会降低模型的性能。BatchNorm相对于其他算子来说,主要的不同在于BN是对batch数据进行操作的。BN在batch数据中进行统计量计算,而其他算子一般都是独立处理单个样本的。因此影响BN的输出不仅仅取决于单个样本的性质,还取决于batch的采样方式

如图所示,左右各举例了三种batch采样方式。其中左图三种batch采样方式分别为entire dataset、mini-batches和subset of mini-batches,右图三种batch采样方式分别为entire domain、each domain和mixture of each domain。

本文实验证明了使用BN时不考虑batch的采样方式会在许多方面产生负面影响,合理使用batch采样方式会改善模型性能。


02

A Review of BatchNorm


简单回顾一下BN的计算形式,这里以CNN中的BN为例。假设BN的输入feature维度为  ,逐通道统计量mean和std为  ,那么BN的输出y为:

  

假设batch的大小为N,mini-batch X的维度为  ,那么mini-batch 统计量为  ,定义为:

  

推理的时候,  用训练集计算得到的统计量,定义为  。


不同于默认的BN设置,因为batch采样方式主要影响的是统计量mean和std,本文将mean和std看成是一个逐通道分开计算的仿射变换(可以等价为一个1x1的depth-wise layer)。


03

Whole Population as aBatch


BN中统计量的计算默认使用EMA方法,但是作者实验发现EMA会导致模型性能次优,然后提出了PreciseBN方法,近似将整个训练集统计量作为一个batch。

Inaccuracy of EMA

EMA是指数滑动平均的缩写,为了统计  ,EMA在训练过程中对统计量进行更新:

  

EMA方法导致次优解的原因有两点:

1.当  太大时,统计量收敛速度变慢。

2.当  太小时,最近几个mini-batches影响更大,统计量无法表示整个训练集的统计量。

Towards Precise Population Statistics

为了得到整个训练集更加精确的统计量,PreciseBN采用了两点小技巧:

1.将相同模型用于多个mini-batches来收集batch统计量

2.将多个batch收集的统计量聚合成一个population统计量

比如有N个样本需要通过数量为的Bmini-batch进行PreciseBN统计量计算,那么需要计算  次,统计量聚合公式为:

  

  

相比于EMA,PreciseBN有两点重要的属性:

1.PreciseBN的统计量是通过相同模型计算得到的,而EMA是通过多个历史模型计算得到的。

2.PreciseBN的所有样本的权重是相同的,而EMA不同样本的权重是不同的。

100 samples of batch mean意思是相同epoch下模型对100个随机batch统计量的结果。如图所示,在训练早期EMA的统计量不精确,会导致最终模型性能次优。由于滑动平均的计算方式导致EMA的统计量滞后于PrciseBN。


4个主要结论:

1.推理时使用PreciseBN会更加稳定。

2.大batch训练对EMA影响更大。

3.PreciseBN只需要10^3~10^4个样本可以得到近似最优。

4.小batch会产生统计量积累错误。



04

Batch in Training and Testing


BN在训练和测试中行为不一致:训练时,BN的统计量来自mini-batch;测试时,BN的统计量来自population。这部分主要探讨了BN行为不一致对模型性能的影响,并且提出消除不一致的方法提升模型性能。

Effect of Normalization Batch Size

为了避免混淆,将SGD batch size或者total batch size定义为所有GPU上总的batch size大小,将normalization batch   size定义为单个GPU上的batch size大小。

normalization batch size对training noise和train-test inconsistency有着直接影响:使用更大的batch,mini-batch统计量越接近population统计量,从而降低training noise和train-test inconsistency。


以下实验的SGD batch size固定使用1024大小。

为了便于分析,作者观察了3种不同评估方法的错误率:

1.在训练集上对mini-batch统计量进行评估

2.在验证集上对mini-batch统计量进行评估

3.在验证集上对population统计量进行评估

Training noise:当normalization batch size非常小时,单个样本会受到同一个min-batch样本的严重影响,导致训练精度较差,优化困难。

Generalization gap:随着normalization batch size的增加,mini-batch的验证集和训练集的之间的泛化误差会增大,这可能是由于training noise和train-test inconsistency没有正则化。

Train-test inconsistency:在小batch下,mini-batch统计量和population统计量的不一致是影响性能的主要因素。当normalization batch size增大时,细微的不一致可以提供正则化效果减少验证误差。在mini-batch为32~128之间时,正则化达到平衡,模型性能最优。


为了保持train和test的BN统计量一致,作者提出了两种方法来解决不一致问题,一种是推理的时候使用mini-batch统计量,另一种是训练的时候使用population batch统计量。

Use Mini-batch in Inference

作者在Mask R-CNN上进行实验,mini-batch的结果超过了population的结果,证明了在推理中使用mini-batch可以有效的缓解训练测试不一致。(ps:不使用norm效果略差,使用GN效果更好)

Use Population Batch in Training

为了在训练阶段使用population统计量,作者采用FrozenBN的方法,FrozenBN使用population统计量。具体地,作者先选择第80个epoch模型,然后将所有BN替换成FrozenBN,然后训练20个epoch。

FrozenBN可以有效缓解训练测试不一致,即使在小normalization batch size,也能达到比较好的性能。但是随着normalization batch size增大,作者提出的两种缓解不一致的方法都不如常规BN的结果。


05

Batch from Different Domains


BN的训练过程可以看成是两个独立的阶段:第一个阶段是通过SGD学习features,第二个阶段是由这些features得到population统计量。两个阶段分别称为SGD training和population statistics training。

由于BN多了一个population统计阶段,导致训练和测试之间的domain shift。当数据来自多个doman时,SGD training、population statistics training和testing三个步骤的domain gap都会对泛化性造成影响。

实验主要探究了两种使用场景:第一种,模型在一个domain上进行训练,然后在其他domain上进行测试;第二种,模型在多个domain上进行训练。

Domain to Compute Population Statistics

作者实验发现,当存在显著的domain shift时,模型使用评估domain的population统计量会得到更好的结果,可以缓解训练测试的不一致。

BatchNorm in Multi-Domain Training

为了对多个domain的情况进行实验,作者将RetinaNet head中的BN统计量进行实验设计。RetinaNet的head是5个feature层共享的,这意味着会接收来自5个不同分布或者domain的输入进行训练。

左图的训练形式非常简单,head独立作用于不同的feature层,都有自己独立的统计量。右图将所有输入特征flatten然后concat在一起,统一进行统计量计算。两种不同计算统计量的方式称为domain-specific statistics和shared statistics。

最终实验表明,SGD training、population statistics training和testing保持一致是非常重要的,并且全部使用domain-specific能取得最好的效果。(ps:不使用norm效果略差,使用GN效果更好)


06

Information Leakage within a Batch


BN在使用中还存在一种information leakage现象,因为BN是对mini-batch的样本计算统计量的,导致在样本进行独立预测时,会利用mini-batch内其他样本的统计信息。

Exploit Patterns in Mini-batches

作者实验发现,当使用random采样的mini-batch统计量时,验证误差会增加,当使用population统计量时,验证误差会随着epoch的增加逐渐增大,验证了BN信息泄露问题的存在。

为了处理信息泄露问题,之前常见的作法是使用SyncBN,来弱化mini-batch内样本之间的相关性。另一种解决方法是在进入head之前在GPU之间随机打乱RoI features,这给每个GPU分配了一个随机的样本子集来进行归一化,同时也削弱了min-batch样本之间的相关性,如上图所示。

实验结果表明,shuffling和SyncBN都能有效地处理信息泄漏,使得head在测试时能够很好地泛化。在速度方面,我们注意到shuffling需要更少的跨gpu同步,但是shuffling每次传输的数据比SyncBN多。因此,shuffling和SyncBN的相对效率跟具体模型架构相关。

Cheating in Contrastive Learning

在对比学习和度量学习时,训练目标通常是在mini-batch下进行比较的,这种情况下BN也会造成信息泄露,导致模型在训练期间作弊,之前的研究提出了很多不同方法来针对性解决对比学习和度量学习的信息泄露问题。



07

总结


本文从多个角度探讨了BN的batch使用不同的采样方式会有什么影响,并且做了非常详尽的对比试验,堪称BatchNorm圣经,建议全文背诵。

另外,看完后最大的感触是,BN不会用就别用,GN yyds。


Reference

[1] Rethinking “Batch” in BatchNorm

长按扫描下方二维码加入交流群,群里博士大佬云集,每日讨论话题有目标检测、语义分割、超分辨率、模型部署、数学基础知识、算法面试题分享的等等内容,当然也少不了搬砖人的扯犊子

长按扫描下方二维码添加小助手。

可以一起讨论遇到的问题

声明:转载请说明出处

扫描下方二维码关注【集智书童】公众号,获取更多实践项目源码和论文解读,非常期待你我的相遇,让我们以梦为马,砥砺前行!

浏览 22
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报