Pytorch | Tensor张量
机器学习与生成对抗网络
共 3423字,需浏览 7分钟
·
2021-09-02 11:32
点击上方“机器学习与生成对抗网络”,关注星标
获取有趣、好玩的前沿干货!
01
02
import torch
#torch.where
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
out = torch.where(a > 0.5, a, b)
print(out)
print("torch.index_select")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
index=torch.tensor([0, 3, 2]))
#dim=0按列,index取的是行
print(out, out.shape)
print("torch.gather")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
out = torch.gather(a, dim=0,
index=torch.tensor([[0, 1, 1, 1],
[ ],
[ ]]))
print(out)
print(out.shape)
print("torch.masked_index")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
print("torch.take")
a = torch.linspace(1, 16, 16).view(4, 4)
b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))
print(b)
#torch.nonzero
print("torch.take")
a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
out = torch.nonzero(a)
print(out)
#稀疏表示
03
print("torch.stack")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)
print(out[:, :, 0])
print(out[:, :, 1])
04
05
import torch
a = torch.rand(2, 3)
print(a)
out = torch.reshape(a, (3, 2))
print(out)
print(a)
print(torch.flip(a, dims=[2, 1]))
print(a)
print(a.shape)
out = torch.rot90(a, -1, dims=[0, 2]) #顺时针旋转90°
print(out)
print(out.shape)
06
Tensor的填充操作
07
猜您喜欢:
CVPR 2021 | GAN的说话人驱动、3D人脸论文汇总
附下载 |《TensorFlow 2.0 深度学习算法实战》
评论