对numpy模块中的axis的理解

Python算法之旅

共 5080字,需浏览 11分钟

 ·

2022-01-12 18:47

说在前面

关于axis参数的理解一直是个困扰初学者的地方。有时候axis=0代表按行操作,例如在ny.concatenate((a1, a2, ...), axis)函数中,axis=0就表示按行拼接;有时候axis=0代表按列操作,例如在np.sort(a, axis, kind, order)函数中,axis=0就表示按列排序
到底该如何去理解axis参数?是从行和列的角度去理解吗?又或者要从别的角度来看?本文就尝试来回答这个问题。

 一、axis是数组层级

1. 二维数组。如下列演示代码所示:

>>>import numpy as np>>> a =np.array([[1,2,3,7],[2,9,6,1],[3,8,5,4]])>>> aarray([[1, 2, 3,7],       [2, 9, 6, 1],       [3, 8, 5, 4]])>>>a.shape(3, 4)

二维数组a2个轴(axis),分别是axis=0axis=1。如下图所示,二维数组a的第0(axis=0)3个元素(左图),分别为a[0],a[1],a[2];第1(axis=1)4个元素(右图),其中第1个元素为[a[0][0],a[1][0],a[2][0]],第2个元素为[a[0][1],a[1][1],a[2][1]],以此类推。


2. 三维数组。如下列演示代码所示:

