【论文解读】NN如何在表格数据中战胜GBDT类模型!

共 3468字,需浏览 7分钟

 ·

2021-02-12 04:11

作者:一元,四品炼丹师


TabNet: Attentive Interpretable Tabular Learning(ArXiv2020)


01

背景

本文提出了一种高性能、可解释的规范深度表格数据学习结构TabNet。号称吊锤XGBoost和LightGBM等GBDT类模型。来吧,开学!

TabNet使用sequential的attention来选择在每个决策步骤中要推理的特征,使得学习被用于最显著的特征,从而实现可解释性和更有效的学习。我们证明了TabNet在广泛的非性能饱和表格数据集上优于其他变体,并产生了可解释的特征属性和对其全局行为的洞察。

最后,我们展示了表格数据的自监督学习,在未标记数据丰富的情况下显著提高了效果。

1. 决策树类模型在诸多的表格型问题中仍然具有非常大的优势:

  • 对于表格型数据中常见的具有近似超平面边界的决策流形,它们是表示有效的;
  • 它们的基本形式具有高度的可解释性(例如,通过跟踪决策节点),并且对于它们的集成形式有流行的事后可解释性方法;
  • 训练非常快;

2. DNN的优势:

  • 有效地编码多种数据类型,如图像和表格数据;
  • 减轻特征工程的需要,这是目前基于树的表格数据学习方法的一个关键方面;
  • 从流式数据中学习;
  • 端到端模型的表示学习,这使得许多有价值的应用场景能够实现,包括数据高效的域适配;

3. TabNet:

  • TabNet无需任何预处理即可输入原始表格数据,并使用基于梯度下降的优化方法进行训练,实现了端到端学习的灵活集成。

  • TabNet使用sequential attention来选择在每个决策步骤中从哪些特征中推理,从而实现可解释性和更好的学习,因为学习能力用于最显著的特征。这种特征选择是基于实例的,例如,对于每个输入,它可以是不同的,并且与其他基于实例的特征选择方法不同,TabNet采用了一种深度特征选择和推理的学习体系结构。

  • TabNet在不同领域的分类和回归问题的不同数据集上优于或等同于其他表格学习模型;

  • TabNet有两种可解释性:局部可解释性,用于可视化特征的重要性及其组合方式;全局可解释性,用于量化每个特征对训练模型的贡献。

  • 最后,对于表格数据,我们首次通过使用无监督预训练来预测掩蔽特征,得到了显著的性能提升;

02

TabNet


类似于DTs的DNN building blocks


 
  • 使用从数据中学习的稀疏实例特征选择;
  • 构造一个连续的多步骤体系结构,其中每个步骤有助于基于所选特征的决策的一部分;
  • 通过对所选特征的非线性处理来提高学习能力;
  • 通过更高的维度和更多的步骤来模拟融合。

TabNET的框架


 

我们使用所有的原始数值特征并且将类别特征转化为可以训练的embedding,我们并不考虑全局特征normalization。

在每一轮我们将D维度的特征传入,其中是batch size, TabNet的编码是基于序列化的多步处理, 有个决策过程。在第步我们输入第步的处理信息来决定使用哪些特征,并且输出处理过的特征表示来集成到整体的决策。


特征选择


我们使用可学习的mask, 用于显著特征的soft选择,通过最多的显著特征的稀疏选择,决策步的学习能力在不相关的上面不被浪费,从而使模型更具参数效率。masking是可乘的,,此处我们使用attentive transformer来获得使用在前面步骤中处理过的特征的masks,.

Sparsemax规范化通过将欧几里得投影映射到概率simplex上鼓励稀疏性,观察到概率simplex在性能上更优越,并与稀疏特征选择的目标一致,以便于解释。注意: , 是一个可以训练的函数。

是先验的scale项,表示一个特殊的特征之前被使用的多少,,其中是缩放参数。

  • 的时候,特征只会在第一个决策步被使用,当变大的时候, 更多的灵活性会在多个决策步被使用, 被初始化为1,,如果某个特征是没什么用处的,那么对应的就是0。

