用 SHAP 可视化解释机器学习模型实用指南

机器学习与数学

共 5938字,需浏览 12分钟

 ·

2021-10-13 15:15

大家好,我是云朵君!

导读: SHAP 是 Python 开发的一个"模型解释"包,是一种博弈论方法来解释任何机器学习模型的输出。本文重点介绍 11 种 shap 可视化图形来解释任何机器学习模型的使用方法。具体理论并不在本次内容内,需要了解模型理论的小伙伴,可参见文末参考文献。本文因篇幅限制,分为上下两篇,本篇介绍 shap 可视化特征重要性及特征效果。

👆点击关注|设为星标|干货速递👆


24b634480f6dcc1fd87e24392b49e59a.webp

SHAP(Shapley Additive exPlanations) 使用来自博弈论及其相关扩展的经典 Shapley value将最佳信用分配与局部解释联系起来,是一种基于游戏理论上最优的 Shapley value来解释个体预测的方法。

从博弈论的角度,把数据集中的每一个特征变量当成一个玩家,用该数据集去训练模型得到预测的结果,可以看成众多玩家合作完成一个项目的收益。Shapley value 通过考虑各个玩家做出的贡献,来公平的分配合作的收益。

数据集

标准的 UCI 成人收入数据集。

import shap
X,y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
76e92d9bc3d33b3049ccadb1bc5b8a4b.webp

创建 Explainer 并计算 SHAP 值

在SHAP中进行模型解释需要先创建一个 explainer,SHAP 支持很多类型的explainer(例如 deep, gradient, kernel, linear, tree, sampling),本文使用支持常用的XGB、LGB、CatBoost 等树集成算法的 tree 为例。

  • deep:用于计算深度学习模型,基于DeepLIFT算法
  • gradient:用于深度学习模型,综合了SHAP、集成梯度、和SmoothGrad等思想,形成单一期望值方程
  • kernel:模型无关,适用于任何模型
  • linear:适用于特征独立不相关的线性模型
  • tree:适用于树模型和基于树模型的集成算法
  • sampling :基于特征独立性假设,当你想使用的后台数据集很大时,kenel的一个很好的替代方案
explainer = shap.TreeExplainer(model)  

然后计算shap_values值,计算非常简单,直接利用上面得到的解释器解释训练样本X,这里有两种形式:

输出 numpy.array 数组

shap_values = explainer.shap_values(X) 
bfc667f0ada7dc5a21727358da2df1ce.webp

输出 shap.Explanation 对象

shap_values2 = explainer(X) 
711ad46f29f6321ea946cc8d50d746fb.webp

模型自带特征重要性

关于模型解释性,除了线性模型和决策树这种天生就有很好解释性的模型以外,sklean/ xgboost 中有很多模型都有 importance 这一接口,可以查看特征的重要性。

model = xgboost.XGBClassifier(eval_metric='mlogloss').fit(X, y)
xgboost.plot_importance(model,height = .5
                        max_num_features=10,
                        show_values = False)
b7071be2544b7b962630cf4bbdecef07.webp

SHAP 特征重要性

Summary Plot

将 SHAP 值矩阵传递给条形图函数会创建一个全局特征重要性图,其中每个特征的全局重要性被视为该特征在所有给定样本中的平均绝对值。

shap.summary_plot(shap_values, X_display, 
                  plot_type="bar")
794a3b115d3384afd449835494084984.webp

在上面两图中,可以看到由 SHAP value 计算的特征重要性与使用 scikit-learn / xgboost计算的特征重要性之间的比较,它们看起来非常相似,但它们并不相同。

Bar plot

全局条形图

特征重要性的条形图还有另一种绘制方法。

shap.plots.bar(shap_values2)
7cced5b808f0f9857f6e983da4efed45.webp

同一个shap_values,不同的计算

summary_plot中的shap_values是numpy.array数组
plots.bar中的shap_values是shap.Explanation对象

当然shap.plots.bar()还可以按照需求修改参数,绘制不同的条形图。如通过max_display参数进行控制条形图最多显示条形树数。

局部条形图

将一行 SHAP 值传递给条形图函数会创建一个局部特征重要性图,其中条形是每个特征的 SHAP 值。其中特征值是否显示,是通过参数show_data控制,默认 'auto' 特征值以灰色显示在特征名称的左侧。

shap.plots.bar(shap_values2[1], show_data=True)
04771b15f30d8e5a7d5d43fa8422dab5.webp

队列条形图

传递解释对象的字典将为解释对象表示的每个群组创建一个多条形图,其中包含一个条形类型。下面我们使用它来分别绘制男性和女性特征重要性的全局摘要。

sex = ["Women" if shap_values2[i,"Sex"].data == 0 
       else "Men" for i in range(shap_values2.shape[0])]
shap.plots.bar(shap_values2.cohorts(sex).abs.mean(0))
3cbe51c75d3ac5eb35ff4726200d37a1.webp

队列条形图还有另一个比较有意思的绘图,他使用 Explanation 对象的自动群组功能来使用决策树创建一个群组。调用Explanation.cohorts(N)将创建 N 个队列,使用 sklearn DecisionTreeRegressor 最佳地分离实例的 SHAP 值。

例如将其用于成人人口普查数据,则看到低资本收益与高资本收益之间的明显区别。括号中的数字是每个队列中的实例数。

v = shap_values2.cohorts(2).abs.mean(0)
shap.plots.bar(v)
480cf0b24add949964dfc1731e5cd2a3.webp

使用特征聚类

很多时候数据集中的特征存在冗余。这意味着模型可以使用任一特征并仍然获得相同的准确性。可以通过计算特征之间的相关矩阵,或使用聚类方法来找到这些特征。

在 SHAP 中通过模型损失比较来测量特征冗余。即使用shap.utils.hclust方法,并通过训练 XGBoost 模型来预测每对输入特征的结果来构建特征的层次聚类。与从无监督方法(如相关性)中获得的特征冗余相比。对典型的结构化数据集进行特征冗余度量,会更加准确。

计算聚类并传递给条形图,就可以同时可视化特征冗余结构和特征重要性。默认只会显示距离 < 0.5 的聚类部分。假设聚类中的距离大致在 0 和 1 之间缩放,其中 0 距离表示特征完全冗余,1 表示它们完全独立。

在下图中,我们看到只有关系和婚姻状况有超过 50% 的冗余,因此它们是条形图中分组的唯一特征:

clustering = shap.utils.hclust(X, y) 
shap.plots.bar(shap_values2, 
               clustering=clustering,
               clustering_cutoff=0.5)
696b3b981cc883230a73fab8a7303ef5.webp

Summary Plot

上面使用 Summary Plot 方法并设置参数plot_type="bar"绘制典型的特征重要性条形图,而他默认绘制 Summary_plot 图,他是结合了特征重要性和特征效果,取代了条形图。

Summary_plot 为每一个样本绘制其每个特征的 Shapley value,它说明哪些特征最重要,以及它们对数据集的影响范围。

y 轴上的位置由特征确定,x 轴上的位置由每 Shapley value 确定。颜色表示特征值(红色高,蓝色低),颜色使我们能够匹配特征值的变化如何影响风险的变化。重叠点在 y 轴方向抖动,因此我们可以了解每个特征的 Shapley value分布,并且这些特征是根据它们的重要性排序的。

shap.summary_plot(shap_values, X)
1fc3a9ee8a5ec3f04874e335cf3e6d94.webp

Beeswarm plot

同条形图一样 shap 也提供了另一个接口plots.beeswarm蜂群图。

蜂群图旨在显示数据集中的 TOP 特征如何影响模型输出的信息密集摘要。给定解释的每个实例由每个特征流上的一个点表示。点的 x 位置由该特征的 SHAP 值 ( shap_values.value[instance,feature]) 确定,并且点沿每个特征行“堆积”以显示密度。颜色用于显示特征的原始值 ( shap_values.data[instance,feature])。

在下图中,我们可以看到平均而言年龄是最重要的特征,与年轻(蓝色)人相比,收入超过 5 万美元的可能性较小。

8f81a29f35d991fa979b7d0c16c041ab.webp

同样可以使用max_display参数调整最多显示行数。

