Saving the model for ongoing use
To save variables from the tensor flow session for future use, you can use the Saver()
function. Let's start by creating a saver
variable right after the writer
variable:
writer = tf.summary.FileWriter(log_location, session.graph) saver = tf.train.Saver(max_to_keep=5)
Then, in the training loop, we will add the following code to save the model after every model_saving_step
:
if step % model_saving_step == 0 or step == num_steps + 1: path = saver.save(session, os.path.join(log_location, "model.ckpt"), global_step=step) logmanager.logger.info('Model saved in file: %s' % path)
After that, whenever we want to restore the model using the saved
model, we can easily create a new Saver()
instance and use the restore
function as follows:
checkpoint_path = tf.train.latest_checkpoint(log_location) restorer = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) restorer.restore(sess, checkpoint_path)
In...