为了控制选择特征的稀疏性,此处加入sparsity的正则来控制数值稳定性,

其中对于数值稳定性是一个很小的书,我们再最终的loss上加入稀疏的正则,对应的参数为.

特征处理


我们使用一个特征transformer来处理过滤的特征,然后拆分决策步骤输出和后续步骤信息,,其中, ,对于具有高容量的参数有效且鲁棒的学习,特征变换器应该包括在所有决策步骤之间共享的层(因为在不同的决策步骤之间输入相同的特征)以及决策步骤相关的层。上图展示了作为两个共享层和两个决策步骤相关层的级联的实现。

每个FC层后面是BN和gated线性单元(GLU)非线性,最终通过归一化连接到归一化残差连接。此处我们通过的正则来保证网络的方差以稳定学习。

为了快速的训练,此处我们使用带有BN的大的batch size,因此,除了应用到输入特征的,我们使用ghost BN形式,使用一个virtual batchsize 和momentum ,对于输入特征,我们观测到low-variance平均的好处,因此可以避免ghost BN,最终我们通过decision-tree形式的aggregation,我们构建整体的决策embedding, ,再使用线性mapping, 得到最终的输出。


解释性


此处我们可以使用特征选择的mask来捕捉在每一步的选择的特征,如果:

  • ,那么第个样本的第个特征对于我们的决策是没有任何帮助的;

如果是一个线性函数,的稀疏应该对应的二者重要性,尽管每次决策步使用一个非线性处理,他们的输出是以线性的方式组合,我们的目的是量化一个总体特征的重要性,除了分析每一步。组合不同步骤的Mask需要一个系数来衡量决策中每个步骤的相对重要性,我们提出:

  • 来表示在第步决策步对于第个样本的累计决策贡献。

直觉上,如果,那么在第个决策步的所有特征就应当对整体的决策没有任何帮助。当它的值增长的时候,它在整体线性的组合上会更为重要,在每次决策步的时候对决策mask进行缩放,,我们对特征重要性mask进行特征的集成, .


表格自监督学习


我们提出了一个解码器架构来从TabNet编码的表示中重建表格特征。解码器由特征变换器组成,每个判决步骤后面是FC层。将输出相加得到重构特征。我们提出了一个从其他特征列中预测缺失特征列的任务。考虑一个二进制掩码,

  • TabNet的encoder输入;
  • decoder输入重构特征, ;

我们在编码器中初始化, 这么做模型只重点关注已知的特征,解码器的最后一层FC层和进行相乘输出未知的特征,我们考虑在自监督阶段的重构损失,

使用真实值的标准偏差进行Normalization是有帮助的,因为特征可能有不同的ranges,我们在每次迭代时以概率从伯努利分布中独立采样;


03

实验

1. 基于实例的特征选择

  • TabNet比所有其他的模型都要好;
  • TabNet的效果与全局特征选择非常接近,它可以找到哪些特征是全局最优的;
  • 删除冗余特征之后,TabNet提升了全局特征选择;

2. 现实数据集上的表现

  • TabNet在多个数据集上的效果都取得了最好的效果;

3. 自监督学习

  • 无监督预训练显著提高了有监督分类任务的性能,特别是在未标记数据集比标记数据集大得多的情况下;
  • 如上图所示,在无监督的预训练下,模型收敛更快。快速收敛有助于持续学习和领域适应.

04

小结

本文我们提出了TabNet,一种新的用于表格学习的深度学习体系结构。TabNet使用一种顺序attention机制来选择语义上有意义的特征子集,以便在每个决策步骤中进行处理。基于实例的特征选择能够有效地进行学习,因为模型容量被充分地用于最显著的特征,并且通过选择模板的可视化产生更具解释性的决策。我们证明了TabNet在不同领域的表格数据集上的性能优于以前的工作。最后,我们展示了无监督预训练对于快速适应和提高模型的效果。

05

参考文献


  1. TabNet: https://arxiv.org/pdf/1908.07442.pdf

往期精彩回顾





本站知识星球“黄博的机器学习圈子”(92416895)

本站qq群704220115。

加入微信群请扫码:

浏览 82
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报