默认使用每个特征的 SHAP 值的平均绝对值shap_values.abs.mean(0) 对特征排序。然而,这个顺序更强调广泛的平均影响,而不是罕见但高强度的影响。如果我们想找到对个人影响较大的特征,可以按最大绝对值排序。

shap.plots.beeswarm(shap_values2, 
                    order=shap_values.abs.max(0))

另外,在绘图之前,就对 shap_values 取绝对值,得到与条形图类似的图形,但比条形图具有更丰富的平行线,因为条形图只是绘制蜂群图中点的平均值。

# 蜂群图
shap.plots.beeswarm(shap_values2.abs, 
                    color="shap_red")
# 条形图
shap.plots.bar(shap_values2.abs.mean(0))

还可以自定义颜色,默认使用shap.plots.colors.red_blue颜色图。

import matplotlib.pyplot as plt
shap.plots.beeswarm(shap_values, 
                    color=plt.get_cmap("cool"))

在 Summary_plot 图中,首先看到了特征值与对预测的影响之间关系的迹象,但是要查看这种关系的确切形式,还必须查看 SHAP Dependence Plot 图。

Dependence Plot

SHAP Partial dependence plot (PDP or PD plot) 依赖图显示了一个或两个特征对机器学习模型的预测结果的边际效应,它可以显示目标和特征之间的关系是线性的、单调的还是更复杂的。他们在许多样本中绘制了一个特征的值与该特征的 SHAP 值。

PDP 是一种全局方法:该方法考虑所有实例并给出关于特征与预测结果的全局关系。PDP 的一个假设是第一个特征与第二个特征不相关。如果违反此假设,则 PDP 计算的平均值将包括极不可能甚至不可能的数据点。

为了显示哪个特征可能会驱动这些交互效应,可以通过第二个特征为我们的年龄依赖性散点图着色(默认第二个特征是自动选择的,尝试挑选出与 Age 交互作用最强的特征列)。也可以通过参数interaction_index设置交互项。如果另一个特征与正在绘制的特征之间存在交互作用,它将显示为不同的垂直着色模式。

shap.dependence_plot('Age', shap_values, X, 
                     display_features=X_display,
                     interaction_index='Capital Gain')
5b3c489dc598314294cc7ea53fbd738b.webp

Dependence plot 是一个散点图,显示单个特征对整个数据集的影响。

  • 每个点都是来自数据集的单个预测(行)。
  • x 轴是数据集中的实际值。(来自 X 矩阵,存储在 中shap_values.data)。
  • y 轴是该特征的 SHAP 值(存储在 中shap_values.values),它表示该特征值对该预测的模型输出的改变程度。

Scatter plot

同样,散点图绘图依赖图,这与上面 dependence_plot 绘制基本一样。

在显示方面有些许不同,plots scatter 图底部的浅灰色区域是显示数据值分布的直方图。

在交互颜色方面。dependence_plot 默认而散点图则需要将整个 Explanation 对象传递给 color 参数。

另外,有时候在输入模型之前是字符串,为输入到模型,需要将其设置为分类编码,此时绘图,并不能很直观地显示内容。此时可以将.display_data Explanation 对象的属性设置为我们希望在图中显示的原始数据类型。

shap_values2.display_data = X_display.values
shap.plots.scatter(shap_values2[:, "Age"], 
                   color=shap_values2[:,"Workclass"])
d7d694e1bdf3695684d97856b9ec3a54.webp

使用全局特征重要性排序

在只想绘制最重要的特征,却不知道其特征名或索引,此时可以使用 Explanation 对象的点链功能来计算全局特征重要性的度量,按该度量(降序)排序,然后挑选出顶部特征。

# 平均绝对均值的特征
ind_mean = shap_values2.abs.mean(0).argsort[-1]
# 平均绝对值最大的特征
ind_max = shap_values.abs.max(0).argsort[-1]
# 95% 绝对值对特征进行排序
ind_perc = shap_values.abs.percentile(950).argsort[-1]
shap.plots.scatter(shap_values2[:, ind_mean])

另外还可以自定义图形属性,详情可参加官方文档。敬请期待下篇。

参考文章  
[1] https://shap.readthedocs.io/en/latest/index.html
[2] https://www.bilibili.com/read/cv11622011

浏览 288
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报