mltk.core.TrainMixin

class TrainMixin[source]

Provides training properties and methods to the base MltkModel

Refer to the Model Training guide for more details.

Properties

batch_size

Number of samples per gradient update

build_model_function

Function that builds and returns a compiled mltk.core.KerasModel instance

checkpoint

Callback to save the Keras model or model weights at some frequency

checkpoints_dir

Return path to directory containing training checkpoints

checkpoints_enabled

If true, enable saving a checkpoint after each training epoch.

early_stopping

Stop training when a monitored metric has stopped improving

epochs

Number of epochs to train the model.

loss

String (name of objective function), objective function

lr_schedule

Learning rate scheduler

metrics

List of metrics to be evaluated by the model during training and testing

on_save_keras_model

Callback to be invoked after the model has been trained to save the KerasModel.

on_training_complete

Callback to be invoked after the model has been successfully trained

optimizer

String (name of optimizer) or optimizer instance

reduce_lr_on_plateau

Reduce learning rate when a metric has stopped improving

tensorboard

Enable visualizations for TensorBoard

tflite_converter

Converts a TensorFlow model into TensorFlow Lite model

train_callbacks

List of keras.callbacks.Callback instances.

train_kwargs

//www.tensorflow.org/api_docs/python/tf/keras/Model#fit>`_ API.

weights_dir

Return path to directory contianing training weights

weights_file_format

Return the file format used to generate model weights files during training

Methods

__init__

get_checkpoint_path

Return the file path to the checkpoint weights for the given epoch

get_weights_path

Return the path to a Keras .h5 weights file

property build_model_function

Function that builds and returns a compiled mltk.core.KerasModel instance

Your model definition MUST provide this setting.

# Create a MltkModel instance with the 'train' mixin
class MyModel(
    MltkModel,
    TrainMixin,
    ImageDatasetMixin,
    EvaluateClassifierMixin
):
    pass
mltk_model = MyModel()

# Define the model build function
def my_model_builder(mltk_model):
    keras_model = Sequential()
    keras_model.add(Conv2D(8, kernel_size=(3,3), padding='valid', input_shape=mltk_model.input_shape))
    keras_model.add(Flatten())
    keras_model.add(Dense(mltk_model.n_classes, activation='softmax'))

    keras_model.compile(loss=mltk_model.loss, optimizer=mltk_model.optimizer, metrics=mltk_model.metrics)

    return keras_model

# Set the MltkModel's build_model function
mltk_model.build_model_function = my_model_builder
property on_training_complete

Callback to be invoked after the model has been successfully trained

def _on_training_completed(results:TrainingResults):
    ...

my_model.on_training_complete = _on_training_completed

Note

This is invoked after the Keras and .tflite model files are saved

property on_save_keras_model

Callback to be invoked after the model has been trained to save the KerasModel.

This callback may be used to modified the KerasModel that gets saved, e.g. Remove layers of the model that were used for training.

def _on_save_keras_model(mltk_model:MltkModel, keras_model:KerasModel, logger:logging.Logger) -> KerasModel:
    ...
    return keras_model

my_model.on_save_keras_model = _on_save_keras_model

Note

This is invoked before the model is quantized. Quantization will use the KerasModel returned by this callback.

property epochs

Number of epochs to train the model.

Default: 100

An epoch is an iteration over the entire x and y data provided. Note that epochs is to be understood as “final epoch”. The model is not trained for a number of iterations given by epochs, but merely until the epoch of index epochs is reached.

If this is set to -1 then the epochs will be set to an arbitrarily large value. In this case, the early_stopping calback should be used to determine when to stop training the model.

Note

The larger this value is, the longer the model will take to train

property batch_size

Number of samples per gradient update

Default: 32

Typical values are: 16, 32, 64.

Typically, the larger this value is, the more RAM that is required during training.

property optimizer

String (name of optimizer) or optimizer instance

Default: adam

property metrics

List of metrics to be evaluated by the model during training and testing

Default: ['accuracy']

property loss

String (name of objective function), objective function

Default: categorical_crossentropy

property checkpoints_enabled

If true, enable saving a checkpoint after each training epoch.

Default: True

This is useful as it allows for resuming training sessions with the --resume argument to the train command.

Note

This is independent of checkpoint. This saves each epoch’s weights to the logdir/train/checkpoints directory regardless of the what’s configured in checkpoint

property train_callbacks

List of keras.callbacks.Callback instances.

Default: []

List of callbacks to apply during training.

Note

If a callback is found in this list, then the corresponding callback setting is ignore. e.g.: If LearningRateScheduler Callback is found in this list, then lr_schedule is ignored.

property lr_schedule

Learning rate scheduler

Default: None

dict(
    schedule, # a function that takes an epoch index (integer, indexed from 0)
              # and current learning rate (float) as inputs and returns a new learning rate as output (float).

    verbose=0 # int. 0: quiet, 1: update messages.
)

Note

Set to None to disable

At the beginning of every epoch, the this callback gets the updated learning rate value from schedule function provided, with the current epoch and current learning rate, and applies the updated learning rate on the optimizer.

