tf.unstack has appeared in the construction of cyclic neural network. It is convenient for you to record here.
Function: split the input value according to the specified axis (dimension) (starting from 0), and output the list containing num elements. Num must be equal to the number of elements in the specified dimension. Of course, this parameter can be ignored.
For example, if value.shape is (2,3,4),
If axis=0, num must be filled with 2. After transformation, the list has 2 elements, and the shape of the element is (3,4).
If axis=1, num must fill in 3. After transformation, the list has 3 elements, and the shape of the element is (2,4).
If axis=2, num must fill in 4. After transformation, the list has 4 elements, and the shape of the element is (2,3).
import tensorflow as tf import numpy as np X = tf.constant(np.array(range(24)).reshape(2, 3, 4)) X0 = tf.unstack(X, 2, 0) X1 = tf.unstack(X, 3, 1) X2 = tf.unstack(X, 4, 2) with tf.Session() as sess: results = sess.run([X, X0, X1, X2]) for t, x in zip([X, X0, X1, X2], results): print(t, '\n', x, '\n')
The output after manual beautification is as follows, followed by X, X0, X1, X2
Tensor("Const:0", shape=(2, 3, 4), dtype=int64) [[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]]] [<tf.Tensor 'unstack:0' shape=(3, 4) dtype=int64>, <tf.Tensor 'unstack:1' shape=(3, 4) dtype=int64>] [array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]), array([[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]])] [<tf.Tensor 'unstack_1:0' shape=(2, 4) dtype=int64>, <tf.Tensor 'unstack_1:1' shape=(2, 4) dtype=int64>, <tf.Tensor 'unstack_1:2' shape=(2, 4) dtype=int64>] [array([[ 0, 1, 2, 3], [12, 13, 14, 15]]), array([[ 4, 5, 6, 7], [16, 17, 18, 19]]), array([[ 8, 9, 10, 11], [20, 21, 22, 23]])] [<tf.Tensor 'unstack_2:0' shape=(2, 3) dtype=int64>, <tf.Tensor 'unstack_2:1' shape=(2, 3) dtype=int64>, <tf.Tensor 'unstack_2:2' shape=(2, 3) dtype=int64>, <tf.Tensor 'unstack_2:3' shape=(2, 3) dtype=int64>] [array([[ 0, 4, 8], [12, 16, 20]]), array([[ 1, 5, 9], [13, 17, 21]]), array([[ 2, 6, 10], [14, 18, 22]]), array([[ 3, 7, 11], [15, 19, 23]])]