TensorFlow 2's default Eagle execution mode brings us flexibility and easy debugging. However, in order to pursue faster speed and higher performance, we still want to use tensorflow 1 The default Graph Execution mode in X. At this point, TensorFlow 2 provides us with TF Function module, combined with AutoGraph mechanism, makes us only need to add a simple @ TF Function modifier, you can easily run the model in diagram execution mode.

### Implementation mode

Just encapsulate the code we want to run in graph execution mode in a function and add @ TF. Before the function Function.

import tensorflow as tf from tensorflow import keras import numpy as np from matplotlib import pyplot as plt import time np.random.seed(42) # Set numpy random number seed tf.random.set_seed(42) # Set tensorflow random number seed # Generate training data x = np.linspace(-1, 1, 100) x = x.astype('float32') y = x * x + 1 + np.random.rand(100)*0.1 # y=x^2+1 + random noise x_train = np.expand_dims(x, 1) # Expand one-dimensional data to two-dimensional data y_train = np.expand_dims(y, 1) # Expand one-dimensional data to two-dimensional data plt.plot(x, y, '.') # Draw training data def create_model(): inputs = keras.Input((1,)) x = keras.layers.Dense(10, activation='relu')(inputs) outputs = keras.layers.Dense(1)(x) model = keras.Model(inputs=inputs, outputs=outputs) return model model = create_model() # Create a model loss_fn = keras.losses.MeanSquaredError() # Define loss function optimizer = keras.optimizers.SGD() # Define optimizer @tf.function # Transform the training process into graph execution mode def train(): with tf.GradientTape() as tape: y_pred = model(x_train, training=True) # Forward propagation, be careful not to forget that training=True loss = loss_fn(y_train, y_pred) # Loss calculation tf.summary.scalar("loss", loss, epoch+1) # Write loss to tensorboard grads = tape.gradient(loss, model.trainable_variables) # Calculated gradient optimizer.apply_gradients(zip(grads, model.trainable_variables)) # Back propagation using optimizer return loss epochs = 1000 begin_time = time.time() # Training start time for epoch in range(epochs): loss = train() print('epoch:', epoch+1, '\t', 'loss:', loss.numpy()) # Print training information end_time = time.time() # Training end time print("Training duration:", end_time-begin_time) # forecast y_pre = model.predict(x_train) # Draw the predicted value plt.plot(x, y_pre.squeeze()) plt.show()

Through the experiment, it is concluded that @ TF Function is better than not using @ TF Function training time is many times faster.

Use @ TF When the function of function is executed, it will generate a calculation diagram, and the operation inside is to calculate each node of the diagram. When the same function is called next time and the parameter types are the same, the calculation diagram will be used directly. If the function name is different or the parameter type is different, a new calculation diagram will be generated.

### Attention

It is recommended to only use TensorFlow's native operation in the function and do not use overly complex Python statements. It is best to include only TensorFlow tensor or NumPy array in the function parameters.

- Because only the native operation of tf will produce nodes in the calculation diagram. (for example, python's native print() function will not generate nodes, while tensorflow's tf print()
- For functions with Tensorflow tensor or Numpy array as parameters, the previous calculation diagram can be reused as long as the type is the same. For python native data as like as two peas, floating point numbers 1, 1.5, etc., the value of the parameters must be exactly the same before the previous calculation diagram is reused, otherwise, a new calculation chart will be created.

In addition, generally speaking, when the model consists of more small operations, @ TF Function brings great improvement effect. When the number of operations of the model is small, but a single operation is time-consuming, @ TF The performance improvement brought by function will not be too great.