【TensorFlow】笔记:基础知识-张量操作(三)
共 3240字,需浏览 7分钟
·
2021-01-26 16:34
TensorFlow 使用 张量 (Tensor)作为数据的基本单位。
今天学习张量操作的形状、DTypes、广播、conver_to_tensor
01
操作形状
首先,我们创建一个张量
var_x = tf.Variable(tf.constant([[1], [2], [3]]))
print(var_x.shape)
# output
(3, 1)
我们也可以将该对象转换为Python列表
print(var_x.shape.as_list())
# output
[3, 1]
通过重构可以改变张量的形状。重构的速度很快,资源消耗很低,因为不需要复制底层数据。
reshaped = tf.reshape(var_x, [1, 3])
print(var_x.shape)
print(reshaped.shape)
# output
(3, 1)
(1, 3)
数据在内存中的布局保持不变,同时使用请求的形状创建一个指向同一数据的新张量。TensorFlow 采用 C 样式的“行优先”内存访问顺序,即最右侧的索引值递增对应于内存中的单步位移。
tensor = tf.constant([
[ ],
[ ]],
[ ],
[ ]],
[ ],
[ ]],])
print(tensor)
tf.Tensor(
[ ]
[ ]]
[ ]
[ ]]
[ ]
[3, 2, 5), dtype=int32) ]]], shape=(
如果展平张量,则可以看到它在内存中的排列顺序。
print(tf.reshape(tensor, [-1]))
# output
tf.Tensor(
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29], shape=(30,), dtype=int32)
一般来说,tf.reshape
唯一合理的用途是用于合并或拆分相邻轴(或添加/移除 1
)。
对于 3x2x5 张量,重构为 (3x2)x5 或 3x(2x5) 都合理,因为切片不会混淆:
print(tf.reshape(tensor, [3*2, 5]), "\n")
print(tf.reshape(tensor, [3, -1]))
# output
tf.Tensor(
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]
[25 26 27 28 29]], shape=(6, 5), dtype=int32)
tf.Tensor(
[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]], shape=(3, 10), dtype=int32)
重构可以处理总元素个数相同的任何新形状,但是如果不遵从轴的顺序,则不会发挥任何作用。
02
DTypes详解
使用 Tensor.dtype
属性可以检查 tf.Tensor
的数据类型。
从 Python 对象创建 tf.Tensor
时,您可以选择指定数据类型。
如果不指定,TensorFlow 会选择一个可以表示您的数据的数据类型。TensorFlow 将 Python 整数转换为 tf.int32
,将 Python 浮点数转换为 tf.float32
。另外,当转换为数组时,TensorFlow 会采用与 NumPy 相同的规则。
数据类型可以相互转换。
the_f64_tensor = tf.constant([2.2, 3.3, 4.4], dtype=tf.float64)
the_f16_tensor = tf.cast(the_f64_tensor, dtype=tf.float16)
# 现在,让我们转换为uint8并失去小数精度
the_u8_tensor = tf.cast(the_f16_tensor, dtype=tf.uint8)
print(the_u8_tensor)
# output
tf.Tensor([2 3 4], shape=(3,), dtype=uint8)
03
广播
广播是从 NumPy 中的等效功能借用的一个概念。简而言之,在一定条件下,对一组张量执行组合运算时,为了适应大张量,会对小张量进行“扩展”。
最简单和最常见的例子是尝试将张量与标量相乘或相加。在这种情况下会对标量进行广播,使其变成与其他参数相同的形状。
x = tf.constant([1, 2, 3])
y = tf.constant(2)
z = tf.constant([2, 2, 2])
# 这些都是相同的运算
print(tf.multiply(x, 2))
print(x * y)
print(x * z)
# output
tf.Tensor([2 4 6], shape=(3,), dtype=int32)
tf.Tensor([2 4 6], shape=(3,), dtype=int32)
tf.Tensor([2 4 6], shape=(3,), dtype=int32)
同样,可以扩展大小为 1 的维度,使其符合其他参数。在同一个计算中可以同时扩展两个参数。
在本例中,一个 3x1 的矩阵与一个 1x4 进行元素级乘法运算,从而产生一个 3x4 的矩阵。注意前导 1 是可选的:y 的形状是 [4]
。
x = tf.reshape(x,[3,1])
y = tf.range(1, 5)
print(x, "\n")
print(y, "\n")
print(tf.multiply(x, y))
# output
tf.Tensor(
[[1]
[2]
[3]], shape=(3, 1), dtype=int32)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor(
[[ 1 2 3 4]
[ 2 4 6 8]
[ 3 6 9 12]], shape=(3, 4), dtype=int32)
广播相加:[3, 1] 乘以 [1, 4]的结果是[3, 4]
04
tf.convert_to_tensor
大部分运算(如 tf.matmul
和 tf.reshape
)会使用 tf.Tensor
类的参数。不过,在上面的示例中,您会发现我们经常传递形状类似于张量的 Python 对象。
大部分(但并非全部)运算会在非张量参数上调用 convert_to_tensor
。我们提供了一个转换注册表,大多数对象类(如 NumPy 的 ndarray
、TensorShape
、Python 列表和 tf.Variable
)都可以自动转换。
参考文献:文档主要参考TensorFlow官网
点击上方“蓝字”关注本公众号
点击上方“蓝字”关注本公众号
END
扫码关注
微信号|sdxx_rmbj