Note

property reduce_lr_on_plateau

Reduce learning rate when a metric has stopped improving

Default: None

Possible values:

dict(
    monitor="val_loss",   # quantity to be monitored.

    factor=0.1,           # factor by which the learning rate will be reduced. new_lr = lr * factor.

    patience=10,          # number of epochs with no improvement after which learning rate will be reduced.

    mode="auto",          # one of {'auto', 'min', 'max'}. In 'min' mode, the learning rate will be reduced
                          # when the quantity monitored has stopped decreasing; in 'max' mode it will be reduced
                          # when the quantity monitored has stopped increasing; in 'auto' mode, the direction is
                          # automatically inferred from the name of the monitored quantity.

    min_delta=0.0001,     # threshold for measuring the new optimum, to only focus on significant changes.

    cooldown=0,           # number of epochs to wait before resuming normal operation after lr has been reduced.

    min_lr=0,             # lower bound on the learning rate.

    verbose=1,            # int. 0: quiet, 1: update messages.
)

Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This callback monitors a quantity and if no improvement is seen for a ‘patience’ number of epochs, the learning rate is reduced.

Note

  • Set to None to disable this callback

  • If lr_schedule is enabled then this callback is automatically disabled

property tensorboard

Enable visualizations for TensorBoard

Default: None

Possible values:

dict(
     histogram_freq=1,       # frequency (in epochs) at which to compute activation and weight histograms
                             # for the layers of the model. If set to 0, histograms won't be computed.
                             # Validation data (or split) must be specified for histogram visualizations.

     write_graph=True,       # whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph is set to True.

     write_images=False,     # whether to write model weights to visualize as image in TensorBoard.

     update_freq="epoch",    # 'batch' or 'epoch' or integer. When using 'batch', writes the losses and metrics
                             # to TensorBoard after each batch. The same applies for 'epoch'.
                             # If using an integer, let's say 1000, the callback will write the metrics and losses
                             # to TensorBoard every 1000 batches. Note that writing too frequently to
                             # TensorBoard can slow down your training.

     profile_batch=2,        # Profile the batch(es) to sample compute characteristics.
                             # profile_batch must be a non-negative integer or a tuple of integers.
                             # A pair of positive integers signify a range of batches to profile.
                             # By default, it will profile the second batch. Set profile_batch=0 to disable profiling.
 )

This callback logs events for TensorBoard, including:

  • Metrics summary plots

  • Training graph visualization

  • Activation histograms

  • Sampled profiling

Note

property checkpoint

Callback to save the Keras model or model weights at some frequency

Default:

dict(
     monitor="val_accuracy",   # The metric name to monitor. Typically the metrics are set by the Model.compile method.
                               # Note:
                               # - Prefix the name with "val_" to monitor validation metrics.
                               # - Use "loss" or "val_loss" to monitor the model's total loss.
                               # - If you specify metrics as strings, like "accuracy", pass the same string (with or without the "val_" prefix).
                               # - If you pass metrics.Metric objects, monitor should be set to metric.name
                               # - If you're not sure about the metric names you can check the contents of the history.history dictionary returned by history = model.fit()
                               # - Multi-output models set additional prefixes on the metric names.

     save_best_only=True,      # if save_best_only=True, it only saves when the model is considered the "best"
                               # and the latest best model according to the quantity monitored will not be overwritten.
                               # If filepath doesn't contain formatting options like {epoch} then filepath will be overwritten by each new better model.

     save_weights_only=True,   # if True, then only the model's weights will be saved (model.save_weights(filepath)),
                               # else the full model is saved (model.save(filepath)).

     mode="auto",              # one of {'auto', 'min', 'max'}. If save_best_only=True, the decision to overwrite
                               # the current save file is made based on either the maximization or the minimization of the
                               # monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc.
                               # In auto mode, the direction is automatically inferred from the name of the monitored quantity.

     save_freq="epoch",        # 'epoch' or integer. When using 'epoch', the callback saves the model after each epoch.
                               # When using integer, the callback saves the model at end of this many batches.
                               # If the Model is compiled with steps_per_execution=N, then the saving criteria will be
                               # checked every Nth batch. Note that if the saving isn't aligned to epochs,
                               # the monitored metric may potentially be less reliable (it could reflect as little
                               # as 1 batch, since the metrics get reset every epoch). Defaults to 'epoch'.

     options=None,             # Optional tf.train.CheckpointOptions object if save_weights_only is true or optional
                               # tf.saved_model.SaveOptions object if save_weights_only is false.

     verbose=0,                # verbosity mode, 0 or 1.
 )

ModelCheckpoint callback is used in conjunction with training using model.fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.

Note

  • Set to None to disable this callback

  • Tensorboard logs are saved to MltkModel.log_dir/train/weights

  • This is independent of checkpoints_enabled.

property early_stopping

Stop training when a monitored metric has stopped improving

Default: None

Possible values:

