Tensorflow model saving, loading and fine tune

1, Save Tensorflow model:

1. Save file description

Tensorflow model mainly includes the design (diagram) of the network and the values of trained parameters. Therefore, tensorflow model has two main files:

1) graph.pbtxt: This is actually a text file, which saves the structure information of the model

2) checkpoint file: in fact, it is a txt file that stores path information

3) .ckpt-*.meta: actually, it's the same as the graph above Pbtxt has the same function. It saves the graph structure, but the meta file is binary

4).ckpt-*.index: This is a string table. The key value of the table is the tensor name and the value is serialized BundleEntryProto. Each BundleEntryProto describes the tensor's metadata. For example, the data file contains the tensor, the offset in the file, some auxiliary data, etc.

5)model.ckpt-*.data - *: save the values of all variables of the model, TensorBundle set.

6)events.out.tfevents.*...: What you save is the value of your accuracy or loss at different times, which is required by Tensorboard.

2. Save code description

In order to save the graph in Tensorflow and the values of all parameters, we create a TF train. An instance of the saver () class.

If we're not in TF train. If you specify any parameter in saver (), it will save all variables. If we don't want to save all variables but only some, we can specify the variables/collections we want to save When creating TF train. When we save the instance, we pass it to the list or dictionary of variables we want to save.

#Save all variables
saver = tf.train.Saver()  

#Save some variables
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')#Gets the tensor of the specified scope
saver = tf.train.Saver(vgg_ref_vars)#When initializing the saver, a VaR is passed in_ List parameters

#Save some variables
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#Create a saver object Values is to obtain dictionary values in the form of a list

Method of saving model save:

  • Session is a session object
  • model_savedpath="path+name" is your "path + name" for your model
  • global_step=num indicates how many iterations to save the model (for example, save the model after 1000 iterations: global_step=1000)
  • max_to_keep=m, if you want to save the latest M models
  • keep_checkpoint_every_n_hours=n, save every n hours of training
  • write_meta_graph=False do not write network structure diagram
saver.save(session, "model_savedpath", global_step=epoch)

Pay attention! Variables exist in the Session environment, that is, only in the Session environment can variable values exist

Model loading

Loading model and variable description

1. All loaded codes include two parts: loading network structure and loading variable parameters

(1)tf.train.import_meta_graph(path+"xxx.meta") loads the network structure

(2) restore(path+"xxx /") method loads the variable #path+"xxx /", which refers to the saved model path, and the recently saved variable file will be found automatically. The previously trained model parameters (i.e. weights, bias, etc.) are required, and the variable value needs to depend on Session. Therefore, when loading parameters, first construct Session:

#Loading model structure  
saver = tf.train.import_meta_graph(path+'xxx/yyy.meta')
#Load variable data using TF train. latest_ Checkpoint() to automatically obtain the last saved model
#path+"xxx /" refers to the saved model path.
saver.restore(sess, tf.train.latest_checkpoint(path+"xxx/"))

 2. If you load variables, you only want to read some of the variable values

  reader = tf.train.NewCheckpointReader(checkpoint_path)
(1) through VAR = reader get_ variable_ to_ shape_ Map () gets all the variables

(2) through graph get_ tensor_ by_ The name ("variable name") method refers to the value corresponding to the saved "variable name"

def read_checkpoint():
  w = []
  checkpoint_path = 'path'
  reader = tf.train.NewCheckpointReader(checkpoint_path)
  var = reader.get_variable_to_shape_map()
  for key in var:
    if 'weights' in key and 'conv' in key and 'Mo' not in key:
      print('tensorname:', key)
  #   # print(reader.get_tensor(key))

  op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Partial variable recovery
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#Create a saver object Values is to obtain dictionary values in the form of a list
saver.restore(sess, model_filename)

Uninitialized parameters need to be initialized manually

var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())

3. Conduct fine tune

Through TF stop_ The gradient () method performs truncation and backpropagation

# pre-train and fine-tuning
fc2 = graph.get_tensor_by_name("fc2/add:0")
fc2 = tf.stop_gradient(fc2)  # Freeze a part of the model
fc2_shape = fc2.get_shape().as_list()
# fine -tuning
new_nums = 6
weights = tf.Variable(tf.truncated_normal([fc2_shape[1], new_nums], stddev=0.1), name="w")
biases = tf.Variable(tf.constant(0.1, shape=[new_nums]), name="b")
conv2 = tf.matmul(fc2, weights) + biases
output2 = tf.nn.softmax(conv2)

Reference website:

1. Tensorflow loads the pre training model and saves the model_ huachao1001 column - CSDN blog_ tensorflow save and load models

2.About Tensorflow models, saving and preloading_ Blog of YQ8023family - CSDN blog

3. Tensorflow loads pre training model and saved model (ckpt file) and migration learning finetuning_loveliuzz blog - CSDN blog_ ckpt file

4. Save and restore model in tensorflow TF train. Explanation of Saver class (method of recovering some model parameters)_ mieleizhi0522 blog - CSDN blog_ saver.restore() recovery section

Keywords: AI TensorFlow Deep Learning

Added by oeb on Fri, 18 Feb 2022 08:08:45 +0200