Model Training API Examples¶
This demonstrates how to use the train_model API.
Refer to the Model Training guide for more details.
NOTES:
Click here: to run this example interactively in your browser
Refer to the Notebook Examples Guide for how to run this example locally in VSCode
Install MLTK Python Package¶
# Install the MLTK Python package (if necessary)
!pip install --upgrade silabs-mltk
Import Python Packages¶
# Import the necessary MLTK APIs
from mltk.core import train_model
Example 1: Train as a “dry run”¶
Before fully training a model, sometimes it is useful to do a “dry run” to ensure everything is working.
This can be done by append -test
to the model name.
With this, the model is trained for one epoch on a subset of the training data, and a model archive with -test
append to the name is generated.
NOTE: Internally, the load_mltk_model API is used to load the model.
See Model Search Path for how to update the model search path for your model.
# Train the model as a dry run by adding '-test' to the model name
training_results = train_model('image_example1-test')
Epoch 1/3
Epoch 2/3
Epoch 3/3
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmp6ab8ka6k\assets
INFO:tensorflow:Assets written to: E:\tmp6ab8ka6k\assets
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmpgclx2y4d\assets
INFO:tensorflow:Assets written to: E:\tmpgclx2y4d\assets
c:\Users\reed\workspace\silabs\mltk\.venv\lib\site-packages\tensorflow\lite\python\convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
warnings.warn("Statistics for quantized inputs were expected, but not "
Example 2: Train for 10 epochs¶
The model specification typically contains the number of training epochs.
Optionally, the epochs
arguments can be used to override the model specification.
NOTE: Internally, the load_mltk_model API is used to load the model.
See Model Search Path for how to update the model search path for your model.
# Train the model for 10 epochs then show the training results
training_results = train_model('image_example1', epochs=10, show=True)
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmppbx5mcil\assets
INFO:tensorflow:Assets written to: E:\tmppbx5mcil\assets
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: E:\tmpeoo7jyna\assets
INFO:tensorflow:Assets written to: E:\tmpeoo7jyna\assets
c:\Users\reed\workspace\silabs\mltk\.venv\lib\site-packages\tensorflow\lite\python\convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
warnings.warn("Statistics for quantized inputs were expected, but not "