Update Model Parameters API Examples

This demonstrates how to use the update_model_parameters API.

Refer to the Model Parameters guide for more details.

NOTES:

  • Click here: Open In Colab to run this example interactively in your browser

  • Refer to the Notebook Examples Guide for how to run this example locally in VSCode

Install MLTK Python Package

# Install the MLTK Python package (if necessary)
!pip install --upgrade silabs-mltk

Import Python Packages

# Import the necessary MLTK APIs
from mltk.core import update_model_parameters, summarize_model

Example 1: Update model specification

The most common use case of the update_model_parameters API is:

  1. Fully train a model

  2. Later modify the model specification script with additional parameters

  3. Run the update_model_parameters API to update the .tflite model file in the model archive.

In this example, it’s assumed that the MltkModel.model_parameters settings in the tflite_micro_speech model specification script have been modified after the model have been trained.

my_model.model_parameters['average_window_duration_ms'] = 1000
my_model.model_parameters['detection_threshold'] = 185
my_model.model_parameters['suppression_ms'] = 1500
my_model.model_parameters['minimum_count'] = 3
my_model.model_parameters['volume_db'] = 5.0
my_model.model_parameters['latency_ms'] = 0
my_model.model_parameters['log_level'] = 'info'

After this API completes, the tflite_micro_speech.mltk.zip model archive is updated with a new tflite_micro_speech.tflite model file.
Note that only the parameters in the .tflite’s metadata section are modified. The model weights and layers are untouched.

# Update the model parameters
update_model_parameters('tflite_micro_speech')

