tf.TensorArray :TensorFlow 动态数组 *

在部分网络结构,尤其是涉及到时间序列的结构中,我们可能需要将一系列张量以数组的方式依次存放起来,以供进一步处理。当然,在Eager Execution下,你可以直接使用一个Python列表(List)存放数组。不过,如果你需要基于计算图的特性(例如使用 @tf.function 加速模型运行或者使用SavedModel导出模型),就无法使用这种方式了。因此,TensorFlow提供了 tf.TensorArray ,一种支持计算图特性的TensorFlow动态数组。

由于需要支持计算图, tf.TensorArray 的使用方式和一般编程语言中的列表/数组类型略有不同,包括4个方法:

  • TODO

一个简单的示例如下:

  1. import tensorflow as tf
  2.  
  3. @tf.function
  4. def array_write_and_read():
  5. arr = tf.TensorArray(dtype=tf.float32, size=3)
  6. arr = arr.write(0, tf.constant(0.0))
  7. arr = arr.write(1, tf.constant(1.0))
  8. arr = arr.write(2, tf.constant(2.0))
  9. arr_0 = arr.read(0)
  10. arr_1 = arr.read(1)
  11. arr_2 = arr.read(2)
  12. return arr_0, arr_1, arr_2
  13.  
  14. a, b, c = array_write_and_read()
  15. print(a, b, c)

输出:

  1. tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(2.0, shape=(), dtype=float32)