【数据竞赛】图像赛排行榜拉开100名差距的技巧

机器学习初学者

共 2177字,需浏览 5分钟

 ·

2021-02-06 16:37

作者:  尘沙风尘


Kaggle图像赛上分技巧之TTA: Test Time Aug


  • 1  TTA(Test Time Aug)

    • 1.1  简介

    • 1.2  案例(keras)

      • 1.2.1  导入适合当前问题的预测器(ClassPredictor用于分类,SegPredictor用于分割)

      • 1.2.2  用配置和所需的任何参数实例化类

      • 1.2.3  对图片进行预测

      • 1.2.4  实验结果

    • 1.3  小结

  • 2  参考文献

简介


我们都知道对我们的训练数据进行翻转,平移,缩放等扩充的操作往往可以获得一个训练更好的网络模型,这些扩充操作往往可以帮助我们的模型更好的挖掘到那些对于位置,光照等信息不敏感的信息,从而具有更好的泛化性,得到更好的预测结果。那既然训练集的数据可以扩充,测试集呢?

Bingo!没错,测试集也是可以采取类似的操作。而这种操作我们称之为TTA(Test Time Augmentation),顾名思义就是在测试的阶段对数据进行扩充。

TTA是一个非常通用的Trick,目前几乎绝大多数图像相关的竞赛都会使用到,而且基本是99%都能带来线上排行榜的提升。那么究竟是怎么做的呢?其实很简单:

就是在模型测试时,对原始的测试图像进行各种策略的扩充,例如:

  • 图像裁剪;
  • 图像缩放;
  • 图像旋转;
  • 图像平移;
  • ...

然后我们将预测的结果进行某种程度的融合,最常见的就是取平均值,然后将该分数作为最终的预测分数。

TTA操作较早出现在2015年ICLR的论文"Very Deep Convolutional Networks for Large-Scale Image Recognition"

We also augment the test set by horizontal flipping of the images; the soft-max class posteriors of the original and flipped images are averaged to obtain the final scores for the image.


案例(基于Keras)


以kaggle Dogs VS Cats为例, edafa (TTA package)


1. 导入适合当前问题的预测器(ClassPredictor用于分类,SegPredictor用于分割)

from edafa import ClassPredictor 

2.继承预测器类并实现主函数:predict_patches(self,patches)

class myPredictor(ClassPredictor):
    def __init__(self,model,*args,**kwargs):
        super().__init__(*args,**kwargs)
        self.model = model

    def predict_patches(self,patches):
        return self.model.predict(patches)

3. 用配置和所需的任何参数实例化类

conf = '{"augs":["NO",\
                "FLIP_LR"],\
        "mean":"ARITH"}'

4. 对图片进行预测

p = myPredictor(model,conf)
y_pred_aug = p.predict_images(X_val)
y_pred_aug = [(y[0]>=0.5).astype(np.uint8) for y in y_pred_aug ]
print('Accuracy with TTA:',np.mean((y_val==y_pred_aug)))

5. 实验结果

  • 使用TTA:Accuracy with TTA: 0.7892
  • 不适用TTA:Accuracy without TTA: 0.7852571428571429
小结


TTA技术目前是各大图像相关的数据竞赛的必备技能之一,它能为最终的成绩带来非常大的帮助,也是目前图像赛的必备技能之一,赶紧收藏吧!

参考文献


  1. Kaggle小技巧:TTA(test time augmentation)测试时加强:https://www.shangmayuan.com/a/0e4942dc496047bb95c5806c.html
  2. https://github.com/qubvel/ttach
  3. https://www.kaggle.com/andrewkh/test-time-augmentation-tta-worth-it


往期精彩回顾





本站知识星球“黄博的机器学习圈子”(92416895)

本站qq群704220115。

加入微信群请扫码:

浏览 46
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报