Pytorch中的5个非常有用的张量操作

共 3614字,需浏览 8分钟

 ·

2021-06-03 17:42

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自:AI公园

导读

虽然也有其他方式可以实现相同的效果,但是这几个操作可以让使用更加方便。


PyTorch是一个基于Python的科学包,用于使用一种称为张量的特殊数据类型执行高级操作。张量是具有规则形状和相同数据类型的数字、向量、矩阵或多维数组。PyTorch是NumPy包的另一种选择,它可以在GPU下使用。它也被用作进行深度学习研究的框架。

这5个操作是:

  • expand()
  • permute()
  • tolist()
  • narrow()
  • where()
1. expand()


将现有张量沿着值为1的维度扩展到新的维度。张量可以同时沿着任意一维或多维展开。如果你不想沿着一个特定的维度展开张量,你可以设置它的参数值为-1。

注意:只能扩展单个维度

# Example 1 - working 
a=torch.tensor([[[1,2,3],[4,5,6]]])
a.size()
>>torch.Size([123])

a.expand(2,2,3)
>>tensor([[[123],
         [456]],

        [[123],
         [456]]])

在这个例子中,张量的原始维数是[1,2,3]。它被扩展到[2,2,3]。

2. permute()


这个函数返回一个张量的视图,原始张量的维数根据我们的选择而改变。例如,如果原来的维数是[1,2,3],我们可以将它改为[3,2,1]。该函数以所需的维数顺序作为参数。

# Example 1 - working
a=torch.tensor([[[1,2,3],[4,5,6]]])
a.size()
>>torch.Size([123])

a.permute(2,1,0).size()
>>torch.Size([321])

a.permute(2,1,0)
>>tensor([[[1],
         [4]],

        [[2],
         [5]],

        [[3],
         [6]]])

在这个例子中,原始张量的维度是[1,2,3]。使用permuting,我将顺序设置为(2,1,0),这意味着新的维度应该是[3,2,1]。如图所示,张量的新视图重新排列了数字,使得张量的维度为[3,2,1]

当我们想要对不同维数的张量进行重新排序,或者用不同阶数的矩阵进行矩阵乘法时,可以使用这个函数。

3. tolist()


这个函数以Python数字、列表或嵌套列表的形式返回张量。在此之后,我们可以对它执行任何python逻辑和操作。

# Example 1 - working
a=torch.tensor([[1,2,3],[4,5,6]])
a.tolist()
>> [[123], [456]]

在这个例子中,张量以嵌套列表的形式返回。

4. narrow()


这个函数返回一个新的张量,这个张量是原来张量的缩小版。这个函数的参数是输入张量、要缩小的维数、起始索引和新张量沿该维数的长度。它返回从索引start到索引(start+length-1)中的元素。

# Example 1 - working
a=torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12],[14,15,16,17]])
torch.narrow(a,1,2,2)
>> tensor([[ 3,  4],
        [ 7,  8],
        [1112],
        [1617]])

在这个例子中,张量要沿着第2维,也就是最里面的维度缩小。它接受列表中的元素,从索引2开始,到索引3(=2+2 -1,即start+length-1)。

Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。

5. where()


这个函数返回一个新的张量,其值在每个索引处都根据给定条件改变。这个函数的参数有:条件,第一个张量和第二个张量。在每个张量的值上检查条件(在条件中使用),如果为真,就用第一个张量中相同位置的值代替,如果为假,就用第二个张量中相同位置的值代替。

# Example 1 - working
a=torch.tensor([[[1,2,3],[4,5,6]]]).to(torch.float32)
b=torch.zeros(1,2,3)
torch.where(a%2==0,b,a)
>>tensor([[[1.0.3.],
         [0.5.0.]]])

这里,它检查张量a的值是否是偶数。如果是,则用张量b中的值替换,b中的值都是0,否则还是和原来一样。

此函数可用于设定阈值。如果张量中的值大于或小于某一数值,它们可以很容易地被替换。


下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 34
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报