top of page

Restoring trained weights after modifying the model graph in TensorFlow

Writer's picture: Divakar VDivakar V

Updated: Jul 26, 2020

Background: Recently, I had a situation where my trained model (TF) could not run on a platform because few nodes in the graph (particularly tf.nn.batch_normalization) were not supported by the platform. So, I had to modify the graph and make it compatible. For few nodes, I replaced them with some other supported variant provided by TensorFlow, and for the rest, I re-wrote them as a function of other supported nodes. In either case, the network graph was modified!


Objective: So, in this post, I'll cover how to restore the learnt weights if the graph is modified.


Resuming model's training from a saved checkpoint involves two broad things and issues occur when there is a mismatch between them:

  1. Loading the graph of the neural network model.

  2. Loading the so-far learnt weights of the model.

The learnt weights are basically the values of variables present in a model. But an important point to note here is that, these variables (learnt-weights) correspond only to the trainable variables in the graph and not to the non-trainable variables of the graph.


[TL; DR] As long as you are not messing with trainable variables, you can simply restore all the weights from the saved checkpoint. If you are adding new trainable variables to the network, see Case 4 below.


Let's define our base model on which we will be performing our modifications. The image below shows our basic FCNN for performing experiments.

Original model

Keeping it simple, let's modify the original network in 4 basic ways:

Case 1: New_model = Original_model - Trainable_Vars
Case 2: New_model = Original_model - non_Trainable_Vars
Case 3: New_model = Original_model + non_Trainable_Vars
Case 4: New_model = Original_model + Trainable_Vars

Now, let's experiment for each case and see if any issue occurs while restoring the model.


Case 1 and 2: Removing nodes

Firstly, one should be careful while removing nodes as it can cause dimension mismatch for the neighboring layers or nodes. For instance, removing convolutional operation (Conv2D node) of 2nd layer will cause dimension error for bias addition (expected 16 but 3).

So, keeping that in mind, we'll remove the 3rd layer completely (Conv2D & BiasAdd).


Orignal (Top) vs Modified (Bottom)

As long as network is valid, removing nodes does not creates an issue while restoring weights in the modified network. This is because weights (fewer) required by the modified network is existing in the saved checkpoint.


Case 3: Adding nodes (No trainable variables)

I had a situation where Elu (Exponential linear unit) activation node was not supported on a platform. So, I substituted all the occurrences of Elu nodes to Relu (most common) activation function and I could proceed ahead and complete my experiment.


To highlight this case, we'll do the following modifications:

Modify node: Convert all Relu nodes to leaky_relu nodes

Addition: Add a skip connection to the network and duplicate activation functions

Orignal (Top) vs Modified (Bottom)

Such modifications don't create issues and you can resume your training from the last checkpoint with the modified network.


Case 4: Adding nodes (containing trainable variables)

This is where things get different. Let's say you add a new Conv2D node. Now, you cannot simply load the saved checkpoint directly because weights corresponding to the new node won't be present in the saved checkpoint. You will get an error saying:

tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:...

Modifying the network by adding a convolutional and bias-addition layer in the end:

Original (Top) vs Modified (Bottom)
-----------List of Trainable Variables-------------
Original Model Weights  |   Modified Model Weights
(present in checkpoint) |
---------------------------------------------------
'model_vars/w1:0'       |      'model_vars/w1:0'
'model_vars/w2:0'       |      'model_vars/w2:0'
'model_vars/w3:0'       |      'model_vars/w3:0'
'model_vars/b1:0'       |      'model_vars/b1:0'
'model_vars/b2:0'       |      'model_vars/b2:0'
'model_vars/b3:0'       |      'model_vars/b3:0'
                        |      'model_vars/w_new:0'
                        |      'model_vars/b_new:0'

Model cannot restore 'w_new' and 'b_new' nodes because they are not present in the saved checkpoint. Hence, the error occurs.

## [X] This will not work [X]
    
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, 'checkpoint/model-original')

Solution:


Instead of trying to restore all the weights of the modified model, restore only a partial list of variables which you know are present in the original model. TensorFlow allows us to achieve this by using var_list parameter of tf.train.Saver().

For the remaining variables, which are not restored, you still need to initialize them with values before using them.


One things to note here, the saver object created above (saver_list) will load as well as store only the variables specified in the list. This means if you train your modified model, then saver_list will only store the weights of the specified variables.

So, you should create another Saver object to store all the weights of the modified network.



Link to GitHub repository:


Gotta modify 'em all!

Comments


bottom of page