# Generate a summary of the updated model with new parameters
print(summarize_model('tflite_micro_speech'))
Updating c:/users/reed/workspace/silabs/mltk/mltk/models/examples/tflite_micro_speech.mltk.zip
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 25, 20, 8)         648       
_________________________________________________________________
batch_normalization (BatchNo (None, 25, 20, 8)         32        
_________________________________________________________________
activation (Activation)      (None, 25, 20, 8)         0         
_________________________________________________________________
dropout (Dropout)            (None, 25, 20, 8)         0         
_________________________________________________________________
flatten (Flatten)            (None, 4000)              0         
_________________________________________________________________
dense (Dense)                (None, 4)                 16004     
=================================================================
Total params: 16,684
Trainable params: 16,668
Non-trainable params: 16
_________________________________________________________________

Total MACs: 336.000 k
Total OPs: 684.004 k
Name: tflite_micro_speech
Version: 1
Description: TFLite-micro speech
Classes: yes, no, _unknown_, _silence_
hash: None
date: None
average_window_duration_ms: 1000
detection_threshold: 185
suppression_ms: 1500
minimum_count: 3
volume_db: 5.0
latency_ms: 0
log_level: info

Example 2: Update with argument

The update_model_parameters API also works with .tflite model files generated outside of the MLTK.
In this mode, model parameters can be supplied via API argument, e.g.:

import os 
import tempfile
import urllib
import shutil

# Use .tflite mode found here:
# https://github.com/mlcommons/tiny/tree/master/benchmark/training/keyword_spotting/trained_models
# NOTE: Update this URL to point to your model if necessary
TFLITE_MODEL_URL = 'https://github.com/mlcommons/tiny/raw/master/benchmark/training/keyword_spotting/trained_models/kws_ref_model.tflite'

# Download the .tflite file and save to the temp dir
external_tflite_path = os.path.normpath(f'{tempfile.gettempdir()}/kws_ref_model.tflite')
with open(external_tflite_path, 'wb') as dst:
    with urllib.request.urlopen(TFLITE_MODEL_URL) as src:
        shutil.copyfileobj(src, dst)
# Set the parameters in a Python dictionary
parameters = {
    "volume": 10.0,
    "log_level": "debug",
    "threshold": 43
}
# Update the model parameters
update_model_parameters(external_tflite_path, params=parameters)

# Generate a summary of the updated model with new parameters
print(summarize_model(external_tflite_path))
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
| Index | OpCode            | Input(s)       | Output(s)      | Config                                                |
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
| 0     | conv_2d           | 49x10x1 (int8) | 25x5x64 (int8) | Padding:same stride:2x2 activation:relu               |
|       |                   | 10x4x1 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 1     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 2     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 3     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 4     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 5     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 6     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 7     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 8     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 9     | average_pool_2d   | 25x5x64 (int8) | 1x1x64 (int8)  | Padding:valid stride:5x25 filter:5x25 activation:none |
| 10    | reshape           | 1x1x64 (int8)  | 64 (int8)      | BuiltinOptionsType=0                                  |
|       |                   | 2 (int32)      |                |                                                       |
| 11    | fully_connected   | 64 (int8)      | 12 (int8)      | Activation:none                                       |
|       |                   | 64 (int8)      |                |                                                       |
|       |                   | 12 (int32)     |                |                                                       |
| 12    | softmax           | 12 (int8)      | 12 (int8)      | BuiltinOptionsType=9                                  |
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
Total MACs: 2.657 M
Total OPs: 5.394 M
Name: summarize_model
Version: 1
Description: Generated by Silicon Lab's MLTK Python package
classes: []
hash: cda11c4380f044289cd9b73ccc2c20cc
date: 2021-10-19T18:53:57.429Z
volume: 10.0
log_level: debug
threshold: 43
.tflite file size: 54.0kB

Example 3: Return TfliteModel instance

Rather than update the given model, the update_model_parameters API can also return a TfliteModel instance with the updated parameters.

This is done by specifying the output='tflite_model' API argument:

# Set the parameters in a Python dictionary
parameters = {
    "volume": 10.0,
    "log_level": "debug",
    "threshold": 43
}
# Generate a TfliteModel instance with the given parameters
# NOTE: The input external_tflite_path file is NOT modified
tflite_model = update_model_parameters(external_tflite_path, params=parameters, output='tflite_model')

# Generate a summary of the returned TfliteModel instance
print(summarize_model(tflite_model))
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
| Index | OpCode            | Input(s)       | Output(s)      | Config                                                |
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
| 0     | conv_2d           | 49x10x1 (int8) | 25x5x64 (int8) | Padding:same stride:2x2 activation:relu               |
|       |                   | 10x4x1 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 1     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 2     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 3     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 4     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 5     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 6     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 7     | depthwise_conv_2d | 25x5x64 (int8) | 25x5x64 (int8) | Multipler:1 padding:same stride:1x1 activation:relu   |
|       |                   | 3x3x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 8     | conv_2d           | 25x5x64 (int8) | 25x5x64 (int8) | Padding:same stride:1x1 activation:relu               |
|       |                   | 1x1x64 (int8)  |                |                                                       |
|       |                   | 64 (int32)     |                |                                                       |
| 9     | average_pool_2d   | 25x5x64 (int8) | 1x1x64 (int8)  | Padding:valid stride:5x25 filter:5x25 activation:none |
| 10    | reshape           | 1x1x64 (int8)  | 64 (int8)      | BuiltinOptionsType=0                                  |
|       |                   | 2 (int32)      |                |                                                       |
| 11    | fully_connected   | 64 (int8)      | 12 (int8)      | Activation:none                                       |
|       |                   | 64 (int8)      |                |                                                       |
|       |                   | 12 (int32)     |                |                                                       |
| 12    | softmax           | 12 (int8)      | 12 (int8)      | BuiltinOptionsType=9                                  |
+-------+-------------------+----------------+----------------+-------------------------------------------------------+
Total MACs: 2.657 M
Total OPs: 5.394 M
Name: summarize_model
Version: 1
Description: Generated by Silicon Lab's MLTK Python package
classes: []
hash: cda11c4380f044289cd9b73ccc2c20cc
date: 2021-10-19T18:53:57.565Z
volume: 10.0
log_level: debug
threshold: 43
.tflite file size: 54.0kB