dict(
    monitor="val_accuracy",     # Quantity to be monitored.

    min_delta=0,                # Minimum change in the monitored quantity to qualify as an improvement,
                                # i.e. an absolute change of less than min_delta, will count as no improvement.

    patience=25,                # Number of epochs with no improvement after which training will be stopped.

    mode="auto",                # One of {"auto", "min", "max"}. In min mode, training will stop when the quantity
                                # monitored has stopped decreasing; in "max" mode it will stop when the quantity monitored
                                # has stopped increasing; in "auto" mode, the direction is automatically inferred from
                                # the name of the monitored quantity.

    baseline=None,              # Baseline value for the monitored quantity. Training will stop if
                                # the model doesn't show improvement over the baseline.

    restore_best_weights=True,  # Whether to restore model weights from the epoch with the best value of the monitored quantity.
                                # If False, the model weights obtained at the last step of training are used.

    verbose=1,                  # verbosity mode.
)

Assuming the goal of a training is to minimize the loss. With this, the metric to be monitored would be ‘loss’, and mode would be ‘min’. A model.fit() training loop will check at end of every epoch whether the loss is no longer decreasing, considering the min_delta and patience if applicable. Once it’s found no longer decreasing, model.stop_training is marked True and the training terminates.

Note

  • Set to None to disable this callback

  • Set epochs to -1 to always train until early stopping is triggered

property tflite_converter

Converts a TensorFlow model into TensorFlow Lite model

Default:

dict(
    optimizations = [tf.lite.Optimize.DEFAULT],             # Experimental flag, subject to change.
                                                            # A list of optimizations to apply when converting the model. E.g. [Optimize.DEFAULT]

    supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8],  # Experimental flag, subject to change. Set of OpsSet options supported by the device.
                                                            # Add to the 'target_spec' option
                                                            # https://www.tensorflow.org/api_docs/python/tf/lite/TargetSpec

    inference_input_type = tf.float32,                      # Data type of the input layer. Note that integer types (tf.int8 and tf.uint8) are
                                                            # currently only supported for post training integer quantization and quantization aware training.
                                                            # (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})

    inference_output_type = tf.float32,                     # Data type of the output layer. Note that integer types (tf.int8 and tf.uint8) are currently only
                                                            # supported for post training integer quantization and quantization aware training.
                                                            # (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})

    representative_dataset = 'generate',                    # A representative dataset that can be used to generate input and output samples
                                                            # for the model. The converter can use the dataset to evaluate different optimizations.
                                                            # Note that this is an optional attribute but it is necessary if INT8 is the only
                                                            # support builtin ops in target ops.
                                                            # If the keyword 'generate' is used, then use update to "representative_dataset_max_samples" samples from the model's
                                                            # validation dataset as the representative dataset

    representative_dataset_max_samples = 1000               # The maximum number of samples to use when representative_dataset == 'generate'

    allow_custom_ops = False,                               # Boolean indicating whether to allow custom operations. When False, any unknown operation is an error.
                                                            # When True, custom ops are created for any op that is unknown. The developer needs to provide these to the
                                                            # TensorFlow Lite runtime with a custom resolver. (default False)

    experimental_new_converter = True,                      # Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. (default True)

    experimental_new_quantizer = True,                      # Experimental flag, subject to change. Enables MLIR-based quantization conversion instead of Flatbuffer-based conversion. (default True)

    experimental_enable_resource_variables = False,         # Experimental flag, subject to change. Enables resource variables to be converted by this converter.
                                                            # This is only allowed if from_saved_model interface is used. (default False)

    generate_unquantized = True                             # Also generate a float32/unquantized .tflite model in addition to the quantized .tflite model

    generate_quantization_report = False                    # Generate a quantization report. This can help determine if layers have quantization errors compared to the float32 model
                                                            # See:
                                                            # https://www.tensorflow.org/lite/performance/quantization_debugger
                                                            # https://siliconlabs.github.io/mltk/mltk/tutorials/model_quantization_tips.html
)

This is used after the model finishes training. The trained Keras .h5 model file is converted to a .tflite file using the TFLiteConverter using the settings specified by this field.

If generate_unquantized=True then a quantized .tflite AND an unquantized .tflite model files with be generated. If you ONLY want to generate an unquantized model, then supported_ops = TFLITE_BUILTINS

Note

See on_training_complete to invoke a custom callback which may be used to perform custom quantization

property checkpoints_dir

Return path to directory containing training checkpoints

get_checkpoint_path(epoch=None)[source]

Return the file path to the checkpoint weights for the given epoch

If no epoch is provided then return the best checkpoint weights file is return. Return None if no checkpoint is found.

Note

Checkpoints are only generated if checkpoints_enabled is True.

Return type:

str

property weights_dir

Return path to directory contianing training weights

property weights_file_format

Return the file format used to generate model weights files during training

get_weights_path(filename=None)[source]

Return the path to a Keras .h5 weights file

Return type:

str

Parameters:

filename (str) –

property train_kwargs

//www.tensorflow.org/api_docs/python/tf/keras/Model#fit>`_ API. These keyword arguments will override the other model properties passed to fit().

Type:

Additional arguments to pass the the `model fit <https