This article has participated in the activity of “New person creation Ceremony”, and started the road of digging gold creation together.

What can you gain?

By reading this blog, you can learn how to save accuracy and loss during Tensorflow training, how to save and load models in Tensorflow, and how to continue training after the last round of training.

Recently, in the process of neural network training, it is necessary to save the data in the training process, and the next training can continue the breakpoint training with the results of the last training. Therefore, the method of writing model.fit-related callback functions was found on Tensorflow2’s official website. In Tensorflow, for parameter saving and breakpoint continuation, a sample code will be given at the end for your reference.

1. Save training data

How to save the data of training process, including training round count (Epoch), training set ACC, training set Loss, verification set ACC and verification set Loss. Through Tensorflow website can find a save the callback function of training data, the tf. Keras. Callbacks. CSVLogger. Using the method is as simple as specifying a save path and adding the method to the callbacks parameter list in Model.fit. Example code is as follows:

# Parameter Description
#append: Whether to append content to the specified file,
csv_logger = CSVLogger('training.log',append=False)
model.fit(X_train, Y_train, callbacks=[csv_logger])
Copy the code

2. Save and load the model

Save the model can use tf. Keras. Callbacks. ModelCheckpoint () to save. The interface is described as follows: During the operation, only three parameters, filepath, save_BEST_only, and save_weight_only, need to be used. Filepath specifies the path to save the file, save_best_only specifies whether to save only the best model, save_weight_only specifies whether to save only the model weight.

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, initial_value_threshold=None, **kwargs
)
Copy the code

When we save the model, we need to load the model, and load the model through model.load_weights(filepath).

Here is a sample code for your reference, specific requirements can be modified according to the code.


model = TestModel()
model.compile(...).Read the saved model weights
checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index') :print('------load the model------')
    model.load_weights(checkpoint_save_path)
    
# save only the weights (so that when the next training starts, you can continue with the previous training)
cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
# Save only the optimal model (so that the optimal model during the training can be saved after the training)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True) model.fit(..... ,callbacks=[cp_callback_save,cp_callback_save_best])Copy the code

3. sample code

Earlier, we introduced saving parameters and models. Back to the previous problem, we need to continue the previous training process after the last training and save data parameters. First of all, if we want to continue training in the last round, we need to know how many rounds we have trained in the last round. We can easily figure out how many rounds we have trained through our parameter data file. Then we can specify the number of initial rounds with the help of the Initial_epoch of Model.fit, so that the training can continue training in the last round. The reference code is as follows:

# return the number of training rounds,filename: training data save file path
def get_init_epoch(filename) :
    with open(filename) as f:
        f_csv = csv.DictReader(f)
        count = 0
        for row in f_csv:
            count = count+1
        return count

init_epoch = 0 # Start round
if os.path.exists(filename):
    init_epoch = get_init_epoch(filename)
model = Test()
model.compile(...). checkpoint_save_path ='./checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index') :print('------load the model------')
    model.load_weights(checkpoint_save_path)

cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

Append allows you to select whether to add content to the original file or recreate it
csv_logger = CSVLogger('training_log',append=True) model.fit(.... ,init_epoch=init_epoch,callbacks=[csv_logger,cp_callback_save_best,cp_callback_save])Copy the code

If you need to customize the callback function, you can refer to the following resources

All Callback functions subclass the keras.callbacks.Callback class and override a set of methods that are called at various stages of training, testing, and prediction. The callback function is useful for understanding the internal state and statistics of the model during training.

The callback method Outlines the global method

on_(train|test|predict)_begin(self, logs=None) at the start of fit/evaluate/predict. on_(train|test|predict)_end(self, logs=None) at the end of fit/evaluate/predict. Batch-level methodsfor training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None) just before the batch is processed during training/test/prediction. on_(train|test|predict)_batch_end(self, batch, logs=None) called at the end of a training/test/prediction batch. In this approach, logs are a dictionary containing the results of the metric. Period-level methods (training only) on_epoch_begin(self, epoch, logs=None) is called at the start of a cycle during training. on_epoch_end(self, epoch, logs=None) is called at the start of a cycle during training.Copy the code

Let’s look at a concrete example.

class CustomCallback(keras.callbacks.Callback) :
    def on_train_begin(self, logs=None) :
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None) :
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None) :
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None) :
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None) :
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None) :
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None) :
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None) :
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None) :
        keys = list(logs.keys())
        print("... Predicting: end of batch {}; got log keys: {}".format(batch, keys)) 

Copy the code