Model Training API Examples

This demonstrates how to use the train_model API.

Refer to the Model Training guide for more details.

NOTES:

  • Click here: Open In Colab to run this example interactively in your browser

  • Refer to the Notebook Examples Guide for how to run this example locally in VSCode

Install MLTK Python Package

# Install the MLTK Python package (if necessary)
!pip install --upgrade silabs-mltk

Import Python Packages

# Import the necessary MLTK APIs
from mltk.core import train_model

Example 1: Train as a “dry run”

Before fully training a model, sometimes it is useful to do a “dry run” to ensure everything is working.
This can be done by append -test to the model name. With this, the model is trained for one epoch on a subset of the training data, and a model archive with -test append to the name is generated.

NOTE: Internally, the load_mltk_model API is used to load the model.
See Model Search Path for how to update the model search path for your model.

# Train the  model as a dry run by adding '-test' to the model  name
training_results = train_model('image_example1-test')
Epoch 1/3
Epoch 2/3
Epoch 3/3
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmp6ab8ka6k\assets
INFO:tensorflow:Assets written to: E:\tmp6ab8ka6k\assets
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmpgclx2y4d\assets
INFO:tensorflow:Assets written to: E:\tmpgclx2y4d\assets
c:\Users\reed\workspace\silabs\mltk\.venv\lib\site-packages\tensorflow\lite\python\convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "

Example 2: Train for 10 epochs

The model specification typically contains the number of training epochs. Optionally, the epochs arguments can be used to override the model specification.

NOTE: Internally, the load_mltk_model API is used to load the model.
See Model Search Path for how to update the model search path for your model.

# Train the model for 10 epochs then show the training results
training_results = train_model('image_example1', epochs=10, show=True)
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
../../_images/52d8adef2873140297d97c0c22f19652bf998050f6aa8442ef7a24c60af7eac3.png
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmppbx5mcil\assets
INFO:tensorflow:Assets written to: E:\tmppbx5mcil\assets
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmpeoo7jyna\assets
INFO:tensorflow:Assets written to: E:\tmpeoo7jyna\assets
c:\Users\reed\workspace\silabs\mltk\.venv\lib\site-packages\tensorflow\lite\python\convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "