请谨慎使用预训练的深度学习模型

共 3839字,需浏览 8分钟

 ·

2022-06-01 11:00

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

导读

预训练模型用起来非常容易,但是你是否忽略了可能影响模型性能的细节呢?


你运行过多少次下面的代码:

  1. import torchvision.models as models

  2. inception = models.inception_v3(pretrained=True)

或者是这个

  1. from keras.applications.inception_v3 import InceptionV3

  2. base_model = InceptionV3(weights='imagenet', include_top=False)

看起来使用这些预训练的模型已经成为行业最佳实践的新标准。毕竟,有一个经过大量数据和计算训练的模型,你为什么不利用呢?

预训练模型万岁!
 

利用预训练的模型有几个重要的好处:

  • 合并超级简单

  • 快速实现稳定(相同或更好)的模型性能

  • 不需要太多的标签数据

  • 迁移学习、预测和特征提取的通用用例

NLP领域的进步也鼓励使用预训练的语言模型,如GPT和GPT-2、AllenNLP的ELMo、谷歌的BERT、Sebastian Ruder和Jeremy Howard的ULMFiT。

利用预训练模型的一种常见技术是特征提取,在此过程中检索由预训练模型生成的中间表示,并将这些表示用作新模型的输入。通常假定这些最终的全连接层得到的是信息与解决新任务相关的。

每个人都参与其中

每一个主流框架,如Tensorflow,Keras,PyTorch,MXNet等,都提供了预先训练好的模型,如Inception V3,ResNet,AlexNet等,带有权重:

  • Keras Applications

  • PyTorch torchvision.models

  • Tensorflow Official Models (and now TensorFlow Hubs)

  • MXNet Model Zoo

  • Fast.ai Applications

很简单,是不是?

但是,这些benchmarks可以复现吗?
 

这篇文章的灵感来自Curtis Northcutt,他是麻省理工学院计算机科学博士研究生。他的文章‘Towards Reproducibility: Benchmarking Keras and PyTorch’ 提出了几个有趣的观点

  1. resnet结构在PyTorch中执行得更好, inception结构在Keras中执行得更好

  2. 在Keras应用程序上不能复现Keras Applications上的已发布的基准测试,即使完全复制示例代码也是如此。事实上,他们报告的准确率(截至2019年2月)通常高于实际的准确率。

  3. 当部署在服务器上或与其他Keras模型按顺序运行时,一些预先训练好的Keras模型会产生不一致或较低的精度。

  4. 使用batch normalization的Keras模型可能不可靠。对于某些模型,前向传递计算(假定梯度为off)仍然会导致在推理时权重发生变化。

你可能会想:这怎么可能?这些不是相同的模型吗?如果在相同的条件下训练,它们不应该有相同的性能吗?

并不是只有你这么想,Curtis的文章也在Twitter上引发了一些反应:

关于这些差异的原因有一些有趣的见解:

了解(并信任)这些基准测试非常重要,因为它们允许你根据要使用的框架做出明智的决策,并且通常用作研究和实现的基线。

那么,当你利用这些预先训练好的模型时,需要注意什么呢?

使用预训练模型的注意事项

1、你的任务有多相似?你的数据有多相似?

对于你的新x射线数据集,你使用Keras Xception模型,你是不是期望0.945的验证精度?首先,你需要检查你的数据与模型所训练的原始数据集(在本例中为ImageNet)有多相似。你还需要知道特征是从何处(网络的底部、中部或顶部)迁移的,因为任务相似性会影响模型性能。

阅读CS231n — Transfer Learning and ‘How transferable are features in deep neural networks?’

2、你如何预处理数据?

你的模型的预处理应该与原始模型相同。几乎所有的torchvision模型都使用相同的预处理值。对于Keras模型,你应该始终为相应的模型级模块使用 preprocess_input函数。例如:

  1. # VGG16

  2. keras.applications.vgg16.preprocess_input

  3. # InceptionV3

  4. keras.applications.inception_v3.preprocess_input

  5. #ResNet50

  6. keras.applications.resnet50.preprocess_input

3、你的backend是什么?

有一些关于HackerNews的传言称,将Keras的后端从Tensorflow更改为CNTK (Microsoft Cognitive toolkit)提高了性能。由于Keras是一个模型级库,它不处理诸如张量积、卷积等较低级别的操作,所以它依赖于其他张量操作框架,比如TensorFlow后端和Theano后端。

Max Woolf提供了一个优秀的基准测试项目,发现CNTK和Tensorflow之间的准确性是相同的,但CNTK在LSTMs和多层感知(MLPs)方面更快,而Tensorflow在CNNs和embeddings方面更快。

Woolf的文章是2017年发表的,所以如果能得到一个更新的比较结果,其中还包括Theano和MXNet作为后端,那将是非常有趣的(尽管Theano现在已经被废弃了)。

还有一些人声称,Theano的某些版本可能会忽略你的种子。

4、你的硬件是什么?

你使用的是Amazon EC2 NVIDIA Tesla K80还是Google的NVIDIA Tesla P100?甚至可能是TPU?😜看看这些不同的pretrained模型的有用的基准参考资料。

  • Apache MXNet’s GluonNLP 0.6:Closing the Gap in Reproducible Research with BERT

  • Caleb Robinson’s ‘How to reproduce ImageNet validation results’ (and of course, again, Curtis’ benchmarking post)

  • DL Bench

  • Stanford DAWNBench

  • TensorFlow’s performance benchmarks

5、你的学习率是什么?

在实践中,你应该保持预训练的参数不变(即,使用预训练好的模型作为特征提取器),或者用一个相当小的学习率来调整它们,以便不忘记原始模型中的所有内容。

6、在使用batch normalization或dropout等优化时,特别是在训练模式和推理模式之间,有什么不同吗?

正如Curtis的帖子所说:

使用batch normalization的Keras模型可能不可靠。对于某些模型,前向传递计算(假定梯度为off)仍然会导致在推断时权重发生变化。

但是为什么会这样呢?

Expedia的首席数据科学家Vasilis Vryniotis首先发现了Keras中的冻结batch normalization层的问题。

Keras当前实现的问题是,当冻结批处理规范化(BN)层时,它在训练期间还是会继续使用mini-batch的统计信息。我认为当BN被冻结时,更好的方法是使用它在训练中学习到的移动平均值和方差。为什么?由于同样的原因,在冻结层时不应该更新mini-batch的统计数据:它可能导致较差的结果,因为下一层没有得到适当的训练。

Vasilis还引用了这样的例子,当Keras模型从训练模式切换到测试模式时,这种差异导致模型性能显著下降(从100%下降到50%)。

好消息! 

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 36
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报