Model Training

This describes how to train a Machine Learning model using the MLTK and Google Tensorflow.


This document focuses on the training aspect of model development. Refer to the tutorials for end-to-end guides on how to develop an ML model.

Quick Reference


The MLTK internally uses Google Tensorflow to train a model.
The basic sequence for training a model is:

  1. Create a model specification script

  2. Populate the model training and dataset parameters

  3. Define the model layout using the Keras API

  4. Invoke model training using the Command-Line or Python API

When training completes, a model archive file is generated in the same directory as the model training script and contains the trained model files and logs.

HINT: See Training via SSH for how to quickly train your model in the cloud.

Model Specification

All model training parameters are defined in the Model Specification script. This is a standard Python script that defines a MltkModel instance.

MltkModel Instance

All training parameters are configured in the MltkModel instance.
For example, the following might be added to the top of

# Define a new MyModel class which inherits the 
# MltkModel and several mixins
# @mltk_model
class MyModel(
    """My Model's class object"""

# Instantiate the MyModel class
my_model = MyModel()

Here we define our model’s class object: MyModel.

At a minimum, this custom class must inherit the following:

Additionally, this class inherits other model “mixins” to aid model development.

After our model is instantiated, the rest of the model specification simply populates the various properties of MyModel, e.g.:

# General Settings
my_model.version = 1
my_model.description = 'My model is great!'

# Training Basic Settings
my_model.epochs = 100
my_model.batch_size = 64 
my_model.optimizer = 'adam'

# Dataset Settings
my_model.dataset = speech_commands_v2
my_model.class_mode = 'categorical'
my_model.classes = ['up', 'down', 'left', 'right']


The filename of the model specification script is the name given to the model. So, in this case, the model name is my_model_v1.

Model Layout

An important property of the MyModel class example from above is TrainMixin.build_model_function. This should reference a function that builds the actual machine learning model which is built using the Keras API.

For example:

def my_model_builder(my_model: MyModel):
    keras_model = Sequential(

    keras_model = Sequential()
        kernel_size=(10, 8),

    return keras_model

# Set the model property to reference the model build function
my_model.build_model_function = my_model_builder

Here, we define a function that builds a KerasModel then sets my_model.build_model_function to reference the function.

At model training time, the model building function is invoked and the built KerasModel is trained using Tensorflow.

Note about hardcoding model layer parameters

While not required, the my_model argument to the building function should be used over hardcoded values, e.g.:

# Good:
# Dynamically determine the number of dense unit based
# on the number of classes specified in the model properties

# Bad:
# Hardcoding dense units
# If the number of classes changes, 
# then training will likely fail

See the Model Specification documentation for more details.

Training Output

When training completes, a model archive file is generated in the same directory as the model specification script and contains the trained model files and logs.

Included in the model archive is a quantized, .tflite model file. This is the file that is programmed into the embedded device and executed by Tensorflow-Lite Micro.

The .tflite is generated by the Tensorflow-Lite Converter. The settings for the converter are defined in the model specification script using the model property: TrainMixin.tflite_converter

For example, the model specification script might have:

my_model.tflite_converter['optimizations'] = [tf.lite.Optimize.DEFAULT]
my_model.tflite_converter['supported_ops'] = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
my_model.tflite_converter['inference_input_type'] = tf.int8
my_model.tflite_converter['inference_output_type'] = tf.int8
my_model.tflite_converter['representative_dataset'] = 'generate'

These settings are used at the end of training to generate the .tflite.
See Model Quantization for more details.


Model training from the command-line is done using the train operation.

For more details on the available command-line options, issue the command:

mltk train --help

HINT: See Training via SSH for how to quickly train your model in the cloud.

The following are examples of how training can be invoked from the command-line:

Example 1: Train as a “dry run”

Before fully training a model, sometimes it is useful to do a “dry run” to ensure everything is working. This can be done by appending -test to the end of the model name. With this, the model is trained for 1 epoch on a subset of the training data, and a model archive with -test append to the name is generated.

mltk train tflite_micro_speech-test

Example 2: Train for 100 epochs

The model specification typically contains the number of training epochs, i.e. TrainMixin.epochs.
Optionally, the --epochs option can be used to override the model specification.

mltk train audio_example1 --epochs 100

Example 3: Resume Training

If training does not fully complete, it can be restarted by adding the --resume option. This will load the weights from the last saved checkpoint and begin training at that checkpoint’s epoch. See TrainMixin.checkpoint for more details.

mltk train image_example1 --resume

Python API

Model training is accessible via train_model API.

Examples using this API may be found in train_model.ipynb