>>>import numpy as np>>> b =np.array([[[1,2,3,4],[1,3,4,5]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])>>> barray([[[1, 2,3, 4],        [1, 3, 4, 5]],        [[2, 4, 7, 5],        [8, 4, 3, 5]],        [[2, 5, 7, 3],        [1, 5, 3, 7]]])>>>b.shape(3, 2, 4)

我们可以把shape理解成数组在每个轴(axis)上的元素个数。三维数组b3个轴,分别是axis=0axis=1axis=2。如下图所示,我们可以把三维数组b想象成一个立方体结构,其中第0(axis=0)3个元素(黑框),分别为b[0],b[1],b[2],各占一层,共3层。

1(axis=1)2个元素(红框),表示每层都有2行,其中第1层第1行的元素为 b[0][0],第1层第2行的元素为 b[0][1],以此类推,第3层第2行的元素为 b[2][1]

2(axis=2)4个元素(绿框),表示每层都有4列,其中第1层第1列的元素为[b[0][0][0],b[0][1][0]],第1层第2列的元素为[b[0][0][1], b[0][1][1]],以此类推,第3层第4列的元素为[b[2][0][3],b[2][1][3]]


二、若axis=i,则沿着第i维的方向进行操作

3. 二维数组。如下列演示代码所示:

>>> aarray([[1, 2, 3, 7],       [2, 9, 6, 1],       [3, 8, 5, 4]])>>> np.max(a), np.max(a,axis=0), np.max(a, axis=1)(9, array([3, 9, 6, 7]), array([7,9, 8]))>>> np.sum(a), np.sum(a,axis=0), np.sum(a, axis=1)(51, array([ 6, 19, 14, 12]),array([13, 18, 20]))

由上述代码可知,程序沿着axis指定的轴进行相应的函数操作。如果不知道axis,则把数组展开成一维,然后再开始计算。

如上图所示,因为二维数组a34列,则当axis=0时,会将a按行分解,即分成3个元素(每个元素占1行),np.max(a,axis=0)的意思是求这3个元素的最大值。因为每个元素都是长度为4的一维数组(共4列),故求最大值时是分别对每一列求最大值;最后将各列的最大值重新组合成一个长度为4的一维数组[3, 9, 6, 7]。即np.max(a, axis=0)的功能是生成一个新的行,其长度与原数组的每一行均相同,元素值依次为各列的最大值。

axis=1时,会将a按列分解,即分成4个元素(每个元素占1列),np.max(a,axis=1)的意思是求这4个元素的最大值。因为每个元素都是长度为3的一维数组(共3行),故求最大值时是分别对每一行求最大值;最后将各行的最大值重新组合成一个长度为3的一维数组[7, 9, 8]。即np.max(a, axis=1)的功能是生成一个新的列,其长度与原数组的每一列均相同,元素值依次为各行的最大值。

同理,当调用函数np.sum()时,遵循相同规律。当axis=0时,生成一个新的行,其长度与原数组的每一行均相同,元素值依次为各列元素之和,即一维数组[ 6, 19, 14, 12];当axis=1时,生成一个新的列,其长度与原数组的每一列均相同,元素值依次为各行元素之和,即一维数组[13, 18, 20]

我们可以将该规律继续推广到min()mean()argmin()等函数中。


4. 三维数组。如下列演示代码所示:

>>> barray([[[1, 2,3, 4],        [1, 3, 4, 5]],        [[2, 4, 7, 5],        [8, 4, 3, 5]],        [[2, 5, 7, 3],        [1, 5, 3, 7]]])>>>np.max(b), np.max(b, axis=0), np.max(b, axis=1), np.max(b, axis=2)(8, array([[2,5, 7, 5],       [8, 5, 4, 7]]),array([[1, 3, 4,5],       [8, 4, 7, 5],       [2, 5, 7, 7]]),array([[4, 5],       [7, 8],       [7, 7]]))>>> np.sum(b),np.sum(b, axis=0), np.sum(b, axis=1), np.sum(b, axis=2)(94, array([[ 5,11, 17, 12],       [10, 12, 10, 17]]), array([[ 2, 5,  7,  9],       [10, 8, 10, 10],       [ 3, 10, 10, 10]]),array([[10, 13],       [18, 20],       [17, 16]]))

原理与二维数组相同,不做过多解释。可以结合代码自行理解。

 

三、axis参数的应用

5. np.sort()函数,返回输入数组的排序副本。函数格式如下:

numpy.sort(a,axis, kind, order)

参数说明:

a: 要排序的数组。

axis: 沿着它排序数组的轴,如果没有,数组会被展开,沿着最后的轴排序(即二维数组默认axis=1,三维数组默认axis=2,以此类推);对于二维数组,axis=0时按列排序,axis=1时按行排序。

kind: 默认为'quicksort'(快速排序)。

order: 如果数组包含字段名称,则是要排序的字段。

如下列演示代码所示:

>>>import numpy as np>>> a =np.array([[1,2,3,7],[2,9,6,1],[3,8,5,4]])>>> aarray([[1, 2, 3,7],       [2, 9, 6, 1],       [3, 8, 5, 4]])>>>np.sort(a)array([[1, 2, 3,7],       [1, 2, 6, 9],       [3, 4, 5, 8]])>>> aarray([[1, 2, 3,7],       [2, 9, 6, 1],       [3, 8, 5, 4]])>>>np.sort(a, axis=0)array([[1, 2, 3,1],       [2, 8, 5, 4],       [3, 9, 6, 7]])>>>np.sort(a, axis=1)array([[1, 2, 3,7],       [1, 2, 6, 9],       [3, 4, 5, 8]])

由如上演示代码可知,对于二维数组a,调用函数np.sort(a),相当于np.sort(a,axis=1),即默认axis=1;其功能为将二维数组a分成n列,分别对这n列的第i行排序,简称按行排序。

同理,np.sort(a, axis=0)的功能为将二维数组a分成m行,分别对这m行的第i列排序,简称按列排序。

 

6. np.concatenate()函数,用于沿指定轴连接相同形状的两个或多个数组,格式如下:

numpy.concatenate((a1,a2, ...), axis)

参数说明:

a1, a2, ...:相同类型的数组

axis:沿着它连接数组的轴,默认为0,即按行拼接。

如下列演示代码所示:

>>> a =np.array([[1,2],[3,4]])>>> b =np.array([[5,6],[7,8]])>>> a,b(array([[1, 2],       [3, 4]]),array([[5, 6],       [7, 8]]))>>>np.concatenate((a,b),axis=0)array([[1, 2],       [3, 4],       [5, 6],       [7, 8]])>>>np.concatenate((a,b),axis=1)array([[1, 2, 5,6],       [3, 4, 7, 8]])

由如上演示代码可知,对于二维数组a,调用函数np.concatenate((a,b),axis=0),其功能为将二维数组a分成m行,将二维数组b分成n行,然后依次把这些行拼接起来,组合成一个共(m+n)行的新数组。简称按行拼接。

同理,np.concatenate((a,b),axis=1),其功能为将二维数组a分成m列,将二维数组b分成n列,然后依次把这些列拼接起来,组合成一个共(m+n)列的新数组。简称按列拼接。

 

7. 为水果打分。现有四个同学分别对桔子、苹果、西瓜这三种水果打分,根据喜爱程度打1-10分。先将每位同学对三种水果的打分存储在一个一维数组中,再将4个同学的数据组合成一个二维数组a。演示代码如下:

>>> a =np.array([[2,5,8],[1,4,7],[2,3,6],[5,9,8]])>>> aarray([[2, 5,8],       [1, 4, 7],       [2, 3, 6],       [5, 9, 8]])

如果我们想看看哪个同学最喜欢吃水果,就把他对三种水果的打分求和,代码如下:

>>>a.sum(axis=1)array([15, 12,11, 22])

可以看出第四位同学最喜欢吃水果。

如果我们想看看哪种水果最受欢迎,就把四个同学给它的打分求和,代码如下:

>>>a.sum(axis=0)array([10, 21,29])

可以看出西瓜最受欢迎。


需要本文源代码和word文稿的,可以加入“Python算法之旅”知识星球参与讨论和下载文件,Python算法之旅”知识星球汇集了数量众多的同好,更多有趣的话题在这里讨论,更多有用的资料在这里分享。

我们专注Python算法,感兴趣就一起来!



相关优秀文章:

阅读代码和写更好的代码

最有效的学习方式

利用pandas模块处理学生成绩

利用pandas模块处理百家姓数据

整体操作numpy数组的方法

用apply()函数对DataFrame对象进行批量操作


浏览 59
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报